From 87b15aa264e583688aa9b448ea57663b87a2b4cf Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 14:03:48 -0400 Subject: [PATCH 1/4] Fix data race in StartSharing --- client.go | 13 +++++++++-- port_forwarder.go | 50 +++++++++++++++++++++++++++++++----------- port_forwarder_test.go | 11 ++++------ session.go | 24 +++++++------------- session_test.go | 5 +++-- terminal.go | 2 +- 6 files changed, 64 insertions(+), 41 deletions(-) diff --git a/client.go b/client.go index 377ec2512..0088662f7 100644 --- a/client.go +++ b/client.go @@ -86,6 +86,12 @@ type joinWorkspaceResult struct { SessionNumber int `json:"sessionNumber"` } +// A channelID is an identifier for an exposed port on a remote +// container that may be used to open an SSH channel to it. +type channelID struct { + name, condition string +} + func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) { args := joinWorkspaceArgs{ ID: c.connection.SessionID, @@ -104,8 +110,11 @@ func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorksp return &result, nil } -func (s *Session) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { - args := getStreamArgs{streamName, condition} +func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) { + args := getStreamArgs{ + StreamName: id.name, + Condition: id.condition, + } var streamID string if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { return nil, fmt.Errorf("error getting stream id: %v", err) diff --git a/port_forwarder.go b/port_forwarder.go index 29dee58f9..4391ef55c 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -9,23 +9,34 @@ import ( // A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session. type PortForwarder struct { - session *Session - port int + session *Session + name string + localPort, remotePort int } -// NewPortForwarder creates a new PortForwarder for a given Live Share session and local TCP port. -func NewPortForwarder(session *Session, port int) *PortForwarder { +// NewPortForwarder creates a new PortForwarder that forwards traffic +// between the local port and the container's remote port over the +// specified Live Share session. The name describes the purpose of the +// remote port or service. +func NewPortForwarder(session *Session, name string, localPort, remotePort int) *PortForwarder { return &PortForwarder{ - session: session, - port: port, + session: session, + name: name, + localPort: localPort, + remotePort: remotePort, } } // Forward enables port forwarding. It accepts and handles TCP // connections until it encounters the first error, which may include // context cancellation. Its result is non-nil. -func (l *PortForwarder) Forward(ctx context.Context) (err error) { - listen, err := net.Listen("tcp", fmt.Sprintf(":%d", l.port)) +func (fwd *PortForwarder) Forward(ctx context.Context) (err error) { + id, err := fwd.shareRemotePort(ctx) + if err != nil { + return err + } + + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", fwd.localPort)) if err != nil { return fmt.Errorf("error listening on TCP port: %v", err) } @@ -49,7 +60,7 @@ func (l *PortForwarder) Forward(ctx context.Context) (err error) { } go func() { - if err := l.handleConnection(ctx, conn); err != nil { + if err := fwd.handleConnection(ctx, id, conn); err != nil { sendError(err) } }() @@ -60,17 +71,30 @@ func (l *PortForwarder) Forward(ctx context.Context) (err error) { } // ForwardWithConn handles port forwarding for a single connection. -func (l *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { +func (fwd *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { + id, err := fwd.shareRemotePort(ctx) + if err != nil { + return err + } + // Create buffered channel so that send doesn't get stuck after context cancellation. errc := make(chan error, 1) go func() { - if err := l.handleConnection(ctx, conn); err != nil { + if err := fwd.handleConnection(ctx, id, conn); err != nil { errc <- err } }() return awaitError(ctx, errc) } +func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) { + id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort) + if err != nil { + err = fmt.Errorf("failed to share remote port %d: %v", fwd.remotePort, err) + } + return id, nil +} + func awaitError(ctx context.Context, errc <-chan error) error { select { case err := <-errc: @@ -81,10 +105,10 @@ func awaitError(ctx context.Context, errc <-chan error) error { } // handleConnection handles forwarding for a single accepted connection, then closes it. -func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) { +func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) { defer safeClose(conn, &err) - channel, err := l.session.openStreamingChannel(ctx, l.session.streamName, l.session.streamCondition) + channel, err := fwd.session.openStreamingChannel(ctx, id) if err != nil { return fmt.Errorf("error opening streaming channel for new connection: %v", err) } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 44ef59fe0..d47730995 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) { t.Errorf("create mock client: %v", err) } defer testServer.Close() - pf := NewPortForwarder(session, 80) + pf := NewPortForwarder(session, "ssh", 81, 80) if pf == nil { t.Error("port forwarder is nil") } @@ -48,14 +48,11 @@ func TestPortForwarderStart(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pf := NewPortForwarder(session, 8000) - done := make(chan error) + done := make(chan error) go func() { - if err := session.StartSharing(ctx, "http", 8000); err != nil { - done <- fmt.Errorf("start sharing: %v", err) - } - done <- pf.Forward(ctx) + const name, local, remote = "ssh", 8000, 8000 + done <- NewPortForwarder(session, name, local, remote).Forward(ctx) }() go func() { diff --git a/session.go b/session.go index d57906f26..0e3120cd7 100644 --- a/session.go +++ b/session.go @@ -9,11 +9,6 @@ import ( type Session struct { ssh *sshSession rpc *rpcClient - - // TODO(adonovan): fix: avoid data race of state accessed by - // multiple calls to StartSharing and concurrent calls to - // PortForwarder. Perhaps combine the two operations in the API? - streamName, streamCondition string } // Port describes a port exposed by the container. @@ -31,20 +26,17 @@ type Port struct { // TODO(adonovan): fix possible typo in field name, and audit others. } -// StartSharing tells the Live Share host to start sharing the specified port from the container. -// The sessionName describes the purpose of the port or service. -func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) error { +// startSharing tells the Live Share host to start sharing the specified port from the container. +// The sessionName describes the purpose of the remote port or service. +// It returns an identifier that can be used to open an SSH channel to the remote port. +func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) { + args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)} var response Port - if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ - port, sessionName, fmt.Sprintf("http://localhost:%d", port), - }, &response); err != nil { - return err + if err := s.rpc.do(ctx, "serverSharing.startSharing", args, &response); err != nil { + return channelID{}, err } - s.streamName = response.StreamName - s.streamCondition = response.StreamCondition - - return nil + return channelID{response.StreamName, response.StreamCondition}, nil } // GetSharedServers returns a description of each container port diff --git a/session_test.go b/session_test.go index 005eacfbd..54aab16c8 100644 --- a/session_test.go +++ b/session_test.go @@ -82,10 +82,11 @@ func TestServerStartSharing(t *testing.T) { done := make(chan error) go func() { - if err := session.StartSharing(ctx, serverProtocol, serverPort); err != nil { + streamID, err := session.startSharing(ctx, serverProtocol, serverPort) + if err != nil { done <- fmt.Errorf("error sharing server: %v", err) } - if session.streamName == "" || session.streamCondition == "" { + if streamID.name == "" || streamID.condition == "" { done <- errors.New("stream name or condition is blank") } done <- nil diff --git a/terminal.go b/terminal.go index 96938ed89..24a0f5121 100644 --- a/terminal.go +++ b/terminal.go @@ -75,7 +75,7 @@ func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { } <-started - channel, err := t.terminal.session.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) + channel, err := t.terminal.session.openStreamingChannel(ctx, channelID{result.StreamName, result.StreamCondition}) if err != nil { return nil, fmt.Errorf("error opening streaming channel: %v", err) } From 94b91661cc68b200e30e809d8c26b41a7f37c1af Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 14:30:19 -0400 Subject: [PATCH 2/4] don't forget to close conn in case of sharing error --- port_forwarder.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/port_forwarder.go b/port_forwarder.go index 4391ef55c..2d1217c24 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -18,6 +18,12 @@ type PortForwarder struct { // between the local port and the container's remote port over the // specified Live Share session. The name describes the purpose of the // remote port or service. +// +// TODO(adonovan): the localPort param is redundant wrt ForwardWithConn. +// Simpler: do away with the NewPortForwarder type altogether: +// +// - ForwardToLocalPort(ctx, session, name, remote, local) +// - ForwardToConnection(ctx, session, name, remote, conn) func NewPortForwarder(session *Session, name string, localPort, remotePort int) *PortForwarder { return &PortForwarder{ session: session, @@ -74,6 +80,7 @@ func (fwd *PortForwarder) Forward(ctx context.Context) (err error) { func (fwd *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { id, err := fwd.shareRemotePort(ctx) if err != nil { + conn.Close() return err } From 94319d4cfeaa6b6a0389e75c0401e265e2078e09 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 15:34:57 -0400 Subject: [PATCH 3/4] move localPort parameter to ForwardToLocalPort --- port_forwarder.go | 39 +++++++++++++++++---------------------- port_forwarder_test.go | 4 ++-- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index 2d1217c24..4a46cd4e6 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -7,42 +7,36 @@ import ( "net" ) -// A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session. +// A PortForwarder forwards TCP traffic over a LiveShare session from a port on a remote +// container to a local destination such as a network port or Go reader/writer. type PortForwarder struct { - session *Session - name string - localPort, remotePort int + session *Session + name string + remotePort int } -// NewPortForwarder creates a new PortForwarder that forwards traffic -// between the local port and the container's remote port over the -// specified Live Share session. The name describes the purpose of the -// remote port or service. -// -// TODO(adonovan): the localPort param is redundant wrt ForwardWithConn. -// Simpler: do away with the NewPortForwarder type altogether: -// -// - ForwardToLocalPort(ctx, session, name, remote, local) -// - ForwardToConnection(ctx, session, name, remote, conn) -func NewPortForwarder(session *Session, name string, localPort, remotePort int) *PortForwarder { +// NewPortForwarder returns a new PortForwarder for the specified +// remote port and Live Share session. The name describes the purpose +// of the remote port or service. +func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder { return &PortForwarder{ session: session, name: name, - localPort: localPort, remotePort: remotePort, } } -// Forward enables port forwarding. It accepts and handles TCP -// connections until it encounters the first error, which may include +// ForwardToLocalPort forwards traffic between the container's remote +// port and a local TCP port. It accepts and handles TCP connections +// on the local until it encounters the first error, which may include // context cancellation. Its result is non-nil. -func (fwd *PortForwarder) Forward(ctx context.Context) (err error) { +func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, localPort int) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { return err } - listen, err := net.Listen("tcp", fmt.Sprintf(":%d", fwd.localPort)) + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort)) if err != nil { return fmt.Errorf("error listening on TCP port: %v", err) } @@ -76,8 +70,9 @@ func (fwd *PortForwarder) Forward(ctx context.Context) (err error) { return awaitError(ctx, errc) } -// ForwardWithConn handles port forwarding for a single connection. -func (fwd *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { +// Forward forwards traffic between the container's remote port and +// the specified read/write stream. On return, the stream is closed. +func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error { id, err := fwd.shareRemotePort(ctx) if err != nil { conn.Close() diff --git a/port_forwarder_test.go b/port_forwarder_test.go index d47730995..6ccb3d05e 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) { t.Errorf("create mock client: %v", err) } defer testServer.Close() - pf := NewPortForwarder(session, "ssh", 81, 80) + pf := NewPortForwarder(session, "ssh", 80) if pf == nil { t.Error("port forwarder is nil") } @@ -52,7 +52,7 @@ func TestPortForwarderStart(t *testing.T) { done := make(chan error) go func() { const name, local, remote = "ssh", 8000, 8000 - done <- NewPortForwarder(session, name, local, remote).Forward(ctx) + done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, local) }() go func() { From 4438b85e294e510edf97510ede486db175e8f084 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 15:41:36 -0400 Subject: [PATCH 4/4] comment tweaks --- port_forwarder.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index 4a46cd4e6..f4895bb60 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -27,9 +27,9 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar } // ForwardToLocalPort forwards traffic between the container's remote -// port and a local TCP port. It accepts and handles TCP connections -// on the local until it encounters the first error, which may include -// context cancellation. Its result is non-nil. +// port and a local TCP port. It accepts and handles connections on +// the local port until it encounters the first error, which may +// include context cancellation. Its error result is always non-nil. func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, localPort int) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil {