From 7332aa428c4db7b87c4280063b89a4d10763cb3c Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 22 Jul 2021 00:45:45 +0000 Subject: [PATCH] Large refactor and solidifying of APIs before tests --- api.go | 130 ---------------------------------------------- client.go | 50 +++++++++++------- client_test.go | 24 +++++++++ connection.go | 43 +++++++++++++++ liveshare.go | 77 --------------------------- port_forwarder.go | 10 ++-- rpc.go | 2 - server.go | 8 +-- session.go | 60 --------------------- socket.go | 105 +++++++++++++++++++++++++++++++++++++ ssh.go | 16 +++--- terminal.go | 8 +-- websocket.go | 105 ------------------------------------- 13 files changed, 224 insertions(+), 414 deletions(-) delete mode 100644 api.go create mode 100644 client_test.go create mode 100644 connection.go delete mode 100644 liveshare.go delete mode 100644 session.go create mode 100644 socket.go delete mode 100644 websocket.go diff --git a/api.go b/api.go deleted file mode 100644 index 55b6e6e93..000000000 --- a/api.go +++ /dev/null @@ -1,130 +0,0 @@ -package liveshare - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "strings" -) - -type api struct { - client *Client - httpClient *http.Client - serviceURI string - workspaceID string -} - -func newAPI(client *Client) *api { - serviceURI := client.liveShare.Configuration.LiveShareEndpoint - if !strings.HasSuffix(client.liveShare.Configuration.LiveShareEndpoint, "/") { - serviceURI = client.liveShare.Configuration.LiveShareEndpoint + "/" - } - - if !strings.Contains(serviceURI, "api/v1.2") { - serviceURI = serviceURI + "api/v1.2" - } - - serviceURI = strings.TrimSuffix(serviceURI, "/") - - return &api{client, &http.Client{}, serviceURI, strings.ToUpper(client.liveShare.Configuration.WorkspaceID)} -} - -type workspaceAccessResponse struct { - SessionToken string `json:"sessionToken"` - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - Name string `json:"name"` - OwnerID string `json:"ownerId"` - JoinLink string `json:"joinLink"` - ConnectLinks []string `json:"connectLinks"` - RelayLink string `json:"relayLink"` - RelaySas string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` - ConversationID string `json:"conversationId"` - AssociatedUserIDs map[string]string `json:"associatedUserIds"` - AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` - IsHostConnected bool `json:"isHostConnected"` - ExpiresAt string `json:"expiresAt"` - InvitationLinks []string `json:"invitationLinks"` - ID string `json:"id"` -} - -func (a *api) workspaceAccess() (*workspaceAccessResponse, error) { - url := fmt.Sprintf("%s/workspace/%s/user", a.serviceURI, a.workspaceID) - - req, err := http.NewRequest(http.MethodPut, url, nil) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - a.setDefaultHeaders(req) - resp, err := a.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error making request: %v", err) - } - - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) - } - - var response workspaceAccessResponse - if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response into json: %v", err) - } - - return &response, nil -} - -func (a *api) setDefaultHeaders(req *http.Request) { - req.Header.Set("Authorization", "Bearer "+a.client.liveShare.Configuration.Token) - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Content-Type", "application/json") -} - -type workspaceInfoResponse struct { - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - Name string `json:"name"` - OwnerID string `json:"ownerId"` - JoinLink string `json:"joinLink"` - ConnectLinks []string `json:"connectLinks"` - RelayLink string `json:"relayLink"` - RelaySas string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` - ConversationID string `json:"conversationId"` - AssociatedUserIDs map[string]string - AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` - IsHostConnected bool `json:"isHostConnected"` - ExpiresAt string `json:"expiresAt"` - InvitationLinks []string `json:"invitationLinks"` - ID string `json:"id"` -} - -func (a *api) workspaceInfo() (*workspaceInfoResponse, error) { - url := fmt.Sprintf("%s/workspace/%s", a.serviceURI, a.workspaceID) - - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - a.setDefaultHeaders(req) - resp, err := a.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error making request: %v", err) - } - - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) - } - - var response workspaceInfoResponse - if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response into json: %v", err) - } - - return &response, nil -} diff --git a/client.go b/client.go index 456a2c321..1ad90b336 100644 --- a/client.go +++ b/client.go @@ -8,31 +8,44 @@ import ( ) type Client struct { - liveShare *LiveShare - session *session + connection Connection + sshSession *sshSession rpc *rpc } -// NewClient is a function ... -func (l *LiveShare) NewClient() *Client { - return &Client{liveShare: l} +type ClientOption func(*Client) error + +func NewClient(opts ...ClientOption) (*Client, error) { + client := new(Client) + + for _, o := range opts { + if err := o(client); err != nil { + return nil, err + } + } + + return client, nil +} + +func WithConnection(connection Connection) ClientOption { + return func(c *Client) error { + if err := connection.validate(); err != nil { + return err + } + + c.connection = connection + return nil + } } func (c *Client) Join(ctx context.Context) (err error) { - api := newAPI(c) - - c.session = newSession(api) - if err := c.session.init(ctx); err != nil { - return fmt.Errorf("error creating session: %v", err) - } - - websocket := newWebsocket(c.session) - if err := websocket.connect(ctx); err != nil { + clientSocket := newSocket(c.connection) + if err := clientSocket.connect(ctx); err != nil { return fmt.Errorf("error connecting websocket: %v", err) } - c.sshSession = newSSH(c.session, websocket) + c.sshSession = newSSH(c.connection.SessionToken, clientSocket) if err := c.sshSession.connect(ctx); err != nil { return fmt.Errorf("error connecting to ssh session: %v", err) } @@ -69,9 +82,9 @@ type joinWorkspaceResult struct { func (c *Client) joinWorkspace(ctx context.Context) (*joinWorkspaceResult, error) { args := joinWorkspaceArgs{ - ID: c.session.workspaceInfo.ID, + ID: c.connection.SessionID, ConnectionMode: "local", - JoiningUserSessionToken: c.session.workspaceAccess.SessionToken, + JoiningUserSessionToken: c.connection.SessionToken, ClientCapabilities: clientCapabilities{ IsNonInteractive: false, }, @@ -99,8 +112,7 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition go ssh.DiscardRequests(reqs) requestType := fmt.Sprintf("stream-transport-%s", streamID) - _, err = channel.SendRequest(requestType, true, nil) - if err != nil { + if _, err = channel.SendRequest(requestType, true, nil); err != nil { return nil, fmt.Errorf("error sending channel request: %v", err) } diff --git a/client_test.go b/client_test.go new file mode 100644 index 000000000..8d118974d --- /dev/null +++ b/client_test.go @@ -0,0 +1,24 @@ +package liveshare + +import ( + "testing" +) + +func TestClientJoin(t *testing.T) { + // connection := Connection{ + // SessionID: "session-id", + // SessionToken: "session-token", + // RelayEndpoint: "relay-endpoint", + // RelaySAS: "relay-sas", + // } + + // client, err := NewClient(WithConnection(connection)) + // if err != nil { + // t.Errorf("error creating client: %v", err) + // } + + // ctx := context.Background() + // if err := client.Join(ctx); err != nil { + // t.Errorf("error joining client: %v", err) + // } +} diff --git a/connection.go b/connection.go new file mode 100644 index 000000000..a97935d3b --- /dev/null +++ b/connection.go @@ -0,0 +1,43 @@ +package liveshare + +import ( + "errors" + "net/url" + "strings" +) + +type Connection struct { + SessionID string `json:"sessionId"` + SessionToken string `json:"sessionToken"` + RelaySAS string `json:"relaySas"` + RelayEndpoint string `json:"relayEndpoint"` +} + +func (r Connection) validate() error { + if r.SessionID == "" { + return errors.New("connection sessionID is required") + } + + if r.SessionToken == "" { + return errors.New("connection sessionToken is required") + } + + if r.RelaySAS == "" { + return errors.New("connection relaySas is required") + } + + if r.RelayEndpoint == "" { + return errors.New("connection relayEndpoint is required") + } + + return nil +} + +func (r Connection) uri(action string) string { + sas := url.QueryEscape(r.RelaySAS) + uri := r.RelayEndpoint + uri = strings.Replace(uri, "sb:", "wss:", -1) + uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) + uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas + return uri +} diff --git a/liveshare.go b/liveshare.go deleted file mode 100644 index 3c4be5c05..000000000 --- a/liveshare.go +++ /dev/null @@ -1,77 +0,0 @@ -package liveshare - -import ( - "errors" - "fmt" - "strings" -) - -type LiveShare struct { - Configuration *Configuration -} - -func New(opts ...Option) (*LiveShare, error) { - configuration := NewConfiguration() - - for _, o := range opts { - if err := o(configuration); err != nil { - return nil, fmt.Errorf("error configuring liveshare: %v", err) - } - } - - if err := configuration.Validate(); err != nil { - return nil, fmt.Errorf("error validating configuration: %v", err) - } - - return &LiveShare{Configuration: configuration}, nil -} - -type Option func(configuration *Configuration) error - -func WithWorkspaceID(id string) Option { - return func(configuration *Configuration) error { - configuration.WorkspaceID = id - return nil - } -} - -func WithLiveShareEndpoint(liveShareEndpoint string) Option { - return func(configuration *Configuration) error { - configuration.LiveShareEndpoint = liveShareEndpoint - return nil - } -} - -func WithToken(token string) Option { - return func(configuration *Configuration) error { - configuration.Token = token - return nil - } -} - -type Configuration struct { - WorkspaceID, LiveShareEndpoint, Token string -} - -func NewConfiguration() *Configuration { - return &Configuration{ - LiveShareEndpoint: "https://prod.liveshare.vsengsaas.visualstudio.com", - } -} - -func (c *Configuration) Validate() error { - errs := []string{} - if c.WorkspaceID == "" { - errs = append(errs, "WorkspaceID is required") - } - - if c.Token == "" { - errs = append(errs, "Token is required") - } - - if len(errs) > 0 { - return errors.New(strings.Join(errs, ", ")) - } - - return nil -} diff --git a/port_forwarder.go b/port_forwarder.go index 20382c208..1227493b2 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -11,18 +11,18 @@ import ( "golang.org/x/crypto/ssh" ) -type LocalPortForwarder struct { +type PortForwarder struct { client *Client server *Server port int channels []ssh.Channel } -func NewLocalPortForwarder(client *Client, server *Server, port int) *LocalPortForwarder { - return &LocalPortForwarder{client, server, port, []ssh.Channel{}} +func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { + return &PortForwarder{client, server, port, []ssh.Channel{}} } -func (l *LocalPortForwarder) Start(ctx context.Context) error { +func (l *PortForwarder) Start(ctx context.Context) error { ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) if err != nil { return fmt.Errorf("error listening on tcp port: %v", err) @@ -42,7 +42,7 @@ func (l *LocalPortForwarder) Start(ctx context.Context) error { return nil } -func (l *LocalPortForwarder) handleConnection(ctx context.Context, conn net.Conn) { +func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) if err != nil { log.Println("errrr handle Connect") diff --git a/rpc.go b/rpc.go index de427cda9..d40046471 100644 --- a/rpc.go +++ b/rpc.go @@ -78,7 +78,5 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{} }() - } else { - // TODO(josebalius): Handle } } diff --git a/server.go b/server.go index 71ec9d4dd..6f17d5ac5 100644 --- a/server.go +++ b/server.go @@ -13,12 +13,12 @@ type Server struct { streamName, streamCondition string } -func (c *Client) NewServer() (*Server, error) { - if !c.hasJoined() { - return nil, errors.New("LiveShareClient must join before creating server") +func NewServer(client *Client) (*Server, error) { + if !client.hasJoined() { + return nil, errors.New("client must join before creating server") } - return &Server{client: c}, nil + return &Server{client: client}, nil } type Port struct { diff --git a/session.go b/session.go deleted file mode 100644 index d0492a10e..000000000 --- a/session.go +++ /dev/null @@ -1,60 +0,0 @@ -package liveshare - -import ( - "context" - "fmt" - "net/url" - "strings" - - "golang.org/x/sync/errgroup" -) - -type session struct { - api *api - - workspaceAccess *workspaceAccessResponse - workspaceInfo *workspaceInfoResponse -} - -func newSession(api *api) *session { - return &session{api: api} -} - -func (s *session) init(ctx context.Context) error { - g, ctx := errgroup.WithContext(ctx) - - g.Go(func() error { - workspaceAccess, err := s.api.workspaceAccess() - if err != nil { - return fmt.Errorf("error getting workspace access: %v", err) - } - s.workspaceAccess = workspaceAccess - return nil - }) - - g.Go(func() error { - workspaceInfo, err := s.api.workspaceInfo() - if err != nil { - return fmt.Errorf("error getting workspace info: %v", err) - } - s.workspaceInfo = workspaceInfo - return nil - }) - - if err := g.Wait(); err != nil { - return err - } - - return nil -} - -// Reference: -// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 -func (s *session) relayURI(action string) string { - relaySas := url.QueryEscape(s.workspaceAccess.RelaySas) - relayURI := s.workspaceAccess.RelayLink - relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) - relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) - relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas - return relayURI -} diff --git a/socket.go b/socket.go new file mode 100644 index 000000000..c3f75b9db --- /dev/null +++ b/socket.go @@ -0,0 +1,105 @@ +package liveshare + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type socket struct { + addr string + conn *websocket.Conn + readMutex sync.Mutex + writeMutex sync.Mutex + reader io.Reader +} + +func newSocket(clientConn Connection) *socket { + return &socket{addr: clientConn.uri("connect")} +} + +func (s *socket) connect(ctx context.Context) error { + ws, _, err := websocket.DefaultDialer.Dial(s.addr, nil) + if err != nil { + return err + } + s.conn = ws + return nil +} + +func (s *socket) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + messageType, reader, err := s.conn.NextReader() + if err != nil { + return 0, err + } + + if messageType != websocket.BinaryMessage { + return 0, errors.New("unexpected websocket message type") + } + + s.reader = reader + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socket) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, err + } + + bytesWritten, err := nextWriter.Write(b) + nextWriter.Close() + + return bytesWritten, err +} + +func (s *socket) Close() error { + return s.conn.Close() +} + +func (s *socket) LocalAddr() net.Addr { + return s.conn.LocalAddr() +} + +func (s *socket) RemoteAddr() net.Addr { + return s.conn.RemoteAddr() +} + +func (s *socket) SetDeadline(t time.Time) error { + if err := s.SetReadDeadline(t); err != nil { + return err + } + + return s.SetWriteDeadline(t) +} + +func (s *socket) SetReadDeadline(t time.Time) error { + return s.conn.SetReadDeadline(t) +} + +func (s *socket) SetWriteDeadline(t time.Time) error { + return s.conn.SetWriteDeadline(t) +} diff --git a/ssh.go b/ssh.go index 9ae32ed7c..3ea2d2777 100644 --- a/ssh.go +++ b/ssh.go @@ -12,22 +12,22 @@ import ( type sshSession struct { *ssh.Session - session *session - socket net.Conn - conn ssh.Conn - reader io.Reader - writer io.Writer + token string + socket net.Conn + conn ssh.Conn + reader io.Reader + writer io.Writer } -func newSSH(session *session, socket net.Conn) *sshSession { - return &sshSession{session: session, socket: socket} +func newSSH(token string, socket net.Conn) *sshSession { + return &sshSession{token: token, socket: socket} } func (s *sshSession) connect(ctx context.Context) error { clientConfig := ssh.ClientConfig{ User: "", Auth: []ssh.AuthMethod{ - ssh.Password(s.session.workspaceAccess.SessionToken), + ssh.Password(s.token), }, HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), diff --git a/terminal.go b/terminal.go index 631e75912..1621559a1 100644 --- a/terminal.go +++ b/terminal.go @@ -13,13 +13,13 @@ type Terminal struct { client *Client } -func (c *Client) NewTerminal() (*Terminal, error) { - if !c.hasJoined() { - return nil, errors.New("LiveShareClient must join before creating terminal") +func NewTerminal(client *Client) (*Terminal, error) { + if !client.hasJoined() { + return nil, errors.New("client must join before creating terminal") } return &Terminal{ - client: c, + client: client, }, nil } diff --git a/websocket.go b/websocket.go deleted file mode 100644 index ae163e6e2..000000000 --- a/websocket.go +++ /dev/null @@ -1,105 +0,0 @@ -package liveshare - -import ( - "context" - "errors" - "io" - "net" - "sync" - "time" - - gorillawebsocket "github.com/gorilla/websocket" -) - -type websocket struct { - session *session - conn *gorillawebsocket.Conn - readMutex sync.Mutex - writeMutex sync.Mutex - reader io.Reader -} - -func newWebsocket(session *session) *websocket { - return &websocket{session: session} -} - -func (w *websocket) connect(ctx context.Context) error { - ws, _, err := gorillawebsocket.DefaultDialer.Dial(w.session.relayURI("connect"), nil) - if err != nil { - return err - } - w.conn = ws - return nil -} - -func (w *websocket) Read(b []byte) (int, error) { - w.readMutex.Lock() - defer w.readMutex.Unlock() - - if w.reader == nil { - messageType, reader, err := w.conn.NextReader() - if err != nil { - return 0, err - } - - if messageType != gorillawebsocket.BinaryMessage { - return 0, errors.New("unexpected websocket message type") - } - - w.reader = reader - } - - bytesRead, err := w.reader.Read(b) - if err != nil { - w.reader = nil - - if err == io.EOF { - err = nil - } - } - - return bytesRead, err -} - -func (w *websocket) Write(b []byte) (int, error) { - w.writeMutex.Lock() - defer w.writeMutex.Unlock() - - nextWriter, err := w.conn.NextWriter(gorillawebsocket.BinaryMessage) - if err != nil { - return 0, err - } - - bytesWritten, err := nextWriter.Write(b) - nextWriter.Close() - - return bytesWritten, err -} - -func (w *websocket) Close() error { - return w.conn.Close() -} - -func (w *websocket) LocalAddr() net.Addr { - return w.conn.LocalAddr() -} - -func (w *websocket) RemoteAddr() net.Addr { - return w.conn.RemoteAddr() -} - -func (w *websocket) SetDeadline(t time.Time) error { - if err := w.SetReadDeadline(t); err != nil { - return err - } - - return w.SetWriteDeadline(t) -} - -func (w *websocket) SetReadDeadline(t time.Time) error { - return w.conn.SetReadDeadline(t) -} - -func (w *websocket) SetWriteDeadline(t time.Time) error { - return w.conn.SetWriteDeadline(t) -}