diff --git a/client.go b/client.go index 65e80a94a..ba9d2f5e7 100644 --- a/client.go +++ b/client.go @@ -58,18 +58,18 @@ func (c *Client) JoinWorkspace(ctx context.Context) (*Session, error) { clientSocket := newSocket(c.connection, c.tlsConfig) if err := clientSocket.connect(ctx); err != nil { - return nil, fmt.Errorf("error connecting websocket: %v", err) + return nil, fmt.Errorf("error connecting websocket: %w", err) } ssh := newSSHSession(c.connection.SessionToken, clientSocket) if err := ssh.connect(ctx); err != nil { - return nil, fmt.Errorf("error connecting to ssh session: %v", err) + return nil, fmt.Errorf("error connecting to ssh session: %w", err) } rpc := newRPCClient(ssh) rpc.connect(ctx) if _, err := c.joinWorkspace(ctx, rpc); err != nil { - return nil, fmt.Errorf("error joining Live Share workspace: %v", err) + return nil, fmt.Errorf("error joining Live Share workspace: %w", err) } return &Session{ssh: ssh, rpc: rpc}, nil @@ -108,7 +108,7 @@ func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorksp var result joinWorkspaceResult if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { - return nil, fmt.Errorf("error making workspace.joinWorkspace call: %v", err) + return nil, fmt.Errorf("error making workspace.joinWorkspace call: %w", err) } return &result, nil @@ -125,7 +125,7 @@ func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.C } var streamID string if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { - return nil, fmt.Errorf("error getting stream id: %v", err) + return nil, fmt.Errorf("error getting stream id: %w", err) } span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") @@ -133,13 +133,13 @@ func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.C channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) if err != nil { - return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) + return nil, fmt.Errorf("error opening ssh channel for transport: %w", err) } go ssh.DiscardRequests(reqs) requestType := fmt.Sprintf("stream-transport-%s", streamID) if _, err = channel.SendRequest(requestType, true, nil); err != nil { - return nil, fmt.Errorf("error sending channel request: %v", err) + return nil, fmt.Errorf("error sending channel request: %w", err) } return channel, nil diff --git a/port_forwarder.go b/port_forwarder.go index 5dafd0c65..56401cc4d 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -92,7 +92,7 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) { id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort) if err != nil { - err = fmt.Errorf("failed to share remote port %d: %v", fwd.remotePort, err) + err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err) } return id, nil } @@ -115,7 +115,7 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co channel, err := fwd.session.openStreamingChannel(ctx, id) if err != nil { - return fmt.Errorf("error opening streaming channel for new connection: %v", err) + return fmt.Errorf("error opening streaming channel for new connection: %w", err) } // Ideally we would call safeClose again, but (*ssh.channel).Close // appears to have a bug that causes it return io.EOF spuriously diff --git a/rpc.go b/rpc.go index 68e187ad6..bfd214c89 100644 --- a/rpc.go +++ b/rpc.go @@ -20,7 +20,6 @@ func newRPCClient(conn io.ReadWriteCloser) *rpcClient { func (r *rpcClient) connect(ctx context.Context) { stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) - // TODO(adonovan): fix: ensure r.Conn is eventually Closed! r.Conn = jsonrpc2.NewConn(ctx, stream, nullHandler{}) } @@ -30,7 +29,7 @@ func (r *rpcClient) do(ctx context.Context, method string, args, result interfac waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { - return fmt.Errorf("error dispatching %q call: %v", method, err) + return fmt.Errorf("error dispatching %q call: %w", method, err) } return waiter.Wait(ctx, result) diff --git a/session.go b/session.go index f427fac6d..6a078da7e 100644 --- a/session.go +++ b/session.go @@ -12,6 +12,20 @@ type Session struct { rpc *rpcClient } +// Close should be called by users to clean up RPC and SSH resources whenever the session +// is no longer active. +func (s *Session) Close() error { + if err := s.rpc.Close(); err != nil { + return fmt.Errorf("failed to close RPC conn: %w", err) + } + + if err := s.ssh.Close(); err != nil { + return fmt.Errorf("failed to close SSH conn: %w", err) + } + + return nil +} + // Port describes a port exposed by the container. type Port struct { SourcePort int `json:"sourcePort"` @@ -22,9 +36,7 @@ type Port struct { BrowseURL string `json:"browseUrl"` IsPublic bool `json:"isPublic"` IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"` - HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"` - // ^^^ - // TODO(adonovan): fix possible typo in field name, and audit others. + HasTLSHandshakePassed bool `json:"hasTLSHandshakePassed"` } // startSharing tells the Live Share host to start sharing the specified port from the container. diff --git a/ssh.go b/ssh.go index b68d400a1..15f67d2a4 100644 --- a/ssh.go +++ b/ssh.go @@ -36,24 +36,24 @@ func (s *sshSession) connect(ctx context.Context) error { sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig) if err != nil { - return fmt.Errorf("error creating ssh client connection: %v", err) + return fmt.Errorf("error creating ssh client connection: %w", err) } s.conn = sshClientConn sshClient := ssh.NewClient(sshClientConn, chans, reqs) s.Session, err = sshClient.NewSession() if err != nil { - return fmt.Errorf("error creating ssh client session: %v", err) + return fmt.Errorf("error creating ssh client session: %w", err) } s.reader, err = s.Session.StdoutPipe() if err != nil { - return fmt.Errorf("error creating ssh session reader: %v", err) + return fmt.Errorf("error creating ssh session reader: %w", err) } s.writer, err = s.Session.StdinPipe() if err != nil { - return fmt.Errorf("error creating ssh session writer: %v", err) + return fmt.Errorf("error creating ssh session writer: %w", err) } return nil