From 7332aa428c4db7b87c4280063b89a4d10763cb3c Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 22 Jul 2021 00:45:45 +0000 Subject: [PATCH 01/10] Large refactor and solidifying of APIs before tests --- api.go | 130 ---------------------------------------------- client.go | 50 +++++++++++------- client_test.go | 24 +++++++++ connection.go | 43 +++++++++++++++ liveshare.go | 77 --------------------------- port_forwarder.go | 10 ++-- rpc.go | 2 - server.go | 8 +-- session.go | 60 --------------------- socket.go | 105 +++++++++++++++++++++++++++++++++++++ ssh.go | 16 +++--- terminal.go | 8 +-- websocket.go | 105 ------------------------------------- 13 files changed, 224 insertions(+), 414 deletions(-) delete mode 100644 api.go create mode 100644 client_test.go create mode 100644 connection.go delete mode 100644 liveshare.go delete mode 100644 session.go create mode 100644 socket.go delete mode 100644 websocket.go diff --git a/api.go b/api.go deleted file mode 100644 index 55b6e6e93..000000000 --- a/api.go +++ /dev/null @@ -1,130 +0,0 @@ -package liveshare - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "strings" -) - -type api struct { - client *Client - httpClient *http.Client - serviceURI string - workspaceID string -} - -func newAPI(client *Client) *api { - serviceURI := client.liveShare.Configuration.LiveShareEndpoint - if !strings.HasSuffix(client.liveShare.Configuration.LiveShareEndpoint, "/") { - serviceURI = client.liveShare.Configuration.LiveShareEndpoint + "/" - } - - if !strings.Contains(serviceURI, "api/v1.2") { - serviceURI = serviceURI + "api/v1.2" - } - - serviceURI = strings.TrimSuffix(serviceURI, "/") - - return &api{client, &http.Client{}, serviceURI, strings.ToUpper(client.liveShare.Configuration.WorkspaceID)} -} - -type workspaceAccessResponse struct { - SessionToken string `json:"sessionToken"` - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - Name string `json:"name"` - OwnerID string `json:"ownerId"` - JoinLink string `json:"joinLink"` - ConnectLinks []string `json:"connectLinks"` - RelayLink string `json:"relayLink"` - RelaySas string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` - ConversationID string `json:"conversationId"` - AssociatedUserIDs map[string]string `json:"associatedUserIds"` - AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` - IsHostConnected bool `json:"isHostConnected"` - ExpiresAt string `json:"expiresAt"` - InvitationLinks []string `json:"invitationLinks"` - ID string `json:"id"` -} - -func (a *api) workspaceAccess() (*workspaceAccessResponse, error) { - url := fmt.Sprintf("%s/workspace/%s/user", a.serviceURI, a.workspaceID) - - req, err := http.NewRequest(http.MethodPut, url, nil) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - a.setDefaultHeaders(req) - resp, err := a.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error making request: %v", err) - } - - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) - } - - var response workspaceAccessResponse - if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response into json: %v", err) - } - - return &response, nil -} - -func (a *api) setDefaultHeaders(req *http.Request) { - req.Header.Set("Authorization", "Bearer "+a.client.liveShare.Configuration.Token) - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Content-Type", "application/json") -} - -type workspaceInfoResponse struct { - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - Name string `json:"name"` - OwnerID string `json:"ownerId"` - JoinLink string `json:"joinLink"` - ConnectLinks []string `json:"connectLinks"` - RelayLink string `json:"relayLink"` - RelaySas string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` - ConversationID string `json:"conversationId"` - AssociatedUserIDs map[string]string - AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` - IsHostConnected bool `json:"isHostConnected"` - ExpiresAt string `json:"expiresAt"` - InvitationLinks []string `json:"invitationLinks"` - ID string `json:"id"` -} - -func (a *api) workspaceInfo() (*workspaceInfoResponse, error) { - url := fmt.Sprintf("%s/workspace/%s", a.serviceURI, a.workspaceID) - - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - a.setDefaultHeaders(req) - resp, err := a.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error making request: %v", err) - } - - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) - } - - var response workspaceInfoResponse - if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response into json: %v", err) - } - - return &response, nil -} diff --git a/client.go b/client.go index 456a2c321..1ad90b336 100644 --- a/client.go +++ b/client.go @@ -8,31 +8,44 @@ import ( ) type Client struct { - liveShare *LiveShare - session *session + connection Connection + sshSession *sshSession rpc *rpc } -// NewClient is a function ... -func (l *LiveShare) NewClient() *Client { - return &Client{liveShare: l} +type ClientOption func(*Client) error + +func NewClient(opts ...ClientOption) (*Client, error) { + client := new(Client) + + for _, o := range opts { + if err := o(client); err != nil { + return nil, err + } + } + + return client, nil +} + +func WithConnection(connection Connection) ClientOption { + return func(c *Client) error { + if err := connection.validate(); err != nil { + return err + } + + c.connection = connection + return nil + } } func (c *Client) Join(ctx context.Context) (err error) { - api := newAPI(c) - - c.session = newSession(api) - if err := c.session.init(ctx); err != nil { - return fmt.Errorf("error creating session: %v", err) - } - - websocket := newWebsocket(c.session) - if err := websocket.connect(ctx); err != nil { + clientSocket := newSocket(c.connection) + if err := clientSocket.connect(ctx); err != nil { return fmt.Errorf("error connecting websocket: %v", err) } - c.sshSession = newSSH(c.session, websocket) + c.sshSession = newSSH(c.connection.SessionToken, clientSocket) if err := c.sshSession.connect(ctx); err != nil { return fmt.Errorf("error connecting to ssh session: %v", err) } @@ -69,9 +82,9 @@ type joinWorkspaceResult struct { func (c *Client) joinWorkspace(ctx context.Context) (*joinWorkspaceResult, error) { args := joinWorkspaceArgs{ - ID: c.session.workspaceInfo.ID, + ID: c.connection.SessionID, ConnectionMode: "local", - JoiningUserSessionToken: c.session.workspaceAccess.SessionToken, + JoiningUserSessionToken: c.connection.SessionToken, ClientCapabilities: clientCapabilities{ IsNonInteractive: false, }, @@ -99,8 +112,7 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition go ssh.DiscardRequests(reqs) requestType := fmt.Sprintf("stream-transport-%s", streamID) - _, err = channel.SendRequest(requestType, true, nil) - if err != nil { + if _, err = channel.SendRequest(requestType, true, nil); err != nil { return nil, fmt.Errorf("error sending channel request: %v", err) } diff --git a/client_test.go b/client_test.go new file mode 100644 index 000000000..8d118974d --- /dev/null +++ b/client_test.go @@ -0,0 +1,24 @@ +package liveshare + +import ( + "testing" +) + +func TestClientJoin(t *testing.T) { + // connection := Connection{ + // SessionID: "session-id", + // SessionToken: "session-token", + // RelayEndpoint: "relay-endpoint", + // RelaySAS: "relay-sas", + // } + + // client, err := NewClient(WithConnection(connection)) + // if err != nil { + // t.Errorf("error creating client: %v", err) + // } + + // ctx := context.Background() + // if err := client.Join(ctx); err != nil { + // t.Errorf("error joining client: %v", err) + // } +} diff --git a/connection.go b/connection.go new file mode 100644 index 000000000..a97935d3b --- /dev/null +++ b/connection.go @@ -0,0 +1,43 @@ +package liveshare + +import ( + "errors" + "net/url" + "strings" +) + +type Connection struct { + SessionID string `json:"sessionId"` + SessionToken string `json:"sessionToken"` + RelaySAS string `json:"relaySas"` + RelayEndpoint string `json:"relayEndpoint"` +} + +func (r Connection) validate() error { + if r.SessionID == "" { + return errors.New("connection sessionID is required") + } + + if r.SessionToken == "" { + return errors.New("connection sessionToken is required") + } + + if r.RelaySAS == "" { + return errors.New("connection relaySas is required") + } + + if r.RelayEndpoint == "" { + return errors.New("connection relayEndpoint is required") + } + + return nil +} + +func (r Connection) uri(action string) string { + sas := url.QueryEscape(r.RelaySAS) + uri := r.RelayEndpoint + uri = strings.Replace(uri, "sb:", "wss:", -1) + uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) + uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas + return uri +} diff --git a/liveshare.go b/liveshare.go deleted file mode 100644 index 3c4be5c05..000000000 --- a/liveshare.go +++ /dev/null @@ -1,77 +0,0 @@ -package liveshare - -import ( - "errors" - "fmt" - "strings" -) - -type LiveShare struct { - Configuration *Configuration -} - -func New(opts ...Option) (*LiveShare, error) { - configuration := NewConfiguration() - - for _, o := range opts { - if err := o(configuration); err != nil { - return nil, fmt.Errorf("error configuring liveshare: %v", err) - } - } - - if err := configuration.Validate(); err != nil { - return nil, fmt.Errorf("error validating configuration: %v", err) - } - - return &LiveShare{Configuration: configuration}, nil -} - -type Option func(configuration *Configuration) error - -func WithWorkspaceID(id string) Option { - return func(configuration *Configuration) error { - configuration.WorkspaceID = id - return nil - } -} - -func WithLiveShareEndpoint(liveShareEndpoint string) Option { - return func(configuration *Configuration) error { - configuration.LiveShareEndpoint = liveShareEndpoint - return nil - } -} - -func WithToken(token string) Option { - return func(configuration *Configuration) error { - configuration.Token = token - return nil - } -} - -type Configuration struct { - WorkspaceID, LiveShareEndpoint, Token string -} - -func NewConfiguration() *Configuration { - return &Configuration{ - LiveShareEndpoint: "https://prod.liveshare.vsengsaas.visualstudio.com", - } -} - -func (c *Configuration) Validate() error { - errs := []string{} - if c.WorkspaceID == "" { - errs = append(errs, "WorkspaceID is required") - } - - if c.Token == "" { - errs = append(errs, "Token is required") - } - - if len(errs) > 0 { - return errors.New(strings.Join(errs, ", ")) - } - - return nil -} diff --git a/port_forwarder.go b/port_forwarder.go index 20382c208..1227493b2 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -11,18 +11,18 @@ import ( "golang.org/x/crypto/ssh" ) -type LocalPortForwarder struct { +type PortForwarder struct { client *Client server *Server port int channels []ssh.Channel } -func NewLocalPortForwarder(client *Client, server *Server, port int) *LocalPortForwarder { - return &LocalPortForwarder{client, server, port, []ssh.Channel{}} +func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { + return &PortForwarder{client, server, port, []ssh.Channel{}} } -func (l *LocalPortForwarder) Start(ctx context.Context) error { +func (l *PortForwarder) Start(ctx context.Context) error { ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) if err != nil { return fmt.Errorf("error listening on tcp port: %v", err) @@ -42,7 +42,7 @@ func (l *LocalPortForwarder) Start(ctx context.Context) error { return nil } -func (l *LocalPortForwarder) handleConnection(ctx context.Context, conn net.Conn) { +func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) if err != nil { log.Println("errrr handle Connect") diff --git a/rpc.go b/rpc.go index de427cda9..d40046471 100644 --- a/rpc.go +++ b/rpc.go @@ -78,7 +78,5 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{} }() - } else { - // TODO(josebalius): Handle } } diff --git a/server.go b/server.go index 71ec9d4dd..6f17d5ac5 100644 --- a/server.go +++ b/server.go @@ -13,12 +13,12 @@ type Server struct { streamName, streamCondition string } -func (c *Client) NewServer() (*Server, error) { - if !c.hasJoined() { - return nil, errors.New("LiveShareClient must join before creating server") +func NewServer(client *Client) (*Server, error) { + if !client.hasJoined() { + return nil, errors.New("client must join before creating server") } - return &Server{client: c}, nil + return &Server{client: client}, nil } type Port struct { diff --git a/session.go b/session.go deleted file mode 100644 index d0492a10e..000000000 --- a/session.go +++ /dev/null @@ -1,60 +0,0 @@ -package liveshare - -import ( - "context" - "fmt" - "net/url" - "strings" - - "golang.org/x/sync/errgroup" -) - -type session struct { - api *api - - workspaceAccess *workspaceAccessResponse - workspaceInfo *workspaceInfoResponse -} - -func newSession(api *api) *session { - return &session{api: api} -} - -func (s *session) init(ctx context.Context) error { - g, ctx := errgroup.WithContext(ctx) - - g.Go(func() error { - workspaceAccess, err := s.api.workspaceAccess() - if err != nil { - return fmt.Errorf("error getting workspace access: %v", err) - } - s.workspaceAccess = workspaceAccess - return nil - }) - - g.Go(func() error { - workspaceInfo, err := s.api.workspaceInfo() - if err != nil { - return fmt.Errorf("error getting workspace info: %v", err) - } - s.workspaceInfo = workspaceInfo - return nil - }) - - if err := g.Wait(); err != nil { - return err - } - - return nil -} - -// Reference: -// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 -func (s *session) relayURI(action string) string { - relaySas := url.QueryEscape(s.workspaceAccess.RelaySas) - relayURI := s.workspaceAccess.RelayLink - relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) - relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) - relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas - return relayURI -} diff --git a/socket.go b/socket.go new file mode 100644 index 000000000..c3f75b9db --- /dev/null +++ b/socket.go @@ -0,0 +1,105 @@ +package liveshare + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type socket struct { + addr string + conn *websocket.Conn + readMutex sync.Mutex + writeMutex sync.Mutex + reader io.Reader +} + +func newSocket(clientConn Connection) *socket { + return &socket{addr: clientConn.uri("connect")} +} + +func (s *socket) connect(ctx context.Context) error { + ws, _, err := websocket.DefaultDialer.Dial(s.addr, nil) + if err != nil { + return err + } + s.conn = ws + return nil +} + +func (s *socket) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + messageType, reader, err := s.conn.NextReader() + if err != nil { + return 0, err + } + + if messageType != websocket.BinaryMessage { + return 0, errors.New("unexpected websocket message type") + } + + s.reader = reader + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socket) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, err + } + + bytesWritten, err := nextWriter.Write(b) + nextWriter.Close() + + return bytesWritten, err +} + +func (s *socket) Close() error { + return s.conn.Close() +} + +func (s *socket) LocalAddr() net.Addr { + return s.conn.LocalAddr() +} + +func (s *socket) RemoteAddr() net.Addr { + return s.conn.RemoteAddr() +} + +func (s *socket) SetDeadline(t time.Time) error { + if err := s.SetReadDeadline(t); err != nil { + return err + } + + return s.SetWriteDeadline(t) +} + +func (s *socket) SetReadDeadline(t time.Time) error { + return s.conn.SetReadDeadline(t) +} + +func (s *socket) SetWriteDeadline(t time.Time) error { + return s.conn.SetWriteDeadline(t) +} diff --git a/ssh.go b/ssh.go index 9ae32ed7c..3ea2d2777 100644 --- a/ssh.go +++ b/ssh.go @@ -12,22 +12,22 @@ import ( type sshSession struct { *ssh.Session - session *session - socket net.Conn - conn ssh.Conn - reader io.Reader - writer io.Writer + token string + socket net.Conn + conn ssh.Conn + reader io.Reader + writer io.Writer } -func newSSH(session *session, socket net.Conn) *sshSession { - return &sshSession{session: session, socket: socket} +func newSSH(token string, socket net.Conn) *sshSession { + return &sshSession{token: token, socket: socket} } func (s *sshSession) connect(ctx context.Context) error { clientConfig := ssh.ClientConfig{ User: "", Auth: []ssh.AuthMethod{ - ssh.Password(s.session.workspaceAccess.SessionToken), + ssh.Password(s.token), }, HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), diff --git a/terminal.go b/terminal.go index 631e75912..1621559a1 100644 --- a/terminal.go +++ b/terminal.go @@ -13,13 +13,13 @@ type Terminal struct { client *Client } -func (c *Client) NewTerminal() (*Terminal, error) { - if !c.hasJoined() { - return nil, errors.New("LiveShareClient must join before creating terminal") +func NewTerminal(client *Client) (*Terminal, error) { + if !client.hasJoined() { + return nil, errors.New("client must join before creating terminal") } return &Terminal{ - client: c, + client: client, }, nil } diff --git a/websocket.go b/websocket.go deleted file mode 100644 index ae163e6e2..000000000 --- a/websocket.go +++ /dev/null @@ -1,105 +0,0 @@ -package liveshare - -import ( - "context" - "errors" - "io" - "net" - "sync" - "time" - - gorillawebsocket "github.com/gorilla/websocket" -) - -type websocket struct { - session *session - conn *gorillawebsocket.Conn - readMutex sync.Mutex - writeMutex sync.Mutex - reader io.Reader -} - -func newWebsocket(session *session) *websocket { - return &websocket{session: session} -} - -func (w *websocket) connect(ctx context.Context) error { - ws, _, err := gorillawebsocket.DefaultDialer.Dial(w.session.relayURI("connect"), nil) - if err != nil { - return err - } - w.conn = ws - return nil -} - -func (w *websocket) Read(b []byte) (int, error) { - w.readMutex.Lock() - defer w.readMutex.Unlock() - - if w.reader == nil { - messageType, reader, err := w.conn.NextReader() - if err != nil { - return 0, err - } - - if messageType != gorillawebsocket.BinaryMessage { - return 0, errors.New("unexpected websocket message type") - } - - w.reader = reader - } - - bytesRead, err := w.reader.Read(b) - if err != nil { - w.reader = nil - - if err == io.EOF { - err = nil - } - } - - return bytesRead, err -} - -func (w *websocket) Write(b []byte) (int, error) { - w.writeMutex.Lock() - defer w.writeMutex.Unlock() - - nextWriter, err := w.conn.NextWriter(gorillawebsocket.BinaryMessage) - if err != nil { - return 0, err - } - - bytesWritten, err := nextWriter.Write(b) - nextWriter.Close() - - return bytesWritten, err -} - -func (w *websocket) Close() error { - return w.conn.Close() -} - -func (w *websocket) LocalAddr() net.Addr { - return w.conn.LocalAddr() -} - -func (w *websocket) RemoteAddr() net.Addr { - return w.conn.RemoteAddr() -} - -func (w *websocket) SetDeadline(t time.Time) error { - if err := w.SetReadDeadline(t); err != nil { - return err - } - - return w.SetWriteDeadline(t) -} - -func (w *websocket) SetReadDeadline(t time.Time) error { - return w.conn.SetReadDeadline(t) -} - -func (w *websocket) SetWriteDeadline(t time.Time) error { - return w.conn.SetWriteDeadline(t) -} From fddcd876b0b6e50959b0530662c50dda1b0079c1 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 22 Jul 2021 01:02:03 +0000 Subject: [PATCH 02/10] Some more cleanup to the port forwarder and connection --- connection.go | 16 ++++++++-------- port_forwarder.go | 28 +++++++++++++--------------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/connection.go b/connection.go index a97935d3b..eda050f63 100644 --- a/connection.go +++ b/connection.go @@ -7,27 +7,27 @@ import ( ) type Connection struct { - SessionID string `json:"sessionId"` - SessionToken string `json:"sessionToken"` - RelaySAS string `json:"relaySas"` - RelayEndpoint string `json:"relayEndpoint"` + SessionID string + SessionToken string + RelaySAS string + RelayEndpoint string } func (r Connection) validate() error { if r.SessionID == "" { - return errors.New("connection sessionID is required") + return errors.New("connection SessionID is required") } if r.SessionToken == "" { - return errors.New("connection sessionToken is required") + return errors.New("connection SessionToken is required") } if r.RelaySAS == "" { - return errors.New("connection relaySas is required") + return errors.New("connection RelaySAS is required") } if r.RelayEndpoint == "" { - return errors.New("connection relayEndpoint is required") + return errors.New("connection RelayEndpoint is required") } return nil diff --git a/port_forwarder.go b/port_forwarder.go index 1227493b2..8d42f3c05 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -4,22 +4,24 @@ import ( "context" "fmt" "io" - "log" "net" "strconv" - - "golang.org/x/crypto/ssh" ) type PortForwarder struct { - client *Client - server *Server - port int - channels []ssh.Channel + client *Client + server *Server + port int + errCh chan error } func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { - return &PortForwarder{client, server, port, []ssh.Channel{}} + return &PortForwarder{ + client: client, + server: server, + port: port, + errCh: make(chan error), + } } func (l *PortForwarder) Start(ctx context.Context) error { @@ -37,22 +39,18 @@ func (l *PortForwarder) Start(ctx context.Context) error { go l.handleConnection(ctx, conn) } - // clean up after ourselves - return nil } func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) if err != nil { - log.Println("errrr handle Connect") - log.Println(err) // TODO(josebalius) handle this somehow + l.errCh <- fmt.Errorf("error opening streaming channel for new connection: %v", err) + return } - l.channels = append(l.channels, channel) copyConn := func(writer io.Writer, reader io.Reader) { - _, err := io.Copy(writer, reader) - if err != nil { + if _, err := io.Copy(writer, reader); err != nil { channel.Close() conn.Close() } From a99d0f5495a575c0694307dbbe606210b16830d7 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 22 Jul 2021 01:07:06 +0000 Subject: [PATCH 03/10] Better naming for rpc client and ssh session --- client.go | 14 +++++++------- rpc.go | 10 +++++----- ssh.go | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 1ad90b336..d0e6e84a8 100644 --- a/client.go +++ b/client.go @@ -10,8 +10,8 @@ import ( type Client struct { connection Connection - sshSession *sshSession - rpc *rpc + ssh *sshSession + rpc *rpcClient } type ClientOption func(*Client) error @@ -45,12 +45,12 @@ func (c *Client) Join(ctx context.Context) (err error) { return fmt.Errorf("error connecting websocket: %v", err) } - c.sshSession = newSSH(c.connection.SessionToken, clientSocket) - if err := c.sshSession.connect(ctx); err != nil { + c.ssh = newSshSession(c.connection.SessionToken, clientSocket) + if err := c.ssh.connect(ctx); err != nil { return fmt.Errorf("error connecting to ssh session: %v", err) } - c.rpc = newRPC(c.sshSession) + c.rpc = newRpcClient(c.ssh) c.rpc.connect(ctx) _, err = c.joinWorkspace(ctx) @@ -62,7 +62,7 @@ func (c *Client) Join(ctx context.Context) (err error) { } func (c *Client) hasJoined() bool { - return c.sshSession != nil && c.rpc != nil + return c.ssh != nil && c.rpc != nil } type clientCapabilities struct { @@ -105,7 +105,7 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition return nil, fmt.Errorf("error getting stream id: %v", err) } - channel, reqs, err := c.sshSession.conn.OpenChannel("session", nil) + channel, reqs, err := c.ssh.conn.OpenChannel("session", nil) if err != nil { return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) } diff --git a/rpc.go b/rpc.go index d40046471..d624bbd74 100644 --- a/rpc.go +++ b/rpc.go @@ -9,22 +9,22 @@ import ( "github.com/sourcegraph/jsonrpc2" ) -type rpc struct { +type rpcClient struct { *jsonrpc2.Conn conn io.ReadWriteCloser handler *rpcHandler } -func newRPC(conn io.ReadWriteCloser) *rpc { - return &rpc{conn: conn, handler: newRPCHandler()} +func newRpcClient(conn io.ReadWriteCloser) *rpcClient { + return &rpcClient{conn: conn, handler: newRPCHandler()} } -func (r *rpc) connect(ctx context.Context) { +func (r *rpcClient) connect(ctx context.Context) { stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) } -func (r *rpc) do(ctx context.Context, method string, args interface{}, result interface{}) error { +func (r *rpcClient) do(ctx context.Context, method string, args interface{}, result interface{}) error { waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { return fmt.Errorf("error on dispatch call: %v", err) diff --git a/ssh.go b/ssh.go index 3ea2d2777..e22cd69d1 100644 --- a/ssh.go +++ b/ssh.go @@ -19,7 +19,7 @@ type sshSession struct { writer io.Writer } -func newSSH(token string, socket net.Conn) *sshSession { +func newSshSession(token string, socket net.Conn) *sshSession { return &sshSession{token: token, socket: socket} } From b9cd9af7fa83ad2fd7cca4727d5adc1be51fa384 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 23 Jul 2021 01:17:32 +0000 Subject: [PATCH 04/10] Start of tests and comments --- client.go | 5 ++++ client_test.go | 73 +++++++++++++++++++++++++++++++++++++++++++++-- connection.go | 1 + port_forwarder.go | 3 ++ 4 files changed, 79 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index d0e6e84a8..435dd6775 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "golang.org/x/crypto/ssh" ) +// A Client capable of joining a liveshare connection type Client struct { connection Connection @@ -14,8 +15,10 @@ type Client struct { rpc *rpcClient } +// A ClientOption is a function that modifies a client type ClientOption func(*Client) error +// NewClient accepts a range of options, applies them and returns a client func NewClient(opts ...ClientOption) (*Client, error) { client := new(Client) @@ -28,6 +31,7 @@ func NewClient(opts ...ClientOption) (*Client, error) { return client, nil } +// WithConnection is a ClientOption that accepts a Connection func WithConnection(connection Connection) ClientOption { return func(c *Client) error { if err := connection.validate(); err != nil { @@ -39,6 +43,7 @@ func WithConnection(connection Connection) ClientOption { } } +// Join is a method that joins the client to the liveshare session func (c *Client) Join(ctx context.Context) (err error) { clientSocket := newSocket(c.connection) if err := clientSocket.connect(ctx); err != nil { diff --git a/client_test.go b/client_test.go index 8d118974d..86f732637 100644 --- a/client_test.go +++ b/client_test.go @@ -1,22 +1,89 @@ package liveshare import ( + "fmt" + "net/http" + "net/http/httptest" "testing" + + "github.com/gorilla/websocket" ) +func TestNewClient(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Errorf("error creating new client: %v", err) + } + if client == nil { + t.Error("client is nil") + } +} + +func TestNewClientValidConnection(t *testing.T) { + connection := Connection{"1", "2", "3", "4"} + + client, err := NewClient(WithConnection(connection)) + if err != nil { + t.Errorf("error creating new client: %v", err) + } + if client == nil { + t.Error("client is nil") + } +} + +func TestNewClientWithInvalidConnection(t *testing.T) { + connection := Connection{} + + if _, err := NewClient(WithConnection(connection)); err == nil { + t.Error("err is nil") + } +} + +var upgrader = websocket.Upgrader{} + +func newMockLiveShareServer() *httptest.Server { + endpoint := func(w http.ResponseWriter, req *http.Request) { + c, err := upgrader.Upgrade(w, req, nil) + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + + for { + mt, message, err := c.ReadMessage() + if err != nil { + fmt.Println(err) + break + } + + err = c.WriteMessage(mt, message) + if err != nil { + fmt.Println(err) + break + } + + } + } + + return httptest.NewTLSServer(http.HandlerFunc(endpoint)) +} + func TestClientJoin(t *testing.T) { + // server := newMockLiveShareServer() + // defer server.Close() + // connection := Connection{ // SessionID: "session-id", // SessionToken: "session-token", - // RelayEndpoint: "relay-endpoint", // RelaySAS: "relay-sas", + // RelayEndpoint: "sb" + strings.TrimPrefix(server.URL, "https"), // } // client, err := NewClient(WithConnection(connection)) // if err != nil { - // t.Errorf("error creating client: %v", err) + // t.Errorf("error creating new client: %v", err) // } - // ctx := context.Background() // if err := client.Join(ctx); err != nil { // t.Errorf("error joining client: %v", err) diff --git a/connection.go b/connection.go index eda050f63..c1a4632c8 100644 --- a/connection.go +++ b/connection.go @@ -6,6 +6,7 @@ import ( "strings" ) +// A Connection represents a set of values necessary to join a liveshare connection type Connection struct { SessionID string SessionToken string diff --git a/port_forwarder.go b/port_forwarder.go index 8d42f3c05..6d459b4d6 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -8,6 +8,7 @@ import ( "strconv" ) +// A PortForwader can forward ports from a remote liveshare host to localhost type PortForwarder struct { client *Client server *Server @@ -15,6 +16,7 @@ type PortForwarder struct { errCh chan error } +// NewPortForwarder creates a new PortForwader with a given client, server and port func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { return &PortForwarder{ client: client, @@ -24,6 +26,7 @@ func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { } } +// Start is a method to start forwarding the server to a localhost port func (l *PortForwarder) Start(ctx context.Context) error { ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) if err != nil { From 9132a28e9cf2a09359a80e104a558c77a0c0abea Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 23 Jul 2021 19:15:54 +0000 Subject: [PATCH 05/10] Checking point after continuing to flesh out mock server --- client.go | 11 ++- client_test.go | 100 +++++++++++----------- socket.go | 17 +++- test/server.go | 226 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 299 insertions(+), 55 deletions(-) create mode 100644 test/server.go diff --git a/client.go b/client.go index 435dd6775..a8a1e3864 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package liveshare import ( "context" + "crypto/tls" "fmt" "golang.org/x/crypto/ssh" @@ -10,6 +11,7 @@ import ( // A Client capable of joining a liveshare connection type Client struct { connection Connection + tlsConfig *tls.Config ssh *sshSession rpc *rpcClient @@ -43,9 +45,16 @@ func WithConnection(connection Connection) ClientOption { } } +func WithTLSConfig(tlsConfig *tls.Config) ClientOption { + return func(c *Client) error { + c.tlsConfig = tlsConfig + return nil + } +} + // Join is a method that joins the client to the liveshare session func (c *Client) Join(ctx context.Context) (err error) { - clientSocket := newSocket(c.connection) + clientSocket := newSocket(c.connection, c.tlsConfig) if err := clientSocket.connect(ctx); err != nil { return fmt.Errorf("error connecting websocket: %v", err) } diff --git a/client_test.go b/client_test.go index 86f732637..bf77e3dce 100644 --- a/client_test.go +++ b/client_test.go @@ -1,12 +1,14 @@ package liveshare import ( + "context" + "crypto/tls" "fmt" - "net/http" - "net/http/httptest" + "strings" "testing" - "github.com/gorilla/websocket" + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" ) func TestNewClient(t *testing.T) { @@ -39,53 +41,51 @@ func TestNewClientWithInvalidConnection(t *testing.T) { } } -var upgrader = websocket.Upgrader{} - -func newMockLiveShareServer() *httptest.Server { - endpoint := func(w http.ResponseWriter, req *http.Request) { - c, err := upgrader.Upgrade(w, req, nil) - if err != nil { - fmt.Println(err) - return - } - defer c.Close() - - for { - mt, message, err := c.ReadMessage() - if err != nil { - fmt.Println(err) - break - } - - err = c.WriteMessage(mt, message) - if err != nil { - fmt.Println(err) - break - } - - } +func TestClientJoin(t *testing.T) { + sessionToken := "session-token" + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + return 1, nil } - return httptest.NewTLSServer(http.HandlerFunc(endpoint)) -} - -func TestClientJoin(t *testing.T) { - // server := newMockLiveShareServer() - // defer server.Close() - - // connection := Connection{ - // SessionID: "session-id", - // SessionToken: "session-token", - // RelaySAS: "relay-sas", - // RelayEndpoint: "sb" + strings.TrimPrefix(server.URL, "https"), - // } - - // client, err := NewClient(WithConnection(connection)) - // if err != nil { - // t.Errorf("error creating new client: %v", err) - // } - // ctx := context.Background() - // if err := client.Join(ctx); err != nil { - // t.Errorf("error joining client: %v", err) - // } + server, err := livesharetest.NewServer( + livesharetest.WithPassword(sessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + ) + if err != nil { + t.Errorf("error creating liveshare server: %v", err) + } + defer server.Close() + + ctx := context.Background() + connection := Connection{ + SessionID: "session-id", + SessionToken: sessionToken, + RelaySAS: "relay-sas", + RelayEndpoint: "sb" + strings.TrimPrefix(server.URL(), "https"), + } + + tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + client, err := NewClient(WithConnection(connection), tlsConfig) + if err != nil { + t.Errorf("error creating new client: %v", err) + } + + clientErr := make(chan error) + go func() { + if err := client.Join(ctx); err != nil { + clientErr <- fmt.Errorf("error joining client: %v", err) + return + } + + ctx.Done() + }() + + select { + case err := <-server.Err(): + t.Errorf("error from server: %v", err) + case err := <-clientErr: + t.Errorf("error from client: %v", err) + case <-ctx.Done(): + return + } } diff --git a/socket.go b/socket.go index c3f75b9db..e4f80a0cf 100644 --- a/socket.go +++ b/socket.go @@ -2,9 +2,11 @@ package liveshare import ( "context" + "crypto/tls" "errors" "io" "net" + "net/http" "sync" "time" @@ -12,19 +14,26 @@ import ( ) type socket struct { - addr string + addr string + tlsConfig *tls.Config + conn *websocket.Conn readMutex sync.Mutex writeMutex sync.Mutex reader io.Reader } -func newSocket(clientConn Connection) *socket { - return &socket{addr: clientConn.uri("connect")} +func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket { + return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig} } func (s *socket) connect(ctx context.Context) error { - ws, _, err := websocket.DefaultDialer.Dial(s.addr, nil) + dialer := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + TLSClientConfig: s.tlsConfig, + } + ws, _, err := dialer.Dial(s.addr, nil) if err != nil { return err } diff --git a/test/server.go b/test/server.go new file mode 100644 index 000000000..ed8666cce --- /dev/null +++ b/test/server.go @@ -0,0 +1,226 @@ +package livesharetest + +import ( + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "path/filepath" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/sourcegraph/jsonrpc2" + "golang.org/x/crypto/ssh" +) + +type Server struct { + password string + services map[string]RpcHandleFunc + + sshConfig *ssh.ServerConfig + httptestServer *httptest.Server + errCh chan error +} + +func NewServer(opts ...ServerOption) (*Server, error) { + server := new(Server) + + for _, o := range opts { + if err := o(server); err != nil { + return nil, err + } + } + + server.sshConfig = &ssh.ServerConfig{ + PasswordCallback: sshPasswordCallback(server.password), + } + b, err := ioutil.ReadFile(filepath.Join("test", "private.key")) + if err != nil { + return nil, fmt.Errorf("error reading private.key: %v", err) + } + privateKey, err := ssh.ParsePrivateKey(b) + if err != nil { + return nil, fmt.Errorf("error parsing key: %v", err) + } + server.sshConfig.AddHostKey(privateKey) + + server.errCh = make(chan error) + server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server))) + return server, nil +} + +type ServerOption func(*Server) error + +func WithPassword(password string) ServerOption { + return func(s *Server) error { + s.password = password + return nil + } +} + +func WithService(serviceName string, handler RpcHandleFunc) ServerOption { + return func(s *Server) error { + if s.services == nil { + s.services = make(map[string]RpcHandleFunc) + } + + s.services[serviceName] = handler + return nil + } +} + +func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { + return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + if string(password) == serverPassword { + return nil, nil + } + return nil, errors.New("password rejected") + } +} + +func (s *Server) Close() { + s.httptestServer.Close() +} + +func (s *Server) URL() string { + return s.httptestServer.URL +} + +func (s *Server) Err() <-chan error { + return s.errCh +} + +var upgrader = websocket.Upgrader{} + +func newConnection(server *Server) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + c, err := upgrader.Upgrade(w, req, nil) + if err != nil { + server.errCh <- fmt.Errorf("error upgrading connection: %v", err) + return + } + defer c.Close() + + socketConn := newSocketConn(c) + _, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig) + if err != nil { + server.errCh <- fmt.Errorf("error creating new ssh conn: %v", err) + return + } + go ssh.DiscardRequests(reqs) + + for newChannel := range chans { + ch, reqs, err := newChannel.Accept() + if err != nil { + server.errCh <- fmt.Errorf("error accepting new channel: %v", err) + return + } + go ssh.DiscardRequests(reqs) + go handleNewChannel(server, ch) + } + } +} + +func handleNewChannel(server *Server, channel ssh.Channel) { + stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) + jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server)) +} + +type RpcHandleFunc func(req *jsonrpc2.Request) (interface{}, error) + +type rpcHandler struct { + server *Server +} + +func newRpcHandler(server *Server) *rpcHandler { + return &rpcHandler{server} +} + +func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + handler, found := r.server.services[req.Method] + if !found { + r.server.errCh <- fmt.Errorf("RPC Method: '%v' not serviced", req.Method) + return + } + + result, err := handler(req) + if err != nil { + r.server.errCh <- fmt.Errorf("error handling: '%v': %v", req.Method, err) + return + } + + if err := conn.Reply(ctx, req.ID, result); err != nil { + r.server.errCh <- fmt.Errorf("error replying: %v", err) + } +} + +type socketConn struct { + *websocket.Conn + + reader io.Reader + writeMutex sync.Mutex + readMutex sync.Mutex +} + +func newSocketConn(conn *websocket.Conn) *socketConn { + return &socketConn{Conn: conn} +} + +func (s *socketConn) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + msgType, r, err := s.Conn.NextReader() + if err != nil { + return 0, fmt.Errorf("error getting next reader: %v", err) + } + if msgType != websocket.BinaryMessage { + return 0, fmt.Errorf("invalid message type") + } + s.reader = r + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socketConn) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + w, err := s.Conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, fmt.Errorf("error getting next writer: %v", err) + } + + n, err := w.Write(b) + if err != nil { + return 0, fmt.Errorf("error writing: %v", err) + } + + if err := w.Close(); err != nil { + return 0, fmt.Errorf("error closing writer: %v", err) + } + + return n, nil +} + +func (s *socketConn) SetDeadline(deadline time.Time) error { + if err := s.Conn.SetReadDeadline(deadline); err != nil { + return err + } + return s.Conn.SetWriteDeadline(deadline) +} From fcfb10cb56e6aaaae9c745bc1ee00f48897220f9 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 23 Jul 2021 20:24:50 +0000 Subject: [PATCH 06/10] Working test for Client.Join --- client_test.go | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/client_test.go b/client_test.go index bf77e3dce..fdf566fc0 100644 --- a/client_test.go +++ b/client_test.go @@ -42,27 +42,26 @@ func TestNewClientWithInvalidConnection(t *testing.T) { } func TestClientJoin(t *testing.T) { - sessionToken := "session-token" + connection := Connection{ + SessionID: "session-id", + SessionToken: "session-token", + RelaySAS: "relay-sas", + } joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { - return 1, nil + return joinWorkspaceResult{1}, nil } server, err := livesharetest.NewServer( - livesharetest.WithPassword(sessionToken), + livesharetest.WithPassword(connection.SessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), ) if err != nil { t.Errorf("error creating liveshare server: %v", err) } defer server.Close() + connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") ctx := context.Background() - connection := Connection{ - SessionID: "session-id", - SessionToken: sessionToken, - RelaySAS: "relay-sas", - RelayEndpoint: "sb" + strings.TrimPrefix(server.URL(), "https"), - } tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) client, err := NewClient(WithConnection(connection), tlsConfig) @@ -70,22 +69,22 @@ func TestClientJoin(t *testing.T) { t.Errorf("error creating new client: %v", err) } - clientErr := make(chan error) + done := make(chan error) go func() { if err := client.Join(ctx); err != nil { - clientErr <- fmt.Errorf("error joining client: %v", err) + done <- fmt.Errorf("error joining client: %v", err) return } - ctx.Done() + done <- nil }() select { case err := <-server.Err(): t.Errorf("error from server: %v", err) - case err := <-clientErr: - t.Errorf("error from client: %v", err) - case <-ctx.Done(): - return + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } } } From 91114d35c3d04245a58f78ebf2feb6bb5edde4e2 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Sat, 24 Jul 2021 03:44:20 +0000 Subject: [PATCH 07/10] More tests --- client_test.go | 19 +++++ port_forwarder_test.go | 1 + server_test.go | 186 +++++++++++++++++++++++++++++++++++++++++ test/server.go | 16 ++++ 4 files changed, 222 insertions(+) create mode 100644 port_forwarder_test.go create mode 100644 server_test.go diff --git a/client_test.go b/client_test.go index fdf566fc0..110c7e3b9 100644 --- a/client_test.go +++ b/client_test.go @@ -3,6 +3,8 @@ package liveshare import ( "context" "crypto/tls" + "encoding/json" + "errors" "fmt" "strings" "testing" @@ -48,12 +50,29 @@ func TestClientJoin(t *testing.T) { RelaySAS: "relay-sas", } joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + var joinWorkspaceReq joinWorkspaceArgs + if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { + return nil, fmt.Errorf("error unmarshaling req: %v", err) + } + if joinWorkspaceReq.ID != connection.SessionID { + return nil, errors.New("connection session id does not match") + } + if joinWorkspaceReq.ConnectionMode != "local" { + return nil, errors.New("connection mode is not local") + } + if joinWorkspaceReq.JoiningUserSessionToken != connection.SessionToken { + return nil, errors.New("connection user token does not match") + } + if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false { + return nil, errors.New("non interactive is not false") + } return joinWorkspaceResult{1}, nil } server, err := livesharetest.NewServer( livesharetest.WithPassword(connection.SessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + livesharetest.WithRelaySAS(connection.RelaySAS), ) if err != nil { t.Errorf("error creating liveshare server: %v", err) diff --git a/port_forwarder_test.go b/port_forwarder_test.go new file mode 100644 index 000000000..e3e219705 --- /dev/null +++ b/port_forwarder_test.go @@ -0,0 +1 @@ +package liveshare diff --git a/server_test.go b/server_test.go new file mode 100644 index 000000000..cc2b9adbd --- /dev/null +++ b/server_test.go @@ -0,0 +1,186 @@ +package liveshare + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestNewServerWithNotJoinedClient(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Errorf("error creating new client: %v", err) + } + if _, err := NewServer(client); err == nil { + t.Error("expected error") + } +} + +func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { + connection := Connection{ + SessionID: "session-id", + SessionToken: "session-token", + RelaySAS: "relay-sas", + } + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + return joinWorkspaceResult{1}, nil + } + opts = append( + opts, + livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + ) + testServer, err := livesharetest.NewServer( + opts..., + ) + connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https") + tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + client, err := NewClient(WithConnection(connection), tlsConfig) + if err != nil { + return nil, nil, fmt.Errorf("error creating new client: %v", err) + } + ctx := context.Background() + if err := client.Join(ctx); err != nil { + return nil, nil, fmt.Errorf("error joining client: %v", err) + } + return testServer, client, nil +} + +func TestNewServer(t *testing.T) { + testServer, client, err := newMockJoinedClient() + defer testServer.Close() + if err != nil { + t.Errorf("error creating mock joined client: %v", err) + } + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + if server == nil { + t.Error("server is nil") + } +} + +func TestServerStartSharing(t *testing.T) { + serverPort, serverProtocol := 2222, "sshd" + startSharing := func(req *jsonrpc2.Request) (interface{}, error) { + var args []interface{} + if err := json.Unmarshal(*req.Params, &args); err != nil { + return nil, fmt.Errorf("error unmarshaling request: %v", err) + } + if len(args) < 3 { + return nil, errors.New("not enough arguments to start sharing") + } + if port, ok := args[0].(float64); !ok { + return nil, errors.New("port argument is not an int") + } else if port != float64(serverPort) { + return nil, errors.New("port does not match serverPort") + } + if protocol, ok := args[1].(string); !ok { + return nil, errors.New("protocol argument is not a string") + } else if protocol != serverProtocol { + return nil, errors.New("protocol does not match serverProtocol") + } + if browseURL, ok := args[2].(string); !ok { + return nil, errors.New("browse url is not a string") + } else if browseURL != fmt.Sprintf("http://localhost:%v", serverPort) { + return nil, errors.New("browseURL does not match expected") + } + return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil + } + testServer, client, err := newMockJoinedClient( + livesharetest.WithService("serverSharing.startSharing", startSharing), + ) + defer testServer.Close() + if err != nil { + t.Errorf("error creating mock joined client: %v", err) + } + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + ctx := context.Background() + + done := make(chan error) + go func() { + if err := server.StartSharing(ctx, serverProtocol, serverPort); err != nil { + done <- fmt.Errorf("error sharing server: %v", err) + } + if server.streamName == "" || server.streamCondition == "" { + done <- errors.New("stream name or condition is blank") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} + +func TestServerGetSharedServers(t *testing.T) { + sharedServer := Port{ + SourcePort: 2222, + StreamName: "stream-name", + StreamCondition: "stream-condition", + } + getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) { + return Ports{&sharedServer}, nil + } + testServer, client, err := newMockJoinedClient( + livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), + ) + if err != nil { + t.Errorf("error creating new mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + ctx := context.Background() + done := make(chan error) + go func() { + ports, err := server.GetSharedServers(ctx) + if err != nil { + done <- fmt.Errorf("error getting shared servers: %v", err) + } + if len(ports) < 1 { + done <- errors.New("not enough ports returned") + } + if ports[0].SourcePort != sharedServer.SourcePort { + done <- errors.New("source port does not match") + } + if ports[0].StreamName != sharedServer.StreamName { + done <- errors.New("stream name does not match") + } + if ports[0].StreamCondition != sharedServer.StreamCondition { + done <- errors.New("stream condiion does not match") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} + +func TestServerUpdateSharedVisibility(t *testing.T) { + +} diff --git a/test/server.go b/test/server.go index ed8666cce..abb7ac96a 100644 --- a/test/server.go +++ b/test/server.go @@ -20,6 +20,7 @@ import ( type Server struct { password string services map[string]RpcHandleFunc + relaySAS string sshConfig *ssh.ServerConfig httptestServer *httptest.Server @@ -73,6 +74,13 @@ func WithService(serviceName string, handler RpcHandleFunc) ServerOption { } } +func WithRelaySAS(sas string) ServerOption { + return func(s *Server) error { + s.relaySAS = sas + return nil + } +} + func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { if string(password) == serverPassword { @@ -98,6 +106,14 @@ var upgrader = websocket.Upgrader{} func newConnection(server *Server) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { + if server.relaySAS != "" { + // validate the sas key + sasParam := req.URL.Query().Get("sb-hc-token") + if sasParam != server.relaySAS { + server.errCh <- errors.New("error validating sas") + return + } + } c, err := upgrader.Upgrade(w, req, nil) if err != nil { server.errCh <- fmt.Errorf("error upgrading connection: %v", err) From 98282ba4b51085e03965672b1b124cc708bc6e82 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 26 Jul 2021 14:31:00 +0000 Subject: [PATCH 08/10] Update shared visibility tests --- rpc.go | 5 ----- server_test.go | 32 +++++++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/rpc.go b/rpc.go index d624bbd74..8abd0e98f 100644 --- a/rpc.go +++ b/rpc.go @@ -30,11 +30,6 @@ func (r *rpcClient) do(ctx context.Context, method string, args interface{}, res return fmt.Errorf("error on dispatch call: %v", err) } - // caller doesn't care about result, so lets ignore it - if result == nil { - return nil - } - return waiter.Wait(ctx, result) } diff --git a/server_test.go b/server_test.go index cc2b9adbd..8a736b6f5 100644 --- a/server_test.go +++ b/server_test.go @@ -182,5 +182,35 @@ func TestServerGetSharedServers(t *testing.T) { } func TestServerUpdateSharedVisibility(t *testing.T) { - + updateSharedVisibility := func(req *jsonrpc2.Request) error { + return nil + } + testServer, client, err := newMockJoinedClient( + livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), + ) + if err != nil { + t.Errorf("creating new mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("creating server: %v", err) + } + ctx := context.Background() + done := make(chan error) + go func() { + if err := server.UpdateSharedVisibility(ctx, 80, true); err != nil { + done <- err + return + } + done <- nil + }() + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } } From 892f73221c69d21f77d53e77eddd16f40c20c4ba Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 26 Jul 2021 14:39:52 +0000 Subject: [PATCH 09/10] Update shared visibility finalized tests --- server_test.go | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/server_test.go b/server_test.go index 8a736b6f5..7c9ee4288 100644 --- a/server_test.go +++ b/server_test.go @@ -182,8 +182,29 @@ func TestServerGetSharedServers(t *testing.T) { } func TestServerUpdateSharedVisibility(t *testing.T) { - updateSharedVisibility := func(req *jsonrpc2.Request) error { - return nil + updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) { + var req []interface{} + if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { + return nil, fmt.Errorf("unmarshal req: %v", err) + } + if len(req) < 2 { + return nil, errors.New("request arguments is less than 2") + } + if port, ok := req[0].(float64); ok { + if port != 80.0 { + return nil, errors.New("port param is not expected value") + } + } else { + return nil, errors.New("port param is not a float64") + } + if public, ok := req[1].(bool); ok { + if public != true { + return nil, errors.New("pulic param is not expected value") + } + } else { + return nil, errors.New("public param is not a bool") + } + return nil, nil } testServer, client, err := newMockJoinedClient( livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), From 0ab67badfad20a67a73bb170647fa115538b2995 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 27 Jul 2021 23:19:55 +0000 Subject: [PATCH 10/10] Final changes to finish this refactor --- connection_test.go | 33 +++++++++++++ port_forwarder.go | 4 ++ port_forwarder_test.go | 102 +++++++++++++++++++++++++++++++++++++++++ server.go | 8 ++++ server_test.go | 10 ++-- socket.go | 20 ++------ test/server.go | 54 ++++++++++++++++++++-- 7 files changed, 206 insertions(+), 25 deletions(-) create mode 100644 connection_test.go diff --git a/connection_test.go b/connection_test.go new file mode 100644 index 000000000..e952290be --- /dev/null +++ b/connection_test.go @@ -0,0 +1,33 @@ +package liveshare + +import "testing" + +func TestConnectionValid(t *testing.T) { + conn := Connection{"sess-id", "sess-token", "sas", "endpoint"} + if err := conn.validate(); err != nil { + t.Error(err) + } +} + +func TestConnectionInvalid(t *testing.T) { + conn := Connection{"", "sess-token", "sas", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "", "sas", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "sess-token", "", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "sess-token", "sas", ""} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"", "", "", ""} + if err := conn.validate(); err == nil { + t.Error(err) + } +} diff --git a/port_forwarder.go b/port_forwarder.go index 6d459b4d6..0a049d586 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -54,8 +54,12 @@ func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { copyConn := func(writer io.Writer, reader io.Reader) { if _, err := io.Copy(writer, reader); err != nil { + fmt.Println(err) channel.Close() conn.Close() + if err != io.EOF { + l.errCh <- fmt.Errorf("tunnel connection: %v", err) + } } } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index e3e219705..33a33b39b 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -1 +1,103 @@ package liveshare + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "testing" + "time" + + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestNewPortForwarder(t *testing.T) { + testServer, client, err := makeMockJoinedClient() + if err != nil { + t.Errorf("create mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("create new server: %v", err) + } + pf := NewPortForwarder(client, server, 80) + if pf == nil { + t.Error("port forwarder is nil") + } +} + +func TestPortForwarderStart(t *testing.T) { + streamName, streamCondition := "stream-name", "stream-condition" + serverSharing := func(req *jsonrpc2.Request) (interface{}, error) { + return Port{StreamName: streamName, StreamCondition: streamCondition}, nil + } + getStream := func(req *jsonrpc2.Request) (interface{}, error) { + return "stream-id", nil + } + + stream := bytes.NewBufferString("stream-data") + testServer, client, err := makeMockJoinedClient( + livesharetest.WithService("serverSharing.startSharing", serverSharing), + livesharetest.WithService("streamManager.getStream", getStream), + livesharetest.WithStream("stream-id", stream), + ) + if err != nil { + t.Errorf("create mock client: %v", err) + } + defer testServer.Close() + + server, err := NewServer(client) + if err != nil { + t.Errorf("create new server: %v", err) + } + + ctx, _ := context.WithCancel(context.Background()) + pf := NewPortForwarder(client, server, 8000) + done := make(chan error) + + go func() { + if err := server.StartSharing(ctx, "http", 8000); err != nil { + done <- fmt.Errorf("start sharing: %v", err) + } + if err := pf.Start(ctx); err != nil { + done <- err + } + done <- nil + }() + + go func() { + var conn net.Conn + retries := 0 + for conn == nil && retries < 2 { + conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second) + time.Sleep(1 * time.Second) + } + if conn == nil { + done <- errors.New("failed to connect to forwarded port") + } + b := make([]byte, len("stream-data")) + if _, err := conn.Read(b); err != nil && err != io.EOF { + done <- fmt.Errorf("reading stream: %v", err) + } + if string(b) != "stream-data" { + done <- fmt.Errorf("stream data is not expected value, got: %v", string(b)) + } + if _, err := conn.Write([]byte("new-data")); err != nil { + done <- fmt.Errorf("writing to stream: %v", err) + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} diff --git a/server.go b/server.go index 6f17d5ac5..7e8c8b1cb 100644 --- a/server.go +++ b/server.go @@ -7,12 +7,14 @@ import ( "strconv" ) +// A Server represents the liveshare host and container server type Server struct { client *Client port int streamName, streamCondition string } +// NewServer creates a new Server with a given Client func NewServer(client *Client) (*Server, error) { if !client.hasJoined() { return nil, errors.New("client must join before creating server") @@ -21,6 +23,7 @@ func NewServer(client *Client) (*Server, error) { return &Server{client: client}, nil } +// Port represents an open port on the container type Port struct { SourcePort int `json:"sourcePort"` DestinationPort int `json:"destinationPort"` @@ -33,6 +36,7 @@ type Port struct { HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"` } +// StartSharing tells the liveshare host to start sharing the port from the container func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { s.port = port @@ -49,8 +53,10 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er return nil } +// Ports is a slice of Port pointers type Ports []*Port +// GetSharedServers returns a list of available/open ports from the container func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { var response Ports if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { @@ -60,6 +66,8 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { return response, nil } +// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly +// via the Browse URL func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error { if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil { return err diff --git a/server_test.go b/server_test.go index 7c9ee4288..b91fbfddc 100644 --- a/server_test.go +++ b/server_test.go @@ -23,7 +23,7 @@ func TestNewServerWithNotJoinedClient(t *testing.T) { } } -func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { +func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { connection := Connection{ SessionID: "session-id", SessionToken: "session-token", @@ -54,7 +54,7 @@ func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Ser } func TestNewServer(t *testing.T) { - testServer, client, err := newMockJoinedClient() + testServer, client, err := makeMockJoinedClient() defer testServer.Close() if err != nil { t.Errorf("error creating mock joined client: %v", err) @@ -95,7 +95,7 @@ func TestServerStartSharing(t *testing.T) { } return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil } - testServer, client, err := newMockJoinedClient( + testServer, client, err := makeMockJoinedClient( livesharetest.WithService("serverSharing.startSharing", startSharing), ) defer testServer.Close() @@ -138,7 +138,7 @@ func TestServerGetSharedServers(t *testing.T) { getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) { return Ports{&sharedServer}, nil } - testServer, client, err := newMockJoinedClient( + testServer, client, err := makeMockJoinedClient( livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), ) if err != nil { @@ -206,7 +206,7 @@ func TestServerUpdateSharedVisibility(t *testing.T) { } return nil, nil } - testServer, client, err := newMockJoinedClient( + testServer, client, err := makeMockJoinedClient( livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), ) if err != nil { diff --git a/socket.go b/socket.go index e4f80a0cf..8744eeb96 100644 --- a/socket.go +++ b/socket.go @@ -3,11 +3,9 @@ package liveshare import ( "context" "crypto/tls" - "errors" "io" "net" "net/http" - "sync" "time" "github.com/gorilla/websocket" @@ -17,10 +15,8 @@ type socket struct { addr string tlsConfig *tls.Config - conn *websocket.Conn - readMutex sync.Mutex - writeMutex sync.Mutex - reader io.Reader + conn *websocket.Conn + reader io.Reader } func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket { @@ -42,19 +38,12 @@ func (s *socket) connect(ctx context.Context) error { } func (s *socket) Read(b []byte) (int, error) { - s.readMutex.Lock() - defer s.readMutex.Unlock() - if s.reader == nil { - messageType, reader, err := s.conn.NextReader() + _, reader, err := s.conn.NextReader() if err != nil { return 0, err } - if messageType != websocket.BinaryMessage { - return 0, errors.New("unexpected websocket message type") - } - s.reader = reader } @@ -71,9 +60,6 @@ func (s *socket) Read(b []byte) (int, error) { } func (s *socket) Write(b []byte) (int, error) { - s.writeMutex.Lock() - defer s.writeMutex.Unlock() - nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage) if err != nil { return 0, err diff --git a/test/server.go b/test/server.go index abb7ac96a..a52d31ab9 100644 --- a/test/server.go +++ b/test/server.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "path/filepath" + "strings" "sync" "time" @@ -21,6 +22,7 @@ type Server struct { password string services map[string]RpcHandleFunc relaySAS string + streams map[string]io.ReadWriter sshConfig *ssh.ServerConfig httptestServer *httptest.Server @@ -50,7 +52,7 @@ func NewServer(opts ...ServerOption) (*Server, error) { server.sshConfig.AddHostKey(privateKey) server.errCh = make(chan error) - server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server))) + server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server))) return server, nil } @@ -81,6 +83,16 @@ func WithRelaySAS(sas string) ServerOption { } } +func WithStream(name string, stream io.ReadWriter) ServerOption { + return func(s *Server) error { + if s.streams == nil { + s.streams = make(map[string]io.ReadWriter) + } + s.streams[name] = stream + return nil + } +} + func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { if string(password) == serverPassword { @@ -104,7 +116,7 @@ func (s *Server) Err() <-chan error { var upgrader = websocket.Upgrader{} -func newConnection(server *Server) http.HandlerFunc { +func makeConnection(server *Server) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { if server.relaySAS != "" { // validate the sas key @@ -135,12 +147,48 @@ func newConnection(server *Server) http.HandlerFunc { server.errCh <- fmt.Errorf("error accepting new channel: %v", err) return } - go ssh.DiscardRequests(reqs) + go handleNewRequests(server, ch, reqs) go handleNewChannel(server, ch) } } } +func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { + for req := range reqs { + if req.WantReply { + if err := req.Reply(true, nil); err != nil { + server.errCh <- fmt.Errorf("error replying to channel request: %v", err) + } + } + if strings.HasPrefix(req.Type, "stream-transport") { + forwardStream(server, req.Type, channel) + } + } +} + +func forwardStream(server *Server, streamName string, channel ssh.Channel) { + simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") + stream, found := server.streams[simpleStreamName] + if !found { + server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName) + return + } + + copy := func(dst io.Writer, src io.Reader) { + if _, err := io.Copy(dst, src); err != nil { + fmt.Println(err) + server.errCh <- fmt.Errorf("io copy: %v", err) + return + } + } + + go copy(stream, channel) + go copy(channel, stream) + + for { + } +} + func handleNewChannel(server *Server, channel ssh.Channel) { stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server))