Fix data race in StartSharing
This commit is contained in:
parent
e9cb521bfd
commit
87b15aa264
6 changed files with 64 additions and 41 deletions
13
client.go
13
client.go
|
|
@ -86,6 +86,12 @@ type joinWorkspaceResult struct {
|
|||
SessionNumber int `json:"sessionNumber"`
|
||||
}
|
||||
|
||||
// A channelID is an identifier for an exposed port on a remote
|
||||
// container that may be used to open an SSH channel to it.
|
||||
type channelID struct {
|
||||
name, condition string
|
||||
}
|
||||
|
||||
func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) {
|
||||
args := joinWorkspaceArgs{
|
||||
ID: c.connection.SessionID,
|
||||
|
|
@ -104,8 +110,11 @@ func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorksp
|
|||
return &result, nil
|
||||
}
|
||||
|
||||
func (s *Session) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) {
|
||||
args := getStreamArgs{streamName, condition}
|
||||
func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) {
|
||||
args := getStreamArgs{
|
||||
StreamName: id.name,
|
||||
Condition: id.condition,
|
||||
}
|
||||
var streamID string
|
||||
if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
|
||||
return nil, fmt.Errorf("error getting stream id: %v", err)
|
||||
|
|
|
|||
|
|
@ -9,23 +9,34 @@ import (
|
|||
|
||||
// A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session.
|
||||
type PortForwarder struct {
|
||||
session *Session
|
||||
port int
|
||||
session *Session
|
||||
name string
|
||||
localPort, remotePort int
|
||||
}
|
||||
|
||||
// NewPortForwarder creates a new PortForwarder for a given Live Share session and local TCP port.
|
||||
func NewPortForwarder(session *Session, port int) *PortForwarder {
|
||||
// NewPortForwarder creates a new PortForwarder that forwards traffic
|
||||
// between the local port and the container's remote port over the
|
||||
// specified Live Share session. The name describes the purpose of the
|
||||
// remote port or service.
|
||||
func NewPortForwarder(session *Session, name string, localPort, remotePort int) *PortForwarder {
|
||||
return &PortForwarder{
|
||||
session: session,
|
||||
port: port,
|
||||
session: session,
|
||||
name: name,
|
||||
localPort: localPort,
|
||||
remotePort: remotePort,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward enables port forwarding. It accepts and handles TCP
|
||||
// connections until it encounters the first error, which may include
|
||||
// context cancellation. Its result is non-nil.
|
||||
func (l *PortForwarder) Forward(ctx context.Context) (err error) {
|
||||
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", l.port))
|
||||
func (fwd *PortForwarder) Forward(ctx context.Context) (err error) {
|
||||
id, err := fwd.shareRemotePort(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", fwd.localPort))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error listening on TCP port: %v", err)
|
||||
}
|
||||
|
|
@ -49,7 +60,7 @@ func (l *PortForwarder) Forward(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
go func() {
|
||||
if err := l.handleConnection(ctx, conn); err != nil {
|
||||
if err := fwd.handleConnection(ctx, id, conn); err != nil {
|
||||
sendError(err)
|
||||
}
|
||||
}()
|
||||
|
|
@ -60,17 +71,30 @@ func (l *PortForwarder) Forward(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
// ForwardWithConn handles port forwarding for a single connection.
|
||||
func (l *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error {
|
||||
func (fwd *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error {
|
||||
id, err := fwd.shareRemotePort(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create buffered channel so that send doesn't get stuck after context cancellation.
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
if err := l.handleConnection(ctx, conn); err != nil {
|
||||
if err := fwd.handleConnection(ctx, id, conn); err != nil {
|
||||
errc <- err
|
||||
}
|
||||
}()
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) {
|
||||
id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to share remote port %d: %v", fwd.remotePort, err)
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func awaitError(ctx context.Context, errc <-chan error) error {
|
||||
select {
|
||||
case err := <-errc:
|
||||
|
|
@ -81,10 +105,10 @@ func awaitError(ctx context.Context, errc <-chan error) error {
|
|||
}
|
||||
|
||||
// handleConnection handles forwarding for a single accepted connection, then closes it.
|
||||
func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) {
|
||||
func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) {
|
||||
defer safeClose(conn, &err)
|
||||
|
||||
channel, err := l.session.openStreamingChannel(ctx, l.session.streamName, l.session.streamCondition)
|
||||
channel, err := fwd.session.openStreamingChannel(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening streaming channel for new connection: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) {
|
|||
t.Errorf("create mock client: %v", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
pf := NewPortForwarder(session, 80)
|
||||
pf := NewPortForwarder(session, "ssh", 81, 80)
|
||||
if pf == nil {
|
||||
t.Error("port forwarder is nil")
|
||||
}
|
||||
|
|
@ -48,14 +48,11 @@ func TestPortForwarderStart(t *testing.T) {
|
|||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
pf := NewPortForwarder(session, 8000)
|
||||
done := make(chan error)
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
if err := session.StartSharing(ctx, "http", 8000); err != nil {
|
||||
done <- fmt.Errorf("start sharing: %v", err)
|
||||
}
|
||||
done <- pf.Forward(ctx)
|
||||
const name, local, remote = "ssh", 8000, 8000
|
||||
done <- NewPortForwarder(session, name, local, remote).Forward(ctx)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
|
|
|
|||
24
session.go
24
session.go
|
|
@ -9,11 +9,6 @@ import (
|
|||
type Session struct {
|
||||
ssh *sshSession
|
||||
rpc *rpcClient
|
||||
|
||||
// TODO(adonovan): fix: avoid data race of state accessed by
|
||||
// multiple calls to StartSharing and concurrent calls to
|
||||
// PortForwarder. Perhaps combine the two operations in the API?
|
||||
streamName, streamCondition string
|
||||
}
|
||||
|
||||
// Port describes a port exposed by the container.
|
||||
|
|
@ -31,20 +26,17 @@ type Port struct {
|
|||
// TODO(adonovan): fix possible typo in field name, and audit others.
|
||||
}
|
||||
|
||||
// StartSharing tells the Live Share host to start sharing the specified port from the container.
|
||||
// The sessionName describes the purpose of the port or service.
|
||||
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) error {
|
||||
// startSharing tells the Live Share host to start sharing the specified port from the container.
|
||||
// The sessionName describes the purpose of the remote port or service.
|
||||
// It returns an identifier that can be used to open an SSH channel to the remote port.
|
||||
func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) {
|
||||
args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
|
||||
var response Port
|
||||
if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{
|
||||
port, sessionName, fmt.Sprintf("http://localhost:%d", port),
|
||||
}, &response); err != nil {
|
||||
return err
|
||||
if err := s.rpc.do(ctx, "serverSharing.startSharing", args, &response); err != nil {
|
||||
return channelID{}, err
|
||||
}
|
||||
|
||||
s.streamName = response.StreamName
|
||||
s.streamCondition = response.StreamCondition
|
||||
|
||||
return nil
|
||||
return channelID{response.StreamName, response.StreamCondition}, nil
|
||||
}
|
||||
|
||||
// GetSharedServers returns a description of each container port
|
||||
|
|
|
|||
|
|
@ -82,10 +82,11 @@ func TestServerStartSharing(t *testing.T) {
|
|||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
if err := session.StartSharing(ctx, serverProtocol, serverPort); err != nil {
|
||||
streamID, err := session.startSharing(ctx, serverProtocol, serverPort)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("error sharing server: %v", err)
|
||||
}
|
||||
if session.streamName == "" || session.streamCondition == "" {
|
||||
if streamID.name == "" || streamID.condition == "" {
|
||||
done <- errors.New("stream name or condition is blank")
|
||||
}
|
||||
done <- nil
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) {
|
|||
}
|
||||
<-started
|
||||
|
||||
channel, err := t.terminal.session.openStreamingChannel(ctx, result.StreamName, result.StreamCondition)
|
||||
channel, err := t.terminal.session.openStreamingChannel(ctx, channelID{result.StreamName, result.StreamCondition})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening streaming channel: %v", err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue