Forward codespace ports over Dev Tunnels
This commit is contained in:
parent
48b0d53d0e
commit
e059f32aa5
13 changed files with 1271 additions and 303 deletions
|
|
@ -201,6 +201,7 @@ type Codespace struct {
|
|||
GitStatus CodespaceGitStatus `json:"git_status"`
|
||||
Connection CodespaceConnection `json:"connection"`
|
||||
Machine CodespaceMachine `json:"machine"`
|
||||
RuntimeConstraints RuntimeConstraints `json:"runtime_constraints"`
|
||||
VSCSTarget string `json:"vscs_target"`
|
||||
PendingOperation bool `json:"pending_operation"`
|
||||
PendingOperationDisabledReason string `json:"pending_operation_disabled_reason"`
|
||||
|
|
@ -246,11 +247,25 @@ const (
|
|||
)
|
||||
|
||||
type CodespaceConnection struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
SessionToken string `json:"sessionToken"`
|
||||
RelayEndpoint string `json:"relayEndpoint"`
|
||||
RelaySAS string `json:"relaySas"`
|
||||
HostPublicKeys []string `json:"hostPublicKeys"`
|
||||
SessionID string `json:"sessionId"`
|
||||
SessionToken string `json:"sessionToken"`
|
||||
RelayEndpoint string `json:"relayEndpoint"`
|
||||
RelaySAS string `json:"relaySas"`
|
||||
HostPublicKeys []string `json:"hostPublicKeys"`
|
||||
TunnelProperties TunnelProperties `json:"tunnelProperties"`
|
||||
}
|
||||
|
||||
type TunnelProperties struct {
|
||||
ConnectAccessToken string `json:"connectAccessToken"`
|
||||
ManagePortsAccessToken string `json:"managePortsAccessToken"`
|
||||
ServiceUri string `json:"serviceUri"`
|
||||
TunnelId string `json:"tunnelId"`
|
||||
ClusterId string `json:"clusterId"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
|
||||
type RuntimeConstraints struct {
|
||||
AllowedPortPrivacySettings []string `json:"allowed_port_privacy_settings"`
|
||||
}
|
||||
|
||||
// ListCodespaceFields is the list of exportable fields for a codespace when using the `gh cs list` command.
|
||||
|
|
@ -1162,3 +1177,13 @@ func (a *API) withRetry(f func() (*http.Response, error)) (*http.Response, error
|
|||
return nil, fmt.Errorf("received response with status code %d", resp.StatusCode)
|
||||
}, backoff.WithMaxRetries(bo, 3))
|
||||
}
|
||||
|
||||
// HTTPClient returns the HTTP client used to make requests to the API.
|
||||
func (a *API) HTTPClient() (*http.Client, error) {
|
||||
httpClient, err := a.client()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return httpClient, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,24 +5,42 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces/connection"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
)
|
||||
|
||||
func connectionReady(codespace *api.Codespace) bool {
|
||||
func connectionReady(codespace *api.Codespace, usingDevTunnels bool) bool {
|
||||
// If the codespace is not available, it is not ready
|
||||
if codespace.State != api.CodespaceStateAvailable {
|
||||
return false
|
||||
}
|
||||
|
||||
// If using Dev Tunnels, we need to check that we have all of the required tunnel properties
|
||||
if usingDevTunnels {
|
||||
return codespace.Connection.TunnelProperties.ConnectAccessToken != "" &&
|
||||
codespace.Connection.TunnelProperties.ManagePortsAccessToken != "" &&
|
||||
codespace.Connection.TunnelProperties.ServiceUri != "" &&
|
||||
codespace.Connection.TunnelProperties.TunnelId != "" &&
|
||||
codespace.Connection.TunnelProperties.ClusterId != "" &&
|
||||
codespace.Connection.TunnelProperties.Domain != ""
|
||||
}
|
||||
|
||||
// If not using Dev Tunnels, we need to check that we have all of the required Live Share properties
|
||||
return codespace.Connection.SessionID != "" &&
|
||||
codespace.Connection.SessionToken != "" &&
|
||||
codespace.Connection.RelayEndpoint != "" &&
|
||||
codespace.Connection.RelaySAS != "" &&
|
||||
codespace.State == api.CodespaceStateAvailable
|
||||
codespace.Connection.RelaySAS != ""
|
||||
}
|
||||
|
||||
type apiClient interface {
|
||||
GetCodespace(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error)
|
||||
StartCodespace(ctx context.Context, name string) error
|
||||
HTTPClient() (*http.Client, error)
|
||||
}
|
||||
|
||||
type progressIndicator interface {
|
||||
|
|
@ -43,9 +61,48 @@ func (e *TimeoutError) Error() string {
|
|||
return e.message
|
||||
}
|
||||
|
||||
// ConnectToLiveshare waits for a Codespace to become running,
|
||||
// and connects to it using a Live Share session.
|
||||
// GetCodespaceConnection waits until a codespace is able
|
||||
// to be connected to and initializes a connection to it.
|
||||
func GetCodespaceConnection(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace) (*connection.CodespaceConnection, error) {
|
||||
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
|
||||
defer progress.StopProgressIndicator()
|
||||
|
||||
httpClient, err := apiClient.HTTPClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting http client: %w", err)
|
||||
}
|
||||
|
||||
return connection.NewCodespaceConnection(ctx, codespace, httpClient)
|
||||
}
|
||||
|
||||
// ConnectToLiveshare waits until a codespace is able to be
|
||||
// connected to and connects to it using a Live Share session.
|
||||
func ConnectToLiveshare(ctx context.Context, progress progressIndicator, sessionLogger logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) {
|
||||
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
|
||||
defer progress.StopProgressIndicator()
|
||||
|
||||
return liveshare.Connect(ctx, liveshare.Options{
|
||||
SessionID: codespace.Connection.SessionID,
|
||||
SessionToken: codespace.Connection.SessionToken,
|
||||
RelaySAS: codespace.Connection.RelaySAS,
|
||||
RelayEndpoint: codespace.Connection.RelayEndpoint,
|
||||
HostPublicKeys: codespace.Connection.HostPublicKeys,
|
||||
Logger: sessionLogger,
|
||||
})
|
||||
}
|
||||
|
||||
// waitUntilCodespaceConnectionReady waits for a Codespace to be running and is able to be connected to.
|
||||
func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace, usingDevTunnels bool) (*api.Codespace, error) {
|
||||
if codespace.State != api.CodespaceStateAvailable {
|
||||
progress.StartProgressIndicatorWithLabel("Starting codespace")
|
||||
defer progress.StopProgressIndicator()
|
||||
|
|
@ -54,7 +111,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
|
|||
}
|
||||
}
|
||||
|
||||
if !connectionReady(codespace) {
|
||||
if !connectionReady(codespace, usingDevTunnels) {
|
||||
expBackoff := backoff.NewExponentialBackOff()
|
||||
expBackoff.Multiplier = 1.1
|
||||
expBackoff.MaxInterval = 10 * time.Second
|
||||
|
|
@ -67,7 +124,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
|
|||
return backoff.Permanent(fmt.Errorf("error getting codespace: %w", err))
|
||||
}
|
||||
|
||||
if connectionReady(codespace) {
|
||||
if connectionReady(codespace, usingDevTunnels) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -83,17 +140,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
|
|||
}
|
||||
}
|
||||
|
||||
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
|
||||
defer progress.StopProgressIndicator()
|
||||
|
||||
return liveshare.Connect(ctx, liveshare.Options{
|
||||
SessionID: codespace.Connection.SessionID,
|
||||
SessionToken: codespace.Connection.SessionToken,
|
||||
RelaySAS: codespace.Connection.RelaySAS,
|
||||
RelayEndpoint: codespace.Connection.RelayEndpoint,
|
||||
HostPublicKeys: codespace.Connection.HostPublicKeys,
|
||||
Logger: sessionLogger,
|
||||
})
|
||||
return codespace, nil
|
||||
}
|
||||
|
||||
// ListenTCP starts a localhost tcp listener on 127.0.0.1 (unless allInterfaces is true) and returns the listener and bound port
|
||||
|
|
|
|||
116
internal/codespaces/connection/connection.go
Normal file
116
internal/codespaces/connection/connection.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels"
|
||||
)
|
||||
|
||||
const (
|
||||
clientName = "gh"
|
||||
)
|
||||
|
||||
type CodespaceConnection struct {
|
||||
tunnelProperties api.TunnelProperties
|
||||
TunnelManager *tunnels.Manager
|
||||
TunnelClient *tunnels.Client
|
||||
Options *tunnels.TunnelRequestOptions
|
||||
Tunnel *tunnels.Tunnel
|
||||
AllowedPortPrivacySettings []string
|
||||
}
|
||||
|
||||
// NewCodespaceConnection initializes a connection to a codespace.
|
||||
// This connections allows for port forwarding which enables the
|
||||
// use of most features of the codespace command.
|
||||
func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpClient *http.Client) (connection *CodespaceConnection, err error) {
|
||||
// Get the tunnel properties
|
||||
tunnelProperties := codespace.Connection.TunnelProperties
|
||||
|
||||
// Create the tunnel manager
|
||||
tunnelManager, err := getTunnelManager(tunnelProperties, httpClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting tunnel management client: %w", err)
|
||||
}
|
||||
|
||||
// Calculate allowed port privacy settings
|
||||
allowedPortPrivacySettings := codespace.RuntimeConstraints.AllowedPortPrivacySettings
|
||||
|
||||
// Get the access tokens
|
||||
connectToken := tunnelProperties.ConnectAccessToken
|
||||
managementToken := tunnelProperties.ManagePortsAccessToken
|
||||
|
||||
// Create the tunnel definition
|
||||
tunnel := &tunnels.Tunnel{
|
||||
AccessTokens: map[tunnels.TunnelAccessScope]string{tunnels.TunnelAccessScopeConnect: connectToken, tunnels.TunnelAccessScopeManagePorts: managementToken},
|
||||
TunnelID: tunnelProperties.TunnelId,
|
||||
ClusterID: tunnelProperties.ClusterId,
|
||||
Domain: tunnelProperties.Domain,
|
||||
}
|
||||
|
||||
// Create options
|
||||
options := &tunnels.TunnelRequestOptions{
|
||||
IncludePorts: true,
|
||||
}
|
||||
|
||||
// Create the tunnel client (not connected yet)
|
||||
tunnelClient, err := getTunnelClient(ctx, tunnelManager, tunnel, options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting tunnel client: %w", err)
|
||||
}
|
||||
|
||||
return &CodespaceConnection{
|
||||
tunnelProperties: tunnelProperties,
|
||||
TunnelManager: tunnelManager,
|
||||
TunnelClient: tunnelClient,
|
||||
Options: options,
|
||||
Tunnel: tunnel,
|
||||
AllowedPortPrivacySettings: allowedPortPrivacySettings,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getTunnelManager creates a tunnel manager for the given codespace.
|
||||
// The tunnel manager is used to get the tunnel hosted in the codespace that we
|
||||
// want to connect to and perform operations on ports (add, remove, list, etc.).
|
||||
func getTunnelManager(tunnelProperties api.TunnelProperties, httpClient *http.Client) (tunnelManager *tunnels.Manager, err error) {
|
||||
userAgent := []tunnels.UserAgent{{Name: clientName}}
|
||||
url, err := url.Parse(tunnelProperties.ServiceUri)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing tunnel service uri: %w", err)
|
||||
}
|
||||
|
||||
// Create the tunnel manager
|
||||
tunnelManager, err = tunnels.NewManager(userAgent, nil, url, httpClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating tunnel manager: %w", err)
|
||||
}
|
||||
|
||||
return tunnelManager, nil
|
||||
}
|
||||
|
||||
// getTunnelClient creates a tunnel client for the given tunnel.
|
||||
// The tunnel client is used to connect to the the tunnel and allows
|
||||
// for ports to be forwarded locally.
|
||||
func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *tunnels.Client, err error) {
|
||||
// Get the tunnel that we want to connect to
|
||||
codespaceTunnel, err := tunnelManager.GetTunnel(ctx, tunnel, options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting tunnel: %w", err)
|
||||
}
|
||||
|
||||
// Copy the access tokens from the tunnel definition
|
||||
codespaceTunnel.AccessTokens = tunnel.AccessTokens
|
||||
|
||||
// We need to pass false for accept local connections because we don't want to automatically connect to all forwarded ports
|
||||
tunnelClient, err = tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating tunnel client: %w", err)
|
||||
}
|
||||
|
||||
return tunnelClient, nil
|
||||
}
|
||||
75
internal/codespaces/connection/connection_test.go
Normal file
75
internal/codespaces/connection/connection_test.go
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels"
|
||||
)
|
||||
|
||||
func TestNewCodespaceConnection(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a mock codespace
|
||||
connection := api.CodespaceConnection{
|
||||
TunnelProperties: api.TunnelProperties{
|
||||
ConnectAccessToken: "connect-token",
|
||||
ManagePortsAccessToken: "manage-ports-token",
|
||||
ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/",
|
||||
TunnelId: "tunnel-id",
|
||||
ClusterId: "usw2",
|
||||
Domain: "domain.com",
|
||||
},
|
||||
}
|
||||
allowedPortPrivacySettings := []string{"public", "private"}
|
||||
codespace := &api.Codespace{
|
||||
Connection: connection,
|
||||
RuntimeConstraints: api.RuntimeConstraints{AllowedPortPrivacySettings: allowedPortPrivacySettings},
|
||||
}
|
||||
|
||||
// Create the mock HTTP client
|
||||
httpClient, err := NewMockHttpClient()
|
||||
if err != nil {
|
||||
t.Fatalf("NewHttpClient returned an error: %v", err)
|
||||
}
|
||||
|
||||
// Create the connection
|
||||
conn, err := NewCodespaceConnection(ctx, codespace, httpClient)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCodespaceConnection returned an error: %v", err)
|
||||
}
|
||||
|
||||
// Check that the connection was created successfully
|
||||
if conn == nil {
|
||||
t.Fatal("NewCodespaceConnection returned nil")
|
||||
}
|
||||
|
||||
// Verify that the connection contains the expected tunnel properties
|
||||
if conn.tunnelProperties != connection.TunnelProperties {
|
||||
t.Fatalf("NewCodespaceConnection returned a connection with unexpected tunnel properties: %+v", conn.tunnelProperties)
|
||||
}
|
||||
|
||||
// Verify that the connection contains the expected tunnel
|
||||
expectedTunnel := &tunnels.Tunnel{
|
||||
AccessTokens: map[tunnels.TunnelAccessScope]string{tunnels.TunnelAccessScopeConnect: connection.TunnelProperties.ConnectAccessToken, tunnels.TunnelAccessScopeManagePorts: connection.TunnelProperties.ManagePortsAccessToken},
|
||||
TunnelID: connection.TunnelProperties.TunnelId,
|
||||
ClusterID: connection.TunnelProperties.ClusterId,
|
||||
Domain: connection.TunnelProperties.Domain,
|
||||
}
|
||||
if !reflect.DeepEqual(conn.Tunnel, expectedTunnel) {
|
||||
t.Fatalf("NewCodespaceConnection returned a connection with unexpected tunnel: %+v", conn.Tunnel)
|
||||
}
|
||||
|
||||
// Verify that the connection contains the expected tunnel options
|
||||
expectedOptions := &tunnels.TunnelRequestOptions{IncludePorts: true}
|
||||
if !reflect.DeepEqual(conn.Options, expectedOptions) {
|
||||
t.Fatalf("NewCodespaceConnection returned a connection with unexpected options: %+v", conn.Options)
|
||||
}
|
||||
|
||||
// Verify that the connection contains the expected allowed port privacy settings
|
||||
if !reflect.DeepEqual(conn.AllowedPortPrivacySettings, allowedPortPrivacySettings) {
|
||||
t.Fatalf("NewCodespaceConnection returned a connection with unexpected allowed port privacy settings: %+v", conn.AllowedPortPrivacySettings)
|
||||
}
|
||||
}
|
||||
396
internal/codespaces/connection/tunnels_api_server_mock.go
Normal file
396
internal/codespaces/connection/tunnels_api_server_mock.go
Normal file
|
|
@ -0,0 +1,396 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels"
|
||||
tunnelssh "github.com/microsoft/dev-tunnels/go/tunnels/ssh"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels/ssh/messages"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func NewMockHttpClient() (*http.Client, error) {
|
||||
accessToken := "tunnel access-token"
|
||||
relayServer, err := newMockrelayServer(withAccessToken(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewrelayServer returned an error: %w", err)
|
||||
}
|
||||
|
||||
hostURL := strings.Replace(relayServer.URL(), "http://", "ws://", 1)
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var response []byte
|
||||
if r.URL.Path == "/api/v1/tunnels/tunnel-id" {
|
||||
tunnel := &tunnels.Tunnel{
|
||||
AccessTokens: map[tunnels.TunnelAccessScope]string{
|
||||
tunnels.TunnelAccessScopeConnect: accessToken,
|
||||
},
|
||||
Endpoints: []tunnels.TunnelEndpoint{
|
||||
{
|
||||
HostID: "host1",
|
||||
TunnelRelayTunnelEndpoint: tunnels.TunnelRelayTunnelEndpoint{
|
||||
ClientRelayURI: hostURL,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
response, err = json.Marshal(*tunnel)
|
||||
if err != nil {
|
||||
log.Fatalf("json.Marshal returned an error: %v", err)
|
||||
}
|
||||
} else if strings.HasPrefix(r.URL.Path, "/api/v1/tunnels/tunnel-id/ports") {
|
||||
// Use regex to check if the path ends with a number
|
||||
match, err := regexp.MatchString(`\/\d+$`, r.URL.Path)
|
||||
if err != nil {
|
||||
log.Fatalf("regexp.MatchString returned an error: %v", err)
|
||||
}
|
||||
|
||||
// If the path ends with a number, it's a request for a specific port
|
||||
if match || r.Method == http.MethodPost {
|
||||
if r.Method == http.MethodDelete {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
tunnelPort := &tunnels.TunnelPort{
|
||||
AccessControl: &tunnels.TunnelAccessControl{
|
||||
Entries: []tunnels.TunnelAccessControlEntry{},
|
||||
},
|
||||
}
|
||||
|
||||
// Convert the tunnel to JSON and write it to the response
|
||||
response, err = json.Marshal(*tunnelPort)
|
||||
if err != nil {
|
||||
log.Fatalf("json.Marshal returned an error: %v", err)
|
||||
}
|
||||
} else {
|
||||
// If the path doesn't end with a number and we aren't making a POST request, return an array of ports
|
||||
tunnelPorts := []tunnels.TunnelPort{
|
||||
{
|
||||
AccessControl: &tunnels.TunnelAccessControl{
|
||||
Entries: []tunnels.TunnelAccessControlEntry{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
response, err = json.Marshal(tunnelPorts)
|
||||
if err != nil {
|
||||
log.Fatalf("json.Marshal returned an error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Write the response
|
||||
_, _ = w.Write(response)
|
||||
}))
|
||||
|
||||
url, err := url.Parse(mockServer.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("url.Parse returned an error: %w", err)
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(url),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type relayServer struct {
|
||||
httpServer *httptest.Server
|
||||
errc chan error
|
||||
sshConfig *ssh.ServerConfig
|
||||
channels map[string]channelHandler
|
||||
accessToken string
|
||||
|
||||
serverConn *ssh.ServerConn
|
||||
}
|
||||
|
||||
type relayServerOption func(*relayServer)
|
||||
type channelHandler func(context.Context, ssh.NewChannel) error
|
||||
|
||||
func newMockrelayServer(opts ...relayServerOption) (*relayServer, error) {
|
||||
server := &relayServer{
|
||||
errc: make(chan error),
|
||||
sshConfig: &ssh.ServerConfig{
|
||||
NoClientAuth: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Create a private key with the crypto package
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate key: %w", err)
|
||||
}
|
||||
|
||||
privateKeyPEM := pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
},
|
||||
)
|
||||
|
||||
// Parse the private key
|
||||
sshPrivateKey, err := ssh.ParsePrivateKey(privateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
}
|
||||
|
||||
server.sshConfig.AddHostKey(ssh.Signer(sshPrivateKey))
|
||||
|
||||
server.httpServer = httptest.NewServer(http.HandlerFunc(makeConnection(server)))
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(server)
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func withAccessToken(accessToken string) func(*relayServer) {
|
||||
return func(server *relayServer) {
|
||||
server.accessToken = accessToken
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *relayServer) URL() string {
|
||||
return rs.httpServer.URL
|
||||
}
|
||||
|
||||
func (rs *relayServer) Err() <-chan error {
|
||||
return rs.errc
|
||||
}
|
||||
|
||||
func (rs *relayServer) sendError(err error) {
|
||||
select {
|
||||
case rs.errc <- err:
|
||||
default:
|
||||
// channel is blocked with a previous error, so we ignore this one
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *relayServer) ForwardPort(ctx context.Context, port uint16) error {
|
||||
pfr := messages.NewPortForwardRequest("127.0.0.1", uint32(port))
|
||||
b, err := pfr.Marshal()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling port forward request: %w", err)
|
||||
}
|
||||
|
||||
replied, data, err := rs.serverConn.SendRequest(messages.PortForwardRequestType, true, b)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error sending port forward request: %w", err)
|
||||
}
|
||||
|
||||
if !replied {
|
||||
return fmt.Errorf("port forward request not replied")
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
return fmt.Errorf("no data returned")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeConnection(server *relayServer) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if server.accessToken != "" {
|
||||
if r.Header.Get("Authorization") != server.accessToken {
|
||||
server.sendError(fmt.Errorf("invalid access token"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
upgrader := websocket.Upgrader{}
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
server.sendError(fmt.Errorf("error upgrading to websocket: %w", err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
server.sendError(fmt.Errorf("error closing websocket: %w", err))
|
||||
}
|
||||
}()
|
||||
|
||||
socketConn := newSocketConn(c)
|
||||
serverConn, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
|
||||
if err != nil {
|
||||
server.sendError(fmt.Errorf("error creating ssh server conn: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
go handleRequests(ctx, convertRequests(reqs))
|
||||
|
||||
server.serverConn = serverConn
|
||||
if err := handleChannels(ctx, server, chans); err != nil {
|
||||
server.sendError(fmt.Errorf("error handling channels: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sr *sshRequest) Type() string {
|
||||
return sr.request.Type
|
||||
}
|
||||
|
||||
type sshRequest struct {
|
||||
request *ssh.Request
|
||||
}
|
||||
|
||||
// Reply method for sshRequest to satisfy the tunnelssh.SSHRequest interface
|
||||
func (sr *sshRequest) Reply(success bool, message []byte) error {
|
||||
return sr.request.Reply(success, message)
|
||||
}
|
||||
|
||||
// convertRequests function
|
||||
func convertRequests(reqs <-chan *ssh.Request) <-chan tunnelssh.SSHRequest {
|
||||
out := make(chan tunnelssh.SSHRequest)
|
||||
go func() {
|
||||
for req := range reqs {
|
||||
out <- &sshRequest{req}
|
||||
}
|
||||
close(out)
|
||||
}()
|
||||
return out
|
||||
}
|
||||
|
||||
func handleChannels(ctx context.Context, server *relayServer, chans <-chan ssh.NewChannel) error {
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
for ch := range chans {
|
||||
if handler, ok := server.channels[ch.ChannelType()]; ok {
|
||||
if err := handler(ctx, ch); err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// generic accept of the channel to not block
|
||||
_, _, err := ch.Accept()
|
||||
if err != nil {
|
||||
errc <- fmt.Errorf("error accepting channel: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
func handleRequests(ctx context.Context, reqs <-chan tunnelssh.SSHRequest) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case req, ok := <-reqs:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Type() == "RefreshPorts" {
|
||||
_ = req.Reply(true, nil)
|
||||
continue
|
||||
} else {
|
||||
_ = req.Reply(false, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func awaitError(ctx context.Context, errc <-chan error) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-errc:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
type socketConn struct {
|
||||
*websocket.Conn
|
||||
|
||||
reader io.Reader
|
||||
writeMutex sync.Mutex
|
||||
readMutex sync.Mutex
|
||||
}
|
||||
|
||||
func newSocketConn(conn *websocket.Conn) *socketConn {
|
||||
return &socketConn{Conn: conn}
|
||||
}
|
||||
|
||||
func (s *socketConn) Read(b []byte) (int, error) {
|
||||
s.readMutex.Lock()
|
||||
defer s.readMutex.Unlock()
|
||||
|
||||
if s.reader == nil {
|
||||
msgType, r, err := s.Conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next reader: %w", err)
|
||||
}
|
||||
if msgType != websocket.BinaryMessage {
|
||||
return 0, fmt.Errorf("invalid message type")
|
||||
}
|
||||
s.reader = r
|
||||
}
|
||||
|
||||
bytesRead, err := s.reader.Read(b)
|
||||
if err != nil {
|
||||
s.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (s *socketConn) Write(b []byte) (int, error) {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
|
||||
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next writer: %w", err)
|
||||
}
|
||||
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error writing: %w", err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return 0, fmt.Errorf("error closing writer: %w", err)
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *socketConn) SetDeadline(deadline time.Time) error {
|
||||
if err := s.Conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Conn.SetWriteDeadline(deadline)
|
||||
}
|
||||
253
internal/codespaces/portforwarder/port_forwarder.go
Normal file
253
internal/codespaces/portforwarder/port_forwarder.go
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
package portforwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/connection"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels"
|
||||
)
|
||||
|
||||
const (
|
||||
githubSubjectId = "1"
|
||||
InternalPortTag = "InternalPort"
|
||||
UserForwardedPortTag = "UserForwardedPort"
|
||||
)
|
||||
|
||||
const (
|
||||
PrivatePortVisibility = "private"
|
||||
OrgPortVisibility = "org"
|
||||
PublicPortVisibility = "public"
|
||||
)
|
||||
|
||||
type PortForwarder struct {
|
||||
connection connection.CodespaceConnection
|
||||
}
|
||||
|
||||
// NewPortForwarder returns a new PortForwarder for the specified codespace.
|
||||
func NewPortForwarder(ctx context.Context, codespaceConnection *connection.CodespaceConnection) (fwd *PortForwarder, err error) {
|
||||
return &PortForwarder{
|
||||
connection: *codespaceConnection,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ForwardAndConnectToPort forwards a port and connects to it via a local TCP port.
|
||||
func (fwd *PortForwarder) ForwardAndConnectToPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, internal bool) error {
|
||||
return fwd.ForwardPort(ctx, remotePort, listen, keepAlive, true, internal, "")
|
||||
}
|
||||
|
||||
// ForwardPort forwards a port and optionally connects to it via a local TCP port.
|
||||
func (fwd *PortForwarder) ForwardPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, connect bool, internal bool, visibility string) error {
|
||||
tunnelPort := tunnels.NewTunnelPort(remotePort, "", "", tunnels.TunnelProtocolHttp)
|
||||
|
||||
// If no visibility is provided, Dev Tunnels will use the default (private)
|
||||
if visibility != "" {
|
||||
// Check if the requested visibility is allowed
|
||||
allowed := false
|
||||
for _, allowedVisibility := range fwd.connection.AllowedPortPrivacySettings {
|
||||
if allowedVisibility == visibility {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If the requested visibility is not allowed, return an error
|
||||
if !allowed {
|
||||
return fmt.Errorf("visibility %s is not allowed", visibility)
|
||||
}
|
||||
|
||||
accessControlEntries := visibilityToAccessControlEntries(visibility)
|
||||
if len(accessControlEntries) > 0 {
|
||||
tunnelPort.AccessControl = &tunnels.TunnelAccessControl{
|
||||
Entries: accessControlEntries,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tag the port as internal or user forwarded so we know if it needs to be shown in the UI
|
||||
if internal {
|
||||
tunnelPort.Tags = []string{InternalPortTag}
|
||||
} else {
|
||||
tunnelPort.Tags = []string{UserForwardedPortTag}
|
||||
}
|
||||
|
||||
// Create the tunnel port
|
||||
_, err := fwd.connection.TunnelManager.CreateTunnelPort(ctx, fwd.connection.Tunnel, tunnelPort, fwd.connection.Options)
|
||||
if err != nil && !strings.Contains(err.Error(), "409") {
|
||||
return fmt.Errorf("create tunnel port failed: %v", err)
|
||||
}
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
// Connect to the tunnel
|
||||
err = fwd.connection.TunnelClient.Connect(ctx, "")
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("connect failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Inform the host that we've forwarded the port locally
|
||||
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("refresh ports failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If we don't want to connect to the port, exit early
|
||||
if !connect {
|
||||
done <- nil
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure the port is forwarded before connecting
|
||||
err = fwd.connection.TunnelClient.WaitForForwardedPort(ctx, remotePort)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("wait for forwarded port failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Connect to the forwarded port via a local TCP port
|
||||
err = fwd.connection.TunnelClient.ConnectToForwardedPort(ctx, listen, remotePort)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("connect to forwarded port failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
done <- nil
|
||||
}()
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to tunnel: %w", err)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ListPorts fetches the list of ports that are currently forwarded.
|
||||
func (fwd *PortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.TunnelPort, err error) {
|
||||
ports, err = fwd.connection.TunnelManager.ListTunnelPorts(ctx, fwd.connection.Tunnel, fwd.connection.Options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error listing ports: %w", err)
|
||||
}
|
||||
|
||||
return ports, nil
|
||||
}
|
||||
|
||||
// UpdatePortVisibility changes the visibility (private, org, public) of the specified port.
|
||||
func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error {
|
||||
tunnelPort, err := fwd.connection.TunnelManager.GetTunnelPort(ctx, fwd.connection.Tunnel, remotePort, fwd.connection.Options)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting tunnel port: %w", err)
|
||||
}
|
||||
|
||||
// If the port visibility isn't changing, don't do anything
|
||||
if AccessControlEntriesToVisibility(tunnelPort.AccessControl.Entries) == visibility {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete the existing tunnel port to update
|
||||
err = fwd.connection.TunnelManager.DeleteTunnelPort(ctx, fwd.connection.Tunnel, uint16(remotePort), fwd.connection.Options)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting tunnel port: %w", err)
|
||||
}
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
// Connect to the tunnel
|
||||
err = fwd.connection.TunnelClient.Connect(ctx, "")
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("connect failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Inform the host that we've deleted the port
|
||||
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("refresh ports failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
// Wait for the done channel to be closed
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to tunnel: %w", err)
|
||||
}
|
||||
|
||||
// Re-forward the port with the updated visibility
|
||||
err = fwd.ForwardPort(ctx, uint16(remotePort), nil, false, false, false, visibility)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error forwarding port: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value.
|
||||
func AccessControlEntriesToVisibility(accessControlEntries []tunnels.TunnelAccessControlEntry) string {
|
||||
for _, entry := range accessControlEntries {
|
||||
// If we have the anonymous type (and we're not denying it), it's public
|
||||
if (entry.Type == tunnels.TunnelAccessControlEntryTypeAnonymous) && (!entry.IsDeny) {
|
||||
return PublicPortVisibility
|
||||
}
|
||||
|
||||
// If we have the organizations type (and we're not denying it), it's org
|
||||
if (entry.Provider == string(tunnels.TunnelAuthenticationSchemeGitHub)) && (!entry.IsDeny) {
|
||||
return OrgPortVisibility
|
||||
}
|
||||
}
|
||||
|
||||
// Else, it's private
|
||||
return PrivatePortVisibility
|
||||
}
|
||||
|
||||
// visibilityToAccessControlEntries converts the given visibility to access control entries that can be used by Dev Tunnels.
|
||||
func visibilityToAccessControlEntries(visibility string) []tunnels.TunnelAccessControlEntry {
|
||||
switch visibility {
|
||||
case PublicPortVisibility:
|
||||
return []tunnels.TunnelAccessControlEntry{{
|
||||
Type: tunnels.TunnelAccessControlEntryTypeAnonymous,
|
||||
Subjects: []string{},
|
||||
Scopes: []string{string(tunnels.TunnelAccessScopeConnect)},
|
||||
}}
|
||||
case OrgPortVisibility:
|
||||
return []tunnels.TunnelAccessControlEntry{{
|
||||
Type: tunnels.TunnelAccessControlEntryTypeOrganizations,
|
||||
Subjects: []string{githubSubjectId},
|
||||
Scopes: []string{
|
||||
string(tunnels.TunnelAccessScopeConnect),
|
||||
},
|
||||
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
|
||||
}}
|
||||
default:
|
||||
// The tunnel manager doesn't accept empty access control entries, so we need to return a deny entry
|
||||
return []tunnels.TunnelAccessControlEntry{{
|
||||
Type: tunnels.TunnelAccessControlEntryTypeOrganizations,
|
||||
Subjects: []string{githubSubjectId},
|
||||
Scopes: []string{},
|
||||
IsDeny: true,
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
||||
// IsInternalPort returns true if the port is internal.
|
||||
func IsInternalPort(port *tunnels.TunnelPort) bool {
|
||||
for _, tag := range port.Tags {
|
||||
if strings.EqualFold(tag, InternalPortTag) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
139
internal/codespaces/portforwarder/port_forwarder_test.go
Normal file
139
internal/codespaces/portforwarder/port_forwarder_test.go
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
package portforwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces/connection"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels"
|
||||
)
|
||||
|
||||
func TestNewPortForwarder(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a mock codespace
|
||||
codespace := &api.Codespace{
|
||||
Connection: api.CodespaceConnection{
|
||||
TunnelProperties: api.TunnelProperties{
|
||||
ConnectAccessToken: "connect-token",
|
||||
ManagePortsAccessToken: "manage-ports-token",
|
||||
ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/",
|
||||
TunnelId: "tunnel-id",
|
||||
ClusterId: "usw2",
|
||||
Domain: "domain.com",
|
||||
},
|
||||
},
|
||||
RuntimeConstraints: api.RuntimeConstraints{
|
||||
AllowedPortPrivacySettings: []string{"public", "private"},
|
||||
},
|
||||
}
|
||||
|
||||
// Create the mock HTTP client
|
||||
httpClient, err := connection.NewMockHttpClient()
|
||||
if err != nil {
|
||||
t.Fatalf("NewHttpClient returned an error: %v", err)
|
||||
}
|
||||
|
||||
// Call the function being tested
|
||||
conn, err := connection.NewCodespaceConnection(ctx, codespace, httpClient)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCodespaceConnection returned an error: %v", err)
|
||||
}
|
||||
|
||||
// Create the new port forwarder
|
||||
portForwarder, err := NewPortForwarder(ctx, conn)
|
||||
if err != nil {
|
||||
t.Fatalf("NewPortForwarder returned an error: %v", err)
|
||||
}
|
||||
|
||||
// Check that the port forwarder was created successfully
|
||||
if portForwarder == nil {
|
||||
t.Fatal("NewPortForwarder returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessControlEntriesToVisibility(t *testing.T) {
|
||||
publicAccessControlEntry := []tunnels.TunnelAccessControlEntry{{
|
||||
Type: tunnels.TunnelAccessControlEntryTypeAnonymous,
|
||||
}}
|
||||
orgAccessControlEntry := []tunnels.TunnelAccessControlEntry{{
|
||||
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
|
||||
}}
|
||||
privateAccessControlEntry := []tunnels.TunnelAccessControlEntry{}
|
||||
orgIsDenyAccessControlEntry := []tunnels.TunnelAccessControlEntry{{
|
||||
Provider: string(tunnels.TunnelAuthenticationSchemeGitHub),
|
||||
IsDeny: true,
|
||||
}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accessControlEntries []tunnels.TunnelAccessControlEntry
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "public",
|
||||
accessControlEntries: publicAccessControlEntry,
|
||||
expected: PublicPortVisibility,
|
||||
},
|
||||
{
|
||||
name: "org",
|
||||
accessControlEntries: orgAccessControlEntry,
|
||||
expected: OrgPortVisibility,
|
||||
},
|
||||
{
|
||||
name: "private",
|
||||
accessControlEntries: privateAccessControlEntry,
|
||||
expected: PrivatePortVisibility,
|
||||
},
|
||||
{
|
||||
name: "orgIsDeny",
|
||||
accessControlEntries: orgIsDenyAccessControlEntry,
|
||||
expected: PrivatePortVisibility,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
visibility := AccessControlEntriesToVisibility(test.accessControlEntries)
|
||||
if visibility != test.expected {
|
||||
t.Errorf("expected %q, got %q", test.expected, visibility)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInternalPort(t *testing.T) {
|
||||
internalPort := &tunnels.TunnelPort{
|
||||
Tags: []string{"InternalPort"},
|
||||
}
|
||||
userForwardedPort := &tunnels.TunnelPort{
|
||||
Tags: []string{"UserForwardedPort"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
port *tunnels.TunnelPort
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "internal",
|
||||
port: internalPort,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "user-forwarded",
|
||||
port: userForwardedPort,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
isInternal := IsInternalPort(test.port)
|
||||
if isInternal != test.expected {
|
||||
t.Errorf("expected %v, got %v", test.expected, isInternal)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue