diff --git a/go.mod b/go.mod index 4a97a0e94..376480ecc 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/mattn/go-colorable v0.1.13 github.com/mattn/go-isatty v0.0.19 github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d + github.com/microsoft/dev-tunnels v0.0.21 github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38 github.com/opentracing/opentracing-go v1.1.0 github.com/rivo/tview v0.0.0-20221029100920-c4a7e501810d @@ -75,6 +76,7 @@ require ( github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.4 // indirect + github.com/rodaine/table v1.0.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 // indirect github.com/stretchr/objx v0.5.0 // indirect diff --git a/go.sum b/go.sum index e2996530f..a6c982abb 100644 --- a/go.sum +++ b/go.sum @@ -117,6 +117,8 @@ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/microcosm-cc/bluemonday v1.0.21 h1:dNH3e4PSyE4vNX+KlRGHT5KrSvjeUkoNPwEORjffHJg= github.com/microcosm-cc/bluemonday v1.0.21/go.mod h1:ytNkv4RrDrLJ2pqlsSI46O6IVXmZOBBD4SaJyDwwTkM= +github.com/microsoft/dev-tunnels v0.0.21 h1:p4QP7C5ZOyP9bGbmanRjPxUMckfi9Z41Gl+KY4C11w0= +github.com/microsoft/dev-tunnels v0.0.21/go.mod h1:frU++12T/oqxckXkDpTuYa427ncguEOodSPZcGCCrzQ= github.com/muesli/reflow v0.2.1-0.20210115123740-9e1d0d53df68/go.mod h1:Xk+z4oIWdQqJzsxyjgl3P22oYZnHdZ8FFTHAQQt5BMQ= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= @@ -139,6 +141,8 @@ github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rodaine/table v1.0.1 h1:U/VwCnUxlVYxw8+NJiLIuCxA/xa6jL38MY3FYysVWWQ= +github.com/rodaine/table v1.0.1/go.mod h1:UVEtfBsflpeEcD56nF4F5AocNFta0ZuolpSVdPtlmP4= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shurcooL/githubv4 v0.0.0-20230704064427-599ae7bbf278 h1:kdEGVAV4sO46DPtb8k793jiecUEhaX9ixoIBt41HEGU= diff --git a/internal/codespaces/api/api.go b/internal/codespaces/api/api.go index d6b21c304..dd2ba4033 100644 --- a/internal/codespaces/api/api.go +++ b/internal/codespaces/api/api.go @@ -201,6 +201,7 @@ type Codespace struct { GitStatus CodespaceGitStatus `json:"git_status"` Connection CodespaceConnection `json:"connection"` Machine CodespaceMachine `json:"machine"` + RuntimeConstraints RuntimeConstraints `json:"runtime_constraints"` VSCSTarget string `json:"vscs_target"` PendingOperation bool `json:"pending_operation"` PendingOperationDisabledReason string `json:"pending_operation_disabled_reason"` @@ -246,11 +247,25 @@ const ( ) type CodespaceConnection struct { - SessionID string `json:"sessionId"` - SessionToken string `json:"sessionToken"` - RelayEndpoint string `json:"relayEndpoint"` - RelaySAS string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` + SessionID string `json:"sessionId"` + SessionToken string `json:"sessionToken"` + RelayEndpoint string `json:"relayEndpoint"` + RelaySAS string `json:"relaySas"` + HostPublicKeys []string `json:"hostPublicKeys"` + TunnelProperties TunnelProperties `json:"tunnelProperties"` +} + +type TunnelProperties struct { + ConnectAccessToken string `json:"connectAccessToken"` + ManagePortsAccessToken string `json:"managePortsAccessToken"` + ServiceUri string `json:"serviceUri"` + TunnelId string `json:"tunnelId"` + ClusterId string `json:"clusterId"` + Domain string `json:"domain"` +} + +type RuntimeConstraints struct { + AllowedPortPrivacySettings []string `json:"allowed_port_privacy_settings"` } // ListCodespaceFields is the list of exportable fields for a codespace when using the `gh cs list` command. @@ -1162,3 +1177,13 @@ func (a *API) withRetry(f func() (*http.Response, error)) (*http.Response, error return nil, fmt.Errorf("received response with status code %d", resp.StatusCode) }, backoff.WithMaxRetries(bo, 3)) } + +// HTTPClient returns the HTTP client used to make requests to the API. +func (a *API) HTTPClient() (*http.Client, error) { + httpClient, err := a.client() + if err != nil { + return nil, err + } + + return httpClient, nil +} diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 8834f0e6c..3bcc2b404 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -5,24 +5,42 @@ import ( "errors" "fmt" "net" + "net/http" "time" "github.com/cenkalti/backoff/v4" "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/cli/cli/v2/internal/codespaces/connection" "github.com/cli/cli/v2/pkg/liveshare" ) -func connectionReady(codespace *api.Codespace) bool { +func connectionReady(codespace *api.Codespace, usingDevTunnels bool) bool { + // If the codespace is not available, it is not ready + if codespace.State != api.CodespaceStateAvailable { + return false + } + + // If using Dev Tunnels, we need to check that we have all of the required tunnel properties + if usingDevTunnels { + return codespace.Connection.TunnelProperties.ConnectAccessToken != "" && + codespace.Connection.TunnelProperties.ManagePortsAccessToken != "" && + codespace.Connection.TunnelProperties.ServiceUri != "" && + codespace.Connection.TunnelProperties.TunnelId != "" && + codespace.Connection.TunnelProperties.ClusterId != "" && + codespace.Connection.TunnelProperties.Domain != "" + } + + // If not using Dev Tunnels, we need to check that we have all of the required Live Share properties return codespace.Connection.SessionID != "" && codespace.Connection.SessionToken != "" && codespace.Connection.RelayEndpoint != "" && - codespace.Connection.RelaySAS != "" && - codespace.State == api.CodespaceStateAvailable + codespace.Connection.RelaySAS != "" } type apiClient interface { GetCodespace(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error) StartCodespace(ctx context.Context, name string) error + HTTPClient() (*http.Client, error) } type progressIndicator interface { @@ -43,9 +61,48 @@ func (e *TimeoutError) Error() string { return e.message } -// ConnectToLiveshare waits for a Codespace to become running, -// and connects to it using a Live Share session. +// GetCodespaceConnection waits until a codespace is able +// to be connected to and initializes a connection to it. +func GetCodespaceConnection(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace) (*connection.CodespaceConnection, error) { + codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, true) + if err != nil { + return nil, err + } + + progress.StartProgressIndicatorWithLabel("Connecting to codespace") + defer progress.StopProgressIndicator() + + httpClient, err := apiClient.HTTPClient() + if err != nil { + return nil, fmt.Errorf("error getting http client: %w", err) + } + + return connection.NewCodespaceConnection(ctx, codespace, httpClient) +} + +// ConnectToLiveshare waits until a codespace is able to be +// connected to and connects to it using a Live Share session. func ConnectToLiveshare(ctx context.Context, progress progressIndicator, sessionLogger logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) { + codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, false) + if err != nil { + return nil, err + } + + progress.StartProgressIndicatorWithLabel("Connecting to codespace") + defer progress.StopProgressIndicator() + + return liveshare.Connect(ctx, liveshare.Options{ + SessionID: codespace.Connection.SessionID, + SessionToken: codespace.Connection.SessionToken, + RelaySAS: codespace.Connection.RelaySAS, + RelayEndpoint: codespace.Connection.RelayEndpoint, + HostPublicKeys: codespace.Connection.HostPublicKeys, + Logger: sessionLogger, + }) +} + +// waitUntilCodespaceConnectionReady waits for a Codespace to be running and is able to be connected to. +func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace, usingDevTunnels bool) (*api.Codespace, error) { if codespace.State != api.CodespaceStateAvailable { progress.StartProgressIndicatorWithLabel("Starting codespace") defer progress.StopProgressIndicator() @@ -54,7 +111,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session } } - if !connectionReady(codespace) { + if !connectionReady(codespace, usingDevTunnels) { expBackoff := backoff.NewExponentialBackOff() expBackoff.Multiplier = 1.1 expBackoff.MaxInterval = 10 * time.Second @@ -67,7 +124,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session return backoff.Permanent(fmt.Errorf("error getting codespace: %w", err)) } - if connectionReady(codespace) { + if connectionReady(codespace, usingDevTunnels) { return nil } @@ -83,17 +140,7 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session } } - progress.StartProgressIndicatorWithLabel("Connecting to codespace") - defer progress.StopProgressIndicator() - - return liveshare.Connect(ctx, liveshare.Options{ - SessionID: codespace.Connection.SessionID, - SessionToken: codespace.Connection.SessionToken, - RelaySAS: codespace.Connection.RelaySAS, - RelayEndpoint: codespace.Connection.RelayEndpoint, - HostPublicKeys: codespace.Connection.HostPublicKeys, - Logger: sessionLogger, - }) + return codespace, nil } // ListenTCP starts a localhost tcp listener on 127.0.0.1 (unless allInterfaces is true) and returns the listener and bound port diff --git a/internal/codespaces/connection/connection.go b/internal/codespaces/connection/connection.go new file mode 100644 index 000000000..5eea89a76 --- /dev/null +++ b/internal/codespaces/connection/connection.go @@ -0,0 +1,116 @@ +package connection + +import ( + "context" + "fmt" + "io" + "log" + "net/http" + "net/url" + + "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/microsoft/dev-tunnels/go/tunnels" +) + +const ( + clientName = "gh" +) + +type CodespaceConnection struct { + tunnelProperties api.TunnelProperties + TunnelManager *tunnels.Manager + TunnelClient *tunnels.Client + Options *tunnels.TunnelRequestOptions + Tunnel *tunnels.Tunnel + AllowedPortPrivacySettings []string +} + +// NewCodespaceConnection initializes a connection to a codespace. +// This connections allows for port forwarding which enables the +// use of most features of the codespace command. +func NewCodespaceConnection(ctx context.Context, codespace *api.Codespace, httpClient *http.Client) (connection *CodespaceConnection, err error) { + // Get the tunnel properties + tunnelProperties := codespace.Connection.TunnelProperties + + // Create the tunnel manager + tunnelManager, err := getTunnelManager(tunnelProperties, httpClient) + if err != nil { + return nil, fmt.Errorf("error getting tunnel management client: %w", err) + } + + // Calculate allowed port privacy settings + allowedPortPrivacySettings := codespace.RuntimeConstraints.AllowedPortPrivacySettings + + // Get the access tokens + connectToken := tunnelProperties.ConnectAccessToken + managementToken := tunnelProperties.ManagePortsAccessToken + + // Create the tunnel definition + tunnel := &tunnels.Tunnel{ + AccessTokens: map[tunnels.TunnelAccessScope]string{tunnels.TunnelAccessScopeConnect: connectToken, tunnels.TunnelAccessScopeManagePorts: managementToken}, + TunnelID: tunnelProperties.TunnelId, + ClusterID: tunnelProperties.ClusterId, + Domain: tunnelProperties.Domain, + } + + // Create options + options := &tunnels.TunnelRequestOptions{ + IncludePorts: true, + } + + // Create the tunnel client (not connected yet) + tunnelClient, err := getTunnelClient(ctx, tunnelManager, tunnel, options) + if err != nil { + return nil, fmt.Errorf("error getting tunnel client: %w", err) + } + + return &CodespaceConnection{ + tunnelProperties: tunnelProperties, + TunnelManager: tunnelManager, + TunnelClient: tunnelClient, + Options: options, + Tunnel: tunnel, + AllowedPortPrivacySettings: allowedPortPrivacySettings, + }, nil +} + +// getTunnelManager creates a tunnel manager for the given codespace. +// The tunnel manager is used to get the tunnel hosted in the codespace that we +// want to connect to and perform operations on ports (add, remove, list, etc.). +func getTunnelManager(tunnelProperties api.TunnelProperties, httpClient *http.Client) (tunnelManager *tunnels.Manager, err error) { + userAgent := []tunnels.UserAgent{{Name: clientName}} + url, err := url.Parse(tunnelProperties.ServiceUri) + if err != nil { + return nil, fmt.Errorf("error parsing tunnel service uri: %w", err) + } + + // Create the tunnel manager + tunnelManager, err = tunnels.NewManager(userAgent, nil, url, httpClient) + if err != nil { + return nil, fmt.Errorf("error creating tunnel manager: %w", err) + } + + return tunnelManager, nil +} + +// getTunnelClient creates a tunnel client for the given tunnel. +// The tunnel client is used to connect to the the tunnel and allows +// for ports to be forwarded locally. +func getTunnelClient(ctx context.Context, tunnelManager *tunnels.Manager, tunnel *tunnels.Tunnel, options *tunnels.TunnelRequestOptions) (tunnelClient *tunnels.Client, err error) { + // Get the tunnel that we want to connect to + codespaceTunnel, err := tunnelManager.GetTunnel(ctx, tunnel, options) + if err != nil { + return nil, fmt.Errorf("error getting tunnel: %w", err) + } + + // Copy the access tokens from the tunnel definition + codespaceTunnel.AccessTokens = tunnel.AccessTokens + + // We need to pass false for accept local connections because we don't want to automatically connect to all forwarded ports + tunnelClient, err = tunnels.NewClient(log.New(io.Discard, "", log.LstdFlags), codespaceTunnel, false) + if err != nil { + return nil, fmt.Errorf("error creating tunnel client: %w", err) + } + + return tunnelClient, nil +} diff --git a/internal/codespaces/connection/connection_test.go b/internal/codespaces/connection/connection_test.go new file mode 100644 index 000000000..e7ebd2788 --- /dev/null +++ b/internal/codespaces/connection/connection_test.go @@ -0,0 +1,75 @@ +package connection + +import ( + "context" + "reflect" + "testing" + + "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/microsoft/dev-tunnels/go/tunnels" +) + +func TestNewCodespaceConnection(t *testing.T) { + ctx := context.Background() + + // Create a mock codespace + connection := api.CodespaceConnection{ + TunnelProperties: api.TunnelProperties{ + ConnectAccessToken: "connect-token", + ManagePortsAccessToken: "manage-ports-token", + ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/", + TunnelId: "tunnel-id", + ClusterId: "usw2", + Domain: "domain.com", + }, + } + allowedPortPrivacySettings := []string{"public", "private"} + codespace := &api.Codespace{ + Connection: connection, + RuntimeConstraints: api.RuntimeConstraints{AllowedPortPrivacySettings: allowedPortPrivacySettings}, + } + + // Create the mock HTTP client + httpClient, err := NewMockHttpClient() + if err != nil { + t.Fatalf("NewHttpClient returned an error: %v", err) + } + + // Create the connection + conn, err := NewCodespaceConnection(ctx, codespace, httpClient) + if err != nil { + t.Fatalf("NewCodespaceConnection returned an error: %v", err) + } + + // Check that the connection was created successfully + if conn == nil { + t.Fatal("NewCodespaceConnection returned nil") + } + + // Verify that the connection contains the expected tunnel properties + if conn.tunnelProperties != connection.TunnelProperties { + t.Fatalf("NewCodespaceConnection returned a connection with unexpected tunnel properties: %+v", conn.tunnelProperties) + } + + // Verify that the connection contains the expected tunnel + expectedTunnel := &tunnels.Tunnel{ + AccessTokens: map[tunnels.TunnelAccessScope]string{tunnels.TunnelAccessScopeConnect: connection.TunnelProperties.ConnectAccessToken, tunnels.TunnelAccessScopeManagePorts: connection.TunnelProperties.ManagePortsAccessToken}, + TunnelID: connection.TunnelProperties.TunnelId, + ClusterID: connection.TunnelProperties.ClusterId, + Domain: connection.TunnelProperties.Domain, + } + if !reflect.DeepEqual(conn.Tunnel, expectedTunnel) { + t.Fatalf("NewCodespaceConnection returned a connection with unexpected tunnel: %+v", conn.Tunnel) + } + + // Verify that the connection contains the expected tunnel options + expectedOptions := &tunnels.TunnelRequestOptions{IncludePorts: true} + if !reflect.DeepEqual(conn.Options, expectedOptions) { + t.Fatalf("NewCodespaceConnection returned a connection with unexpected options: %+v", conn.Options) + } + + // Verify that the connection contains the expected allowed port privacy settings + if !reflect.DeepEqual(conn.AllowedPortPrivacySettings, allowedPortPrivacySettings) { + t.Fatalf("NewCodespaceConnection returned a connection with unexpected allowed port privacy settings: %+v", conn.AllowedPortPrivacySettings) + } +} diff --git a/internal/codespaces/connection/tunnels_api_server_mock.go b/internal/codespaces/connection/tunnels_api_server_mock.go new file mode 100644 index 000000000..cf8f05cfa --- /dev/null +++ b/internal/codespaces/connection/tunnels_api_server_mock.go @@ -0,0 +1,396 @@ +package connection + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/microsoft/dev-tunnels/go/tunnels" + tunnelssh "github.com/microsoft/dev-tunnels/go/tunnels/ssh" + "github.com/microsoft/dev-tunnels/go/tunnels/ssh/messages" + "golang.org/x/crypto/ssh" +) + +func NewMockHttpClient() (*http.Client, error) { + accessToken := "tunnel access-token" + relayServer, err := newMockrelayServer(withAccessToken(accessToken)) + if err != nil { + return nil, fmt.Errorf("NewrelayServer returned an error: %w", err) + } + + hostURL := strings.Replace(relayServer.URL(), "http://", "ws://", 1) + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var response []byte + if r.URL.Path == "/api/v1/tunnels/tunnel-id" { + tunnel := &tunnels.Tunnel{ + AccessTokens: map[tunnels.TunnelAccessScope]string{ + tunnels.TunnelAccessScopeConnect: accessToken, + }, + Endpoints: []tunnels.TunnelEndpoint{ + { + HostID: "host1", + TunnelRelayTunnelEndpoint: tunnels.TunnelRelayTunnelEndpoint{ + ClientRelayURI: hostURL, + }, + }, + }, + } + + response, err = json.Marshal(*tunnel) + if err != nil { + log.Fatalf("json.Marshal returned an error: %v", err) + } + } else if strings.HasPrefix(r.URL.Path, "/api/v1/tunnels/tunnel-id/ports") { + // Use regex to check if the path ends with a number + match, err := regexp.MatchString(`\/\d+$`, r.URL.Path) + if err != nil { + log.Fatalf("regexp.MatchString returned an error: %v", err) + } + + // If the path ends with a number, it's a request for a specific port + if match || r.Method == http.MethodPost { + if r.Method == http.MethodDelete { + w.WriteHeader(http.StatusOK) + return + } + + tunnelPort := &tunnels.TunnelPort{ + AccessControl: &tunnels.TunnelAccessControl{ + Entries: []tunnels.TunnelAccessControlEntry{}, + }, + } + + // Convert the tunnel to JSON and write it to the response + response, err = json.Marshal(*tunnelPort) + if err != nil { + log.Fatalf("json.Marshal returned an error: %v", err) + } + } else { + // If the path doesn't end with a number and we aren't making a POST request, return an array of ports + tunnelPorts := []tunnels.TunnelPort{ + { + AccessControl: &tunnels.TunnelAccessControl{ + Entries: []tunnels.TunnelAccessControlEntry{}, + }, + }, + } + + response, err = json.Marshal(tunnelPorts) + if err != nil { + log.Fatalf("json.Marshal returned an error: %v", err) + } + } + + } else { + w.WriteHeader(http.StatusNotFound) + return + } + + // Write the response + _, _ = w.Write(response) + })) + + url, err := url.Parse(mockServer.URL) + if err != nil { + return nil, fmt.Errorf("url.Parse returned an error: %w", err) + } + return &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(url), + }, + }, nil +} + +type relayServer struct { + httpServer *httptest.Server + errc chan error + sshConfig *ssh.ServerConfig + channels map[string]channelHandler + accessToken string + + serverConn *ssh.ServerConn +} + +type relayServerOption func(*relayServer) +type channelHandler func(context.Context, ssh.NewChannel) error + +func newMockrelayServer(opts ...relayServerOption) (*relayServer, error) { + server := &relayServer{ + errc: make(chan error), + sshConfig: &ssh.ServerConfig{ + NoClientAuth: true, + }, + } + + // Create a private key with the crypto package + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("failed to generate key: %w", err) + } + + privateKeyPEM := pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }, + ) + + // Parse the private key + sshPrivateKey, err := ssh.ParsePrivateKey(privateKeyPEM) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + server.sshConfig.AddHostKey(ssh.Signer(sshPrivateKey)) + + server.httpServer = httptest.NewServer(http.HandlerFunc(makeConnection(server))) + + for _, opt := range opts { + opt(server) + } + + return server, nil +} + +func withAccessToken(accessToken string) func(*relayServer) { + return func(server *relayServer) { + server.accessToken = accessToken + } +} + +func (rs *relayServer) URL() string { + return rs.httpServer.URL +} + +func (rs *relayServer) Err() <-chan error { + return rs.errc +} + +func (rs *relayServer) sendError(err error) { + select { + case rs.errc <- err: + default: + // channel is blocked with a previous error, so we ignore this one + } +} + +func (rs *relayServer) ForwardPort(ctx context.Context, port uint16) error { + pfr := messages.NewPortForwardRequest("127.0.0.1", uint32(port)) + b, err := pfr.Marshal() + if err != nil { + return fmt.Errorf("error marshaling port forward request: %w", err) + } + + replied, data, err := rs.serverConn.SendRequest(messages.PortForwardRequestType, true, b) + if err != nil { + return fmt.Errorf("error sending port forward request: %w", err) + } + + if !replied { + return fmt.Errorf("port forward request not replied") + } + + if data == nil { + return fmt.Errorf("no data returned") + } + + return nil +} + +func makeConnection(server *relayServer) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if server.accessToken != "" { + if r.Header.Get("Authorization") != server.accessToken { + server.sendError(fmt.Errorf("invalid access token")) + return + } + } + + upgrader := websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + server.sendError(fmt.Errorf("error upgrading to websocket: %w", err)) + return + } + defer func() { + if err := c.Close(); err != nil { + server.sendError(fmt.Errorf("error closing websocket: %w", err)) + } + }() + + socketConn := newSocketConn(c) + serverConn, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig) + if err != nil { + server.sendError(fmt.Errorf("error creating ssh server conn: %w", err)) + return + } + + go handleRequests(ctx, convertRequests(reqs)) + + server.serverConn = serverConn + if err := handleChannels(ctx, server, chans); err != nil { + server.sendError(fmt.Errorf("error handling channels: %w", err)) + return + } + } +} + +func (sr *sshRequest) Type() string { + return sr.request.Type +} + +type sshRequest struct { + request *ssh.Request +} + +// Reply method for sshRequest to satisfy the tunnelssh.SSHRequest interface +func (sr *sshRequest) Reply(success bool, message []byte) error { + return sr.request.Reply(success, message) +} + +// convertRequests function +func convertRequests(reqs <-chan *ssh.Request) <-chan tunnelssh.SSHRequest { + out := make(chan tunnelssh.SSHRequest) + go func() { + for req := range reqs { + out <- &sshRequest{req} + } + close(out) + }() + return out +} + +func handleChannels(ctx context.Context, server *relayServer, chans <-chan ssh.NewChannel) error { + errc := make(chan error, 1) + go func() { + for ch := range chans { + if handler, ok := server.channels[ch.ChannelType()]; ok { + if err := handler(ctx, ch); err != nil { + errc <- err + return + } + } else { + // generic accept of the channel to not block + _, _, err := ch.Accept() + if err != nil { + errc <- fmt.Errorf("error accepting channel: %w", err) + return + } + } + } + }() + return awaitError(ctx, errc) +} + +func handleRequests(ctx context.Context, reqs <-chan tunnelssh.SSHRequest) { + for { + select { + case <-ctx.Done(): + return + case req, ok := <-reqs: + if !ok { + return + } + + if req.Type() == "RefreshPorts" { + _ = req.Reply(true, nil) + continue + } else { + _ = req.Reply(false, nil) + } + } + } +} + +func awaitError(ctx context.Context, errc <-chan error) error { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errc: + return err + } +} + +type socketConn struct { + *websocket.Conn + + reader io.Reader + writeMutex sync.Mutex + readMutex sync.Mutex +} + +func newSocketConn(conn *websocket.Conn) *socketConn { + return &socketConn{Conn: conn} +} + +func (s *socketConn) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + msgType, r, err := s.Conn.NextReader() + if err != nil { + return 0, fmt.Errorf("error getting next reader: %w", err) + } + if msgType != websocket.BinaryMessage { + return 0, fmt.Errorf("invalid message type") + } + s.reader = r + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socketConn) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + w, err := s.Conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, fmt.Errorf("error getting next writer: %w", err) + } + + n, err := w.Write(b) + if err != nil { + return 0, fmt.Errorf("error writing: %w", err) + } + + if err := w.Close(); err != nil { + return 0, fmt.Errorf("error closing writer: %w", err) + } + + return n, nil +} + +func (s *socketConn) SetDeadline(deadline time.Time) error { + if err := s.Conn.SetReadDeadline(deadline); err != nil { + return err + } + return s.Conn.SetWriteDeadline(deadline) +} diff --git a/internal/codespaces/portforwarder/port_forwarder.go b/internal/codespaces/portforwarder/port_forwarder.go new file mode 100644 index 000000000..ba510e4ff --- /dev/null +++ b/internal/codespaces/portforwarder/port_forwarder.go @@ -0,0 +1,253 @@ +package portforwarder + +import ( + "context" + "fmt" + "net" + "strings" + + "github.com/cli/cli/v2/internal/codespaces/connection" + "github.com/microsoft/dev-tunnels/go/tunnels" +) + +const ( + githubSubjectId = "1" + InternalPortTag = "InternalPort" + UserForwardedPortTag = "UserForwardedPort" +) + +const ( + PrivatePortVisibility = "private" + OrgPortVisibility = "org" + PublicPortVisibility = "public" +) + +type PortForwarder struct { + connection connection.CodespaceConnection +} + +// NewPortForwarder returns a new PortForwarder for the specified codespace. +func NewPortForwarder(ctx context.Context, codespaceConnection *connection.CodespaceConnection) (fwd *PortForwarder, err error) { + return &PortForwarder{ + connection: *codespaceConnection, + }, nil +} + +// ForwardAndConnectToPort forwards a port and connects to it via a local TCP port. +func (fwd *PortForwarder) ForwardAndConnectToPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, internal bool) error { + return fwd.ForwardPort(ctx, remotePort, listen, keepAlive, true, internal, "") +} + +// ForwardPort forwards a port and optionally connects to it via a local TCP port. +func (fwd *PortForwarder) ForwardPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, connect bool, internal bool, visibility string) error { + tunnelPort := tunnels.NewTunnelPort(remotePort, "", "", tunnels.TunnelProtocolHttp) + + // If no visibility is provided, Dev Tunnels will use the default (private) + if visibility != "" { + // Check if the requested visibility is allowed + allowed := false + for _, allowedVisibility := range fwd.connection.AllowedPortPrivacySettings { + if allowedVisibility == visibility { + allowed = true + break + } + } + + // If the requested visibility is not allowed, return an error + if !allowed { + return fmt.Errorf("visibility %s is not allowed", visibility) + } + + accessControlEntries := visibilityToAccessControlEntries(visibility) + if len(accessControlEntries) > 0 { + tunnelPort.AccessControl = &tunnels.TunnelAccessControl{ + Entries: accessControlEntries, + } + } + } + + // Tag the port as internal or user forwarded so we know if it needs to be shown in the UI + if internal { + tunnelPort.Tags = []string{InternalPortTag} + } else { + tunnelPort.Tags = []string{UserForwardedPortTag} + } + + // Create the tunnel port + _, err := fwd.connection.TunnelManager.CreateTunnelPort(ctx, fwd.connection.Tunnel, tunnelPort, fwd.connection.Options) + if err != nil && !strings.Contains(err.Error(), "409") { + return fmt.Errorf("create tunnel port failed: %v", err) + } + + done := make(chan error) + go func() { + // Connect to the tunnel + err = fwd.connection.TunnelClient.Connect(ctx, "") + if err != nil { + done <- fmt.Errorf("connect failed: %v", err) + return + } + + // Inform the host that we've forwarded the port locally + err = fwd.connection.TunnelClient.RefreshPorts(ctx) + if err != nil { + done <- fmt.Errorf("refresh ports failed: %v", err) + return + } + + // If we don't want to connect to the port, exit early + if !connect { + done <- nil + return + } + + // Ensure the port is forwarded before connecting + err = fwd.connection.TunnelClient.WaitForForwardedPort(ctx, remotePort) + if err != nil { + done <- fmt.Errorf("wait for forwarded port failed: %v", err) + return + } + + // Connect to the forwarded port via a local TCP port + err = fwd.connection.TunnelClient.ConnectToForwardedPort(ctx, listen, remotePort) + if err != nil { + done <- fmt.Errorf("connect to forwarded port failed: %v", err) + return + } + + done <- nil + }() + select { + case err := <-done: + if err != nil { + return fmt.Errorf("error connecting to tunnel: %w", err) + } + return nil + case <-ctx.Done(): + return nil + } +} + +// ListPorts fetches the list of ports that are currently forwarded. +func (fwd *PortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.TunnelPort, err error) { + ports, err = fwd.connection.TunnelManager.ListTunnelPorts(ctx, fwd.connection.Tunnel, fwd.connection.Options) + if err != nil { + return nil, fmt.Errorf("error listing ports: %w", err) + } + + return ports, nil +} + +// UpdatePortVisibility changes the visibility (private, org, public) of the specified port. +func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error { + tunnelPort, err := fwd.connection.TunnelManager.GetTunnelPort(ctx, fwd.connection.Tunnel, remotePort, fwd.connection.Options) + if err != nil { + return fmt.Errorf("error getting tunnel port: %w", err) + } + + // If the port visibility isn't changing, don't do anything + if AccessControlEntriesToVisibility(tunnelPort.AccessControl.Entries) == visibility { + return nil + } + + // Delete the existing tunnel port to update + err = fwd.connection.TunnelManager.DeleteTunnelPort(ctx, fwd.connection.Tunnel, uint16(remotePort), fwd.connection.Options) + if err != nil { + return fmt.Errorf("error deleting tunnel port: %w", err) + } + + done := make(chan error) + go func() { + // Connect to the tunnel + err = fwd.connection.TunnelClient.Connect(ctx, "") + if err != nil { + done <- fmt.Errorf("connect failed: %v", err) + return + } + + // Inform the host that we've deleted the port + err = fwd.connection.TunnelClient.RefreshPorts(ctx) + if err != nil { + done <- fmt.Errorf("refresh ports failed: %v", err) + return + } + + done <- nil + }() + + // Wait for the done channel to be closed + select { + case err := <-done: + if err != nil { + return fmt.Errorf("error connecting to tunnel: %w", err) + } + + // Re-forward the port with the updated visibility + err = fwd.ForwardPort(ctx, uint16(remotePort), nil, false, false, false, visibility) + if err != nil { + return fmt.Errorf("error forwarding port: %w", err) + } + + return nil + case <-ctx.Done(): + return nil + } +} + +// AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value. +func AccessControlEntriesToVisibility(accessControlEntries []tunnels.TunnelAccessControlEntry) string { + for _, entry := range accessControlEntries { + // If we have the anonymous type (and we're not denying it), it's public + if (entry.Type == tunnels.TunnelAccessControlEntryTypeAnonymous) && (!entry.IsDeny) { + return PublicPortVisibility + } + + // If we have the organizations type (and we're not denying it), it's org + if (entry.Provider == string(tunnels.TunnelAuthenticationSchemeGitHub)) && (!entry.IsDeny) { + return OrgPortVisibility + } + } + + // Else, it's private + return PrivatePortVisibility +} + +// visibilityToAccessControlEntries converts the given visibility to access control entries that can be used by Dev Tunnels. +func visibilityToAccessControlEntries(visibility string) []tunnels.TunnelAccessControlEntry { + switch visibility { + case PublicPortVisibility: + return []tunnels.TunnelAccessControlEntry{{ + Type: tunnels.TunnelAccessControlEntryTypeAnonymous, + Subjects: []string{}, + Scopes: []string{string(tunnels.TunnelAccessScopeConnect)}, + }} + case OrgPortVisibility: + return []tunnels.TunnelAccessControlEntry{{ + Type: tunnels.TunnelAccessControlEntryTypeOrganizations, + Subjects: []string{githubSubjectId}, + Scopes: []string{ + string(tunnels.TunnelAccessScopeConnect), + }, + Provider: string(tunnels.TunnelAuthenticationSchemeGitHub), + }} + default: + // The tunnel manager doesn't accept empty access control entries, so we need to return a deny entry + return []tunnels.TunnelAccessControlEntry{{ + Type: tunnels.TunnelAccessControlEntryTypeOrganizations, + Subjects: []string{githubSubjectId}, + Scopes: []string{}, + IsDeny: true, + }} + } +} + +// IsInternalPort returns true if the port is internal. +func IsInternalPort(port *tunnels.TunnelPort) bool { + for _, tag := range port.Tags { + if strings.EqualFold(tag, InternalPortTag) { + return true + } + } + + return false +} diff --git a/internal/codespaces/portforwarder/port_forwarder_test.go b/internal/codespaces/portforwarder/port_forwarder_test.go new file mode 100644 index 000000000..d107afec4 --- /dev/null +++ b/internal/codespaces/portforwarder/port_forwarder_test.go @@ -0,0 +1,139 @@ +package portforwarder + +import ( + "context" + "testing" + + "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/cli/cli/v2/internal/codespaces/connection" + "github.com/microsoft/dev-tunnels/go/tunnels" +) + +func TestNewPortForwarder(t *testing.T) { + ctx := context.Background() + + // Create a mock codespace + codespace := &api.Codespace{ + Connection: api.CodespaceConnection{ + TunnelProperties: api.TunnelProperties{ + ConnectAccessToken: "connect-token", + ManagePortsAccessToken: "manage-ports-token", + ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/", + TunnelId: "tunnel-id", + ClusterId: "usw2", + Domain: "domain.com", + }, + }, + RuntimeConstraints: api.RuntimeConstraints{ + AllowedPortPrivacySettings: []string{"public", "private"}, + }, + } + + // Create the mock HTTP client + httpClient, err := connection.NewMockHttpClient() + if err != nil { + t.Fatalf("NewHttpClient returned an error: %v", err) + } + + // Call the function being tested + conn, err := connection.NewCodespaceConnection(ctx, codespace, httpClient) + if err != nil { + t.Fatalf("NewCodespaceConnection returned an error: %v", err) + } + + // Create the new port forwarder + portForwarder, err := NewPortForwarder(ctx, conn) + if err != nil { + t.Fatalf("NewPortForwarder returned an error: %v", err) + } + + // Check that the port forwarder was created successfully + if portForwarder == nil { + t.Fatal("NewPortForwarder returned nil") + } +} + +func TestAccessControlEntriesToVisibility(t *testing.T) { + publicAccessControlEntry := []tunnels.TunnelAccessControlEntry{{ + Type: tunnels.TunnelAccessControlEntryTypeAnonymous, + }} + orgAccessControlEntry := []tunnels.TunnelAccessControlEntry{{ + Provider: string(tunnels.TunnelAuthenticationSchemeGitHub), + }} + privateAccessControlEntry := []tunnels.TunnelAccessControlEntry{} + orgIsDenyAccessControlEntry := []tunnels.TunnelAccessControlEntry{{ + Provider: string(tunnels.TunnelAuthenticationSchemeGitHub), + IsDeny: true, + }} + + tests := []struct { + name string + accessControlEntries []tunnels.TunnelAccessControlEntry + expected string + }{ + { + name: "public", + accessControlEntries: publicAccessControlEntry, + expected: PublicPortVisibility, + }, + { + name: "org", + accessControlEntries: orgAccessControlEntry, + expected: OrgPortVisibility, + }, + { + name: "private", + accessControlEntries: privateAccessControlEntry, + expected: PrivatePortVisibility, + }, + { + name: "orgIsDeny", + accessControlEntries: orgIsDenyAccessControlEntry, + expected: PrivatePortVisibility, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + visibility := AccessControlEntriesToVisibility(test.accessControlEntries) + if visibility != test.expected { + t.Errorf("expected %q, got %q", test.expected, visibility) + } + }) + } +} + +func TestIsInternalPort(t *testing.T) { + internalPort := &tunnels.TunnelPort{ + Tags: []string{"InternalPort"}, + } + userForwardedPort := &tunnels.TunnelPort{ + Tags: []string{"UserForwardedPort"}, + } + + tests := []struct { + name string + port *tunnels.TunnelPort + expected bool + }{ + { + name: "internal", + port: internalPort, + expected: true, + }, + { + name: "user-forwarded", + port: userForwardedPort, + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + isInternal := IsInternalPort(test.port) + if isInternal != test.expected { + t.Errorf("expected %v, got %v", test.expected, isInternal) + } + }) + } +} diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index 006669151..86493ce8e 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "log" + "net/http" "os" "sort" "strings" @@ -104,6 +105,7 @@ type apiClient interface { ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []api.DevContainerEntry, err error) GetCodespaceRepoSuggestions(ctx context.Context, partialSearch string, params api.RepoSearchParameters) ([]string, error) GetCodespaceBillableOwner(ctx context.Context, nwo string) (*api.User, error) + HTTPClient() (*http.Client, error) } var errNoCodespaces = errors.New("you have no codespaces") diff --git a/pkg/cmd/codespace/mock_api.go b/pkg/cmd/codespace/mock_api.go index 7e211fd7c..aad15c025 100644 --- a/pkg/cmd/codespace/mock_api.go +++ b/pkg/cmd/codespace/mock_api.go @@ -5,6 +5,7 @@ package codespace import ( "context" + "net/http" "sync" codespacesAPI "github.com/cli/cli/v2/internal/codespaces/api" @@ -40,15 +41,15 @@ import ( // GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*codespacesAPI.Machine, error) { // panic("mock out the GetCodespacesMachines method") // }, +// HTTPClientFunc: func() (*http.Client, error) { +// panic("mock out the HTTPClient method") +// }, // GetOrgMemberCodespaceFunc: func(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error) { // panic("mock out the GetOrgMemberCodespace method") // }, // GetRepositoryFunc: func(ctx context.Context, nwo string) (*codespacesAPI.Repository, error) { // panic("mock out the GetRepository method") // }, -// ServerURLFunc: func() string { -// panic("mock out the ServerURL method") -// }, // GetUserFunc: func(ctx context.Context) (*codespacesAPI.User, error) { // panic("mock out the GetUser method") // }, @@ -58,6 +59,9 @@ import ( // ListDevContainersFunc: func(ctx context.Context, repoID int, branch string, limit int) ([]codespacesAPI.DevContainerEntry, error) { // panic("mock out the ListDevContainers method") // }, +// ServerURLFunc: func() string { +// panic("mock out the ServerURL method") +// }, // StartCodespaceFunc: func(ctx context.Context, name string) error { // panic("mock out the StartCodespace method") // }, @@ -95,15 +99,15 @@ type apiClientMock struct { // GetCodespacesMachinesFunc mocks the GetCodespacesMachines method. GetCodespacesMachinesFunc func(ctx context.Context, repoID int, branch string, location string, devcontainerPath string) ([]*codespacesAPI.Machine, error) + // HTTPClientFunc mocks the HTTPClient method. + HTTPClientFunc func() (*http.Client, error) + // GetOrgMemberCodespaceFunc mocks the GetOrgMemberCodespace method. GetOrgMemberCodespaceFunc func(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error) // GetRepositoryFunc mocks the GetRepository method. GetRepositoryFunc func(ctx context.Context, nwo string) (*codespacesAPI.Repository, error) - // ServerURLFunc mocks the ServerURL method. - ServerURLFunc func() string - // GetUserFunc mocks the GetUser method. GetUserFunc func(ctx context.Context) (*codespacesAPI.User, error) @@ -113,6 +117,9 @@ type apiClientMock struct { // ListDevContainersFunc mocks the ListDevContainers method. ListDevContainersFunc func(ctx context.Context, repoID int, branch string, limit int) ([]codespacesAPI.DevContainerEntry, error) + // ServerURLFunc mocks the ServerURL method. + ServerURLFunc func() string + // StartCodespaceFunc mocks the StartCodespace method. StartCodespaceFunc func(ctx context.Context, name string) error @@ -195,6 +202,9 @@ type apiClientMock struct { // DevcontainerPath is the devcontainerPath argument value. DevcontainerPath string } + // HTTPClient holds details about calls to the HTTPClient method. + HTTPClient []struct { + } // GetOrgMemberCodespace holds details about calls to the GetOrgMemberCodespace method. GetOrgMemberCodespace []struct { // Ctx is the ctx argument value. @@ -213,9 +223,6 @@ type apiClientMock struct { // Nwo is the nwo argument value. Nwo string } - // ServerURL holds details about calls to the ServerURL method. - ServerURL []struct { - } // GetUser holds details about calls to the GetUser method. GetUser []struct { // Ctx is the ctx argument value. @@ -239,6 +246,9 @@ type apiClientMock struct { // Limit is the limit argument value. Limit int } + // ServerURL holds details about calls to the ServerURL method. + ServerURL []struct { + } // StartCodespace holds details about calls to the StartCodespace method. StartCodespace []struct { // Ctx is the ctx argument value. @@ -266,12 +276,13 @@ type apiClientMock struct { lockGetCodespaceRepoSuggestions sync.RWMutex lockGetCodespaceRepositoryContents sync.RWMutex lockGetCodespacesMachines sync.RWMutex + lockHTTPClient sync.RWMutex lockGetOrgMemberCodespace sync.RWMutex lockGetRepository sync.RWMutex - lockServerURL sync.RWMutex lockGetUser sync.RWMutex lockListCodespaces sync.RWMutex lockListDevContainers sync.RWMutex + lockServerURL sync.RWMutex lockStartCodespace sync.RWMutex lockStopCodespace sync.RWMutex } @@ -600,6 +611,33 @@ func (mock *apiClientMock) GetCodespacesMachinesCalls() []struct { return calls } +// HTTPClient calls HTTPClientFunc. +func (mock *apiClientMock) HTTPClient() (*http.Client, error) { + if mock.HTTPClientFunc == nil { + panic("apiClientMock.HTTPClientFunc: method is nil but apiClient.HTTPClient was just called") + } + callInfo := struct { + }{} + mock.lockHTTPClient.Lock() + mock.calls.HTTPClient = append(mock.calls.HTTPClient, callInfo) + mock.lockHTTPClient.Unlock() + return mock.HTTPClientFunc() +} + +// HTTPClientCalls gets all the calls that were made to HTTPClient. +// Check the length with: +// +// len(mockedapiClient.HTTPClientCalls()) +func (mock *apiClientMock) HTTPClientCalls() []struct { +} { + var calls []struct { + } + mock.lockHTTPClient.RLock() + calls = mock.calls.HTTPClient + mock.lockHTTPClient.RUnlock() + return calls +} + // GetOrgMemberCodespace calls GetOrgMemberCodespaceFunc. func (mock *apiClientMock) GetOrgMemberCodespace(ctx context.Context, orgName string, userName string, codespaceName string) (*codespacesAPI.Codespace, error) { if mock.GetOrgMemberCodespaceFunc == nil { @@ -680,33 +718,6 @@ func (mock *apiClientMock) GetRepositoryCalls() []struct { return calls } -// ServerURL calls ServerURLFunc. -func (mock *apiClientMock) ServerURL() string { - if mock.ServerURLFunc == nil { - panic("apiClientMock.ServerURLFunc: method is nil but apiClient.ServerURL was just called") - } - callInfo := struct { - }{} - mock.lockServerURL.Lock() - mock.calls.ServerURL = append(mock.calls.ServerURL, callInfo) - mock.lockServerURL.Unlock() - return mock.ServerURLFunc() -} - -// ServerURLCalls gets all the calls that were made to ServerURL. -// Check the length with: -// -// len(mockedapiClient.ServerURLCalls()) -func (mock *apiClientMock) ServerURLCalls() []struct { -} { - var calls []struct { - } - mock.lockServerURL.RLock() - calls = mock.calls.ServerURL - mock.lockServerURL.RUnlock() - return calls -} - // GetUser calls GetUserFunc. func (mock *apiClientMock) GetUser(ctx context.Context) (*codespacesAPI.User, error) { if mock.GetUserFunc == nil { @@ -819,6 +830,33 @@ func (mock *apiClientMock) ListDevContainersCalls() []struct { return calls } +// ServerURL calls ServerURLFunc. +func (mock *apiClientMock) ServerURL() string { + if mock.ServerURLFunc == nil { + panic("apiClientMock.ServerURLFunc: method is nil but apiClient.ServerURL was just called") + } + callInfo := struct { + }{} + mock.lockServerURL.Lock() + mock.calls.ServerURL = append(mock.calls.ServerURL, callInfo) + mock.lockServerURL.Unlock() + return mock.ServerURLFunc() +} + +// ServerURLCalls gets all the calls that were made to ServerURL. +// Check the length with: +// +// len(mockedapiClient.ServerURLCalls()) +func (mock *apiClientMock) ServerURLCalls() []struct { +} { + var calls []struct { + } + mock.lockServerURL.RLock() + calls = mock.calls.ServerURL + mock.lockServerURL.RUnlock() + return calls +} + // StartCodespace calls StartCodespaceFunc. func (mock *apiClientMock) StartCodespace(ctx context.Context, name string) error { if mock.StartCodespaceFunc == nil { diff --git a/pkg/cmd/codespace/ports.go b/pkg/cmd/codespace/ports.go index cbe48b7b3..b7eb96537 100644 --- a/pkg/cmd/codespace/ports.go +++ b/pkg/cmd/codespace/ports.go @@ -6,26 +6,21 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "strconv" "strings" "time" "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/cli/cli/v2/internal/codespaces/portforwarder" + "github.com/cli/cli/v2/internal/tableprinter" "github.com/cli/cli/v2/pkg/cmdutil" - "github.com/cli/cli/v2/pkg/liveshare" - "github.com/cli/cli/v2/utils" + "github.com/microsoft/dev-tunnels/go/tunnels" "github.com/muhammadmuzzammil1998/jsonc" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" ) -const ( - vscodeServerPortName = "VSCodeServerInternal" - codespacesInternalPortName = "CodespacesInternal" -) - // newPortsCmd returns a Cobra "ports" command that displays a table of available ports, // according to the specified flags. func newPortsCmd(app *App) *cobra.Command { @@ -62,15 +57,19 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export devContainerCh := getDevContainer(ctx, a.apiClient, codespace) - session, err := startLiveShareSession(ctx, codespace, a, false, "") + codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace) if err != nil { - return err + return fmt.Errorf("error connecting to codespace: %w", err) } - defer safeClose(session, &err) - var ports []*liveshare.Port + fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection) + if err != nil { + return fmt.Errorf("failed to create port forwarder: %w", err) + } + + var ports []*tunnels.TunnelPort err = a.RunWithProgress("Fetching ports", func() (err error) { - ports, err = session.GetSharedServers(ctx) + ports, err = fwd.ListPorts(ctx) return }) if err != nil { @@ -87,9 +86,10 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export for _, p := range ports { // filter out internal ports from list - if strings.HasPrefix(p.SessionName, vscodeServerPortName) || strings.HasPrefix(p.SessionName, codespacesInternalPortName) { + if portforwarder.IsInternalPort(p) { continue } + portInfos = append(portInfos, &portInfo{ Port: p, codespace: codespace, @@ -107,40 +107,42 @@ func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, export } cs := a.io.ColorScheme() - //nolint:staticcheck // SA1019: utils.NewTablePrinter is deprecated: use internal/tableprinter - tp := utils.NewTablePrinter(a.io) + tp := tableprinter.New(a.io) - if tp.IsTTY() { - tp.AddField("LABEL", nil, nil) - tp.AddField("PORT", nil, nil) - tp.AddField("VISIBILITY", nil, nil) - tp.AddField("BROWSE URL", nil, nil) + if a.io.IsStdoutTTY() { + tp.AddField("LABEL") + tp.AddField("PORT") + tp.AddField("VISIBILITY") + tp.AddField("BROWSE URL") tp.EndRow() } for _, port := range portInfos { - tp.AddField(port.Label(), nil, nil) - tp.AddField(strconv.Itoa(port.SourcePort), nil, cs.Yellow) - tp.AddField(port.Privacy, nil, nil) - tp.AddField(port.BrowseURL(), nil, nil) + // Convert the ACE to a friendly visibility string (private, org, public) + visibility := portforwarder.AccessControlEntriesToVisibility(port.Port.AccessControl.Entries) + + tp.AddField(port.Label()) + tp.AddField(cs.Yellow(fmt.Sprintf("%d", port.Port.PortNumber))) + tp.AddField(visibility) + tp.AddField(port.BrowseURL()) tp.EndRow() } return tp.Render() } type portInfo struct { - *liveshare.Port + Port *tunnels.TunnelPort codespace *api.Codespace devContainer *devContainer } func (pi *portInfo) BrowseURL() string { - return fmt.Sprintf("https://%s-%d.preview.app.github.dev", pi.codespace.Name, pi.Port.SourcePort) + return fmt.Sprintf("https://%s-%d.app.github.dev", pi.codespace.Name, pi.Port.PortNumber) } func (pi *portInfo) Label() string { if pi.devContainer != nil { - portStr := strconv.Itoa(pi.Port.SourcePort) + portStr := strconv.Itoa(int(pi.Port.PortNumber)) if attributes, ok := pi.devContainer.PortAttributes[portStr]; ok { return attributes.Label } @@ -150,7 +152,6 @@ func (pi *portInfo) Label() string { var portFields = []string{ "sourcePort", - // "destinationPort", // TODO(mislav): this appears to always be blank? "visibility", "label", "browseUrl", @@ -162,11 +163,9 @@ func (pi *portInfo) ExportData(fields []string) map[string]interface{} { for _, f := range fields { switch f { case "sourcePort": - data[f] = pi.Port.SourcePort - case "destinationPort": - data[f] = pi.Port.DestinationPort + data[f] = pi.Port.PortNumber case "visibility": - data[f] = pi.Port.Privacy + data[f] = portforwarder.AccessControlEntriesToVisibility(pi.Port.AccessControl.Entries) case "label": data[f] = pi.Label() case "browseUrl": @@ -235,30 +234,6 @@ func newPortsVisibilityCmd(app *App, selector *CodespaceSelector) *cobra.Command } } -type ErrUpdatingPortVisibility struct { - port int - visibility string - err error -} - -func newErrUpdatingPortVisibility(port int, visibility string, err error) *ErrUpdatingPortVisibility { - return &ErrUpdatingPortVisibility{ - port: port, - visibility: visibility, - err: err, - } -} - -func (e *ErrUpdatingPortVisibility) Error() string { - return fmt.Sprintf("error waiting for port %d to update to %s: %s", e.port, e.visibility, e.err) -} - -func (e *ErrUpdatingPortVisibility) Unwrap() error { - return e.err -} - -var errUpdatePortVisibilityForbidden = errors.New("organization admin has forbidden this privacy setting") - func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelector, args []string) (err error) { ports, err := a.parsePortVisibilities(args) if err != nil { @@ -270,47 +245,28 @@ func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelec return err } - session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, codespace) + codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace) if err != nil { return fmt.Errorf("error connecting to codespace: %w", err) } - defer safeClose(session, &err) + + fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection) + if err != nil { + return fmt.Errorf("failed to create port forwarder: %w", err) + } // TODO: check if port visibility can be updated in parallel instead of sequentially for _, port := range ports { err := a.RunWithProgress(fmt.Sprintf("Updating port %d visibility to: %s", port.number, port.visibility), func() (err error) { // wait for success or failure - g, ctx := errgroup.WithContext(ctx) ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - g.Go(func() error { - updateNotif, err := session.WaitForPortNotification(ctx, port.number, liveshare.PortChangeKindUpdate) - if err != nil { - return fmt.Errorf("error waiting for port %d to update: %w", port.number, err) - - } - if !updateNotif.Success { - if updateNotif.StatusCode == http.StatusForbidden { - return newErrUpdatingPortVisibility(port.number, port.visibility, errUpdatePortVisibilityForbidden) - } - return newErrUpdatingPortVisibility(port.number, port.visibility, errors.New(updateNotif.ErrorDetail)) - - } - return nil // success - }) - - g.Go(func() error { - err := session.UpdateSharedServerPrivacy(ctx, port.number, port.visibility) - if err != nil { - return fmt.Errorf("error updating port %d to %s: %w", port.number, port.visibility, err) - } - return nil - }) - - // wait for success or failure - err = g.Wait() - return + err = fwd.UpdatePortVisibility(ctx, port.number, port.visibility) + if err != nil { + return fmt.Errorf("error updating port %d to %s: %w", port.number, port.visibility, err) + } + return nil }) if err != nil { return err @@ -367,11 +323,10 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por return err } - session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, codespace) + codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace) if err != nil { return fmt.Errorf("error connecting to codespace: %w", err) } - defer safeClose(session, &err) // Run forwarding of all ports concurrently, aborting all of // them at the first failure, including cancellation of the context. @@ -386,9 +341,11 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por defer listen.Close() a.errLogger.Printf("Forwarding ports: remote %d <=> local %d", pair.remote, pair.local) - name := fmt.Sprintf("share-%d", pair.remote) - fwd := liveshare.NewPortForwarder(session, name, pair.remote, false) - return fwd.ForwardToListener(ctx, listen) // error always non-nil + fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection) + if err != nil { + return fmt.Errorf("failed to create port forwarder: %w", err) + } + return fwd.ForwardAndConnectToPort(ctx, uint16(pair.remote), listen, false, false) }) } return group.Wait() // first error diff --git a/pkg/cmd/codespace/ports_test.go b/pkg/cmd/codespace/ports_test.go index bb7554238..034c15eb6 100644 --- a/pkg/cmd/codespace/ports_test.go +++ b/pkg/cmd/codespace/ports_test.go @@ -2,18 +2,34 @@ package codespace import ( "context" - "errors" "fmt" - "os" + "net/http" "testing" "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/cli/cli/v2/internal/codespaces/connection" "github.com/cli/cli/v2/pkg/iostreams" - "github.com/cli/cli/v2/pkg/liveshare" - livesharetest "github.com/cli/cli/v2/pkg/liveshare/test" - "github.com/sourcegraph/jsonrpc2" ) +func TestListPorts(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockApi := GetMockApi(false) + ios, _, _, _ := iostreams.Test() + + a := &App{ + io: ios, + apiClient: mockApi, + } + + selector := &CodespaceSelector{api: a.apiClient, codespaceName: "codespace-name"} + err := a.ListPorts(ctx, selector, nil) + if err != nil { + t.Errorf("unexpected error: %v", err) + } +} + func TestPortsUpdateVisibilitySuccess(t *testing.T) { portVisibilities := []portVisibility{ { @@ -26,175 +42,35 @@ func TestPortsUpdateVisibilitySuccess(t *testing.T) { }, } - eventResponses := []string{ - "serverSharing.sharingSucceeded", - "serverSharing.sharingSucceeded", - } - - portsData := []liveshare.PortNotification{ - { - Success: true, - Port: 80, - ChangeKind: liveshare.PortChangeKindUpdate, - }, - { - Success: true, - Port: 9999, - ChangeKind: liveshare.PortChangeKindUpdate, - }, - } - - err := runUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData) - + err := runUpdateVisibilityTest(t, portVisibilities, true) if err != nil { t.Errorf("unexpected error: %v", err) } } -func TestPortsUpdateVisibilityFailure403(t *testing.T) { - portVisibilities := []portVisibility{ - { - number: 80, - visibility: "org", - }, - { - number: 9999, - visibility: "public", - }, - } - - eventResponses := []string{ - "serverSharing.sharingSucceeded", - "serverSharing.sharingFailed", - } - - portsData := []liveshare.PortNotification{ - { - Success: true, - Port: 80, - ChangeKind: liveshare.PortChangeKindUpdate, - }, - { - Success: false, - Port: 9999, - ChangeKind: liveshare.PortChangeKindUpdate, - ErrorDetail: "test error", - StatusCode: 403, - }, - } - - err := runUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData) - if err == nil { - t.Fatalf("runUpdateVisibilityTest succeeded unexpectedly") - } - - if errors.Unwrap(err) != errUpdatePortVisibilityForbidden { - t.Errorf("expected: %v, got: %v", errUpdatePortVisibilityForbidden, errors.Unwrap(err)) - } -} - func TestPortsUpdateVisibilityFailure(t *testing.T) { portVisibilities := []portVisibility{ - { - number: 80, - visibility: "org", - }, { number: 9999, visibility: "public", }, - } - - eventResponses := []string{ - "serverSharing.sharingSucceeded", - "serverSharing.sharingFailed", - } - - portsData := []liveshare.PortNotification{ { - Success: true, - Port: 80, - ChangeKind: liveshare.PortChangeKindUpdate, - }, - { - Success: false, - Port: 9999, - ChangeKind: liveshare.PortChangeKindUpdate, - ErrorDetail: "test error", + number: 80, + visibility: "org", }, } - err := runUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData) + err := runUpdateVisibilityTest(t, portVisibilities, false) if err == nil { t.Fatalf("runUpdateVisibilityTest succeeded unexpectedly") } - - var expectedErr *ErrUpdatingPortVisibility - if !errors.As(err, &expectedErr) { - t.Errorf("expected: %v, got: %v", expectedErr, err) - } } -type joinWorkspaceResult struct { - SessionNumber int `json:"sessionNumber"` -} - -func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, eventResponses []string, portsData []liveshare.PortNotification) error { - t.Helper() - if os.Getenv("GITHUB_ACTIONS") == "true" { - t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5663") - } - - joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { - return joinWorkspaceResult{1}, nil - } - const sessionToken = "session-token" - - ch := make(chan *jsonrpc2.Conn, 1) - updateSharedVisibility := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) { - ch <- conn - return nil, nil - } - testServer, err := livesharetest.NewServer( - livesharetest.WithNonSecure(), - livesharetest.WithPassword(sessionToken), - livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), - livesharetest.WithService("serverSharing.updateSharedServerPrivacy", updateSharedVisibility), - ) - if err != nil { - return fmt.Errorf("unable to create test server: %w", err) - } - +func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, allowOrgPorts bool) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - go func() { - for i, pd := range portsData { - select { - case <-ctx.Done(): - return - case conn := <-ch: - _, _ = conn.DispatchCall(ctx, eventResponses[i], pd, nil) - } - } - }() - - mockApi := &apiClientMock{ - GetCodespaceFunc: func(ctx context.Context, codespaceName string, includeConnection bool) (*api.Codespace, error) { - return &api.Codespace{ - Name: "codespace-name", - State: api.CodespaceStateAvailable, - Connection: api.CodespaceConnection{ - SessionID: "session-id", - SessionToken: sessionToken, - RelayEndpoint: testServer.URL(), - RelaySAS: "relay-sas", - HostPublicKeys: []string{livesharetest.SSHPublicKey}, - }, - }, nil - }, - } - + mockApi := GetMockApi(allowOrgPorts) ios, _, _, _ := iostreams.Test() a := &App{ @@ -251,6 +127,44 @@ func TestPendingOperationDisallowsForwardPorts(t *testing.T) { } } +func GetMockApi(allowOrgPorts bool) *apiClientMock { + return &apiClientMock{ + GetCodespaceFunc: func(ctx context.Context, codespaceName string, includeConnection bool) (*api.Codespace, error) { + allowedPortPrivacySettings := []string{"public", "private"} + if allowOrgPorts { + allowedPortPrivacySettings = append(allowedPortPrivacySettings, "org") + } + + return &api.Codespace{ + Name: "codespace-name", + State: api.CodespaceStateAvailable, + Connection: api.CodespaceConnection{ + TunnelProperties: api.TunnelProperties{ + ConnectAccessToken: "tunnel access-token", + ManagePortsAccessToken: "manage-ports-token", + ServiceUri: "http://global.rel.tunnels.api.visualstudio.com/", + TunnelId: "tunnel-id", + ClusterId: "usw2", + Domain: "domain.com", + }, + }, + RuntimeConstraints: api.RuntimeConstraints{ + AllowedPortPrivacySettings: allowedPortPrivacySettings, + }, + }, nil + }, + StartCodespaceFunc: func(ctx context.Context, codespaceName string) error { + return nil + }, + GetCodespaceRepositoryContentsFunc: func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) { + return nil, nil + }, + HTTPClientFunc: func() (*http.Client, error) { + return connection.NewMockHttpClient() + }, + } +} + func testingPortsApp() *App { disabledCodespace := &api.Codespace{ Name: "disabledCodespace",