From 55f4fcf05c51896a98e6223615a4faaf4210e0b6 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 7 Oct 2021 10:42:06 -0400 Subject: [PATCH 1/9] Live Share session activity detection - Session now accepts two new options: ClientName and Logger - Port forwarder now supports a keepAlive parameter which when true, instructs the PF to call the session's keepAlive method. - Port forwarder uses a new traffic monitor to detect I/O traffic and notify the session when applicable. - The SSH command introduces a new debug flag which enables the command to log to a new temporary file. The file path is printed to the user. --- internal/codespaces/codespaces.go | 5 ++- internal/codespaces/states.go | 4 +- pkg/cmd/codespace/logs.go | 4 +- pkg/cmd/codespace/output/logger.go | 6 ++- pkg/cmd/codespace/ports.go | 8 ++-- pkg/cmd/codespace/ssh.go | 72 +++++++++++++++++++++++++----- pkg/liveshare/client.go | 36 ++++++++++++++- pkg/liveshare/port_forwarder.go | 45 +++++++++++++++++-- pkg/liveshare/session.go | 45 +++++++++++++++++++ 9 files changed, 199 insertions(+), 26 deletions(-) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index ab013409b..3c7be9c01 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -12,6 +12,7 @@ import ( type logger interface { Print(v ...interface{}) (int, error) + Printf(f string, v ...interface{}) (int, error) Println(v ...interface{}) (int, error) } @@ -30,7 +31,7 @@ type apiClient interface { // ConnectToLiveshare waits for a Codespace to become running, // and connects to it using a Live Share session. -func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) { +func ConnectToLiveshare(ctx context.Context, log, sessionLogger logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { startedCodespace = true @@ -67,10 +68,12 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, co log.Println("Connecting to your codespace...") return liveshare.Connect(ctx, liveshare.Options{ + ClientName: "gh", SessionID: codespace.Environment.Connection.SessionID, SessionToken: codespace.Environment.Connection.SessionToken, RelaySAS: codespace.Environment.Connection.RelaySAS, RelayEndpoint: codespace.Environment.Connection.RelayEndpoint, HostPublicKeys: codespace.Environment.Connection.HostPublicKeys, + Logger: sessionLogger, }) } diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index e8f197410..10d7dd4ca 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -37,7 +37,7 @@ type PostCreateState struct { // and calls the supplied poller for each batch of state changes. // It runs until it encounters an error, including cancellation of the context. func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { - session, err := ConnectToLiveshare(ctx, log, apiClient, codespace) + session, err := ConnectToLiveshare(ctx, log, nil, apiClient, codespace) if err != nil { return fmt.Errorf("connect to Live Share: %w", err) } @@ -62,7 +62,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness go func() { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false) tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil }() diff --git a/pkg/cmd/codespace/logs.go b/pkg/cmd/codespace/logs.go index 9bfc4a967..3a7c6948f 100644 --- a/pkg/cmd/codespace/logs.go +++ b/pkg/cmd/codespace/logs.go @@ -51,7 +51,7 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err return fmt.Errorf("get or choose codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, nil, a.apiClient, codespace) if err != nil { return fmt.Errorf("connecting to Live Share: %w", err) } @@ -90,7 +90,7 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err tunnelClosed := make(chan error, 1) go func() { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false) tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil }() diff --git a/pkg/cmd/codespace/output/logger.go b/pkg/cmd/codespace/output/logger.go index 6ad7513f1..fdefcad0f 100644 --- a/pkg/cmd/codespace/output/logger.go +++ b/pkg/cmd/codespace/output/logger.go @@ -9,10 +9,14 @@ import ( // NewLogger returns a Logger that will write to the given stdout/stderr writers. // Disable the Logger to prevent it from writing to stdout in a TTY environment. func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger { + enabled := !disabled + if isTTY(stdout) && !enabled { + enabled = false + } return &Logger{ out: stdout, errout: stderr, - enabled: !disabled && isTTY(stdout), + enabled: enabled, } } diff --git a/pkg/cmd/codespace/ports.go b/pkg/cmd/codespace/ports.go index c36f46078..5b4b7fadd 100644 --- a/pkg/cmd/codespace/ports.go +++ b/pkg/cmd/codespace/ports.go @@ -60,7 +60,7 @@ func (a *App) ListPorts(ctx context.Context, codespaceName string, asJSON bool) devContainerCh := getDevContainer(ctx, a.apiClient, codespace) - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, nil, a.apiClient, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -194,7 +194,7 @@ func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName, sourcePor return fmt.Errorf("error getting codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, nil, a.apiClient, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -253,7 +253,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st return fmt.Errorf("error getting codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, nil, a.apiClient, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -272,7 +272,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st defer listen.Close() a.logger.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) name := fmt.Sprintf("share-%d", pair.remote) - fwd := liveshare.NewPortForwarder(session, name, pair.remote) + fwd := liveshare.NewPortForwarder(session, name, pair.remote, false) return fwd.ForwardToListener(ctx, listen) // error always non-nil }) } diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index c90b542cb..56f62434d 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -3,34 +3,44 @@ package codespace import ( "context" "fmt" + "io/ioutil" "net" + "os" "github.com/cli/cli/v2/internal/codespaces" + "github.com/cli/cli/v2/pkg/cmd/codespace/output" "github.com/cli/cli/v2/pkg/liveshare" "github.com/spf13/cobra" ) +type sshOptions struct { + codespace string + profile string + serverPort int + debug bool +} + func newSSHCmd(app *App) *cobra.Command { - var sshProfile, codespaceName string - var sshServerPort int + var opts sshOptions sshCmd := &cobra.Command{ Use: "ssh [flags] [--] [ssh-flags] [command]", Short: "SSH into a codespace", RunE: func(cmd *cobra.Command, args []string) error { - return app.SSH(cmd.Context(), args, sshProfile, codespaceName, sshServerPort) + return app.SSH(cmd.Context(), args, opts) }, } - sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "Name of the SSH profile to use") - sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH server port number (0 => pick unused)") - sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "Name of the codespace") + sshCmd.Flags().StringVarP(&opts.profile, "profile", "", "", "Name of the SSH profile to use") + sshCmd.Flags().IntVarP(&opts.serverPort, "server-port", "", 0, "SSH server port number (0 => pick unused)") + sshCmd.Flags().StringVarP(&opts.codespace, "codespace", "c", "", "Name of the codespace") + sshCmd.Flags().BoolVarP(&opts.debug, "debug", "d", false, "Log debug data to a file") return sshCmd } // SSH opens an ssh session or runs an ssh command in a codespace. -func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) { +func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err error) { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -45,12 +55,22 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login) }() - codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName) + codespace, err := getOrChooseCodespace(ctx, a.apiClient, opts.codespace) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace) + var debugLogger *fileLogger + if opts.debug { + debugLogger, err = newFileLogger() + if err != nil { + return fmt.Errorf("error creating debug logger: %w", err) + } + defer safeClose(debugLogger, &err) + a.logger.Println("Debug file located at: " + debugLogger.Name()) + } + + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, debugLogger, a.apiClient, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -66,6 +86,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa return fmt.Errorf("error getting ssh server details: %w", err) } + localSSHServerPort := opts.serverPort usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell // Ensure local port is listening before client (Shell) connects. @@ -76,7 +97,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa defer listen.Close() localSSHServerPort = listen.Addr().(*net.TCPAddr).Port - connectDestination := sshProfile + connectDestination := opts.profile if connectDestination == "" { connectDestination = fmt.Sprintf("%s@localhost", sshUser) } @@ -84,7 +105,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa a.logger.Println("Ready...") tunnelClosed := make(chan error, 1) go func() { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true) tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil }() @@ -103,3 +124,32 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa return nil // success } } + +// fileLogger is a wrapper around an output.Logger configured to write +// to a file. It exports two additional methods to get the log file name +// and close the file handle when the operation is finished. +type fileLogger struct { + // TODO(josebalius): should we use https://pkg.go.dev/log#New instead? + *output.Logger + + f *os.File +} + +func newFileLogger() (*fileLogger, error) { + f, err := ioutil.TempFile("", "gh-cs-ssh") + if err != nil { + return nil, err + } + return &fileLogger{ + Logger: output.NewLogger(f, f, false), + f: f, + }, nil +} + +func (fl *fileLogger) Name() string { + return fl.f.Name() +} + +func (fl *fileLogger) Close() error { + return fl.f.Close() +} diff --git a/pkg/liveshare/client.go b/pkg/liveshare/client.go index 913f19195..ccf57b08a 100644 --- a/pkg/liveshare/client.go +++ b/pkg/liveshare/client.go @@ -22,18 +22,38 @@ import ( "golang.org/x/crypto/ssh" ) +type logger interface { + Println(v ...interface{}) (int, error) + Printf(f string, v ...interface{}) (int, error) +} + +type noopLogger struct{} + +func (n noopLogger) Println(...interface{}) (int, error) { + return 0, nil +} + +func (n noopLogger) Printf(string, ...interface{}) (int, error) { + return 0, nil +} + // An Options specifies Live Share connection parameters. type Options struct { + ClientName string // ClientName is the name of the connecting client. SessionID string SessionToken string // token for SSH session RelaySAS string RelayEndpoint string HostPublicKeys []string TLSConfig *tls.Config // (optional) + Logger logger // (optional) } // uri returns a websocket URL for the specified options. func (opts *Options) uri(action string) (string, error) { + if opts.ClientName == "" { + return "", errors.New("ClientName is required") + } if opts.SessionID == "" { return "", errors.New("SessionID is required") } @@ -61,6 +81,11 @@ func Connect(ctx context.Context, opts Options) (*Session, error) { return nil, err } + var sessionLogger logger = noopLogger{} + if opts.Logger != nil { + sessionLogger = opts.Logger + } + span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") defer span.Finish() @@ -93,7 +118,16 @@ func Connect(ctx context.Context, opts Options) (*Session, error) { return nil, fmt.Errorf("error joining Live Share workspace: %w", err) } - return &Session{ssh: ssh, rpc: rpc}, nil + s := &Session{ + ssh: ssh, + rpc: rpc, + clientName: opts.ClientName, + keepAliveReason: make(chan string, 1), + logger: sessionLogger, + } + go s.heartbeat(ctx) + + return s, nil } type clientCapabilities struct { diff --git a/pkg/liveshare/port_forwarder.go b/pkg/liveshare/port_forwarder.go index fcc7ba767..43b6da805 100644 --- a/pkg/liveshare/port_forwarder.go +++ b/pkg/liveshare/port_forwarder.go @@ -15,16 +15,19 @@ type PortForwarder struct { session *Session name string remotePort int + keepAlive bool } // NewPortForwarder returns a new PortForwarder for the specified // remote port and Live Share session. The name describes the purpose -// of the remote port or service. -func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder { +// of the remote port or service. The keepAlive flag indicates whether +// the session should be kept alive with port forwarding traffic. +func NewPortForwarder(session *Session, name string, remotePort int, keepAlive bool) *PortForwarder { return &PortForwarder{ session: session, name: name, remotePort: remotePort, + keepAlive: keepAlive, } } @@ -106,6 +109,27 @@ func awaitError(ctx context.Context, errc <-chan error) error { } } +// trafficMonitor implements io.Reader. It keeps the session alive by notifying +// it of the traffic type during Read operations. +type trafficMonitor struct { + reader io.Reader + + session *Session + trafficType string +} + +// newTrafficMonitor returns a new trafficMonitor for the specified +// session and traffic type. It wraps the provided io.Reader with its own +// Read method. +func newTrafficMonitor(reader io.Reader, session *Session, trafficType string) *trafficMonitor { + return &trafficMonitor{reader, session, trafficType} +} + +func (t *trafficMonitor) Read(p []byte) (n int, err error) { + t.session.keepAlive(t.trafficType) + return t.reader.Read(p) +} + // handleConnection handles forwarding for a single accepted connection, then closes it. func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) { span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection") @@ -133,8 +157,21 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co _, err := io.Copy(w, r) errs <- err } - go copyConn(conn, channel) - go copyConn(channel, conn) + + var ( + channelReader io.Reader = channel + connReader io.Reader = conn + ) + + // If we the port forwader has been configured to keep the session alive + // it will monitor the I/O and notify the session of the traffic. + if fwd.keepAlive { + channelReader = newTrafficMonitor(channelReader, fwd.session, "output") + connReader = newTrafficMonitor(connReader, fwd.session, "input") + } + + go copyConn(conn, channelReader) + go copyConn(channel, connReader) // Wait until context is cancelled or both copies are done. // Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail. diff --git a/pkg/liveshare/session.go b/pkg/liveshare/session.go index 929e8605b..329ea1a2e 100644 --- a/pkg/liveshare/session.go +++ b/pkg/liveshare/session.go @@ -4,12 +4,17 @@ import ( "context" "fmt" "strconv" + "time" ) // A Session represents the session between a connected Live Share client and server. type Session struct { ssh *sshSession rpc *rpcClient + + clientName string + keepAliveReason chan string + logger logger } // Close should be called by users to clean up RPC and SSH resources whenever the session @@ -97,3 +102,43 @@ func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) { return port, response.User, nil } + +// heartbeat ticks every minute and sends a signal to the Live Share host to keep +// the connection alive if there is a reason to do so. +func (s *Session) heartbeat(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.logger.Println("Running session heartbeat") + reason := <-s.keepAliveReason + s.logger.Println("Keep alive reason: " + reason) + if err := s.notifyHostOfActivity(ctx, reason); err != nil { + s.logger.Printf("Failed to notify host of activity: %s\n", err) + } + } + } + s.logger.Println("Ending session heartbeat") +} + +// notifyHostOfActivity notifies the Live Share host of client activity. +func (s *Session) notifyHostOfActivity(ctx context.Context, activity string) error { + activities := []string{activity} + params := []interface{}{s.clientName, activities} + return s.rpc.do(ctx, "ICodespaceHostService.notifyCodespaceOfClientActivity", params, nil) +} + +// keepAlive accepts a reason that is retained if there is no active reason +// to send to the server. +func (s *Session) keepAlive(reason string) { + select { + case s.keepAliveReason <- reason: + default: + // there is already an active keep alive reason + // so we can ignore this one + } +} From 8f5d6bb672e889ed8723a7f5bbc22da1c0a9ef12 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 7 Oct 2021 15:14:42 -0400 Subject: [PATCH 2/9] Tests for most of the new behavior - Made the heartbeat interval configurable for easier testing - Moved span to the top of connect to capture the full execution --- pkg/liveshare/client.go | 9 +- pkg/liveshare/client_test.go | 1 + pkg/liveshare/options_test.go | 1 + pkg/liveshare/port_forwarder_test.go | 27 ++++- pkg/liveshare/session.go | 6 +- pkg/liveshare/session_test.go | 166 +++++++++++++++++++++++++++ 6 files changed, 201 insertions(+), 9 deletions(-) diff --git a/pkg/liveshare/client.go b/pkg/liveshare/client.go index ccf57b08a..c3e92004d 100644 --- a/pkg/liveshare/client.go +++ b/pkg/liveshare/client.go @@ -17,6 +17,7 @@ import ( "fmt" "net/url" "strings" + "time" "github.com/opentracing/opentracing-go" "golang.org/x/crypto/ssh" @@ -76,6 +77,9 @@ func (opts *Options) uri(action string) (string, error) { // options, and returns a session representing the connection. // The caller must call the session's Close method to end the session. func Connect(ctx context.Context, opts Options) (*Session, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") + defer span.Finish() + uri, err := opts.uri("connect") if err != nil { return nil, err @@ -86,9 +90,6 @@ func Connect(ctx context.Context, opts Options) (*Session, error) { sessionLogger = opts.Logger } - span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") - defer span.Finish() - sock := newSocket(uri, opts.TLSConfig) if err := sock.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting websocket: %w", err) @@ -125,7 +126,7 @@ func Connect(ctx context.Context, opts Options) (*Session, error) { keepAliveReason: make(chan string, 1), logger: sessionLogger, } - go s.heartbeat(ctx) + go s.heartbeat(ctx, 1*time.Minute) return s, nil } diff --git a/pkg/liveshare/client_test.go b/pkg/liveshare/client_test.go index 46807a22e..a775ba4af 100644 --- a/pkg/liveshare/client_test.go +++ b/pkg/liveshare/client_test.go @@ -15,6 +15,7 @@ import ( func TestConnect(t *testing.T) { opts := Options{ + ClientName: "liveshare-client", SessionID: "session-id", SessionToken: "session-token", RelaySAS: "relay-sas", diff --git a/pkg/liveshare/options_test.go b/pkg/liveshare/options_test.go index 830c59104..d244193b4 100644 --- a/pkg/liveshare/options_test.go +++ b/pkg/liveshare/options_test.go @@ -41,6 +41,7 @@ func checkBadOptions(t *testing.T, opts Options) { func TestOptionsURI(t *testing.T) { opts := Options{ + ClientName: "liveshare-client", SessionID: "sess-id", SessionToken: "sess-token", RelaySAS: "sas", diff --git a/pkg/liveshare/port_forwarder_test.go b/pkg/liveshare/port_forwarder_test.go index 624428dda..c5b61d430 100644 --- a/pkg/liveshare/port_forwarder_test.go +++ b/pkg/liveshare/port_forwarder_test.go @@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) { t.Errorf("create mock client: %w", err) } defer testServer.Close() - pf := NewPortForwarder(session, "ssh", 80) + pf := NewPortForwarder(session, "ssh", 80, false) if pf == nil { t.Error("port forwarder is nil") } @@ -58,7 +58,7 @@ func TestPortForwarderStart(t *testing.T) { done := make(chan error) go func() { const name, remote = "ssh", 8000 - done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen) + done <- NewPortForwarder(session, name, remote, false).ForwardToListener(ctx, listen) }() go func() { @@ -93,3 +93,26 @@ func TestPortForwarderStart(t *testing.T) { } } } + +func TestPortForwarderTrafficMonitor(t *testing.T) { + buf := bytes.NewBufferString("some-input") + session := &Session{keepAliveReason: make(chan string, 1)} + trafficType := "io" + + tm := newTrafficMonitor(buf, session, trafficType) + l := len(buf.Bytes()) + + bb := make([]byte, l) + n, err := tm.Read(bb) + if err != nil { + t.Errorf("failed to read from traffic monitor: %w", err) + } + if n != l { + t.Errorf("expected to read %d bytes, got %d", l, n) + } + + keepAliveReason := <-session.keepAliveReason + if keepAliveReason != trafficType { + t.Errorf("expected keep alive reason to be %s, got %s", trafficType, keepAliveReason) + } +} diff --git a/pkg/liveshare/session.go b/pkg/liveshare/session.go index 329ea1a2e..4815ae77a 100644 --- a/pkg/liveshare/session.go +++ b/pkg/liveshare/session.go @@ -103,10 +103,10 @@ func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) { return port, response.User, nil } -// heartbeat ticks every minute and sends a signal to the Live Share host to keep +// heartbeat ticks every interval and sends a signal to the Live Share host to keep // the connection alive if there is a reason to do so. -func (s *Session) heartbeat(ctx context.Context) { - ticker := time.NewTicker(1 * time.Minute) +func (s *Session) heartbeat(ctx context.Context, interval time.Duration) { + ticker := time.NewTicker(interval) defer ticker.Stop() for { diff --git a/pkg/liveshare/session_test.go b/pkg/liveshare/session_test.go index 7f0b573b5..fdbcab2b5 100644 --- a/pkg/liveshare/session_test.go +++ b/pkg/liveshare/session_test.go @@ -1,6 +1,7 @@ package liveshare import ( + "bytes" "context" "crypto/tls" "encoding/json" @@ -8,11 +9,14 @@ import ( "fmt" "strings" "testing" + "time" livesharetest "github.com/cli/cli/v2/pkg/liveshare/test" "github.com/sourcegraph/jsonrpc2" ) +const mockClientName = "liveshare-client" + func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { return joinWorkspaceResult{1}, nil @@ -29,6 +33,7 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, } session, err := Connect(context.Background(), Options{ + ClientName: mockClientName, SessionID: "session-id", SessionToken: sessionToken, RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), @@ -221,3 +226,164 @@ func TestInvalidHostKey(t *testing.T) { t.Error("expected invalid host key error, got: nil") } } + +func TestKeepAliveNonBlocking(t *testing.T) { + session := &Session{keepAliveReason: make(chan string, 1)} + var i int + for ; i < 2; i++ { + session.keepAlive("io") + } + + // if keepAlive blocks, we'll never reach this and timeout the test + // timing out + if i != 2 { + t.Errorf("unexpected iteration account, expected: 2, got: %d", i) + } +} + +func TestNotifyHostOfActivity(t *testing.T) { + notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) { + var req []interface{} + if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { + return nil, fmt.Errorf("unmarshal req: %w", err) + } + if len(req) < 2 { + return nil, errors.New("request arguments is less than 2") + } + + if clientName, ok := req[0].(string); ok { + if clientName != mockClientName { + return nil, fmt.Errorf( + "unexpected clientName param, expected: %q, got: %q", mockClientName, clientName, + ) + } + } else { + return nil, errors.New("clientName param is not a string") + } + + if acs, ok := req[1].([]interface{}); ok { + if fmt.Sprintf("%s", acs) != "[input]" { + return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs) + } + } else { + return nil, errors.New("activities param is not a slice") + } + + return nil, nil + } + svc := livesharetest.WithService( + "ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity, + ) + testServer, session, err := makeMockSession(svc) + if err != nil { + t.Errorf("creating mock session: %w", err) + } + defer testServer.Close() + ctx := context.Background() + done := make(chan error) + go func() { + done <- session.notifyHostOfActivity(ctx, "input") + }() + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %w", err) + } + } +} + +func TestSessionHeartbeat(t *testing.T) { + var requests int + notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) { + requests++ + + var req []interface{} + if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { + return nil, fmt.Errorf("unmarshal req: %w", err) + } + if len(req) < 2 { + return nil, errors.New("request arguments is less than 2") + } + + if clientName, ok := req[0].(string); ok { + if clientName != mockClientName { + return nil, fmt.Errorf( + "unexpected clientName param, expected: %q, got: %q", mockClientName, clientName, + ) + } + } else { + return nil, errors.New("clientName param is not a string") + } + + if acs, ok := req[1].([]interface{}); ok { + if fmt.Sprintf("%s", acs) != "[input]" { + return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs) + } + } else { + return nil, errors.New("activities param is not a slice") + } + + return nil, nil + } + svc := livesharetest.WithService( + "ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity, + ) + testServer, session, err := makeMockSession(svc) + if err != nil { + t.Errorf("creating mock session: %w", err) + } + defer testServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan struct{}) + + logger := newMockLogger() + session.logger = logger + + go session.heartbeat(ctx, 50*time.Millisecond) + go func() { + session.keepAlive("input") + <-time.Tick(100 * time.Millisecond) + session.keepAlive("input") + <-time.Tick(100 * time.Millisecond) + done <- struct{}{} + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case <-done: + activityCount := strings.Count(logger.String(), "input") + if activityCount != 2 { + t.Errorf("unexpected number of activities, expected: 2, got: %d", activityCount) + } + if requests != 2 { + t.Errorf("unexpected number of requests, expected: 2, got: %d", requests) + } + return + } +} + +type mockLogger struct { + buf *bytes.Buffer +} + +func newMockLogger() *mockLogger { + return &mockLogger{new(bytes.Buffer)} +} + +func (m *mockLogger) Printf(format string, v ...interface{}) (int, error) { + return m.buf.WriteString(fmt.Sprintf(format, v...)) +} + +func (m *mockLogger) Println(v ...interface{}) (int, error) { + return m.buf.WriteString(fmt.Sprintln(v...)) +} + +func (m *mockLogger) String() string { + return m.buf.String() +} From 7ba2fb4c0ed1a376dbb2cf76d2bfd6f6c810cc7a Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 7 Oct 2021 15:19:14 -0400 Subject: [PATCH 3/9] Make fileLogger more versatile --- pkg/cmd/codespace/ssh.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 56f62434d..1f183e9d1 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -62,7 +62,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e var debugLogger *fileLogger if opts.debug { - debugLogger, err = newFileLogger() + debugLogger, err = newFileLogger("gh-cs-ssh") if err != nil { return fmt.Errorf("error creating debug logger: %w", err) } @@ -135,8 +135,11 @@ type fileLogger struct { f *os.File } -func newFileLogger() (*fileLogger, error) { - f, err := ioutil.TempFile("", "gh-cs-ssh") +// newFileLogger creates a new fileLogger. It returns an error if the file +// cannot be created. The file is created in the operating system tmp directory +// under the name parameter. +func newFileLogger(name string) (*fileLogger, error) { + f, err := ioutil.TempFile("", name) if err != nil { return nil, err } From 2406f3f09a0d618a5d442e5b34b09dbd25adbaad Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 7 Oct 2021 15:32:07 -0400 Subject: [PATCH 4/9] Fix races and remove unreachable code --- pkg/liveshare/session.go | 3 +-- pkg/liveshare/session_test.go | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pkg/liveshare/session.go b/pkg/liveshare/session.go index 4815ae77a..ca01d4cf3 100644 --- a/pkg/liveshare/session.go +++ b/pkg/liveshare/session.go @@ -114,7 +114,7 @@ func (s *Session) heartbeat(ctx context.Context, interval time.Duration) { case <-ctx.Done(): return case <-ticker.C: - s.logger.Println("Running session heartbeat") + s.logger.Println("Heartbeat tick") reason := <-s.keepAliveReason s.logger.Println("Keep alive reason: " + reason) if err := s.notifyHostOfActivity(ctx, reason); err != nil { @@ -122,7 +122,6 @@ func (s *Session) heartbeat(ctx context.Context, interval time.Duration) { } } } - s.logger.Println("Ending session heartbeat") } // notifyHostOfActivity notifies the Live Share host of client activity. diff --git a/pkg/liveshare/session_test.go b/pkg/liveshare/session_test.go index fdbcab2b5..e2ef891f2 100644 --- a/pkg/liveshare/session_test.go +++ b/pkg/liveshare/session_test.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "strings" + "sync" "testing" "time" @@ -295,9 +296,14 @@ func TestNotifyHostOfActivity(t *testing.T) { } func TestSessionHeartbeat(t *testing.T) { - var requests int + var ( + requestsMu sync.Mutex + requests int + ) notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) { + requestsMu.Lock() requests++ + requestsMu.Unlock() var req []interface{} if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { @@ -369,21 +375,28 @@ func TestSessionHeartbeat(t *testing.T) { } type mockLogger struct { + sync.Mutex buf *bytes.Buffer } func newMockLogger() *mockLogger { - return &mockLogger{new(bytes.Buffer)} + return &mockLogger{buf: new(bytes.Buffer)} } func (m *mockLogger) Printf(format string, v ...interface{}) (int, error) { + m.Lock() + defer m.Unlock() return m.buf.WriteString(fmt.Sprintf(format, v...)) } func (m *mockLogger) Println(v ...interface{}) (int, error) { + m.Lock() + defer m.Unlock() return m.buf.WriteString(fmt.Sprintln(v...)) } func (m *mockLogger) String() string { + m.Lock() + defer m.Unlock() return m.buf.String() } From 8a559ee12a0cd4b564b68564a0aef803c69f3d87 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 7 Oct 2021 15:38:16 -0400 Subject: [PATCH 5/9] Fix unrelated tests --- pkg/cmd/codespace/delete_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/cmd/codespace/delete_test.go b/pkg/cmd/codespace/delete_test.go index 14896098a..68c7892de 100644 --- a/pkg/cmd/codespace/delete_test.go +++ b/pkg/cmd/codespace/delete_test.go @@ -44,6 +44,7 @@ func TestDelete(t *testing.T) { }, }, wantDeleted: []string{"hubot-robawt-abc"}, + wantStdout: "Codespace deleted.\n", }, { name: "by repo", @@ -65,6 +66,7 @@ func TestDelete(t *testing.T) { }, }, wantDeleted: []string{"monalisa-spoonknife-123", "monalisa-spoonknife-c4f3"}, + wantStdout: "Codespaces deleted.\n", }, { name: "unused", @@ -87,6 +89,7 @@ func TestDelete(t *testing.T) { }, }, wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"}, + wantStdout: "Codespaces deleted.\n", }, { name: "deletion failed", @@ -148,6 +151,7 @@ func TestDelete(t *testing.T) { "Codespace hubot-robawt-abc has unsaved changes. OK to delete?": true, }, wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"}, + wantStdout: "Codespaces deleted.\n", }, } for _, tt := range tests { From 97cbdca84a2e11f1d249cf34d2760ebf5c7a3faa Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 7 Oct 2021 15:45:55 -0400 Subject: [PATCH 6/9] Fix additional race in tests --- pkg/liveshare/session_test.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pkg/liveshare/session_test.go b/pkg/liveshare/session_test.go index e2ef891f2..e6ebb3645 100644 --- a/pkg/liveshare/session_test.go +++ b/pkg/liveshare/session_test.go @@ -367,7 +367,11 @@ func TestSessionHeartbeat(t *testing.T) { if activityCount != 2 { t.Errorf("unexpected number of activities, expected: 2, got: %d", activityCount) } - if requests != 2 { + + requestsMu.Lock() + rc := requests + requestsMu.Unlock() + if rc != 2 { t.Errorf("unexpected number of requests, expected: 2, got: %d", requests) } return From 1ff58a3de734f93351111775e64a7fafd99cb683 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 7 Oct 2021 16:39:43 -0400 Subject: [PATCH 7/9] Update docs, remove needless condition check --- pkg/liveshare/port_forwarder.go | 2 +- pkg/liveshare/session.go | 4 ++-- pkg/liveshare/session_test.go | 6 +----- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/pkg/liveshare/port_forwarder.go b/pkg/liveshare/port_forwarder.go index 43b6da805..2649abd3c 100644 --- a/pkg/liveshare/port_forwarder.go +++ b/pkg/liveshare/port_forwarder.go @@ -163,7 +163,7 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co connReader io.Reader = conn ) - // If we the port forwader has been configured to keep the session alive + // If the forwader has been configured to keep the session alive // it will monitor the I/O and notify the session of the traffic. if fwd.keepAlive { channelReader = newTrafficMonitor(channelReader, fwd.session, "output") diff --git a/pkg/liveshare/session.go b/pkg/liveshare/session.go index ca01d4cf3..13558f911 100644 --- a/pkg/liveshare/session.go +++ b/pkg/liveshare/session.go @@ -103,8 +103,8 @@ func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) { return port, response.User, nil } -// heartbeat ticks every interval and sends a signal to the Live Share host to keep -// the connection alive if there is a reason to do so. +// heartbeat runs until context cancellation, periodically checking whether there is a +// reason to keep the connection alive, and if so, notifying the Live Share host to do so. func (s *Session) heartbeat(ctx context.Context, interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() diff --git a/pkg/liveshare/session_test.go b/pkg/liveshare/session_test.go index e6ebb3645..0e10644fb 100644 --- a/pkg/liveshare/session_test.go +++ b/pkg/liveshare/session_test.go @@ -230,16 +230,12 @@ func TestInvalidHostKey(t *testing.T) { func TestKeepAliveNonBlocking(t *testing.T) { session := &Session{keepAliveReason: make(chan string, 1)} - var i int - for ; i < 2; i++ { + for i := 0; i < 2; i++ { session.keepAlive("io") } // if keepAlive blocks, we'll never reach this and timeout the test // timing out - if i != 2 { - t.Errorf("unexpected iteration account, expected: 2, got: %d", i) - } } func TestNotifyHostOfActivity(t *testing.T) { From 1aefc7437834522f1bd1fbc3c1f7ff5cbf7fa801 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 7 Oct 2021 16:48:09 -0400 Subject: [PATCH 8/9] Add more time between events --- pkg/liveshare/session_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/liveshare/session_test.go b/pkg/liveshare/session_test.go index 0e10644fb..3be528fe8 100644 --- a/pkg/liveshare/session_test.go +++ b/pkg/liveshare/session_test.go @@ -349,7 +349,7 @@ func TestSessionHeartbeat(t *testing.T) { go session.heartbeat(ctx, 50*time.Millisecond) go func() { session.keepAlive("input") - <-time.Tick(100 * time.Millisecond) + <-time.Tick(200 * time.Millisecond) session.keepAlive("input") <-time.Tick(100 * time.Millisecond) done <- struct{}{} From 5170a2931f9cd8378fafeba89cab1db014a2de43 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 12 Oct 2021 15:45:05 -0400 Subject: [PATCH 9/9] Switch to standard lib log.Logger & support dfile - --debug-file flag can now be used in conjuction with --debug to specify the debug file path - Push out logger concerns to callers of liveshare --- internal/codespaces/codespaces.go | 10 +++++++-- internal/codespaces/states.go | 10 ++++++--- pkg/cmd/codespace/common.go | 6 ++++++ pkg/cmd/codespace/logs.go | 2 +- pkg/cmd/codespace/ports.go | 6 +++--- pkg/cmd/codespace/ssh.go | 34 ++++++++++++++++++++----------- pkg/liveshare/client.go | 23 ++++++--------------- pkg/liveshare/client_test.go | 1 + pkg/liveshare/session_test.go | 9 ++++---- 9 files changed, 59 insertions(+), 42 deletions(-) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 3c7be9c01..c9475ad4d 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -12,10 +12,16 @@ import ( type logger interface { Print(v ...interface{}) (int, error) - Printf(f string, v ...interface{}) (int, error) Println(v ...interface{}) (int, error) } +// TODO(josebalius): clean this up once we standardrize +// logging for codespaces +type liveshareLogger interface { + Println(v ...interface{}) + Printf(f string, v ...interface{}) +} + func connectionReady(codespace *api.Codespace) bool { return codespace.Environment.Connection.SessionID != "" && codespace.Environment.Connection.SessionToken != "" && @@ -31,7 +37,7 @@ type apiClient interface { // ConnectToLiveshare waits for a Codespace to become running, // and connects to it using a Live Share session. -func ConnectToLiveshare(ctx context.Context, log, sessionLogger logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) { +func ConnectToLiveshare(ctx context.Context, log logger, sessionLogger liveshareLogger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { startedCodespace = true diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 10d7dd4ca..00170d3ec 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -5,6 +5,8 @@ import ( "context" "encoding/json" "fmt" + "io/ioutil" + "log" "net" "strings" "time" @@ -36,8 +38,10 @@ type PostCreateState struct { // PollPostCreateStates watches for state changes in a codespace, // and calls the supplied poller for each batch of state changes. // It runs until it encounters an error, including cancellation of the context. -func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { - session, err := ConnectToLiveshare(ctx, log, nil, apiClient, codespace) +func PollPostCreateStates(ctx context.Context, logger logger, apiClient apiClient, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { + noopLogger := log.New(ioutil.Discard, "", 0) + + session, err := ConnectToLiveshare(ctx, logger, noopLogger, apiClient, codespace) if err != nil { return fmt.Errorf("connect to Live Share: %w", err) } @@ -54,7 +58,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, } localPort := listen.Addr().(*net.TCPAddr).Port - log.Println("Fetching SSH Details...") + logger.Println("Fetching SSH Details...") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) if err != nil { return fmt.Errorf("error getting ssh server details: %w", err) diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index d304f8d0d..d49995344 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -7,6 +7,8 @@ import ( "errors" "fmt" "io" + "io/ioutil" + "log" "os" "sort" "strings" @@ -211,6 +213,10 @@ func noArgsConstraint(cmd *cobra.Command, args []string) error { return nil } +func noopLogger() *log.Logger { + return log.New(ioutil.Discard, "", 0) +} + type codespace struct { *api.Codespace } diff --git a/pkg/cmd/codespace/logs.go b/pkg/cmd/codespace/logs.go index 3a7c6948f..317059918 100644 --- a/pkg/cmd/codespace/logs.go +++ b/pkg/cmd/codespace/logs.go @@ -51,7 +51,7 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err return fmt.Errorf("get or choose codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, nil, a.apiClient, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace) if err != nil { return fmt.Errorf("connecting to Live Share: %w", err) } diff --git a/pkg/cmd/codespace/ports.go b/pkg/cmd/codespace/ports.go index 5b4b7fadd..898ede494 100644 --- a/pkg/cmd/codespace/ports.go +++ b/pkg/cmd/codespace/ports.go @@ -60,7 +60,7 @@ func (a *App) ListPorts(ctx context.Context, codespaceName string, asJSON bool) devContainerCh := getDevContainer(ctx, a.apiClient, codespace) - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, nil, a.apiClient, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -194,7 +194,7 @@ func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName, sourcePor return fmt.Errorf("error getting codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, nil, a.apiClient, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -253,7 +253,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st return fmt.Errorf("error getting codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, a.logger, nil, a.apiClient, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 1f183e9d1..928bd044e 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -4,11 +4,11 @@ import ( "context" "fmt" "io/ioutil" + "log" "net" "os" "github.com/cli/cli/v2/internal/codespaces" - "github.com/cli/cli/v2/pkg/cmd/codespace/output" "github.com/cli/cli/v2/pkg/liveshare" "github.com/spf13/cobra" ) @@ -18,6 +18,7 @@ type sshOptions struct { profile string serverPort int debug bool + debugFile string } func newSSHCmd(app *App) *cobra.Command { @@ -35,6 +36,7 @@ func newSSHCmd(app *App) *cobra.Command { sshCmd.Flags().IntVarP(&opts.serverPort, "server-port", "", 0, "SSH server port number (0 => pick unused)") sshCmd.Flags().StringVarP(&opts.codespace, "codespace", "c", "", "Name of the codespace") sshCmd.Flags().BoolVarP(&opts.debug, "debug", "d", false, "Log debug data to a file") + sshCmd.Flags().StringVarP(&opts.debugFile, "debug-file", "", "", "Path of the file log to") return sshCmd } @@ -62,7 +64,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e var debugLogger *fileLogger if opts.debug { - debugLogger, err = newFileLogger("gh-cs-ssh") + debugLogger, err = newFileLogger(opts.debugFile) if err != nil { return fmt.Errorf("error creating debug logger: %w", err) } @@ -125,26 +127,34 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e } } -// fileLogger is a wrapper around an output.Logger configured to write +// fileLogger is a wrapper around an log.Logger configured to write // to a file. It exports two additional methods to get the log file name // and close the file handle when the operation is finished. type fileLogger struct { - // TODO(josebalius): should we use https://pkg.go.dev/log#New instead? - *output.Logger + *log.Logger f *os.File } // newFileLogger creates a new fileLogger. It returns an error if the file -// cannot be created. The file is created in the operating system tmp directory -// under the name parameter. -func newFileLogger(name string) (*fileLogger, error) { - f, err := ioutil.TempFile("", name) - if err != nil { - return nil, err +// cannot be created. The file is created on the specified path, if the path +// is empty it is created in the temporary directory. +func newFileLogger(file string) (fl *fileLogger, err error) { + var f *os.File + if file == "" { + f, err = ioutil.TempFile("", "") + if err != nil { + return nil, fmt.Errorf("failed to create tmp file: %w", err) + } + } else { + f, err = os.Create(file) + if err != nil { + return nil, err + } } + return &fileLogger{ - Logger: output.NewLogger(f, f, false), + Logger: log.New(f, "", log.LstdFlags), f: f, }, nil } diff --git a/pkg/liveshare/client.go b/pkg/liveshare/client.go index c3e92004d..840e99db9 100644 --- a/pkg/liveshare/client.go +++ b/pkg/liveshare/client.go @@ -24,18 +24,8 @@ import ( ) type logger interface { - Println(v ...interface{}) (int, error) - Printf(f string, v ...interface{}) (int, error) -} - -type noopLogger struct{} - -func (n noopLogger) Println(...interface{}) (int, error) { - return 0, nil -} - -func (n noopLogger) Printf(string, ...interface{}) (int, error) { - return 0, nil + Println(v ...interface{}) + Printf(f string, v ...interface{}) } // An Options specifies Live Share connection parameters. @@ -46,8 +36,8 @@ type Options struct { RelaySAS string RelayEndpoint string HostPublicKeys []string + Logger logger // required TLSConfig *tls.Config // (optional) - Logger logger // (optional) } // uri returns a websocket URL for the specified options. @@ -85,9 +75,8 @@ func Connect(ctx context.Context, opts Options) (*Session, error) { return nil, err } - var sessionLogger logger = noopLogger{} - if opts.Logger != nil { - sessionLogger = opts.Logger + if opts.Logger == nil { + return nil, errors.New("Logger is required") } sock := newSocket(uri, opts.TLSConfig) @@ -124,7 +113,7 @@ func Connect(ctx context.Context, opts Options) (*Session, error) { rpc: rpc, clientName: opts.ClientName, keepAliveReason: make(chan string, 1), - logger: sessionLogger, + logger: opts.Logger, } go s.heartbeat(ctx, 1*time.Minute) diff --git a/pkg/liveshare/client_test.go b/pkg/liveshare/client_test.go index a775ba4af..c6502d684 100644 --- a/pkg/liveshare/client_test.go +++ b/pkg/liveshare/client_test.go @@ -20,6 +20,7 @@ func TestConnect(t *testing.T) { SessionToken: "session-token", RelaySAS: "relay-sas", HostPublicKeys: []string{livesharetest.SSHPublicKey}, + Logger: newMockLogger(), } joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { var joinWorkspaceReq joinWorkspaceArgs diff --git a/pkg/liveshare/session_test.go b/pkg/liveshare/session_test.go index 3be528fe8..998de6ac0 100644 --- a/pkg/liveshare/session_test.go +++ b/pkg/liveshare/session_test.go @@ -41,6 +41,7 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, RelaySAS: "relay-sas", HostPublicKeys: []string{livesharetest.SSHPublicKey}, TLSConfig: &tls.Config{InsecureSkipVerify: true}, + Logger: newMockLogger(), }) if err != nil { return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err) @@ -383,16 +384,16 @@ func newMockLogger() *mockLogger { return &mockLogger{buf: new(bytes.Buffer)} } -func (m *mockLogger) Printf(format string, v ...interface{}) (int, error) { +func (m *mockLogger) Printf(format string, v ...interface{}) { m.Lock() defer m.Unlock() - return m.buf.WriteString(fmt.Sprintf(format, v...)) + m.buf.WriteString(fmt.Sprintf(format, v...)) } -func (m *mockLogger) Println(v ...interface{}) (int, error) { +func (m *mockLogger) Println(v ...interface{}) { m.Lock() defer m.Unlock() - return m.buf.WriteString(fmt.Sprintln(v...)) + m.buf.WriteString(fmt.Sprintln(v...)) } func (m *mockLogger) String() string {