diff --git a/client.go b/client.go index 377ec2512..0088662f7 100644 --- a/client.go +++ b/client.go @@ -86,6 +86,12 @@ type joinWorkspaceResult struct { SessionNumber int `json:"sessionNumber"` } +// A channelID is an identifier for an exposed port on a remote +// container that may be used to open an SSH channel to it. +type channelID struct { + name, condition string +} + func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) { args := joinWorkspaceArgs{ ID: c.connection.SessionID, @@ -104,8 +110,11 @@ func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorksp return &result, nil } -func (s *Session) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { - args := getStreamArgs{streamName, condition} +func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) { + args := getStreamArgs{ + StreamName: id.name, + Condition: id.condition, + } var streamID string if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { return nil, fmt.Errorf("error getting stream id: %v", err) diff --git a/port_forwarder.go b/port_forwarder.go index 29dee58f9..f4895bb60 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -7,25 +7,36 @@ import ( "net" ) -// A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session. +// A PortForwarder forwards TCP traffic over a LiveShare session from a port on a remote +// container to a local destination such as a network port or Go reader/writer. type PortForwarder struct { - session *Session - port int + session *Session + name string + remotePort int } -// NewPortForwarder creates a new PortForwarder for a given Live Share session and local TCP port. -func NewPortForwarder(session *Session, port int) *PortForwarder { +// NewPortForwarder returns a new PortForwarder for the specified +// remote port and Live Share session. The name describes the purpose +// of the remote port or service. +func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder { return &PortForwarder{ - session: session, - port: port, + session: session, + name: name, + remotePort: remotePort, } } -// Forward enables port forwarding. It accepts and handles TCP -// connections until it encounters the first error, which may include -// context cancellation. Its result is non-nil. -func (l *PortForwarder) Forward(ctx context.Context) (err error) { - listen, err := net.Listen("tcp", fmt.Sprintf(":%d", l.port)) +// ForwardToLocalPort forwards traffic between the container's remote +// port and a local TCP port. It accepts and handles connections on +// the local port until it encounters the first error, which may +// include context cancellation. Its error result is always non-nil. +func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, localPort int) (err error) { + id, err := fwd.shareRemotePort(ctx) + if err != nil { + return err + } + + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort)) if err != nil { return fmt.Errorf("error listening on TCP port: %v", err) } @@ -49,7 +60,7 @@ func (l *PortForwarder) Forward(ctx context.Context) (err error) { } go func() { - if err := l.handleConnection(ctx, conn); err != nil { + if err := fwd.handleConnection(ctx, id, conn); err != nil { sendError(err) } }() @@ -59,18 +70,33 @@ func (l *PortForwarder) Forward(ctx context.Context) (err error) { return awaitError(ctx, errc) } -// ForwardWithConn handles port forwarding for a single connection. -func (l *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { +// Forward forwards traffic between the container's remote port and +// the specified read/write stream. On return, the stream is closed. +func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error { + id, err := fwd.shareRemotePort(ctx) + if err != nil { + conn.Close() + return err + } + // Create buffered channel so that send doesn't get stuck after context cancellation. errc := make(chan error, 1) go func() { - if err := l.handleConnection(ctx, conn); err != nil { + if err := fwd.handleConnection(ctx, id, conn); err != nil { errc <- err } }() return awaitError(ctx, errc) } +func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) { + id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort) + if err != nil { + err = fmt.Errorf("failed to share remote port %d: %v", fwd.remotePort, err) + } + return id, nil +} + func awaitError(ctx context.Context, errc <-chan error) error { select { case err := <-errc: @@ -81,10 +107,10 @@ func awaitError(ctx context.Context, errc <-chan error) error { } // handleConnection handles forwarding for a single accepted connection, then closes it. -func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) { +func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) { defer safeClose(conn, &err) - channel, err := l.session.openStreamingChannel(ctx, l.session.streamName, l.session.streamCondition) + channel, err := fwd.session.openStreamingChannel(ctx, id) if err != nil { return fmt.Errorf("error opening streaming channel for new connection: %v", err) } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 44ef59fe0..6ccb3d05e 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) { t.Errorf("create mock client: %v", err) } defer testServer.Close() - pf := NewPortForwarder(session, 80) + pf := NewPortForwarder(session, "ssh", 80) if pf == nil { t.Error("port forwarder is nil") } @@ -48,14 +48,11 @@ func TestPortForwarderStart(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pf := NewPortForwarder(session, 8000) - done := make(chan error) + done := make(chan error) go func() { - if err := session.StartSharing(ctx, "http", 8000); err != nil { - done <- fmt.Errorf("start sharing: %v", err) - } - done <- pf.Forward(ctx) + const name, local, remote = "ssh", 8000, 8000 + done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, local) }() go func() { diff --git a/session.go b/session.go index d57906f26..0e3120cd7 100644 --- a/session.go +++ b/session.go @@ -9,11 +9,6 @@ import ( type Session struct { ssh *sshSession rpc *rpcClient - - // TODO(adonovan): fix: avoid data race of state accessed by - // multiple calls to StartSharing and concurrent calls to - // PortForwarder. Perhaps combine the two operations in the API? - streamName, streamCondition string } // Port describes a port exposed by the container. @@ -31,20 +26,17 @@ type Port struct { // TODO(adonovan): fix possible typo in field name, and audit others. } -// StartSharing tells the Live Share host to start sharing the specified port from the container. -// The sessionName describes the purpose of the port or service. -func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) error { +// startSharing tells the Live Share host to start sharing the specified port from the container. +// The sessionName describes the purpose of the remote port or service. +// It returns an identifier that can be used to open an SSH channel to the remote port. +func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) { + args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)} var response Port - if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ - port, sessionName, fmt.Sprintf("http://localhost:%d", port), - }, &response); err != nil { - return err + if err := s.rpc.do(ctx, "serverSharing.startSharing", args, &response); err != nil { + return channelID{}, err } - s.streamName = response.StreamName - s.streamCondition = response.StreamCondition - - return nil + return channelID{response.StreamName, response.StreamCondition}, nil } // GetSharedServers returns a description of each container port diff --git a/session_test.go b/session_test.go index 005eacfbd..54aab16c8 100644 --- a/session_test.go +++ b/session_test.go @@ -82,10 +82,11 @@ func TestServerStartSharing(t *testing.T) { done := make(chan error) go func() { - if err := session.StartSharing(ctx, serverProtocol, serverPort); err != nil { + streamID, err := session.startSharing(ctx, serverProtocol, serverPort) + if err != nil { done <- fmt.Errorf("error sharing server: %v", err) } - if session.streamName == "" || session.streamCondition == "" { + if streamID.name == "" || streamID.condition == "" { done <- errors.New("stream name or condition is blank") } done <- nil diff --git a/terminal.go b/terminal.go index 96938ed89..24a0f5121 100644 --- a/terminal.go +++ b/terminal.go @@ -75,7 +75,7 @@ func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { } <-started - channel, err := t.terminal.session.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) + channel, err := t.terminal.session.openStreamingChannel(ctx, channelID{result.StreamName, result.StreamCondition}) if err != nil { return nil, fmt.Errorf("error opening streaming channel: %v", err) }