diff --git a/api.go b/api.go deleted file mode 100644 index 55b6e6e93..000000000 --- a/api.go +++ /dev/null @@ -1,130 +0,0 @@ -package liveshare - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "strings" -) - -type api struct { - client *Client - httpClient *http.Client - serviceURI string - workspaceID string -} - -func newAPI(client *Client) *api { - serviceURI := client.liveShare.Configuration.LiveShareEndpoint - if !strings.HasSuffix(client.liveShare.Configuration.LiveShareEndpoint, "/") { - serviceURI = client.liveShare.Configuration.LiveShareEndpoint + "/" - } - - if !strings.Contains(serviceURI, "api/v1.2") { - serviceURI = serviceURI + "api/v1.2" - } - - serviceURI = strings.TrimSuffix(serviceURI, "/") - - return &api{client, &http.Client{}, serviceURI, strings.ToUpper(client.liveShare.Configuration.WorkspaceID)} -} - -type workspaceAccessResponse struct { - SessionToken string `json:"sessionToken"` - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - Name string `json:"name"` - OwnerID string `json:"ownerId"` - JoinLink string `json:"joinLink"` - ConnectLinks []string `json:"connectLinks"` - RelayLink string `json:"relayLink"` - RelaySas string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` - ConversationID string `json:"conversationId"` - AssociatedUserIDs map[string]string `json:"associatedUserIds"` - AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` - IsHostConnected bool `json:"isHostConnected"` - ExpiresAt string `json:"expiresAt"` - InvitationLinks []string `json:"invitationLinks"` - ID string `json:"id"` -} - -func (a *api) workspaceAccess() (*workspaceAccessResponse, error) { - url := fmt.Sprintf("%s/workspace/%s/user", a.serviceURI, a.workspaceID) - - req, err := http.NewRequest(http.MethodPut, url, nil) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - a.setDefaultHeaders(req) - resp, err := a.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error making request: %v", err) - } - - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) - } - - var response workspaceAccessResponse - if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response into json: %v", err) - } - - return &response, nil -} - -func (a *api) setDefaultHeaders(req *http.Request) { - req.Header.Set("Authorization", "Bearer "+a.client.liveShare.Configuration.Token) - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Content-Type", "application/json") -} - -type workspaceInfoResponse struct { - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - Name string `json:"name"` - OwnerID string `json:"ownerId"` - JoinLink string `json:"joinLink"` - ConnectLinks []string `json:"connectLinks"` - RelayLink string `json:"relayLink"` - RelaySas string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` - ConversationID string `json:"conversationId"` - AssociatedUserIDs map[string]string - AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` - IsHostConnected bool `json:"isHostConnected"` - ExpiresAt string `json:"expiresAt"` - InvitationLinks []string `json:"invitationLinks"` - ID string `json:"id"` -} - -func (a *api) workspaceInfo() (*workspaceInfoResponse, error) { - url := fmt.Sprintf("%s/workspace/%s", a.serviceURI, a.workspaceID) - - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - a.setDefaultHeaders(req) - resp, err := a.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error making request: %v", err) - } - - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) - } - - var response workspaceInfoResponse - if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response into json: %v", err) - } - - return &response, nil -} diff --git a/client.go b/client.go index 456a2c321..a8a1e3864 100644 --- a/client.go +++ b/client.go @@ -2,42 +2,69 @@ package liveshare import ( "context" + "crypto/tls" "fmt" "golang.org/x/crypto/ssh" ) +// A Client capable of joining a liveshare connection type Client struct { - liveShare *LiveShare - session *session - sshSession *sshSession - rpc *rpc + connection Connection + tlsConfig *tls.Config + + ssh *sshSession + rpc *rpcClient } -// NewClient is a function ... -func (l *LiveShare) NewClient() *Client { - return &Client{liveShare: l} -} +// A ClientOption is a function that modifies a client +type ClientOption func(*Client) error -func (c *Client) Join(ctx context.Context) (err error) { - api := newAPI(c) +// NewClient accepts a range of options, applies them and returns a client +func NewClient(opts ...ClientOption) (*Client, error) { + client := new(Client) - c.session = newSession(api) - if err := c.session.init(ctx); err != nil { - return fmt.Errorf("error creating session: %v", err) + for _, o := range opts { + if err := o(client); err != nil { + return nil, err + } } - websocket := newWebsocket(c.session) - if err := websocket.connect(ctx); err != nil { + return client, nil +} + +// WithConnection is a ClientOption that accepts a Connection +func WithConnection(connection Connection) ClientOption { + return func(c *Client) error { + if err := connection.validate(); err != nil { + return err + } + + c.connection = connection + return nil + } +} + +func WithTLSConfig(tlsConfig *tls.Config) ClientOption { + return func(c *Client) error { + c.tlsConfig = tlsConfig + return nil + } +} + +// Join is a method that joins the client to the liveshare session +func (c *Client) Join(ctx context.Context) (err error) { + clientSocket := newSocket(c.connection, c.tlsConfig) + if err := clientSocket.connect(ctx); err != nil { return fmt.Errorf("error connecting websocket: %v", err) } - c.sshSession = newSSH(c.session, websocket) - if err := c.sshSession.connect(ctx); err != nil { + c.ssh = newSshSession(c.connection.SessionToken, clientSocket) + if err := c.ssh.connect(ctx); err != nil { return fmt.Errorf("error connecting to ssh session: %v", err) } - c.rpc = newRPC(c.sshSession) + c.rpc = newRpcClient(c.ssh) c.rpc.connect(ctx) _, err = c.joinWorkspace(ctx) @@ -49,7 +76,7 @@ func (c *Client) Join(ctx context.Context) (err error) { } func (c *Client) hasJoined() bool { - return c.sshSession != nil && c.rpc != nil + return c.ssh != nil && c.rpc != nil } type clientCapabilities struct { @@ -69,9 +96,9 @@ type joinWorkspaceResult struct { func (c *Client) joinWorkspace(ctx context.Context) (*joinWorkspaceResult, error) { args := joinWorkspaceArgs{ - ID: c.session.workspaceInfo.ID, + ID: c.connection.SessionID, ConnectionMode: "local", - JoiningUserSessionToken: c.session.workspaceAccess.SessionToken, + JoiningUserSessionToken: c.connection.SessionToken, ClientCapabilities: clientCapabilities{ IsNonInteractive: false, }, @@ -92,15 +119,14 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition return nil, fmt.Errorf("error getting stream id: %v", err) } - channel, reqs, err := c.sshSession.conn.OpenChannel("session", nil) + channel, reqs, err := c.ssh.conn.OpenChannel("session", nil) if err != nil { return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) } go ssh.DiscardRequests(reqs) requestType := fmt.Sprintf("stream-transport-%s", streamID) - _, err = channel.SendRequest(requestType, true, nil) - if err != nil { + if _, err = channel.SendRequest(requestType, true, nil); err != nil { return nil, fmt.Errorf("error sending channel request: %v", err) } diff --git a/client_test.go b/client_test.go new file mode 100644 index 000000000..110c7e3b9 --- /dev/null +++ b/client_test.go @@ -0,0 +1,109 @@ +package liveshare + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestNewClient(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Errorf("error creating new client: %v", err) + } + if client == nil { + t.Error("client is nil") + } +} + +func TestNewClientValidConnection(t *testing.T) { + connection := Connection{"1", "2", "3", "4"} + + client, err := NewClient(WithConnection(connection)) + if err != nil { + t.Errorf("error creating new client: %v", err) + } + if client == nil { + t.Error("client is nil") + } +} + +func TestNewClientWithInvalidConnection(t *testing.T) { + connection := Connection{} + + if _, err := NewClient(WithConnection(connection)); err == nil { + t.Error("err is nil") + } +} + +func TestClientJoin(t *testing.T) { + connection := Connection{ + SessionID: "session-id", + SessionToken: "session-token", + RelaySAS: "relay-sas", + } + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + var joinWorkspaceReq joinWorkspaceArgs + if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { + return nil, fmt.Errorf("error unmarshaling req: %v", err) + } + if joinWorkspaceReq.ID != connection.SessionID { + return nil, errors.New("connection session id does not match") + } + if joinWorkspaceReq.ConnectionMode != "local" { + return nil, errors.New("connection mode is not local") + } + if joinWorkspaceReq.JoiningUserSessionToken != connection.SessionToken { + return nil, errors.New("connection user token does not match") + } + if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false { + return nil, errors.New("non interactive is not false") + } + return joinWorkspaceResult{1}, nil + } + + server, err := livesharetest.NewServer( + livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + livesharetest.WithRelaySAS(connection.RelaySAS), + ) + if err != nil { + t.Errorf("error creating liveshare server: %v", err) + } + defer server.Close() + connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") + + ctx := context.Background() + + tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + client, err := NewClient(WithConnection(connection), tlsConfig) + if err != nil { + t.Errorf("error creating new client: %v", err) + } + + done := make(chan error) + go func() { + if err := client.Join(ctx); err != nil { + done <- fmt.Errorf("error joining client: %v", err) + return + } + + done <- nil + }() + + select { + case err := <-server.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} diff --git a/connection.go b/connection.go new file mode 100644 index 000000000..c1a4632c8 --- /dev/null +++ b/connection.go @@ -0,0 +1,44 @@ +package liveshare + +import ( + "errors" + "net/url" + "strings" +) + +// A Connection represents a set of values necessary to join a liveshare connection +type Connection struct { + SessionID string + SessionToken string + RelaySAS string + RelayEndpoint string +} + +func (r Connection) validate() error { + if r.SessionID == "" { + return errors.New("connection SessionID is required") + } + + if r.SessionToken == "" { + return errors.New("connection SessionToken is required") + } + + if r.RelaySAS == "" { + return errors.New("connection RelaySAS is required") + } + + if r.RelayEndpoint == "" { + return errors.New("connection RelayEndpoint is required") + } + + return nil +} + +func (r Connection) uri(action string) string { + sas := url.QueryEscape(r.RelaySAS) + uri := r.RelayEndpoint + uri = strings.Replace(uri, "sb:", "wss:", -1) + uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) + uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas + return uri +} diff --git a/connection_test.go b/connection_test.go new file mode 100644 index 000000000..e952290be --- /dev/null +++ b/connection_test.go @@ -0,0 +1,33 @@ +package liveshare + +import "testing" + +func TestConnectionValid(t *testing.T) { + conn := Connection{"sess-id", "sess-token", "sas", "endpoint"} + if err := conn.validate(); err != nil { + t.Error(err) + } +} + +func TestConnectionInvalid(t *testing.T) { + conn := Connection{"", "sess-token", "sas", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "", "sas", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "sess-token", "", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "sess-token", "sas", ""} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"", "", "", ""} + if err := conn.validate(); err == nil { + t.Error(err) + } +} diff --git a/liveshare.go b/liveshare.go deleted file mode 100644 index 3c4be5c05..000000000 --- a/liveshare.go +++ /dev/null @@ -1,77 +0,0 @@ -package liveshare - -import ( - "errors" - "fmt" - "strings" -) - -type LiveShare struct { - Configuration *Configuration -} - -func New(opts ...Option) (*LiveShare, error) { - configuration := NewConfiguration() - - for _, o := range opts { - if err := o(configuration); err != nil { - return nil, fmt.Errorf("error configuring liveshare: %v", err) - } - } - - if err := configuration.Validate(); err != nil { - return nil, fmt.Errorf("error validating configuration: %v", err) - } - - return &LiveShare{Configuration: configuration}, nil -} - -type Option func(configuration *Configuration) error - -func WithWorkspaceID(id string) Option { - return func(configuration *Configuration) error { - configuration.WorkspaceID = id - return nil - } -} - -func WithLiveShareEndpoint(liveShareEndpoint string) Option { - return func(configuration *Configuration) error { - configuration.LiveShareEndpoint = liveShareEndpoint - return nil - } -} - -func WithToken(token string) Option { - return func(configuration *Configuration) error { - configuration.Token = token - return nil - } -} - -type Configuration struct { - WorkspaceID, LiveShareEndpoint, Token string -} - -func NewConfiguration() *Configuration { - return &Configuration{ - LiveShareEndpoint: "https://prod.liveshare.vsengsaas.visualstudio.com", - } -} - -func (c *Configuration) Validate() error { - errs := []string{} - if c.WorkspaceID == "" { - errs = append(errs, "WorkspaceID is required") - } - - if c.Token == "" { - errs = append(errs, "Token is required") - } - - if len(errs) > 0 { - return errors.New(strings.Join(errs, ", ")) - } - - return nil -} diff --git a/port_forwarder.go b/port_forwarder.go index 20382c208..0a049d586 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -4,25 +4,30 @@ import ( "context" "fmt" "io" - "log" "net" "strconv" - - "golang.org/x/crypto/ssh" ) -type LocalPortForwarder struct { - client *Client - server *Server - port int - channels []ssh.Channel +// A PortForwader can forward ports from a remote liveshare host to localhost +type PortForwarder struct { + client *Client + server *Server + port int + errCh chan error } -func NewLocalPortForwarder(client *Client, server *Server, port int) *LocalPortForwarder { - return &LocalPortForwarder{client, server, port, []ssh.Channel{}} +// NewPortForwarder creates a new PortForwader with a given client, server and port +func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { + return &PortForwarder{ + client: client, + server: server, + port: port, + errCh: make(chan error), + } } -func (l *LocalPortForwarder) Start(ctx context.Context) error { +// Start is a method to start forwarding the server to a localhost port +func (l *PortForwarder) Start(ctx context.Context) error { ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) if err != nil { return fmt.Errorf("error listening on tcp port: %v", err) @@ -37,24 +42,24 @@ func (l *LocalPortForwarder) Start(ctx context.Context) error { go l.handleConnection(ctx, conn) } - // clean up after ourselves - return nil } -func (l *LocalPortForwarder) handleConnection(ctx context.Context, conn net.Conn) { +func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) if err != nil { - log.Println("errrr handle Connect") - log.Println(err) // TODO(josebalius) handle this somehow + l.errCh <- fmt.Errorf("error opening streaming channel for new connection: %v", err) + return } - l.channels = append(l.channels, channel) copyConn := func(writer io.Writer, reader io.Reader) { - _, err := io.Copy(writer, reader) - if err != nil { + if _, err := io.Copy(writer, reader); err != nil { + fmt.Println(err) channel.Close() conn.Close() + if err != io.EOF { + l.errCh <- fmt.Errorf("tunnel connection: %v", err) + } } } diff --git a/port_forwarder_test.go b/port_forwarder_test.go new file mode 100644 index 000000000..33a33b39b --- /dev/null +++ b/port_forwarder_test.go @@ -0,0 +1,103 @@ +package liveshare + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "testing" + "time" + + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestNewPortForwarder(t *testing.T) { + testServer, client, err := makeMockJoinedClient() + if err != nil { + t.Errorf("create mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("create new server: %v", err) + } + pf := NewPortForwarder(client, server, 80) + if pf == nil { + t.Error("port forwarder is nil") + } +} + +func TestPortForwarderStart(t *testing.T) { + streamName, streamCondition := "stream-name", "stream-condition" + serverSharing := func(req *jsonrpc2.Request) (interface{}, error) { + return Port{StreamName: streamName, StreamCondition: streamCondition}, nil + } + getStream := func(req *jsonrpc2.Request) (interface{}, error) { + return "stream-id", nil + } + + stream := bytes.NewBufferString("stream-data") + testServer, client, err := makeMockJoinedClient( + livesharetest.WithService("serverSharing.startSharing", serverSharing), + livesharetest.WithService("streamManager.getStream", getStream), + livesharetest.WithStream("stream-id", stream), + ) + if err != nil { + t.Errorf("create mock client: %v", err) + } + defer testServer.Close() + + server, err := NewServer(client) + if err != nil { + t.Errorf("create new server: %v", err) + } + + ctx, _ := context.WithCancel(context.Background()) + pf := NewPortForwarder(client, server, 8000) + done := make(chan error) + + go func() { + if err := server.StartSharing(ctx, "http", 8000); err != nil { + done <- fmt.Errorf("start sharing: %v", err) + } + if err := pf.Start(ctx); err != nil { + done <- err + } + done <- nil + }() + + go func() { + var conn net.Conn + retries := 0 + for conn == nil && retries < 2 { + conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second) + time.Sleep(1 * time.Second) + } + if conn == nil { + done <- errors.New("failed to connect to forwarded port") + } + b := make([]byte, len("stream-data")) + if _, err := conn.Read(b); err != nil && err != io.EOF { + done <- fmt.Errorf("reading stream: %v", err) + } + if string(b) != "stream-data" { + done <- fmt.Errorf("stream data is not expected value, got: %v", string(b)) + } + if _, err := conn.Write([]byte("new-data")); err != nil { + done <- fmt.Errorf("writing to stream: %v", err) + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} diff --git a/rpc.go b/rpc.go index de427cda9..8abd0e98f 100644 --- a/rpc.go +++ b/rpc.go @@ -9,32 +9,27 @@ import ( "github.com/sourcegraph/jsonrpc2" ) -type rpc struct { +type rpcClient struct { *jsonrpc2.Conn conn io.ReadWriteCloser handler *rpcHandler } -func newRPC(conn io.ReadWriteCloser) *rpc { - return &rpc{conn: conn, handler: newRPCHandler()} +func newRpcClient(conn io.ReadWriteCloser) *rpcClient { + return &rpcClient{conn: conn, handler: newRPCHandler()} } -func (r *rpc) connect(ctx context.Context) { +func (r *rpcClient) connect(ctx context.Context) { stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) } -func (r *rpc) do(ctx context.Context, method string, args interface{}, result interface{}) error { +func (r *rpcClient) do(ctx context.Context, method string, args interface{}, result interface{}) error { waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { return fmt.Errorf("error on dispatch call: %v", err) } - // caller doesn't care about result, so lets ignore it - if result == nil { - return nil - } - return waiter.Wait(ctx, result) } @@ -78,7 +73,5 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{} }() - } else { - // TODO(josebalius): Handle } } diff --git a/server.go b/server.go index 71ec9d4dd..7e8c8b1cb 100644 --- a/server.go +++ b/server.go @@ -7,20 +7,23 @@ import ( "strconv" ) +// A Server represents the liveshare host and container server type Server struct { client *Client port int streamName, streamCondition string } -func (c *Client) NewServer() (*Server, error) { - if !c.hasJoined() { - return nil, errors.New("LiveShareClient must join before creating server") +// NewServer creates a new Server with a given Client +func NewServer(client *Client) (*Server, error) { + if !client.hasJoined() { + return nil, errors.New("client must join before creating server") } - return &Server{client: c}, nil + return &Server{client: client}, nil } +// Port represents an open port on the container type Port struct { SourcePort int `json:"sourcePort"` DestinationPort int `json:"destinationPort"` @@ -33,6 +36,7 @@ type Port struct { HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"` } +// StartSharing tells the liveshare host to start sharing the port from the container func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { s.port = port @@ -49,8 +53,10 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er return nil } +// Ports is a slice of Port pointers type Ports []*Port +// GetSharedServers returns a list of available/open ports from the container func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { var response Ports if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { @@ -60,6 +66,8 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { return response, nil } +// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly +// via the Browse URL func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error { if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil { return err diff --git a/server_test.go b/server_test.go new file mode 100644 index 000000000..b91fbfddc --- /dev/null +++ b/server_test.go @@ -0,0 +1,237 @@ +package liveshare + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestNewServerWithNotJoinedClient(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Errorf("error creating new client: %v", err) + } + if _, err := NewServer(client); err == nil { + t.Error("expected error") + } +} + +func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { + connection := Connection{ + SessionID: "session-id", + SessionToken: "session-token", + RelaySAS: "relay-sas", + } + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + return joinWorkspaceResult{1}, nil + } + opts = append( + opts, + livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + ) + testServer, err := livesharetest.NewServer( + opts..., + ) + connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https") + tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + client, err := NewClient(WithConnection(connection), tlsConfig) + if err != nil { + return nil, nil, fmt.Errorf("error creating new client: %v", err) + } + ctx := context.Background() + if err := client.Join(ctx); err != nil { + return nil, nil, fmt.Errorf("error joining client: %v", err) + } + return testServer, client, nil +} + +func TestNewServer(t *testing.T) { + testServer, client, err := makeMockJoinedClient() + defer testServer.Close() + if err != nil { + t.Errorf("error creating mock joined client: %v", err) + } + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + if server == nil { + t.Error("server is nil") + } +} + +func TestServerStartSharing(t *testing.T) { + serverPort, serverProtocol := 2222, "sshd" + startSharing := func(req *jsonrpc2.Request) (interface{}, error) { + var args []interface{} + if err := json.Unmarshal(*req.Params, &args); err != nil { + return nil, fmt.Errorf("error unmarshaling request: %v", err) + } + if len(args) < 3 { + return nil, errors.New("not enough arguments to start sharing") + } + if port, ok := args[0].(float64); !ok { + return nil, errors.New("port argument is not an int") + } else if port != float64(serverPort) { + return nil, errors.New("port does not match serverPort") + } + if protocol, ok := args[1].(string); !ok { + return nil, errors.New("protocol argument is not a string") + } else if protocol != serverProtocol { + return nil, errors.New("protocol does not match serverProtocol") + } + if browseURL, ok := args[2].(string); !ok { + return nil, errors.New("browse url is not a string") + } else if browseURL != fmt.Sprintf("http://localhost:%v", serverPort) { + return nil, errors.New("browseURL does not match expected") + } + return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil + } + testServer, client, err := makeMockJoinedClient( + livesharetest.WithService("serverSharing.startSharing", startSharing), + ) + defer testServer.Close() + if err != nil { + t.Errorf("error creating mock joined client: %v", err) + } + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + ctx := context.Background() + + done := make(chan error) + go func() { + if err := server.StartSharing(ctx, serverProtocol, serverPort); err != nil { + done <- fmt.Errorf("error sharing server: %v", err) + } + if server.streamName == "" || server.streamCondition == "" { + done <- errors.New("stream name or condition is blank") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} + +func TestServerGetSharedServers(t *testing.T) { + sharedServer := Port{ + SourcePort: 2222, + StreamName: "stream-name", + StreamCondition: "stream-condition", + } + getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) { + return Ports{&sharedServer}, nil + } + testServer, client, err := makeMockJoinedClient( + livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), + ) + if err != nil { + t.Errorf("error creating new mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + ctx := context.Background() + done := make(chan error) + go func() { + ports, err := server.GetSharedServers(ctx) + if err != nil { + done <- fmt.Errorf("error getting shared servers: %v", err) + } + if len(ports) < 1 { + done <- errors.New("not enough ports returned") + } + if ports[0].SourcePort != sharedServer.SourcePort { + done <- errors.New("source port does not match") + } + if ports[0].StreamName != sharedServer.StreamName { + done <- errors.New("stream name does not match") + } + if ports[0].StreamCondition != sharedServer.StreamCondition { + done <- errors.New("stream condiion does not match") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} + +func TestServerUpdateSharedVisibility(t *testing.T) { + updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) { + var req []interface{} + if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { + return nil, fmt.Errorf("unmarshal req: %v", err) + } + if len(req) < 2 { + return nil, errors.New("request arguments is less than 2") + } + if port, ok := req[0].(float64); ok { + if port != 80.0 { + return nil, errors.New("port param is not expected value") + } + } else { + return nil, errors.New("port param is not a float64") + } + if public, ok := req[1].(bool); ok { + if public != true { + return nil, errors.New("pulic param is not expected value") + } + } else { + return nil, errors.New("public param is not a bool") + } + return nil, nil + } + testServer, client, err := makeMockJoinedClient( + livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), + ) + if err != nil { + t.Errorf("creating new mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("creating server: %v", err) + } + ctx := context.Background() + done := make(chan error) + go func() { + if err := server.UpdateSharedVisibility(ctx, 80, true); err != nil { + done <- err + return + } + done <- nil + }() + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} diff --git a/session.go b/session.go deleted file mode 100644 index d0492a10e..000000000 --- a/session.go +++ /dev/null @@ -1,60 +0,0 @@ -package liveshare - -import ( - "context" - "fmt" - "net/url" - "strings" - - "golang.org/x/sync/errgroup" -) - -type session struct { - api *api - - workspaceAccess *workspaceAccessResponse - workspaceInfo *workspaceInfoResponse -} - -func newSession(api *api) *session { - return &session{api: api} -} - -func (s *session) init(ctx context.Context) error { - g, ctx := errgroup.WithContext(ctx) - - g.Go(func() error { - workspaceAccess, err := s.api.workspaceAccess() - if err != nil { - return fmt.Errorf("error getting workspace access: %v", err) - } - s.workspaceAccess = workspaceAccess - return nil - }) - - g.Go(func() error { - workspaceInfo, err := s.api.workspaceInfo() - if err != nil { - return fmt.Errorf("error getting workspace info: %v", err) - } - s.workspaceInfo = workspaceInfo - return nil - }) - - if err := g.Wait(); err != nil { - return err - } - - return nil -} - -// Reference: -// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 -func (s *session) relayURI(action string) string { - relaySas := url.QueryEscape(s.workspaceAccess.RelaySas) - relayURI := s.workspaceAccess.RelayLink - relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) - relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) - relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas - return relayURI -} diff --git a/socket.go b/socket.go new file mode 100644 index 000000000..8744eeb96 --- /dev/null +++ b/socket.go @@ -0,0 +1,100 @@ +package liveshare + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/http" + "time" + + "github.com/gorilla/websocket" +) + +type socket struct { + addr string + tlsConfig *tls.Config + + conn *websocket.Conn + reader io.Reader +} + +func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket { + return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig} +} + +func (s *socket) connect(ctx context.Context) error { + dialer := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + TLSClientConfig: s.tlsConfig, + } + ws, _, err := dialer.Dial(s.addr, nil) + if err != nil { + return err + } + s.conn = ws + return nil +} + +func (s *socket) Read(b []byte) (int, error) { + if s.reader == nil { + _, reader, err := s.conn.NextReader() + if err != nil { + return 0, err + } + + s.reader = reader + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socket) Write(b []byte) (int, error) { + nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, err + } + + bytesWritten, err := nextWriter.Write(b) + nextWriter.Close() + + return bytesWritten, err +} + +func (s *socket) Close() error { + return s.conn.Close() +} + +func (s *socket) LocalAddr() net.Addr { + return s.conn.LocalAddr() +} + +func (s *socket) RemoteAddr() net.Addr { + return s.conn.RemoteAddr() +} + +func (s *socket) SetDeadline(t time.Time) error { + if err := s.SetReadDeadline(t); err != nil { + return err + } + + return s.SetWriteDeadline(t) +} + +func (s *socket) SetReadDeadline(t time.Time) error { + return s.conn.SetReadDeadline(t) +} + +func (s *socket) SetWriteDeadline(t time.Time) error { + return s.conn.SetWriteDeadline(t) +} diff --git a/ssh.go b/ssh.go index 9ae32ed7c..e22cd69d1 100644 --- a/ssh.go +++ b/ssh.go @@ -12,22 +12,22 @@ import ( type sshSession struct { *ssh.Session - session *session - socket net.Conn - conn ssh.Conn - reader io.Reader - writer io.Writer + token string + socket net.Conn + conn ssh.Conn + reader io.Reader + writer io.Writer } -func newSSH(session *session, socket net.Conn) *sshSession { - return &sshSession{session: session, socket: socket} +func newSshSession(token string, socket net.Conn) *sshSession { + return &sshSession{token: token, socket: socket} } func (s *sshSession) connect(ctx context.Context) error { clientConfig := ssh.ClientConfig{ User: "", Auth: []ssh.AuthMethod{ - ssh.Password(s.session.workspaceAccess.SessionToken), + ssh.Password(s.token), }, HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), diff --git a/terminal.go b/terminal.go index 631e75912..1621559a1 100644 --- a/terminal.go +++ b/terminal.go @@ -13,13 +13,13 @@ type Terminal struct { client *Client } -func (c *Client) NewTerminal() (*Terminal, error) { - if !c.hasJoined() { - return nil, errors.New("LiveShareClient must join before creating terminal") +func NewTerminal(client *Client) (*Terminal, error) { + if !client.hasJoined() { + return nil, errors.New("client must join before creating terminal") } return &Terminal{ - client: c, + client: client, }, nil } diff --git a/test/server.go b/test/server.go new file mode 100644 index 000000000..a52d31ab9 --- /dev/null +++ b/test/server.go @@ -0,0 +1,290 @@ +package livesharetest + +import ( + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/sourcegraph/jsonrpc2" + "golang.org/x/crypto/ssh" +) + +type Server struct { + password string + services map[string]RpcHandleFunc + relaySAS string + streams map[string]io.ReadWriter + + sshConfig *ssh.ServerConfig + httptestServer *httptest.Server + errCh chan error +} + +func NewServer(opts ...ServerOption) (*Server, error) { + server := new(Server) + + for _, o := range opts { + if err := o(server); err != nil { + return nil, err + } + } + + server.sshConfig = &ssh.ServerConfig{ + PasswordCallback: sshPasswordCallback(server.password), + } + b, err := ioutil.ReadFile(filepath.Join("test", "private.key")) + if err != nil { + return nil, fmt.Errorf("error reading private.key: %v", err) + } + privateKey, err := ssh.ParsePrivateKey(b) + if err != nil { + return nil, fmt.Errorf("error parsing key: %v", err) + } + server.sshConfig.AddHostKey(privateKey) + + server.errCh = make(chan error) + server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server))) + return server, nil +} + +type ServerOption func(*Server) error + +func WithPassword(password string) ServerOption { + return func(s *Server) error { + s.password = password + return nil + } +} + +func WithService(serviceName string, handler RpcHandleFunc) ServerOption { + return func(s *Server) error { + if s.services == nil { + s.services = make(map[string]RpcHandleFunc) + } + + s.services[serviceName] = handler + return nil + } +} + +func WithRelaySAS(sas string) ServerOption { + return func(s *Server) error { + s.relaySAS = sas + return nil + } +} + +func WithStream(name string, stream io.ReadWriter) ServerOption { + return func(s *Server) error { + if s.streams == nil { + s.streams = make(map[string]io.ReadWriter) + } + s.streams[name] = stream + return nil + } +} + +func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { + return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + if string(password) == serverPassword { + return nil, nil + } + return nil, errors.New("password rejected") + } +} + +func (s *Server) Close() { + s.httptestServer.Close() +} + +func (s *Server) URL() string { + return s.httptestServer.URL +} + +func (s *Server) Err() <-chan error { + return s.errCh +} + +var upgrader = websocket.Upgrader{} + +func makeConnection(server *Server) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + if server.relaySAS != "" { + // validate the sas key + sasParam := req.URL.Query().Get("sb-hc-token") + if sasParam != server.relaySAS { + server.errCh <- errors.New("error validating sas") + return + } + } + c, err := upgrader.Upgrade(w, req, nil) + if err != nil { + server.errCh <- fmt.Errorf("error upgrading connection: %v", err) + return + } + defer c.Close() + + socketConn := newSocketConn(c) + _, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig) + if err != nil { + server.errCh <- fmt.Errorf("error creating new ssh conn: %v", err) + return + } + go ssh.DiscardRequests(reqs) + + for newChannel := range chans { + ch, reqs, err := newChannel.Accept() + if err != nil { + server.errCh <- fmt.Errorf("error accepting new channel: %v", err) + return + } + go handleNewRequests(server, ch, reqs) + go handleNewChannel(server, ch) + } + } +} + +func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { + for req := range reqs { + if req.WantReply { + if err := req.Reply(true, nil); err != nil { + server.errCh <- fmt.Errorf("error replying to channel request: %v", err) + } + } + if strings.HasPrefix(req.Type, "stream-transport") { + forwardStream(server, req.Type, channel) + } + } +} + +func forwardStream(server *Server, streamName string, channel ssh.Channel) { + simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") + stream, found := server.streams[simpleStreamName] + if !found { + server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName) + return + } + + copy := func(dst io.Writer, src io.Reader) { + if _, err := io.Copy(dst, src); err != nil { + fmt.Println(err) + server.errCh <- fmt.Errorf("io copy: %v", err) + return + } + } + + go copy(stream, channel) + go copy(channel, stream) + + for { + } +} + +func handleNewChannel(server *Server, channel ssh.Channel) { + stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) + jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server)) +} + +type RpcHandleFunc func(req *jsonrpc2.Request) (interface{}, error) + +type rpcHandler struct { + server *Server +} + +func newRpcHandler(server *Server) *rpcHandler { + return &rpcHandler{server} +} + +func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + handler, found := r.server.services[req.Method] + if !found { + r.server.errCh <- fmt.Errorf("RPC Method: '%v' not serviced", req.Method) + return + } + + result, err := handler(req) + if err != nil { + r.server.errCh <- fmt.Errorf("error handling: '%v': %v", req.Method, err) + return + } + + if err := conn.Reply(ctx, req.ID, result); err != nil { + r.server.errCh <- fmt.Errorf("error replying: %v", err) + } +} + +type socketConn struct { + *websocket.Conn + + reader io.Reader + writeMutex sync.Mutex + readMutex sync.Mutex +} + +func newSocketConn(conn *websocket.Conn) *socketConn { + return &socketConn{Conn: conn} +} + +func (s *socketConn) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + msgType, r, err := s.Conn.NextReader() + if err != nil { + return 0, fmt.Errorf("error getting next reader: %v", err) + } + if msgType != websocket.BinaryMessage { + return 0, fmt.Errorf("invalid message type") + } + s.reader = r + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socketConn) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + w, err := s.Conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, fmt.Errorf("error getting next writer: %v", err) + } + + n, err := w.Write(b) + if err != nil { + return 0, fmt.Errorf("error writing: %v", err) + } + + if err := w.Close(); err != nil { + return 0, fmt.Errorf("error closing writer: %v", err) + } + + return n, nil +} + +func (s *socketConn) SetDeadline(deadline time.Time) error { + if err := s.Conn.SetReadDeadline(deadline); err != nil { + return err + } + return s.Conn.SetWriteDeadline(deadline) +} diff --git a/websocket.go b/websocket.go deleted file mode 100644 index ae163e6e2..000000000 --- a/websocket.go +++ /dev/null @@ -1,105 +0,0 @@ -package liveshare - -import ( - "context" - "errors" - "io" - "net" - "sync" - "time" - - gorillawebsocket "github.com/gorilla/websocket" -) - -type websocket struct { - session *session - conn *gorillawebsocket.Conn - readMutex sync.Mutex - writeMutex sync.Mutex - reader io.Reader -} - -func newWebsocket(session *session) *websocket { - return &websocket{session: session} -} - -func (w *websocket) connect(ctx context.Context) error { - ws, _, err := gorillawebsocket.DefaultDialer.Dial(w.session.relayURI("connect"), nil) - if err != nil { - return err - } - w.conn = ws - return nil -} - -func (w *websocket) Read(b []byte) (int, error) { - w.readMutex.Lock() - defer w.readMutex.Unlock() - - if w.reader == nil { - messageType, reader, err := w.conn.NextReader() - if err != nil { - return 0, err - } - - if messageType != gorillawebsocket.BinaryMessage { - return 0, errors.New("unexpected websocket message type") - } - - w.reader = reader - } - - bytesRead, err := w.reader.Read(b) - if err != nil { - w.reader = nil - - if err == io.EOF { - err = nil - } - } - - return bytesRead, err -} - -func (w *websocket) Write(b []byte) (int, error) { - w.writeMutex.Lock() - defer w.writeMutex.Unlock() - - nextWriter, err := w.conn.NextWriter(gorillawebsocket.BinaryMessage) - if err != nil { - return 0, err - } - - bytesWritten, err := nextWriter.Write(b) - nextWriter.Close() - - return bytesWritten, err -} - -func (w *websocket) Close() error { - return w.conn.Close() -} - -func (w *websocket) LocalAddr() net.Addr { - return w.conn.LocalAddr() -} - -func (w *websocket) RemoteAddr() net.Addr { - return w.conn.RemoteAddr() -} - -func (w *websocket) SetDeadline(t time.Time) error { - if err := w.SetReadDeadline(t); err != nil { - return err - } - - return w.SetWriteDeadline(t) -} - -func (w *websocket) SetReadDeadline(t time.Time) error { - return w.conn.SetReadDeadline(t) -} - -func (w *websocket) SetWriteDeadline(t time.Time) error { - return w.conn.SetWriteDeadline(t) -}