diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 4b9029313..cd5695abc 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,5 +1,4 @@ * @cli/code-reviewers pkg/cmd/codespace/ @cli/codespaces -pkg/liveshare/ @cli/codespaces internal/codespaces/ @cli/codespaces diff --git a/go.mod b/go.mod index 96bbc113a..cc56bda51 100644 --- a/go.mod +++ b/go.mod @@ -27,12 +27,11 @@ require ( github.com/mattn/go-colorable v0.1.13 github.com/mattn/go-isatty v0.0.19 github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d - github.com/microsoft/dev-tunnels v0.0.21 + github.com/microsoft/dev-tunnels v0.0.25 github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38 github.com/opentracing/opentracing-go v1.1.0 github.com/rivo/tview v0.0.0-20221029100920-c4a7e501810d github.com/shurcooL/githubv4 v0.0.0-20230704064427-599ae7bbf278 - github.com/sourcegraph/jsonrpc2 v0.1.0 github.com/spf13/cobra v1.6.1 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 diff --git a/go.sum b/go.sum index a6e30ac16..c1160707a 100644 --- a/go.sum +++ b/go.sum @@ -67,7 +67,6 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaU github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= -github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= @@ -118,8 +117,8 @@ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyex github.com/microcosm-cc/bluemonday v1.0.21/go.mod h1:ytNkv4RrDrLJ2pqlsSI46O6IVXmZOBBD4SaJyDwwTkM= github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58= github.com/microcosm-cc/bluemonday v1.0.26/go.mod h1:JyzOCs9gkyQyjs+6h10UEVSe02CGwkhd72Xdqh78TWs= -github.com/microsoft/dev-tunnels v0.0.21 h1:p4QP7C5ZOyP9bGbmanRjPxUMckfi9Z41Gl+KY4C11w0= -github.com/microsoft/dev-tunnels v0.0.21/go.mod h1:frU++12T/oqxckXkDpTuYa427ncguEOodSPZcGCCrzQ= +github.com/microsoft/dev-tunnels v0.0.25 h1:UlMKUI+2O8cSu4RlB52ioSyn1LthYSVkJA+CSTsdKoA= +github.com/microsoft/dev-tunnels v0.0.25/go.mod h1:frU++12T/oqxckXkDpTuYa427ncguEOodSPZcGCCrzQ= github.com/muesli/reflow v0.2.1-0.20210115123740-9e1d0d53df68/go.mod h1:Xk+z4oIWdQqJzsxyjgl3P22oYZnHdZ8FFTHAQQt5BMQ= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= @@ -150,8 +149,6 @@ github.com/shurcooL/githubv4 v0.0.0-20230704064427-599ae7bbf278 h1:kdEGVAV4sO46D github.com/shurcooL/githubv4 v0.0.0-20230704064427-599ae7bbf278/go.mod h1:zqMwyHmnN/eDOZOdiTohqIUKUrTFX62PNlu7IJdu0q8= github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 h1:17JxqqJY66GmZVHkmAsGEkcIu0oCe3AM420QDgGwZx0= github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466/go.mod h1:9dIRpgIY7hVhoqfe0/FcYp0bpInZaT7dc3BYOprrIUE= -github.com/sourcegraph/jsonrpc2 v0.1.0 h1:ohJHjZ+PcaLxDUjqk2NC3tIGsVa5bXThe1ZheSXOjuk= -github.com/sourcegraph/jsonrpc2 v0.1.0/go.mod h1:ZafdZgk/axhT1cvZAPOhw+95nz2I/Ra5qMlU4gTRwIo= github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA= github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= diff --git a/internal/codespaces/api/api.go b/internal/codespaces/api/api.go index dd2ba4033..302e0ee0f 100644 --- a/internal/codespaces/api/api.go +++ b/internal/codespaces/api/api.go @@ -247,11 +247,6 @@ const ( ) type CodespaceConnection struct { - SessionID string `json:"sessionId"` - SessionToken string `json:"sessionToken"` - RelayEndpoint string `json:"relayEndpoint"` - RelaySAS string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` TunnelProperties TunnelProperties `json:"tunnelProperties"` } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 3bcc2b404..da2eacd44 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -11,30 +11,20 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/cli/cli/v2/internal/codespaces/api" "github.com/cli/cli/v2/internal/codespaces/connection" - "github.com/cli/cli/v2/pkg/liveshare" ) -func connectionReady(codespace *api.Codespace, usingDevTunnels bool) bool { +func connectionReady(codespace *api.Codespace) bool { // If the codespace is not available, it is not ready if codespace.State != api.CodespaceStateAvailable { return false } - // If using Dev Tunnels, we need to check that we have all of the required tunnel properties - if usingDevTunnels { - return codespace.Connection.TunnelProperties.ConnectAccessToken != "" && - codespace.Connection.TunnelProperties.ManagePortsAccessToken != "" && - codespace.Connection.TunnelProperties.ServiceUri != "" && - codespace.Connection.TunnelProperties.TunnelId != "" && - codespace.Connection.TunnelProperties.ClusterId != "" && - codespace.Connection.TunnelProperties.Domain != "" - } - - // If not using Dev Tunnels, we need to check that we have all of the required Live Share properties - return codespace.Connection.SessionID != "" && - codespace.Connection.SessionToken != "" && - codespace.Connection.RelayEndpoint != "" && - codespace.Connection.RelaySAS != "" + return codespace.Connection.TunnelProperties.ConnectAccessToken != "" && + codespace.Connection.TunnelProperties.ManagePortsAccessToken != "" && + codespace.Connection.TunnelProperties.ServiceUri != "" && + codespace.Connection.TunnelProperties.TunnelId != "" && + codespace.Connection.TunnelProperties.ClusterId != "" && + codespace.Connection.TunnelProperties.Domain != "" } type apiClient interface { @@ -48,11 +38,6 @@ type progressIndicator interface { StopProgressIndicator() } -type logger interface { - Println(v ...interface{}) - Printf(f string, v ...interface{}) -} - type TimeoutError struct { message string } @@ -64,7 +49,7 @@ func (e *TimeoutError) Error() string { // GetCodespaceConnection waits until a codespace is able // to be connected to and initializes a connection to it. func GetCodespaceConnection(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace) (*connection.CodespaceConnection, error) { - codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, true) + codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace) if err != nil { return nil, err } @@ -80,29 +65,8 @@ func GetCodespaceConnection(ctx context.Context, progress progressIndicator, api return connection.NewCodespaceConnection(ctx, codespace, httpClient) } -// ConnectToLiveshare waits until a codespace is able to be -// connected to and connects to it using a Live Share session. -func ConnectToLiveshare(ctx context.Context, progress progressIndicator, sessionLogger logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) { - codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, false) - if err != nil { - return nil, err - } - - progress.StartProgressIndicatorWithLabel("Connecting to codespace") - defer progress.StopProgressIndicator() - - return liveshare.Connect(ctx, liveshare.Options{ - SessionID: codespace.Connection.SessionID, - SessionToken: codespace.Connection.SessionToken, - RelaySAS: codespace.Connection.RelaySAS, - RelayEndpoint: codespace.Connection.RelayEndpoint, - HostPublicKeys: codespace.Connection.HostPublicKeys, - Logger: sessionLogger, - }) -} - // waitUntilCodespaceConnectionReady waits for a Codespace to be running and is able to be connected to. -func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace, usingDevTunnels bool) (*api.Codespace, error) { +func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace) (*api.Codespace, error) { if codespace.State != api.CodespaceStateAvailable { progress.StartProgressIndicatorWithLabel("Starting codespace") defer progress.StopProgressIndicator() @@ -111,7 +75,7 @@ func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressInd } } - if !connectionReady(codespace, usingDevTunnels) { + if !connectionReady(codespace) { expBackoff := backoff.NewExponentialBackOff() expBackoff.Multiplier = 1.1 expBackoff.MaxInterval = 10 * time.Second @@ -124,7 +88,7 @@ func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressInd return backoff.Permanent(fmt.Errorf("error getting codespace: %w", err)) } - if connectionReady(codespace, usingDevTunnels) { + if connectionReady(codespace) { return nil } diff --git a/internal/codespaces/portforwarder/port_forwarder.go b/internal/codespaces/portforwarder/port_forwarder.go index ba510e4ff..44838a6be 100644 --- a/internal/codespaces/portforwarder/port_forwarder.go +++ b/internal/codespaces/portforwarder/port_forwarder.go @@ -3,6 +3,7 @@ package portforwarder import ( "context" "fmt" + "io" "net" "strings" @@ -22,101 +23,75 @@ const ( PublicPortVisibility = "public" ) -type PortForwarder struct { - connection connection.CodespaceConnection +const ( + trafficTypeInput = "input" + trafficTypeOutput = "output" +) + +type ForwardPortOpts struct { + Port int + Internal bool + KeepAlive bool + Visibility string +} + +type CodespacesPortForwarder struct { + connection connection.CodespaceConnection + keepAliveReason chan string +} + +type PortForwarder interface { + ForwardPortToListener(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) error + ForwardPort(ctx context.Context, opts ForwardPortOpts) error + ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, opts ForwardPortOpts) error + ListPorts(ctx context.Context) ([]*tunnels.TunnelPort, error) + UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error + KeepAlive(reason string) + GetKeepAliveReason() string + CloseSSHConnection() } // NewPortForwarder returns a new PortForwarder for the specified codespace. -func NewPortForwarder(ctx context.Context, codespaceConnection *connection.CodespaceConnection) (fwd *PortForwarder, err error) { - return &PortForwarder{ - connection: *codespaceConnection, +func NewPortForwarder(ctx context.Context, codespaceConnection *connection.CodespaceConnection) (fwd PortForwarder, err error) { + return &CodespacesPortForwarder{ + connection: *codespaceConnection, + keepAliveReason: make(chan string, 1), }, nil } -// ForwardAndConnectToPort forwards a port and connects to it via a local TCP port. -func (fwd *PortForwarder) ForwardAndConnectToPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, internal bool) error { - return fwd.ForwardPort(ctx, remotePort, listen, keepAlive, true, internal, "") -} - -// ForwardPort forwards a port and optionally connects to it via a local TCP port. -func (fwd *PortForwarder) ForwardPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, connect bool, internal bool, visibility string) error { - tunnelPort := tunnels.NewTunnelPort(remotePort, "", "", tunnels.TunnelProtocolHttp) - - // If no visibility is provided, Dev Tunnels will use the default (private) - if visibility != "" { - // Check if the requested visibility is allowed - allowed := false - for _, allowedVisibility := range fwd.connection.AllowedPortPrivacySettings { - if allowedVisibility == visibility { - allowed = true - break - } - } - - // If the requested visibility is not allowed, return an error - if !allowed { - return fmt.Errorf("visibility %s is not allowed", visibility) - } - - accessControlEntries := visibilityToAccessControlEntries(visibility) - if len(accessControlEntries) > 0 { - tunnelPort.AccessControl = &tunnels.TunnelAccessControl{ - Entries: accessControlEntries, - } - } +// ForwardPortToListener forwards the specified port to the given TCP listener. +func (fwd *CodespacesPortForwarder) ForwardPortToListener(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) error { + err := fwd.ForwardPort(ctx, opts) + if err != nil { + return fmt.Errorf("error forwarding port: %w", err) } - // Tag the port as internal or user forwarded so we know if it needs to be shown in the UI - if internal { - tunnelPort.Tags = []string{InternalPortTag} - } else { - tunnelPort.Tags = []string{UserForwardedPortTag} - } - - // Create the tunnel port - _, err := fwd.connection.TunnelManager.CreateTunnelPort(ctx, fwd.connection.Tunnel, tunnelPort, fwd.connection.Options) - if err != nil && !strings.Contains(err.Error(), "409") { - return fmt.Errorf("create tunnel port failed: %v", err) - } + // Close the SSH connection when we're done + defer fwd.CloseSSHConnection() done := make(chan error) go func() { - // Connect to the tunnel - err = fwd.connection.TunnelClient.Connect(ctx, "") + // Convert the port number to a uint16 + port, err := convertIntToUint16(opts.Port) if err != nil { - done <- fmt.Errorf("connect failed: %v", err) - return - } - - // Inform the host that we've forwarded the port locally - err = fwd.connection.TunnelClient.RefreshPorts(ctx) - if err != nil { - done <- fmt.Errorf("refresh ports failed: %v", err) - return - } - - // If we don't want to connect to the port, exit early - if !connect { - done <- nil + done <- fmt.Errorf("error converting port: %w", err) return } // Ensure the port is forwarded before connecting - err = fwd.connection.TunnelClient.WaitForForwardedPort(ctx, remotePort) + err = fwd.connection.TunnelClient.WaitForForwardedPort(ctx, port) if err != nil { done <- fmt.Errorf("wait for forwarded port failed: %v", err) return } - // Connect to the forwarded port via a local TCP port - err = fwd.connection.TunnelClient.ConnectToForwardedPort(ctx, listen, remotePort) + // Connect to the forwarded port + err = fwd.connectListenerToForwardedPort(ctx, opts, listener) if err != nil { done <- fmt.Errorf("connect to forwarded port failed: %v", err) - return } - - done <- nil }() + select { case err := <-done: if err != nil { @@ -128,8 +103,131 @@ func (fwd *PortForwarder) ForwardPort(ctx context.Context, remotePort uint16, li } } +// ForwardPort informs the host that we would like to forward the given port. +func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts ForwardPortOpts) error { + // Convert the port number to a uint16 + port, err := convertIntToUint16(opts.Port) + if err != nil { + return fmt.Errorf("error converting port: %w", err) + } + + tunnelPort := tunnels.NewTunnelPort(port, "", "", tunnels.TunnelProtocolHttp) + + // If no visibility is provided, Dev Tunnels will use the default (private) + if opts.Visibility != "" { + // Check if the requested visibility is allowed + allowed := false + for _, allowedVisibility := range fwd.connection.AllowedPortPrivacySettings { + if allowedVisibility == opts.Visibility { + allowed = true + break + } + } + + // If the requested visibility is not allowed, return an error + if !allowed { + return fmt.Errorf("visibility %s is not allowed", opts.Visibility) + } + + accessControlEntries := visibilityToAccessControlEntries(opts.Visibility) + if len(accessControlEntries) > 0 { + tunnelPort.AccessControl = &tunnels.TunnelAccessControl{ + Entries: accessControlEntries, + } + } + } + + // Tag the port as internal or user forwarded so we know if it needs to be shown in the UI + if opts.Internal { + tunnelPort.Tags = []string{InternalPortTag} + } else { + tunnelPort.Tags = []string{UserForwardedPortTag} + } + + // Create the tunnel port + _, err = fwd.connection.TunnelManager.CreateTunnelPort(ctx, fwd.connection.Tunnel, tunnelPort, fwd.connection.Options) + if err != nil && !strings.Contains(err.Error(), "409") { + return fmt.Errorf("create tunnel port failed: %v", err) + } + + // Connect to the tunnel + err = fwd.connection.TunnelClient.Connect(ctx, "") + if err != nil { + return fmt.Errorf("connect failed: %v", err) + } + + // Inform the host that we've forwarded the port locally + err = fwd.connection.TunnelClient.RefreshPorts(ctx) + if err != nil { + fwd.CloseSSHConnection() + return fmt.Errorf("refresh ports failed: %v", err) + } + + return nil +} + +// connectListenerToForwardedPort connects to the forwarded port via a local TCP port. +func (fwd *CodespacesPortForwarder) connectListenerToForwardedPort(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) (err error) { + errc := make(chan error, 1) + sendError := func(err error) { + // Use non-blocking send, to avoid goroutines getting + // stuck in case of concurrent or sequential errors. + select { + case errc <- err: + default: + } + } + go func() { + for { + conn, err := listener.AcceptTCP() + if err != nil { + sendError(err) + return + } + + // Connect to the forwarded port in a goroutine so we can accept new connections + go func() { + if err := fwd.ConnectToForwardedPort(ctx, conn, opts); err != nil { + sendError(err) + } + }() + } + }() + + // Wait for an error or for the context to be cancelled + select { + case err := <-errc: + return err + case <-ctx.Done(): + return ctx.Err() // canceled + } +} + +// ConnectToForwardedPort connects to the forwarded port via a given ReadWriteCloser. +// Optionally, it detects traffic over the connection and sends activity signals to the server to keep the codespace from shutting down. +func (fwd *CodespacesPortForwarder) ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, opts ForwardPortOpts) error { + // Create a traffic monitor to keep the session alive + if opts.KeepAlive { + conn = newTrafficMonitor(conn, fwd) + } + + // Convert the port number to a uint16 + port, err := convertIntToUint16(opts.Port) + if err != nil { + return fmt.Errorf("error converting port: %w", err) + } + + // Connect to the forwarded port + err = fwd.connection.TunnelClient.ConnectToForwardedPort(ctx, conn, port) + if err != nil { + return fmt.Errorf("error connecting to forwarded port: %w", err) + } + + return nil +} + // ListPorts fetches the list of ports that are currently forwarded. -func (fwd *PortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.TunnelPort, err error) { +func (fwd *CodespacesPortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.TunnelPort, err error) { ports, err = fwd.connection.TunnelManager.ListTunnelPorts(ctx, fwd.connection.Tunnel, fwd.connection.Options) if err != nil { return nil, fmt.Errorf("error listing ports: %w", err) @@ -139,7 +237,7 @@ func (fwd *PortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.Tunne } // UpdatePortVisibility changes the visibility (private, org, public) of the specified port. -func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error { +func (fwd *CodespacesPortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error { tunnelPort, err := fwd.connection.TunnelManager.GetTunnelPort(ctx, fwd.connection.Tunnel, remotePort, fwd.connection.Options) if err != nil { return fmt.Errorf("error getting tunnel port: %w", err) @@ -165,6 +263,9 @@ func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort i return } + // Close the SSH connection when we're done + defer fwd.CloseSSHConnection() + // Inform the host that we've deleted the port err = fwd.connection.TunnelClient.RefreshPorts(ctx) if err != nil { @@ -172,6 +273,13 @@ func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort i return } + // Re-forward the port with the updated visibility + err = fwd.ForwardPort(ctx, ForwardPortOpts{Port: remotePort, Visibility: visibility}) + if err != nil { + done <- fmt.Errorf("error forwarding port: %w", err) + return + } + done <- nil }() @@ -179,13 +287,10 @@ func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort i select { case err := <-done: if err != nil { - return fmt.Errorf("error connecting to tunnel: %w", err) - } + // If we fail to re-forward the port, we need to forward again with the original visibility so the port is still accessible + _ = fwd.ForwardPort(ctx, ForwardPortOpts{Port: remotePort, Visibility: AccessControlEntriesToVisibility(tunnelPort.AccessControl.Entries)}) - // Re-forward the port with the updated visibility - err = fwd.ForwardPort(ctx, uint16(remotePort), nil, false, false, false, visibility) - if err != nil { - return fmt.Errorf("error forwarding port: %w", err) + return fmt.Errorf("error connecting to tunnel: %w", err) } return nil @@ -194,6 +299,27 @@ func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort i } } +// KeepAlive accepts a reason that is retained if there is no active reason +// to send to the server. +func (fwd *CodespacesPortForwarder) KeepAlive(reason string) { + select { + case fwd.keepAliveReason <- reason: + default: + // there is already an active keep alive reason + // so we can ignore this one + } +} + +// GetKeepAliveReason fetches the keep alive reason from the channel and returns it. +func (fwd *CodespacesPortForwarder) GetKeepAliveReason() string { + return <-fwd.keepAliveReason +} + +// Close closes the port forwarder's tunnel client connection. +func (fwd *CodespacesPortForwarder) CloseSSHConnection() { + _ = fwd.connection.TunnelClient.Close() +} + // AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value. func AccessControlEntriesToVisibility(accessControlEntries []tunnels.TunnelAccessControlEntry) string { for _, entry := range accessControlEntries { @@ -251,3 +377,45 @@ func IsInternalPort(port *tunnels.TunnelPort) bool { return false } + +// convertIntToUint16 converts the given int to a uint16. +func convertIntToUint16(port int) (uint16, error) { + var updatedPort uint16 + if port >= 0 && port <= 65535 { + updatedPort = uint16(port) + } else { + return 0, fmt.Errorf("invalid port number: %d", port) + } + + return updatedPort, nil +} + +// trafficMonitor implements io.Reader. It keeps the session alive by notifying +// it of the traffic type during Read operations. +type trafficMonitor struct { + rwc io.ReadWriteCloser + fwd PortForwarder +} + +// newTrafficMonitor returns a trafficMonitor for the specified codespace connection. +// It wraps the provided io.ReaderWriteCloser with its own Read/Write/Close methods. +func newTrafficMonitor(rwc io.ReadWriteCloser, fwd PortForwarder) *trafficMonitor { + return &trafficMonitor{rwc, fwd} +} + +// Read wraps the underlying ReadWriteCloser's Read method and keeps the session alive with the "input" traffic type. +func (t *trafficMonitor) Read(p []byte) (n int, err error) { + t.fwd.KeepAlive(trafficTypeInput) + return t.rwc.Read(p) +} + +// Write wraps the underlying ReadWriteCloser's Write method and keeps the session alive with the "output" traffic type. +func (t *trafficMonitor) Write(p []byte) (n int, err error) { + t.fwd.KeepAlive(trafficTypeOutput) + return t.rwc.Write(p) +} + +// Close closes the underlying ReadWriteCloser. +func (t *trafficMonitor) Close() error { + return t.rwc.Close() +} diff --git a/internal/codespaces/rpc/invoker.go b/internal/codespaces/rpc/invoker.go index 39ae5ab19..7701d6092 100644 --- a/internal/codespaces/rpc/invoker.go +++ b/internal/codespaces/rpc/invoker.go @@ -12,10 +12,10 @@ import ( "strings" "time" + "github.com/cli/cli/v2/internal/codespaces/portforwarder" "github.com/cli/cli/v2/internal/codespaces/rpc/codespace" "github.com/cli/cli/v2/internal/codespaces/rpc/jupyter" "github.com/cli/cli/v2/internal/codespaces/rpc/ssh" - "github.com/cli/cli/v2/pkg/liveshare" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" @@ -47,7 +47,7 @@ type Invoker interface { type invoker struct { conn *grpc.ClientConn - session liveshare.LiveshareSession + fwd portforwarder.PortForwarder listener net.Listener jupyterClient jupyter.JupyterServerHostClient codespaceClient codespace.CodespaceHostClient @@ -56,11 +56,11 @@ type invoker struct { } // Connects to the internal RPC server and returns a new invoker for it -func CreateInvoker(ctx context.Context, session liveshare.LiveshareSession) (Invoker, error) { +func CreateInvoker(ctx context.Context, fwd portforwarder.PortForwarder) (Invoker, error) { ctx, cancel := context.WithTimeout(ctx, ConnectionTimeout) defer cancel() - invoker, err := connect(ctx, session) + invoker, err := connect(ctx, fwd) if err != nil { return nil, fmt.Errorf("error connecting to internal server: %w", err) } @@ -69,7 +69,7 @@ func CreateInvoker(ctx context.Context, session liveshare.LiveshareSession) (Inv } // Finds a free port to listen on and creates a new RPC invoker that connects to that port -func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, error) { +func connect(ctx context.Context, fwd portforwarder.PortForwarder) (Invoker, error) { listener, err := listenTCP() if err != nil { return nil, err @@ -77,7 +77,7 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, localAddress := listener.Addr().String() invoker := &invoker{ - session: session, + fwd: fwd, listener: listener, } @@ -100,8 +100,12 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, // Tunnel the remote gRPC server port to the local port go func() { - fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true) - ch <- fwd.ForwardToListener(pfctx, listener) + // Start forwarding the port locally + opts := portforwarder.ForwardPortOpts{ + Port: codespacesInternalPort, + Internal: true, + } + ch <- fwd.ForwardPortToListener(pfctx, opts, listener) }() var conn *grpc.ClientConn @@ -262,7 +266,7 @@ func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) { case <-ctx.Done(): return case <-ticker.C: - reason := i.session.GetKeepAliveReason() + reason := i.fwd.GetKeepAliveReason() _ = i.notifyCodespaceOfClientActivity(ctx, reason) } } diff --git a/internal/codespaces/rpc/invoker_test.go b/internal/codespaces/rpc/invoker_test.go index ba3e13ac3..d9b271c54 100644 --- a/internal/codespaces/rpc/invoker_test.go +++ b/internal/codespaces/rpc/invoker_test.go @@ -72,7 +72,8 @@ func createTestInvoker(t *testing.T, server *mockServer) (Invoker, func(), error listener.Close() } - invoker, err := CreateInvoker(context.Background(), &rpctest.Session{}) + // Create a new invoker with a mock port forwarder + invoker, err := CreateInvoker(context.Background(), rpctest.PortForwarder{}) if err != nil { close() return nil, nil, fmt.Errorf("error connecting to internal server: %w", err) diff --git a/internal/codespaces/rpc/test/channel.go b/internal/codespaces/rpc/test/channel.go deleted file mode 100644 index eef42c4aa..000000000 --- a/internal/codespaces/rpc/test/channel.go +++ /dev/null @@ -1,34 +0,0 @@ -package test - -import ( - "io" - "net" -) - -type Channel struct { - conn net.Conn -} - -func (c *Channel) Read(data []byte) (int, error) { - return c.conn.Read(data) -} - -func (c *Channel) Write(data []byte) (int, error) { - return c.conn.Write(data) -} - -func (c *Channel) Close() error { - return c.conn.Close() -} - -func (c *Channel) CloseWrite() error { - return nil -} - -func (c *Channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { - return false, nil -} - -func (c *Channel) Stderr() io.ReadWriter { - return nil -} diff --git a/internal/codespaces/rpc/test/port_forwarder.go b/internal/codespaces/rpc/test/port_forwarder.go new file mode 100644 index 000000000..4993ac0a1 --- /dev/null +++ b/internal/codespaces/rpc/test/port_forwarder.go @@ -0,0 +1,78 @@ +package test + +import ( + "context" + "fmt" + "io" + "net" + + "github.com/cli/cli/v2/internal/codespaces/portforwarder" + "github.com/microsoft/dev-tunnels/go/tunnels" +) + +type PortForwarder struct{} + +// Close implements portforwarder.PortForwarder. +func (PortForwarder) CloseSSHConnection() { + panic("unimplemented") +} + +// ConnectToForwardedPort implements portforwarder.PortForwarder. +func (PortForwarder) ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, opts portforwarder.ForwardPortOpts) error { + panic("unimplemented") +} + +// ForwardPort implements portforwarder.PortForwarder. +func (PortForwarder) ForwardPort(ctx context.Context, opts portforwarder.ForwardPortOpts) error { + panic("unimplemented") +} + +// GetKeepAliveReason implements portforwarder.PortForwarder. +func (PortForwarder) GetKeepAliveReason() string { + panic("unimplemented") +} + +// KeepAlive implements portforwarder.PortForwarder. +func (PortForwarder) KeepAlive(reason string) { + panic("unimplemented") +} + +// ForwardPortToListener implements portforwarder.PortForwarder. +func (PortForwarder) ForwardPortToListener(ctx context.Context, opts portforwarder.ForwardPortOpts, listener *net.TCPListener) error { + // Start forwarding the port locally + hostConn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", opts.Port)) + if err != nil { + return err + } + + // Accept the connection from the listener + listenerConn, err := listener.Accept() + if err != nil { + return err + } + + // Copy data between the two connections + go func() { + _, _ = io.Copy(hostConn, listenerConn) + hostConn.Close() + }() + go func() { + _, _ = io.Copy(listenerConn, hostConn) + listenerConn.Close() + }() + + // ForwardPortToListener typically blocks until the context is cancelled so we need to do the same + <-ctx.Done() + + return nil +} + +// ListPorts implements portforwarder.PortForwarder. +func (PortForwarder) ListPorts(ctx context.Context) ([]*tunnels.TunnelPort, error) { + panic("unimplemented") +} + +// UpdatePortVisibility implements portforwarder.PortForwarder. +func (PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error { + panic("unimplemented") +} diff --git a/internal/codespaces/rpc/test/session.go b/internal/codespaces/rpc/test/session.go deleted file mode 100644 index 531d4c33f..000000000 --- a/internal/codespaces/rpc/test/session.go +++ /dev/null @@ -1,43 +0,0 @@ -package test - -import ( - "context" - "fmt" - "net" - - "github.com/cli/cli/v2/pkg/liveshare" - "golang.org/x/crypto/ssh" -) - -type Session struct { - channel ssh.Channel -} - -func (*Session) Close() error { - panic("unimplemented") -} - -func (*Session) GetSharedServers(context.Context) ([]*liveshare.Port, error) { - panic("unimplemented") -} - -func (s *Session) KeepAlive(reason string) { -} - -func (s *Session) GetKeepAliveReason() string { - return "" -} - -func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) { - conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) - if err != nil { - return liveshare.ChannelID{}, err - } - s.channel = &Channel{conn} - return liveshare.ChannelID{}, nil -} - -// Creates mock SSH channel connected to the mock gRPC server -func (s *Session) OpenStreamingChannel(ctx context.Context, id liveshare.ChannelID) (ssh.Channel, error) { - return s.channel, nil -} diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 3a2872365..afbdf4673 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -6,13 +6,12 @@ import ( "encoding/json" "fmt" "io" - "log" "time" "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/cli/cli/v2/internal/codespaces/portforwarder" "github.com/cli/cli/v2/internal/codespaces/rpc" "github.com/cli/cli/v2/internal/text" - "github.com/cli/cli/v2/pkg/liveshare" ) // PostCreateStateStatus is a string value representing the different statuses a state can have. @@ -39,17 +38,15 @@ type PostCreateState struct { // and calls the supplied poller for each batch of state changes. // It runs until it encounters an error, including cancellation of the context. func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { - noopLogger := log.New(io.Discard, "", 0) - - session, err := ConnectToLiveshare(ctx, progress, noopLogger, apiClient, codespace) + codespaceConnection, err := GetCodespaceConnection(ctx, progress, apiClient, codespace) if err != nil { - return fmt.Errorf("connect to codespace: %w", err) + return fmt.Errorf("error connecting to codespace: %w", err) + } + + fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection) + if err != nil { + return fmt.Errorf("failed to create port forwarder: %w", err) } - defer func() { - if closeErr := session.Close(); err == nil { - err = closeErr - } - }() // Ensure local port is listening before client (getPostCreateOutput) connects. listen, localPort, err := ListenTCP(0, false) @@ -58,7 +55,7 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl } progress.StartProgressIndicatorWithLabel("Fetching SSH Details") - invoker, err := rpc.CreateInvoker(ctx, session) + invoker, err := rpc.CreateInvoker(ctx, fwd) if err != nil { return err } @@ -73,8 +70,11 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl progress.StartProgressIndicatorWithLabel("Fetching status") tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness go func() { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false) - tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil + opts := portforwarder.ForwardPortOpts{ + Port: remoteSSHServerPort, + Internal: true, + } + tunnelClosed <- fwd.ForwardPortToListener(ctx, opts, listen) }() t := time.NewTicker(1 * time.Second) diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index 86493ce8e..3ad3463e6 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -17,10 +17,8 @@ import ( "github.com/AlecAivazis/survey/v2/terminal" clicontext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/internal/browser" - "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/api" "github.com/cli/cli/v2/pkg/iostreams" - "github.com/cli/cli/v2/pkg/liveshare" "github.com/spf13/cobra" "golang.org/x/term" ) @@ -65,28 +63,6 @@ func (a *App) RunWithProgress(label string, run func() error) error { return a.io.RunWithProgress(label, run) } -// Connects to a codespace using Live Share and returns that session -func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App, debug bool, debugFile string) (session *liveshare.Session, err error) { - liveshareLogger := noopLogger() - if debug { - debugLogger, err := newFileLogger(debugFile) - if err != nil { - return nil, fmt.Errorf("couldn't create file logger: %w", err) - } - defer safeClose(debugLogger, &err) - - liveshareLogger = debugLogger.Logger - a.errLogger.Printf("Debug file located at: %s", debugLogger.Name()) - } - - session, err = codespaces.ConnectToLiveshare(ctx, a, liveshareLogger, a.apiClient, codespace) - if err != nil { - return nil, fmt.Errorf("failed to connect to Live Share: %w", err) - } - - return session, nil -} - //go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient type apiClient interface { ServerURL() string @@ -201,10 +177,6 @@ func noArgsConstraint(cmd *cobra.Command, args []string) error { return nil } -func noopLogger() *log.Logger { - return log.New(io.Discard, "", 0) -} - type codespace struct { *api.Codespace } diff --git a/pkg/cmd/codespace/jupyter.go b/pkg/cmd/codespace/jupyter.go index 91c798eda..3546837fa 100644 --- a/pkg/cmd/codespace/jupyter.go +++ b/pkg/cmd/codespace/jupyter.go @@ -7,8 +7,8 @@ import ( "strings" "github.com/cli/cli/v2/internal/codespaces" + "github.com/cli/cli/v2/internal/codespaces/portforwarder" "github.com/cli/cli/v2/internal/codespaces/rpc" - "github.com/cli/cli/v2/pkg/liveshare" "github.com/spf13/cobra" ) @@ -39,11 +39,15 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err return err } - session, err := startLiveShareSession(ctx, codespace, a, false, "") + codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace) if err != nil { - return err + return fmt.Errorf("error connecting to codespace: %w", err) + } + + fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection) + if err != nil { + return fmt.Errorf("failed to create port forwarder: %w", err) } - defer safeClose(session, &err) var ( invoker rpc.Invoker @@ -51,7 +55,7 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err serverUrl string ) err = a.RunWithProgress("Starting JupyterLab on codespace", func() (err error) { - invoker, err = rpc.CreateInvoker(ctx, session) + invoker, err = rpc.CreateInvoker(ctx, fwd) if err != nil { return } @@ -76,8 +80,10 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err tunnelClosed := make(chan error, 1) go func() { - fwd := liveshare.NewPortForwarder(session, "jupyter", serverPort, true) - tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil + opts := portforwarder.ForwardPortOpts{ + Port: serverPort, + } + tunnelClosed <- fwd.ForwardPortToListener(ctx, opts, listen) }() // Server URL contains an authentication token that must be preserved diff --git a/pkg/cmd/codespace/logs.go b/pkg/cmd/codespace/logs.go index 3ec950849..13d5ce185 100644 --- a/pkg/cmd/codespace/logs.go +++ b/pkg/cmd/codespace/logs.go @@ -5,8 +5,8 @@ import ( "fmt" "github.com/cli/cli/v2/internal/codespaces" + "github.com/cli/cli/v2/internal/codespaces/portforwarder" "github.com/cli/cli/v2/internal/codespaces/rpc" - "github.com/cli/cli/v2/pkg/liveshare" "github.com/spf13/cobra" ) @@ -42,11 +42,15 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool return err } - session, err := startLiveShareSession(ctx, codespace, a, false, "") + codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace) if err != nil { - return err + return fmt.Errorf("error connecting to codespace: %w", err) + } + + fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection) + if err != nil { + return fmt.Errorf("failed to create port forwarder: %w", err) } - defer safeClose(session, &err) // Ensure local port is listening before client (getPostCreateOutput) connects. listen, localPort, err := codespaces.ListenTCP(0, false) @@ -57,7 +61,7 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool remoteSSHServerPort, sshUser := 0, "" err = a.RunWithProgress("Fetching SSH Details", func() (err error) { - invoker, err := rpc.CreateInvoker(ctx, session) + invoker, err := rpc.CreateInvoker(ctx, fwd) if err != nil { return } @@ -85,8 +89,11 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool tunnelClosed := make(chan error, 1) go func() { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false) - tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil + opts := portforwarder.ForwardPortOpts{ + Port: remoteSSHServerPort, + Internal: true, + } + tunnelClosed <- fwd.ForwardPortToListener(ctx, opts, listen) }() cmdDone := make(chan error, 1) diff --git a/pkg/cmd/codespace/ports.go b/pkg/cmd/codespace/ports.go index b7eb96537..38688cd3a 100644 --- a/pkg/cmd/codespace/ports.go +++ b/pkg/cmd/codespace/ports.go @@ -345,7 +345,11 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por if err != nil { return fmt.Errorf("failed to create port forwarder: %w", err) } - return fwd.ForwardAndConnectToPort(ctx, uint16(pair.remote), listen, false, false) + + opts := portforwarder.ForwardPortOpts{ + Port: pair.remote, + } + return fwd.ForwardPortToListener(ctx, opts, listen) }) } return group.Wait() // first error diff --git a/pkg/cmd/codespace/rebuild.go b/pkg/cmd/codespace/rebuild.go index 17e00670b..464c58502 100644 --- a/pkg/cmd/codespace/rebuild.go +++ b/pkg/cmd/codespace/rebuild.go @@ -4,7 +4,9 @@ import ( "context" "fmt" + "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/cli/cli/v2/internal/codespaces/portforwarder" "github.com/cli/cli/v2/internal/codespaces/rpc" "github.com/spf13/cobra" ) @@ -49,13 +51,17 @@ func (a *App) Rebuild(ctx context.Context, selector *CodespaceSelector, full boo return nil } - session, err := startLiveShareSession(ctx, codespace, a, false, "") + codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace) if err != nil { - return fmt.Errorf("starting Live Share session: %w", err) + return fmt.Errorf("error connecting to codespace: %w", err) } - defer safeClose(session, &err) - invoker, err := rpc.CreateInvoker(ctx, session) + fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection) + if err != nil { + return fmt.Errorf("failed to create port forwarder: %w", err) + } + + invoker, err := rpc.CreateInvoker(ctx, fwd) if err != nil { return err } diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 40eebcc63..42571b52e 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -6,7 +6,7 @@ import ( "context" "errors" "fmt" - "log" + "io" "os" "os/exec" "path" @@ -18,10 +18,10 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/cli/cli/v2/internal/codespaces/portforwarder" "github.com/cli/cli/v2/internal/codespaces/rpc" "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/pkg/cmdutil" - "github.com/cli/cli/v2/pkg/liveshare" "github.com/cli/cli/v2/pkg/ssh" "github.com/cli/safeexec" "github.com/spf13/cobra" @@ -144,6 +144,24 @@ func newSSHCmd(app *App) *cobra.Command { return sshCmd } +type combinedReadWriteHalfCloser struct { + io.ReadCloser + io.WriteCloser +} + +func (crwc *combinedReadWriteHalfCloser) Close() error { + werr := crwc.WriteCloser.Close() + rerr := crwc.ReadCloser.Close() + if werr != nil { + return werr + } + return rerr +} + +func (crwc *combinedReadWriteHalfCloser) CloseWrite() error { + return crwc.WriteCloser.Close() +} + // SSH opens an ssh session or runs an ssh command in a codespace. func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err error) { // Ensure all child tasks (e.g. port forwarding) terminate before return. @@ -175,11 +193,15 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e return err } - session, err := startLiveShareSession(ctx, codespace, a, opts.debug, opts.debugFile) + codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace) if err != nil { - return err + return fmt.Errorf("error connecting to codespace: %w", err) + } + + fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection) + if err != nil { + return fmt.Errorf("failed to create port forwarder: %w", err) } - defer safeClose(session, &err) var ( invoker rpc.Invoker @@ -187,7 +209,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e sshUser string ) err = a.RunWithProgress("Fetching SSH Details", func() (err error) { - invoker, err = rpc.CreateInvoker(ctx, session) + invoker, err = rpc.CreateInvoker(ctx, fwd) if err != nil { return } @@ -203,9 +225,28 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e } if opts.stdio { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true) - stdio := liveshare.NewReadWriteHalfCloser(os.Stdin, os.Stdout) - err := fwd.Forward(ctx, stdio) // always non-nil + stdio := &combinedReadWriteHalfCloser{os.Stdin, os.Stdout} + opts := portforwarder.ForwardPortOpts{ + Port: remoteSSHServerPort, + Internal: true, + KeepAlive: true, + } + + // Forward the port + err = fwd.ForwardPort(ctx, opts) + if err != nil { + return fmt.Errorf("failed to forward port: %w", err) + } + + // Close the SSH connection when we're done + defer fwd.CloseSSHConnection() + + // Connect to the forwarded port + err = fwd.ConnectToForwardedPort(ctx, stdio, opts) + if err != nil { + return fmt.Errorf("failed to connect to forwarded port: %w", err) + } + return fmt.Errorf("tunnel closed: %w", err) } @@ -227,8 +268,12 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e tunnelClosed := make(chan error, 1) go func() { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true) - tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil + opts := portforwarder.ForwardPortOpts{ + Port: remoteSSHServerPort, + Internal: true, + KeepAlive: true, + } + tunnelClosed <- fwd.ForwardPortToListener(ctx, opts, listen) }() shellClosed := make(chan error, 1) @@ -526,27 +571,36 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro result := sshResult{} defer wg.Done() - session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, cs) + codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, cs) if err != nil { result.err = fmt.Errorf("error connecting to codespace: %w", err) - } else { - defer safeClose(session, &err) - - invoker, err := rpc.CreateInvoker(ctx, session) - if err != nil { - result.err = fmt.Errorf("error connecting to codespace: %w", err) - } else { - defer safeClose(invoker, &err) - - _, result.user, err = invoker.StartSSHServer(ctx) - if err != nil { - result.err = fmt.Errorf("error getting ssh server details: %w", err) - } else { - result.codespace = cs - } - } + sshUsers <- result + return } + fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection) + if err != nil { + result.err = fmt.Errorf("failed to create port forwarder: %w", err) + sshUsers <- result + return + } + + invoker, err := rpc.CreateInvoker(ctx, fwd) + if err != nil { + result.err = fmt.Errorf("error connecting to codespace: %w", err) + sshUsers <- result + return + } + defer safeClose(invoker, &err) + + _, result.user, err = invoker.StartSSHServer(ctx) + if err != nil { + result.err = fmt.Errorf("error getting ssh server details: %w", err) + sshUsers <- result + return + } + + result.codespace = cs sshUsers <- result }() } @@ -722,43 +776,3 @@ func (a *App) Copy(ctx context.Context, args []string, opts cpOptions) error { } return a.SSH(ctx, nil, opts.sshOptions) } - -// fileLogger is a wrapper around an log.Logger configured to write -// to a file. It exports two additional methods to get the log file name -// and close the file handle when the operation is finished. -type fileLogger struct { - *log.Logger - - f *os.File -} - -// newFileLogger creates a new fileLogger. It returns an error if the file -// cannot be created. The file is created on the specified path, if the path -// is empty it is created in the temporary directory. -func newFileLogger(file string) (fl *fileLogger, err error) { - var f *os.File - if file == "" { - f, err = os.CreateTemp("", "") - if err != nil { - return nil, fmt.Errorf("failed to create tmp file: %w", err) - } - } else { - f, err = os.Create(file) - if err != nil { - return nil, err - } - } - - return &fileLogger{ - Logger: log.New(f, "", log.LstdFlags), - f: f, - }, nil -} - -func (fl *fileLogger) Name() string { - return fl.f.Name() -} - -func (fl *fileLogger) Close() error { - return fl.f.Close() -} diff --git a/pkg/liveshare/client.go b/pkg/liveshare/client.go deleted file mode 100644 index cbfa5d458..000000000 --- a/pkg/liveshare/client.go +++ /dev/null @@ -1,130 +0,0 @@ -// Package liveshare is a Go client library for the Visual Studio Live Share -// service, which provides collaborative, distributed editing and debugging. -// See https://docs.microsoft.com/en-us/visualstudio/liveshare for an overview. -// -// It provides the ability for a Go program to connect to a Live Share -// workspace (Connect), to expose a TCP port on a remote host -// (UpdateSharedVisibility), to start an SSH server listening on an -// exposed port (StartSSHServer), and to forward connections between -// the remote port and a local listening TCP port (ForwardToListener) -// or a local Go reader/writer (Forward). -package liveshare - -import ( - "context" - "crypto/tls" - "errors" - "fmt" - "net/url" - "strings" - - "github.com/opentracing/opentracing-go" -) - -type logger interface { - Println(v ...interface{}) - Printf(f string, v ...interface{}) -} - -// An Options specifies Live Share connection parameters. -type Options struct { - SessionID string - SessionToken string // token for SSH session - RelaySAS string - RelayEndpoint string - HostPublicKeys []string - Logger logger // required - TLSConfig *tls.Config // (optional) -} - -// uri returns a websocket URL for the specified options. -func (opts *Options) uri(action string) (string, error) { - if opts.SessionID == "" { - return "", errors.New("SessionID is required") - } - if opts.RelaySAS == "" { - return "", errors.New("RelaySAS is required") - } - if opts.RelayEndpoint == "" { - return "", errors.New("RelayEndpoint is required") - } - - sas := url.QueryEscape(opts.RelaySAS) - uri := opts.RelayEndpoint - - if strings.HasPrefix(uri, "http:") { - uri = strings.Replace(uri, "http:", "ws:", 1) - } else { - uri = strings.Replace(uri, "sb:", "wss:", -1) - } - - uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) - uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas - return uri, nil -} - -// Connect connects to a Live Share workspace specified by the -// options, and returns a session representing the connection. -// The caller must call the session's Close method to end the session. -func Connect(ctx context.Context, opts Options) (*Session, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") - defer span.Finish() - - uri, err := opts.uri("connect") - if err != nil { - return nil, err - } - - sock := newSocket(uri, opts.TLSConfig) - if err := sock.connect(ctx); err != nil { - return nil, fmt.Errorf("error connecting websocket: %w", err) - } - - if opts.SessionToken == "" { - return nil, errors.New("SessionToken is required") - } - ssh := newSSHSession(opts.SessionToken, opts.HostPublicKeys, sock) - if err := ssh.connect(ctx); err != nil { - return nil, fmt.Errorf("error connecting to ssh session: %w", err) - } - - rpc := newRPCClient(ssh) - rpc.connect(ctx) - - args := joinWorkspaceArgs{ - ID: opts.SessionID, - ConnectionMode: "local", - JoiningUserSessionToken: opts.SessionToken, - ClientCapabilities: clientCapabilities{ - IsNonInteractive: false, - }, - } - var result joinWorkspaceResult - if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { - return nil, fmt.Errorf("error joining Live Share workspace: %w", err) - } - - s := &Session{ - ssh: ssh, - rpc: rpc, - keepAliveReason: make(chan string, 1), - logger: opts.Logger, - } - - return s, nil -} - -type clientCapabilities struct { - IsNonInteractive bool `json:"isNonInteractive"` -} - -type joinWorkspaceArgs struct { - ID string `json:"id"` - ConnectionMode string `json:"connectionMode"` - JoiningUserSessionToken string `json:"joiningUserSessionToken"` - ClientCapabilities clientCapabilities `json:"clientCapabilities"` -} - -type joinWorkspaceResult struct { - SessionNumber int `json:"sessionNumber"` -} diff --git a/pkg/liveshare/client_test.go b/pkg/liveshare/client_test.go deleted file mode 100644 index 4ccc42729..000000000 --- a/pkg/liveshare/client_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package liveshare - -import ( - "context" - "crypto/tls" - "encoding/json" - "errors" - "fmt" - "strings" - "testing" - - livesharetest "github.com/cli/cli/v2/pkg/liveshare/test" - "github.com/sourcegraph/jsonrpc2" -) - -func TestConnect(t *testing.T) { - opts := Options{ - SessionID: "session-id", - SessionToken: "session-token", - RelaySAS: "relay-sas", - HostPublicKeys: []string{livesharetest.SSHPublicKey}, - Logger: newMockLogger(), - } - joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { - var joinWorkspaceReq joinWorkspaceArgs - if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { - return nil, fmt.Errorf("error unmarshalling req: %w", err) - } - if joinWorkspaceReq.ID != opts.SessionID { - return nil, errors.New("connection session id does not match") - } - if joinWorkspaceReq.ConnectionMode != "local" { - return nil, errors.New("connection mode is not local") - } - if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken { - return nil, errors.New("connection user token does not match") - } - if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false { - return nil, errors.New("non interactive is not false") - } - return joinWorkspaceResult{1}, nil - } - - server, err := livesharetest.NewServer( - livesharetest.WithPassword(opts.SessionToken), - livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), - livesharetest.WithRelaySAS(opts.RelaySAS), - ) - if err != nil { - t.Errorf("error creating Live Share server: %v", err) - } - defer server.Close() - opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") - - ctx := context.Background() - - opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} - - done := make(chan error) - go func() { - _, err := Connect(ctx, opts) // ignore session - done <- err - }() - - select { - case err := <-server.Err(): - t.Errorf("error from server: %v", err) - case err := <-done: - if err != nil { - t.Errorf("error from client: %v", err) - } - } -} diff --git a/pkg/liveshare/options_test.go b/pkg/liveshare/options_test.go deleted file mode 100644 index 830c59104..000000000 --- a/pkg/liveshare/options_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package liveshare - -import ( - "context" - "testing" -) - -func TestBadOptions(t *testing.T) { - goodOptions := Options{ - SessionID: "sess-id", - SessionToken: "sess-token", - RelaySAS: "sas", - RelayEndpoint: "endpoint", - } - - opts := goodOptions - opts.SessionID = "" - checkBadOptions(t, opts) - - opts = goodOptions - opts.SessionToken = "" - checkBadOptions(t, opts) - - opts = goodOptions - opts.RelaySAS = "" - checkBadOptions(t, opts) - - opts = goodOptions - opts.RelayEndpoint = "" - checkBadOptions(t, opts) - - opts = Options{} - checkBadOptions(t, opts) -} - -func checkBadOptions(t *testing.T, opts Options) { - if _, err := Connect(context.Background(), opts); err == nil { - t.Errorf("Connect(%+v): no error", opts) - } -} - -func TestOptionsURI(t *testing.T) { - opts := Options{ - SessionID: "sess-id", - SessionToken: "sess-token", - RelaySAS: "sas", - RelayEndpoint: "sb://endpoint/.net/liveshare", - } - uri, err := opts.uri("connect") - if err != nil { - t.Fatal(err) - } - if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" { - t.Errorf("uri is not correct, got: '%v'", uri) - } -} diff --git a/pkg/liveshare/port_forwarder.go b/pkg/liveshare/port_forwarder.go deleted file mode 100644 index 5f2742209..000000000 --- a/pkg/liveshare/port_forwarder.go +++ /dev/null @@ -1,241 +0,0 @@ -package liveshare - -import ( - "context" - "fmt" - "io" - "net" - - "github.com/opentracing/opentracing-go" - "golang.org/x/crypto/ssh" -) - -type portForwardingSession interface { - StartSharing(context.Context, string, int) (ChannelID, error) - OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error) - KeepAlive(string) -} - -type ReadWriteHalfCloser interface { - io.ReadWriteCloser - CloseWrite() error -} - -type combinedReadWriteHalfCloser struct { - io.ReadCloser - io.WriteCloser -} - -func NewReadWriteHalfCloser(reader io.ReadCloser, writer io.WriteCloser) ReadWriteHalfCloser { - return &combinedReadWriteHalfCloser{reader, writer} -} - -func (crwc *combinedReadWriteHalfCloser) Close() error { - werr := crwc.WriteCloser.Close() - rerr := crwc.ReadCloser.Close() - if werr != nil { - return werr - } - return rerr -} - -func (crwc *combinedReadWriteHalfCloser) CloseWrite() error { - return crwc.WriteCloser.Close() -} - -// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote -// container to a local destination such as a network port or Go reader/writer. -type PortForwarder struct { - session portForwardingSession - name string - remotePort int - keepAlive bool -} - -// NewPortForwarder returns a new PortForwarder for the specified -// remote port and Live Share session. The name describes the purpose -// of the remote port or service. The keepAlive flag indicates whether -// the session should be kept alive with port forwarding traffic. -func NewPortForwarder(session portForwardingSession, name string, remotePort int, keepAlive bool) *PortForwarder { - return &PortForwarder{ - session: session, - name: name, - remotePort: remotePort, - keepAlive: keepAlive, - } -} - -// ForwardToListener forwards traffic between the container's remote -// port and a local port, which must already be listening for -// connections. (Accepting a listener rather than a port number avoids -// races against other processes opening ports, and against a client -// connecting to the socket prematurely.) -// -// ForwardToListener accepts and handles connections on the local port -// until it encounters the first error, which may include context -// cancellation. Its error result is always non-nil. The caller is -// responsible for closing the listening port. -func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen *net.TCPListener) (err error) { - id, err := fwd.shareRemotePort(ctx) - if err != nil { - return err - } - - errc := make(chan error, 1) - sendError := func(err error) { - // Use non-blocking send, to avoid goroutines getting - // stuck in case of concurrent or sequential errors. - select { - case errc <- err: - default: - } - } - go func() { - for { - conn, err := listen.AcceptTCP() - if err != nil { - sendError(err) - return - } - - go func() { - if err := fwd.handleConnection(ctx, id, conn); err != nil { - sendError(err) - } - }() - } - }() - - return awaitError(ctx, errc) -} - -// Forward forwards traffic between the container's remote port and -// the specified read/write stream. On return, the stream is closed. -func (fwd *PortForwarder) Forward(ctx context.Context, conn ReadWriteHalfCloser) error { - id, err := fwd.shareRemotePort(ctx) - if err != nil { - conn.Close() - return err - } - - // Create buffered channel so that send doesn't get stuck after context cancellation. - errc := make(chan error, 1) - go func() { - errc <- fwd.handleConnection(ctx, id, conn) - }() - return awaitError(ctx, errc) -} - -func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) { - id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort) - if err != nil { - err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err) - } - - return id, err -} - -func awaitError(ctx context.Context, errc <-chan error) error { - select { - case err := <-errc: - return err - case <-ctx.Done(): - return ctx.Err() // canceled - } -} - -type trafficMonitorSession interface { - KeepAlive(string) -} - -// trafficMonitor implements io.Reader. It keeps the session alive by notifying -// it of the traffic type during Read operations. -type trafficMonitor struct { - reader io.Reader - - session trafficMonitorSession - trafficType string -} - -// newTrafficMonitor returns a new trafficMonitor for the specified -// session and traffic type. It wraps the provided io.Reader with its own -// Read method. -func newTrafficMonitor(reader io.Reader, session trafficMonitorSession, trafficType string) *trafficMonitor { - return &trafficMonitor{reader, session, trafficType} -} - -func (t *trafficMonitor) Read(p []byte) (n int, err error) { - t.session.KeepAlive(t.trafficType) - return t.reader.Read(p) -} - -// handleConnection handles forwarding for a single accepted connection, then closes it. -func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn ReadWriteHalfCloser) (err error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection") - defer span.Finish() - - defer safeClose(conn, &err) - - channel, err := fwd.session.OpenStreamingChannel(ctx, id) - if err != nil { - return fmt.Errorf("error opening streaming channel for new connection: %w", err) - } - // Ideally we would call safeClose again, but (*ssh.channel).Close - // appears to have a bug that causes it return io.EOF spuriously - // if its peer closed first; see github.com/golang/go/issues/38115. - defer func() { - closeErr := channel.Close() - if err == nil && closeErr != io.EOF { - err = closeErr - } - }() - - // bi-directional copy of data. - errs := make(chan error, 2) - copyConn := func(w ReadWriteHalfCloser, r io.Reader) { - _, err := io.Copy(w, r) - errs <- err - - // Ignore errors here, we call the full Close() later and catch that error - _ = w.CloseWrite() - } - - var ( - channelReader io.Reader = channel - connReader io.Reader = conn - ) - - // If the forwader has been configured to keep the session alive - // it will monitor the I/O and notify the session of the traffic. - if fwd.keepAlive { - channelReader = newTrafficMonitor(channelReader, fwd.session, "output") - connReader = newTrafficMonitor(connReader, fwd.session, "input") - } - - go copyConn(conn, channelReader) - go copyConn(channel, connReader) - - // Wait until context is cancelled or both copies are done. - // Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail. - // TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file? - for i := 0; ; { - select { - case <-ctx.Done(): - return ctx.Err() - case <-errs: - i++ - if i == 2 { - return nil - } - } - } -} - -// safeClose reports the error (to *err) from closing the stream only -// if no other error was previously reported. -func safeClose(closer io.Closer, err *error) { - closeErr := closer.Close() - if *err == nil { - *err = closeErr - } -} diff --git a/pkg/liveshare/port_forwarder_test.go b/pkg/liveshare/port_forwarder_test.go deleted file mode 100644 index 61acde368..000000000 --- a/pkg/liveshare/port_forwarder_test.go +++ /dev/null @@ -1,153 +0,0 @@ -package liveshare - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "net" - "os" - "testing" - "time" - - livesharetest "github.com/cli/cli/v2/pkg/liveshare/test" - "github.com/sourcegraph/jsonrpc2" -) - -func TestNewPortForwarder(t *testing.T) { - testServer, session, err := makeMockSession() - if err != nil { - t.Errorf("create mock client: %v", err) - } - defer testServer.Close() - pf := NewPortForwarder(session, "ssh", 80, false) - if pf == nil { - t.Error("port forwarder is nil") - } -} - -type portUpdateNotification struct { - PortNotification - conn *jsonrpc2.Conn -} - -func TestPortForwarderStart(t *testing.T) { - if os.Getenv("GITHUB_ACTIONS") == "true" { - t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5338") - } - - streamName, streamCondition := "stream-name", "stream-condition" - const port = 8000 - sendNotification := make(chan portUpdateNotification) - serverSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { - // Send the PortNotification that will be awaited on in session.StartSharing - sendNotification <- portUpdateNotification{ - PortNotification: PortNotification{ - Port: port, - ChangeKind: PortChangeKindStart, - }, - conn: conn, - } - return Port{StreamName: streamName, StreamCondition: streamCondition}, nil - } - getStream := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { - return "stream-id", nil - } - - stream := bytes.NewBufferString("stream-data") - testServer, session, err := makeMockSession( - livesharetest.WithService("serverSharing.startSharing", serverSharing), - livesharetest.WithService("streamManager.getStream", getStream), - livesharetest.WithStream("stream-id", stream), - ) - if err != nil { - t.Errorf("create mock session: %v", err) - } - defer testServer.Close() - - listen, err := net.Listen("tcp", "127.0.0.1:8000") - if err != nil { - t.Fatal(err) - } - defer listen.Close() - tcpListener, ok := listen.(*net.TCPListener) - if !ok { - t.Fatal("net.Listen did not return a TCPListener") - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - notif := <-sendNotification - _, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif) - }() - - done := make(chan error, 2) - go func() { - done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, tcpListener) - }() - - go func() { - var conn net.Conn - - // We retry DialTimeout in a loop to deal with a race in PortForwarder startup. - for tries := 0; conn == nil && tries < 2; tries++ { - conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second) - if conn == nil { - time.Sleep(1 * time.Second) - } - } - if conn == nil { - done <- errors.New("failed to connect to forwarded port") - return - } - b := make([]byte, len("stream-data")) - if _, err := conn.Read(b); err != nil && err != io.EOF { - done <- fmt.Errorf("reading stream: %w", err) - return - } - if string(b) != "stream-data" { - done <- fmt.Errorf("stream data is not expected value, got: %s", string(b)) - return - } - if _, err := conn.Write([]byte("new-data")); err != nil { - done <- fmt.Errorf("writing to stream: %w", err) - return - } - done <- nil - }() - - select { - case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) - case err := <-done: - if err != nil { - t.Errorf("error from client: %v", err) - } - } -} - -func TestPortForwarderTrafficMonitor(t *testing.T) { - buf := bytes.NewBufferString("some-input") - session := &Session{keepAliveReason: make(chan string, 1)} - trafficType := "io" - - tm := newTrafficMonitor(buf, session, trafficType) - l := len(buf.Bytes()) - - bb := make([]byte, l) - n, err := tm.Read(bb) - if err != nil { - t.Errorf("failed to read from traffic monitor: %v", err) - } - if n != l { - t.Errorf("expected to read %d bytes, got %d", l, n) - } - - keepAliveReason := <-session.keepAliveReason - if keepAliveReason != trafficType { - t.Errorf("expected keep alive reason to be %s, got %s", trafficType, keepAliveReason) - } -} diff --git a/pkg/liveshare/ports.go b/pkg/liveshare/ports.go deleted file mode 100644 index b39a4f630..000000000 --- a/pkg/liveshare/ports.go +++ /dev/null @@ -1,101 +0,0 @@ -package liveshare - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/sourcegraph/jsonrpc2" -) - -// Port describes a port exposed by the container. -type Port struct { - SourcePort int `json:"sourcePort"` - DestinationPort int `json:"destinationPort"` - SessionName string `json:"sessionName"` - StreamName string `json:"streamName"` - StreamCondition string `json:"streamCondition"` - BrowseURL string `json:"browseUrl"` - IsPublic bool `json:"isPublic"` - IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"` - HasTLSHandshakePassed bool `json:"hasTLSHandshakePassed"` - Privacy string `json:"privacy"` -} - -type PortChangeKind string - -const ( - PortChangeKindStart PortChangeKind = "start" - PortChangeKindUpdate PortChangeKind = "update" -) - -type PortNotification struct { - Success bool // Helps us disambiguate between the SharingSucceeded/SharingFailed events - // The following are properties included in the SharingSucceeded/SharingFailed events sent by the server sharing service in the Codespace - Port int `json:"port"` - ChangeKind PortChangeKind `json:"changeKind"` - ErrorDetail string `json:"errorDetail"` - StatusCode int `json:"statusCode"` -} - -// WaitForPortNotification waits for a port notification to be received. It returns the notification -// or an error if the notification is not received before the context is cancelled or it fails -// to parse the notification. -func (s *Session) WaitForPortNotification(ctx context.Context, port int, notifType PortChangeKind) (*PortNotification, error) { - // We use 1-buffered channels and non-blocking sends so that - // no goroutine gets stuck. - notificationCh := make(chan *PortNotification, 1) - errCh := make(chan error, 1) - - h := func(success bool) func(*jsonrpc2.Conn, *jsonrpc2.Request) { - return func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) { - notification := new(PortNotification) - if err := json.Unmarshal(*req.Params, ¬ification); err != nil { - select { - case errCh <- fmt.Errorf("error unmarshalling notification: %w", err): - default: - } - return - } - notification.Success = success - if notification.Port == port && notification.ChangeKind == notifType { - select { - case notificationCh <- notification: - default: - } - } - } - } - deregisterSuccess := s.registerRequestHandler("serverSharing.sharingSucceeded", h(true)) - deregisterFailure := s.registerRequestHandler("serverSharing.sharingFailed", h(false)) - defer deregisterSuccess() - defer deregisterFailure() - - for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case err := <-errCh: - return nil, err - case notification := <-notificationCh: - return notification, nil - } - } -} - -// GetSharedServers returns a description of each container port -// shared by a prior call to StartSharing by some client. -func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) { - var response []*Port - if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { - return nil, err - } - - return response, nil -} - -// UpdateSharedServerPrivacy controls port permissions and visibility scopes for who can access its URLs -// in the browser. -func (s *Session) UpdateSharedServerPrivacy(ctx context.Context, port int, visibility string) error { - return s.rpc.do(ctx, "serverSharing.updateSharedServerPrivacy", []interface{}{port, visibility}, nil) -} diff --git a/pkg/liveshare/rpc.go b/pkg/liveshare/rpc.go deleted file mode 100644 index 639f538c9..000000000 --- a/pkg/liveshare/rpc.go +++ /dev/null @@ -1,87 +0,0 @@ -package liveshare - -import ( - "context" - "fmt" - "io" - "sync" - "time" - - "github.com/opentracing/opentracing-go" - "github.com/sourcegraph/jsonrpc2" -) - -type rpcClient struct { - *jsonrpc2.Conn - conn io.ReadWriteCloser - handlersMu sync.Mutex - handlers map[string][]*handlerWrapper -} - -func newRPCClient(conn io.ReadWriteCloser) *rpcClient { - return &rpcClient{conn: conn, handlers: make(map[string][]*handlerWrapper)} -} - -func (r *rpcClient) connect(ctx context.Context) { - stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) - r.Conn = jsonrpc2.NewConn(ctx, stream, r) -} - -func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error { - span, ctx := opentracing.StartSpanFromContext(ctx, method) - defer span.Finish() - - waiter, err := r.Conn.DispatchCall(ctx, method, args) - if err != nil { - return fmt.Errorf("error dispatching %q call: %w", method, err) - } - - // timeout for waiter in case a connection cannot be made - waitCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) - defer cancel() - - return waiter.Wait(waitCtx, result) -} - -type handler func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) - -type handlerWrapper struct { - fn handler -} - -func (r *rpcClient) register(requestType string, fn handler) func() { - r.handlersMu.Lock() - defer r.handlersMu.Unlock() - - h := &handlerWrapper{fn: fn} - r.handlers[requestType] = append(r.handlers[requestType], h) - - return func() { - r.deregister(requestType, h) - } -} - -func (r *rpcClient) deregister(requestType string, handler *handlerWrapper) { - r.handlersMu.Lock() - defer r.handlersMu.Unlock() - - handlers := r.handlers[requestType] - for i, h := range handlers { - if h == handler { - // Swap h with last element and pop. - last := len(handlers) - 1 - handlers[i], handlers[last] = handlers[last], nil - r.handlers[requestType] = handlers[:last] - break - } - } -} - -func (r *rpcClient) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { - r.handlersMu.Lock() - defer r.handlersMu.Unlock() - - for _, handler := range r.handlers[req.Method] { - go handler.fn(conn, req) - } -} diff --git a/pkg/liveshare/session.go b/pkg/liveshare/session.go deleted file mode 100644 index e5ec86703..000000000 --- a/pkg/liveshare/session.go +++ /dev/null @@ -1,133 +0,0 @@ -package liveshare - -import ( - "context" - "fmt" - - "github.com/opentracing/opentracing-go" - "golang.org/x/crypto/ssh" - "golang.org/x/sync/errgroup" -) - -// A ChannelID is an identifier for an exposed port on a remote -// container that may be used to open an SSH channel to it. -type ChannelID struct { - name, condition string -} - -// Interface to allow the mocking of the liveshare session -type LiveshareSession interface { - Close() error - GetSharedServers(context.Context) ([]*Port, error) - KeepAlive(string) - OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error) - StartSharing(context.Context, string, int) (ChannelID, error) - GetKeepAliveReason() string -} - -// A Session represents the session between a connected Live Share client and server. -type Session struct { - ssh *sshSession - rpc *rpcClient - - keepAliveReason chan string - logger logger -} - -// Close should be called by users to clean up RPC and SSH resources whenever the session -// is no longer active. -func (s *Session) Close() error { - // Closing the RPC conn closes the underlying stream (SSH) - // So we only need to close once - if err := s.rpc.Close(); err != nil { - s.ssh.Close() // close SSH and ignore error - return fmt.Errorf("error while closing Live Share session: %w", err) - } - - return nil -} - -// Fetches the keep alive reason from the channel and returns it. -func (s *Session) GetKeepAliveReason() string { - return <-s.keepAliveReason -} - -// registerRequestHandler registers a handler for the given request type with the RPC -// server and returns a callback function to deregister the handler -func (s *Session) registerRequestHandler(requestType string, h handler) func() { - return s.rpc.register(requestType, h) -} - -// KeepAlive accepts a reason that is retained if there is no active reason -// to send to the server. -func (s *Session) KeepAlive(reason string) { - select { - case s.keepAliveReason <- reason: - default: - // there is already an active keep alive reason - // so we can ignore this one - } -} - -// StartSharing tells the Live Share host to start sharing the specified port from the container. -// The sessionName describes the purpose of the remote port or service. -// It returns an identifier that can be used to open an SSH channel to the remote port. -func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (ChannelID, error) { - args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)} - g, ctx := errgroup.WithContext(ctx) - - g.Go(func() error { - startNotification, err := s.WaitForPortNotification(ctx, port, PortChangeKindStart) - if err != nil { - return fmt.Errorf("error while waiting for port notification: %w", err) - - } - if !startNotification.Success { - return fmt.Errorf("error while starting port sharing: %s", startNotification.ErrorDetail) - } - return nil // success - }) - - var response Port - g.Go(func() error { - return s.rpc.do(ctx, "serverSharing.startSharing", args, &response) - }) - - if err := g.Wait(); err != nil { - return ChannelID{}, err - } - - return ChannelID{response.StreamName, response.StreamCondition}, nil -} - -func (s *Session) OpenStreamingChannel(ctx context.Context, id ChannelID) (ssh.Channel, error) { - type getStreamArgs struct { - StreamName string `json:"streamName"` - Condition string `json:"condition"` - } - args := getStreamArgs{ - StreamName: id.name, - Condition: id.condition, - } - var streamID string - if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { - return nil, fmt.Errorf("error getting stream id: %w", err) - } - - span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") - defer span.Finish() - _ = ctx // ctx is not currently used - - channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) - if err != nil { - return nil, fmt.Errorf("error opening ssh channel for transport: %w", err) - } - go ssh.DiscardRequests(reqs) - - requestType := fmt.Sprintf("stream-transport-%s", streamID) - if _, err = channel.SendRequest(requestType, true, nil); err != nil { - return nil, fmt.Errorf("error sending channel request: %w", err) - } - - return channel, nil -} diff --git a/pkg/liveshare/session_test.go b/pkg/liveshare/session_test.go deleted file mode 100644 index ab7cf18bd..000000000 --- a/pkg/liveshare/session_test.go +++ /dev/null @@ -1,278 +0,0 @@ -package liveshare - -import ( - "bytes" - "context" - "crypto/tls" - "encoding/json" - "errors" - "fmt" - "strings" - "sync" - "testing" - - livesharetest "github.com/cli/cli/v2/pkg/liveshare/test" - "github.com/sourcegraph/jsonrpc2" -) - -func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { - joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { - return joinWorkspaceResult{1}, nil - } - const sessionToken = "session-token" - opts = append( - opts, - livesharetest.WithPassword(sessionToken), - livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), - ) - testServer, err := livesharetest.NewServer(opts...) - if err != nil { - return nil, nil, fmt.Errorf("error creating server: %w", err) - } - - session, err := Connect(context.Background(), Options{ - SessionID: "session-id", - SessionToken: sessionToken, - RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), - RelaySAS: "relay-sas", - HostPublicKeys: []string{livesharetest.SSHPublicKey}, - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - Logger: newMockLogger(), - }) - if err != nil { - return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err) - } - return testServer, session, nil -} - -func TestServerStartSharing(t *testing.T) { - serverPort, serverProtocol := 2222, "sshd" - sendNotification := make(chan portUpdateNotification) - startSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { - var args []interface{} - if err := json.Unmarshal(*req.Params, &args); err != nil { - return nil, fmt.Errorf("error unmarshalling request: %w", err) - } - if len(args) < 3 { - return nil, errors.New("not enough arguments to start sharing") - } - port, ok := args[0].(float64) - if !ok { - return nil, errors.New("port argument is not an int") - } - if port != float64(serverPort) { - return nil, errors.New("port does not match serverPort") - } - if protocol, ok := args[1].(string); !ok { - return nil, errors.New("protocol argument is not a string") - } else if protocol != serverProtocol { - return nil, errors.New("protocol does not match serverProtocol") - } - if browseURL, ok := args[2].(string); !ok { - return nil, errors.New("browse url is not a string") - } else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) { - return nil, errors.New("browseURL does not match expected") - } - sendNotification <- portUpdateNotification{ - PortNotification: PortNotification{ - Port: int(port), - ChangeKind: PortChangeKindStart, - }, - conn: conn, - } - return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil - } - testServer, session, err := makeMockSession( - livesharetest.WithService("serverSharing.startSharing", startSharing), - ) - defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close() - - if err != nil { - t.Errorf("error creating mock session: %v", err) - } - ctx := context.Background() - - go func() { - notif := <-sendNotification - _, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif) - }() - - done := make(chan error) - go func() { - streamID, err := session.StartSharing(ctx, serverProtocol, serverPort) - if err != nil { - done <- fmt.Errorf("error sharing server: %w", err) - } - if streamID.name == "" || streamID.condition == "" { - done <- errors.New("stream name or condition is blank") - } - done <- nil - }() - - select { - case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) - case err := <-done: - if err != nil { - t.Errorf("error from client: %v", err) - } - } -} - -func TestServerGetSharedServers(t *testing.T) { - sharedServer := Port{ - SourcePort: 2222, - StreamName: "stream-name", - StreamCondition: "stream-condition", - } - getSharedServers := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { - return []*Port{&sharedServer}, nil - } - testServer, session, err := makeMockSession( - livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), - ) - if err != nil { - t.Errorf("error creating mock session: %v", err) - } - defer testServer.Close() - ctx := context.Background() - done := make(chan error) - go func() { - ports, err := session.GetSharedServers(ctx) - if err != nil { - done <- fmt.Errorf("error getting shared servers: %w", err) - } - if len(ports) < 1 { - done <- errors.New("not enough ports returned") - } - if ports[0].SourcePort != sharedServer.SourcePort { - done <- errors.New("source port does not match") - } - if ports[0].StreamName != sharedServer.StreamName { - done <- errors.New("stream name does not match") - } - if ports[0].StreamCondition != sharedServer.StreamCondition { - done <- errors.New("stream condiion does not match") - } - done <- nil - }() - - select { - case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) - case err := <-done: - if err != nil { - t.Errorf("error from client: %v", err) - } - } -} - -func TestServerUpdateSharedServerPrivacy(t *testing.T) { - updateSharedVisibility := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) { - var req []interface{} - if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { - return nil, fmt.Errorf("unmarshal req: %w", err) - } - if len(req) < 2 { - return nil, errors.New("request arguments is less than 2") - } - if port, ok := req[0].(float64); ok { - if port != 80.0 { - return nil, errors.New("port param is not expected value") - } - } else { - return nil, errors.New("port param is not a float64") - } - if privacy, ok := req[1].(string); ok { - if privacy != "public" { - return nil, fmt.Errorf("expected privacy param to be public but got %q", privacy) - } - } else { - return nil, fmt.Errorf("expected privacy param to be a bool but go %T", req[1]) - } - return nil, nil - } - testServer, session, err := makeMockSession( - livesharetest.WithService("serverSharing.updateSharedServerPrivacy", updateSharedVisibility), - ) - if err != nil { - t.Errorf("creating mock session: %v", err) - } - defer testServer.Close() - ctx := context.Background() - done := make(chan error) - go func() { - done <- session.UpdateSharedServerPrivacy(ctx, 80, "public") - }() - select { - case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) - case err := <-done: - if err != nil { - t.Errorf("error from client: %v", err) - } - } -} - -func TestInvalidHostKey(t *testing.T) { - joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) { - return joinWorkspaceResult{1}, nil - } - const sessionToken = "session-token" - opts := []livesharetest.ServerOption{ - livesharetest.WithPassword(sessionToken), - livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), - } - testServer, err := livesharetest.NewServer(opts...) - if err != nil { - t.Errorf("error creating server: %v", err) - } - _, err = Connect(context.Background(), Options{ - SessionID: "session-id", - SessionToken: sessionToken, - RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), - RelaySAS: "relay-sas", - HostPublicKeys: []string{}, - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - }) - if err == nil { - t.Error("expected invalid host key error, got: nil") - } -} - -func TestKeepAliveNonBlocking(t *testing.T) { - session := &Session{keepAliveReason: make(chan string, 1)} - for i := 0; i < 2; i++ { - session.KeepAlive("io") - } - - // if KeepAlive blocks, we'll never reach this and timeout the test - // timing out -} - -type mockLogger struct { - sync.Mutex - buf *bytes.Buffer -} - -func newMockLogger() *mockLogger { - return &mockLogger{buf: new(bytes.Buffer)} -} - -func (m *mockLogger) Printf(format string, v ...interface{}) { - m.Lock() - defer m.Unlock() - m.buf.WriteString(fmt.Sprintf(format, v...)) -} - -func (m *mockLogger) Println(v ...interface{}) { - m.Lock() - defer m.Unlock() - m.buf.WriteString(fmt.Sprintln(v...)) -} - -func (m *mockLogger) String() string { - m.Lock() - defer m.Unlock() - return m.buf.String() -} diff --git a/pkg/liveshare/socket.go b/pkg/liveshare/socket.go deleted file mode 100644 index f66436f65..000000000 --- a/pkg/liveshare/socket.go +++ /dev/null @@ -1,100 +0,0 @@ -package liveshare - -import ( - "context" - "crypto/tls" - "io" - "net" - "net/http" - "time" - - "github.com/gorilla/websocket" -) - -type socket struct { - addr string - tlsConfig *tls.Config - - conn *websocket.Conn - reader io.Reader -} - -func newSocket(uri string, tlsConfig *tls.Config) *socket { - return &socket{addr: uri, tlsConfig: tlsConfig} -} - -func (s *socket) connect(ctx context.Context) error { - dialer := websocket.Dialer{ - Proxy: http.ProxyFromEnvironment, - HandshakeTimeout: 45 * time.Second, - TLSClientConfig: s.tlsConfig, - } - ws, _, err := dialer.Dial(s.addr, nil) - if err != nil { - return err - } - s.conn = ws - return nil -} - -func (s *socket) Read(b []byte) (int, error) { - if s.reader == nil { - _, reader, err := s.conn.NextReader() - if err != nil { - return 0, err - } - - s.reader = reader - } - - bytesRead, err := s.reader.Read(b) - if err != nil { - s.reader = nil - - if err == io.EOF { - err = nil - } - } - - return bytesRead, err -} - -func (s *socket) Write(b []byte) (int, error) { - nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage) - if err != nil { - return 0, err - } - - bytesWritten, err := nextWriter.Write(b) - nextWriter.Close() - - return bytesWritten, err -} - -func (s *socket) Close() error { - return s.conn.Close() -} - -func (s *socket) LocalAddr() net.Addr { - return s.conn.LocalAddr() -} - -func (s *socket) RemoteAddr() net.Addr { - return s.conn.RemoteAddr() -} - -func (s *socket) SetDeadline(t time.Time) error { - if err := s.SetReadDeadline(t); err != nil { - return err - } - - return s.SetWriteDeadline(t) -} - -func (s *socket) SetReadDeadline(t time.Time) error { - return s.conn.SetReadDeadline(t) -} - -func (s *socket) SetWriteDeadline(t time.Time) error { - return s.conn.SetWriteDeadline(t) -} diff --git a/pkg/liveshare/ssh.go b/pkg/liveshare/ssh.go deleted file mode 100644 index e7de9055a..000000000 --- a/pkg/liveshare/ssh.go +++ /dev/null @@ -1,79 +0,0 @@ -package liveshare - -import ( - "context" - "encoding/base64" - "errors" - "fmt" - "io" - "net" - "time" - - "golang.org/x/crypto/ssh" -) - -type sshSession struct { - *ssh.Session - token string - hostPublicKeys []string - socket net.Conn - conn ssh.Conn - reader io.Reader - writer io.Writer -} - -func newSSHSession(token string, hostPublicKeys []string, socket net.Conn) *sshSession { - return &sshSession{token: token, hostPublicKeys: hostPublicKeys, socket: socket} -} - -func (s *sshSession) connect(ctx context.Context) error { - clientConfig := ssh.ClientConfig{ - User: "", - Auth: []ssh.AuthMethod{ - ssh.Password(s.token), - }, - HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"}, - HostKeyCallback: func(hostname string, addr net.Addr, key ssh.PublicKey) error { - encodedKey := base64.StdEncoding.EncodeToString(key.Marshal()) - for _, hpk := range s.hostPublicKeys { - if encodedKey == hpk { - return nil // we found a match for expected public key, safely return - } - } - return errors.New("invalid host public key") - }, - Timeout: 10 * time.Second, - } - - sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig) - if err != nil { - return fmt.Errorf("error creating ssh client connection: %w", err) - } - s.conn = sshClientConn - - sshClient := ssh.NewClient(sshClientConn, chans, reqs) - s.Session, err = sshClient.NewSession() - if err != nil { - return fmt.Errorf("error creating ssh client session: %w", err) - } - - s.reader, err = s.Session.StdoutPipe() - if err != nil { - return fmt.Errorf("error creating ssh session reader: %w", err) - } - - s.writer, err = s.Session.StdinPipe() - if err != nil { - return fmt.Errorf("error creating ssh session writer: %w", err) - } - - return nil -} - -func (s *sshSession) Read(p []byte) (n int, err error) { - return s.reader.Read(p) -} - -func (s *sshSession) Write(p []byte) (n int, err error) { - return s.writer.Write(p) -} diff --git a/pkg/liveshare/test/server.go b/pkg/liveshare/test/server.go deleted file mode 100644 index 0b4f6a7ba..000000000 --- a/pkg/liveshare/test/server.go +++ /dev/null @@ -1,349 +0,0 @@ -package livesharetest - -import ( - "context" - "errors" - "fmt" - "io" - "net/http" - "net/http/httptest" - "strings" - "sync" - - "github.com/gorilla/websocket" - "github.com/sourcegraph/jsonrpc2" - "golang.org/x/crypto/ssh" -) - -const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY----- -MIICXgIBAAKBgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yV -rCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhY -lR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3wIDAQAB -AoGBAI8UemkYoSM06gBCh5D1RHQt8eKNltzL7g9QSNfoXeZOC7+q+/TiZPcbqLp0 -5lyOalu8b8Ym7J0rSE377Ypj13LyHMXS63e4wMiXv3qOl3GDhMLpypnJ8PwqR2b8 -IijL2jrpQfLu6IYqlteA+7e9aEexJa1RRwxYIyq6pG1IYpbhAkEA9nKgtj3Z6ZDC -46IdqYzuUM9ZQdcw4AFr407+lub7tbWe5pYmaq3cT725IwLw081OAmnWJYFDMa/n -IPl9YcZSPQJBAMGOMbPs/YPkQAsgNdIUlFtK3o41OrrwJuTRTvv0DsbqDV0LKOiC -t8oAQQvjisH6Ew5OOhFyIFXtvZfzQMJppksCQQDWFd+cUICTUEise/Duj9maY3Uz -J99ySGnTbZTlu8PfJuXhg3/d3ihrMPG6A1z3cPqaSBxaOj8H07mhQHn1zNU1AkEA -hkl+SGPrO793g4CUdq2ahIA8SpO5rIsDoQtq7jlUq0MlhGFCv5Y5pydn+bSjx5MV -933kocf5kUSBntPBIWElYwJAZTm5ghu0JtSE6t3km0iuj7NGAQSdb6mD8+O7C3CP -FU3vi+4HlBysaT6IZ/HG+/dBsr4gYp4LGuS7DbaLuYw/uw== ------END RSA PRIVATE KEY-----` - -const SSHPublicKey = `AAAAB3NzaC1yc2EAAAADAQABAAAAgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yVrCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhYlR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3w==` - -// Server represents a LiveShare relay host server. -type Server struct { - password string - services map[string]RPCHandleFunc - relaySAS string - streams map[string]io.ReadWriter - sshConfig *ssh.ServerConfig - httptestServer *httptest.Server - errCh chan error - nonSecure bool -} - -// NewServer creates a new Server. ServerOptions can be passed to configure -// the SSH password, backing service, secrets and more. -func NewServer(opts ...ServerOption) (*Server, error) { - server := new(Server) - - for _, o := range opts { - if err := o(server); err != nil { - return nil, err - } - } - - server.sshConfig = &ssh.ServerConfig{ - PasswordCallback: sshPasswordCallback(server.password), - } - privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey)) - if err != nil { - return nil, fmt.Errorf("error parsing key: %w", err) - } - server.sshConfig.AddHostKey(privateKey) - - server.errCh = make(chan error, 1) - - if server.nonSecure { - server.httptestServer = httptest.NewServer(http.HandlerFunc(makeConnection(server))) - } else { - server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server))) - } - return server, nil -} - -// ServerOption is used to configure the Server. -type ServerOption func(*Server) error - -// WithPassword configures the Server password for SSH. -func WithPassword(password string) ServerOption { - return func(s *Server) error { - s.password = password - return nil - } -} - -// WithNonSecure configures the Server as non-secure. -func WithNonSecure() ServerOption { - return func(s *Server) error { - s.nonSecure = true - return nil - } -} - -// WithService accepts a mock RPC service for the Server to invoke. -func WithService(serviceName string, handler RPCHandleFunc) ServerOption { - return func(s *Server) error { - if s.services == nil { - s.services = make(map[string]RPCHandleFunc) - } - - s.services[serviceName] = handler - return nil - } -} - -// WithRelaySAS configures the relay SAS configuration key. -func WithRelaySAS(sas string) ServerOption { - return func(s *Server) error { - s.relaySAS = sas - return nil - } -} - -// WithStream allows you to specify a mock data stream for the server. -func WithStream(name string, stream io.ReadWriter) ServerOption { - return func(s *Server) error { - if s.streams == nil { - s.streams = make(map[string]io.ReadWriter) - } - s.streams[name] = stream - return nil - } -} - -func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { - return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { - if string(password) == serverPassword { - return nil, nil - } - return nil, errors.New("password rejected") - } -} - -// Close closes the underlying httptest Server. -func (s *Server) Close() { - s.httptestServer.Close() -} - -// URL returns the httptest Server url. -func (s *Server) URL() string { - return s.httptestServer.URL -} - -func (s *Server) Err() <-chan error { - return s.errCh -} - -var upgrader = websocket.Upgrader{} - -func makeConnection(server *Server) http.HandlerFunc { - return func(w http.ResponseWriter, req *http.Request) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - if server.relaySAS != "" { - // validate the sas key - sasParam := req.URL.Query().Get("sb-hc-token") - if sasParam != server.relaySAS { - sendError(server.errCh, errors.New("error validating sas")) - return - } - } - c, err := upgrader.Upgrade(w, req, nil) - if err != nil { - sendError(server.errCh, fmt.Errorf("error upgrading connection: %w", err)) - return - } - defer func() { - if err := c.Close(); err != nil { - sendError(server.errCh, err) - } - }() - - socketConn := newSocketConn(c) - _, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig) - if err != nil { - sendError(server.errCh, fmt.Errorf("error creating new ssh conn: %w", err)) - return - } - go ssh.DiscardRequests(reqs) - - if err := handleChannels(ctx, server, chans); err != nil { - sendError(server.errCh, err) - } - } -} - -// sendError does a non-blocking send of the error to the err channel. -func sendError(errc chan<- error, err error) { - select { - case errc <- err: - default: - // channel is blocked with a previous error, so we ignore - // this current error - } -} - -// awaitError waits for the context to finish and returns its error (if any). -// It also waits for an err to come through the err channel. -func awaitError(ctx context.Context, errc <-chan error) error { - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-errc: - return err - } -} - -// handleChannels services the sshChannels channel. For each SSH channel received -// it creates a go routine to service the channel's requests. It returns on the first -// error encountered. -func handleChannels(ctx context.Context, server *Server, sshChannels <-chan ssh.NewChannel) error { - errc := make(chan error, 1) - go func() { - for sshCh := range sshChannels { - ch, reqs, err := sshCh.Accept() - if err != nil { - sendError(errc, fmt.Errorf("failed to accept channel: %w", err)) - return - } - - go func() { - if err := handleRequests(ctx, server, ch, reqs); err != nil { - sendError(errc, fmt.Errorf("failed to handle requests: %w", err)) - } - }() - - handleChannel(server, ch) - } - }() - return awaitError(ctx, errc) -} - -// handleRequests services the SSH channel requests channel. It replies to requests and -// when stream transport requests are encountered, creates a go routine to create a -// bi-directional data stream between the channel and server stream. It returns on the first error -// encountered. -func handleRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) error { - errc := make(chan error, 1) - go func() { - for req := range reqs { - r := req - if r.WantReply { - if err := r.Reply(true, nil); err != nil { - sendError(errc, fmt.Errorf("error replying to channel request: %w", err)) - return - } - } - - if strings.HasPrefix(r.Type, "stream-transport") { - go func() { - if err := forwardStream(ctx, server, r.Type, channel); err != nil { - sendError(errc, fmt.Errorf("failed to forward stream: %w", err)) - } - }() - } - } - }() - - return awaitError(ctx, errc) -} - -// concurrentStream is a concurrency safe io.ReadWriter. -type concurrentStream struct { - sync.RWMutex - stream io.ReadWriter -} - -func newConcurrentStream(rw io.ReadWriter) *concurrentStream { - return &concurrentStream{stream: rw} -} - -func (cs *concurrentStream) Read(b []byte) (int, error) { - cs.RLock() - defer cs.RUnlock() - return cs.stream.Read(b) -} - -func (cs *concurrentStream) Write(b []byte) (int, error) { - cs.Lock() - defer cs.Unlock() - return cs.stream.Write(b) -} - -// forwardStream does a bi-directional copy of the stream <-> with the SSH channel. The io.Copy -// runs until an error is encountered. -func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) (err error) { - simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") - stream, found := server.streams[simpleStreamName] - if !found { - return fmt.Errorf("stream '%s' not found", simpleStreamName) - } - defer func() { - if closeErr := channel.Close(); err == nil && closeErr != io.EOF { - err = closeErr - } - }() - - errc := make(chan error, 2) - copy := func(dst io.Writer, src io.Reader) { - if _, err := io.Copy(dst, src); err != nil { - errc <- err - } - } - - csStream := newConcurrentStream(stream) - go copy(csStream, channel) - go copy(channel, csStream) - - return awaitError(ctx, errc) -} - -func handleChannel(server *Server, channel ssh.Channel) { - stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) - jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server)) -} - -type RPCHandleFunc func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) - -type rpcHandler struct { - server *Server -} - -func newRPCHandler(server *Server) *rpcHandler { - return &rpcHandler{server} -} - -// Handle satisfies the jsonrpc2 pkg handler interface. It tries to find a mocked -// RPC service method and if found, it invokes the handler and replies to the request. -func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { - handler, found := r.server.services[req.Method] - if !found { - sendError(r.server.errCh, fmt.Errorf("RPC Method: '%s' not serviced", req.Method)) - return - } - - result, err := handler(conn, req) - if err != nil { - sendError(r.server.errCh, fmt.Errorf("error handling: '%s': %w", req.Method, err)) - return - } - - if err := conn.Reply(ctx, req.ID, result); err != nil { - sendError(r.server.errCh, fmt.Errorf("error replying: %w", err)) - } -} diff --git a/pkg/liveshare/test/socket.go b/pkg/liveshare/test/socket.go deleted file mode 100644 index 00cd64a1b..000000000 --- a/pkg/liveshare/test/socket.go +++ /dev/null @@ -1,77 +0,0 @@ -package livesharetest - -import ( - "fmt" - "io" - "sync" - "time" - - "github.com/gorilla/websocket" -) - -type socketConn struct { - *websocket.Conn - - reader io.Reader - writeMutex sync.Mutex - readMutex sync.Mutex -} - -func newSocketConn(conn *websocket.Conn) *socketConn { - return &socketConn{Conn: conn} -} - -func (s *socketConn) Read(b []byte) (int, error) { - s.readMutex.Lock() - defer s.readMutex.Unlock() - - if s.reader == nil { - msgType, r, err := s.Conn.NextReader() - if err != nil { - return 0, fmt.Errorf("error getting next reader: %w", err) - } - if msgType != websocket.BinaryMessage { - return 0, fmt.Errorf("invalid message type") - } - s.reader = r - } - - bytesRead, err := s.reader.Read(b) - if err != nil { - s.reader = nil - - if err == io.EOF { - err = nil - } - } - - return bytesRead, err -} - -func (s *socketConn) Write(b []byte) (int, error) { - s.writeMutex.Lock() - defer s.writeMutex.Unlock() - - w, err := s.Conn.NextWriter(websocket.BinaryMessage) - if err != nil { - return 0, fmt.Errorf("error getting next writer: %w", err) - } - - n, err := w.Write(b) - if err != nil { - return 0, fmt.Errorf("error writing: %w", err) - } - - if err := w.Close(); err != nil { - return 0, fmt.Errorf("error closing writer: %w", err) - } - - return n, nil -} - -func (s *socketConn) SetDeadline(deadline time.Time) error { - if err := s.Conn.SetReadDeadline(deadline); err != nil { - return err - } - return s.Conn.SetWriteDeadline(deadline) -}