From 8eba57a9ed6ccf69c4944939cd587ff3e1403e70 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 23 Jun 2021 20:00:24 -0400 Subject: [PATCH 01/68] initial commit --- api.go | 130 ++++++++++++++++++++++++++++++++++++++++++++++++ client.go | 28 +++++++++++ config.go | 56 +++++++++++++++++++++ example/main.go | 23 +++++++++ liveshare.go | 35 +++++++++++++ session.go | 44 ++++++++++++++++ ssh.go | 70 ++++++++++++++++++++++++++ 7 files changed, 386 insertions(+) create mode 100644 api.go create mode 100644 client.go create mode 100644 config.go create mode 100644 example/main.go create mode 100644 liveshare.go create mode 100644 session.go create mode 100644 ssh.go diff --git a/api.go b/api.go new file mode 100644 index 000000000..c7efb9830 --- /dev/null +++ b/api.go @@ -0,0 +1,130 @@ +package liveshare + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "strings" +) + +type API struct { + Configuration *Configuration + HttpClient *http.Client + ServiceURI string + WorkspaceID string +} + +func NewAPI(configuration *Configuration) *API { + serviceURI := configuration.LiveShareEndpoint + if !strings.HasSuffix(configuration.LiveShareEndpoint, "/") { + serviceURI = configuration.LiveShareEndpoint + "/" + } + + if !strings.Contains(serviceURI, "api/v1.2") { + serviceURI = serviceURI + "api/v1.2" + } + + serviceURI = strings.TrimSuffix(serviceURI, "/") + + return &API{configuration, &http.Client{}, serviceURI, strings.ToUpper(configuration.WorkspaceID)} +} + +type WorkspaceAccessResponse struct { + SessionToken string `json:"sessionToken"` + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + Name string `json:"name"` + OwnerID string `json:"ownerId"` + JoinLink string `json:"joinLink"` + ConnectLinks []string `json:"connectLinks"` + RelayLink string `json:"relayLink"` + RelaySas string `json:"relaySas"` + HostPublicKeys []string `json:"hostPublicKeys"` + ConversationID string `json:"conversationId"` + AssociatedUserIDs []string `json:"associatedUserIds"` + AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` + IsHostConnected bool `json:"isHostConnected"` + ExpiresAt string `json:"expiresAt"` + InvitationLinks []string `json:"invitationLinks"` + ID string `json:"id"` +} + +func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { + url := fmt.Sprintf("%s/workspace/%s/user", a.ServiceURI, a.WorkspaceID) + + req, err := http.NewRequest(http.MethodPut, url, nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + a.setDefaultHeaders(req) + resp, err := a.HttpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + var workspaceAccessResponse WorkspaceAccessResponse + if err := json.Unmarshal(b, &workspaceAccessResponse); err != nil { + return nil, fmt.Errorf("error unmarshaling response into json: %v", err) + } + + return &workspaceAccessResponse, nil +} + +func (a *API) setDefaultHeaders(req *http.Request) { + req.Header.Set("Authorization", "Bearer "+a.Configuration.Token) + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Content-Type", "application/json") +} + +type WorkspaceInfoResponse struct { + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + Name string `json:"name"` + OwnerID string `json:"ownerId"` + JoinLink string `json:"joinLink"` + ConnectLinks []string `json:"connectLinks"` + RelayLink string `json:"relayLink"` + RelaySas string `json:"relaySas"` + HostPublicKeys []string `json:"hostPublicKeys"` + ConversationID string `json:"conversationId"` + AssociatedUserIDs []string `json:"associatedUserIds"` + AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` + IsHostConnected bool `json:"isHostConnected"` + ExpiresAt string `json:"expiresAt"` + InvitationLinks []string `json:"invitationLinks"` + ID string `json:"id"` +} + +func (a *API) WorkspaceInfo() (*WorkspaceInfoResponse, error) { + url := fmt.Sprintf("%s/workspace/%s", a.ServiceURI, a.WorkspaceID) + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + a.setDefaultHeaders(req) + resp, err := a.HttpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + var workspaceInfoResponse WorkspaceInfoResponse + if err := json.Unmarshal(b, &workspaceInfoResponse); err != nil { + return nil, fmt.Errorf("error unmarshaling response into json: %v", err) + } + + return &workspaceInfoResponse, nil +} diff --git a/client.go b/client.go new file mode 100644 index 000000000..a58a34c9b --- /dev/null +++ b/client.go @@ -0,0 +1,28 @@ +package liveshare + +import ( + "context" + "fmt" +) + +type Client struct { + Configuration *Configuration +} + +func NewClient(configuration *Configuration) *Client { + return &Client{configuration} +} + +func (c *Client) Join(ctx context.Context) error { + session, err := GetSession(ctx, c.Configuration) + if err != nil { + return fmt.Errorf("error getting session: %v", err) + } + + sshSession := NewSSHSession(session) + if err := sshSession.Connect(); err != nil { + return fmt.Errorf("error authenticating ssh session: %v", err) + } + + return nil +} diff --git a/config.go b/config.go new file mode 100644 index 000000000..74eb5b178 --- /dev/null +++ b/config.go @@ -0,0 +1,56 @@ +package liveshare + +import ( + "errors" + "strings" +) + +type Option func(configuration *Configuration) error + +func WithWorkspaceID(id string) Option { + return func(configuration *Configuration) error { + configuration.WorkspaceID = id + return nil + } +} + +func WithLiveShareEndpoint(liveShareEndpoint string) Option { + return func(configuration *Configuration) error { + configuration.LiveShareEndpoint = liveShareEndpoint + return nil + } +} + +func WithToken(token string) Option { + return func(configuration *Configuration) error { + configuration.Token = token + return nil + } +} + +type Configuration struct { + WorkspaceID, LiveShareEndpoint, Token string +} + +func NewConfiguration() *Configuration { + return &Configuration{ + LiveShareEndpoint: "https://prod.liveshare.vsengsaas.visualstudio.com", + } +} + +func (c *Configuration) Validate() error { + errs := []string{} + if c.WorkspaceID == "" { + errs = append(errs, "WorkspaceID is required") + } + + if c.Token == "" { + errs = append(errs, "Token is required") + } + + if len(errs) > 0 { + return errors.New(strings.Join(errs, ", ")) + } + + return nil +} diff --git a/example/main.go b/example/main.go new file mode 100644 index 000000000..79ab53377 --- /dev/null +++ b/example/main.go @@ -0,0 +1,23 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/josebalius/go-liveshare" +) + +func main() { + liveShare, err := liveshare.New( + liveshare.WithWorkspaceID("..."), + liveshare.WithToken("..."), + ) + if err != nil { + log.Fatal(fmt.Errorf("error creating liveshare: %v", err)) + } + + if err := liveShare.Connect(context.Background()); err != nil { + log.Fatal(fmt.Errorf("error connecting to liveshare: %v", err)) + } +} diff --git a/liveshare.go b/liveshare.go new file mode 100644 index 000000000..a8c8b69d6 --- /dev/null +++ b/liveshare.go @@ -0,0 +1,35 @@ +package liveshare + +import ( + "context" + "fmt" +) + +type LiveShare struct { + Configuration *Configuration +} + +func New(opts ...Option) (*LiveShare, error) { + configuration := NewConfiguration() + + for _, o := range opts { + if err := o(configuration); err != nil { + return nil, fmt.Errorf("error configuring liveshare: %v", err) + } + } + + if err := configuration.Validate(); err != nil { + return nil, fmt.Errorf("error validating configuration: %v", err) + } + + return &LiveShare{configuration}, nil +} + +func (l *LiveShare) Connect(ctx context.Context) error { + workspaceClient := NewClient(l.Configuration) + if err := workspaceClient.Join(ctx); err != nil { + return fmt.Errorf("error joining with workspace client: %v", err) + } + + return nil +} diff --git a/session.go b/session.go new file mode 100644 index 000000000..24a284ef2 --- /dev/null +++ b/session.go @@ -0,0 +1,44 @@ +package liveshare + +import ( + "context" + "fmt" + + "golang.org/x/sync/errgroup" +) + +type Session struct { + WorkspaceAccess *WorkspaceAccessResponse + WorkspaceInfo *WorkspaceInfoResponse +} + +func GetSession(ctx context.Context, configuration *Configuration) (*Session, error) { + api := NewAPI(configuration) + session := new(Session) + + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + workspaceAccess, err := api.WorkspaceAccess() + if err != nil { + return fmt.Errorf("error getting workspace access: %v", err) + } + session.WorkspaceAccess = workspaceAccess + return nil + }) + + g.Go(func() error { + workspaceInfo, err := api.WorkspaceInfo() + if err != nil { + return fmt.Errorf("error getting workspace info: %v", err) + } + session.WorkspaceInfo = workspaceInfo + return nil + }) + + if err := g.Wait(); err != nil { + return nil, err + } + + return session, nil +} diff --git a/ssh.go b/ssh.go new file mode 100644 index 000000000..af132b331 --- /dev/null +++ b/ssh.go @@ -0,0 +1,70 @@ +package liveshare + +import ( + "fmt" + "net" + "net/url" + "strings" + + "golang.org/x/crypto/ssh" + "golang.org/x/net/websocket" +) + +type SSHSession struct { + Session *Session + VersionExchangeError chan error +} + +func NewSSHSession(session *Session) *SSHSession { + return &SSHSession{ + Session: session, + } +} + +func (s *SSHSession) Connect() error { + socketStream, err := s.socketStream() + if err != nil { + return fmt.Errorf("error creating socket stream: %v", err) + } + + clientConfig := ssh.ClientConfig{ + User: "", + Auth: []ssh.AuthMethod{ + ssh.Password(s.Session.WorkspaceAccess.SessionToken), + }, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + // TODO(josebalius): implement + return nil + }, + } + + sshClientConn, chans, reqs, err := ssh.NewClientConn(socketStream, "", &clientConfig) + if err != nil { + return fmt.Errorf("error creating ssh client connection: %v", err) + } + + fmt.Println(sshClientConn, chans, reqs) + + return nil +} + +// Reference: +// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 +func (s *SSHSession) relayURI(action string) string { + relaySas := url.QueryEscape(s.Session.WorkspaceAccess.RelaySas) + relayURI := s.Session.WorkspaceAccess.RelayLink + relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) + relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) + relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas + return relayURI +} + +func (s *SSHSession) socketStream() (*websocket.Conn, error) { + uri := s.relayURI("connect") + ws, err := websocket.Dial(uri, "", uri) + if err != nil { + return nil, fmt.Errorf("error dialing relay connection: %v", err) + } + + return ws, nil +} From a8b1b87f7b33ab3d37fc3aed276d356d61631367 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 24 Jun 2021 20:44:16 -0400 Subject: [PATCH 02/68] Start of RPC implementation, need to figure out format --- example/main.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/example/main.go b/example/main.go index 79ab53377..bbf540d7c 100644 --- a/example/main.go +++ b/example/main.go @@ -10,8 +10,8 @@ import ( func main() { liveShare, err := liveshare.New( - liveshare.WithWorkspaceID("..."), - liveshare.WithToken("..."), + liveshare.WithWorkspaceID(""), + liveshare.WithToken(""), ) if err != nil { log.Fatal(fmt.Errorf("error creating liveshare: %v", err)) @@ -20,4 +20,17 @@ func main() { if err := liveShare.Connect(context.Background()); err != nil { log.Fatal(fmt.Errorf("error connecting to liveshare: %v", err)) } + + terminal := liveShare.NewTerminal() + + cmd := terminal.NewCommand( + "/home/codespace/workspace", + "docker ps -aq --filter label=Type=codespaces --filter status=running", + ) + output, err := cmd.Run(context.Background()) + if err != nil { + log.Fatal(fmt.Errorf("error starting ssh server with liveshare: %v", err)) + } + + fmt.Println(string(output)) } From 897ab1598b3d2cd9dd08c3310f0938f5c2ce4264 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 24 Jun 2021 20:45:03 -0400 Subject: [PATCH 03/68] RPC functionality started take two --- adapter.go | 100 +++++++++++++++++++++++++++++++++++++++++++++++++++ api.go | 40 +++++++++++---------- client.go | 9 ++--- liveshare.go | 63 ++++++++++++++++++++++++++++++-- ssh.go | 91 ++++++++++++++++++++++++++++++---------------- 5 files changed, 248 insertions(+), 55 deletions(-) create mode 100644 adapter.go diff --git a/adapter.go b/adapter.go new file mode 100644 index 000000000..fb3424734 --- /dev/null +++ b/adapter.go @@ -0,0 +1,100 @@ +package liveshare + +import ( + "errors" + "io" + "net" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type Adapter struct { + conn *websocket.Conn + readMutex sync.Mutex + writeMutex sync.Mutex + reader io.Reader +} + +func NewAdapter(conn *websocket.Conn) *Adapter { + return &Adapter{ + conn: conn, + } +} + +func (a *Adapter) Read(b []byte) (int, error) { + // Read() can be called concurrently, and we mutate some internal state here + a.readMutex.Lock() + defer a.readMutex.Unlock() + + if a.reader == nil { + messageType, reader, err := a.conn.NextReader() + if err != nil { + return 0, err + } + + if messageType != websocket.BinaryMessage { + return 0, errors.New("unexpected websocket message type") + } + + a.reader = reader + } + + bytesRead, err := a.reader.Read(b) + if err != nil { + a.reader = nil + + // EOF for the current Websocket frame, more will probably come so.. + if err == io.EOF { + // .. we must hide this from the caller since our semantics are a + // stream of bytes across many frames + err = nil + } + } + + return bytesRead, err +} + +func (a *Adapter) Write(b []byte) (int, error) { + a.writeMutex.Lock() + defer a.writeMutex.Unlock() + + nextWriter, err := a.conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, err + } + + bytesWritten, err := nextWriter.Write(b) + nextWriter.Close() + + return bytesWritten, err +} + +func (a *Adapter) Close() error { + return a.conn.Close() +} + +func (a *Adapter) LocalAddr() net.Addr { + return a.conn.LocalAddr() +} + +func (a *Adapter) RemoteAddr() net.Addr { + return a.conn.RemoteAddr() +} + +func (a *Adapter) SetDeadline(t time.Time) error { + if err := a.SetReadDeadline(t); err != nil { + return err + } + + return a.SetWriteDeadline(t) +} + +func (a *Adapter) SetReadDeadline(t time.Time) error { + return a.conn.SetReadDeadline(t) +} + +func (a *Adapter) SetWriteDeadline(t time.Time) error { + return a.conn.SetWriteDeadline(t) +} diff --git a/api.go b/api.go index c7efb9830..d8a4bebbd 100644 --- a/api.go +++ b/api.go @@ -5,6 +5,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/http/httputil" "strings" ) @@ -31,27 +32,28 @@ func NewAPI(configuration *Configuration) *API { } type WorkspaceAccessResponse struct { - SessionToken string `json:"sessionToken"` - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - Name string `json:"name"` - OwnerID string `json:"ownerId"` - JoinLink string `json:"joinLink"` - ConnectLinks []string `json:"connectLinks"` - RelayLink string `json:"relayLink"` - RelaySas string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` - ConversationID string `json:"conversationId"` - AssociatedUserIDs []string `json:"associatedUserIds"` - AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` - IsHostConnected bool `json:"isHostConnected"` - ExpiresAt string `json:"expiresAt"` - InvitationLinks []string `json:"invitationLinks"` - ID string `json:"id"` + SessionToken string `json:"sessionToken"` + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + Name string `json:"name"` + OwnerID string `json:"ownerId"` + JoinLink string `json:"joinLink"` + ConnectLinks []string `json:"connectLinks"` + RelayLink string `json:"relayLink"` + RelaySas string `json:"relaySas"` + HostPublicKeys []string `json:"hostPublicKeys"` + ConversationID string `json:"conversationId"` + AssociatedUserIDs map[string]string `json:"associatedUserIds"` + AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` + IsHostConnected bool `json:"isHostConnected"` + ExpiresAt string `json:"expiresAt"` + InvitationLinks []string `json:"invitationLinks"` + ID string `json:"id"` } func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { url := fmt.Sprintf("%s/workspace/%s/user", a.ServiceURI, a.WorkspaceID) + fmt.Println(url) req, err := http.NewRequest(http.MethodPut, url, nil) if err != nil { @@ -69,6 +71,8 @@ func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { return nil, fmt.Errorf("error reading response body: %v", err) } + d, _ := httputil.DumpResponse(resp, true) + fmt.Println(string(d)) var workspaceAccessResponse WorkspaceAccessResponse if err := json.Unmarshal(b, &workspaceAccessResponse); err != nil { return nil, fmt.Errorf("error unmarshaling response into json: %v", err) @@ -94,7 +98,7 @@ type WorkspaceInfoResponse struct { RelaySas string `json:"relaySas"` HostPublicKeys []string `json:"hostPublicKeys"` ConversationID string `json:"conversationId"` - AssociatedUserIDs []string `json:"associatedUserIds"` + AssociatedUserIDs map[string]string AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` IsHostConnected bool `json:"isHostConnected"` ExpiresAt string `json:"expiresAt"` diff --git a/client.go b/client.go index a58a34c9b..0a89a125c 100644 --- a/client.go +++ b/client.go @@ -7,10 +7,11 @@ import ( type Client struct { Configuration *Configuration + SSHSession *SSHSession } func NewClient(configuration *Configuration) *Client { - return &Client{configuration} + return &Client{Configuration: configuration} } func (c *Client) Join(ctx context.Context) error { @@ -19,9 +20,9 @@ func (c *Client) Join(ctx context.Context) error { return fmt.Errorf("error getting session: %v", err) } - sshSession := NewSSHSession(session) - if err := sshSession.Connect(); err != nil { - return fmt.Errorf("error authenticating ssh session: %v", err) + c.SSHSession, err = NewSSH(session).NewSession() + if err != nil { + return fmt.Errorf("error connecting to ssh session: %v", err) } return nil diff --git a/liveshare.go b/liveshare.go index a8c8b69d6..174eac20f 100644 --- a/liveshare.go +++ b/liveshare.go @@ -3,10 +3,14 @@ package liveshare import ( "context" "fmt" + "net/rpc" ) type LiveShare struct { Configuration *Configuration + + workspaceClient *Client + terminal *Terminal } func New(opts ...Option) (*LiveShare, error) { @@ -22,14 +26,67 @@ func New(opts ...Option) (*LiveShare, error) { return nil, fmt.Errorf("error validating configuration: %v", err) } - return &LiveShare{configuration}, nil + return &LiveShare{Configuration: configuration}, nil } func (l *LiveShare) Connect(ctx context.Context) error { - workspaceClient := NewClient(l.Configuration) - if err := workspaceClient.Join(ctx); err != nil { + l.workspaceClient = NewClient(l.Configuration) + if err := l.workspaceClient.Join(ctx); err != nil { return fmt.Errorf("error joining with workspace client: %v", err) } return nil } + +type Terminal struct { + WorkspaceClient *Client + RPCClient *rpc.Client +} + +func (l *LiveShare) NewTerminal() *Terminal { + return &Terminal{ + WorkspaceClient: l.workspaceClient, + RPCClient: rpc.NewClient(l.workspaceClient.SSHSession), + } +} + +type TerminalCommand struct { + Terminal *Terminal + Cwd string + Cmd string +} + +func (t *Terminal) NewCommand(cwd, cmd string) TerminalCommand { + return TerminalCommand{t, cwd, cmd} +} + +type RunArgs struct { + Name string + Rows, Cols int + App string + Cwd string + CommandLine []string + ReadOnlyForGuests bool +} + +func (t TerminalCommand) Run(ctx context.Context) ([]byte, error) { + args := RunArgs{ + Name: "RunCommand", + Rows: 10, + Cols: 80, + App: "/bin/bash", + Cwd: t.Cwd, + CommandLine: []string{"-c", t.Cmd}, + ReadOnlyForGuests: false, + } + + var output []byte + runCall := t.Terminal.RPCClient.Go("terminal.startAsync", &args, &output, nil) + + runReply := <-runCall.Done + if runReply.Error != nil { + return nil, fmt.Errorf("error startAsync operation: %v", runReply.Error) + } + fmt.Printf("%+v\n\n", runReply) + return output, nil +} diff --git a/ssh.go b/ssh.go index af132b331..51906290e 100644 --- a/ssh.go +++ b/ssh.go @@ -2,29 +2,66 @@ package liveshare import ( "fmt" + "io" "net" "net/url" "strings" + "time" + "github.com/gorilla/websocket" "golang.org/x/crypto/ssh" - "golang.org/x/net/websocket" ) -type SSHSession struct { - Session *Session - VersionExchangeError chan error +type SSH struct { + Session *Session } -func NewSSHSession(session *Session) *SSHSession { - return &SSHSession{ +func NewSSH(session *Session) *SSH { + return &SSH{ Session: session, } } -func (s *SSHSession) Connect() error { +// Reference: +// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 +func (s *SSH) relayURI(action string) string { + relaySas := url.QueryEscape(s.Session.WorkspaceAccess.RelaySas) + relayURI := s.Session.WorkspaceAccess.RelayLink + relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) + relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) + relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas + return relayURI +} + +func (s *SSH) socketStream() (net.Conn, error) { + uri := s.relayURI("connect") + + ws, _, err := websocket.DefaultDialer.Dial(uri, nil) + if err != nil { + return nil, fmt.Errorf("error dialing websocket connection: %v", err) + } + + return NewAdapter(ws), nil +} + +type SSHSession struct { + *ssh.Session + reader io.Reader + writer io.Writer +} + +func (s SSHSession) Read(p []byte) (n int, err error) { + return s.reader.Read(p) +} + +func (s SSHSession) Write(p []byte) (n int, err error) { + return s.writer.Write(p) +} + +func (s *SSH) NewSession() (*SSHSession, error) { socketStream, err := s.socketStream() if err != nil { - return fmt.Errorf("error creating socket stream: %v", err) + return nil, fmt.Errorf("error creating socket stream: %v", err) } clientConfig := ssh.ClientConfig{ @@ -36,35 +73,29 @@ func (s *SSHSession) Connect() error { // TODO(josebalius): implement return nil }, + Timeout: 10 * time.Second, } sshClientConn, chans, reqs, err := ssh.NewClientConn(socketStream, "", &clientConfig) if err != nil { - return fmt.Errorf("error creating ssh client connection: %v", err) + return nil, fmt.Errorf("error creating ssh client connection: %v", err) } - fmt.Println(sshClientConn, chans, reqs) - - return nil -} - -// Reference: -// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 -func (s *SSHSession) relayURI(action string) string { - relaySas := url.QueryEscape(s.Session.WorkspaceAccess.RelaySas) - relayURI := s.Session.WorkspaceAccess.RelayLink - relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) - relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) - relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas - return relayURI -} - -func (s *SSHSession) socketStream() (*websocket.Conn, error) { - uri := s.relayURI("connect") - ws, err := websocket.Dial(uri, "", uri) + sshClient := ssh.NewClient(sshClientConn, chans, reqs) + sshSession, err := sshClient.NewSession() if err != nil { - return nil, fmt.Errorf("error dialing relay connection: %v", err) + return nil, fmt.Errorf("error creating ssh client session: %v", err) } - return ws, nil + reader, err := sshSession.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("error creating ssh session reader: %v", err) + } + + writer, err := sshSession.StdinPipe() + if err != nil { + return nil, fmt.Errorf("error creating ssh session writer: %v", err) + } + + return &SSHSession{Session: sshSession, reader: reader, writer: writer}, nil } From 6cd0aa7a90c9737a8c9c0f113a301ebaf1264e5e Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 6 Jul 2021 09:12:43 -0400 Subject: [PATCH 04/68] Working albeit not imperfect implementation --- adapter.go | 100 --------------------------------------- api.go | 55 +++++++++++----------- client.go | 115 +++++++++++++++++++++++++++++++++++++++++---- example/main.go | 109 +++++++++++++++++++++++++++++++++++++------ liveshare.go | 67 -------------------------- port_forwarder.go | 63 +++++++++++++++++++++++++ rpc.go | 97 ++++++++++++++++++++++++++++++++++++++ rpc_test.go | 27 +++++++++++ server.go | 52 +++++++++++++++++++++ session.go | 40 +++++++++++----- ssh.go | 97 +++++++++++++------------------------- terminal.go | 116 ++++++++++++++++++++++++++++++++++++++++++++++ websocket.go | 105 +++++++++++++++++++++++++++++++++++++++++ 13 files changed, 746 insertions(+), 297 deletions(-) delete mode 100644 adapter.go create mode 100644 port_forwarder.go create mode 100644 rpc.go create mode 100644 rpc_test.go create mode 100644 server.go create mode 100644 terminal.go create mode 100644 websocket.go diff --git a/adapter.go b/adapter.go deleted file mode 100644 index fb3424734..000000000 --- a/adapter.go +++ /dev/null @@ -1,100 +0,0 @@ -package liveshare - -import ( - "errors" - "io" - "net" - "sync" - "time" - - "github.com/gorilla/websocket" -) - -type Adapter struct { - conn *websocket.Conn - readMutex sync.Mutex - writeMutex sync.Mutex - reader io.Reader -} - -func NewAdapter(conn *websocket.Conn) *Adapter { - return &Adapter{ - conn: conn, - } -} - -func (a *Adapter) Read(b []byte) (int, error) { - // Read() can be called concurrently, and we mutate some internal state here - a.readMutex.Lock() - defer a.readMutex.Unlock() - - if a.reader == nil { - messageType, reader, err := a.conn.NextReader() - if err != nil { - return 0, err - } - - if messageType != websocket.BinaryMessage { - return 0, errors.New("unexpected websocket message type") - } - - a.reader = reader - } - - bytesRead, err := a.reader.Read(b) - if err != nil { - a.reader = nil - - // EOF for the current Websocket frame, more will probably come so.. - if err == io.EOF { - // .. we must hide this from the caller since our semantics are a - // stream of bytes across many frames - err = nil - } - } - - return bytesRead, err -} - -func (a *Adapter) Write(b []byte) (int, error) { - a.writeMutex.Lock() - defer a.writeMutex.Unlock() - - nextWriter, err := a.conn.NextWriter(websocket.BinaryMessage) - if err != nil { - return 0, err - } - - bytesWritten, err := nextWriter.Write(b) - nextWriter.Close() - - return bytesWritten, err -} - -func (a *Adapter) Close() error { - return a.conn.Close() -} - -func (a *Adapter) LocalAddr() net.Addr { - return a.conn.LocalAddr() -} - -func (a *Adapter) RemoteAddr() net.Addr { - return a.conn.RemoteAddr() -} - -func (a *Adapter) SetDeadline(t time.Time) error { - if err := a.SetReadDeadline(t); err != nil { - return err - } - - return a.SetWriteDeadline(t) -} - -func (a *Adapter) SetReadDeadline(t time.Time) error { - return a.conn.SetReadDeadline(t) -} - -func (a *Adapter) SetWriteDeadline(t time.Time) error { - return a.conn.SetWriteDeadline(t) -} diff --git a/api.go b/api.go index d8a4bebbd..a101823df 100644 --- a/api.go +++ b/api.go @@ -5,21 +5,20 @@ import ( "fmt" "io/ioutil" "net/http" - "net/http/httputil" "strings" ) -type API struct { - Configuration *Configuration - HttpClient *http.Client - ServiceURI string - WorkspaceID string +type api struct { + client *Client + httpClient *http.Client + serviceURI string + workspaceID string } -func NewAPI(configuration *Configuration) *API { - serviceURI := configuration.LiveShareEndpoint - if !strings.HasSuffix(configuration.LiveShareEndpoint, "/") { - serviceURI = configuration.LiveShareEndpoint + "/" +func newAPI(client *Client) *api { + serviceURI := client.liveShare.Configuration.LiveShareEndpoint + if !strings.HasSuffix(client.liveShare.Configuration.LiveShareEndpoint, "/") { + serviceURI = client.liveShare.Configuration.LiveShareEndpoint + "/" } if !strings.Contains(serviceURI, "api/v1.2") { @@ -28,10 +27,10 @@ func NewAPI(configuration *Configuration) *API { serviceURI = strings.TrimSuffix(serviceURI, "/") - return &API{configuration, &http.Client{}, serviceURI, strings.ToUpper(configuration.WorkspaceID)} + return &api{client, &http.Client{}, serviceURI, strings.ToUpper(client.liveShare.Configuration.WorkspaceID)} } -type WorkspaceAccessResponse struct { +type workspaceAccessResponse struct { SessionToken string `json:"sessionToken"` CreatedAt string `json:"createdAt"` UpdatedAt string `json:"updatedAt"` @@ -51,8 +50,8 @@ type WorkspaceAccessResponse struct { ID string `json:"id"` } -func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { - url := fmt.Sprintf("%s/workspace/%s/user", a.ServiceURI, a.WorkspaceID) +func (a *api) workspaceAccess() (*workspaceAccessResponse, error) { + url := fmt.Sprintf("%s/workspace/%s/user", a.serviceURI, a.workspaceID) fmt.Println(url) req, err := http.NewRequest(http.MethodPut, url, nil) @@ -61,7 +60,7 @@ func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { } a.setDefaultHeaders(req) - resp, err := a.HttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -71,23 +70,21 @@ func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { return nil, fmt.Errorf("error reading response body: %v", err) } - d, _ := httputil.DumpResponse(resp, true) - fmt.Println(string(d)) - var workspaceAccessResponse WorkspaceAccessResponse - if err := json.Unmarshal(b, &workspaceAccessResponse); err != nil { + var response workspaceAccessResponse + if err := json.Unmarshal(b, &response); err != nil { return nil, fmt.Errorf("error unmarshaling response into json: %v", err) } - return &workspaceAccessResponse, nil + return &response, nil } -func (a *API) setDefaultHeaders(req *http.Request) { - req.Header.Set("Authorization", "Bearer "+a.Configuration.Token) +func (a *api) setDefaultHeaders(req *http.Request) { + req.Header.Set("Authorization", "Bearer "+a.client.liveShare.Configuration.Token) req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Content-Type", "application/json") } -type WorkspaceInfoResponse struct { +type workspaceInfoResponse struct { CreatedAt string `json:"createdAt"` UpdatedAt string `json:"updatedAt"` Name string `json:"name"` @@ -106,8 +103,8 @@ type WorkspaceInfoResponse struct { ID string `json:"id"` } -func (a *API) WorkspaceInfo() (*WorkspaceInfoResponse, error) { - url := fmt.Sprintf("%s/workspace/%s", a.ServiceURI, a.WorkspaceID) +func (a *api) workspaceInfo() (*workspaceInfoResponse, error) { + url := fmt.Sprintf("%s/workspace/%s", a.serviceURI, a.workspaceID) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { @@ -115,7 +112,7 @@ func (a *API) WorkspaceInfo() (*WorkspaceInfoResponse, error) { } a.setDefaultHeaders(req) - resp, err := a.HttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -125,10 +122,10 @@ func (a *API) WorkspaceInfo() (*WorkspaceInfoResponse, error) { return nil, fmt.Errorf("error reading response body: %v", err) } - var workspaceInfoResponse WorkspaceInfoResponse - if err := json.Unmarshal(b, &workspaceInfoResponse); err != nil { + var response workspaceInfoResponse + if err := json.Unmarshal(b, &response); err != nil { return nil, fmt.Errorf("error unmarshaling response into json: %v", err) } - return &workspaceInfoResponse, nil + return &response, nil } diff --git a/client.go b/client.go index 0a89a125c..88c91ba01 100644 --- a/client.go +++ b/client.go @@ -3,27 +3,122 @@ package liveshare import ( "context" "fmt" + "log" + + "golang.org/x/crypto/ssh" ) type Client struct { - Configuration *Configuration - SSHSession *SSHSession + liveShare *LiveShare + session *session + sshSession *sshSession + rpc *rpc } -func NewClient(configuration *Configuration) *Client { - return &Client{Configuration: configuration} +// NewClient is a function ... +func (l *LiveShare) NewClient() *Client { + return &Client{liveShare: l} } -func (c *Client) Join(ctx context.Context) error { - session, err := GetSession(ctx, c.Configuration) - if err != nil { - return fmt.Errorf("error getting session: %v", err) +func (c *Client) Join(ctx context.Context) (err error) { + api := newAPI(c) + + c.session = newSession(api) + if err := c.session.init(ctx); err != nil { + return fmt.Errorf("error creating session: %v", err) } - c.SSHSession, err = NewSSH(session).NewSession() - if err != nil { + websocket := newWebsocket(c.session) + if err := websocket.connect(ctx); err != nil { + return fmt.Errorf("error connecting websocket: %v", err) + } + + c.sshSession = newSSH(c.session, websocket) + if err := c.sshSession.connect(ctx); err != nil { return fmt.Errorf("error connecting to ssh session: %v", err) } + c.rpc = newRPC(c.sshSession) + c.rpc.connect(ctx) + + _, err = c.joinWorkspace(ctx) + if err != nil { + return fmt.Errorf("error joining liveshare workspace: %v", err) + } + return nil } + +func (c *Client) hasJoined() bool { + return c.sshSession != nil && c.rpc != nil +} + +type clientCapabilities struct { + IsNonInteractive bool `json:"isNonInteractive"` +} + +type joinWorkspaceArgs struct { + ID string `json:"id"` + ConnectionMode string `json:"connectionMode"` + JoiningUserSessionToken string `json:"joiningUserSessionToken"` + ClientCapabilities clientCapabilities `json:"clientCapabilities"` +} + +type joinWorkspaceResult struct { + SessionNumber int `json:"sessionNumber"` +} + +func (c *Client) joinWorkspace(ctx context.Context) (*joinWorkspaceResult, error) { + args := joinWorkspaceArgs{ + ID: c.session.workspaceInfo.ID, + ConnectionMode: "local", + JoiningUserSessionToken: c.session.workspaceAccess.SessionToken, + ClientCapabilities: clientCapabilities{ + IsNonInteractive: false, + }, + } + + var result joinWorkspaceResult + if err := c.rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { + return nil, fmt.Errorf("error making workspace.joinWorkspace call: %v", err) + } + + return &result, nil +} + +func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { + args := getStreamArgs{streamName, condition} + var streamID string + if err := c.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { + return nil, fmt.Errorf("error getting stream id: %v", err) + } + + channel, reqs, err := c.sshSession.conn.OpenChannel("session", nil) + if err != nil { + return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) + } + go c.processChannelRequests(ctx, reqs) + + requestType := fmt.Sprintf("stream-transport-%s", streamID) + acked, err := channel.SendRequest(requestType, true, nil) + if err != nil { + return nil, fmt.Errorf("error sending channel request: %v", err) + } + fmt.Println("ACKED: ", acked) + + return channel, nil +} + +func (c *Client) processChannelRequests(ctx context.Context, reqs <-chan *ssh.Request) { + for { + select { + case req := <-reqs: + if req != nil { + fmt.Printf("REQ: %+v\n\n", req) + log.Println("streaming channel requests are not supported") + } + case <-ctx.Done(): + break + } + } +} diff --git a/example/main.go b/example/main.go index bbf540d7c..b1fda8d32 100644 --- a/example/main.go +++ b/example/main.go @@ -1,36 +1,117 @@ package main import ( + "bufio" "context" + "flag" "fmt" "log" + "os" - "github.com/josebalius/go-liveshare" + "github.com/github/go-liveshare" ) +var workspaceIdFlag = flag.String("w", "", "workspace session id") + +func init() { + flag.Parse() +} + func main() { liveShare, err := liveshare.New( - liveshare.WithWorkspaceID(""), - liveshare.WithToken(""), + liveshare.WithWorkspaceID(*workspaceIdFlag), + liveshare.WithToken(os.Getenv("CODESPACE_TOKEN")), ) if err != nil { log.Fatal(fmt.Errorf("error creating liveshare: %v", err)) } - if err := liveShare.Connect(context.Background()); err != nil { - log.Fatal(fmt.Errorf("error connecting to liveshare: %v", err)) + ctx := context.Background() + liveShareClient := liveShare.NewClient() + if err := liveShareClient.Join(ctx); err != nil { + log.Fatal(fmt.Errorf("error joining liveshare with client: %v", err)) } - terminal := liveShare.NewTerminal() - - cmd := terminal.NewCommand( - "/home/codespace/workspace", - "docker ps -aq --filter label=Type=codespaces --filter status=running", - ) - output, err := cmd.Run(context.Background()) + terminal, err := liveShareClient.NewTerminal() if err != nil { - log.Fatal(fmt.Errorf("error starting ssh server with liveshare: %v", err)) + log.Fatal(fmt.Errorf("error creating liveshare terminal")) } - fmt.Println(string(output)) + containerID, err := getContainerID(ctx, terminal) + if err != nil { + log.Fatal(fmt.Errorf("error getting container id: %v", err)) + } + + if err := setupSSH(ctx, terminal, containerID); err != nil { + log.Fatal(fmt.Errorf("error setting up ssh: %v", err)) + } + + fmt.Println("Starting server...") + + server, err := liveShareClient.NewServer() + if err != nil { + log.Fatal(fmt.Errorf("error creating server: %v", err)) + } + + fmt.Println("Starting sharing...") + if err := server.StartSharing(ctx, "sshd", 2222); err != nil { + log.Fatal(fmt.Errorf("error server sharing: %v", err)) + } + + portForwarder := liveshare.NewLocalPortForwarder(liveShareClient, server, 2222) + + fmt.Println("Listening on port 2222") + if err := portForwarder.Start(ctx); err != nil { + log.Fatal(fmt.Errorf("error forwarding port: %v", err)) + } +} + +func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID string) error { + cmd := terminal.NewCommand( + "/", + fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c \"echo -e \\\"testpwd1\\ntestpwd1\\n\\\" | sudo passwd codespace;/usr/local/share/ssh-init.sh\"", containerID), + ) + stream, err := cmd.Run(ctx) + if err != nil { + return fmt.Errorf("error running command: %v", err) + } + + scanner := bufio.NewScanner(stream) + scanner.Scan() + + fmt.Println("> Debug:", scanner.Text()) + if err := scanner.Err(); err != nil { + return fmt.Errorf("error scanning stream: %v", err) + } + + if err := stream.Close(); err != nil { + return fmt.Errorf("error closing stream: %v", err) + } + + return nil +} + +func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) { + cmd := terminal.NewCommand( + "/", + "/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running", + ) + stream, err := cmd.Run(ctx) + if err != nil { + return "", fmt.Errorf("error running command: %v", err) + } + + scanner := bufio.NewScanner(stream) + scanner.Scan() + + containerID := scanner.Text() + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("error scanning stream: %v", err) + } + + if err := stream.Close(); err != nil { + return "", fmt.Errorf("error closing stream: %v", err) + } + + return containerID, nil } diff --git a/liveshare.go b/liveshare.go index 174eac20f..38222957a 100644 --- a/liveshare.go +++ b/liveshare.go @@ -1,16 +1,11 @@ package liveshare import ( - "context" "fmt" - "net/rpc" ) type LiveShare struct { Configuration *Configuration - - workspaceClient *Client - terminal *Terminal } func New(opts ...Option) (*LiveShare, error) { @@ -28,65 +23,3 @@ func New(opts ...Option) (*LiveShare, error) { return &LiveShare{Configuration: configuration}, nil } - -func (l *LiveShare) Connect(ctx context.Context) error { - l.workspaceClient = NewClient(l.Configuration) - if err := l.workspaceClient.Join(ctx); err != nil { - return fmt.Errorf("error joining with workspace client: %v", err) - } - - return nil -} - -type Terminal struct { - WorkspaceClient *Client - RPCClient *rpc.Client -} - -func (l *LiveShare) NewTerminal() *Terminal { - return &Terminal{ - WorkspaceClient: l.workspaceClient, - RPCClient: rpc.NewClient(l.workspaceClient.SSHSession), - } -} - -type TerminalCommand struct { - Terminal *Terminal - Cwd string - Cmd string -} - -func (t *Terminal) NewCommand(cwd, cmd string) TerminalCommand { - return TerminalCommand{t, cwd, cmd} -} - -type RunArgs struct { - Name string - Rows, Cols int - App string - Cwd string - CommandLine []string - ReadOnlyForGuests bool -} - -func (t TerminalCommand) Run(ctx context.Context) ([]byte, error) { - args := RunArgs{ - Name: "RunCommand", - Rows: 10, - Cols: 80, - App: "/bin/bash", - Cwd: t.Cwd, - CommandLine: []string{"-c", t.Cmd}, - ReadOnlyForGuests: false, - } - - var output []byte - runCall := t.Terminal.RPCClient.Go("terminal.startAsync", &args, &output, nil) - - runReply := <-runCall.Done - if runReply.Error != nil { - return nil, fmt.Errorf("error startAsync operation: %v", runReply.Error) - } - fmt.Printf("%+v\n\n", runReply) - return output, nil -} diff --git a/port_forwarder.go b/port_forwarder.go new file mode 100644 index 000000000..0ae5e1916 --- /dev/null +++ b/port_forwarder.go @@ -0,0 +1,63 @@ +package liveshare + +import ( + "context" + "fmt" + "io" + "log" + "net" + "strconv" + + "golang.org/x/crypto/ssh" +) + +type LocalPortForwarder struct { + client *Client + server *Server + port int + channels []ssh.Channel +} + +func NewLocalPortForwarder(client *Client, server *Server, port int) *LocalPortForwarder { + return &LocalPortForwarder{client, server, port, []ssh.Channel{}} +} + +func (l *LocalPortForwarder) Start(ctx context.Context) error { + ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) + if err != nil { + return fmt.Errorf("error listening on tcp port: %v", err) + } + + for { + conn, err := ln.Accept() + if err != nil { + return fmt.Errorf("error accepting incoming connection: %v", err) + } + + go l.handleConnection(ctx, conn) + } + + // clean up after ourselves + + return nil +} + +func (l *LocalPortForwarder) handleConnection(ctx context.Context, conn net.Conn) { + channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) + if err != nil { + log.Println("errrr handle Connect") + log.Println(err) // TODO(josebalius) handle this somehow + } + l.channels = append(l.channels, channel) + + copyConn := func(writer io.Writer, reader io.Reader) { + _, err := io.Copy(writer, reader) + if err != nil { + log.Println("errrrr copyConn") + log.Println(err) //TODO(josebalius): handle this somehow + } + } + + go copyConn(conn, channel) + go copyConn(channel, conn) +} diff --git a/rpc.go b/rpc.go new file mode 100644 index 000000000..e90f71ba6 --- /dev/null +++ b/rpc.go @@ -0,0 +1,97 @@ +package liveshare + +import ( + "context" + "encoding/json" + "fmt" + "io" + "sync" + + "github.com/sourcegraph/jsonrpc2" +) + +type rpc struct { + *jsonrpc2.Conn + conn io.ReadWriteCloser + handler *rpcHandler +} + +func newRPC(conn io.ReadWriteCloser) *rpc { + return &rpc{conn: conn, handler: newRPCHandler()} +} + +func (r *rpc) connect(ctx context.Context) { + stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) + r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) +} + +func (r *rpc) do(ctx context.Context, method string, args interface{}, result interface{}) error { + b, _ := json.Marshal(args) + fmt.Println("rpc sent: ", method, string(b)) + waiter, err := r.Conn.DispatchCall(ctx, method, args) + if err != nil { + return fmt.Errorf("error on dispatch call: %v", err) + } + + // caller doesn't care about result, so lets ignore it + if result == nil { + return nil + } + + return waiter.Wait(ctx, result) +} + +type rpcHandler struct { + mutex sync.RWMutex + eventHandlers map[string][]chan *jsonrpc2.Request +} + +func newRPCHandler() *rpcHandler { + return &rpcHandler{ + eventHandlers: make(map[string][]chan *jsonrpc2.Request), + } +} + +func (r *rpcHandler) registerEventHandler(eventMethod string) <-chan *jsonrpc2.Request { + r.mutex.Lock() + defer r.mutex.Unlock() + + ch := make(chan *jsonrpc2.Request) + if _, ok := r.eventHandlers[eventMethod]; !ok { + r.eventHandlers[eventMethod] = []chan *jsonrpc2.Request{ch} + } else { + r.eventHandlers[eventMethod] = append(r.eventHandlers[eventMethod], ch) + } + return ch +} + +func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + r.mutex.Lock() + defer r.mutex.Unlock() + + fmt.Println("REQUEST") + fmt.Println("Method:", req.Method) + b, _ := req.MarshalJSON() + fmt.Println(string(b)) + fmt.Println("----") + fmt.Printf("%+v\n\n", r.eventHandlers) + if handlers, ok := r.eventHandlers[req.Method]; ok { + go func() { + for _, handler := range handlers { + select { + case handler <- req: + case <-ctx.Done(): + break + } + } + + r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{} + }() + } else { + fmt.Println("UNHANDLED REQUEST") + fmt.Println("Method:", req.Method) + b, _ := req.MarshalJSON() + fmt.Println(string(b)) + fmt.Println("----") + } +} diff --git a/rpc_test.go b/rpc_test.go new file mode 100644 index 000000000..d16b32a4f --- /dev/null +++ b/rpc_test.go @@ -0,0 +1,27 @@ +package liveshare + +import ( + "context" + "testing" + "time" + + "github.com/sourcegraph/jsonrpc2" +) + +func TestRPCHandlerEvents(t *testing.T) { + rpcHandler := newRPCHandler() + eventCh := rpcHandler.registerEventHandler("somethingHappened") + go func() { + time.Sleep(1 * time.Second) + rpcHandler.Handle(context.Background(), nil, &jsonrpc2.Request{Method: "somethingHappened"}) + }() + ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + select { + case event := <-eventCh: + if event.Method != "somethingHappened" { + t.Error("event.Method is not the expect value") + } + case <-ctx.Done(): + t.Error("Test time out") + } +} diff --git a/server.go b/server.go new file mode 100644 index 000000000..65e03d584 --- /dev/null +++ b/server.go @@ -0,0 +1,52 @@ +package liveshare + +import ( + "context" + "errors" + "fmt" + "strconv" +) + +type Server struct { + client *Client + port int + streamName, streamCondition string +} + +func (c *Client) NewServer() (*Server, error) { + if !c.hasJoined() { + return nil, errors.New("LiveShareClient must join before creating server") + } + + return &Server{client: c}, nil +} + +type serverSharingResponse struct { + SourcePort int `json:"sourcePort"` + DestinationPort int `json:"destinationPort"` + SessionName string `json:"sessionName"` + StreamName string `json:"streamName"` + StreamCondition string `json:"streamCondition"` + BrowseURL string `json:"browseUrl"` + IsPublic bool `json:"isPublic"` + IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"` + HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"` +} + +func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { + s.port = port + + sharingStarted := s.client.rpc.handler.registerEventHandler("serverSharing.sharingStarted") + var response serverSharingResponse + if err := s.client.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ + port, protocol, fmt.Sprintf("http://localhost:%s", strconv.Itoa(port)), + }, &response); err != nil { + return err + } + <-sharingStarted + + s.streamName = response.StreamName + s.streamCondition = response.StreamCondition + + return nil +} diff --git a/session.go b/session.go index 24a284ef2..d0492a10e 100644 --- a/session.go +++ b/session.go @@ -3,42 +3,58 @@ package liveshare import ( "context" "fmt" + "net/url" + "strings" "golang.org/x/sync/errgroup" ) -type Session struct { - WorkspaceAccess *WorkspaceAccessResponse - WorkspaceInfo *WorkspaceInfoResponse +type session struct { + api *api + + workspaceAccess *workspaceAccessResponse + workspaceInfo *workspaceInfoResponse } -func GetSession(ctx context.Context, configuration *Configuration) (*Session, error) { - api := NewAPI(configuration) - session := new(Session) +func newSession(api *api) *session { + return &session{api: api} +} +func (s *session) init(ctx context.Context) error { g, ctx := errgroup.WithContext(ctx) g.Go(func() error { - workspaceAccess, err := api.WorkspaceAccess() + workspaceAccess, err := s.api.workspaceAccess() if err != nil { return fmt.Errorf("error getting workspace access: %v", err) } - session.WorkspaceAccess = workspaceAccess + s.workspaceAccess = workspaceAccess return nil }) g.Go(func() error { - workspaceInfo, err := api.WorkspaceInfo() + workspaceInfo, err := s.api.workspaceInfo() if err != nil { return fmt.Errorf("error getting workspace info: %v", err) } - session.WorkspaceInfo = workspaceInfo + s.workspaceInfo = workspaceInfo return nil }) if err := g.Wait(); err != nil { - return nil, err + return err } - return session, nil + return nil +} + +// Reference: +// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 +func (s *session) relayURI(action string) string { + relaySas := url.QueryEscape(s.workspaceAccess.RelaySas) + relayURI := s.workspaceAccess.RelayLink + relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) + relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) + relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas + return relayURI } diff --git a/ssh.go b/ssh.go index 51906290e..9ae32ed7c 100644 --- a/ssh.go +++ b/ssh.go @@ -1,101 +1,68 @@ package liveshare import ( + "context" "fmt" "io" "net" - "net/url" - "strings" "time" - "github.com/gorilla/websocket" "golang.org/x/crypto/ssh" ) -type SSH struct { - Session *Session -} - -func NewSSH(session *Session) *SSH { - return &SSH{ - Session: session, - } -} - -// Reference: -// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 -func (s *SSH) relayURI(action string) string { - relaySas := url.QueryEscape(s.Session.WorkspaceAccess.RelaySas) - relayURI := s.Session.WorkspaceAccess.RelayLink - relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) - relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) - relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas - return relayURI -} - -func (s *SSH) socketStream() (net.Conn, error) { - uri := s.relayURI("connect") - - ws, _, err := websocket.DefaultDialer.Dial(uri, nil) - if err != nil { - return nil, fmt.Errorf("error dialing websocket connection: %v", err) - } - - return NewAdapter(ws), nil -} - -type SSHSession struct { +type sshSession struct { *ssh.Session - reader io.Reader - writer io.Writer + session *session + socket net.Conn + conn ssh.Conn + reader io.Reader + writer io.Writer } -func (s SSHSession) Read(p []byte) (n int, err error) { - return s.reader.Read(p) +func newSSH(session *session, socket net.Conn) *sshSession { + return &sshSession{session: session, socket: socket} } -func (s SSHSession) Write(p []byte) (n int, err error) { - return s.writer.Write(p) -} - -func (s *SSH) NewSession() (*SSHSession, error) { - socketStream, err := s.socketStream() - if err != nil { - return nil, fmt.Errorf("error creating socket stream: %v", err) - } - +func (s *sshSession) connect(ctx context.Context) error { clientConfig := ssh.ClientConfig{ User: "", Auth: []ssh.AuthMethod{ - ssh.Password(s.Session.WorkspaceAccess.SessionToken), + ssh.Password(s.session.workspaceAccess.SessionToken), }, - HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { - // TODO(josebalius): implement - return nil - }, - Timeout: 10 * time.Second, + HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 10 * time.Second, } - sshClientConn, chans, reqs, err := ssh.NewClientConn(socketStream, "", &clientConfig) + sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig) if err != nil { - return nil, fmt.Errorf("error creating ssh client connection: %v", err) + return fmt.Errorf("error creating ssh client connection: %v", err) } + s.conn = sshClientConn sshClient := ssh.NewClient(sshClientConn, chans, reqs) - sshSession, err := sshClient.NewSession() + s.Session, err = sshClient.NewSession() if err != nil { - return nil, fmt.Errorf("error creating ssh client session: %v", err) + return fmt.Errorf("error creating ssh client session: %v", err) } - reader, err := sshSession.StdoutPipe() + s.reader, err = s.Session.StdoutPipe() if err != nil { - return nil, fmt.Errorf("error creating ssh session reader: %v", err) + return fmt.Errorf("error creating ssh session reader: %v", err) } - writer, err := sshSession.StdinPipe() + s.writer, err = s.Session.StdinPipe() if err != nil { - return nil, fmt.Errorf("error creating ssh session writer: %v", err) + return fmt.Errorf("error creating ssh session writer: %v", err) } - return &SSHSession{Session: sshSession, reader: reader, writer: writer}, nil + return nil +} + +func (s *sshSession) Read(p []byte) (n int, err error) { + return s.reader.Read(p) +} + +func (s *sshSession) Write(p []byte) (n int, err error) { + return s.writer.Write(p) } diff --git a/terminal.go b/terminal.go new file mode 100644 index 000000000..631e75912 --- /dev/null +++ b/terminal.go @@ -0,0 +1,116 @@ +package liveshare + +import ( + "context" + "errors" + "fmt" + "io" + + "golang.org/x/crypto/ssh" +) + +type Terminal struct { + client *Client +} + +func (c *Client) NewTerminal() (*Terminal, error) { + if !c.hasJoined() { + return nil, errors.New("LiveShareClient must join before creating terminal") + } + + return &Terminal{ + client: c, + }, nil +} + +type TerminalCommand struct { + terminal *Terminal + cwd string + cmd string +} + +func (t *Terminal) NewCommand(cwd, cmd string) TerminalCommand { + return TerminalCommand{t, cwd, cmd} +} + +type runArgs struct { + Name string `json:"name"` + Rows int `json:"rows"` + Cols int `json:"cols"` + App string `json:"app"` + Cwd string `json:"cwd"` + CommandLine []string `json:"commandLine"` + ReadOnlyForGuests bool `json:"readOnlyForGuests"` +} + +type startTerminalResult struct { + ID int `json:"id"` + StreamName string `json:"streamName"` + StreamCondition string `json:"streamCondition"` + LocalPipeName string `json:"localPipeName"` + AppProcessID int `json:"appProcessId"` +} + +type getStreamArgs struct { + StreamName string `json:"streamName"` + Condition string `json:"condition"` +} + +type stopTerminalArgs struct { + ID int `json:"id"` +} + +func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { + args := runArgs{ + Name: "RunCommand", + Rows: 10, + Cols: 80, + App: "/bin/bash", + Cwd: t.cwd, + CommandLine: []string{"-c", t.cmd}, + ReadOnlyForGuests: false, + } + + terminalStarted := t.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStarted") + var result startTerminalResult + if err := t.terminal.client.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil { + return nil, fmt.Errorf("error making terminal.startTerminal call: %v", err) + } + <-terminalStarted + + channel, err := t.terminal.client.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) + if err != nil { + return nil, fmt.Errorf("error opening streaming channel: %v", err) + } + + return t.newTerminalReadCloser(result.ID, channel), nil +} + +type terminalReadCloser struct { + terminalCommand TerminalCommand + terminalID int + channel ssh.Channel +} + +func (t TerminalCommand) newTerminalReadCloser(terminalID int, channel ssh.Channel) io.ReadCloser { + return terminalReadCloser{t, terminalID, channel} +} + +func (t terminalReadCloser) Read(b []byte) (int, error) { + return t.channel.Read(b) +} + +func (t terminalReadCloser) Close() error { + terminalStopped := t.terminalCommand.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStopped") + if err := t.terminalCommand.terminal.client.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil { + return fmt.Errorf("error making terminal.stopTerminal call: %v", err) + } + + if err := t.channel.Close(); err != nil { + return fmt.Errorf("error closing channel: %v", err) + } + + <-terminalStopped + + return nil +} diff --git a/websocket.go b/websocket.go new file mode 100644 index 000000000..ae163e6e2 --- /dev/null +++ b/websocket.go @@ -0,0 +1,105 @@ +package liveshare + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + gorillawebsocket "github.com/gorilla/websocket" +) + +type websocket struct { + session *session + conn *gorillawebsocket.Conn + readMutex sync.Mutex + writeMutex sync.Mutex + reader io.Reader +} + +func newWebsocket(session *session) *websocket { + return &websocket{session: session} +} + +func (w *websocket) connect(ctx context.Context) error { + ws, _, err := gorillawebsocket.DefaultDialer.Dial(w.session.relayURI("connect"), nil) + if err != nil { + return err + } + w.conn = ws + return nil +} + +func (w *websocket) Read(b []byte) (int, error) { + w.readMutex.Lock() + defer w.readMutex.Unlock() + + if w.reader == nil { + messageType, reader, err := w.conn.NextReader() + if err != nil { + return 0, err + } + + if messageType != gorillawebsocket.BinaryMessage { + return 0, errors.New("unexpected websocket message type") + } + + w.reader = reader + } + + bytesRead, err := w.reader.Read(b) + if err != nil { + w.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (w *websocket) Write(b []byte) (int, error) { + w.writeMutex.Lock() + defer w.writeMutex.Unlock() + + nextWriter, err := w.conn.NextWriter(gorillawebsocket.BinaryMessage) + if err != nil { + return 0, err + } + + bytesWritten, err := nextWriter.Write(b) + nextWriter.Close() + + return bytesWritten, err +} + +func (w *websocket) Close() error { + return w.conn.Close() +} + +func (w *websocket) LocalAddr() net.Addr { + return w.conn.LocalAddr() +} + +func (w *websocket) RemoteAddr() net.Addr { + return w.conn.RemoteAddr() +} + +func (w *websocket) SetDeadline(t time.Time) error { + if err := w.SetReadDeadline(t); err != nil { + return err + } + + return w.SetWriteDeadline(t) +} + +func (w *websocket) SetReadDeadline(t time.Time) error { + return w.conn.SetReadDeadline(t) +} + +func (w *websocket) SetWriteDeadline(t time.Time) error { + return w.conn.SetWriteDeadline(t) +} From 04a6383ccb778d0683bf63a113ac79b74d6d2b2b Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 7 Jul 2021 08:00:01 -0400 Subject: [PATCH 05/68] Tidy up go.mod --- example/main.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/example/main.go b/example/main.go index b1fda8d32..e9347bd14 100644 --- a/example/main.go +++ b/example/main.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "os" + "time" "github.com/github/go-liveshare" ) @@ -88,6 +89,8 @@ func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID str return fmt.Errorf("error closing stream: %v", err) } + time.Sleep(2 * time.Second) + return nil } From 53fd96d22ec4443ecd98b8b645f7ea0eee6e6d31 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 14 Jul 2021 20:47:06 -0400 Subject: [PATCH 06/68] Some polish and module replacement --- api.go | 1 - client.go | 7 ++----- config.go | 56 ---------------------------------------------------- liveshare.go | 52 ++++++++++++++++++++++++++++++++++++++++++++++++ rpc.go | 15 +------------- server.go | 2 -- 6 files changed, 55 insertions(+), 78 deletions(-) delete mode 100644 config.go diff --git a/api.go b/api.go index a101823df..55b6e6e93 100644 --- a/api.go +++ b/api.go @@ -52,7 +52,6 @@ type workspaceAccessResponse struct { func (a *api) workspaceAccess() (*workspaceAccessResponse, error) { url := fmt.Sprintf("%s/workspace/%s/user", a.serviceURI, a.workspaceID) - fmt.Println(url) req, err := http.NewRequest(http.MethodPut, url, nil) if err != nil { diff --git a/client.go b/client.go index 88c91ba01..0af904b92 100644 --- a/client.go +++ b/client.go @@ -3,7 +3,6 @@ package liveshare import ( "context" "fmt" - "log" "golang.org/x/crypto/ssh" ) @@ -100,11 +99,10 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition go c.processChannelRequests(ctx, reqs) requestType := fmt.Sprintf("stream-transport-%s", streamID) - acked, err := channel.SendRequest(requestType, true, nil) + _, err = channel.SendRequest(requestType, true, nil) if err != nil { return nil, fmt.Errorf("error sending channel request: %v", err) } - fmt.Println("ACKED: ", acked) return channel, nil } @@ -114,8 +112,7 @@ func (c *Client) processChannelRequests(ctx context.Context, reqs <-chan *ssh.Re select { case req := <-reqs: if req != nil { - fmt.Printf("REQ: %+v\n\n", req) - log.Println("streaming channel requests are not supported") + // TODO(josebalius): Handle } case <-ctx.Done(): break diff --git a/config.go b/config.go deleted file mode 100644 index 74eb5b178..000000000 --- a/config.go +++ /dev/null @@ -1,56 +0,0 @@ -package liveshare - -import ( - "errors" - "strings" -) - -type Option func(configuration *Configuration) error - -func WithWorkspaceID(id string) Option { - return func(configuration *Configuration) error { - configuration.WorkspaceID = id - return nil - } -} - -func WithLiveShareEndpoint(liveShareEndpoint string) Option { - return func(configuration *Configuration) error { - configuration.LiveShareEndpoint = liveShareEndpoint - return nil - } -} - -func WithToken(token string) Option { - return func(configuration *Configuration) error { - configuration.Token = token - return nil - } -} - -type Configuration struct { - WorkspaceID, LiveShareEndpoint, Token string -} - -func NewConfiguration() *Configuration { - return &Configuration{ - LiveShareEndpoint: "https://prod.liveshare.vsengsaas.visualstudio.com", - } -} - -func (c *Configuration) Validate() error { - errs := []string{} - if c.WorkspaceID == "" { - errs = append(errs, "WorkspaceID is required") - } - - if c.Token == "" { - errs = append(errs, "Token is required") - } - - if len(errs) > 0 { - return errors.New(strings.Join(errs, ", ")) - } - - return nil -} diff --git a/liveshare.go b/liveshare.go index 38222957a..3c4be5c05 100644 --- a/liveshare.go +++ b/liveshare.go @@ -1,7 +1,9 @@ package liveshare import ( + "errors" "fmt" + "strings" ) type LiveShare struct { @@ -23,3 +25,53 @@ func New(opts ...Option) (*LiveShare, error) { return &LiveShare{Configuration: configuration}, nil } + +type Option func(configuration *Configuration) error + +func WithWorkspaceID(id string) Option { + return func(configuration *Configuration) error { + configuration.WorkspaceID = id + return nil + } +} + +func WithLiveShareEndpoint(liveShareEndpoint string) Option { + return func(configuration *Configuration) error { + configuration.LiveShareEndpoint = liveShareEndpoint + return nil + } +} + +func WithToken(token string) Option { + return func(configuration *Configuration) error { + configuration.Token = token + return nil + } +} + +type Configuration struct { + WorkspaceID, LiveShareEndpoint, Token string +} + +func NewConfiguration() *Configuration { + return &Configuration{ + LiveShareEndpoint: "https://prod.liveshare.vsengsaas.visualstudio.com", + } +} + +func (c *Configuration) Validate() error { + errs := []string{} + if c.WorkspaceID == "" { + errs = append(errs, "WorkspaceID is required") + } + + if c.Token == "" { + errs = append(errs, "Token is required") + } + + if len(errs) > 0 { + return errors.New(strings.Join(errs, ", ")) + } + + return nil +} diff --git a/rpc.go b/rpc.go index e90f71ba6..de427cda9 100644 --- a/rpc.go +++ b/rpc.go @@ -2,7 +2,6 @@ package liveshare import ( "context" - "encoding/json" "fmt" "io" "sync" @@ -26,8 +25,6 @@ func (r *rpc) connect(ctx context.Context) { } func (r *rpc) do(ctx context.Context, method string, args interface{}, result interface{}) error { - b, _ := json.Marshal(args) - fmt.Println("rpc sent: ", method, string(b)) waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { return fmt.Errorf("error on dispatch call: %v", err) @@ -69,12 +66,6 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr r.mutex.Lock() defer r.mutex.Unlock() - fmt.Println("REQUEST") - fmt.Println("Method:", req.Method) - b, _ := req.MarshalJSON() - fmt.Println(string(b)) - fmt.Println("----") - fmt.Printf("%+v\n\n", r.eventHandlers) if handlers, ok := r.eventHandlers[req.Method]; ok { go func() { for _, handler := range handlers { @@ -88,10 +79,6 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{} }() } else { - fmt.Println("UNHANDLED REQUEST") - fmt.Println("Method:", req.Method) - b, _ := req.MarshalJSON() - fmt.Println(string(b)) - fmt.Println("----") + // TODO(josebalius): Handle } } diff --git a/server.go b/server.go index 65e03d584..b0f3996c9 100644 --- a/server.go +++ b/server.go @@ -36,14 +36,12 @@ type serverSharingResponse struct { func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { s.port = port - sharingStarted := s.client.rpc.handler.registerEventHandler("serverSharing.sharingStarted") var response serverSharingResponse if err := s.client.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ port, protocol, fmt.Sprintf("http://localhost:%s", strconv.Itoa(port)), }, &response); err != nil { return err } - <-sharingStarted s.streamName = response.StreamName s.streamCondition = response.StreamCondition From 98bcdd16cfccafd7ef601067287012d8010a150f Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 16 Jul 2021 22:34:51 +0000 Subject: [PATCH 07/68] Support for GetSharedServers --- server.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/server.go b/server.go index b0f3996c9..68053d9f7 100644 --- a/server.go +++ b/server.go @@ -21,7 +21,7 @@ func (c *Client) NewServer() (*Server, error) { return &Server{client: c}, nil } -type serverSharingResponse struct { +type Port struct { SourcePort int `json:"sourcePort"` DestinationPort int `json:"destinationPort"` SessionName string `json:"sessionName"` @@ -36,7 +36,7 @@ type serverSharingResponse struct { func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { s.port = port - var response serverSharingResponse + var response Port if err := s.client.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ port, protocol, fmt.Sprintf("http://localhost:%s", strconv.Itoa(port)), }, &response); err != nil { @@ -48,3 +48,14 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er return nil } + +type Ports []*Port + +func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { + var response Ports + if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { + return nil, err + } + + return response, nil +} From e373c91f8b2121a25eacf2f70f040dab14bff730 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Sun, 18 Jul 2021 00:05:13 +0000 Subject: [PATCH 08/68] UpdateSharedServerVisibility API for Server --- server.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/server.go b/server.go index 68053d9f7..71ec9d4dd 100644 --- a/server.go +++ b/server.go @@ -59,3 +59,11 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { return response, nil } + +func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error { + if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil { + return err + } + + return nil +} From 6d5726d78a665643f89514fc678d9ab1ccb1a138 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 20 Jul 2021 11:59:14 +0000 Subject: [PATCH 09/68] Better way to discard requests & close channel/conn on disconnects --- client.go | 15 +-------------- port_forwarder.go | 4 ++-- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/client.go b/client.go index 0af904b92..456a2c321 100644 --- a/client.go +++ b/client.go @@ -96,7 +96,7 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition if err != nil { return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) } - go c.processChannelRequests(ctx, reqs) + go ssh.DiscardRequests(reqs) requestType := fmt.Sprintf("stream-transport-%s", streamID) _, err = channel.SendRequest(requestType, true, nil) @@ -106,16 +106,3 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition return channel, nil } - -func (c *Client) processChannelRequests(ctx context.Context, reqs <-chan *ssh.Request) { - for { - select { - case req := <-reqs: - if req != nil { - // TODO(josebalius): Handle - } - case <-ctx.Done(): - break - } - } -} diff --git a/port_forwarder.go b/port_forwarder.go index 0ae5e1916..20382c208 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -53,8 +53,8 @@ func (l *LocalPortForwarder) handleConnection(ctx context.Context, conn net.Conn copyConn := func(writer io.Writer, reader io.Reader) { _, err := io.Copy(writer, reader) if err != nil { - log.Println("errrrr copyConn") - log.Println(err) //TODO(josebalius): handle this somehow + channel.Close() + conn.Close() } } From 7332aa428c4db7b87c4280063b89a4d10763cb3c Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 22 Jul 2021 00:45:45 +0000 Subject: [PATCH 10/68] Large refactor and solidifying of APIs before tests --- api.go | 130 ---------------------------------------------- client.go | 50 +++++++++++------- client_test.go | 24 +++++++++ connection.go | 43 +++++++++++++++ liveshare.go | 77 --------------------------- port_forwarder.go | 10 ++-- rpc.go | 2 - server.go | 8 +-- session.go | 60 --------------------- socket.go | 105 +++++++++++++++++++++++++++++++++++++ ssh.go | 16 +++--- terminal.go | 8 +-- websocket.go | 105 ------------------------------------- 13 files changed, 224 insertions(+), 414 deletions(-) delete mode 100644 api.go create mode 100644 client_test.go create mode 100644 connection.go delete mode 100644 liveshare.go delete mode 100644 session.go create mode 100644 socket.go delete mode 100644 websocket.go diff --git a/api.go b/api.go deleted file mode 100644 index 55b6e6e93..000000000 --- a/api.go +++ /dev/null @@ -1,130 +0,0 @@ -package liveshare - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "strings" -) - -type api struct { - client *Client - httpClient *http.Client - serviceURI string - workspaceID string -} - -func newAPI(client *Client) *api { - serviceURI := client.liveShare.Configuration.LiveShareEndpoint - if !strings.HasSuffix(client.liveShare.Configuration.LiveShareEndpoint, "/") { - serviceURI = client.liveShare.Configuration.LiveShareEndpoint + "/" - } - - if !strings.Contains(serviceURI, "api/v1.2") { - serviceURI = serviceURI + "api/v1.2" - } - - serviceURI = strings.TrimSuffix(serviceURI, "/") - - return &api{client, &http.Client{}, serviceURI, strings.ToUpper(client.liveShare.Configuration.WorkspaceID)} -} - -type workspaceAccessResponse struct { - SessionToken string `json:"sessionToken"` - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - Name string `json:"name"` - OwnerID string `json:"ownerId"` - JoinLink string `json:"joinLink"` - ConnectLinks []string `json:"connectLinks"` - RelayLink string `json:"relayLink"` - RelaySas string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` - ConversationID string `json:"conversationId"` - AssociatedUserIDs map[string]string `json:"associatedUserIds"` - AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` - IsHostConnected bool `json:"isHostConnected"` - ExpiresAt string `json:"expiresAt"` - InvitationLinks []string `json:"invitationLinks"` - ID string `json:"id"` -} - -func (a *api) workspaceAccess() (*workspaceAccessResponse, error) { - url := fmt.Sprintf("%s/workspace/%s/user", a.serviceURI, a.workspaceID) - - req, err := http.NewRequest(http.MethodPut, url, nil) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - a.setDefaultHeaders(req) - resp, err := a.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error making request: %v", err) - } - - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) - } - - var response workspaceAccessResponse - if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response into json: %v", err) - } - - return &response, nil -} - -func (a *api) setDefaultHeaders(req *http.Request) { - req.Header.Set("Authorization", "Bearer "+a.client.liveShare.Configuration.Token) - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Content-Type", "application/json") -} - -type workspaceInfoResponse struct { - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - Name string `json:"name"` - OwnerID string `json:"ownerId"` - JoinLink string `json:"joinLink"` - ConnectLinks []string `json:"connectLinks"` - RelayLink string `json:"relayLink"` - RelaySas string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` - ConversationID string `json:"conversationId"` - AssociatedUserIDs map[string]string - AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` - IsHostConnected bool `json:"isHostConnected"` - ExpiresAt string `json:"expiresAt"` - InvitationLinks []string `json:"invitationLinks"` - ID string `json:"id"` -} - -func (a *api) workspaceInfo() (*workspaceInfoResponse, error) { - url := fmt.Sprintf("%s/workspace/%s", a.serviceURI, a.workspaceID) - - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - a.setDefaultHeaders(req) - resp, err := a.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error making request: %v", err) - } - - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) - } - - var response workspaceInfoResponse - if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response into json: %v", err) - } - - return &response, nil -} diff --git a/client.go b/client.go index 456a2c321..1ad90b336 100644 --- a/client.go +++ b/client.go @@ -8,31 +8,44 @@ import ( ) type Client struct { - liveShare *LiveShare - session *session + connection Connection + sshSession *sshSession rpc *rpc } -// NewClient is a function ... -func (l *LiveShare) NewClient() *Client { - return &Client{liveShare: l} +type ClientOption func(*Client) error + +func NewClient(opts ...ClientOption) (*Client, error) { + client := new(Client) + + for _, o := range opts { + if err := o(client); err != nil { + return nil, err + } + } + + return client, nil +} + +func WithConnection(connection Connection) ClientOption { + return func(c *Client) error { + if err := connection.validate(); err != nil { + return err + } + + c.connection = connection + return nil + } } func (c *Client) Join(ctx context.Context) (err error) { - api := newAPI(c) - - c.session = newSession(api) - if err := c.session.init(ctx); err != nil { - return fmt.Errorf("error creating session: %v", err) - } - - websocket := newWebsocket(c.session) - if err := websocket.connect(ctx); err != nil { + clientSocket := newSocket(c.connection) + if err := clientSocket.connect(ctx); err != nil { return fmt.Errorf("error connecting websocket: %v", err) } - c.sshSession = newSSH(c.session, websocket) + c.sshSession = newSSH(c.connection.SessionToken, clientSocket) if err := c.sshSession.connect(ctx); err != nil { return fmt.Errorf("error connecting to ssh session: %v", err) } @@ -69,9 +82,9 @@ type joinWorkspaceResult struct { func (c *Client) joinWorkspace(ctx context.Context) (*joinWorkspaceResult, error) { args := joinWorkspaceArgs{ - ID: c.session.workspaceInfo.ID, + ID: c.connection.SessionID, ConnectionMode: "local", - JoiningUserSessionToken: c.session.workspaceAccess.SessionToken, + JoiningUserSessionToken: c.connection.SessionToken, ClientCapabilities: clientCapabilities{ IsNonInteractive: false, }, @@ -99,8 +112,7 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition go ssh.DiscardRequests(reqs) requestType := fmt.Sprintf("stream-transport-%s", streamID) - _, err = channel.SendRequest(requestType, true, nil) - if err != nil { + if _, err = channel.SendRequest(requestType, true, nil); err != nil { return nil, fmt.Errorf("error sending channel request: %v", err) } diff --git a/client_test.go b/client_test.go new file mode 100644 index 000000000..8d118974d --- /dev/null +++ b/client_test.go @@ -0,0 +1,24 @@ +package liveshare + +import ( + "testing" +) + +func TestClientJoin(t *testing.T) { + // connection := Connection{ + // SessionID: "session-id", + // SessionToken: "session-token", + // RelayEndpoint: "relay-endpoint", + // RelaySAS: "relay-sas", + // } + + // client, err := NewClient(WithConnection(connection)) + // if err != nil { + // t.Errorf("error creating client: %v", err) + // } + + // ctx := context.Background() + // if err := client.Join(ctx); err != nil { + // t.Errorf("error joining client: %v", err) + // } +} diff --git a/connection.go b/connection.go new file mode 100644 index 000000000..a97935d3b --- /dev/null +++ b/connection.go @@ -0,0 +1,43 @@ +package liveshare + +import ( + "errors" + "net/url" + "strings" +) + +type Connection struct { + SessionID string `json:"sessionId"` + SessionToken string `json:"sessionToken"` + RelaySAS string `json:"relaySas"` + RelayEndpoint string `json:"relayEndpoint"` +} + +func (r Connection) validate() error { + if r.SessionID == "" { + return errors.New("connection sessionID is required") + } + + if r.SessionToken == "" { + return errors.New("connection sessionToken is required") + } + + if r.RelaySAS == "" { + return errors.New("connection relaySas is required") + } + + if r.RelayEndpoint == "" { + return errors.New("connection relayEndpoint is required") + } + + return nil +} + +func (r Connection) uri(action string) string { + sas := url.QueryEscape(r.RelaySAS) + uri := r.RelayEndpoint + uri = strings.Replace(uri, "sb:", "wss:", -1) + uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) + uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas + return uri +} diff --git a/liveshare.go b/liveshare.go deleted file mode 100644 index 3c4be5c05..000000000 --- a/liveshare.go +++ /dev/null @@ -1,77 +0,0 @@ -package liveshare - -import ( - "errors" - "fmt" - "strings" -) - -type LiveShare struct { - Configuration *Configuration -} - -func New(opts ...Option) (*LiveShare, error) { - configuration := NewConfiguration() - - for _, o := range opts { - if err := o(configuration); err != nil { - return nil, fmt.Errorf("error configuring liveshare: %v", err) - } - } - - if err := configuration.Validate(); err != nil { - return nil, fmt.Errorf("error validating configuration: %v", err) - } - - return &LiveShare{Configuration: configuration}, nil -} - -type Option func(configuration *Configuration) error - -func WithWorkspaceID(id string) Option { - return func(configuration *Configuration) error { - configuration.WorkspaceID = id - return nil - } -} - -func WithLiveShareEndpoint(liveShareEndpoint string) Option { - return func(configuration *Configuration) error { - configuration.LiveShareEndpoint = liveShareEndpoint - return nil - } -} - -func WithToken(token string) Option { - return func(configuration *Configuration) error { - configuration.Token = token - return nil - } -} - -type Configuration struct { - WorkspaceID, LiveShareEndpoint, Token string -} - -func NewConfiguration() *Configuration { - return &Configuration{ - LiveShareEndpoint: "https://prod.liveshare.vsengsaas.visualstudio.com", - } -} - -func (c *Configuration) Validate() error { - errs := []string{} - if c.WorkspaceID == "" { - errs = append(errs, "WorkspaceID is required") - } - - if c.Token == "" { - errs = append(errs, "Token is required") - } - - if len(errs) > 0 { - return errors.New(strings.Join(errs, ", ")) - } - - return nil -} diff --git a/port_forwarder.go b/port_forwarder.go index 20382c208..1227493b2 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -11,18 +11,18 @@ import ( "golang.org/x/crypto/ssh" ) -type LocalPortForwarder struct { +type PortForwarder struct { client *Client server *Server port int channels []ssh.Channel } -func NewLocalPortForwarder(client *Client, server *Server, port int) *LocalPortForwarder { - return &LocalPortForwarder{client, server, port, []ssh.Channel{}} +func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { + return &PortForwarder{client, server, port, []ssh.Channel{}} } -func (l *LocalPortForwarder) Start(ctx context.Context) error { +func (l *PortForwarder) Start(ctx context.Context) error { ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) if err != nil { return fmt.Errorf("error listening on tcp port: %v", err) @@ -42,7 +42,7 @@ func (l *LocalPortForwarder) Start(ctx context.Context) error { return nil } -func (l *LocalPortForwarder) handleConnection(ctx context.Context, conn net.Conn) { +func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) if err != nil { log.Println("errrr handle Connect") diff --git a/rpc.go b/rpc.go index de427cda9..d40046471 100644 --- a/rpc.go +++ b/rpc.go @@ -78,7 +78,5 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{} }() - } else { - // TODO(josebalius): Handle } } diff --git a/server.go b/server.go index 71ec9d4dd..6f17d5ac5 100644 --- a/server.go +++ b/server.go @@ -13,12 +13,12 @@ type Server struct { streamName, streamCondition string } -func (c *Client) NewServer() (*Server, error) { - if !c.hasJoined() { - return nil, errors.New("LiveShareClient must join before creating server") +func NewServer(client *Client) (*Server, error) { + if !client.hasJoined() { + return nil, errors.New("client must join before creating server") } - return &Server{client: c}, nil + return &Server{client: client}, nil } type Port struct { diff --git a/session.go b/session.go deleted file mode 100644 index d0492a10e..000000000 --- a/session.go +++ /dev/null @@ -1,60 +0,0 @@ -package liveshare - -import ( - "context" - "fmt" - "net/url" - "strings" - - "golang.org/x/sync/errgroup" -) - -type session struct { - api *api - - workspaceAccess *workspaceAccessResponse - workspaceInfo *workspaceInfoResponse -} - -func newSession(api *api) *session { - return &session{api: api} -} - -func (s *session) init(ctx context.Context) error { - g, ctx := errgroup.WithContext(ctx) - - g.Go(func() error { - workspaceAccess, err := s.api.workspaceAccess() - if err != nil { - return fmt.Errorf("error getting workspace access: %v", err) - } - s.workspaceAccess = workspaceAccess - return nil - }) - - g.Go(func() error { - workspaceInfo, err := s.api.workspaceInfo() - if err != nil { - return fmt.Errorf("error getting workspace info: %v", err) - } - s.workspaceInfo = workspaceInfo - return nil - }) - - if err := g.Wait(); err != nil { - return err - } - - return nil -} - -// Reference: -// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 -func (s *session) relayURI(action string) string { - relaySas := url.QueryEscape(s.workspaceAccess.RelaySas) - relayURI := s.workspaceAccess.RelayLink - relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) - relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) - relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas - return relayURI -} diff --git a/socket.go b/socket.go new file mode 100644 index 000000000..c3f75b9db --- /dev/null +++ b/socket.go @@ -0,0 +1,105 @@ +package liveshare + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type socket struct { + addr string + conn *websocket.Conn + readMutex sync.Mutex + writeMutex sync.Mutex + reader io.Reader +} + +func newSocket(clientConn Connection) *socket { + return &socket{addr: clientConn.uri("connect")} +} + +func (s *socket) connect(ctx context.Context) error { + ws, _, err := websocket.DefaultDialer.Dial(s.addr, nil) + if err != nil { + return err + } + s.conn = ws + return nil +} + +func (s *socket) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + messageType, reader, err := s.conn.NextReader() + if err != nil { + return 0, err + } + + if messageType != websocket.BinaryMessage { + return 0, errors.New("unexpected websocket message type") + } + + s.reader = reader + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socket) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, err + } + + bytesWritten, err := nextWriter.Write(b) + nextWriter.Close() + + return bytesWritten, err +} + +func (s *socket) Close() error { + return s.conn.Close() +} + +func (s *socket) LocalAddr() net.Addr { + return s.conn.LocalAddr() +} + +func (s *socket) RemoteAddr() net.Addr { + return s.conn.RemoteAddr() +} + +func (s *socket) SetDeadline(t time.Time) error { + if err := s.SetReadDeadline(t); err != nil { + return err + } + + return s.SetWriteDeadline(t) +} + +func (s *socket) SetReadDeadline(t time.Time) error { + return s.conn.SetReadDeadline(t) +} + +func (s *socket) SetWriteDeadline(t time.Time) error { + return s.conn.SetWriteDeadline(t) +} diff --git a/ssh.go b/ssh.go index 9ae32ed7c..3ea2d2777 100644 --- a/ssh.go +++ b/ssh.go @@ -12,22 +12,22 @@ import ( type sshSession struct { *ssh.Session - session *session - socket net.Conn - conn ssh.Conn - reader io.Reader - writer io.Writer + token string + socket net.Conn + conn ssh.Conn + reader io.Reader + writer io.Writer } -func newSSH(session *session, socket net.Conn) *sshSession { - return &sshSession{session: session, socket: socket} +func newSSH(token string, socket net.Conn) *sshSession { + return &sshSession{token: token, socket: socket} } func (s *sshSession) connect(ctx context.Context) error { clientConfig := ssh.ClientConfig{ User: "", Auth: []ssh.AuthMethod{ - ssh.Password(s.session.workspaceAccess.SessionToken), + ssh.Password(s.token), }, HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), diff --git a/terminal.go b/terminal.go index 631e75912..1621559a1 100644 --- a/terminal.go +++ b/terminal.go @@ -13,13 +13,13 @@ type Terminal struct { client *Client } -func (c *Client) NewTerminal() (*Terminal, error) { - if !c.hasJoined() { - return nil, errors.New("LiveShareClient must join before creating terminal") +func NewTerminal(client *Client) (*Terminal, error) { + if !client.hasJoined() { + return nil, errors.New("client must join before creating terminal") } return &Terminal{ - client: c, + client: client, }, nil } diff --git a/websocket.go b/websocket.go deleted file mode 100644 index ae163e6e2..000000000 --- a/websocket.go +++ /dev/null @@ -1,105 +0,0 @@ -package liveshare - -import ( - "context" - "errors" - "io" - "net" - "sync" - "time" - - gorillawebsocket "github.com/gorilla/websocket" -) - -type websocket struct { - session *session - conn *gorillawebsocket.Conn - readMutex sync.Mutex - writeMutex sync.Mutex - reader io.Reader -} - -func newWebsocket(session *session) *websocket { - return &websocket{session: session} -} - -func (w *websocket) connect(ctx context.Context) error { - ws, _, err := gorillawebsocket.DefaultDialer.Dial(w.session.relayURI("connect"), nil) - if err != nil { - return err - } - w.conn = ws - return nil -} - -func (w *websocket) Read(b []byte) (int, error) { - w.readMutex.Lock() - defer w.readMutex.Unlock() - - if w.reader == nil { - messageType, reader, err := w.conn.NextReader() - if err != nil { - return 0, err - } - - if messageType != gorillawebsocket.BinaryMessage { - return 0, errors.New("unexpected websocket message type") - } - - w.reader = reader - } - - bytesRead, err := w.reader.Read(b) - if err != nil { - w.reader = nil - - if err == io.EOF { - err = nil - } - } - - return bytesRead, err -} - -func (w *websocket) Write(b []byte) (int, error) { - w.writeMutex.Lock() - defer w.writeMutex.Unlock() - - nextWriter, err := w.conn.NextWriter(gorillawebsocket.BinaryMessage) - if err != nil { - return 0, err - } - - bytesWritten, err := nextWriter.Write(b) - nextWriter.Close() - - return bytesWritten, err -} - -func (w *websocket) Close() error { - return w.conn.Close() -} - -func (w *websocket) LocalAddr() net.Addr { - return w.conn.LocalAddr() -} - -func (w *websocket) RemoteAddr() net.Addr { - return w.conn.RemoteAddr() -} - -func (w *websocket) SetDeadline(t time.Time) error { - if err := w.SetReadDeadline(t); err != nil { - return err - } - - return w.SetWriteDeadline(t) -} - -func (w *websocket) SetReadDeadline(t time.Time) error { - return w.conn.SetReadDeadline(t) -} - -func (w *websocket) SetWriteDeadline(t time.Time) error { - return w.conn.SetWriteDeadline(t) -} From fddcd876b0b6e50959b0530662c50dda1b0079c1 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 22 Jul 2021 01:02:03 +0000 Subject: [PATCH 11/68] Some more cleanup to the port forwarder and connection --- connection.go | 16 ++++++++-------- port_forwarder.go | 28 +++++++++++++--------------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/connection.go b/connection.go index a97935d3b..eda050f63 100644 --- a/connection.go +++ b/connection.go @@ -7,27 +7,27 @@ import ( ) type Connection struct { - SessionID string `json:"sessionId"` - SessionToken string `json:"sessionToken"` - RelaySAS string `json:"relaySas"` - RelayEndpoint string `json:"relayEndpoint"` + SessionID string + SessionToken string + RelaySAS string + RelayEndpoint string } func (r Connection) validate() error { if r.SessionID == "" { - return errors.New("connection sessionID is required") + return errors.New("connection SessionID is required") } if r.SessionToken == "" { - return errors.New("connection sessionToken is required") + return errors.New("connection SessionToken is required") } if r.RelaySAS == "" { - return errors.New("connection relaySas is required") + return errors.New("connection RelaySAS is required") } if r.RelayEndpoint == "" { - return errors.New("connection relayEndpoint is required") + return errors.New("connection RelayEndpoint is required") } return nil diff --git a/port_forwarder.go b/port_forwarder.go index 1227493b2..8d42f3c05 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -4,22 +4,24 @@ import ( "context" "fmt" "io" - "log" "net" "strconv" - - "golang.org/x/crypto/ssh" ) type PortForwarder struct { - client *Client - server *Server - port int - channels []ssh.Channel + client *Client + server *Server + port int + errCh chan error } func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { - return &PortForwarder{client, server, port, []ssh.Channel{}} + return &PortForwarder{ + client: client, + server: server, + port: port, + errCh: make(chan error), + } } func (l *PortForwarder) Start(ctx context.Context) error { @@ -37,22 +39,18 @@ func (l *PortForwarder) Start(ctx context.Context) error { go l.handleConnection(ctx, conn) } - // clean up after ourselves - return nil } func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) if err != nil { - log.Println("errrr handle Connect") - log.Println(err) // TODO(josebalius) handle this somehow + l.errCh <- fmt.Errorf("error opening streaming channel for new connection: %v", err) + return } - l.channels = append(l.channels, channel) copyConn := func(writer io.Writer, reader io.Reader) { - _, err := io.Copy(writer, reader) - if err != nil { + if _, err := io.Copy(writer, reader); err != nil { channel.Close() conn.Close() } From a99d0f5495a575c0694307dbbe606210b16830d7 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 22 Jul 2021 01:07:06 +0000 Subject: [PATCH 12/68] Better naming for rpc client and ssh session --- client.go | 14 +++++++------- rpc.go | 10 +++++----- ssh.go | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 1ad90b336..d0e6e84a8 100644 --- a/client.go +++ b/client.go @@ -10,8 +10,8 @@ import ( type Client struct { connection Connection - sshSession *sshSession - rpc *rpc + ssh *sshSession + rpc *rpcClient } type ClientOption func(*Client) error @@ -45,12 +45,12 @@ func (c *Client) Join(ctx context.Context) (err error) { return fmt.Errorf("error connecting websocket: %v", err) } - c.sshSession = newSSH(c.connection.SessionToken, clientSocket) - if err := c.sshSession.connect(ctx); err != nil { + c.ssh = newSshSession(c.connection.SessionToken, clientSocket) + if err := c.ssh.connect(ctx); err != nil { return fmt.Errorf("error connecting to ssh session: %v", err) } - c.rpc = newRPC(c.sshSession) + c.rpc = newRpcClient(c.ssh) c.rpc.connect(ctx) _, err = c.joinWorkspace(ctx) @@ -62,7 +62,7 @@ func (c *Client) Join(ctx context.Context) (err error) { } func (c *Client) hasJoined() bool { - return c.sshSession != nil && c.rpc != nil + return c.ssh != nil && c.rpc != nil } type clientCapabilities struct { @@ -105,7 +105,7 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition return nil, fmt.Errorf("error getting stream id: %v", err) } - channel, reqs, err := c.sshSession.conn.OpenChannel("session", nil) + channel, reqs, err := c.ssh.conn.OpenChannel("session", nil) if err != nil { return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) } diff --git a/rpc.go b/rpc.go index d40046471..d624bbd74 100644 --- a/rpc.go +++ b/rpc.go @@ -9,22 +9,22 @@ import ( "github.com/sourcegraph/jsonrpc2" ) -type rpc struct { +type rpcClient struct { *jsonrpc2.Conn conn io.ReadWriteCloser handler *rpcHandler } -func newRPC(conn io.ReadWriteCloser) *rpc { - return &rpc{conn: conn, handler: newRPCHandler()} +func newRpcClient(conn io.ReadWriteCloser) *rpcClient { + return &rpcClient{conn: conn, handler: newRPCHandler()} } -func (r *rpc) connect(ctx context.Context) { +func (r *rpcClient) connect(ctx context.Context) { stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) } -func (r *rpc) do(ctx context.Context, method string, args interface{}, result interface{}) error { +func (r *rpcClient) do(ctx context.Context, method string, args interface{}, result interface{}) error { waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { return fmt.Errorf("error on dispatch call: %v", err) diff --git a/ssh.go b/ssh.go index 3ea2d2777..e22cd69d1 100644 --- a/ssh.go +++ b/ssh.go @@ -19,7 +19,7 @@ type sshSession struct { writer io.Writer } -func newSSH(token string, socket net.Conn) *sshSession { +func newSshSession(token string, socket net.Conn) *sshSession { return &sshSession{token: token, socket: socket} } From b9cd9af7fa83ad2fd7cca4727d5adc1be51fa384 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 23 Jul 2021 01:17:32 +0000 Subject: [PATCH 13/68] Start of tests and comments --- client.go | 5 ++++ client_test.go | 73 +++++++++++++++++++++++++++++++++++++++++++++-- connection.go | 1 + port_forwarder.go | 3 ++ 4 files changed, 79 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index d0e6e84a8..435dd6775 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "golang.org/x/crypto/ssh" ) +// A Client capable of joining a liveshare connection type Client struct { connection Connection @@ -14,8 +15,10 @@ type Client struct { rpc *rpcClient } +// A ClientOption is a function that modifies a client type ClientOption func(*Client) error +// NewClient accepts a range of options, applies them and returns a client func NewClient(opts ...ClientOption) (*Client, error) { client := new(Client) @@ -28,6 +31,7 @@ func NewClient(opts ...ClientOption) (*Client, error) { return client, nil } +// WithConnection is a ClientOption that accepts a Connection func WithConnection(connection Connection) ClientOption { return func(c *Client) error { if err := connection.validate(); err != nil { @@ -39,6 +43,7 @@ func WithConnection(connection Connection) ClientOption { } } +// Join is a method that joins the client to the liveshare session func (c *Client) Join(ctx context.Context) (err error) { clientSocket := newSocket(c.connection) if err := clientSocket.connect(ctx); err != nil { diff --git a/client_test.go b/client_test.go index 8d118974d..86f732637 100644 --- a/client_test.go +++ b/client_test.go @@ -1,22 +1,89 @@ package liveshare import ( + "fmt" + "net/http" + "net/http/httptest" "testing" + + "github.com/gorilla/websocket" ) +func TestNewClient(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Errorf("error creating new client: %v", err) + } + if client == nil { + t.Error("client is nil") + } +} + +func TestNewClientValidConnection(t *testing.T) { + connection := Connection{"1", "2", "3", "4"} + + client, err := NewClient(WithConnection(connection)) + if err != nil { + t.Errorf("error creating new client: %v", err) + } + if client == nil { + t.Error("client is nil") + } +} + +func TestNewClientWithInvalidConnection(t *testing.T) { + connection := Connection{} + + if _, err := NewClient(WithConnection(connection)); err == nil { + t.Error("err is nil") + } +} + +var upgrader = websocket.Upgrader{} + +func newMockLiveShareServer() *httptest.Server { + endpoint := func(w http.ResponseWriter, req *http.Request) { + c, err := upgrader.Upgrade(w, req, nil) + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + + for { + mt, message, err := c.ReadMessage() + if err != nil { + fmt.Println(err) + break + } + + err = c.WriteMessage(mt, message) + if err != nil { + fmt.Println(err) + break + } + + } + } + + return httptest.NewTLSServer(http.HandlerFunc(endpoint)) +} + func TestClientJoin(t *testing.T) { + // server := newMockLiveShareServer() + // defer server.Close() + // connection := Connection{ // SessionID: "session-id", // SessionToken: "session-token", - // RelayEndpoint: "relay-endpoint", // RelaySAS: "relay-sas", + // RelayEndpoint: "sb" + strings.TrimPrefix(server.URL, "https"), // } // client, err := NewClient(WithConnection(connection)) // if err != nil { - // t.Errorf("error creating client: %v", err) + // t.Errorf("error creating new client: %v", err) // } - // ctx := context.Background() // if err := client.Join(ctx); err != nil { // t.Errorf("error joining client: %v", err) diff --git a/connection.go b/connection.go index eda050f63..c1a4632c8 100644 --- a/connection.go +++ b/connection.go @@ -6,6 +6,7 @@ import ( "strings" ) +// A Connection represents a set of values necessary to join a liveshare connection type Connection struct { SessionID string SessionToken string diff --git a/port_forwarder.go b/port_forwarder.go index 8d42f3c05..6d459b4d6 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -8,6 +8,7 @@ import ( "strconv" ) +// A PortForwader can forward ports from a remote liveshare host to localhost type PortForwarder struct { client *Client server *Server @@ -15,6 +16,7 @@ type PortForwarder struct { errCh chan error } +// NewPortForwarder creates a new PortForwader with a given client, server and port func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { return &PortForwarder{ client: client, @@ -24,6 +26,7 @@ func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { } } +// Start is a method to start forwarding the server to a localhost port func (l *PortForwarder) Start(ctx context.Context) error { ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) if err != nil { From 9132a28e9cf2a09359a80e104a558c77a0c0abea Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 23 Jul 2021 19:15:54 +0000 Subject: [PATCH 14/68] Checking point after continuing to flesh out mock server --- client.go | 11 ++- client_test.go | 100 +++++++++++----------- socket.go | 17 +++- test/server.go | 226 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 299 insertions(+), 55 deletions(-) create mode 100644 test/server.go diff --git a/client.go b/client.go index 435dd6775..a8a1e3864 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package liveshare import ( "context" + "crypto/tls" "fmt" "golang.org/x/crypto/ssh" @@ -10,6 +11,7 @@ import ( // A Client capable of joining a liveshare connection type Client struct { connection Connection + tlsConfig *tls.Config ssh *sshSession rpc *rpcClient @@ -43,9 +45,16 @@ func WithConnection(connection Connection) ClientOption { } } +func WithTLSConfig(tlsConfig *tls.Config) ClientOption { + return func(c *Client) error { + c.tlsConfig = tlsConfig + return nil + } +} + // Join is a method that joins the client to the liveshare session func (c *Client) Join(ctx context.Context) (err error) { - clientSocket := newSocket(c.connection) + clientSocket := newSocket(c.connection, c.tlsConfig) if err := clientSocket.connect(ctx); err != nil { return fmt.Errorf("error connecting websocket: %v", err) } diff --git a/client_test.go b/client_test.go index 86f732637..bf77e3dce 100644 --- a/client_test.go +++ b/client_test.go @@ -1,12 +1,14 @@ package liveshare import ( + "context" + "crypto/tls" "fmt" - "net/http" - "net/http/httptest" + "strings" "testing" - "github.com/gorilla/websocket" + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" ) func TestNewClient(t *testing.T) { @@ -39,53 +41,51 @@ func TestNewClientWithInvalidConnection(t *testing.T) { } } -var upgrader = websocket.Upgrader{} - -func newMockLiveShareServer() *httptest.Server { - endpoint := func(w http.ResponseWriter, req *http.Request) { - c, err := upgrader.Upgrade(w, req, nil) - if err != nil { - fmt.Println(err) - return - } - defer c.Close() - - for { - mt, message, err := c.ReadMessage() - if err != nil { - fmt.Println(err) - break - } - - err = c.WriteMessage(mt, message) - if err != nil { - fmt.Println(err) - break - } - - } +func TestClientJoin(t *testing.T) { + sessionToken := "session-token" + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + return 1, nil } - return httptest.NewTLSServer(http.HandlerFunc(endpoint)) -} - -func TestClientJoin(t *testing.T) { - // server := newMockLiveShareServer() - // defer server.Close() - - // connection := Connection{ - // SessionID: "session-id", - // SessionToken: "session-token", - // RelaySAS: "relay-sas", - // RelayEndpoint: "sb" + strings.TrimPrefix(server.URL, "https"), - // } - - // client, err := NewClient(WithConnection(connection)) - // if err != nil { - // t.Errorf("error creating new client: %v", err) - // } - // ctx := context.Background() - // if err := client.Join(ctx); err != nil { - // t.Errorf("error joining client: %v", err) - // } + server, err := livesharetest.NewServer( + livesharetest.WithPassword(sessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + ) + if err != nil { + t.Errorf("error creating liveshare server: %v", err) + } + defer server.Close() + + ctx := context.Background() + connection := Connection{ + SessionID: "session-id", + SessionToken: sessionToken, + RelaySAS: "relay-sas", + RelayEndpoint: "sb" + strings.TrimPrefix(server.URL(), "https"), + } + + tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + client, err := NewClient(WithConnection(connection), tlsConfig) + if err != nil { + t.Errorf("error creating new client: %v", err) + } + + clientErr := make(chan error) + go func() { + if err := client.Join(ctx); err != nil { + clientErr <- fmt.Errorf("error joining client: %v", err) + return + } + + ctx.Done() + }() + + select { + case err := <-server.Err(): + t.Errorf("error from server: %v", err) + case err := <-clientErr: + t.Errorf("error from client: %v", err) + case <-ctx.Done(): + return + } } diff --git a/socket.go b/socket.go index c3f75b9db..e4f80a0cf 100644 --- a/socket.go +++ b/socket.go @@ -2,9 +2,11 @@ package liveshare import ( "context" + "crypto/tls" "errors" "io" "net" + "net/http" "sync" "time" @@ -12,19 +14,26 @@ import ( ) type socket struct { - addr string + addr string + tlsConfig *tls.Config + conn *websocket.Conn readMutex sync.Mutex writeMutex sync.Mutex reader io.Reader } -func newSocket(clientConn Connection) *socket { - return &socket{addr: clientConn.uri("connect")} +func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket { + return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig} } func (s *socket) connect(ctx context.Context) error { - ws, _, err := websocket.DefaultDialer.Dial(s.addr, nil) + dialer := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + TLSClientConfig: s.tlsConfig, + } + ws, _, err := dialer.Dial(s.addr, nil) if err != nil { return err } diff --git a/test/server.go b/test/server.go new file mode 100644 index 000000000..ed8666cce --- /dev/null +++ b/test/server.go @@ -0,0 +1,226 @@ +package livesharetest + +import ( + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "path/filepath" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/sourcegraph/jsonrpc2" + "golang.org/x/crypto/ssh" +) + +type Server struct { + password string + services map[string]RpcHandleFunc + + sshConfig *ssh.ServerConfig + httptestServer *httptest.Server + errCh chan error +} + +func NewServer(opts ...ServerOption) (*Server, error) { + server := new(Server) + + for _, o := range opts { + if err := o(server); err != nil { + return nil, err + } + } + + server.sshConfig = &ssh.ServerConfig{ + PasswordCallback: sshPasswordCallback(server.password), + } + b, err := ioutil.ReadFile(filepath.Join("test", "private.key")) + if err != nil { + return nil, fmt.Errorf("error reading private.key: %v", err) + } + privateKey, err := ssh.ParsePrivateKey(b) + if err != nil { + return nil, fmt.Errorf("error parsing key: %v", err) + } + server.sshConfig.AddHostKey(privateKey) + + server.errCh = make(chan error) + server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server))) + return server, nil +} + +type ServerOption func(*Server) error + +func WithPassword(password string) ServerOption { + return func(s *Server) error { + s.password = password + return nil + } +} + +func WithService(serviceName string, handler RpcHandleFunc) ServerOption { + return func(s *Server) error { + if s.services == nil { + s.services = make(map[string]RpcHandleFunc) + } + + s.services[serviceName] = handler + return nil + } +} + +func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { + return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + if string(password) == serverPassword { + return nil, nil + } + return nil, errors.New("password rejected") + } +} + +func (s *Server) Close() { + s.httptestServer.Close() +} + +func (s *Server) URL() string { + return s.httptestServer.URL +} + +func (s *Server) Err() <-chan error { + return s.errCh +} + +var upgrader = websocket.Upgrader{} + +func newConnection(server *Server) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + c, err := upgrader.Upgrade(w, req, nil) + if err != nil { + server.errCh <- fmt.Errorf("error upgrading connection: %v", err) + return + } + defer c.Close() + + socketConn := newSocketConn(c) + _, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig) + if err != nil { + server.errCh <- fmt.Errorf("error creating new ssh conn: %v", err) + return + } + go ssh.DiscardRequests(reqs) + + for newChannel := range chans { + ch, reqs, err := newChannel.Accept() + if err != nil { + server.errCh <- fmt.Errorf("error accepting new channel: %v", err) + return + } + go ssh.DiscardRequests(reqs) + go handleNewChannel(server, ch) + } + } +} + +func handleNewChannel(server *Server, channel ssh.Channel) { + stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) + jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server)) +} + +type RpcHandleFunc func(req *jsonrpc2.Request) (interface{}, error) + +type rpcHandler struct { + server *Server +} + +func newRpcHandler(server *Server) *rpcHandler { + return &rpcHandler{server} +} + +func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + handler, found := r.server.services[req.Method] + if !found { + r.server.errCh <- fmt.Errorf("RPC Method: '%v' not serviced", req.Method) + return + } + + result, err := handler(req) + if err != nil { + r.server.errCh <- fmt.Errorf("error handling: '%v': %v", req.Method, err) + return + } + + if err := conn.Reply(ctx, req.ID, result); err != nil { + r.server.errCh <- fmt.Errorf("error replying: %v", err) + } +} + +type socketConn struct { + *websocket.Conn + + reader io.Reader + writeMutex sync.Mutex + readMutex sync.Mutex +} + +func newSocketConn(conn *websocket.Conn) *socketConn { + return &socketConn{Conn: conn} +} + +func (s *socketConn) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + msgType, r, err := s.Conn.NextReader() + if err != nil { + return 0, fmt.Errorf("error getting next reader: %v", err) + } + if msgType != websocket.BinaryMessage { + return 0, fmt.Errorf("invalid message type") + } + s.reader = r + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socketConn) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + w, err := s.Conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, fmt.Errorf("error getting next writer: %v", err) + } + + n, err := w.Write(b) + if err != nil { + return 0, fmt.Errorf("error writing: %v", err) + } + + if err := w.Close(); err != nil { + return 0, fmt.Errorf("error closing writer: %v", err) + } + + return n, nil +} + +func (s *socketConn) SetDeadline(deadline time.Time) error { + if err := s.Conn.SetReadDeadline(deadline); err != nil { + return err + } + return s.Conn.SetWriteDeadline(deadline) +} From fcfb10cb56e6aaaae9c745bc1ee00f48897220f9 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 23 Jul 2021 20:24:50 +0000 Subject: [PATCH 15/68] Working test for Client.Join --- client_test.go | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/client_test.go b/client_test.go index bf77e3dce..fdf566fc0 100644 --- a/client_test.go +++ b/client_test.go @@ -42,27 +42,26 @@ func TestNewClientWithInvalidConnection(t *testing.T) { } func TestClientJoin(t *testing.T) { - sessionToken := "session-token" + connection := Connection{ + SessionID: "session-id", + SessionToken: "session-token", + RelaySAS: "relay-sas", + } joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { - return 1, nil + return joinWorkspaceResult{1}, nil } server, err := livesharetest.NewServer( - livesharetest.WithPassword(sessionToken), + livesharetest.WithPassword(connection.SessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), ) if err != nil { t.Errorf("error creating liveshare server: %v", err) } defer server.Close() + connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") ctx := context.Background() - connection := Connection{ - SessionID: "session-id", - SessionToken: sessionToken, - RelaySAS: "relay-sas", - RelayEndpoint: "sb" + strings.TrimPrefix(server.URL(), "https"), - } tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) client, err := NewClient(WithConnection(connection), tlsConfig) @@ -70,22 +69,22 @@ func TestClientJoin(t *testing.T) { t.Errorf("error creating new client: %v", err) } - clientErr := make(chan error) + done := make(chan error) go func() { if err := client.Join(ctx); err != nil { - clientErr <- fmt.Errorf("error joining client: %v", err) + done <- fmt.Errorf("error joining client: %v", err) return } - ctx.Done() + done <- nil }() select { case err := <-server.Err(): t.Errorf("error from server: %v", err) - case err := <-clientErr: - t.Errorf("error from client: %v", err) - case <-ctx.Done(): - return + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } } } From 91114d35c3d04245a58f78ebf2feb6bb5edde4e2 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Sat, 24 Jul 2021 03:44:20 +0000 Subject: [PATCH 16/68] More tests --- client_test.go | 19 +++++ port_forwarder_test.go | 1 + server_test.go | 186 +++++++++++++++++++++++++++++++++++++++++ test/server.go | 16 ++++ 4 files changed, 222 insertions(+) create mode 100644 port_forwarder_test.go create mode 100644 server_test.go diff --git a/client_test.go b/client_test.go index fdf566fc0..110c7e3b9 100644 --- a/client_test.go +++ b/client_test.go @@ -3,6 +3,8 @@ package liveshare import ( "context" "crypto/tls" + "encoding/json" + "errors" "fmt" "strings" "testing" @@ -48,12 +50,29 @@ func TestClientJoin(t *testing.T) { RelaySAS: "relay-sas", } joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + var joinWorkspaceReq joinWorkspaceArgs + if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { + return nil, fmt.Errorf("error unmarshaling req: %v", err) + } + if joinWorkspaceReq.ID != connection.SessionID { + return nil, errors.New("connection session id does not match") + } + if joinWorkspaceReq.ConnectionMode != "local" { + return nil, errors.New("connection mode is not local") + } + if joinWorkspaceReq.JoiningUserSessionToken != connection.SessionToken { + return nil, errors.New("connection user token does not match") + } + if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false { + return nil, errors.New("non interactive is not false") + } return joinWorkspaceResult{1}, nil } server, err := livesharetest.NewServer( livesharetest.WithPassword(connection.SessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + livesharetest.WithRelaySAS(connection.RelaySAS), ) if err != nil { t.Errorf("error creating liveshare server: %v", err) diff --git a/port_forwarder_test.go b/port_forwarder_test.go new file mode 100644 index 000000000..e3e219705 --- /dev/null +++ b/port_forwarder_test.go @@ -0,0 +1 @@ +package liveshare diff --git a/server_test.go b/server_test.go new file mode 100644 index 000000000..cc2b9adbd --- /dev/null +++ b/server_test.go @@ -0,0 +1,186 @@ +package liveshare + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestNewServerWithNotJoinedClient(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Errorf("error creating new client: %v", err) + } + if _, err := NewServer(client); err == nil { + t.Error("expected error") + } +} + +func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { + connection := Connection{ + SessionID: "session-id", + SessionToken: "session-token", + RelaySAS: "relay-sas", + } + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + return joinWorkspaceResult{1}, nil + } + opts = append( + opts, + livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + ) + testServer, err := livesharetest.NewServer( + opts..., + ) + connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https") + tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + client, err := NewClient(WithConnection(connection), tlsConfig) + if err != nil { + return nil, nil, fmt.Errorf("error creating new client: %v", err) + } + ctx := context.Background() + if err := client.Join(ctx); err != nil { + return nil, nil, fmt.Errorf("error joining client: %v", err) + } + return testServer, client, nil +} + +func TestNewServer(t *testing.T) { + testServer, client, err := newMockJoinedClient() + defer testServer.Close() + if err != nil { + t.Errorf("error creating mock joined client: %v", err) + } + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + if server == nil { + t.Error("server is nil") + } +} + +func TestServerStartSharing(t *testing.T) { + serverPort, serverProtocol := 2222, "sshd" + startSharing := func(req *jsonrpc2.Request) (interface{}, error) { + var args []interface{} + if err := json.Unmarshal(*req.Params, &args); err != nil { + return nil, fmt.Errorf("error unmarshaling request: %v", err) + } + if len(args) < 3 { + return nil, errors.New("not enough arguments to start sharing") + } + if port, ok := args[0].(float64); !ok { + return nil, errors.New("port argument is not an int") + } else if port != float64(serverPort) { + return nil, errors.New("port does not match serverPort") + } + if protocol, ok := args[1].(string); !ok { + return nil, errors.New("protocol argument is not a string") + } else if protocol != serverProtocol { + return nil, errors.New("protocol does not match serverProtocol") + } + if browseURL, ok := args[2].(string); !ok { + return nil, errors.New("browse url is not a string") + } else if browseURL != fmt.Sprintf("http://localhost:%v", serverPort) { + return nil, errors.New("browseURL does not match expected") + } + return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil + } + testServer, client, err := newMockJoinedClient( + livesharetest.WithService("serverSharing.startSharing", startSharing), + ) + defer testServer.Close() + if err != nil { + t.Errorf("error creating mock joined client: %v", err) + } + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + ctx := context.Background() + + done := make(chan error) + go func() { + if err := server.StartSharing(ctx, serverProtocol, serverPort); err != nil { + done <- fmt.Errorf("error sharing server: %v", err) + } + if server.streamName == "" || server.streamCondition == "" { + done <- errors.New("stream name or condition is blank") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} + +func TestServerGetSharedServers(t *testing.T) { + sharedServer := Port{ + SourcePort: 2222, + StreamName: "stream-name", + StreamCondition: "stream-condition", + } + getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) { + return Ports{&sharedServer}, nil + } + testServer, client, err := newMockJoinedClient( + livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), + ) + if err != nil { + t.Errorf("error creating new mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + ctx := context.Background() + done := make(chan error) + go func() { + ports, err := server.GetSharedServers(ctx) + if err != nil { + done <- fmt.Errorf("error getting shared servers: %v", err) + } + if len(ports) < 1 { + done <- errors.New("not enough ports returned") + } + if ports[0].SourcePort != sharedServer.SourcePort { + done <- errors.New("source port does not match") + } + if ports[0].StreamName != sharedServer.StreamName { + done <- errors.New("stream name does not match") + } + if ports[0].StreamCondition != sharedServer.StreamCondition { + done <- errors.New("stream condiion does not match") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} + +func TestServerUpdateSharedVisibility(t *testing.T) { + +} diff --git a/test/server.go b/test/server.go index ed8666cce..abb7ac96a 100644 --- a/test/server.go +++ b/test/server.go @@ -20,6 +20,7 @@ import ( type Server struct { password string services map[string]RpcHandleFunc + relaySAS string sshConfig *ssh.ServerConfig httptestServer *httptest.Server @@ -73,6 +74,13 @@ func WithService(serviceName string, handler RpcHandleFunc) ServerOption { } } +func WithRelaySAS(sas string) ServerOption { + return func(s *Server) error { + s.relaySAS = sas + return nil + } +} + func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { if string(password) == serverPassword { @@ -98,6 +106,14 @@ var upgrader = websocket.Upgrader{} func newConnection(server *Server) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { + if server.relaySAS != "" { + // validate the sas key + sasParam := req.URL.Query().Get("sb-hc-token") + if sasParam != server.relaySAS { + server.errCh <- errors.New("error validating sas") + return + } + } c, err := upgrader.Upgrade(w, req, nil) if err != nil { server.errCh <- fmt.Errorf("error upgrading connection: %v", err) From 98282ba4b51085e03965672b1b124cc708bc6e82 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 26 Jul 2021 14:31:00 +0000 Subject: [PATCH 17/68] Update shared visibility tests --- rpc.go | 5 ----- server_test.go | 32 +++++++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/rpc.go b/rpc.go index d624bbd74..8abd0e98f 100644 --- a/rpc.go +++ b/rpc.go @@ -30,11 +30,6 @@ func (r *rpcClient) do(ctx context.Context, method string, args interface{}, res return fmt.Errorf("error on dispatch call: %v", err) } - // caller doesn't care about result, so lets ignore it - if result == nil { - return nil - } - return waiter.Wait(ctx, result) } diff --git a/server_test.go b/server_test.go index cc2b9adbd..8a736b6f5 100644 --- a/server_test.go +++ b/server_test.go @@ -182,5 +182,35 @@ func TestServerGetSharedServers(t *testing.T) { } func TestServerUpdateSharedVisibility(t *testing.T) { - + updateSharedVisibility := func(req *jsonrpc2.Request) error { + return nil + } + testServer, client, err := newMockJoinedClient( + livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), + ) + if err != nil { + t.Errorf("creating new mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("creating server: %v", err) + } + ctx := context.Background() + done := make(chan error) + go func() { + if err := server.UpdateSharedVisibility(ctx, 80, true); err != nil { + done <- err + return + } + done <- nil + }() + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } } From 892f73221c69d21f77d53e77eddd16f40c20c4ba Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 26 Jul 2021 14:39:52 +0000 Subject: [PATCH 18/68] Update shared visibility finalized tests --- server_test.go | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/server_test.go b/server_test.go index 8a736b6f5..7c9ee4288 100644 --- a/server_test.go +++ b/server_test.go @@ -182,8 +182,29 @@ func TestServerGetSharedServers(t *testing.T) { } func TestServerUpdateSharedVisibility(t *testing.T) { - updateSharedVisibility := func(req *jsonrpc2.Request) error { - return nil + updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) { + var req []interface{} + if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { + return nil, fmt.Errorf("unmarshal req: %v", err) + } + if len(req) < 2 { + return nil, errors.New("request arguments is less than 2") + } + if port, ok := req[0].(float64); ok { + if port != 80.0 { + return nil, errors.New("port param is not expected value") + } + } else { + return nil, errors.New("port param is not a float64") + } + if public, ok := req[1].(bool); ok { + if public != true { + return nil, errors.New("pulic param is not expected value") + } + } else { + return nil, errors.New("public param is not a bool") + } + return nil, nil } testServer, client, err := newMockJoinedClient( livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), From 0ab67badfad20a67a73bb170647fa115538b2995 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 27 Jul 2021 23:19:55 +0000 Subject: [PATCH 19/68] Final changes to finish this refactor --- connection_test.go | 33 +++++++++++++ port_forwarder.go | 4 ++ port_forwarder_test.go | 102 +++++++++++++++++++++++++++++++++++++++++ server.go | 8 ++++ server_test.go | 10 ++-- socket.go | 20 ++------ test/server.go | 54 ++++++++++++++++++++-- 7 files changed, 206 insertions(+), 25 deletions(-) create mode 100644 connection_test.go diff --git a/connection_test.go b/connection_test.go new file mode 100644 index 000000000..e952290be --- /dev/null +++ b/connection_test.go @@ -0,0 +1,33 @@ +package liveshare + +import "testing" + +func TestConnectionValid(t *testing.T) { + conn := Connection{"sess-id", "sess-token", "sas", "endpoint"} + if err := conn.validate(); err != nil { + t.Error(err) + } +} + +func TestConnectionInvalid(t *testing.T) { + conn := Connection{"", "sess-token", "sas", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "", "sas", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "sess-token", "", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "sess-token", "sas", ""} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"", "", "", ""} + if err := conn.validate(); err == nil { + t.Error(err) + } +} diff --git a/port_forwarder.go b/port_forwarder.go index 6d459b4d6..0a049d586 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -54,8 +54,12 @@ func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { copyConn := func(writer io.Writer, reader io.Reader) { if _, err := io.Copy(writer, reader); err != nil { + fmt.Println(err) channel.Close() conn.Close() + if err != io.EOF { + l.errCh <- fmt.Errorf("tunnel connection: %v", err) + } } } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index e3e219705..33a33b39b 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -1 +1,103 @@ package liveshare + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "testing" + "time" + + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestNewPortForwarder(t *testing.T) { + testServer, client, err := makeMockJoinedClient() + if err != nil { + t.Errorf("create mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("create new server: %v", err) + } + pf := NewPortForwarder(client, server, 80) + if pf == nil { + t.Error("port forwarder is nil") + } +} + +func TestPortForwarderStart(t *testing.T) { + streamName, streamCondition := "stream-name", "stream-condition" + serverSharing := func(req *jsonrpc2.Request) (interface{}, error) { + return Port{StreamName: streamName, StreamCondition: streamCondition}, nil + } + getStream := func(req *jsonrpc2.Request) (interface{}, error) { + return "stream-id", nil + } + + stream := bytes.NewBufferString("stream-data") + testServer, client, err := makeMockJoinedClient( + livesharetest.WithService("serverSharing.startSharing", serverSharing), + livesharetest.WithService("streamManager.getStream", getStream), + livesharetest.WithStream("stream-id", stream), + ) + if err != nil { + t.Errorf("create mock client: %v", err) + } + defer testServer.Close() + + server, err := NewServer(client) + if err != nil { + t.Errorf("create new server: %v", err) + } + + ctx, _ := context.WithCancel(context.Background()) + pf := NewPortForwarder(client, server, 8000) + done := make(chan error) + + go func() { + if err := server.StartSharing(ctx, "http", 8000); err != nil { + done <- fmt.Errorf("start sharing: %v", err) + } + if err := pf.Start(ctx); err != nil { + done <- err + } + done <- nil + }() + + go func() { + var conn net.Conn + retries := 0 + for conn == nil && retries < 2 { + conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second) + time.Sleep(1 * time.Second) + } + if conn == nil { + done <- errors.New("failed to connect to forwarded port") + } + b := make([]byte, len("stream-data")) + if _, err := conn.Read(b); err != nil && err != io.EOF { + done <- fmt.Errorf("reading stream: %v", err) + } + if string(b) != "stream-data" { + done <- fmt.Errorf("stream data is not expected value, got: %v", string(b)) + } + if _, err := conn.Write([]byte("new-data")); err != nil { + done <- fmt.Errorf("writing to stream: %v", err) + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} diff --git a/server.go b/server.go index 6f17d5ac5..7e8c8b1cb 100644 --- a/server.go +++ b/server.go @@ -7,12 +7,14 @@ import ( "strconv" ) +// A Server represents the liveshare host and container server type Server struct { client *Client port int streamName, streamCondition string } +// NewServer creates a new Server with a given Client func NewServer(client *Client) (*Server, error) { if !client.hasJoined() { return nil, errors.New("client must join before creating server") @@ -21,6 +23,7 @@ func NewServer(client *Client) (*Server, error) { return &Server{client: client}, nil } +// Port represents an open port on the container type Port struct { SourcePort int `json:"sourcePort"` DestinationPort int `json:"destinationPort"` @@ -33,6 +36,7 @@ type Port struct { HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"` } +// StartSharing tells the liveshare host to start sharing the port from the container func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { s.port = port @@ -49,8 +53,10 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er return nil } +// Ports is a slice of Port pointers type Ports []*Port +// GetSharedServers returns a list of available/open ports from the container func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { var response Ports if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { @@ -60,6 +66,8 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { return response, nil } +// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly +// via the Browse URL func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error { if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil { return err diff --git a/server_test.go b/server_test.go index 7c9ee4288..b91fbfddc 100644 --- a/server_test.go +++ b/server_test.go @@ -23,7 +23,7 @@ func TestNewServerWithNotJoinedClient(t *testing.T) { } } -func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { +func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { connection := Connection{ SessionID: "session-id", SessionToken: "session-token", @@ -54,7 +54,7 @@ func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Ser } func TestNewServer(t *testing.T) { - testServer, client, err := newMockJoinedClient() + testServer, client, err := makeMockJoinedClient() defer testServer.Close() if err != nil { t.Errorf("error creating mock joined client: %v", err) @@ -95,7 +95,7 @@ func TestServerStartSharing(t *testing.T) { } return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil } - testServer, client, err := newMockJoinedClient( + testServer, client, err := makeMockJoinedClient( livesharetest.WithService("serverSharing.startSharing", startSharing), ) defer testServer.Close() @@ -138,7 +138,7 @@ func TestServerGetSharedServers(t *testing.T) { getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) { return Ports{&sharedServer}, nil } - testServer, client, err := newMockJoinedClient( + testServer, client, err := makeMockJoinedClient( livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), ) if err != nil { @@ -206,7 +206,7 @@ func TestServerUpdateSharedVisibility(t *testing.T) { } return nil, nil } - testServer, client, err := newMockJoinedClient( + testServer, client, err := makeMockJoinedClient( livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), ) if err != nil { diff --git a/socket.go b/socket.go index e4f80a0cf..8744eeb96 100644 --- a/socket.go +++ b/socket.go @@ -3,11 +3,9 @@ package liveshare import ( "context" "crypto/tls" - "errors" "io" "net" "net/http" - "sync" "time" "github.com/gorilla/websocket" @@ -17,10 +15,8 @@ type socket struct { addr string tlsConfig *tls.Config - conn *websocket.Conn - readMutex sync.Mutex - writeMutex sync.Mutex - reader io.Reader + conn *websocket.Conn + reader io.Reader } func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket { @@ -42,19 +38,12 @@ func (s *socket) connect(ctx context.Context) error { } func (s *socket) Read(b []byte) (int, error) { - s.readMutex.Lock() - defer s.readMutex.Unlock() - if s.reader == nil { - messageType, reader, err := s.conn.NextReader() + _, reader, err := s.conn.NextReader() if err != nil { return 0, err } - if messageType != websocket.BinaryMessage { - return 0, errors.New("unexpected websocket message type") - } - s.reader = reader } @@ -71,9 +60,6 @@ func (s *socket) Read(b []byte) (int, error) { } func (s *socket) Write(b []byte) (int, error) { - s.writeMutex.Lock() - defer s.writeMutex.Unlock() - nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage) if err != nil { return 0, err diff --git a/test/server.go b/test/server.go index abb7ac96a..a52d31ab9 100644 --- a/test/server.go +++ b/test/server.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "path/filepath" + "strings" "sync" "time" @@ -21,6 +22,7 @@ type Server struct { password string services map[string]RpcHandleFunc relaySAS string + streams map[string]io.ReadWriter sshConfig *ssh.ServerConfig httptestServer *httptest.Server @@ -50,7 +52,7 @@ func NewServer(opts ...ServerOption) (*Server, error) { server.sshConfig.AddHostKey(privateKey) server.errCh = make(chan error) - server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server))) + server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server))) return server, nil } @@ -81,6 +83,16 @@ func WithRelaySAS(sas string) ServerOption { } } +func WithStream(name string, stream io.ReadWriter) ServerOption { + return func(s *Server) error { + if s.streams == nil { + s.streams = make(map[string]io.ReadWriter) + } + s.streams[name] = stream + return nil + } +} + func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { if string(password) == serverPassword { @@ -104,7 +116,7 @@ func (s *Server) Err() <-chan error { var upgrader = websocket.Upgrader{} -func newConnection(server *Server) http.HandlerFunc { +func makeConnection(server *Server) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { if server.relaySAS != "" { // validate the sas key @@ -135,12 +147,48 @@ func newConnection(server *Server) http.HandlerFunc { server.errCh <- fmt.Errorf("error accepting new channel: %v", err) return } - go ssh.DiscardRequests(reqs) + go handleNewRequests(server, ch, reqs) go handleNewChannel(server, ch) } } } +func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { + for req := range reqs { + if req.WantReply { + if err := req.Reply(true, nil); err != nil { + server.errCh <- fmt.Errorf("error replying to channel request: %v", err) + } + } + if strings.HasPrefix(req.Type, "stream-transport") { + forwardStream(server, req.Type, channel) + } + } +} + +func forwardStream(server *Server, streamName string, channel ssh.Channel) { + simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") + stream, found := server.streams[simpleStreamName] + if !found { + server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName) + return + } + + copy := func(dst io.Writer, src io.Reader) { + if _, err := io.Copy(dst, src); err != nil { + fmt.Println(err) + server.errCh <- fmt.Errorf("io copy: %v", err) + return + } + } + + go copy(stream, channel) + go copy(channel, stream) + + for { + } +} + func handleNewChannel(server *Server, channel ssh.Channel) { stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server)) From 3a2ade23a4a154eb327c097214784aa95bd138c9 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 28 Jul 2021 13:52:30 +0000 Subject: [PATCH 20/68] Connection test --- connection_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/connection_test.go b/connection_test.go index e952290be..f42ec4189 100644 --- a/connection_test.go +++ b/connection_test.go @@ -31,3 +31,11 @@ func TestConnectionInvalid(t *testing.T) { t.Error(err) } } + +func TestConnectionURI(t *testing.T) { + conn := Connection{"sess-id", "sess-token", "sas", "sb://endpoint/.net/liveshare"} + uri := conn.uri("connect") + if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" { + t.Errorf("uri is not correct, got: '%v'", uri) + } +} From ae29c3c1ea7358504650e0c6fd07dc4a57cbb2a0 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 28 Jul 2021 13:55:33 +0000 Subject: [PATCH 21/68] Ignore EOF on terminal close --- terminal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terminal.go b/terminal.go index 1621559a1..c26d9fd9f 100644 --- a/terminal.go +++ b/terminal.go @@ -106,7 +106,7 @@ func (t terminalReadCloser) Close() error { return fmt.Errorf("error making terminal.stopTerminal call: %v", err) } - if err := t.channel.Close(); err != nil { + if err := t.channel.Close(); err != nil && err != io.EOF { return fmt.Errorf("error closing channel: %v", err) } From fbf0d286729dd355889a8caad0b80be71e4ae601 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 6 Aug 2021 01:03:03 +0000 Subject: [PATCH 22/68] port forwarding err handling and test refactors --- example/main.go | 120 ---------------------------------------------- port_forwarder.go | 21 +++++--- test/server.go | 105 +++++++++++----------------------------- test/socket.go | 77 +++++++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 202 deletions(-) delete mode 100644 example/main.go create mode 100644 test/socket.go diff --git a/example/main.go b/example/main.go deleted file mode 100644 index e9347bd14..000000000 --- a/example/main.go +++ /dev/null @@ -1,120 +0,0 @@ -package main - -import ( - "bufio" - "context" - "flag" - "fmt" - "log" - "os" - "time" - - "github.com/github/go-liveshare" -) - -var workspaceIdFlag = flag.String("w", "", "workspace session id") - -func init() { - flag.Parse() -} - -func main() { - liveShare, err := liveshare.New( - liveshare.WithWorkspaceID(*workspaceIdFlag), - liveshare.WithToken(os.Getenv("CODESPACE_TOKEN")), - ) - if err != nil { - log.Fatal(fmt.Errorf("error creating liveshare: %v", err)) - } - - ctx := context.Background() - liveShareClient := liveShare.NewClient() - if err := liveShareClient.Join(ctx); err != nil { - log.Fatal(fmt.Errorf("error joining liveshare with client: %v", err)) - } - - terminal, err := liveShareClient.NewTerminal() - if err != nil { - log.Fatal(fmt.Errorf("error creating liveshare terminal")) - } - - containerID, err := getContainerID(ctx, terminal) - if err != nil { - log.Fatal(fmt.Errorf("error getting container id: %v", err)) - } - - if err := setupSSH(ctx, terminal, containerID); err != nil { - log.Fatal(fmt.Errorf("error setting up ssh: %v", err)) - } - - fmt.Println("Starting server...") - - server, err := liveShareClient.NewServer() - if err != nil { - log.Fatal(fmt.Errorf("error creating server: %v", err)) - } - - fmt.Println("Starting sharing...") - if err := server.StartSharing(ctx, "sshd", 2222); err != nil { - log.Fatal(fmt.Errorf("error server sharing: %v", err)) - } - - portForwarder := liveshare.NewLocalPortForwarder(liveShareClient, server, 2222) - - fmt.Println("Listening on port 2222") - if err := portForwarder.Start(ctx); err != nil { - log.Fatal(fmt.Errorf("error forwarding port: %v", err)) - } -} - -func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID string) error { - cmd := terminal.NewCommand( - "/", - fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c \"echo -e \\\"testpwd1\\ntestpwd1\\n\\\" | sudo passwd codespace;/usr/local/share/ssh-init.sh\"", containerID), - ) - stream, err := cmd.Run(ctx) - if err != nil { - return fmt.Errorf("error running command: %v", err) - } - - scanner := bufio.NewScanner(stream) - scanner.Scan() - - fmt.Println("> Debug:", scanner.Text()) - if err := scanner.Err(); err != nil { - return fmt.Errorf("error scanning stream: %v", err) - } - - if err := stream.Close(); err != nil { - return fmt.Errorf("error closing stream: %v", err) - } - - time.Sleep(2 * time.Second) - - return nil -} - -func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) { - cmd := terminal.NewCommand( - "/", - "/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running", - ) - stream, err := cmd.Run(ctx) - if err != nil { - return "", fmt.Errorf("error running command: %v", err) - } - - scanner := bufio.NewScanner(stream) - scanner.Scan() - - containerID := scanner.Text() - if err := scanner.Err(); err != nil { - return "", fmt.Errorf("error scanning stream: %v", err) - } - - if err := stream.Close(); err != nil { - return "", fmt.Errorf("error closing stream: %v", err) - } - - return containerID, nil -} diff --git a/port_forwarder.go b/port_forwarder.go index 0a049d586..3a73e3fce 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -33,13 +33,22 @@ func (l *PortForwarder) Start(ctx context.Context) error { return fmt.Errorf("error listening on tcp port: %v", err) } - for { - conn, err := ln.Accept() - if err != nil { - return fmt.Errorf("error accepting incoming connection: %v", err) - } + go func() { + for { + conn, err := ln.Accept() + if err != nil { + l.errCh <- fmt.Errorf("error accepting incoming connection: %v", err) + } - go l.handleConnection(ctx, conn) + go l.handleConnection(ctx, conn) + } + }() + + select { + case err := <-l.errCh: + return err + case <-ctx.Done(): + return ln.Close() } return nil diff --git a/test/server.go b/test/server.go index a52d31ab9..159a2a982 100644 --- a/test/server.go +++ b/test/server.go @@ -5,19 +5,43 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" - "path/filepath" "strings" - "sync" - "time" "github.com/gorilla/websocket" "github.com/sourcegraph/jsonrpc2" "golang.org/x/crypto/ssh" ) +const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAp/Jmzy/HaPNx5Bug09FX5Q/KGY4G9c4DfplhWrn31OQCqNiT +ZSLd46rdXC75liHzE7e5Ic0RJN61cYN9SNArjvEXx2vvs7szhwO7LonwPOvpYpUf +daayrgbr6S46plpx+hEZ1kO/6BqMgFuvnkIVThrEyx5b48ll8zgDABsYrKF8/p1V +SjGfb+bLwjn1NtnZF2prBG5P4ZtMR06HaPglLqBJhmc0ZMG5IZGUE7ew/VrPDqdC +f1v4XvvGiU4BLoKYy4QOhyrCGh9Uk/9u0Ea56M2bh4RqwhbpR8m7TYJZ0DVMLbGW +8C+4lCWp+xRyBNxAQh8qeQVCxYl02hPE4bXLGQIDAQABAoIBAEoVPk6UZ+UexhV2 +LnphNOFhFqgxI1bYWmhE5lHsCKuLLLUoW9RYDgL4gw6/1e7o6N3AxFRpre9Soj0B +YIl28k/qf6/DKAhjQnaDKdV8mVF2Swvmdesi7lyfxv6kGtD4wqApXPlMB2IuG94f +E5e+1MEQQ9DJgoU3eNZR1dj9GuRC3PyzPcNNJ2R/MMGFw3sOOVcLOgAukotoicuL +0SiL51rHPQu8a5/darH9EltN1GFeceJSDDhgqMP5T8Tp7g/c3//H6szon4H9W+uN +Z3UrImJ+teJjFOaVDqN93+J2eQSUk0lCPGQCd4U9I4AGDGyU6ucdcLQ58Aha9gmU +uQwkfKUCgYEA0UkuPOSDE9dbXe+yhsbOwMb1kKzJYgFDKjRTSP7D9BOMZu4YyASo +J95R4DWjePlDopafG2tNJoWX+CwUl7Uld1R3Ex6xHBa2B7hwZj860GZtr7D4mdWc +DTVjczAjp4P0K1MIFYQui1mVJterkjKuePiI6q/27L1c2jIa/39BWBcCgYEAzW8R +MFZamVw3eA2JYSpBuqhQgE5gX5IWrmVJZSUhpAQTNG/A4nxf7WGtjy9p99tm0RMb +ld05+sOmNLrzw8Pq8SBpFOd+MAca7lPLS1A2CoaAHbOqRqrzVcZ4EZ2jB3WjoLoq +yctwslGb9KmrhBCdcwT48aPAYUIJCZdqEen2xE8CgYBoMowvywGrvjwCH9X9njvP +5P7cAfrdrY04FQcmP5lmCtmLYZ267/6couaWv33dPBU9fMpIh3rI5BiOebvi8FBw +AgCq50v8lR4Z5+0mKvLoUSbpIy4SwTRJqzwRXHVT8LF/ZH6Q39egj4Bf716/kjYl +im/4kJVatsjk5a9lZ4EsDwKBgERkJ3rKJNtNggHrr8KzSLKVekdc0GTAw+BHRAny +NKLf4Gzij3pXIbBrhlZW2JZ1amNMUzCvN7AuFlUTsDeKL9saiSE2eCIRG3wgVVu7 +VmJmqJw6xgNEwkHaEvr6Wd4P4euOTtRjcB9NX/gxzDHpPiGelCoN8+vtCgkxaVSR +aV+tAoGAO4HtLOfBAVDNbVXa27aJAjQSUq8qfkwUNJNz+rwgpVQahfiVkyqAPCQM +IfRJxKWb0Wbt9ojw3AowK/k0d3LZA7FS41JSiiGKIllSGb+i7JKqKW7RHLA3VJ/E +Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U= +-----END RSA PRIVATE KEY-----` + type Server struct { password string services map[string]RpcHandleFunc @@ -41,11 +65,7 @@ func NewServer(opts ...ServerOption) (*Server, error) { server.sshConfig = &ssh.ServerConfig{ PasswordCallback: sshPasswordCallback(server.password), } - b, err := ioutil.ReadFile(filepath.Join("test", "private.key")) - if err != nil { - return nil, fmt.Errorf("error reading private.key: %v", err) - } - privateKey, err := ssh.ParsePrivateKey(b) + privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey)) if err != nil { return nil, fmt.Errorf("error parsing key: %v", err) } @@ -221,70 +241,3 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr r.server.errCh <- fmt.Errorf("error replying: %v", err) } } - -type socketConn struct { - *websocket.Conn - - reader io.Reader - writeMutex sync.Mutex - readMutex sync.Mutex -} - -func newSocketConn(conn *websocket.Conn) *socketConn { - return &socketConn{Conn: conn} -} - -func (s *socketConn) Read(b []byte) (int, error) { - s.readMutex.Lock() - defer s.readMutex.Unlock() - - if s.reader == nil { - msgType, r, err := s.Conn.NextReader() - if err != nil { - return 0, fmt.Errorf("error getting next reader: %v", err) - } - if msgType != websocket.BinaryMessage { - return 0, fmt.Errorf("invalid message type") - } - s.reader = r - } - - bytesRead, err := s.reader.Read(b) - if err != nil { - s.reader = nil - - if err == io.EOF { - err = nil - } - } - - return bytesRead, err -} - -func (s *socketConn) Write(b []byte) (int, error) { - s.writeMutex.Lock() - defer s.writeMutex.Unlock() - - w, err := s.Conn.NextWriter(websocket.BinaryMessage) - if err != nil { - return 0, fmt.Errorf("error getting next writer: %v", err) - } - - n, err := w.Write(b) - if err != nil { - return 0, fmt.Errorf("error writing: %v", err) - } - - if err := w.Close(); err != nil { - return 0, fmt.Errorf("error closing writer: %v", err) - } - - return n, nil -} - -func (s *socketConn) SetDeadline(deadline time.Time) error { - if err := s.Conn.SetReadDeadline(deadline); err != nil { - return err - } - return s.Conn.SetWriteDeadline(deadline) -} diff --git a/test/socket.go b/test/socket.go new file mode 100644 index 000000000..9a2d92491 --- /dev/null +++ b/test/socket.go @@ -0,0 +1,77 @@ +package livesharetest + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type socketConn struct { + *websocket.Conn + + reader io.Reader + writeMutex sync.Mutex + readMutex sync.Mutex +} + +func newSocketConn(conn *websocket.Conn) *socketConn { + return &socketConn{Conn: conn} +} + +func (s *socketConn) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + msgType, r, err := s.Conn.NextReader() + if err != nil { + return 0, fmt.Errorf("error getting next reader: %v", err) + } + if msgType != websocket.BinaryMessage { + return 0, fmt.Errorf("invalid message type") + } + s.reader = r + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socketConn) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + w, err := s.Conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, fmt.Errorf("error getting next writer: %v", err) + } + + n, err := w.Write(b) + if err != nil { + return 0, fmt.Errorf("error writing: %v", err) + } + + if err := w.Close(); err != nil { + return 0, fmt.Errorf("error closing writer: %v", err) + } + + return n, nil +} + +func (s *socketConn) SetDeadline(deadline time.Time) error { + if err := s.Conn.SetReadDeadline(deadline); err != nil { + return err + } + return s.Conn.SetWriteDeadline(deadline) +} From eb2a17645056c6b34638cf32c9b0d39e068e1e69 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Sat, 7 Aug 2021 17:54:43 +0000 Subject: [PATCH 23/68] remove err print --- port_forwarder.go | 1 - 1 file changed, 1 deletion(-) diff --git a/port_forwarder.go b/port_forwarder.go index 3a73e3fce..06c164e8d 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -63,7 +63,6 @@ func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { copyConn := func(writer io.Writer, reader io.Reader) { if _, err := io.Copy(writer, reader); err != nil { - fmt.Println(err) channel.Close() conn.Close() if err != io.EOF { From 269196c94f3ca7b2d8fb9efe001295bc161d6948 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 18 Aug 2021 15:12:47 +0000 Subject: [PATCH 24/68] support existing connections for port forwarding --- port_forwarder.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/port_forwarder.go b/port_forwarder.go index 06c164e8d..e6eedf16c 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -54,7 +54,12 @@ func (l *PortForwarder) Start(ctx context.Context) error { return nil } -func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { +func (l *PortForwarder) StartWithConn(ctx context.Context, conn io.ReadWriteCloser) error { + go l.handleConnection(ctx, conn) + return <-l.errCh +} + +func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) { channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) if err != nil { l.errCh <- fmt.Errorf("error opening streaming channel for new connection: %v", err) From a89c17a564b9a9a9bbee261d3b4a158c4efff36d Mon Sep 17 00:00:00 2001 From: Edmundo Gonzalez <51725820+edgonmsft@users.noreply.github.com> Date: Thu, 26 Aug 2021 22:34:56 +0000 Subject: [PATCH 25/68] Adding sshRPC interface --- sshRpc.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 sshRpc.go diff --git a/sshRpc.go b/sshRpc.go new file mode 100644 index 000000000..78ac90f82 --- /dev/null +++ b/sshRpc.go @@ -0,0 +1,34 @@ +package liveshare + +import ( + "context" + "errors" +) + +type SshRpc struct { + client *Client +} + +func NewSSHRpc(client *Client) (*SshRpc, error) { + if !client.hasJoined() { + return nil, errors.New("client must join before creating server") + } + return &SshRpc{client: client}, nil +} + +type SshServerStartResult struct { + Result bool `json:"result"` + ServerPort string `json:"serverPort"` + User string `json:"user"` + Message string `json:"message"` +} + +func (s *SshRpc) StartRemoteServer(ctx context.Context) (SshServerStartResult, error) { + var response SshServerStartResult + + if err := s.client.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { + return response, err + } + + return response, nil +} From 18ab421b0846e7a152d431646b12702f14b26ba9 Mon Sep 17 00:00:00 2001 From: Edmundo Gonzalez <51725820+edgonmsft@users.noreply.github.com> Date: Thu, 26 Aug 2021 23:04:05 +0000 Subject: [PATCH 26/68] Rename to SSHServer --- sshRpc.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sshRpc.go b/sshRpc.go index 78ac90f82..ec7d8dfd1 100644 --- a/sshRpc.go +++ b/sshRpc.go @@ -5,26 +5,26 @@ import ( "errors" ) -type SshRpc struct { +type SSHServer struct { client *Client } -func NewSSHRpc(client *Client) (*SshRpc, error) { +func NewSSHServer(client *Client) (*SSHServer, error) { if !client.hasJoined() { return nil, errors.New("client must join before creating server") } - return &SshRpc{client: client}, nil + return &SSHServer{client: client}, nil } -type SshServerStartResult struct { +type SSHServerStartResult struct { Result bool `json:"result"` ServerPort string `json:"serverPort"` User string `json:"user"` Message string `json:"message"` } -func (s *SshRpc) StartRemoteServer(ctx context.Context) (SshServerStartResult, error) { - var response SshServerStartResult +func (s *SSHServer) StartRemoteServer(ctx context.Context) (SSHServerStartResult, error) { + var response SSHServerStartResult if err := s.client.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { return response, err From 0eb769d608552e54f6db6bdb53e70996893a6acb Mon Sep 17 00:00:00 2001 From: Edmundo Gonzalez <51725820+edgonmsft@users.noreply.github.com> Date: Thu, 26 Aug 2021 23:04:35 +0000 Subject: [PATCH 27/68] Rename File --- sshRpc.go => sshServer.go | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sshRpc.go => sshServer.go (100%) diff --git a/sshRpc.go b/sshServer.go similarity index 100% rename from sshRpc.go rename to sshServer.go From 273782bcbcb06bb143d28f322fcc1e935e378737 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 11:49:21 +0000 Subject: [PATCH 28/68] rename file --- sshServer.go => ssh_server.go | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sshServer.go => ssh_server.go (100%) diff --git a/sshServer.go b/ssh_server.go similarity index 100% rename from sshServer.go rename to ssh_server.go From 4af240d87da018b38c4765f31ece31b7aa2c8478 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 30 Aug 2021 17:36:28 -0400 Subject: [PATCH 29/68] handle errors in port forwarding --- port_forwarder.go | 113 +++++++++++++++++++++++++++-------------- port_forwarder_test.go | 5 +- 2 files changed, 77 insertions(+), 41 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index e6eedf16c..cc7b6ea1d 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -5,77 +5,116 @@ import ( "fmt" "io" "net" - "strconv" ) -// A PortForwader can forward ports from a remote liveshare host to localhost +// A PortForwarder forwards TCP traffic between a port on a remote +// LiveShare host and a local port. type PortForwarder struct { client *Client server *Server port int - errCh chan error } -// NewPortForwarder creates a new PortForwader with a given client, server and port +// NewPortForwarder creates a new PortForwarder that connects a given client, server and port. func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { return &PortForwarder{ client: client, server: server, port: port, - errCh: make(chan error), } } -// Start is a method to start forwarding the server to a localhost port -func (l *PortForwarder) Start(ctx context.Context) error { - ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) +// Forward enables port forwarding. It accepts and handles TCP +// connections until it encounters the first error, which may include +// context cancellation. Its result is non-nil. +func (l *PortForwarder) Forward(ctx context.Context) (err error) { + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", l.port)) if err != nil { - return fmt.Errorf("error listening on tcp port: %v", err) + return fmt.Errorf("error listening on TCP port: %v", err) } + defer safeClose(listen, &err) + errc := make(chan error, 1) + sendError := func(err error) { + // Use non-blocking send, to avoid goroutines getting + // stuck in case of concurrent or sequential errors. + select { + case errc <- err: + default: + } + } go func() { for { - conn, err := ln.Accept() + conn, err := listen.Accept() if err != nil { - l.errCh <- fmt.Errorf("error accepting incoming connection: %v", err) + sendError(err) + return } - go l.handleConnection(ctx, conn) + go func() { + if err := l.handleConnection(ctx, conn); err != nil { + sendError(err) + } + }() } }() + return awaitError(ctx, errc) +} + +// ForwardWithConn handles port forwarding for a single connection. +func (l *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { + // Create buffered channel so that send doesn't get stuck after context cancellation. + errc := make(chan error, 1) + go func() { + if err := l.handleConnection(ctx, conn); err != nil { + errc <- err + } + }() + return awaitError(ctx, errc) +} + +func awaitError(ctx context.Context, errc <-chan error) error { select { - case err := <-l.errCh: + case err := <-errc: return err case <-ctx.Done(): - return ln.Close() + return ctx.Err() // canceled } +} +// handleConnection handles forwarding for a single accepted connection, then closes it. +func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) { + defer safeClose(conn, &err) + + channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) + if err != nil { + return fmt.Errorf("error opening streaming channel for new connection: %v", err) + } + defer safeClose(channel, &err) + + errs := make(chan error, 2) + copyConn := func(w io.Writer, r io.Reader) { + _, err := io.Copy(w, r) + errs <- err + } + go copyConn(conn, channel) + go copyConn(channel, conn) + + // await result + for i := 0; i < 2; i++ { + if err := <-errs; err != nil && err != io.EOF { + return fmt.Errorf("tunnel connection: %v", err) + } + } return nil } -func (l *PortForwarder) StartWithConn(ctx context.Context, conn io.ReadWriteCloser) error { - go l.handleConnection(ctx, conn) - return <-l.errCh -} - -func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) { - channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) - if err != nil { - l.errCh <- fmt.Errorf("error opening streaming channel for new connection: %v", err) - return +// safeClose reports the error (to *err) from closing the stream only +// if no other error was previously reported. +func safeClose(closer io.Closer, err *error) { + closeErr := closer.Close() + if *err == nil { + *err = closeErr } - - copyConn := func(writer io.Writer, reader io.Reader) { - if _, err := io.Copy(writer, reader); err != nil { - channel.Close() - conn.Close() - if err != io.EOF { - l.errCh <- fmt.Errorf("tunnel connection: %v", err) - } - } - } - - go copyConn(conn, channel) - go copyConn(channel, conn) } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 33a33b39b..3ae846937 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -63,10 +63,7 @@ func TestPortForwarderStart(t *testing.T) { if err := server.StartSharing(ctx, "http", 8000); err != nil { done <- fmt.Errorf("start sharing: %v", err) } - if err := pf.Start(ctx); err != nil { - done <- err - } - done <- nil + done <- pf.Forward(ctx) }() go func() { From b63972b62f2564dad26a922d346108c4f7683953 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 31 Aug 2021 11:07:26 -0400 Subject: [PATCH 30/68] spell Live Share product name correctly in UI --- client.go | 2 +- client_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index a8a1e3864..fe890f0bf 100644 --- a/client.go +++ b/client.go @@ -69,7 +69,7 @@ func (c *Client) Join(ctx context.Context) (err error) { _, err = c.joinWorkspace(ctx) if err != nil { - return fmt.Errorf("error joining liveshare workspace: %v", err) + return fmt.Errorf("error joining Live Share workspace: %v", err) } return nil diff --git a/client_test.go b/client_test.go index 110c7e3b9..f1591ed51 100644 --- a/client_test.go +++ b/client_test.go @@ -75,7 +75,7 @@ func TestClientJoin(t *testing.T) { livesharetest.WithRelaySAS(connection.RelaySAS), ) if err != nil { - t.Errorf("error creating liveshare server: %v", err) + t.Errorf("error creating Live Share server: %v", err) } defer server.Close() connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") From 55fa17d8bc3055ddd143ac0b4e70f8513c01ef70 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 31 Aug 2021 17:30:40 -0400 Subject: [PATCH 31/68] wip --- client.go | 2 +- port_forwarder.go | 3 +-- port_forwarder_test.go | 3 ++- rpc.go | 26 +++++++++++++------------- rpc_test.go | 3 ++- 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/client.go b/client.go index a8a1e3864..628d557b5 100644 --- a/client.go +++ b/client.go @@ -64,7 +64,7 @@ func (c *Client) Join(ctx context.Context) (err error) { return fmt.Errorf("error connecting to ssh session: %v", err) } - c.rpc = newRpcClient(c.ssh) + c.rpc = newRPCClient(c.ssh) c.rpc.connect(ctx) _, err = c.joinWorkspace(ctx) diff --git a/port_forwarder.go b/port_forwarder.go index e6eedf16c..774fec863 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -48,10 +48,9 @@ func (l *PortForwarder) Start(ctx context.Context) error { case err := <-l.errCh: return err case <-ctx.Done(): + // TODO ctx.Error? return ln.Close() } - - return nil } func (l *PortForwarder) StartWithConn(ctx context.Context, conn io.ReadWriteCloser) error { diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 33a33b39b..a3621c075 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -55,7 +55,8 @@ func TestPortForwarderStart(t *testing.T) { t.Errorf("create new server: %v", err) } - ctx, _ := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() pf := NewPortForwarder(client, server, 8000) done := make(chan error) diff --git a/rpc.go b/rpc.go index 8abd0e98f..3fea63b10 100644 --- a/rpc.go +++ b/rpc.go @@ -15,7 +15,7 @@ type rpcClient struct { handler *rpcHandler } -func newRpcClient(conn io.ReadWriteCloser) *rpcClient { +func newRPCClient(conn io.ReadWriteCloser) *rpcClient { return &rpcClient{conn: conn, handler: newRPCHandler()} } @@ -24,17 +24,17 @@ func (r *rpcClient) connect(ctx context.Context) { r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) } -func (r *rpcClient) do(ctx context.Context, method string, args interface{}, result interface{}) error { +func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error { waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { - return fmt.Errorf("error on dispatch call: %v", err) + return fmt.Errorf("error dispatching %q call: %v", method, err) } return waiter.Wait(ctx, result) } type rpcHandler struct { - mutex sync.RWMutex + mutex sync.Mutex eventHandlers map[string][]chan *jsonrpc2.Request } @@ -44,34 +44,34 @@ func newRPCHandler() *rpcHandler { } } +// TODO: document obligations around chan. It appears to be used for at most one request. func (r *rpcHandler) registerEventHandler(eventMethod string) <-chan *jsonrpc2.Request { r.mutex.Lock() defer r.mutex.Unlock() ch := make(chan *jsonrpc2.Request) - if _, ok := r.eventHandlers[eventMethod]; !ok { - r.eventHandlers[eventMethod] = []chan *jsonrpc2.Request{ch} - } else { - r.eventHandlers[eventMethod] = append(r.eventHandlers[eventMethod], ch) - } + r.eventHandlers[eventMethod] = append(r.eventHandlers[eventMethod], ch) return ch } func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { r.mutex.Lock() - defer r.mutex.Unlock() + handlers := r.eventHandlers[req.Method] + r.eventHandlers[req.Method] = nil + r.mutex.Unlock() - if handlers, ok := r.eventHandlers[req.Method]; ok { + if len(handlers) > 0 { go func() { + // Broadcast the request to each handler in sequence. + // TODO rethink this. needs function call. for _, handler := range handlers { select { case handler <- req: case <-ctx.Done(): + // TODO: ctx.Err break } } - - r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{} }() } } diff --git a/rpc_test.go b/rpc_test.go index d16b32a4f..7543152d1 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -15,7 +15,8 @@ func TestRPCHandlerEvents(t *testing.T) { time.Sleep(1 * time.Second) rpcHandler.Handle(context.Background(), nil, &jsonrpc2.Request{Method: "somethingHappened"}) }() - ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + defer cancel() select { case event := <-eventCh: if event.Method != "somethingHappened" { From af38292f1e0a80e0ef6d996f0d67aee0452c7232 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 1 Sep 2021 18:12:23 -0400 Subject: [PATCH 32/68] fix data races --- rpc.go | 47 +++++++++++++++++++---------------------------- rpc_test.go | 5 ++++- terminal.go | 15 +++++++++++---- 3 files changed, 34 insertions(+), 33 deletions(-) diff --git a/rpc.go b/rpc.go index 3fea63b10..d1e020d17 100644 --- a/rpc.go +++ b/rpc.go @@ -33,45 +33,36 @@ func (r *rpcClient) do(ctx context.Context, method string, args, result interfac return waiter.Wait(ctx, result) } +type rpcHandlerFunc = func(*jsonrpc2.Request) + type rpcHandler struct { - mutex sync.Mutex - eventHandlers map[string][]chan *jsonrpc2.Request + handlersMu sync.Mutex + handlers map[string][]rpcHandlerFunc } func newRPCHandler() *rpcHandler { return &rpcHandler{ - eventHandlers: make(map[string][]chan *jsonrpc2.Request), + handlers: make(map[string][]rpcHandlerFunc), } } -// TODO: document obligations around chan. It appears to be used for at most one request. -func (r *rpcHandler) registerEventHandler(eventMethod string) <-chan *jsonrpc2.Request { - r.mutex.Lock() - defer r.mutex.Unlock() - - ch := make(chan *jsonrpc2.Request) - r.eventHandlers[eventMethod] = append(r.eventHandlers[eventMethod], ch) - return ch +// registerEventHandler registers a handler for the specified event. +// After the next occurrence of the event, the handler will be called, +// once, in its own goroutine. +func (r *rpcHandler) registerEventHandler(eventMethod string, h rpcHandlerFunc) { + r.handlersMu.Lock() + r.handlers[eventMethod] = append(r.handlers[eventMethod], h) + r.handlersMu.Unlock() } +// Handle calls all registered handlers for the request, concurrently, each in its own goroutine. func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { - r.mutex.Lock() - handlers := r.eventHandlers[req.Method] - r.eventHandlers[req.Method] = nil - r.mutex.Unlock() + r.handlersMu.Lock() + handlers := r.handlers[req.Method] + r.handlers[req.Method] = nil + r.handlersMu.Unlock() - if len(handlers) > 0 { - go func() { - // Broadcast the request to each handler in sequence. - // TODO rethink this. needs function call. - for _, handler := range handlers { - select { - case handler <- req: - case <-ctx.Done(): - // TODO: ctx.Err - break - } - } - }() + for _, h := range handlers { + go h(req) } } diff --git a/rpc_test.go b/rpc_test.go index 7543152d1..cf9c4cf81 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -10,7 +10,10 @@ import ( func TestRPCHandlerEvents(t *testing.T) { rpcHandler := newRPCHandler() - eventCh := rpcHandler.registerEventHandler("somethingHappened") + eventCh := make(chan *jsonrpc2.Request) + rpcHandler.registerEventHandler("somethingHappened", func(req *jsonrpc2.Request) { + eventCh <- req + }) go func() { time.Sleep(1 * time.Second) rpcHandler.Handle(context.Background(), nil, &jsonrpc2.Request{Method: "somethingHappened"}) diff --git a/terminal.go b/terminal.go index c26d9fd9f..32dd54248 100644 --- a/terminal.go +++ b/terminal.go @@ -6,6 +6,7 @@ import ( "fmt" "io" + "github.com/sourcegraph/jsonrpc2" "golang.org/x/crypto/ssh" ) @@ -71,12 +72,15 @@ func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { ReadOnlyForGuests: false, } - terminalStarted := t.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStarted") + started := make(chan struct{}) + t.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStarted", func(*jsonrpc2.Request) { + close(started) + }) var result startTerminalResult if err := t.terminal.client.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil { return nil, fmt.Errorf("error making terminal.startTerminal call: %v", err) } - <-terminalStarted + <-started channel, err := t.terminal.client.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) if err != nil { @@ -101,7 +105,10 @@ func (t terminalReadCloser) Read(b []byte) (int, error) { } func (t terminalReadCloser) Close() error { - terminalStopped := t.terminalCommand.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStopped") + stopped := make(chan struct{}) + t.terminalCommand.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStopped", func(*jsonrpc2.Request) { + close(stopped) + }) if err := t.terminalCommand.terminal.client.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil { return fmt.Errorf("error making terminal.stopTerminal call: %v", err) } @@ -110,7 +117,7 @@ func (t terminalReadCloser) Close() error { return fmt.Errorf("error closing channel: %v", err) } - <-terminalStopped + <-stopped return nil } From 4cceda1af02e3a097418baadd14048d331780c50 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 11:06:49 -0400 Subject: [PATCH 33/68] rename Server to Session and simplify API --- client.go | 48 +++++++++------------ client_test.go | 8 ++-- port_forwarder.go | 19 ++++----- port_forwarder_test.go | 21 +++------- rpc.go | 1 + server.go => session.go | 34 +++++---------- server_test.go => session_test.go | 70 +++++++------------------------ ssh.go | 2 +- ssh_server.go | 18 ++++---- terminal.go | 23 ++++------ 10 files changed, 82 insertions(+), 162 deletions(-) rename server.go => session.go (58%) rename server_test.go => session_test.go (74%) diff --git a/client.go b/client.go index fe890f0bf..140db3b02 100644 --- a/client.go +++ b/client.go @@ -8,13 +8,10 @@ import ( "golang.org/x/crypto/ssh" ) -// A Client capable of joining a liveshare connection +// A Client capable of joining a Live Share workspace. type Client struct { connection Connection tlsConfig *tls.Config - - ssh *sshSession - rpc *rpcClient } // A ClientOption is a function that modifies a client @@ -52,31 +49,26 @@ func WithTLSConfig(tlsConfig *tls.Config) ClientOption { } } -// Join is a method that joins the client to the liveshare session -func (c *Client) Join(ctx context.Context) (err error) { +// JoinWorkspace connects the client to the server's Live Share +// workspace and returns a session representing their connection. +func (c *Client) JoinWorkspace(ctx context.Context) (*Session, error) { clientSocket := newSocket(c.connection, c.tlsConfig) if err := clientSocket.connect(ctx); err != nil { - return fmt.Errorf("error connecting websocket: %v", err) + return nil, fmt.Errorf("error connecting websocket: %v", err) } - c.ssh = newSshSession(c.connection.SessionToken, clientSocket) - if err := c.ssh.connect(ctx); err != nil { - return fmt.Errorf("error connecting to ssh session: %v", err) + ssh := newSSHSession(c.connection.SessionToken, clientSocket) + if err := ssh.connect(ctx); err != nil { + return nil, fmt.Errorf("error connecting to ssh session: %v", err) } - c.rpc = newRpcClient(c.ssh) - c.rpc.connect(ctx) - - _, err = c.joinWorkspace(ctx) - if err != nil { - return fmt.Errorf("error joining Live Share workspace: %v", err) + rpc := newRpcClient(ssh) + rpc.connect(ctx) + if _, err := c.joinWorkspace(ctx, rpc); err != nil { + return nil, fmt.Errorf("error joining Live Share workspace: %v", err) } - return nil -} - -func (c *Client) hasJoined() bool { - return c.ssh != nil && c.rpc != nil + return &Session{ssh: ssh, rpc: rpc}, nil } type clientCapabilities struct { @@ -94,32 +86,32 @@ type joinWorkspaceResult struct { SessionNumber int `json:"sessionNumber"` } -func (c *Client) joinWorkspace(ctx context.Context) (*joinWorkspaceResult, error) { +func (client *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) { args := joinWorkspaceArgs{ - ID: c.connection.SessionID, + ID: client.connection.SessionID, ConnectionMode: "local", - JoiningUserSessionToken: c.connection.SessionToken, + JoiningUserSessionToken: client.connection.SessionToken, ClientCapabilities: clientCapabilities{ IsNonInteractive: false, }, } var result joinWorkspaceResult - if err := c.rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { + if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { return nil, fmt.Errorf("error making workspace.joinWorkspace call: %v", err) } return &result, nil } -func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { +func (session *Session) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { args := getStreamArgs{streamName, condition} var streamID string - if err := c.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { + if err := session.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { return nil, fmt.Errorf("error getting stream id: %v", err) } - channel, reqs, err := c.ssh.conn.OpenChannel("session", nil) + channel, reqs, err := session.ssh.conn.OpenChannel("session", nil) if err != nil { return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) } diff --git a/client_test.go b/client_test.go index f1591ed51..c1e61f6e8 100644 --- a/client_test.go +++ b/client_test.go @@ -43,7 +43,7 @@ func TestNewClientWithInvalidConnection(t *testing.T) { } } -func TestClientJoin(t *testing.T) { +func TestJoinSession(t *testing.T) { connection := Connection{ SessionID: "session-id", SessionToken: "session-token", @@ -90,10 +90,12 @@ func TestClientJoin(t *testing.T) { done := make(chan error) go func() { - if err := client.Join(ctx); err != nil { - done <- fmt.Errorf("error joining client: %v", err) + session, err := client.JoinWorkspace(ctx) + if err != nil { + done <- fmt.Errorf("error joining workspace: %v", err) return } + _ = session done <- nil }() diff --git a/port_forwarder.go b/port_forwarder.go index cc7b6ea1d..29dee58f9 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -7,20 +7,17 @@ import ( "net" ) -// A PortForwarder forwards TCP traffic between a port on a remote -// LiveShare host and a local port. +// A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session. type PortForwarder struct { - client *Client - server *Server - port int + session *Session + port int } -// NewPortForwarder creates a new PortForwarder that connects a given client, server and port. -func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { +// NewPortForwarder creates a new PortForwarder for a given Live Share session and local TCP port. +func NewPortForwarder(session *Session, port int) *PortForwarder { return &PortForwarder{ - client: client, - server: server, - port: port, + session: session, + port: port, } } @@ -87,7 +84,7 @@ func awaitError(ctx context.Context, errc <-chan error) error { func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) { defer safeClose(conn, &err) - channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) + channel, err := l.session.openStreamingChannel(ctx, l.session.streamName, l.session.streamCondition) if err != nil { return fmt.Errorf("error opening streaming channel for new connection: %v", err) } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 3ae846937..6af5c7e70 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -15,16 +15,12 @@ import ( ) func TestNewPortForwarder(t *testing.T) { - testServer, client, err := makeMockJoinedClient() + testServer, session, err := makeMockSession() if err != nil { t.Errorf("create mock client: %v", err) } defer testServer.Close() - server, err := NewServer(client) - if err != nil { - t.Errorf("create new server: %v", err) - } - pf := NewPortForwarder(client, server, 80) + pf := NewPortForwarder(session, 80) if pf == nil { t.Error("port forwarder is nil") } @@ -40,27 +36,22 @@ func TestPortForwarderStart(t *testing.T) { } stream := bytes.NewBufferString("stream-data") - testServer, client, err := makeMockJoinedClient( + testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", serverSharing), livesharetest.WithService("streamManager.getStream", getStream), livesharetest.WithStream("stream-id", stream), ) if err != nil { - t.Errorf("create mock client: %v", err) + t.Errorf("create mock session: %v", err) } defer testServer.Close() - server, err := NewServer(client) - if err != nil { - t.Errorf("create new server: %v", err) - } - ctx, _ := context.WithCancel(context.Background()) - pf := NewPortForwarder(client, server, 8000) + pf := NewPortForwarder(session, 8000) done := make(chan error) go func() { - if err := server.StartSharing(ctx, "http", 8000); err != nil { + if err := session.StartSharing(ctx, "http", 8000); err != nil { done <- fmt.Errorf("start sharing: %v", err) } done <- pf.Forward(ctx) diff --git a/rpc.go b/rpc.go index 8abd0e98f..c58ab419d 100644 --- a/rpc.go +++ b/rpc.go @@ -21,6 +21,7 @@ func newRpcClient(conn io.ReadWriteCloser) *rpcClient { func (r *rpcClient) connect(ctx context.Context) { stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) + // TODO(adonovan): fix: ensure r.Conn is eventually Closed! r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) } diff --git a/server.go b/session.go similarity index 58% rename from server.go rename to session.go index 7e8c8b1cb..b1a175df3 100644 --- a/server.go +++ b/session.go @@ -2,27 +2,18 @@ package liveshare import ( "context" - "errors" "fmt" "strconv" ) -// A Server represents the liveshare host and container server -type Server struct { - client *Client +// A Session represents the session between a connected Live Share client and server. +type Session struct { + ssh *sshSession + rpc *rpcClient port int streamName, streamCondition string } -// NewServer creates a new Server with a given Client -func NewServer(client *Client) (*Server, error) { - if !client.hasJoined() { - return nil, errors.New("client must join before creating server") - } - - return &Server{client: client}, nil -} - // Port represents an open port on the container type Port struct { SourcePort int `json:"sourcePort"` @@ -37,11 +28,11 @@ type Port struct { } // StartSharing tells the liveshare host to start sharing the port from the container -func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { +func (s *Session) StartSharing(ctx context.Context, protocol string, port int) error { s.port = port var response Port - if err := s.client.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ + if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ port, protocol, fmt.Sprintf("http://localhost:%s", strconv.Itoa(port)), }, &response); err != nil { return err @@ -53,13 +44,10 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er return nil } -// Ports is a slice of Port pointers -type Ports []*Port - // GetSharedServers returns a list of available/open ports from the container -func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { - var response Ports - if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { +func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) { + var response []*Port + if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { return nil, err } @@ -68,8 +56,8 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { // UpdateSharedVisibility controls port permissions and whether it can be accessed publicly // via the Browse URL -func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error { - if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil { +func (s *Session) UpdateSharedVisibility(ctx context.Context, port int, public bool) error { + if err := s.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil { return err } diff --git a/server_test.go b/session_test.go similarity index 74% rename from server_test.go rename to session_test.go index b91fbfddc..005eacfbd 100644 --- a/server_test.go +++ b/session_test.go @@ -13,17 +13,7 @@ import ( "github.com/sourcegraph/jsonrpc2" ) -func TestNewServerWithNotJoinedClient(t *testing.T) { - client, err := NewClient() - if err != nil { - t.Errorf("error creating new client: %v", err) - } - if _, err := NewServer(client); err == nil { - t.Error("expected error") - } -} - -func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { +func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { connection := Connection{ SessionID: "session-id", SessionToken: "session-token", @@ -47,25 +37,11 @@ func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Se return nil, nil, fmt.Errorf("error creating new client: %v", err) } ctx := context.Background() - if err := client.Join(ctx); err != nil { - return nil, nil, fmt.Errorf("error joining client: %v", err) - } - return testServer, client, nil -} - -func TestNewServer(t *testing.T) { - testServer, client, err := makeMockJoinedClient() - defer testServer.Close() + session, err := client.JoinWorkspace(ctx) if err != nil { - t.Errorf("error creating mock joined client: %v", err) - } - server, err := NewServer(client) - if err != nil { - t.Errorf("error creating new server: %v", err) - } - if server == nil { - t.Error("server is nil") + return nil, nil, fmt.Errorf("error joining workspace: %v", err) } + return testServer, session, nil } func TestServerStartSharing(t *testing.T) { @@ -95,25 +71,21 @@ func TestServerStartSharing(t *testing.T) { } return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil } - testServer, client, err := makeMockJoinedClient( + testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", startSharing), ) defer testServer.Close() if err != nil { - t.Errorf("error creating mock joined client: %v", err) - } - server, err := NewServer(client) - if err != nil { - t.Errorf("error creating new server: %v", err) + t.Errorf("error creating mock session: %v", err) } ctx := context.Background() done := make(chan error) go func() { - if err := server.StartSharing(ctx, serverProtocol, serverPort); err != nil { + if err := session.StartSharing(ctx, serverProtocol, serverPort); err != nil { done <- fmt.Errorf("error sharing server: %v", err) } - if server.streamName == "" || server.streamCondition == "" { + if session.streamName == "" || session.streamCondition == "" { done <- errors.New("stream name or condition is blank") } done <- nil @@ -136,23 +108,19 @@ func TestServerGetSharedServers(t *testing.T) { StreamCondition: "stream-condition", } getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) { - return Ports{&sharedServer}, nil + return []*Port{&sharedServer}, nil } - testServer, client, err := makeMockJoinedClient( + testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), ) if err != nil { - t.Errorf("error creating new mock client: %v", err) + t.Errorf("error creating mock session: %v", err) } defer testServer.Close() - server, err := NewServer(client) - if err != nil { - t.Errorf("error creating new server: %v", err) - } ctx := context.Background() done := make(chan error) go func() { - ports, err := server.GetSharedServers(ctx) + ports, err := session.GetSharedServers(ctx) if err != nil { done <- fmt.Errorf("error getting shared servers: %v", err) } @@ -206,25 +174,17 @@ func TestServerUpdateSharedVisibility(t *testing.T) { } return nil, nil } - testServer, client, err := makeMockJoinedClient( + testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), ) if err != nil { - t.Errorf("creating new mock client: %v", err) + t.Errorf("creating mock session: %v", err) } defer testServer.Close() - server, err := NewServer(client) - if err != nil { - t.Errorf("creating server: %v", err) - } ctx := context.Background() done := make(chan error) go func() { - if err := server.UpdateSharedVisibility(ctx, 80, true); err != nil { - done <- err - return - } - done <- nil + done <- session.UpdateSharedVisibility(ctx, 80, true) }() select { case err := <-testServer.Err(): diff --git a/ssh.go b/ssh.go index e22cd69d1..b68d400a1 100644 --- a/ssh.go +++ b/ssh.go @@ -19,7 +19,7 @@ type sshSession struct { writer io.Writer } -func newSshSession(token string, socket net.Conn) *sshSession { +func newSSHSession(token string, socket net.Conn) *sshSession { return &sshSession{token: token, socket: socket} } diff --git a/ssh_server.go b/ssh_server.go index ec7d8dfd1..03b45f25f 100644 --- a/ssh_server.go +++ b/ssh_server.go @@ -2,18 +2,14 @@ package liveshare import ( "context" - "errors" ) type SSHServer struct { - client *Client + session *Session } -func NewSSHServer(client *Client) (*SSHServer, error) { - if !client.hasJoined() { - return nil, errors.New("client must join before creating server") - } - return &SSHServer{client: client}, nil +func (session *Session) SSHServer() *SSHServer { + return &SSHServer{session: session} } type SSHServerStartResult struct { @@ -23,12 +19,12 @@ type SSHServerStartResult struct { Message string `json:"message"` } -func (s *SSHServer) StartRemoteServer(ctx context.Context) (SSHServerStartResult, error) { +func (s *SSHServer) StartRemoteServer(ctx context.Context) (*SSHServerStartResult, error) { var response SSHServerStartResult - if err := s.client.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { - return response, err + if err := s.session.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { + return nil, err } - return response, nil + return &response, nil } diff --git a/terminal.go b/terminal.go index c26d9fd9f..07532f426 100644 --- a/terminal.go +++ b/terminal.go @@ -2,7 +2,6 @@ package liveshare import ( "context" - "errors" "fmt" "io" @@ -10,17 +9,11 @@ import ( ) type Terminal struct { - client *Client + session *Session } -func NewTerminal(client *Client) (*Terminal, error) { - if !client.hasJoined() { - return nil, errors.New("client must join before creating terminal") - } - - return &Terminal{ - client: client, - }, nil +func NewTerminal(session *Session) *Terminal { + return &Terminal{session: session} } type TerminalCommand struct { @@ -71,14 +64,14 @@ func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { ReadOnlyForGuests: false, } - terminalStarted := t.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStarted") + terminalStarted := t.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStarted") var result startTerminalResult - if err := t.terminal.client.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil { + if err := t.terminal.session.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil { return nil, fmt.Errorf("error making terminal.startTerminal call: %v", err) } <-terminalStarted - channel, err := t.terminal.client.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) + channel, err := t.terminal.session.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) if err != nil { return nil, fmt.Errorf("error opening streaming channel: %v", err) } @@ -101,8 +94,8 @@ func (t terminalReadCloser) Read(b []byte) (int, error) { } func (t terminalReadCloser) Close() error { - terminalStopped := t.terminalCommand.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStopped") - if err := t.terminalCommand.terminal.client.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil { + terminalStopped := t.terminalCommand.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStopped") + if err := t.terminalCommand.terminal.session.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil { return fmt.Errorf("error making terminal.stopTerminal call: %v", err) } From 05a3d90a99b4f884a492c797f352458519d1252a Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 11:39:29 -0400 Subject: [PATCH 34/68] more tweaks --- client.go | 12 ++++++------ port_forwarder_test.go | 3 ++- rpc_test.go | 3 ++- session.go | 5 +++-- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index 140db3b02..19b0aff50 100644 --- a/client.go +++ b/client.go @@ -86,11 +86,11 @@ type joinWorkspaceResult struct { SessionNumber int `json:"sessionNumber"` } -func (client *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) { +func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) { args := joinWorkspaceArgs{ - ID: client.connection.SessionID, + ID: c.connection.SessionID, ConnectionMode: "local", - JoiningUserSessionToken: client.connection.SessionToken, + JoiningUserSessionToken: c.connection.SessionToken, ClientCapabilities: clientCapabilities{ IsNonInteractive: false, }, @@ -104,14 +104,14 @@ func (client *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinW return &result, nil } -func (session *Session) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { +func (s *Session) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { args := getStreamArgs{streamName, condition} var streamID string - if err := session.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { + if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { return nil, fmt.Errorf("error getting stream id: %v", err) } - channel, reqs, err := session.ssh.conn.OpenChannel("session", nil) + channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) if err != nil { return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 6af5c7e70..44ef59fe0 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -46,7 +46,8 @@ func TestPortForwarderStart(t *testing.T) { } defer testServer.Close() - ctx, _ := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() pf := NewPortForwarder(session, 8000) done := make(chan error) diff --git a/rpc_test.go b/rpc_test.go index d16b32a4f..7543152d1 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -15,7 +15,8 @@ func TestRPCHandlerEvents(t *testing.T) { time.Sleep(1 * time.Second) rpcHandler.Handle(context.Background(), nil, &jsonrpc2.Request{Method: "somethingHappened"}) }() - ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + defer cancel() select { case event := <-eventCh: if event.Method != "somethingHappened" { diff --git a/session.go b/session.go index b1a175df3..d13bba9f1 100644 --- a/session.go +++ b/session.go @@ -3,7 +3,6 @@ package liveshare import ( "context" "fmt" - "strconv" ) // A Session represents the session between a connected Live Share client and server. @@ -25,6 +24,8 @@ type Port struct { IsPublic bool `json:"isPublic"` IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"` HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"` + // ^^^ + // TODO(adonovan): fix possible typo in field name, and audit others. } // StartSharing tells the liveshare host to start sharing the port from the container @@ -33,7 +34,7 @@ func (s *Session) StartSharing(ctx context.Context, protocol string, port int) e var response Port if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ - port, protocol, fmt.Sprintf("http://localhost:%s", strconv.Itoa(port)), + port, protocol, fmt.Sprintf("http://localhost:%d", port), }, &response); err != nil { return err } From 6f45c7fa7dfd4553483d28242df77ca059a174d1 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 12:14:04 -0400 Subject: [PATCH 35/68] point out data races to be fixed --- session.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/session.go b/session.go index ed87c6c2c..d57906f26 100644 --- a/session.go +++ b/session.go @@ -7,13 +7,16 @@ import ( // A Session represents the session between a connected Live Share client and server. type Session struct { - ssh *sshSession - rpc *rpcClient - port int + ssh *sshSession + rpc *rpcClient + + // TODO(adonovan): fix: avoid data race of state accessed by + // multiple calls to StartSharing and concurrent calls to + // PortForwarder. Perhaps combine the two operations in the API? streamName, streamCondition string } -// Port represents an open port on the container +// Port describes a port exposed by the container. type Port struct { SourcePort int `json:"sourcePort"` DestinationPort int `json:"destinationPort"` @@ -31,8 +34,6 @@ type Port struct { // StartSharing tells the Live Share host to start sharing the specified port from the container. // The sessionName describes the purpose of the port or service. func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) error { - s.port = port - var response Port if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ port, sessionName, fmt.Sprintf("http://localhost:%d", port), @@ -46,7 +47,8 @@ func (s *Session) StartSharing(ctx context.Context, sessionName string, port int return nil } -// GetSharedServers returns a list of available/open ports from the container +// GetSharedServers returns a description of each container port +// shared by a prior call to StartSharing by some client. func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) { var response []*Port if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { From 87b15aa264e583688aa9b448ea57663b87a2b4cf Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 14:03:48 -0400 Subject: [PATCH 36/68] Fix data race in StartSharing --- client.go | 13 +++++++++-- port_forwarder.go | 50 +++++++++++++++++++++++++++++++----------- port_forwarder_test.go | 11 ++++------ session.go | 24 +++++++------------- session_test.go | 5 +++-- terminal.go | 2 +- 6 files changed, 64 insertions(+), 41 deletions(-) diff --git a/client.go b/client.go index 377ec2512..0088662f7 100644 --- a/client.go +++ b/client.go @@ -86,6 +86,12 @@ type joinWorkspaceResult struct { SessionNumber int `json:"sessionNumber"` } +// A channelID is an identifier for an exposed port on a remote +// container that may be used to open an SSH channel to it. +type channelID struct { + name, condition string +} + func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) { args := joinWorkspaceArgs{ ID: c.connection.SessionID, @@ -104,8 +110,11 @@ func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorksp return &result, nil } -func (s *Session) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { - args := getStreamArgs{streamName, condition} +func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) { + args := getStreamArgs{ + StreamName: id.name, + Condition: id.condition, + } var streamID string if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { return nil, fmt.Errorf("error getting stream id: %v", err) diff --git a/port_forwarder.go b/port_forwarder.go index 29dee58f9..4391ef55c 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -9,23 +9,34 @@ import ( // A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session. type PortForwarder struct { - session *Session - port int + session *Session + name string + localPort, remotePort int } -// NewPortForwarder creates a new PortForwarder for a given Live Share session and local TCP port. -func NewPortForwarder(session *Session, port int) *PortForwarder { +// NewPortForwarder creates a new PortForwarder that forwards traffic +// between the local port and the container's remote port over the +// specified Live Share session. The name describes the purpose of the +// remote port or service. +func NewPortForwarder(session *Session, name string, localPort, remotePort int) *PortForwarder { return &PortForwarder{ - session: session, - port: port, + session: session, + name: name, + localPort: localPort, + remotePort: remotePort, } } // Forward enables port forwarding. It accepts and handles TCP // connections until it encounters the first error, which may include // context cancellation. Its result is non-nil. -func (l *PortForwarder) Forward(ctx context.Context) (err error) { - listen, err := net.Listen("tcp", fmt.Sprintf(":%d", l.port)) +func (fwd *PortForwarder) Forward(ctx context.Context) (err error) { + id, err := fwd.shareRemotePort(ctx) + if err != nil { + return err + } + + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", fwd.localPort)) if err != nil { return fmt.Errorf("error listening on TCP port: %v", err) } @@ -49,7 +60,7 @@ func (l *PortForwarder) Forward(ctx context.Context) (err error) { } go func() { - if err := l.handleConnection(ctx, conn); err != nil { + if err := fwd.handleConnection(ctx, id, conn); err != nil { sendError(err) } }() @@ -60,17 +71,30 @@ func (l *PortForwarder) Forward(ctx context.Context) (err error) { } // ForwardWithConn handles port forwarding for a single connection. -func (l *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { +func (fwd *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { + id, err := fwd.shareRemotePort(ctx) + if err != nil { + return err + } + // Create buffered channel so that send doesn't get stuck after context cancellation. errc := make(chan error, 1) go func() { - if err := l.handleConnection(ctx, conn); err != nil { + if err := fwd.handleConnection(ctx, id, conn); err != nil { errc <- err } }() return awaitError(ctx, errc) } +func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) { + id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort) + if err != nil { + err = fmt.Errorf("failed to share remote port %d: %v", fwd.remotePort, err) + } + return id, nil +} + func awaitError(ctx context.Context, errc <-chan error) error { select { case err := <-errc: @@ -81,10 +105,10 @@ func awaitError(ctx context.Context, errc <-chan error) error { } // handleConnection handles forwarding for a single accepted connection, then closes it. -func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) { +func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) { defer safeClose(conn, &err) - channel, err := l.session.openStreamingChannel(ctx, l.session.streamName, l.session.streamCondition) + channel, err := fwd.session.openStreamingChannel(ctx, id) if err != nil { return fmt.Errorf("error opening streaming channel for new connection: %v", err) } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 44ef59fe0..d47730995 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) { t.Errorf("create mock client: %v", err) } defer testServer.Close() - pf := NewPortForwarder(session, 80) + pf := NewPortForwarder(session, "ssh", 81, 80) if pf == nil { t.Error("port forwarder is nil") } @@ -48,14 +48,11 @@ func TestPortForwarderStart(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pf := NewPortForwarder(session, 8000) - done := make(chan error) + done := make(chan error) go func() { - if err := session.StartSharing(ctx, "http", 8000); err != nil { - done <- fmt.Errorf("start sharing: %v", err) - } - done <- pf.Forward(ctx) + const name, local, remote = "ssh", 8000, 8000 + done <- NewPortForwarder(session, name, local, remote).Forward(ctx) }() go func() { diff --git a/session.go b/session.go index d57906f26..0e3120cd7 100644 --- a/session.go +++ b/session.go @@ -9,11 +9,6 @@ import ( type Session struct { ssh *sshSession rpc *rpcClient - - // TODO(adonovan): fix: avoid data race of state accessed by - // multiple calls to StartSharing and concurrent calls to - // PortForwarder. Perhaps combine the two operations in the API? - streamName, streamCondition string } // Port describes a port exposed by the container. @@ -31,20 +26,17 @@ type Port struct { // TODO(adonovan): fix possible typo in field name, and audit others. } -// StartSharing tells the Live Share host to start sharing the specified port from the container. -// The sessionName describes the purpose of the port or service. -func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) error { +// startSharing tells the Live Share host to start sharing the specified port from the container. +// The sessionName describes the purpose of the remote port or service. +// It returns an identifier that can be used to open an SSH channel to the remote port. +func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) { + args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)} var response Port - if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ - port, sessionName, fmt.Sprintf("http://localhost:%d", port), - }, &response); err != nil { - return err + if err := s.rpc.do(ctx, "serverSharing.startSharing", args, &response); err != nil { + return channelID{}, err } - s.streamName = response.StreamName - s.streamCondition = response.StreamCondition - - return nil + return channelID{response.StreamName, response.StreamCondition}, nil } // GetSharedServers returns a description of each container port diff --git a/session_test.go b/session_test.go index 005eacfbd..54aab16c8 100644 --- a/session_test.go +++ b/session_test.go @@ -82,10 +82,11 @@ func TestServerStartSharing(t *testing.T) { done := make(chan error) go func() { - if err := session.StartSharing(ctx, serverProtocol, serverPort); err != nil { + streamID, err := session.startSharing(ctx, serverProtocol, serverPort) + if err != nil { done <- fmt.Errorf("error sharing server: %v", err) } - if session.streamName == "" || session.streamCondition == "" { + if streamID.name == "" || streamID.condition == "" { done <- errors.New("stream name or condition is blank") } done <- nil diff --git a/terminal.go b/terminal.go index 96938ed89..24a0f5121 100644 --- a/terminal.go +++ b/terminal.go @@ -75,7 +75,7 @@ func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { } <-started - channel, err := t.terminal.session.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) + channel, err := t.terminal.session.openStreamingChannel(ctx, channelID{result.StreamName, result.StreamCondition}) if err != nil { return nil, fmt.Errorf("error opening streaming channel: %v", err) } From 94b91661cc68b200e30e809d8c26b41a7f37c1af Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 14:30:19 -0400 Subject: [PATCH 37/68] don't forget to close conn in case of sharing error --- port_forwarder.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/port_forwarder.go b/port_forwarder.go index 4391ef55c..2d1217c24 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -18,6 +18,12 @@ type PortForwarder struct { // between the local port and the container's remote port over the // specified Live Share session. The name describes the purpose of the // remote port or service. +// +// TODO(adonovan): the localPort param is redundant wrt ForwardWithConn. +// Simpler: do away with the NewPortForwarder type altogether: +// +// - ForwardToLocalPort(ctx, session, name, remote, local) +// - ForwardToConnection(ctx, session, name, remote, conn) func NewPortForwarder(session *Session, name string, localPort, remotePort int) *PortForwarder { return &PortForwarder{ session: session, @@ -74,6 +80,7 @@ func (fwd *PortForwarder) Forward(ctx context.Context) (err error) { func (fwd *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { id, err := fwd.shareRemotePort(ctx) if err != nil { + conn.Close() return err } From 94319d4cfeaa6b6a0389e75c0401e265e2078e09 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 15:34:57 -0400 Subject: [PATCH 38/68] move localPort parameter to ForwardToLocalPort --- port_forwarder.go | 39 +++++++++++++++++---------------------- port_forwarder_test.go | 4 ++-- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index 2d1217c24..4a46cd4e6 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -7,42 +7,36 @@ import ( "net" ) -// A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session. +// A PortForwarder forwards TCP traffic over a LiveShare session from a port on a remote +// container to a local destination such as a network port or Go reader/writer. type PortForwarder struct { - session *Session - name string - localPort, remotePort int + session *Session + name string + remotePort int } -// NewPortForwarder creates a new PortForwarder that forwards traffic -// between the local port and the container's remote port over the -// specified Live Share session. The name describes the purpose of the -// remote port or service. -// -// TODO(adonovan): the localPort param is redundant wrt ForwardWithConn. -// Simpler: do away with the NewPortForwarder type altogether: -// -// - ForwardToLocalPort(ctx, session, name, remote, local) -// - ForwardToConnection(ctx, session, name, remote, conn) -func NewPortForwarder(session *Session, name string, localPort, remotePort int) *PortForwarder { +// NewPortForwarder returns a new PortForwarder for the specified +// remote port and Live Share session. The name describes the purpose +// of the remote port or service. +func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder { return &PortForwarder{ session: session, name: name, - localPort: localPort, remotePort: remotePort, } } -// Forward enables port forwarding. It accepts and handles TCP -// connections until it encounters the first error, which may include +// ForwardToLocalPort forwards traffic between the container's remote +// port and a local TCP port. It accepts and handles TCP connections +// on the local until it encounters the first error, which may include // context cancellation. Its result is non-nil. -func (fwd *PortForwarder) Forward(ctx context.Context) (err error) { +func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, localPort int) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { return err } - listen, err := net.Listen("tcp", fmt.Sprintf(":%d", fwd.localPort)) + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort)) if err != nil { return fmt.Errorf("error listening on TCP port: %v", err) } @@ -76,8 +70,9 @@ func (fwd *PortForwarder) Forward(ctx context.Context) (err error) { return awaitError(ctx, errc) } -// ForwardWithConn handles port forwarding for a single connection. -func (fwd *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { +// Forward forwards traffic between the container's remote port and +// the specified read/write stream. On return, the stream is closed. +func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error { id, err := fwd.shareRemotePort(ctx) if err != nil { conn.Close() diff --git a/port_forwarder_test.go b/port_forwarder_test.go index d47730995..6ccb3d05e 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) { t.Errorf("create mock client: %v", err) } defer testServer.Close() - pf := NewPortForwarder(session, "ssh", 81, 80) + pf := NewPortForwarder(session, "ssh", 80) if pf == nil { t.Error("port forwarder is nil") } @@ -52,7 +52,7 @@ func TestPortForwarderStart(t *testing.T) { done := make(chan error) go func() { const name, local, remote = "ssh", 8000, 8000 - done <- NewPortForwarder(session, name, local, remote).Forward(ctx) + done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, local) }() go func() { From 4438b85e294e510edf97510ede486db175e8f084 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 15:41:36 -0400 Subject: [PATCH 39/68] comment tweaks --- port_forwarder.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index 4a46cd4e6..f4895bb60 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -27,9 +27,9 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar } // ForwardToLocalPort forwards traffic between the container's remote -// port and a local TCP port. It accepts and handles TCP connections -// on the local until it encounters the first error, which may include -// context cancellation. Its result is non-nil. +// port and a local TCP port. It accepts and handles connections on +// the local port until it encounters the first error, which may +// include context cancellation. Its error result is always non-nil. func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, localPort int) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { From 5bd0519ef32827e59d94003b995a62a8915f48d4 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 16:45:23 -0400 Subject: [PATCH 40/68] move Listen call into clients to avoid race --- port_forwarder.go | 26 ++++++++++++++++---------- port_forwarder_test.go | 10 ++++++++-- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index f4895bb60..fe0d7d80e 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -26,22 +26,28 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar } } +// ListenTCP calls listen on the chosen local TCP port. Zero picks an arbitrary port. +// It is provided for the convenience of callers of ForwardToLocalPort. +func Listen(port int) (net.Listener, error) { + return net.Listen("tcp", fmt.Sprintf(":%d", port)) +} + // ForwardToLocalPort forwards traffic between the container's remote -// port and a local TCP port. It accepts and handles connections on -// the local port until it encounters the first error, which may -// include context cancellation. Its error result is always non-nil. -func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, localPort int) (err error) { +// port and a local TCP port, which must already be listening for +// connections. (Accepting a listener rather than a port number avoids +// races against other processes opening ports, and against a client +// connecting to the socket prematurely.) +// +// ForwardToLocalPort accepts and handles connections on the local +// port until it encounters the first error, which may include context +// cancellation. Its error result is always non-nil. The caller is +// responsible for closing the listening port. +func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, listen net.Listener) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { return err } - listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort)) - if err != nil { - return fmt.Errorf("error listening on TCP port: %v", err) - } - defer safeClose(listen, &err) - errc := make(chan error, 1) sendError := func(err error) { // Use non-blocking send, to avoid goroutines getting diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 6ccb3d05e..68b658b6b 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -46,13 +46,19 @@ func TestPortForwarderStart(t *testing.T) { } defer testServer.Close() + listen, err := Listen(8000) // local port + if err != nil { + t.Fatal(err) + } + defer listen.Close() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() done := make(chan error) go func() { - const name, local, remote = "ssh", 8000, 8000 - done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, local) + const name, remote = "ssh", 8000 + done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, listen) }() go func() { From e2552fbd2a049a8314d83e1b1357b6b5494267dd Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 3 Sep 2021 09:43:31 -0400 Subject: [PATCH 41/68] rename to ForwardToListener --- port_forwarder.go | 17 +++++++++-------- port_forwarder_test.go | 4 ++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index fe0d7d80e..593b70fe7 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -26,23 +26,24 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar } } -// ListenTCP calls listen on the chosen local TCP port. Zero picks an arbitrary port. -// It is provided for the convenience of callers of ForwardToLocalPort. -func Listen(port int) (net.Listener, error) { +// ListenTCP calls listen on the chosen local TCP port. Zero picks an +// arbitrary port. It is provided for the convenience of callers of +// ForwardToListener. +func ListenTCP(port int) (net.Listener, error) { return net.Listen("tcp", fmt.Sprintf(":%d", port)) } -// ForwardToLocalPort forwards traffic between the container's remote -// port and a local TCP port, which must already be listening for +// ForwardToListener forwards traffic between the container's remote +// port and a local port, which must already be listening for // connections. (Accepting a listener rather than a port number avoids // races against other processes opening ports, and against a client // connecting to the socket prematurely.) // -// ForwardToLocalPort accepts and handles connections on the local -// port until it encounters the first error, which may include context +// ForwardToListener accepts and handles connections on the local port +// until it encounters the first error, which may include context // cancellation. Its error result is always non-nil. The caller is // responsible for closing the listening port. -func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, listen net.Listener) (err error) { +func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { return err diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 68b658b6b..d6a4e7708 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -46,7 +46,7 @@ func TestPortForwarderStart(t *testing.T) { } defer testServer.Close() - listen, err := Listen(8000) // local port + listen, err := ListenTCP(8000) // local port if err != nil { t.Fatal(err) } @@ -58,7 +58,7 @@ func TestPortForwarderStart(t *testing.T) { done := make(chan error) go func() { const name, remote = "ssh", 8000 - done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, listen) + done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen) }() go func() { From 50523c4f1087ea361b1216d894a26c7ca6fb7d46 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 3 Sep 2021 14:39:47 -0400 Subject: [PATCH 42/68] remove ListenTCP and add workaround for ssh.channel.Close EOF --- port_forwarder.go | 17 +++++++++-------- port_forwarder_test.go | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index 593b70fe7..dc91222ed 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -26,13 +26,6 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar } } -// ListenTCP calls listen on the chosen local TCP port. Zero picks an -// arbitrary port. It is provided for the convenience of callers of -// ForwardToListener. -func ListenTCP(port int) (net.Listener, error) { - return net.Listen("tcp", fmt.Sprintf(":%d", port)) -} - // ForwardToListener forwards traffic between the container's remote // port and a local port, which must already be listening for // connections. (Accepting a listener rather than a port number avoids @@ -121,7 +114,15 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co if err != nil { return fmt.Errorf("error opening streaming channel for new connection: %v", err) } - defer safeClose(channel, &err) + // Ideally we would call safeClose again, but (*ssh.channel).Close + // appears to have a bug that causes it return io.EOF spuriously + // if its peer closed first; see github.com/golang/go/issues/38115. + defer func() { + closeErr := channel.Close() + if err == nil && closeErr != io.EOF { + err = closeErr + } + }() errs := make(chan error, 2) copyConn := func(w io.Writer, r io.Reader) { diff --git a/port_forwarder_test.go b/port_forwarder_test.go index d6a4e7708..c4245f513 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -46,7 +46,7 @@ func TestPortForwarderStart(t *testing.T) { } defer testServer.Close() - listen, err := ListenTCP(8000) // local port + listen, err := net.Listen("tcp", ":8000") if err != nil { t.Fatal(err) } From 72659a360334186804e5cfcf781d504104ca50a8 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 8 Sep 2021 17:21:54 -0400 Subject: [PATCH 43/68] add lightstep instrumentation --- client.go | 7 +++++++ port_forwarder.go | 5 +++++ rpc.go | 4 ++++ 3 files changed, 16 insertions(+) diff --git a/client.go b/client.go index 0088662f7..566db6cd3 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "fmt" + "github.com/opentracing/opentracing-go" "golang.org/x/crypto/ssh" ) @@ -52,6 +53,9 @@ func WithTLSConfig(tlsConfig *tls.Config) ClientOption { // JoinWorkspace connects the client to the server's Live Share // workspace and returns a session representing their connection. func (c *Client) JoinWorkspace(ctx context.Context) (*Session, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "Client.JoinWorkspace") + defer span.Finish() + clientSocket := newSocket(c.connection, c.tlsConfig) if err := clientSocket.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting websocket: %v", err) @@ -120,6 +124,9 @@ func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.C return nil, fmt.Errorf("error getting stream id: %v", err) } + span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") + defer span.Finish() + channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) if err != nil { return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) diff --git a/port_forwarder.go b/port_forwarder.go index dc91222ed..7d3363ba2 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -5,6 +5,8 @@ import ( "fmt" "io" "net" + + "github.com/opentracing/opentracing-go" ) // A PortForwarder forwards TCP traffic over a LiveShare session from a port on a remote @@ -108,6 +110,9 @@ func awaitError(ctx context.Context, errc <-chan error) error { // handleConnection handles forwarding for a single accepted connection, then closes it. func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection") + defer span.Finish() + defer safeClose(conn, &err) channel, err := fwd.session.openStreamingChannel(ctx, id) diff --git a/rpc.go b/rpc.go index 237606fe0..10aa2c7eb 100644 --- a/rpc.go +++ b/rpc.go @@ -6,6 +6,7 @@ import ( "io" "sync" + "github.com/opentracing/opentracing-go" "github.com/sourcegraph/jsonrpc2" ) @@ -26,6 +27,9 @@ func (r *rpcClient) connect(ctx context.Context) { } func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error { + span, ctx := opentracing.StartSpanFromContext(ctx, method) + defer span.Finish() + waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { return fmt.Errorf("error dispatching %q call: %v", method, err) From 8b0e8c990e68dc9b74ed3d40f4c73c9a24bacb7b Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 9 Sep 2021 17:31:18 +0000 Subject: [PATCH 44/68] ignore pf conn errors --- port_forwarder.go | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index dc91222ed..1351025cb 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -33,9 +33,7 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar // connecting to the socket prematurely.) // // ForwardToListener accepts and handles connections on the local port -// until it encounters the first error, which may include context -// cancellation. Its error result is always non-nil. The caller is -// responsible for closing the listening port. +// until the context is cancelled. The caller is responsible for closing the listening port. func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { @@ -124,21 +122,14 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co } }() - errs := make(chan error, 2) - copyConn := func(w io.Writer, r io.Reader) { - _, err := io.Copy(w, r) - errs <- err - } - go copyConn(conn, channel) - go copyConn(channel, conn) + // Bi-directional copy of data. + // If any individual connection has an error, we can safely ignore them + // and defer to connection clients to handle data loss as necessary. + go io.Copy(conn, channel) + go io.Copy(channel, conn) - // await result - for i := 0; i < 2; i++ { - if err := <-errs; err != nil && err != io.EOF { - return fmt.Errorf("tunnel connection: %v", err) - } - } - return nil + <-ctx.Done() + return ctx.Err() } // safeClose reports the error (to *err) from closing the stream only From 1ff5c514fb82458558757fc8f1fcdf6cc838afc0 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 9 Sep 2021 18:35:05 +0000 Subject: [PATCH 45/68] fix erroneous ctx waiting and introduce back io.EOF handling --- port_forwarder.go | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index 1351025cb..f47d11565 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -123,13 +123,28 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co }() // Bi-directional copy of data. - // If any individual connection has an error, we can safely ignore them - // and defer to connection clients to handle data loss as necessary. - go io.Copy(conn, channel) - go io.Copy(channel, conn) + errs := make(chan error, 2) + copyConn := func(w io.Writer, r io.Reader) { + _, err := io.Copy(w, r) + errs <- err + } + go copyConn(conn, channel) + go copyConn(channel, conn) - <-ctx.Done() - return ctx.Err() + // wait until context is cancelled or we've received two io.EOF +Loop: + for i := 0; i < 2; i++ { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errs: + if err != nil && err != io.EOF { + break Loop // non-EOF errors stop connection handling + } + } + } + + return nil } // safeClose reports the error (to *err) from closing the stream only From 920f793c6ddf001901308df947091c9d98563fdd Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 9 Sep 2021 19:33:16 +0000 Subject: [PATCH 46/68] pr feedback --- port_forwarder.go | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index f47d11565..e8649c693 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -122,7 +122,7 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co } }() - // Bi-directional copy of data. + // bi-directional copy of data. errs := make(chan error, 2) copyConn := func(w io.Writer, r io.Reader) { _, err := io.Copy(w, r) @@ -131,20 +131,18 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co go copyConn(conn, channel) go copyConn(channel, conn) - // wait until context is cancelled or we've received two io.EOF -Loop: - for i := 0; i < 2; i++ { + // wait until context is cancelled or both copies are done + for i := 0; ; { select { case <-ctx.Done(): return ctx.Err() - case err := <-errs: - if err != nil && err != io.EOF { - break Loop // non-EOF errors stop connection handling + case <-errs: + i++ + if i == 2 { + return nil } } } - - return nil } // safeClose reports the error (to *err) from closing the stream only From efe519cb7af8015a0747ac20d91877627b21b8cc Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 9 Sep 2021 20:11:45 +0000 Subject: [PATCH 47/68] comments + fix Forward method --- port_forwarder.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index e8649c693..be191c211 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -80,9 +80,7 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) // Create buffered channel so that send doesn't get stuck after context cancellation. errc := make(chan error, 1) go func() { - if err := fwd.handleConnection(ctx, id, conn); err != nil { - errc <- err - } + errc <- fwd.handleConnection(ctx, id, conn) }() return awaitError(ctx, errc) } @@ -131,7 +129,9 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co go copyConn(conn, channel) go copyConn(channel, conn) - // wait until context is cancelled or both copies are done + // Wait until context is cancelled or both copies are done. + // Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail. + // TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file? for i := 0; ; { select { case <-ctx.Done(): From 272ea57b541c33846add2b9b493c0ca011985c93 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 9 Sep 2021 21:00:09 +0000 Subject: [PATCH 48/68] revert comment update --- port_forwarder.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/port_forwarder.go b/port_forwarder.go index be191c211..8011d19fc 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -33,7 +33,9 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar // connecting to the socket prematurely.) // // ForwardToListener accepts and handles connections on the local port -// until the context is cancelled. The caller is responsible for closing the listening port. +// until it encounters the first error, which may include context +// cancellation. Its error result is always non-nil. The caller is +// responsible for closing the listening port. func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { From 5b23d87d47f4ffa6dfce6c10794e9e32ec5c6371 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 10 Sep 2021 15:09:45 -0400 Subject: [PATCH 49/68] Remove Terminal, no longer needed by ghcs --- client.go | 4 ++ rpc.go | 41 +++---------------- rpc_test.go | 31 -------------- terminal.go | 116 ---------------------------------------------------- 4 files changed, 9 insertions(+), 183 deletions(-) delete mode 100644 rpc_test.go delete mode 100644 terminal.go diff --git a/client.go b/client.go index 566db6cd3..65e80a94a 100644 --- a/client.go +++ b/client.go @@ -115,6 +115,10 @@ func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorksp } func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) { + type getStreamArgs struct { + StreamName string `json:"streamName"` + Condition string `json:"condition"` + } args := getStreamArgs{ StreamName: id.name, Condition: id.condition, diff --git a/rpc.go b/rpc.go index 10aa2c7eb..68e187ad6 100644 --- a/rpc.go +++ b/rpc.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "sync" "github.com/opentracing/opentracing-go" "github.com/sourcegraph/jsonrpc2" @@ -12,18 +11,17 @@ import ( type rpcClient struct { *jsonrpc2.Conn - conn io.ReadWriteCloser - handler *rpcHandler + conn io.ReadWriteCloser } func newRPCClient(conn io.ReadWriteCloser) *rpcClient { - return &rpcClient{conn: conn, handler: newRPCHandler()} + return &rpcClient{conn: conn} } func (r *rpcClient) connect(ctx context.Context) { stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) // TODO(adonovan): fix: ensure r.Conn is eventually Closed! - r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) + r.Conn = jsonrpc2.NewConn(ctx, stream, nullHandler{}) } func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error { @@ -38,36 +36,7 @@ func (r *rpcClient) do(ctx context.Context, method string, args, result interfac return waiter.Wait(ctx, result) } -type rpcHandlerFunc = func(*jsonrpc2.Request) +type nullHandler struct{} -type rpcHandler struct { - handlersMu sync.Mutex - handlers map[string][]rpcHandlerFunc -} - -func newRPCHandler() *rpcHandler { - return &rpcHandler{ - handlers: make(map[string][]rpcHandlerFunc), - } -} - -// registerEventHandler registers a handler for the specified event. -// After the next occurrence of the event, the handler will be called, -// once, in its own goroutine. -func (r *rpcHandler) registerEventHandler(eventMethod string, h rpcHandlerFunc) { - r.handlersMu.Lock() - r.handlers[eventMethod] = append(r.handlers[eventMethod], h) - r.handlersMu.Unlock() -} - -// Handle calls all registered handlers for the request, concurrently, each in its own goroutine. -func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { - r.handlersMu.Lock() - handlers := r.handlers[req.Method] - r.handlers[req.Method] = nil - r.handlersMu.Unlock() - - for _, h := range handlers { - go h(req) - } +func (nullHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { } diff --git a/rpc_test.go b/rpc_test.go deleted file mode 100644 index cf9c4cf81..000000000 --- a/rpc_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package liveshare - -import ( - "context" - "testing" - "time" - - "github.com/sourcegraph/jsonrpc2" -) - -func TestRPCHandlerEvents(t *testing.T) { - rpcHandler := newRPCHandler() - eventCh := make(chan *jsonrpc2.Request) - rpcHandler.registerEventHandler("somethingHappened", func(req *jsonrpc2.Request) { - eventCh <- req - }) - go func() { - time.Sleep(1 * time.Second) - rpcHandler.Handle(context.Background(), nil, &jsonrpc2.Request{Method: "somethingHappened"}) - }() - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) - defer cancel() - select { - case event := <-eventCh: - if event.Method != "somethingHappened" { - t.Error("event.Method is not the expect value") - } - case <-ctx.Done(): - t.Error("Test time out") - } -} diff --git a/terminal.go b/terminal.go deleted file mode 100644 index 24a0f5121..000000000 --- a/terminal.go +++ /dev/null @@ -1,116 +0,0 @@ -package liveshare - -import ( - "context" - "fmt" - "io" - - "github.com/sourcegraph/jsonrpc2" - "golang.org/x/crypto/ssh" -) - -type Terminal struct { - session *Session -} - -func NewTerminal(session *Session) *Terminal { - return &Terminal{session: session} -} - -type TerminalCommand struct { - terminal *Terminal - cwd string - cmd string -} - -func (t *Terminal) NewCommand(cwd, cmd string) TerminalCommand { - return TerminalCommand{t, cwd, cmd} -} - -type runArgs struct { - Name string `json:"name"` - Rows int `json:"rows"` - Cols int `json:"cols"` - App string `json:"app"` - Cwd string `json:"cwd"` - CommandLine []string `json:"commandLine"` - ReadOnlyForGuests bool `json:"readOnlyForGuests"` -} - -type startTerminalResult struct { - ID int `json:"id"` - StreamName string `json:"streamName"` - StreamCondition string `json:"streamCondition"` - LocalPipeName string `json:"localPipeName"` - AppProcessID int `json:"appProcessId"` -} - -type getStreamArgs struct { - StreamName string `json:"streamName"` - Condition string `json:"condition"` -} - -type stopTerminalArgs struct { - ID int `json:"id"` -} - -func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { - args := runArgs{ - Name: "RunCommand", - Rows: 10, - Cols: 80, - App: "/bin/bash", - Cwd: t.cwd, - CommandLine: []string{"-c", t.cmd}, - ReadOnlyForGuests: false, - } - - started := make(chan struct{}) - t.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStarted", func(*jsonrpc2.Request) { - close(started) - }) - var result startTerminalResult - if err := t.terminal.session.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil { - return nil, fmt.Errorf("error making terminal.startTerminal call: %v", err) - } - <-started - - channel, err := t.terminal.session.openStreamingChannel(ctx, channelID{result.StreamName, result.StreamCondition}) - if err != nil { - return nil, fmt.Errorf("error opening streaming channel: %v", err) - } - - return t.newTerminalReadCloser(result.ID, channel), nil -} - -type terminalReadCloser struct { - terminalCommand TerminalCommand - terminalID int - channel ssh.Channel -} - -func (t TerminalCommand) newTerminalReadCloser(terminalID int, channel ssh.Channel) io.ReadCloser { - return terminalReadCloser{t, terminalID, channel} -} - -func (t terminalReadCloser) Read(b []byte) (int, error) { - return t.channel.Read(b) -} - -func (t terminalReadCloser) Close() error { - stopped := make(chan struct{}) - t.terminalCommand.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStopped", func(*jsonrpc2.Request) { - close(stopped) - }) - if err := t.terminalCommand.terminal.session.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil { - return fmt.Errorf("error making terminal.stopTerminal call: %v", err) - } - - if err := t.channel.Close(); err != nil && err != io.EOF { - return fmt.Errorf("error closing channel: %v", err) - } - - <-stopped - - return nil -} From 497b45e4e2e41c2fca55f61995a289a6e62022e6 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 14 Sep 2021 23:57:40 +0000 Subject: [PATCH 50/68] ssh server docs --- ssh_server.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ssh_server.go b/ssh_server.go index 03b45f25f..ca66ec7de 100644 --- a/ssh_server.go +++ b/ssh_server.go @@ -4,14 +4,20 @@ import ( "context" ) +// A SSHServer handles starting the remote SSH server. +// If there is no SSH server available it installs one. type SSHServer struct { session *Session } +// SSHServer returns a new SSHServer from the LiveShare Session. func (session *Session) SSHServer() *SSHServer { return &SSHServer{session: session} } +// SSHServerStartResult contains whether or not the start of the SSH server was +// successful. If it succeeded the server port and user is included. If it failed, +// it contains an explanation message. type SSHServerStartResult struct { Result bool `json:"result"` ServerPort string `json:"serverPort"` @@ -19,6 +25,7 @@ type SSHServerStartResult struct { Message string `json:"message"` } +// StartRemoteServer starts or install the remote SSH server and returns the result. func (s *SSHServer) StartRemoteServer(ctx context.Context) (*SSHServerStartResult, error) { var response SSHServerStartResult From 8abff2af97688a47744066aac4cd632f22259b53 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 15 Sep 2021 13:14:58 +0000 Subject: [PATCH 51/68] move StartSSHServer to Session --- port_forwarder.go | 2 +- session.go | 28 ++++++++++++++++++++++++++++ ssh_server.go | 37 ------------------------------------- 3 files changed, 29 insertions(+), 38 deletions(-) delete mode 100644 ssh_server.go diff --git a/port_forwarder.go b/port_forwarder.go index 400d6ac97..5dafd0c65 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -9,7 +9,7 @@ import ( "github.com/opentracing/opentracing-go" ) -// A PortForwarder forwards TCP traffic over a LiveShare session from a port on a remote +// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote // container to a local destination such as a network port or Go reader/writer. type PortForwarder struct { session *Session diff --git a/session.go b/session.go index 0e3120cd7..1d13ad58c 100644 --- a/session.go +++ b/session.go @@ -3,6 +3,7 @@ package liveshare import ( "context" "fmt" + "strconv" ) // A Session represents the session between a connected Live Share client and server. @@ -59,3 +60,30 @@ func (s *Session) UpdateSharedVisibility(ctx context.Context, port int, public b return nil } + +// StartSSHServer starts the SSHD server and returns the user and port for which to authenticate with. +// If there is no SSHD server installed on the server, it will attempt to install it. The installation +// process can take upwards of 20+ seconds. +func (s *Session) StartSSHServer(ctx context.Context) (string, int64, error) { + var response struct { + Result bool `json:"result"` + ServerPort string `json:"serverPort"` + User string `json:"user"` + Message string `json:"message"` + } + + if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { + return "", 0, err + } + + if !response.Result { + return "", 0, fmt.Errorf("failed to start server: %s", response.Message) + } + + port, err := strconv.ParseInt(response.ServerPort, 10, 64) + if err != nil { + return "", 0, fmt.Errorf("failed to parse port: %w", err) + } + + return response.User, port, nil +} diff --git a/ssh_server.go b/ssh_server.go deleted file mode 100644 index ca66ec7de..000000000 --- a/ssh_server.go +++ /dev/null @@ -1,37 +0,0 @@ -package liveshare - -import ( - "context" -) - -// A SSHServer handles starting the remote SSH server. -// If there is no SSH server available it installs one. -type SSHServer struct { - session *Session -} - -// SSHServer returns a new SSHServer from the LiveShare Session. -func (session *Session) SSHServer() *SSHServer { - return &SSHServer{session: session} -} - -// SSHServerStartResult contains whether or not the start of the SSH server was -// successful. If it succeeded the server port and user is included. If it failed, -// it contains an explanation message. -type SSHServerStartResult struct { - Result bool `json:"result"` - ServerPort string `json:"serverPort"` - User string `json:"user"` - Message string `json:"message"` -} - -// StartRemoteServer starts or install the remote SSH server and returns the result. -func (s *SSHServer) StartRemoteServer(ctx context.Context) (*SSHServerStartResult, error) { - var response SSHServerStartResult - - if err := s.session.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { - return nil, err - } - - return &response, nil -} From 20e618fd025e115726853db4b4fc37ec76295ebd Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 15 Sep 2021 13:49:03 +0000 Subject: [PATCH 52/68] pr feedback --- session.go | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/session.go b/session.go index 1d13ad58c..f427fac6d 100644 --- a/session.go +++ b/session.go @@ -61,10 +61,9 @@ func (s *Session) UpdateSharedVisibility(ctx context.Context, port int, public b return nil } -// StartSSHServer starts the SSHD server and returns the user and port for which to authenticate with. -// If there is no SSHD server installed on the server, it will attempt to install it. The installation -// process can take upwards of 20+ seconds. -func (s *Session) StartSSHServer(ctx context.Context) (string, int64, error) { +// StartsSSHServer starts an SSH server in the container, installing sshd if necessary, +// and returns the port on which it listens and the user name clients should provide. +func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) { var response struct { Result bool `json:"result"` ServerPort string `json:"serverPort"` @@ -73,17 +72,17 @@ func (s *Session) StartSSHServer(ctx context.Context) (string, int64, error) { } if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { - return "", 0, err + return 0, "", err } if !response.Result { - return "", 0, fmt.Errorf("failed to start server: %s", response.Message) + return 0, "", fmt.Errorf("failed to start server: %s", response.Message) } - port, err := strconv.ParseInt(response.ServerPort, 10, 64) + port, err := strconv.Atoi(response.ServerPort) if err != nil { - return "", 0, fmt.Errorf("failed to parse port: %w", err) + return 0, "", fmt.Errorf("failed to parse port: %w", err) } - return response.User, port, nil + return port, response.User, nil } From 57d04dc5f020ebefbe080e1fa6873dcded731d7a Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 20 Sep 2021 13:16:38 +0000 Subject: [PATCH 53/68] Allow clients to Close a Session, general tidy up - Allow clients to call Close on a Session to clean up resources - Switch to the %w verb for error wrapping - Fix typo on Port struct after verifying the server does not have a typo --- client.go | 14 +++++++------- port_forwarder.go | 4 ++-- rpc.go | 3 +-- session.go | 18 +++++++++++++++--- ssh.go | 8 ++++---- 5 files changed, 29 insertions(+), 18 deletions(-) diff --git a/client.go b/client.go index 65e80a94a..ba9d2f5e7 100644 --- a/client.go +++ b/client.go @@ -58,18 +58,18 @@ func (c *Client) JoinWorkspace(ctx context.Context) (*Session, error) { clientSocket := newSocket(c.connection, c.tlsConfig) if err := clientSocket.connect(ctx); err != nil { - return nil, fmt.Errorf("error connecting websocket: %v", err) + return nil, fmt.Errorf("error connecting websocket: %w", err) } ssh := newSSHSession(c.connection.SessionToken, clientSocket) if err := ssh.connect(ctx); err != nil { - return nil, fmt.Errorf("error connecting to ssh session: %v", err) + return nil, fmt.Errorf("error connecting to ssh session: %w", err) } rpc := newRPCClient(ssh) rpc.connect(ctx) if _, err := c.joinWorkspace(ctx, rpc); err != nil { - return nil, fmt.Errorf("error joining Live Share workspace: %v", err) + return nil, fmt.Errorf("error joining Live Share workspace: %w", err) } return &Session{ssh: ssh, rpc: rpc}, nil @@ -108,7 +108,7 @@ func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorksp var result joinWorkspaceResult if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { - return nil, fmt.Errorf("error making workspace.joinWorkspace call: %v", err) + return nil, fmt.Errorf("error making workspace.joinWorkspace call: %w", err) } return &result, nil @@ -125,7 +125,7 @@ func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.C } var streamID string if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { - return nil, fmt.Errorf("error getting stream id: %v", err) + return nil, fmt.Errorf("error getting stream id: %w", err) } span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") @@ -133,13 +133,13 @@ func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.C channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) if err != nil { - return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) + return nil, fmt.Errorf("error opening ssh channel for transport: %w", err) } go ssh.DiscardRequests(reqs) requestType := fmt.Sprintf("stream-transport-%s", streamID) if _, err = channel.SendRequest(requestType, true, nil); err != nil { - return nil, fmt.Errorf("error sending channel request: %v", err) + return nil, fmt.Errorf("error sending channel request: %w", err) } return channel, nil diff --git a/port_forwarder.go b/port_forwarder.go index 5dafd0c65..56401cc4d 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -92,7 +92,7 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) { id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort) if err != nil { - err = fmt.Errorf("failed to share remote port %d: %v", fwd.remotePort, err) + err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err) } return id, nil } @@ -115,7 +115,7 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co channel, err := fwd.session.openStreamingChannel(ctx, id) if err != nil { - return fmt.Errorf("error opening streaming channel for new connection: %v", err) + return fmt.Errorf("error opening streaming channel for new connection: %w", err) } // Ideally we would call safeClose again, but (*ssh.channel).Close // appears to have a bug that causes it return io.EOF spuriously diff --git a/rpc.go b/rpc.go index 68e187ad6..bfd214c89 100644 --- a/rpc.go +++ b/rpc.go @@ -20,7 +20,6 @@ func newRPCClient(conn io.ReadWriteCloser) *rpcClient { func (r *rpcClient) connect(ctx context.Context) { stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) - // TODO(adonovan): fix: ensure r.Conn is eventually Closed! r.Conn = jsonrpc2.NewConn(ctx, stream, nullHandler{}) } @@ -30,7 +29,7 @@ func (r *rpcClient) do(ctx context.Context, method string, args, result interfac waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { - return fmt.Errorf("error dispatching %q call: %v", method, err) + return fmt.Errorf("error dispatching %q call: %w", method, err) } return waiter.Wait(ctx, result) diff --git a/session.go b/session.go index f427fac6d..6a078da7e 100644 --- a/session.go +++ b/session.go @@ -12,6 +12,20 @@ type Session struct { rpc *rpcClient } +// Close should be called by users to clean up RPC and SSH resources whenever the session +// is no longer active. +func (s *Session) Close() error { + if err := s.rpc.Close(); err != nil { + return fmt.Errorf("failed to close RPC conn: %w", err) + } + + if err := s.ssh.Close(); err != nil { + return fmt.Errorf("failed to close SSH conn: %w", err) + } + + return nil +} + // Port describes a port exposed by the container. type Port struct { SourcePort int `json:"sourcePort"` @@ -22,9 +36,7 @@ type Port struct { BrowseURL string `json:"browseUrl"` IsPublic bool `json:"isPublic"` IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"` - HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"` - // ^^^ - // TODO(adonovan): fix possible typo in field name, and audit others. + HasTLSHandshakePassed bool `json:"hasTLSHandshakePassed"` } // startSharing tells the Live Share host to start sharing the specified port from the container. diff --git a/ssh.go b/ssh.go index b68d400a1..15f67d2a4 100644 --- a/ssh.go +++ b/ssh.go @@ -36,24 +36,24 @@ func (s *sshSession) connect(ctx context.Context) error { sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig) if err != nil { - return fmt.Errorf("error creating ssh client connection: %v", err) + return fmt.Errorf("error creating ssh client connection: %w", err) } s.conn = sshClientConn sshClient := ssh.NewClient(sshClientConn, chans, reqs) s.Session, err = sshClient.NewSession() if err != nil { - return fmt.Errorf("error creating ssh client session: %v", err) + return fmt.Errorf("error creating ssh client session: %w", err) } s.reader, err = s.Session.StdoutPipe() if err != nil { - return fmt.Errorf("error creating ssh session reader: %v", err) + return fmt.Errorf("error creating ssh session reader: %w", err) } s.writer, err = s.Session.StdinPipe() if err != nil { - return fmt.Errorf("error creating ssh session writer: %v", err) + return fmt.Errorf("error creating ssh session writer: %w", err) } return nil From 40886479ae42cff937e37febadeea4708451d4cb Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 20 Sep 2021 20:35:12 +0000 Subject: [PATCH 54/68] Close SSH even if RPC Close fails --- session.go | 1 + 1 file changed, 1 insertion(+) diff --git a/session.go b/session.go index 6a078da7e..5ea961d82 100644 --- a/session.go +++ b/session.go @@ -16,6 +16,7 @@ type Session struct { // is no longer active. func (s *Session) Close() error { if err := s.rpc.Close(); err != nil { + s.ssh.Close() // close SSH and ignore error return fmt.Errorf("failed to close RPC conn: %w", err) } From 23f6d449e0f6bcf9846dbc57b652f129db7e133d Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 20 Sep 2021 21:16:54 +0000 Subject: [PATCH 55/68] Close RPC conn only - Only close SSH if RPC fails. Closing RPC automatically closes the underlying stream which in this case is the SSH connection. - I thought about closing the SSH conn instead of RPC, but there is a bit more cleanup that the RPC library needs to do. --- session.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/session.go b/session.go index 5ea961d82..5202205a8 100644 --- a/session.go +++ b/session.go @@ -15,16 +15,14 @@ type Session struct { // Close should be called by users to clean up RPC and SSH resources whenever the session // is no longer active. func (s *Session) Close() error { - if err := s.rpc.Close(); err != nil { + // Closing the RPC conn closes the underlying stream (SSH) + // So we only need to close once + err := s.rpc.Close() + if err != nil { s.ssh.Close() // close SSH and ignore error - return fmt.Errorf("failed to close RPC conn: %w", err) } - if err := s.ssh.Close(); err != nil { - return fmt.Errorf("failed to close SSH conn: %w", err) - } - - return nil + return err } // Port describes a port exposed by the container. From 5f6b3a5eeed2c8d0ea3c9073df06dad33686af88 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 21 Sep 2021 13:46:30 +0000 Subject: [PATCH 56/68] Add error context to Session.Close --- session.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/session.go b/session.go index 5202205a8..929e8605b 100644 --- a/session.go +++ b/session.go @@ -17,12 +17,12 @@ type Session struct { func (s *Session) Close() error { // Closing the RPC conn closes the underlying stream (SSH) // So we only need to close once - err := s.rpc.Close() - if err != nil { + if err := s.rpc.Close(); err != nil { s.ssh.Close() // close SSH and ignore error + return fmt.Errorf("error while closing Live Share session: %w", err) } - return err + return nil } // Port describes a port exposed by the container. From b3b675d108d02f32b24ad69b33f1dacdd5e85c1d Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 21 Sep 2021 12:44:30 -0400 Subject: [PATCH 57/68] Merge NewClient and JoinWorkspace into Connect --- client.go | 105 +++++++++++++++++++++++++----------------------- client_test.go | 46 ++------------------- connection.go | 2 +- session_test.go | 9 +---- 4 files changed, 61 insertions(+), 101 deletions(-) diff --git a/client.go b/client.go index ba9d2f5e7..3f9345ce4 100644 --- a/client.go +++ b/client.go @@ -1,3 +1,13 @@ +// Package liveshare is a Go client library for the Visual Studio Live Share +// service, which provides collaborative, distibuted editing and debugging. +// See https://docs.microsoft.com/en-us/visualstudio/liveshare for an overview. +// +// It provides the ability for a Go program to connect to a Live Share +// workspace (Connect), to expose a TCP port on a remote host +// (UpdateSharedVisibility), to start an SSH server listening on an +// exposed port (StartSSHServer), and to forward connections between +// the remote port and a local listening TCP port (ForwardToListener) +// or a local Go reader/writer (Forward). package liveshare import ( @@ -9,66 +19,79 @@ import ( "golang.org/x/crypto/ssh" ) -// A Client capable of joining a Live Share workspace. -type Client struct { +// A client capable of joining a Live Share workspace. +type client struct { connection Connection tlsConfig *tls.Config } -// A ClientOption is a function that modifies a client -type ClientOption func(*Client) error +// An Option updates the initial configuration state of a Live Share connection. +type Option func(*client) error -// NewClient accepts a range of options, applies them and returns a client -func NewClient(opts ...ClientOption) (*Client, error) { - client := new(Client) - - for _, o := range opts { - if err := o(client); err != nil { - return nil, err - } - } - - return client, nil -} - -// WithConnection is a ClientOption that accepts a Connection -func WithConnection(connection Connection) ClientOption { - return func(c *Client) error { +// WithConnection is a Option that accepts a Connection. +// +// TODO(adonovan): WithConnection is not optional, so it should not be +// not an Option. We should make Connection a mandatory parameter of +// Connect, at which point, why not just merge +// client+Option+Connection, rename it to Options, do away with the +// function mechanism, and express TLS config (etc) as public fields +// of Options with sensible zero values, like websocket.Dialer, etc? +func WithConnection(connection Connection) Option { + return func(cli *client) error { if err := connection.validate(); err != nil { return err } - c.connection = connection + cli.connection = connection return nil } } -func WithTLSConfig(tlsConfig *tls.Config) ClientOption { - return func(c *Client) error { - c.tlsConfig = tlsConfig +// WithTLSConfig returns a Connect option that sets the TLS configuration. +func WithTLSConfig(tlsConfig *tls.Config) Option { + return func(cli *client) error { + cli.tlsConfig = tlsConfig return nil } } -// JoinWorkspace connects the client to the server's Live Share -// workspace and returns a session representing their connection. -func (c *Client) JoinWorkspace(ctx context.Context) (*Session, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "Client.JoinWorkspace") +// Connect connects to a Live Share workspace specified by the +// options, and returns a session representing the connection. +// The caller must call the session's Close method to end the session. +func Connect(ctx context.Context, opts ...Option) (*Session, error) { + cli := new(client) + for _, opt := range opts { + if err := opt(cli); err != nil { + return nil, fmt.Errorf("error applying Live Share connect option: %w", err) + } + } + + span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") defer span.Finish() - clientSocket := newSocket(c.connection, c.tlsConfig) - if err := clientSocket.connect(ctx); err != nil { + sock := newSocket(cli.connection, cli.tlsConfig) + if err := sock.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting websocket: %w", err) } - ssh := newSSHSession(c.connection.SessionToken, clientSocket) + ssh := newSSHSession(cli.connection.SessionToken, sock) if err := ssh.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting to ssh session: %w", err) } rpc := newRPCClient(ssh) rpc.connect(ctx) - if _, err := c.joinWorkspace(ctx, rpc); err != nil { + + args := joinWorkspaceArgs{ + ID: cli.connection.SessionID, + ConnectionMode: "local", + JoiningUserSessionToken: cli.connection.SessionToken, + ClientCapabilities: clientCapabilities{ + IsNonInteractive: false, + }, + } + var result joinWorkspaceResult + if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { return nil, fmt.Errorf("error joining Live Share workspace: %w", err) } @@ -96,24 +119,6 @@ type channelID struct { name, condition string } -func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) { - args := joinWorkspaceArgs{ - ID: c.connection.SessionID, - ConnectionMode: "local", - JoiningUserSessionToken: c.connection.SessionToken, - ClientCapabilities: clientCapabilities{ - IsNonInteractive: false, - }, - } - - var result joinWorkspaceResult - if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { - return nil, fmt.Errorf("error making workspace.joinWorkspace call: %w", err) - } - - return &result, nil -} - func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) { type getStreamArgs struct { StreamName string `json:"streamName"` diff --git a/client_test.go b/client_test.go index c1e61f6e8..369c53b28 100644 --- a/client_test.go +++ b/client_test.go @@ -13,37 +13,7 @@ import ( "github.com/sourcegraph/jsonrpc2" ) -func TestNewClient(t *testing.T) { - client, err := NewClient() - if err != nil { - t.Errorf("error creating new client: %v", err) - } - if client == nil { - t.Error("client is nil") - } -} - -func TestNewClientValidConnection(t *testing.T) { - connection := Connection{"1", "2", "3", "4"} - - client, err := NewClient(WithConnection(connection)) - if err != nil { - t.Errorf("error creating new client: %v", err) - } - if client == nil { - t.Error("client is nil") - } -} - -func TestNewClientWithInvalidConnection(t *testing.T) { - connection := Connection{} - - if _, err := NewClient(WithConnection(connection)); err == nil { - t.Error("err is nil") - } -} - -func TestJoinSession(t *testing.T) { +func TestConnect(t *testing.T) { connection := Connection{ SessionID: "session-id", SessionToken: "session-token", @@ -83,21 +53,11 @@ func TestJoinSession(t *testing.T) { ctx := context.Background() tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) - client, err := NewClient(WithConnection(connection), tlsConfig) - if err != nil { - t.Errorf("error creating new client: %v", err) - } done := make(chan error) go func() { - session, err := client.JoinWorkspace(ctx) - if err != nil { - done <- fmt.Errorf("error joining workspace: %v", err) - return - } - _ = session - - done <- nil + _, err := Connect(ctx, WithConnection(connection), tlsConfig) // ignore session + done <- err }() select { diff --git a/connection.go b/connection.go index c1a4632c8..f402e4bb9 100644 --- a/connection.go +++ b/connection.go @@ -6,7 +6,7 @@ import ( "strings" ) -// A Connection represents a set of values necessary to join a liveshare connection +// A Connection represents a set of values necessary to join a liveshare connection. type Connection struct { SessionID string SessionToken string diff --git a/session_test.go b/session_test.go index 54aab16c8..3be90cb0e 100644 --- a/session_test.go +++ b/session_test.go @@ -32,14 +32,9 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, ) connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https") tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) - client, err := NewClient(WithConnection(connection), tlsConfig) + session, err := Connect(context.Background(), WithConnection(connection), tlsConfig) if err != nil { - return nil, nil, fmt.Errorf("error creating new client: %v", err) - } - ctx := context.Background() - session, err := client.JoinWorkspace(ctx) - if err != nil { - return nil, nil, fmt.Errorf("error joining workspace: %v", err) + return nil, nil, fmt.Errorf("error connecting to Live Share: %v", err) } return testServer, session, nil } From f8a8713520f031758a2b75dc70c5faaea2927ea5 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 21 Sep 2021 15:23:02 -0400 Subject: [PATCH 58/68] refactor Options API --- client.go | 77 ++++++++++++++++++++++------------------------ client_test.go | 16 +++++----- connection.go | 44 -------------------------- connection_test.go | 41 ------------------------ options_test.go | 56 +++++++++++++++++++++++++++++++++ session_test.go | 22 ++++++------- socket.go | 4 +-- 7 files changed, 113 insertions(+), 147 deletions(-) delete mode 100644 connection.go delete mode 100644 connection_test.go create mode 100644 options_test.go diff --git a/client.go b/client.go index 3f9345ce4..b51e25ea6 100644 --- a/client.go +++ b/client.go @@ -13,68 +13,65 @@ package liveshare import ( "context" "crypto/tls" + "errors" "fmt" + "net/url" + "strings" "github.com/opentracing/opentracing-go" "golang.org/x/crypto/ssh" ) -// A client capable of joining a Live Share workspace. -type client struct { - connection Connection - tlsConfig *tls.Config +// An Options specifies Live Share connection parameters. +type Options struct { + SessionID string + SessionToken string // token for SSH session + RelaySAS string + RelayEndpoint string + TLSConfig *tls.Config // (optional) } -// An Option updates the initial configuration state of a Live Share connection. -type Option func(*client) error - -// WithConnection is a Option that accepts a Connection. -// -// TODO(adonovan): WithConnection is not optional, so it should not be -// not an Option. We should make Connection a mandatory parameter of -// Connect, at which point, why not just merge -// client+Option+Connection, rename it to Options, do away with the -// function mechanism, and express TLS config (etc) as public fields -// of Options with sensible zero values, like websocket.Dialer, etc? -func WithConnection(connection Connection) Option { - return func(cli *client) error { - if err := connection.validate(); err != nil { - return err - } - - cli.connection = connection - return nil +// uri returns a websocket URL for the specified options. +func (opts *Options) uri(action string) (string, error) { + if opts.SessionID == "" { + return "", errors.New("SessionID is required") } -} - -// WithTLSConfig returns a Connect option that sets the TLS configuration. -func WithTLSConfig(tlsConfig *tls.Config) Option { - return func(cli *client) error { - cli.tlsConfig = tlsConfig - return nil + if opts.RelaySAS == "" { + return "", errors.New("RelaySAS is required") } + if opts.RelayEndpoint == "" { + return "", errors.New("RelayEndpoint is required") + } + + sas := url.QueryEscape(opts.RelaySAS) + uri := opts.RelayEndpoint + uri = strings.Replace(uri, "sb:", "wss:", -1) + uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) + uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas + return uri, nil } // Connect connects to a Live Share workspace specified by the // options, and returns a session representing the connection. // The caller must call the session's Close method to end the session. -func Connect(ctx context.Context, opts ...Option) (*Session, error) { - cli := new(client) - for _, opt := range opts { - if err := opt(cli); err != nil { - return nil, fmt.Errorf("error applying Live Share connect option: %w", err) - } +func Connect(ctx context.Context, opts Options) (*Session, error) { + uri, err := opts.uri("connect") + if err != nil { + return nil, err } span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") defer span.Finish() - sock := newSocket(cli.connection, cli.tlsConfig) + sock := newSocket(uri, opts.TLSConfig) if err := sock.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting websocket: %w", err) } - ssh := newSSHSession(cli.connection.SessionToken, sock) + if opts.SessionToken == "" { + return nil, errors.New("SessionToken is required") + } + ssh := newSSHSession(opts.SessionToken, sock) if err := ssh.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting to ssh session: %w", err) } @@ -83,9 +80,9 @@ func Connect(ctx context.Context, opts ...Option) (*Session, error) { rpc.connect(ctx) args := joinWorkspaceArgs{ - ID: cli.connection.SessionID, + ID: opts.SessionID, ConnectionMode: "local", - JoiningUserSessionToken: cli.connection.SessionToken, + JoiningUserSessionToken: opts.SessionToken, ClientCapabilities: clientCapabilities{ IsNonInteractive: false, }, diff --git a/client_test.go b/client_test.go index 369c53b28..2b95f738f 100644 --- a/client_test.go +++ b/client_test.go @@ -14,7 +14,7 @@ import ( ) func TestConnect(t *testing.T) { - connection := Connection{ + opts := Options{ SessionID: "session-id", SessionToken: "session-token", RelaySAS: "relay-sas", @@ -24,13 +24,13 @@ func TestConnect(t *testing.T) { if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { return nil, fmt.Errorf("error unmarshaling req: %v", err) } - if joinWorkspaceReq.ID != connection.SessionID { + if joinWorkspaceReq.ID != opts.SessionID { return nil, errors.New("connection session id does not match") } if joinWorkspaceReq.ConnectionMode != "local" { return nil, errors.New("connection mode is not local") } - if joinWorkspaceReq.JoiningUserSessionToken != connection.SessionToken { + if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken { return nil, errors.New("connection user token does not match") } if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false { @@ -40,23 +40,23 @@ func TestConnect(t *testing.T) { } server, err := livesharetest.NewServer( - livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithPassword(opts.SessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), - livesharetest.WithRelaySAS(connection.RelaySAS), + livesharetest.WithRelaySAS(opts.RelaySAS), ) if err != nil { t.Errorf("error creating Live Share server: %v", err) } defer server.Close() - connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") + opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") ctx := context.Background() - tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} done := make(chan error) go func() { - _, err := Connect(ctx, WithConnection(connection), tlsConfig) // ignore session + _, err := Connect(ctx, opts) // ignore session done <- err }() diff --git a/connection.go b/connection.go deleted file mode 100644 index f402e4bb9..000000000 --- a/connection.go +++ /dev/null @@ -1,44 +0,0 @@ -package liveshare - -import ( - "errors" - "net/url" - "strings" -) - -// A Connection represents a set of values necessary to join a liveshare connection. -type Connection struct { - SessionID string - SessionToken string - RelaySAS string - RelayEndpoint string -} - -func (r Connection) validate() error { - if r.SessionID == "" { - return errors.New("connection SessionID is required") - } - - if r.SessionToken == "" { - return errors.New("connection SessionToken is required") - } - - if r.RelaySAS == "" { - return errors.New("connection RelaySAS is required") - } - - if r.RelayEndpoint == "" { - return errors.New("connection RelayEndpoint is required") - } - - return nil -} - -func (r Connection) uri(action string) string { - sas := url.QueryEscape(r.RelaySAS) - uri := r.RelayEndpoint - uri = strings.Replace(uri, "sb:", "wss:", -1) - uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) - uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas - return uri -} diff --git a/connection_test.go b/connection_test.go deleted file mode 100644 index f42ec4189..000000000 --- a/connection_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package liveshare - -import "testing" - -func TestConnectionValid(t *testing.T) { - conn := Connection{"sess-id", "sess-token", "sas", "endpoint"} - if err := conn.validate(); err != nil { - t.Error(err) - } -} - -func TestConnectionInvalid(t *testing.T) { - conn := Connection{"", "sess-token", "sas", "endpoint"} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"sess-id", "", "sas", "endpoint"} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"sess-id", "sess-token", "", "endpoint"} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"sess-id", "sess-token", "sas", ""} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"", "", "", ""} - if err := conn.validate(); err == nil { - t.Error(err) - } -} - -func TestConnectionURI(t *testing.T) { - conn := Connection{"sess-id", "sess-token", "sas", "sb://endpoint/.net/liveshare"} - uri := conn.uri("connect") - if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" { - t.Errorf("uri is not correct, got: '%v'", uri) - } -} diff --git a/options_test.go b/options_test.go new file mode 100644 index 000000000..830c59104 --- /dev/null +++ b/options_test.go @@ -0,0 +1,56 @@ +package liveshare + +import ( + "context" + "testing" +) + +func TestBadOptions(t *testing.T) { + goodOptions := Options{ + SessionID: "sess-id", + SessionToken: "sess-token", + RelaySAS: "sas", + RelayEndpoint: "endpoint", + } + + opts := goodOptions + opts.SessionID = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.SessionToken = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.RelaySAS = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.RelayEndpoint = "" + checkBadOptions(t, opts) + + opts = Options{} + checkBadOptions(t, opts) +} + +func checkBadOptions(t *testing.T, opts Options) { + if _, err := Connect(context.Background(), opts); err == nil { + t.Errorf("Connect(%+v): no error", opts) + } +} + +func TestOptionsURI(t *testing.T) { + opts := Options{ + SessionID: "sess-id", + SessionToken: "sess-token", + RelaySAS: "sas", + RelayEndpoint: "sb://endpoint/.net/liveshare", + } + uri, err := opts.uri("connect") + if err != nil { + t.Fatal(err) + } + if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" { + t.Errorf("uri is not correct, got: '%v'", uri) + } +} diff --git a/session_test.go b/session_test.go index 3be90cb0e..cd0a7b474 100644 --- a/session_test.go +++ b/session_test.go @@ -14,25 +14,23 @@ import ( ) func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { - connection := Connection{ - SessionID: "session-id", - SessionToken: "session-token", - RelaySAS: "relay-sas", - } joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { return joinWorkspaceResult{1}, nil } + const sessionToken = "session-token" opts = append( opts, - livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithPassword(sessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), ) - testServer, err := livesharetest.NewServer( - opts..., - ) - connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https") - tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) - session, err := Connect(context.Background(), WithConnection(connection), tlsConfig) + testServer, err := livesharetest.NewServer(opts...) + session, err := Connect(context.Background(), Options{ + SessionID: "session-id", + SessionToken: sessionToken, + RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), + RelaySAS: "relay-sas", + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + }) if err != nil { return nil, nil, fmt.Errorf("error connecting to Live Share: %v", err) } diff --git a/socket.go b/socket.go index 8744eeb96..f66436f65 100644 --- a/socket.go +++ b/socket.go @@ -19,8 +19,8 @@ type socket struct { reader io.Reader } -func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket { - return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig} +func newSocket(uri string, tlsConfig *tls.Config) *socket { + return &socket{addr: uri, tlsConfig: tlsConfig} } func (s *socket) connect(ctx context.Context) error { From 6ca35d0e730d1adaecc1c7c79c9c4892e2138449 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 11:18:49 -0400 Subject: [PATCH 59/68] Moved files to liveshare dir --- client.go => liveshare/client.go | 0 client_test.go => liveshare/client_test.go | 0 options_test.go => liveshare/options_test.go | 0 port_forwarder.go => liveshare/port_forwarder.go | 0 port_forwarder_test.go => liveshare/port_forwarder_test.go | 0 rpc.go => liveshare/rpc.go | 0 session.go => liveshare/session.go | 0 session_test.go => liveshare/session_test.go | 0 socket.go => liveshare/socket.go | 0 ssh.go => liveshare/ssh.go | 0 {test => liveshare/test}/server.go | 0 {test => liveshare/test}/socket.go | 0 12 files changed, 0 insertions(+), 0 deletions(-) rename client.go => liveshare/client.go (100%) rename client_test.go => liveshare/client_test.go (100%) rename options_test.go => liveshare/options_test.go (100%) rename port_forwarder.go => liveshare/port_forwarder.go (100%) rename port_forwarder_test.go => liveshare/port_forwarder_test.go (100%) rename rpc.go => liveshare/rpc.go (100%) rename session.go => liveshare/session.go (100%) rename session_test.go => liveshare/session_test.go (100%) rename socket.go => liveshare/socket.go (100%) rename ssh.go => liveshare/ssh.go (100%) rename {test => liveshare/test}/server.go (100%) rename {test => liveshare/test}/socket.go (100%) diff --git a/client.go b/liveshare/client.go similarity index 100% rename from client.go rename to liveshare/client.go diff --git a/client_test.go b/liveshare/client_test.go similarity index 100% rename from client_test.go rename to liveshare/client_test.go diff --git a/options_test.go b/liveshare/options_test.go similarity index 100% rename from options_test.go rename to liveshare/options_test.go diff --git a/port_forwarder.go b/liveshare/port_forwarder.go similarity index 100% rename from port_forwarder.go rename to liveshare/port_forwarder.go diff --git a/port_forwarder_test.go b/liveshare/port_forwarder_test.go similarity index 100% rename from port_forwarder_test.go rename to liveshare/port_forwarder_test.go diff --git a/rpc.go b/liveshare/rpc.go similarity index 100% rename from rpc.go rename to liveshare/rpc.go diff --git a/session.go b/liveshare/session.go similarity index 100% rename from session.go rename to liveshare/session.go diff --git a/session_test.go b/liveshare/session_test.go similarity index 100% rename from session_test.go rename to liveshare/session_test.go diff --git a/socket.go b/liveshare/socket.go similarity index 100% rename from socket.go rename to liveshare/socket.go diff --git a/ssh.go b/liveshare/ssh.go similarity index 100% rename from ssh.go rename to liveshare/ssh.go diff --git a/test/server.go b/liveshare/test/server.go similarity index 100% rename from test/server.go rename to liveshare/test/server.go diff --git a/test/socket.go b/liveshare/test/socket.go similarity index 100% rename from test/socket.go rename to liveshare/test/socket.go From f4396e8f1a0b79630e81b233b802d49cd0172dad Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 11:28:04 -0400 Subject: [PATCH 60/68] Inline go-liveshare with history --- cmd/ghcs/logs.go | 2 +- cmd/ghcs/ports.go | 2 +- cmd/ghcs/ssh.go | 2 +- internal/codespaces/codespaces.go | 2 +- internal/codespaces/states.go | 2 +- {liveshare => internal/liveshare}/client.go | 0 {liveshare => internal/liveshare}/client_test.go | 2 +- {liveshare => internal/liveshare}/options_test.go | 0 {liveshare => internal/liveshare}/port_forwarder.go | 0 {liveshare => internal/liveshare}/port_forwarder_test.go | 2 +- {liveshare => internal/liveshare}/rpc.go | 0 {liveshare => internal/liveshare}/session.go | 0 {liveshare => internal/liveshare}/session_test.go | 2 +- {liveshare => internal/liveshare}/socket.go | 0 {liveshare => internal/liveshare}/ssh.go | 0 {liveshare => internal/liveshare}/test/server.go | 0 {liveshare => internal/liveshare}/test/socket.go | 0 17 files changed, 8 insertions(+), 8 deletions(-) rename {liveshare => internal/liveshare}/client.go (100%) rename {liveshare => internal/liveshare}/client_test.go (96%) rename {liveshare => internal/liveshare}/options_test.go (100%) rename {liveshare => internal/liveshare}/port_forwarder.go (100%) rename {liveshare => internal/liveshare}/port_forwarder_test.go (97%) rename {liveshare => internal/liveshare}/rpc.go (100%) rename {liveshare => internal/liveshare}/session.go (100%) rename {liveshare => internal/liveshare}/session_test.go (98%) rename {liveshare => internal/liveshare}/socket.go (100%) rename {liveshare => internal/liveshare}/ssh.go (100%) rename {liveshare => internal/liveshare}/test/server.go (100%) rename {liveshare => internal/liveshare}/test/socket.go (100%) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 514c36966..4bc83001b 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -9,7 +9,7 @@ import ( "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" "github.com/spf13/cobra" ) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 8a4f855fa..24ec7a6e8 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -14,7 +14,7 @@ import ( "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" "github.com/muhammadmuzzammil1998/jsonc" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 3a49e6ebc..bb771107a 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -9,7 +9,7 @@ import ( "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" "github.com/spf13/cobra" ) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 43809bab9..1cd605abc 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -7,7 +7,7 @@ import ( "time" "github.com/github/ghcs/internal/api" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" ) type logger interface { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 31105d576..c7d61b41e 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -10,7 +10,7 @@ import ( "time" "github.com/github/ghcs/internal/api" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" ) // PostCreateStateStatus is a string value representing the different statuses a state can have. diff --git a/liveshare/client.go b/internal/liveshare/client.go similarity index 100% rename from liveshare/client.go rename to internal/liveshare/client.go diff --git a/liveshare/client_test.go b/internal/liveshare/client_test.go similarity index 96% rename from liveshare/client_test.go rename to internal/liveshare/client_test.go index 2b95f738f..55139d762 100644 --- a/liveshare/client_test.go +++ b/internal/liveshare/client_test.go @@ -9,7 +9,7 @@ import ( "strings" "testing" - livesharetest "github.com/github/go-liveshare/test" + livesharetest "github.com/github/ghcs/internal/liveshare/test" "github.com/sourcegraph/jsonrpc2" ) diff --git a/liveshare/options_test.go b/internal/liveshare/options_test.go similarity index 100% rename from liveshare/options_test.go rename to internal/liveshare/options_test.go diff --git a/liveshare/port_forwarder.go b/internal/liveshare/port_forwarder.go similarity index 100% rename from liveshare/port_forwarder.go rename to internal/liveshare/port_forwarder.go diff --git a/liveshare/port_forwarder_test.go b/internal/liveshare/port_forwarder_test.go similarity index 97% rename from liveshare/port_forwarder_test.go rename to internal/liveshare/port_forwarder_test.go index c4245f513..25b4b2c80 100644 --- a/liveshare/port_forwarder_test.go +++ b/internal/liveshare/port_forwarder_test.go @@ -10,7 +10,7 @@ import ( "testing" "time" - livesharetest "github.com/github/go-liveshare/test" + livesharetest "github.com/github/ghcs/internal/liveshare/test" "github.com/sourcegraph/jsonrpc2" ) diff --git a/liveshare/rpc.go b/internal/liveshare/rpc.go similarity index 100% rename from liveshare/rpc.go rename to internal/liveshare/rpc.go diff --git a/liveshare/session.go b/internal/liveshare/session.go similarity index 100% rename from liveshare/session.go rename to internal/liveshare/session.go diff --git a/liveshare/session_test.go b/internal/liveshare/session_test.go similarity index 98% rename from liveshare/session_test.go rename to internal/liveshare/session_test.go index cd0a7b474..1273c6f2b 100644 --- a/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -9,7 +9,7 @@ import ( "strings" "testing" - livesharetest "github.com/github/go-liveshare/test" + livesharetest "github.com/github/ghcs/internal/liveshare/test" "github.com/sourcegraph/jsonrpc2" ) diff --git a/liveshare/socket.go b/internal/liveshare/socket.go similarity index 100% rename from liveshare/socket.go rename to internal/liveshare/socket.go diff --git a/liveshare/ssh.go b/internal/liveshare/ssh.go similarity index 100% rename from liveshare/ssh.go rename to internal/liveshare/ssh.go diff --git a/liveshare/test/server.go b/internal/liveshare/test/server.go similarity index 100% rename from liveshare/test/server.go rename to internal/liveshare/test/server.go diff --git a/liveshare/test/socket.go b/internal/liveshare/test/socket.go similarity index 100% rename from liveshare/test/socket.go rename to internal/liveshare/test/socket.go From d0c65e549067426f80caf6dc5a99f99ffa4006cd Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 11:36:27 -0400 Subject: [PATCH 61/68] Linter fixes --- internal/liveshare/session_test.go | 11 +++++++++- internal/liveshare/test/server.go | 32 +++++++++++++++--------------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go index 1273c6f2b..0ffdfe136 100644 --- a/internal/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -24,6 +24,10 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), ) testServer, err := livesharetest.NewServer(opts...) + if err != nil { + return nil, nil, fmt.Errorf("error creating server: %w", err) + } + session, err := Connect(context.Background(), Options{ SessionID: "session-id", SessionToken: sessionToken, @@ -67,7 +71,12 @@ func TestServerStartSharing(t *testing.T) { testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", startSharing), ) - defer testServer.Close() + defer func() { + if err := testServer.Close(); err != nil { + t.Errorf("failed to close test server: %w", err) + } + }() + if err != nil { t.Errorf("error creating mock session: %v", err) } diff --git a/internal/liveshare/test/server.go b/internal/liveshare/test/server.go index 159a2a982..8f80d1bce 100644 --- a/internal/liveshare/test/server.go +++ b/internal/liveshare/test/server.go @@ -44,7 +44,7 @@ Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U= type Server struct { password string - services map[string]RpcHandleFunc + services map[string]RPCHandleFunc relaySAS string streams map[string]io.ReadWriter @@ -67,7 +67,7 @@ func NewServer(opts ...ServerOption) (*Server, error) { } privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey)) if err != nil { - return nil, fmt.Errorf("error parsing key: %v", err) + return nil, fmt.Errorf("error parsing key: %w", err) } server.sshConfig.AddHostKey(privateKey) @@ -85,10 +85,10 @@ func WithPassword(password string) ServerOption { } } -func WithService(serviceName string, handler RpcHandleFunc) ServerOption { +func WithService(serviceName string, handler RPCHandleFunc) ServerOption { return func(s *Server) error { if s.services == nil { - s.services = make(map[string]RpcHandleFunc) + s.services = make(map[string]RPCHandleFunc) } s.services[serviceName] = handler @@ -148,7 +148,7 @@ func makeConnection(server *Server) http.HandlerFunc { } c, err := upgrader.Upgrade(w, req, nil) if err != nil { - server.errCh <- fmt.Errorf("error upgrading connection: %v", err) + server.errCh <- fmt.Errorf("error upgrading connection: %w", err) return } defer c.Close() @@ -156,7 +156,7 @@ func makeConnection(server *Server) http.HandlerFunc { socketConn := newSocketConn(c) _, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig) if err != nil { - server.errCh <- fmt.Errorf("error creating new ssh conn: %v", err) + server.errCh <- fmt.Errorf("error creating new ssh conn: %w", err) return } go ssh.DiscardRequests(reqs) @@ -164,7 +164,7 @@ func makeConnection(server *Server) http.HandlerFunc { for newChannel := range chans { ch, reqs, err := newChannel.Accept() if err != nil { - server.errCh <- fmt.Errorf("error accepting new channel: %v", err) + server.errCh <- fmt.Errorf("error accepting new channel: %w", err) return } go handleNewRequests(server, ch, reqs) @@ -177,7 +177,7 @@ func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Req for req := range reqs { if req.WantReply { if err := req.Reply(true, nil); err != nil { - server.errCh <- fmt.Errorf("error replying to channel request: %v", err) + server.errCh <- fmt.Errorf("error replying to channel request: %w", err) } } if strings.HasPrefix(req.Type, "stream-transport") { @@ -190,14 +190,14 @@ func forwardStream(server *Server, streamName string, channel ssh.Channel) { simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") stream, found := server.streams[simpleStreamName] if !found { - server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName) + server.errCh <- fmt.Errorf("stream '%w' not found", simpleStreamName) return } copy := func(dst io.Writer, src io.Reader) { if _, err := io.Copy(dst, src); err != nil { fmt.Println(err) - server.errCh <- fmt.Errorf("io copy: %v", err) + server.errCh <- fmt.Errorf("io copy: %w", err) return } } @@ -211,33 +211,33 @@ func forwardStream(server *Server, streamName string, channel ssh.Channel) { func handleNewChannel(server *Server, channel ssh.Channel) { stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) - jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server)) + jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server)) } -type RpcHandleFunc func(req *jsonrpc2.Request) (interface{}, error) +type RPCHandleFunc func(req *jsonrpc2.Request) (interface{}, error) type rpcHandler struct { server *Server } -func newRpcHandler(server *Server) *rpcHandler { +func newRPCHandler(server *Server) *rpcHandler { return &rpcHandler{server} } func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { handler, found := r.server.services[req.Method] if !found { - r.server.errCh <- fmt.Errorf("RPC Method: '%v' not serviced", req.Method) + r.server.errCh <- fmt.Errorf("RPC Method: '%s' not serviced", req.Method) return } result, err := handler(req) if err != nil { - r.server.errCh <- fmt.Errorf("error handling: '%v': %v", req.Method, err) + r.server.errCh <- fmt.Errorf("error handling: '%s': %w", req.Method, err) return } if err := conn.Reply(ctx, req.ID, result); err != nil { - r.server.errCh <- fmt.Errorf("error replying: %v", err) + r.server.errCh <- fmt.Errorf("error replying: %w", err) } } From 958990cef83defd3c70278d3b4597c9165641128 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 11:47:52 -0400 Subject: [PATCH 62/68] More linter fixes --- internal/liveshare/session_test.go | 6 +----- internal/liveshare/test/server.go | 16 +++++++++------- internal/liveshare/test/socket.go | 6 +++--- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go index 0ffdfe136..c830c33b1 100644 --- a/internal/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -71,11 +71,7 @@ func TestServerStartSharing(t *testing.T) { testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", startSharing), ) - defer func() { - if err := testServer.Close(); err != nil { - t.Errorf("failed to close test server: %w", err) - } - }() + defer testServer.Close() if err != nil { t.Errorf("error creating mock session: %v", err) diff --git a/internal/liveshare/test/server.go b/internal/liveshare/test/server.go index 8f80d1bce..9b898dafb 100644 --- a/internal/liveshare/test/server.go +++ b/internal/liveshare/test/server.go @@ -138,6 +138,9 @@ var upgrader = websocket.Upgrader{} func makeConnection(server *Server) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if server.relaySAS != "" { // validate the sas key sasParam := req.URL.Query().Get("sb-hc-token") @@ -167,13 +170,13 @@ func makeConnection(server *Server) http.HandlerFunc { server.errCh <- fmt.Errorf("error accepting new channel: %w", err) return } - go handleNewRequests(server, ch, reqs) + go handleNewRequests(ctx, server, ch, reqs) go handleNewChannel(server, ch) } } } -func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { +func handleNewRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { for req := range reqs { if req.WantReply { if err := req.Reply(true, nil); err != nil { @@ -181,16 +184,16 @@ func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Req } } if strings.HasPrefix(req.Type, "stream-transport") { - forwardStream(server, req.Type, channel) + forwardStream(ctx, server, req.Type, channel) } } } -func forwardStream(server *Server, streamName string, channel ssh.Channel) { +func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) { simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") stream, found := server.streams[simpleStreamName] if !found { - server.errCh <- fmt.Errorf("stream '%w' not found", simpleStreamName) + server.errCh <- fmt.Errorf("stream '%s' not found", simpleStreamName) return } @@ -205,8 +208,7 @@ func forwardStream(server *Server, streamName string, channel ssh.Channel) { go copy(stream, channel) go copy(channel, stream) - for { - } + <-ctx.Done() // TODO(josebalius): improve this } func handleNewChannel(server *Server, channel ssh.Channel) { diff --git a/internal/liveshare/test/socket.go b/internal/liveshare/test/socket.go index 9a2d92491..0a7a8baf0 100644 --- a/internal/liveshare/test/socket.go +++ b/internal/liveshare/test/socket.go @@ -28,7 +28,7 @@ func (s *socketConn) Read(b []byte) (int, error) { if s.reader == nil { msgType, r, err := s.Conn.NextReader() if err != nil { - return 0, fmt.Errorf("error getting next reader: %v", err) + return 0, fmt.Errorf("error getting next reader: %w", err) } if msgType != websocket.BinaryMessage { return 0, fmt.Errorf("invalid message type") @@ -54,7 +54,7 @@ func (s *socketConn) Write(b []byte) (int, error) { w, err := s.Conn.NextWriter(websocket.BinaryMessage) if err != nil { - return 0, fmt.Errorf("error getting next writer: %v", err) + return 0, fmt.Errorf("error getting next writer: %w", err) } n, err := w.Write(b) @@ -63,7 +63,7 @@ func (s *socketConn) Write(b []byte) (int, error) { } if err := w.Close(); err != nil { - return 0, fmt.Errorf("error closing writer: %v", err) + return 0, fmt.Errorf("error closing writer: %w", err) } return n, nil From fb53ccb06a1b21e06ce2849d98c2c79f14e4ba76 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 11:56:41 -0400 Subject: [PATCH 63/68] Linter fixes --- internal/liveshare/client_test.go | 8 ++++---- internal/liveshare/port_forwarder_test.go | 14 +++++++------- internal/liveshare/test/socket.go | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/internal/liveshare/client_test.go b/internal/liveshare/client_test.go index 55139d762..12ea903b6 100644 --- a/internal/liveshare/client_test.go +++ b/internal/liveshare/client_test.go @@ -22,7 +22,7 @@ func TestConnect(t *testing.T) { joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { var joinWorkspaceReq joinWorkspaceArgs if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { - return nil, fmt.Errorf("error unmarshaling req: %v", err) + return nil, fmt.Errorf("error unmarshaling req: %w", err) } if joinWorkspaceReq.ID != opts.SessionID { return nil, errors.New("connection session id does not match") @@ -45,7 +45,7 @@ func TestConnect(t *testing.T) { livesharetest.WithRelaySAS(opts.RelaySAS), ) if err != nil { - t.Errorf("error creating Live Share server: %v", err) + t.Errorf("error creating Live Share server: %w", err) } defer server.Close() opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") @@ -62,10 +62,10 @@ func TestConnect(t *testing.T) { select { case err := <-server.Err(): - t.Errorf("error from server: %v", err) + t.Errorf("error from server: %w", err) case err := <-done: if err != nil { - t.Errorf("error from client: %v", err) + t.Errorf("error from client: %w", err) } } } diff --git a/internal/liveshare/port_forwarder_test.go b/internal/liveshare/port_forwarder_test.go index 25b4b2c80..64dfb5c88 100644 --- a/internal/liveshare/port_forwarder_test.go +++ b/internal/liveshare/port_forwarder_test.go @@ -17,7 +17,7 @@ import ( func TestNewPortForwarder(t *testing.T) { testServer, session, err := makeMockSession() if err != nil { - t.Errorf("create mock client: %v", err) + t.Errorf("create mock client: %w", err) } defer testServer.Close() pf := NewPortForwarder(session, "ssh", 80) @@ -42,7 +42,7 @@ func TestPortForwarderStart(t *testing.T) { livesharetest.WithStream("stream-id", stream), ) if err != nil { - t.Errorf("create mock session: %v", err) + t.Errorf("create mock session: %w", err) } defer testServer.Close() @@ -73,23 +73,23 @@ func TestPortForwarderStart(t *testing.T) { } b := make([]byte, len("stream-data")) if _, err := conn.Read(b); err != nil && err != io.EOF { - done <- fmt.Errorf("reading stream: %v", err) + done <- fmt.Errorf("reading stream: %w", err) } if string(b) != "stream-data" { - done <- fmt.Errorf("stream data is not expected value, got: %v", string(b)) + done <- fmt.Errorf("stream data is not expected value, got: %s", string(b)) } if _, err := conn.Write([]byte("new-data")); err != nil { - done <- fmt.Errorf("writing to stream: %v", err) + done <- fmt.Errorf("writing to stream: %w", err) } done <- nil }() select { case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) + t.Errorf("error from server: %w", err) case err := <-done: if err != nil { - t.Errorf("error from client: %v", err) + t.Errorf("error from client: %w", err) } } } diff --git a/internal/liveshare/test/socket.go b/internal/liveshare/test/socket.go index 0a7a8baf0..00cd64a1b 100644 --- a/internal/liveshare/test/socket.go +++ b/internal/liveshare/test/socket.go @@ -59,7 +59,7 @@ func (s *socketConn) Write(b []byte) (int, error) { n, err := w.Write(b) if err != nil { - return 0, fmt.Errorf("error writing: %v", err) + return 0, fmt.Errorf("error writing: %w", err) } if err := w.Close(); err != nil { From c4114cc972ccbccabbd3b7e1152703ebab12a892 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 11:58:55 -0400 Subject: [PATCH 64/68] Linter fixes --- internal/liveshare/session_test.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go index c830c33b1..47bac3108 100644 --- a/internal/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -36,7 +36,7 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, TLSConfig: &tls.Config{InsecureSkipVerify: true}, }) if err != nil { - return nil, nil, fmt.Errorf("error connecting to Live Share: %v", err) + return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err) } return testServer, session, nil } @@ -46,7 +46,7 @@ func TestServerStartSharing(t *testing.T) { startSharing := func(req *jsonrpc2.Request) (interface{}, error) { var args []interface{} if err := json.Unmarshal(*req.Params, &args); err != nil { - return nil, fmt.Errorf("error unmarshaling request: %v", err) + return nil, fmt.Errorf("error unmarshaling request: %w", err) } if len(args) < 3 { return nil, errors.New("not enough arguments to start sharing") @@ -63,7 +63,7 @@ func TestServerStartSharing(t *testing.T) { } if browseURL, ok := args[2].(string); !ok { return nil, errors.New("browse url is not a string") - } else if browseURL != fmt.Sprintf("http://localhost:%v", serverPort) { + } else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) { return nil, errors.New("browseURL does not match expected") } return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil @@ -74,7 +74,7 @@ func TestServerStartSharing(t *testing.T) { defer testServer.Close() if err != nil { - t.Errorf("error creating mock session: %v", err) + t.Errorf("error creating mock session: %w", err) } ctx := context.Background() @@ -82,7 +82,7 @@ func TestServerStartSharing(t *testing.T) { go func() { streamID, err := session.startSharing(ctx, serverProtocol, serverPort) if err != nil { - done <- fmt.Errorf("error sharing server: %v", err) + done <- fmt.Errorf("error sharing server: %w", err) } if streamID.name == "" || streamID.condition == "" { done <- errors.New("stream name or condition is blank") @@ -92,10 +92,10 @@ func TestServerStartSharing(t *testing.T) { select { case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) + t.Errorf("error from server: %w", err) case err := <-done: if err != nil { - t.Errorf("error from client: %v", err) + t.Errorf("error from client: %w", err) } } } @@ -113,7 +113,7 @@ func TestServerGetSharedServers(t *testing.T) { livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), ) if err != nil { - t.Errorf("error creating mock session: %v", err) + t.Errorf("error creating mock session: %w", err) } defer testServer.Close() ctx := context.Background() @@ -121,7 +121,7 @@ func TestServerGetSharedServers(t *testing.T) { go func() { ports, err := session.GetSharedServers(ctx) if err != nil { - done <- fmt.Errorf("error getting shared servers: %v", err) + done <- fmt.Errorf("error getting shared servers: %w", err) } if len(ports) < 1 { done <- errors.New("not enough ports returned") @@ -140,10 +140,10 @@ func TestServerGetSharedServers(t *testing.T) { select { case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) + t.Errorf("error from server: %w", err) case err := <-done: if err != nil { - t.Errorf("error from client: %v", err) + t.Errorf("error from client: %w", err) } } } @@ -152,7 +152,7 @@ func TestServerUpdateSharedVisibility(t *testing.T) { updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) { var req []interface{} if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { - return nil, fmt.Errorf("unmarshal req: %v", err) + return nil, fmt.Errorf("unmarshal req: %w", err) } if len(req) < 2 { return nil, errors.New("request arguments is less than 2") @@ -177,7 +177,7 @@ func TestServerUpdateSharedVisibility(t *testing.T) { livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), ) if err != nil { - t.Errorf("creating mock session: %v", err) + t.Errorf("creating mock session: %w", err) } defer testServer.Close() ctx := context.Background() @@ -187,10 +187,10 @@ func TestServerUpdateSharedVisibility(t *testing.T) { }() select { case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) + t.Errorf("error from server: %w", err) case err := <-done: if err != nil { - t.Errorf("error from client: %v", err) + t.Errorf("error from client: %w", err) } } } From b8f35f950ca104c88489a9dd0f4586cd2a47fa36 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 13:14:35 -0400 Subject: [PATCH 65/68] Linter fixes --- internal/liveshare/client.go | 2 +- internal/liveshare/port_forwarder.go | 2 +- internal/liveshare/session_test.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/liveshare/client.go b/internal/liveshare/client.go index b51e25ea6..76f8146de 100644 --- a/internal/liveshare/client.go +++ b/internal/liveshare/client.go @@ -130,7 +130,7 @@ func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.C return nil, fmt.Errorf("error getting stream id: %w", err) } - span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") + span, _ := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") defer span.Finish() channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) diff --git a/internal/liveshare/port_forwarder.go b/internal/liveshare/port_forwarder.go index 56401cc4d..fcc7ba767 100644 --- a/internal/liveshare/port_forwarder.go +++ b/internal/liveshare/port_forwarder.go @@ -94,7 +94,7 @@ func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error if err != nil { err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err) } - return id, nil + return id, err } func awaitError(ctx context.Context, errc <-chan error) error { diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go index 47bac3108..c9a1be567 100644 --- a/internal/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -71,7 +71,7 @@ func TestServerStartSharing(t *testing.T) { testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", startSharing), ) - defer testServer.Close() + defer testServer.Close() //nolint:stylecheck // httptest.Server does not return errors on Close() if err != nil { t.Errorf("error creating mock session: %w", err) From 08bc181d79f30f344f361494ac490762453edb16 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 13:16:20 -0400 Subject: [PATCH 66/68] Linter fixes --- internal/liveshare/session_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go index c9a1be567..bca11d885 100644 --- a/internal/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -71,7 +71,7 @@ func TestServerStartSharing(t *testing.T) { testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", startSharing), ) - defer testServer.Close() //nolint:stylecheck // httptest.Server does not return errors on Close() + defer testServer.Close() //nolint - httptest.Server does not return errors on Close() if err != nil { t.Errorf("error creating mock session: %w", err) From 65dcb0f428ff703135c43db9ce1246860905f469 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 13:22:20 -0400 Subject: [PATCH 67/68] Linter fixes --- internal/liveshare/session_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go index bca11d885..af41dd117 100644 --- a/internal/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -71,7 +71,7 @@ func TestServerStartSharing(t *testing.T) { testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", startSharing), ) - defer testServer.Close() //nolint - httptest.Server does not return errors on Close() + defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close() if err != nil { t.Errorf("error creating mock session: %w", err) From 5d6ea5029ed7cf2ab39abf1cb2fc37da5883b842 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 13:36:04 -0400 Subject: [PATCH 68/68] Linter fixes --- internal/liveshare/client.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/liveshare/client.go b/internal/liveshare/client.go index 76f8146de..2b1f97831 100644 --- a/internal/liveshare/client.go +++ b/internal/liveshare/client.go @@ -130,8 +130,9 @@ func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.C return nil, fmt.Errorf("error getting stream id: %w", err) } - span, _ := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") + span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") defer span.Finish() + _ = ctx // ctx is not currently used channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) if err != nil {