diff --git a/client.go b/client.go index ba9d2f5e7..b51e25ea6 100644 --- a/client.go +++ b/client.go @@ -1,74 +1,94 @@ +// Package liveshare is a Go client library for the Visual Studio Live Share +// service, which provides collaborative, distibuted editing and debugging. +// See https://docs.microsoft.com/en-us/visualstudio/liveshare for an overview. +// +// It provides the ability for a Go program to connect to a Live Share +// workspace (Connect), to expose a TCP port on a remote host +// (UpdateSharedVisibility), to start an SSH server listening on an +// exposed port (StartSSHServer), and to forward connections between +// the remote port and a local listening TCP port (ForwardToListener) +// or a local Go reader/writer (Forward). package liveshare import ( "context" "crypto/tls" + "errors" "fmt" + "net/url" + "strings" "github.com/opentracing/opentracing-go" "golang.org/x/crypto/ssh" ) -// A Client capable of joining a Live Share workspace. -type Client struct { - connection Connection - tlsConfig *tls.Config +// An Options specifies Live Share connection parameters. +type Options struct { + SessionID string + SessionToken string // token for SSH session + RelaySAS string + RelayEndpoint string + TLSConfig *tls.Config // (optional) } -// A ClientOption is a function that modifies a client -type ClientOption func(*Client) error - -// NewClient accepts a range of options, applies them and returns a client -func NewClient(opts ...ClientOption) (*Client, error) { - client := new(Client) - - for _, o := range opts { - if err := o(client); err != nil { - return nil, err - } +// uri returns a websocket URL for the specified options. +func (opts *Options) uri(action string) (string, error) { + if opts.SessionID == "" { + return "", errors.New("SessionID is required") + } + if opts.RelaySAS == "" { + return "", errors.New("RelaySAS is required") + } + if opts.RelayEndpoint == "" { + return "", errors.New("RelayEndpoint is required") } - return client, nil + sas := url.QueryEscape(opts.RelaySAS) + uri := opts.RelayEndpoint + uri = strings.Replace(uri, "sb:", "wss:", -1) + uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) + uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas + return uri, nil } -// WithConnection is a ClientOption that accepts a Connection -func WithConnection(connection Connection) ClientOption { - return func(c *Client) error { - if err := connection.validate(); err != nil { - return err - } - - c.connection = connection - return nil +// Connect connects to a Live Share workspace specified by the +// options, and returns a session representing the connection. +// The caller must call the session's Close method to end the session. +func Connect(ctx context.Context, opts Options) (*Session, error) { + uri, err := opts.uri("connect") + if err != nil { + return nil, err } -} -func WithTLSConfig(tlsConfig *tls.Config) ClientOption { - return func(c *Client) error { - c.tlsConfig = tlsConfig - return nil - } -} - -// JoinWorkspace connects the client to the server's Live Share -// workspace and returns a session representing their connection. -func (c *Client) JoinWorkspace(ctx context.Context) (*Session, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "Client.JoinWorkspace") + span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") defer span.Finish() - clientSocket := newSocket(c.connection, c.tlsConfig) - if err := clientSocket.connect(ctx); err != nil { + sock := newSocket(uri, opts.TLSConfig) + if err := sock.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting websocket: %w", err) } - ssh := newSSHSession(c.connection.SessionToken, clientSocket) + if opts.SessionToken == "" { + return nil, errors.New("SessionToken is required") + } + ssh := newSSHSession(opts.SessionToken, sock) if err := ssh.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting to ssh session: %w", err) } rpc := newRPCClient(ssh) rpc.connect(ctx) - if _, err := c.joinWorkspace(ctx, rpc); err != nil { + + args := joinWorkspaceArgs{ + ID: opts.SessionID, + ConnectionMode: "local", + JoiningUserSessionToken: opts.SessionToken, + ClientCapabilities: clientCapabilities{ + IsNonInteractive: false, + }, + } + var result joinWorkspaceResult + if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { return nil, fmt.Errorf("error joining Live Share workspace: %w", err) } @@ -96,24 +116,6 @@ type channelID struct { name, condition string } -func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) { - args := joinWorkspaceArgs{ - ID: c.connection.SessionID, - ConnectionMode: "local", - JoiningUserSessionToken: c.connection.SessionToken, - ClientCapabilities: clientCapabilities{ - IsNonInteractive: false, - }, - } - - var result joinWorkspaceResult - if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { - return nil, fmt.Errorf("error making workspace.joinWorkspace call: %w", err) - } - - return &result, nil -} - func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) { type getStreamArgs struct { StreamName string `json:"streamName"` diff --git a/client_test.go b/client_test.go index c1e61f6e8..2b95f738f 100644 --- a/client_test.go +++ b/client_test.go @@ -13,38 +13,8 @@ import ( "github.com/sourcegraph/jsonrpc2" ) -func TestNewClient(t *testing.T) { - client, err := NewClient() - if err != nil { - t.Errorf("error creating new client: %v", err) - } - if client == nil { - t.Error("client is nil") - } -} - -func TestNewClientValidConnection(t *testing.T) { - connection := Connection{"1", "2", "3", "4"} - - client, err := NewClient(WithConnection(connection)) - if err != nil { - t.Errorf("error creating new client: %v", err) - } - if client == nil { - t.Error("client is nil") - } -} - -func TestNewClientWithInvalidConnection(t *testing.T) { - connection := Connection{} - - if _, err := NewClient(WithConnection(connection)); err == nil { - t.Error("err is nil") - } -} - -func TestJoinSession(t *testing.T) { - connection := Connection{ +func TestConnect(t *testing.T) { + opts := Options{ SessionID: "session-id", SessionToken: "session-token", RelaySAS: "relay-sas", @@ -54,13 +24,13 @@ func TestJoinSession(t *testing.T) { if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { return nil, fmt.Errorf("error unmarshaling req: %v", err) } - if joinWorkspaceReq.ID != connection.SessionID { + if joinWorkspaceReq.ID != opts.SessionID { return nil, errors.New("connection session id does not match") } if joinWorkspaceReq.ConnectionMode != "local" { return nil, errors.New("connection mode is not local") } - if joinWorkspaceReq.JoiningUserSessionToken != connection.SessionToken { + if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken { return nil, errors.New("connection user token does not match") } if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false { @@ -70,34 +40,24 @@ func TestJoinSession(t *testing.T) { } server, err := livesharetest.NewServer( - livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithPassword(opts.SessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), - livesharetest.WithRelaySAS(connection.RelaySAS), + livesharetest.WithRelaySAS(opts.RelaySAS), ) if err != nil { t.Errorf("error creating Live Share server: %v", err) } defer server.Close() - connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") + opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") ctx := context.Background() - tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) - client, err := NewClient(WithConnection(connection), tlsConfig) - if err != nil { - t.Errorf("error creating new client: %v", err) - } + opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} done := make(chan error) go func() { - session, err := client.JoinWorkspace(ctx) - if err != nil { - done <- fmt.Errorf("error joining workspace: %v", err) - return - } - _ = session - - done <- nil + _, err := Connect(ctx, opts) // ignore session + done <- err }() select { diff --git a/connection.go b/connection.go deleted file mode 100644 index c1a4632c8..000000000 --- a/connection.go +++ /dev/null @@ -1,44 +0,0 @@ -package liveshare - -import ( - "errors" - "net/url" - "strings" -) - -// A Connection represents a set of values necessary to join a liveshare connection -type Connection struct { - SessionID string - SessionToken string - RelaySAS string - RelayEndpoint string -} - -func (r Connection) validate() error { - if r.SessionID == "" { - return errors.New("connection SessionID is required") - } - - if r.SessionToken == "" { - return errors.New("connection SessionToken is required") - } - - if r.RelaySAS == "" { - return errors.New("connection RelaySAS is required") - } - - if r.RelayEndpoint == "" { - return errors.New("connection RelayEndpoint is required") - } - - return nil -} - -func (r Connection) uri(action string) string { - sas := url.QueryEscape(r.RelaySAS) - uri := r.RelayEndpoint - uri = strings.Replace(uri, "sb:", "wss:", -1) - uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) - uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas - return uri -} diff --git a/connection_test.go b/connection_test.go deleted file mode 100644 index f42ec4189..000000000 --- a/connection_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package liveshare - -import "testing" - -func TestConnectionValid(t *testing.T) { - conn := Connection{"sess-id", "sess-token", "sas", "endpoint"} - if err := conn.validate(); err != nil { - t.Error(err) - } -} - -func TestConnectionInvalid(t *testing.T) { - conn := Connection{"", "sess-token", "sas", "endpoint"} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"sess-id", "", "sas", "endpoint"} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"sess-id", "sess-token", "", "endpoint"} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"sess-id", "sess-token", "sas", ""} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"", "", "", ""} - if err := conn.validate(); err == nil { - t.Error(err) - } -} - -func TestConnectionURI(t *testing.T) { - conn := Connection{"sess-id", "sess-token", "sas", "sb://endpoint/.net/liveshare"} - uri := conn.uri("connect") - if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" { - t.Errorf("uri is not correct, got: '%v'", uri) - } -} diff --git a/options_test.go b/options_test.go new file mode 100644 index 000000000..830c59104 --- /dev/null +++ b/options_test.go @@ -0,0 +1,56 @@ +package liveshare + +import ( + "context" + "testing" +) + +func TestBadOptions(t *testing.T) { + goodOptions := Options{ + SessionID: "sess-id", + SessionToken: "sess-token", + RelaySAS: "sas", + RelayEndpoint: "endpoint", + } + + opts := goodOptions + opts.SessionID = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.SessionToken = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.RelaySAS = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.RelayEndpoint = "" + checkBadOptions(t, opts) + + opts = Options{} + checkBadOptions(t, opts) +} + +func checkBadOptions(t *testing.T, opts Options) { + if _, err := Connect(context.Background(), opts); err == nil { + t.Errorf("Connect(%+v): no error", opts) + } +} + +func TestOptionsURI(t *testing.T) { + opts := Options{ + SessionID: "sess-id", + SessionToken: "sess-token", + RelaySAS: "sas", + RelayEndpoint: "sb://endpoint/.net/liveshare", + } + uri, err := opts.uri("connect") + if err != nil { + t.Fatal(err) + } + if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" { + t.Errorf("uri is not correct, got: '%v'", uri) + } +} diff --git a/session_test.go b/session_test.go index 54aab16c8..cd0a7b474 100644 --- a/session_test.go +++ b/session_test.go @@ -14,32 +14,25 @@ import ( ) func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { - connection := Connection{ - SessionID: "session-id", - SessionToken: "session-token", - RelaySAS: "relay-sas", - } joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { return joinWorkspaceResult{1}, nil } + const sessionToken = "session-token" opts = append( opts, - livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithPassword(sessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), ) - testServer, err := livesharetest.NewServer( - opts..., - ) - connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https") - tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) - client, err := NewClient(WithConnection(connection), tlsConfig) + testServer, err := livesharetest.NewServer(opts...) + session, err := Connect(context.Background(), Options{ + SessionID: "session-id", + SessionToken: sessionToken, + RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), + RelaySAS: "relay-sas", + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + }) if err != nil { - return nil, nil, fmt.Errorf("error creating new client: %v", err) - } - ctx := context.Background() - session, err := client.JoinWorkspace(ctx) - if err != nil { - return nil, nil, fmt.Errorf("error joining workspace: %v", err) + return nil, nil, fmt.Errorf("error connecting to Live Share: %v", err) } return testServer, session, nil } diff --git a/socket.go b/socket.go index 8744eeb96..f66436f65 100644 --- a/socket.go +++ b/socket.go @@ -19,8 +19,8 @@ type socket struct { reader io.Reader } -func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket { - return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig} +func newSocket(uri string, tlsConfig *tls.Config) *socket { + return &socket{addr: uri, tlsConfig: tlsConfig} } func (s *socket) connect(ctx context.Context) error {