diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index c9831a961..55ec2a25a 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -15,6 +15,13 @@ type logger interface { Println(v ...interface{}) (int, error) } +// TODO(josebalius): clean this up once we standardrize +// logging for codespaces +type liveshareLogger interface { + Println(v ...interface{}) + Printf(f string, v ...interface{}) +} + func connectionReady(codespace *api.Codespace) bool { return codespace.Connection.SessionID != "" && codespace.Connection.SessionToken != "" && @@ -30,7 +37,7 @@ type apiClient interface { // ConnectToLiveshare waits for a Codespace to become running, // and connects to it using a Live Share session. -func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) { +func ConnectToLiveshare(ctx context.Context, log logger, sessionLogger liveshareLogger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { startedCodespace = true @@ -67,10 +74,12 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, co log.Println("Connecting to your codespace...") return liveshare.Connect(ctx, liveshare.Options{ + ClientName: "gh", SessionID: codespace.Connection.SessionID, SessionToken: codespace.Connection.SessionToken, RelaySAS: codespace.Connection.RelaySAS, RelayEndpoint: codespace.Connection.RelayEndpoint, HostPublicKeys: codespace.Connection.HostPublicKeys, + Logger: sessionLogger, }) } diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index e8f197410..00170d3ec 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -5,6 +5,8 @@ import ( "context" "encoding/json" "fmt" + "io/ioutil" + "log" "net" "strings" "time" @@ -36,8 +38,10 @@ type PostCreateState struct { // PollPostCreateStates watches for state changes in a codespace, // and calls the supplied poller for each batch of state changes. // It runs until it encounters an error, including cancellation of the context. -func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { - session, err := ConnectToLiveshare(ctx, log, apiClient, codespace) +func PollPostCreateStates(ctx context.Context, logger logger, apiClient apiClient, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { + noopLogger := log.New(ioutil.Discard, "", 0) + + session, err := ConnectToLiveshare(ctx, logger, noopLogger, apiClient, codespace) if err != nil { return fmt.Errorf("connect to Live Share: %w", err) } @@ -54,7 +58,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, } localPort := listen.Addr().(*net.TCPAddr).Port - log.Println("Fetching SSH Details...") + logger.Println("Fetching SSH Details...") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) if err != nil { return fmt.Errorf("error getting ssh server details: %w", err) @@ -62,7 +66,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness go func() { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false) tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil }() diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index b93b9a8d6..721a5c992 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -7,6 +7,8 @@ import ( "errors" "fmt" "io" + "io/ioutil" + "log" "os" "sort" "strings" @@ -211,6 +213,10 @@ func noArgsConstraint(cmd *cobra.Command, args []string) error { return nil } +func noopLogger() *log.Logger { + return log.New(ioutil.Discard, "", 0) +} + type codespace struct { *api.Codespace } diff --git a/pkg/cmd/codespace/delete_test.go b/pkg/cmd/codespace/delete_test.go index 0839c9ef4..0251e4529 100644 --- a/pkg/cmd/codespace/delete_test.go +++ b/pkg/cmd/codespace/delete_test.go @@ -44,6 +44,7 @@ func TestDelete(t *testing.T) { }, }, wantDeleted: []string{"hubot-robawt-abc"}, + wantStdout: "Codespace deleted.\n", }, { name: "by repo", @@ -65,6 +66,7 @@ func TestDelete(t *testing.T) { }, }, wantDeleted: []string{"monalisa-spoonknife-123", "monalisa-spoonknife-c4f3"}, + wantStdout: "Codespaces deleted.\n", }, { name: "unused", @@ -87,6 +89,7 @@ func TestDelete(t *testing.T) { }, }, wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"}, + wantStdout: "Codespaces deleted.\n", }, { name: "deletion failed", @@ -148,6 +151,7 @@ func TestDelete(t *testing.T) { "Codespace hubot-robawt-abc has unsaved changes. OK to delete?": true, }, wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"}, + wantStdout: "Codespaces deleted.\n", }, } for _, tt := range tests { diff --git a/pkg/cmd/codespace/logs.go b/pkg/cmd/codespace/logs.go index 9bfc4a967..317059918 100644 --- a/pkg/cmd/codespace/logs.go +++ b/pkg/cmd/codespace/logs.go @@ -51,7 +51,7 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err return fmt.Errorf("get or choose codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace) if err != nil { return fmt.Errorf("connecting to Live Share: %w", err) } @@ -90,7 +90,7 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err tunnelClosed := make(chan error, 1) go func() { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false) tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil }() diff --git a/pkg/cmd/codespace/output/logger.go b/pkg/cmd/codespace/output/logger.go index 6ad7513f1..fdefcad0f 100644 --- a/pkg/cmd/codespace/output/logger.go +++ b/pkg/cmd/codespace/output/logger.go @@ -9,10 +9,14 @@ import ( // NewLogger returns a Logger that will write to the given stdout/stderr writers. // Disable the Logger to prevent it from writing to stdout in a TTY environment. func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger { + enabled := !disabled + if isTTY(stdout) && !enabled { + enabled = false + } return &Logger{ out: stdout, errout: stderr, - enabled: !disabled && isTTY(stdout), + enabled: enabled, } } diff --git a/pkg/cmd/codespace/ports.go b/pkg/cmd/codespace/ports.go index c36f46078..898ede494 100644 --- a/pkg/cmd/codespace/ports.go +++ b/pkg/cmd/codespace/ports.go @@ -60,7 +60,7 @@ func (a *App) ListPorts(ctx context.Context, codespaceName string, asJSON bool) devContainerCh := getDevContainer(ctx, a.apiClient, codespace) - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -194,7 +194,7 @@ func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName, sourcePor return fmt.Errorf("error getting codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -253,7 +253,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st return fmt.Errorf("error getting codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -272,7 +272,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st defer listen.Close() a.logger.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) name := fmt.Sprintf("share-%d", pair.remote) - fwd := liveshare.NewPortForwarder(session, name, pair.remote) + fwd := liveshare.NewPortForwarder(session, name, pair.remote, false) return fwd.ForwardToListener(ctx, listen) // error always non-nil }) } diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index c90b542cb..928bd044e 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -3,34 +3,46 @@ package codespace import ( "context" "fmt" + "io/ioutil" + "log" "net" + "os" "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/pkg/liveshare" "github.com/spf13/cobra" ) +type sshOptions struct { + codespace string + profile string + serverPort int + debug bool + debugFile string +} + func newSSHCmd(app *App) *cobra.Command { - var sshProfile, codespaceName string - var sshServerPort int + var opts sshOptions sshCmd := &cobra.Command{ Use: "ssh [flags] [--] [ssh-flags] [command]", Short: "SSH into a codespace", RunE: func(cmd *cobra.Command, args []string) error { - return app.SSH(cmd.Context(), args, sshProfile, codespaceName, sshServerPort) + return app.SSH(cmd.Context(), args, opts) }, } - sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "Name of the SSH profile to use") - sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH server port number (0 => pick unused)") - sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "Name of the codespace") + sshCmd.Flags().StringVarP(&opts.profile, "profile", "", "", "Name of the SSH profile to use") + sshCmd.Flags().IntVarP(&opts.serverPort, "server-port", "", 0, "SSH server port number (0 => pick unused)") + sshCmd.Flags().StringVarP(&opts.codespace, "codespace", "c", "", "Name of the codespace") + sshCmd.Flags().BoolVarP(&opts.debug, "debug", "d", false, "Log debug data to a file") + sshCmd.Flags().StringVarP(&opts.debugFile, "debug-file", "", "", "Path of the file log to") return sshCmd } // SSH opens an ssh session or runs an ssh command in a codespace. -func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) { +func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err error) { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -45,12 +57,22 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login) }() - codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName) + codespace, err := getOrChooseCodespace(ctx, a.apiClient, opts.codespace) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace) + var debugLogger *fileLogger + if opts.debug { + debugLogger, err = newFileLogger(opts.debugFile) + if err != nil { + return fmt.Errorf("error creating debug logger: %w", err) + } + defer safeClose(debugLogger, &err) + a.logger.Println("Debug file located at: " + debugLogger.Name()) + } + + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, debugLogger, a.apiClient, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -66,6 +88,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa return fmt.Errorf("error getting ssh server details: %w", err) } + localSSHServerPort := opts.serverPort usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell // Ensure local port is listening before client (Shell) connects. @@ -76,7 +99,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa defer listen.Close() localSSHServerPort = listen.Addr().(*net.TCPAddr).Port - connectDestination := sshProfile + connectDestination := opts.profile if connectDestination == "" { connectDestination = fmt.Sprintf("%s@localhost", sshUser) } @@ -84,7 +107,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa a.logger.Println("Ready...") tunnelClosed := make(chan error, 1) go func() { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true) tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil }() @@ -103,3 +126,43 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa return nil // success } } + +// fileLogger is a wrapper around an log.Logger configured to write +// to a file. It exports two additional methods to get the log file name +// and close the file handle when the operation is finished. +type fileLogger struct { + *log.Logger + + f *os.File +} + +// newFileLogger creates a new fileLogger. It returns an error if the file +// cannot be created. The file is created on the specified path, if the path +// is empty it is created in the temporary directory. +func newFileLogger(file string) (fl *fileLogger, err error) { + var f *os.File + if file == "" { + f, err = ioutil.TempFile("", "") + if err != nil { + return nil, fmt.Errorf("failed to create tmp file: %w", err) + } + } else { + f, err = os.Create(file) + if err != nil { + return nil, err + } + } + + return &fileLogger{ + Logger: log.New(f, "", log.LstdFlags), + f: f, + }, nil +} + +func (fl *fileLogger) Name() string { + return fl.f.Name() +} + +func (fl *fileLogger) Close() error { + return fl.f.Close() +} diff --git a/pkg/liveshare/client.go b/pkg/liveshare/client.go index 913f19195..840e99db9 100644 --- a/pkg/liveshare/client.go +++ b/pkg/liveshare/client.go @@ -17,23 +17,34 @@ import ( "fmt" "net/url" "strings" + "time" "github.com/opentracing/opentracing-go" "golang.org/x/crypto/ssh" ) +type logger interface { + Println(v ...interface{}) + Printf(f string, v ...interface{}) +} + // An Options specifies Live Share connection parameters. type Options struct { + ClientName string // ClientName is the name of the connecting client. SessionID string SessionToken string // token for SSH session RelaySAS string RelayEndpoint string HostPublicKeys []string + Logger logger // required TLSConfig *tls.Config // (optional) } // uri returns a websocket URL for the specified options. func (opts *Options) uri(action string) (string, error) { + if opts.ClientName == "" { + return "", errors.New("ClientName is required") + } if opts.SessionID == "" { return "", errors.New("SessionID is required") } @@ -56,13 +67,17 @@ func (opts *Options) uri(action string) (string, error) { // options, and returns a session representing the connection. // The caller must call the session's Close method to end the session. func Connect(ctx context.Context, opts Options) (*Session, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") + defer span.Finish() + uri, err := opts.uri("connect") if err != nil { return nil, err } - span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") - defer span.Finish() + if opts.Logger == nil { + return nil, errors.New("Logger is required") + } sock := newSocket(uri, opts.TLSConfig) if err := sock.connect(ctx); err != nil { @@ -93,7 +108,16 @@ func Connect(ctx context.Context, opts Options) (*Session, error) { return nil, fmt.Errorf("error joining Live Share workspace: %w", err) } - return &Session{ssh: ssh, rpc: rpc}, nil + s := &Session{ + ssh: ssh, + rpc: rpc, + clientName: opts.ClientName, + keepAliveReason: make(chan string, 1), + logger: opts.Logger, + } + go s.heartbeat(ctx, 1*time.Minute) + + return s, nil } type clientCapabilities struct { diff --git a/pkg/liveshare/client_test.go b/pkg/liveshare/client_test.go index 46807a22e..c6502d684 100644 --- a/pkg/liveshare/client_test.go +++ b/pkg/liveshare/client_test.go @@ -15,10 +15,12 @@ import ( func TestConnect(t *testing.T) { opts := Options{ + ClientName: "liveshare-client", SessionID: "session-id", SessionToken: "session-token", RelaySAS: "relay-sas", HostPublicKeys: []string{livesharetest.SSHPublicKey}, + Logger: newMockLogger(), } joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { var joinWorkspaceReq joinWorkspaceArgs diff --git a/pkg/liveshare/options_test.go b/pkg/liveshare/options_test.go index 830c59104..d244193b4 100644 --- a/pkg/liveshare/options_test.go +++ b/pkg/liveshare/options_test.go @@ -41,6 +41,7 @@ func checkBadOptions(t *testing.T, opts Options) { func TestOptionsURI(t *testing.T) { opts := Options{ + ClientName: "liveshare-client", SessionID: "sess-id", SessionToken: "sess-token", RelaySAS: "sas", diff --git a/pkg/liveshare/port_forwarder.go b/pkg/liveshare/port_forwarder.go index fcc7ba767..2649abd3c 100644 --- a/pkg/liveshare/port_forwarder.go +++ b/pkg/liveshare/port_forwarder.go @@ -15,16 +15,19 @@ type PortForwarder struct { session *Session name string remotePort int + keepAlive bool } // NewPortForwarder returns a new PortForwarder for the specified // remote port and Live Share session. The name describes the purpose -// of the remote port or service. -func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder { +// of the remote port or service. The keepAlive flag indicates whether +// the session should be kept alive with port forwarding traffic. +func NewPortForwarder(session *Session, name string, remotePort int, keepAlive bool) *PortForwarder { return &PortForwarder{ session: session, name: name, remotePort: remotePort, + keepAlive: keepAlive, } } @@ -106,6 +109,27 @@ func awaitError(ctx context.Context, errc <-chan error) error { } } +// trafficMonitor implements io.Reader. It keeps the session alive by notifying +// it of the traffic type during Read operations. +type trafficMonitor struct { + reader io.Reader + + session *Session + trafficType string +} + +// newTrafficMonitor returns a new trafficMonitor for the specified +// session and traffic type. It wraps the provided io.Reader with its own +// Read method. +func newTrafficMonitor(reader io.Reader, session *Session, trafficType string) *trafficMonitor { + return &trafficMonitor{reader, session, trafficType} +} + +func (t *trafficMonitor) Read(p []byte) (n int, err error) { + t.session.keepAlive(t.trafficType) + return t.reader.Read(p) +} + // handleConnection handles forwarding for a single accepted connection, then closes it. func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) { span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection") @@ -133,8 +157,21 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co _, err := io.Copy(w, r) errs <- err } - go copyConn(conn, channel) - go copyConn(channel, conn) + + var ( + channelReader io.Reader = channel + connReader io.Reader = conn + ) + + // If the forwader has been configured to keep the session alive + // it will monitor the I/O and notify the session of the traffic. + if fwd.keepAlive { + channelReader = newTrafficMonitor(channelReader, fwd.session, "output") + connReader = newTrafficMonitor(connReader, fwd.session, "input") + } + + go copyConn(conn, channelReader) + go copyConn(channel, connReader) // Wait until context is cancelled or both copies are done. // Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail. diff --git a/pkg/liveshare/port_forwarder_test.go b/pkg/liveshare/port_forwarder_test.go index 624428dda..c5b61d430 100644 --- a/pkg/liveshare/port_forwarder_test.go +++ b/pkg/liveshare/port_forwarder_test.go @@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) { t.Errorf("create mock client: %w", err) } defer testServer.Close() - pf := NewPortForwarder(session, "ssh", 80) + pf := NewPortForwarder(session, "ssh", 80, false) if pf == nil { t.Error("port forwarder is nil") } @@ -58,7 +58,7 @@ func TestPortForwarderStart(t *testing.T) { done := make(chan error) go func() { const name, remote = "ssh", 8000 - done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen) + done <- NewPortForwarder(session, name, remote, false).ForwardToListener(ctx, listen) }() go func() { @@ -93,3 +93,26 @@ func TestPortForwarderStart(t *testing.T) { } } } + +func TestPortForwarderTrafficMonitor(t *testing.T) { + buf := bytes.NewBufferString("some-input") + session := &Session{keepAliveReason: make(chan string, 1)} + trafficType := "io" + + tm := newTrafficMonitor(buf, session, trafficType) + l := len(buf.Bytes()) + + bb := make([]byte, l) + n, err := tm.Read(bb) + if err != nil { + t.Errorf("failed to read from traffic monitor: %w", err) + } + if n != l { + t.Errorf("expected to read %d bytes, got %d", l, n) + } + + keepAliveReason := <-session.keepAliveReason + if keepAliveReason != trafficType { + t.Errorf("expected keep alive reason to be %s, got %s", trafficType, keepAliveReason) + } +} diff --git a/pkg/liveshare/session.go b/pkg/liveshare/session.go index 929e8605b..13558f911 100644 --- a/pkg/liveshare/session.go +++ b/pkg/liveshare/session.go @@ -4,12 +4,17 @@ import ( "context" "fmt" "strconv" + "time" ) // A Session represents the session between a connected Live Share client and server. type Session struct { ssh *sshSession rpc *rpcClient + + clientName string + keepAliveReason chan string + logger logger } // Close should be called by users to clean up RPC and SSH resources whenever the session @@ -97,3 +102,42 @@ func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) { return port, response.User, nil } + +// heartbeat runs until context cancellation, periodically checking whether there is a +// reason to keep the connection alive, and if so, notifying the Live Share host to do so. +func (s *Session) heartbeat(ctx context.Context, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.logger.Println("Heartbeat tick") + reason := <-s.keepAliveReason + s.logger.Println("Keep alive reason: " + reason) + if err := s.notifyHostOfActivity(ctx, reason); err != nil { + s.logger.Printf("Failed to notify host of activity: %s\n", err) + } + } + } +} + +// notifyHostOfActivity notifies the Live Share host of client activity. +func (s *Session) notifyHostOfActivity(ctx context.Context, activity string) error { + activities := []string{activity} + params := []interface{}{s.clientName, activities} + return s.rpc.do(ctx, "ICodespaceHostService.notifyCodespaceOfClientActivity", params, nil) +} + +// keepAlive accepts a reason that is retained if there is no active reason +// to send to the server. +func (s *Session) keepAlive(reason string) { + select { + case s.keepAliveReason <- reason: + default: + // there is already an active keep alive reason + // so we can ignore this one + } +} diff --git a/pkg/liveshare/session_test.go b/pkg/liveshare/session_test.go index 7f0b573b5..998de6ac0 100644 --- a/pkg/liveshare/session_test.go +++ b/pkg/liveshare/session_test.go @@ -1,18 +1,23 @@ package liveshare import ( + "bytes" "context" "crypto/tls" "encoding/json" "errors" "fmt" "strings" + "sync" "testing" + "time" livesharetest "github.com/cli/cli/v2/pkg/liveshare/test" "github.com/sourcegraph/jsonrpc2" ) +const mockClientName = "liveshare-client" + func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { return joinWorkspaceResult{1}, nil @@ -29,12 +34,14 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, } session, err := Connect(context.Background(), Options{ + ClientName: mockClientName, SessionID: "session-id", SessionToken: sessionToken, RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), RelaySAS: "relay-sas", HostPublicKeys: []string{livesharetest.SSHPublicKey}, TLSConfig: &tls.Config{InsecureSkipVerify: true}, + Logger: newMockLogger(), }) if err != nil { return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err) @@ -221,3 +228,176 @@ func TestInvalidHostKey(t *testing.T) { t.Error("expected invalid host key error, got: nil") } } + +func TestKeepAliveNonBlocking(t *testing.T) { + session := &Session{keepAliveReason: make(chan string, 1)} + for i := 0; i < 2; i++ { + session.keepAlive("io") + } + + // if keepAlive blocks, we'll never reach this and timeout the test + // timing out +} + +func TestNotifyHostOfActivity(t *testing.T) { + notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) { + var req []interface{} + if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { + return nil, fmt.Errorf("unmarshal req: %w", err) + } + if len(req) < 2 { + return nil, errors.New("request arguments is less than 2") + } + + if clientName, ok := req[0].(string); ok { + if clientName != mockClientName { + return nil, fmt.Errorf( + "unexpected clientName param, expected: %q, got: %q", mockClientName, clientName, + ) + } + } else { + return nil, errors.New("clientName param is not a string") + } + + if acs, ok := req[1].([]interface{}); ok { + if fmt.Sprintf("%s", acs) != "[input]" { + return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs) + } + } else { + return nil, errors.New("activities param is not a slice") + } + + return nil, nil + } + svc := livesharetest.WithService( + "ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity, + ) + testServer, session, err := makeMockSession(svc) + if err != nil { + t.Errorf("creating mock session: %w", err) + } + defer testServer.Close() + ctx := context.Background() + done := make(chan error) + go func() { + done <- session.notifyHostOfActivity(ctx, "input") + }() + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %w", err) + } + } +} + +func TestSessionHeartbeat(t *testing.T) { + var ( + requestsMu sync.Mutex + requests int + ) + notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) { + requestsMu.Lock() + requests++ + requestsMu.Unlock() + + var req []interface{} + if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { + return nil, fmt.Errorf("unmarshal req: %w", err) + } + if len(req) < 2 { + return nil, errors.New("request arguments is less than 2") + } + + if clientName, ok := req[0].(string); ok { + if clientName != mockClientName { + return nil, fmt.Errorf( + "unexpected clientName param, expected: %q, got: %q", mockClientName, clientName, + ) + } + } else { + return nil, errors.New("clientName param is not a string") + } + + if acs, ok := req[1].([]interface{}); ok { + if fmt.Sprintf("%s", acs) != "[input]" { + return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs) + } + } else { + return nil, errors.New("activities param is not a slice") + } + + return nil, nil + } + svc := livesharetest.WithService( + "ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity, + ) + testServer, session, err := makeMockSession(svc) + if err != nil { + t.Errorf("creating mock session: %w", err) + } + defer testServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan struct{}) + + logger := newMockLogger() + session.logger = logger + + go session.heartbeat(ctx, 50*time.Millisecond) + go func() { + session.keepAlive("input") + <-time.Tick(200 * time.Millisecond) + session.keepAlive("input") + <-time.Tick(100 * time.Millisecond) + done <- struct{}{} + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case <-done: + activityCount := strings.Count(logger.String(), "input") + if activityCount != 2 { + t.Errorf("unexpected number of activities, expected: 2, got: %d", activityCount) + } + + requestsMu.Lock() + rc := requests + requestsMu.Unlock() + if rc != 2 { + t.Errorf("unexpected number of requests, expected: 2, got: %d", requests) + } + return + } +} + +type mockLogger struct { + sync.Mutex + buf *bytes.Buffer +} + +func newMockLogger() *mockLogger { + return &mockLogger{buf: new(bytes.Buffer)} +} + +func (m *mockLogger) Printf(format string, v ...interface{}) { + m.Lock() + defer m.Unlock() + m.buf.WriteString(fmt.Sprintf(format, v...)) +} + +func (m *mockLogger) Println(v ...interface{}) { + m.Lock() + defer m.Unlock() + m.buf.WriteString(fmt.Sprintln(v...)) +} + +func (m *mockLogger) String() string { + m.Lock() + defer m.Unlock() + return m.buf.String() +}