diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index 5e5bc7f64..402fcd7ea 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -19,6 +19,7 @@ import ( "github.com/cli/cli/v2/pkg/iostreams" "github.com/cli/cli/v2/pkg/liveshare" "github.com/spf13/cobra" + "golang.org/x/crypto/ssh" "golang.org/x/term" ) @@ -60,8 +61,18 @@ func (a *App) StopProgressIndicator() { a.io.StopProgressIndicator() } +type liveshareSession interface { + Close() error + GetSharedServers(context.Context) ([]*liveshare.Port, error) + KeepAlive(string) + OpenStreamingChannel(context.Context, liveshare.ChannelID) (ssh.Channel, error) + StartJupyterServer(context.Context) (int, string, error) + StartSharing(context.Context, string, int) (liveshare.ChannelID, error) + StartSSHServer(context.Context) (int, string, error) +} + // Connects to a codespace using Live Share and returns that session -func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App, debug bool, debugFile string) (session *liveshare.Session, err error) { +func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App, debug bool, debugFile string) (session liveshareSession, err error) { // While connecting, ensure in the background that the user has keys installed. // That lets us report a more useful error message if they don't. authkeys := make(chan error, 1) diff --git a/pkg/liveshare/client.go b/pkg/liveshare/client.go index b67e1b1cf..33a8e7691 100644 --- a/pkg/liveshare/client.go +++ b/pkg/liveshare/client.go @@ -20,7 +20,6 @@ import ( "time" "github.com/opentracing/opentracing-go" - "golang.org/x/crypto/ssh" ) type logger interface { @@ -136,41 +135,3 @@ type joinWorkspaceArgs struct { type joinWorkspaceResult struct { SessionNumber int `json:"sessionNumber"` } - -// A channelID is an identifier for an exposed port on a remote -// container that may be used to open an SSH channel to it. -type channelID struct { - name, condition string -} - -func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) { - type getStreamArgs struct { - StreamName string `json:"streamName"` - Condition string `json:"condition"` - } - args := getStreamArgs{ - StreamName: id.name, - Condition: id.condition, - } - var streamID string - if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { - return nil, fmt.Errorf("error getting stream id: %w", err) - } - - span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") - defer span.Finish() - _ = ctx // ctx is not currently used - - channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) - if err != nil { - return nil, fmt.Errorf("error opening ssh channel for transport: %w", err) - } - go ssh.DiscardRequests(reqs) - - requestType := fmt.Sprintf("stream-transport-%s", streamID) - if _, err = channel.SendRequest(requestType, true, nil); err != nil { - return nil, fmt.Errorf("error sending channel request: %w", err) - } - - return channel, nil -} diff --git a/pkg/liveshare/port_forwarder.go b/pkg/liveshare/port_forwarder.go index ba2c7ff40..f042eeaea 100644 --- a/pkg/liveshare/port_forwarder.go +++ b/pkg/liveshare/port_forwarder.go @@ -7,12 +7,19 @@ import ( "net" "github.com/opentracing/opentracing-go" + "golang.org/x/crypto/ssh" ) +type portForwardingSession interface { + StartSharing(context.Context, string, int) (ChannelID, error) + OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error) + KeepAlive(string) +} + // A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote // container to a local destination such as a network port or Go reader/writer. type PortForwarder struct { - session *Session + session portForwardingSession name string remotePort int keepAlive bool @@ -22,7 +29,7 @@ type PortForwarder struct { // remote port and Live Share session. The name describes the purpose // of the remote port or service. The keepAlive flag indicates whether // the session should be kept alive with port forwarding traffic. -func NewPortForwarder(session *Session, name string, remotePort int, keepAlive bool) *PortForwarder { +func NewPortForwarder(session portForwardingSession, name string, remotePort int, keepAlive bool) *PortForwarder { return &PortForwarder{ session: session, name: name, @@ -92,8 +99,8 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) return awaitError(ctx, errc) } -func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) { - id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort) +func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) { + id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort) if err != nil { err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err) } @@ -110,35 +117,39 @@ func awaitError(ctx context.Context, errc <-chan error) error { } } +type trafficMonitorSession interface { + KeepAlive(string) +} + // trafficMonitor implements io.Reader. It keeps the session alive by notifying // it of the traffic type during Read operations. type trafficMonitor struct { reader io.Reader - session *Session + session trafficMonitorSession trafficType string } // newTrafficMonitor returns a new trafficMonitor for the specified // session and traffic type. It wraps the provided io.Reader with its own // Read method. -func newTrafficMonitor(reader io.Reader, session *Session, trafficType string) *trafficMonitor { +func newTrafficMonitor(reader io.Reader, session trafficMonitorSession, trafficType string) *trafficMonitor { return &trafficMonitor{reader, session, trafficType} } func (t *trafficMonitor) Read(p []byte) (n int, err error) { - t.session.keepAlive(t.trafficType) + t.session.KeepAlive(t.trafficType) return t.reader.Read(p) } // handleConnection handles forwarding for a single accepted connection, then closes it. -func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) { +func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn io.ReadWriteCloser) (err error) { span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection") defer span.Finish() defer safeClose(conn, &err) - channel, err := fwd.session.openStreamingChannel(ctx, id) + channel, err := fwd.session.OpenStreamingChannel(ctx, id) if err != nil { return fmt.Errorf("error opening streaming channel for new connection: %w", err) } diff --git a/pkg/liveshare/ports.go b/pkg/liveshare/ports.go index 851fe3190..6cef1584b 100644 --- a/pkg/liveshare/ports.go +++ b/pkg/liveshare/ports.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/sourcegraph/jsonrpc2" - "golang.org/x/sync/errgroup" ) // Port describes a port exposed by the container. @@ -30,37 +29,6 @@ const ( PortChangeKindUpdate PortChangeKind = "update" ) -// startSharing tells the Live Share host to start sharing the specified port from the container. -// The sessionName describes the purpose of the remote port or service. -// It returns an identifier that can be used to open an SSH channel to the remote port. -func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) { - args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)} - g, ctx := errgroup.WithContext(ctx) - - g.Go(func() error { - startNotification, err := s.WaitForPortNotification(ctx, port, PortChangeKindStart) - if err != nil { - return fmt.Errorf("error while waiting for port notification: %w", err) - - } - if !startNotification.Success { - return fmt.Errorf("error while starting port sharing: %s", startNotification.ErrorDetail) - } - return nil // success - }) - - var response Port - g.Go(func() error { - return s.rpc.do(ctx, "serverSharing.startSharing", args, &response) - }) - - if err := g.Wait(); err != nil { - return channelID{}, err - } - - return channelID{response.StreamName, response.StreamCondition}, nil -} - type PortNotification struct { Success bool // Helps us disambiguate between the SharingSucceeded/SharingFailed events // The following are properties included in the SharingSucceeded/SharingFailed events sent by the server sharing service in the Codespace diff --git a/pkg/liveshare/session.go b/pkg/liveshare/session.go index e2648c1c8..33531bffd 100644 --- a/pkg/liveshare/session.go +++ b/pkg/liveshare/session.go @@ -5,8 +5,18 @@ import ( "fmt" "strconv" "time" + + "github.com/opentracing/opentracing-go" + "golang.org/x/crypto/ssh" + "golang.org/x/sync/errgroup" ) +// A ChannelID is an identifier for an exposed port on a remote +// container that may be used to open an SSH channel to it. +type ChannelID struct { + name, condition string +} + // A Session represents the session between a connected Live Share client and server. type Session struct { ssh *sshSession @@ -91,7 +101,7 @@ func (s *Session) StartJupyterServer(ctx context.Context) (int, string, error) { // heartbeat runs until context cancellation, periodically checking whether there is a // reason to keep the connection alive, and if so, notifying the Live Share host to do so. // Heartbeat ensures it does not send more than one request every "interval" to ratelimit -// how many keepAlives we send at a time. +// how many KeepAlives we send at a time. func (s *Session) heartbeat(ctx context.Context, interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() @@ -118,9 +128,9 @@ func (s *Session) notifyHostOfActivity(ctx context.Context, activity string) err return s.rpc.do(ctx, "ICodespaceHostService.notifyCodespaceOfClientActivity", params, nil) } -// keepAlive accepts a reason that is retained if there is no active reason +// KeepAlive accepts a reason that is retained if there is no active reason // to send to the server. -func (s *Session) keepAlive(reason string) { +func (s *Session) KeepAlive(reason string) { select { case s.keepAliveReason <- reason: default: @@ -128,3 +138,66 @@ func (s *Session) keepAlive(reason string) { // so we can ignore this one } } + +// StartSharing tells the Live Share host to start sharing the specified port from the container. +// The sessionName describes the purpose of the remote port or service. +// It returns an identifier that can be used to open an SSH channel to the remote port. +func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (ChannelID, error) { + args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)} + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + startNotification, err := s.WaitForPortNotification(ctx, port, PortChangeKindStart) + if err != nil { + return fmt.Errorf("error while waiting for port notification: %w", err) + + } + if !startNotification.Success { + return fmt.Errorf("error while starting port sharing: %s", startNotification.ErrorDetail) + } + return nil // success + }) + + var response Port + g.Go(func() error { + return s.rpc.do(ctx, "serverSharing.startSharing", args, &response) + }) + + if err := g.Wait(); err != nil { + return ChannelID{}, err + } + + return ChannelID{response.StreamName, response.StreamCondition}, nil +} + +func (s *Session) OpenStreamingChannel(ctx context.Context, id ChannelID) (ssh.Channel, error) { + type getStreamArgs struct { + StreamName string `json:"streamName"` + Condition string `json:"condition"` + } + args := getStreamArgs{ + StreamName: id.name, + Condition: id.condition, + } + var streamID string + if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { + return nil, fmt.Errorf("error getting stream id: %w", err) + } + + span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") + defer span.Finish() + _ = ctx // ctx is not currently used + + channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) + if err != nil { + return nil, fmt.Errorf("error opening ssh channel for transport: %w", err) + } + go ssh.DiscardRequests(reqs) + + requestType := fmt.Sprintf("stream-transport-%s", streamID) + if _, err = channel.SendRequest(requestType, true, nil); err != nil { + return nil, fmt.Errorf("error sending channel request: %w", err) + } + + return channel, nil +} diff --git a/pkg/liveshare/session_test.go b/pkg/liveshare/session_test.go index c06a4f06a..cfe8ccd11 100644 --- a/pkg/liveshare/session_test.go +++ b/pkg/liveshare/session_test.go @@ -103,7 +103,7 @@ func TestServerStartSharing(t *testing.T) { done := make(chan error) go func() { - streamID, err := session.startSharing(ctx, serverProtocol, serverPort) + streamID, err := session.StartSharing(ctx, serverProtocol, serverPort) if err != nil { done <- fmt.Errorf("error sharing server: %w", err) } @@ -247,10 +247,10 @@ func TestInvalidHostKey(t *testing.T) { func TestKeepAliveNonBlocking(t *testing.T) { session := &Session{keepAliveReason: make(chan string, 1)} for i := 0; i < 2; i++ { - session.keepAlive("io") + session.KeepAlive("io") } - // if keepAlive blocks, we'll never reach this and timeout the test + // if KeepAlive blocks, we'll never reach this and timeout the test // timing out } @@ -367,10 +367,10 @@ func TestSessionHeartbeat(t *testing.T) { go session.heartbeat(ctx, 50*time.Millisecond) go func() { - session.keepAlive("input") + session.KeepAlive("input") wg.Wait() wg.Add(1) - session.keepAlive("input") + session.KeepAlive("input") wg.Wait() done <- struct{}{} }() @@ -380,7 +380,7 @@ func TestSessionHeartbeat(t *testing.T) { t.Errorf("error from server: %v", err) case <-done: activityCount := strings.Count(logger.String(), "input") - // by design keepAlive can drop requests, and therefore there is zero guarantee + // by design KeepAlive can drop requests, and therefore there is zero guarantee // that we actually get two requests if the network happened to be slow (rarely) // during testing. if activityCount != 1 && activityCount != 2 {