Merge pull request #12 from github/fix-share-data-race

Fix data race in StartSharing
This commit is contained in:
Alan Donovan 2021-09-02 15:44:49 -04:00 committed by GitHub
commit b4686935b9
6 changed files with 71 additions and 46 deletions

View file

@ -86,6 +86,12 @@ type joinWorkspaceResult struct {
SessionNumber int `json:"sessionNumber"`
}
// A channelID is an identifier for an exposed port on a remote
// container that may be used to open an SSH channel to it.
type channelID struct {
name, condition string
}
func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) {
args := joinWorkspaceArgs{
ID: c.connection.SessionID,
@ -104,8 +110,11 @@ func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorksp
return &result, nil
}
func (s *Session) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) {
args := getStreamArgs{streamName, condition}
func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) {
args := getStreamArgs{
StreamName: id.name,
Condition: id.condition,
}
var streamID string
if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
return nil, fmt.Errorf("error getting stream id: %v", err)

View file

@ -7,25 +7,36 @@ import (
"net"
)
// A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session.
// A PortForwarder forwards TCP traffic over a LiveShare session from a port on a remote
// container to a local destination such as a network port or Go reader/writer.
type PortForwarder struct {
session *Session
port int
session *Session
name string
remotePort int
}
// NewPortForwarder creates a new PortForwarder for a given Live Share session and local TCP port.
func NewPortForwarder(session *Session, port int) *PortForwarder {
// NewPortForwarder returns a new PortForwarder for the specified
// remote port and Live Share session. The name describes the purpose
// of the remote port or service.
func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder {
return &PortForwarder{
session: session,
port: port,
session: session,
name: name,
remotePort: remotePort,
}
}
// Forward enables port forwarding. It accepts and handles TCP
// connections until it encounters the first error, which may include
// context cancellation. Its result is non-nil.
func (l *PortForwarder) Forward(ctx context.Context) (err error) {
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", l.port))
// ForwardToLocalPort forwards traffic between the container's remote
// port and a local TCP port. It accepts and handles connections on
// the local port until it encounters the first error, which may
// include context cancellation. Its error result is always non-nil.
func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, localPort int) (err error) {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
return err
}
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort))
if err != nil {
return fmt.Errorf("error listening on TCP port: %v", err)
}
@ -49,7 +60,7 @@ func (l *PortForwarder) Forward(ctx context.Context) (err error) {
}
go func() {
if err := l.handleConnection(ctx, conn); err != nil {
if err := fwd.handleConnection(ctx, id, conn); err != nil {
sendError(err)
}
}()
@ -59,18 +70,33 @@ func (l *PortForwarder) Forward(ctx context.Context) (err error) {
return awaitError(ctx, errc)
}
// ForwardWithConn handles port forwarding for a single connection.
func (l *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error {
// Forward forwards traffic between the container's remote port and
// the specified read/write stream. On return, the stream is closed.
func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
conn.Close()
return err
}
// Create buffered channel so that send doesn't get stuck after context cancellation.
errc := make(chan error, 1)
go func() {
if err := l.handleConnection(ctx, conn); err != nil {
if err := fwd.handleConnection(ctx, id, conn); err != nil {
errc <- err
}
}()
return awaitError(ctx, errc)
}
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) {
id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort)
if err != nil {
err = fmt.Errorf("failed to share remote port %d: %v", fwd.remotePort, err)
}
return id, nil
}
func awaitError(ctx context.Context, errc <-chan error) error {
select {
case err := <-errc:
@ -81,10 +107,10 @@ func awaitError(ctx context.Context, errc <-chan error) error {
}
// handleConnection handles forwarding for a single accepted connection, then closes it.
func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) {
func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) {
defer safeClose(conn, &err)
channel, err := l.session.openStreamingChannel(ctx, l.session.streamName, l.session.streamCondition)
channel, err := fwd.session.openStreamingChannel(ctx, id)
if err != nil {
return fmt.Errorf("error opening streaming channel for new connection: %v", err)
}

View file

@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) {
t.Errorf("create mock client: %v", err)
}
defer testServer.Close()
pf := NewPortForwarder(session, 80)
pf := NewPortForwarder(session, "ssh", 80)
if pf == nil {
t.Error("port forwarder is nil")
}
@ -48,14 +48,11 @@ func TestPortForwarderStart(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pf := NewPortForwarder(session, 8000)
done := make(chan error)
done := make(chan error)
go func() {
if err := session.StartSharing(ctx, "http", 8000); err != nil {
done <- fmt.Errorf("start sharing: %v", err)
}
done <- pf.Forward(ctx)
const name, local, remote = "ssh", 8000, 8000
done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, local)
}()
go func() {

View file

@ -9,11 +9,6 @@ import (
type Session struct {
ssh *sshSession
rpc *rpcClient
// TODO(adonovan): fix: avoid data race of state accessed by
// multiple calls to StartSharing and concurrent calls to
// PortForwarder. Perhaps combine the two operations in the API?
streamName, streamCondition string
}
// Port describes a port exposed by the container.
@ -31,20 +26,17 @@ type Port struct {
// TODO(adonovan): fix possible typo in field name, and audit others.
}
// StartSharing tells the Live Share host to start sharing the specified port from the container.
// The sessionName describes the purpose of the port or service.
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) error {
// startSharing tells the Live Share host to start sharing the specified port from the container.
// The sessionName describes the purpose of the remote port or service.
// It returns an identifier that can be used to open an SSH channel to the remote port.
func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) {
args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
var response Port
if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{
port, sessionName, fmt.Sprintf("http://localhost:%d", port),
}, &response); err != nil {
return err
if err := s.rpc.do(ctx, "serverSharing.startSharing", args, &response); err != nil {
return channelID{}, err
}
s.streamName = response.StreamName
s.streamCondition = response.StreamCondition
return nil
return channelID{response.StreamName, response.StreamCondition}, nil
}
// GetSharedServers returns a description of each container port

View file

@ -82,10 +82,11 @@ func TestServerStartSharing(t *testing.T) {
done := make(chan error)
go func() {
if err := session.StartSharing(ctx, serverProtocol, serverPort); err != nil {
streamID, err := session.startSharing(ctx, serverProtocol, serverPort)
if err != nil {
done <- fmt.Errorf("error sharing server: %v", err)
}
if session.streamName == "" || session.streamCondition == "" {
if streamID.name == "" || streamID.condition == "" {
done <- errors.New("stream name or condition is blank")
}
done <- nil

View file

@ -75,7 +75,7 @@ func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) {
}
<-started
channel, err := t.terminal.session.openStreamingChannel(ctx, result.StreamName, result.StreamCondition)
channel, err := t.terminal.session.openStreamingChannel(ctx, channelID{result.StreamName, result.StreamCondition})
if err != nil {
return nil, fmt.Errorf("error opening streaming channel: %v", err)
}