diff --git a/client.go b/client.go index 3f9345ce4..b51e25ea6 100644 --- a/client.go +++ b/client.go @@ -13,68 +13,65 @@ package liveshare import ( "context" "crypto/tls" + "errors" "fmt" + "net/url" + "strings" "github.com/opentracing/opentracing-go" "golang.org/x/crypto/ssh" ) -// A client capable of joining a Live Share workspace. -type client struct { - connection Connection - tlsConfig *tls.Config +// An Options specifies Live Share connection parameters. +type Options struct { + SessionID string + SessionToken string // token for SSH session + RelaySAS string + RelayEndpoint string + TLSConfig *tls.Config // (optional) } -// An Option updates the initial configuration state of a Live Share connection. -type Option func(*client) error - -// WithConnection is a Option that accepts a Connection. -// -// TODO(adonovan): WithConnection is not optional, so it should not be -// not an Option. We should make Connection a mandatory parameter of -// Connect, at which point, why not just merge -// client+Option+Connection, rename it to Options, do away with the -// function mechanism, and express TLS config (etc) as public fields -// of Options with sensible zero values, like websocket.Dialer, etc? -func WithConnection(connection Connection) Option { - return func(cli *client) error { - if err := connection.validate(); err != nil { - return err - } - - cli.connection = connection - return nil +// uri returns a websocket URL for the specified options. +func (opts *Options) uri(action string) (string, error) { + if opts.SessionID == "" { + return "", errors.New("SessionID is required") } -} - -// WithTLSConfig returns a Connect option that sets the TLS configuration. -func WithTLSConfig(tlsConfig *tls.Config) Option { - return func(cli *client) error { - cli.tlsConfig = tlsConfig - return nil + if opts.RelaySAS == "" { + return "", errors.New("RelaySAS is required") } + if opts.RelayEndpoint == "" { + return "", errors.New("RelayEndpoint is required") + } + + sas := url.QueryEscape(opts.RelaySAS) + uri := opts.RelayEndpoint + uri = strings.Replace(uri, "sb:", "wss:", -1) + uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) + uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas + return uri, nil } // Connect connects to a Live Share workspace specified by the // options, and returns a session representing the connection. // The caller must call the session's Close method to end the session. -func Connect(ctx context.Context, opts ...Option) (*Session, error) { - cli := new(client) - for _, opt := range opts { - if err := opt(cli); err != nil { - return nil, fmt.Errorf("error applying Live Share connect option: %w", err) - } +func Connect(ctx context.Context, opts Options) (*Session, error) { + uri, err := opts.uri("connect") + if err != nil { + return nil, err } span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") defer span.Finish() - sock := newSocket(cli.connection, cli.tlsConfig) + sock := newSocket(uri, opts.TLSConfig) if err := sock.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting websocket: %w", err) } - ssh := newSSHSession(cli.connection.SessionToken, sock) + if opts.SessionToken == "" { + return nil, errors.New("SessionToken is required") + } + ssh := newSSHSession(opts.SessionToken, sock) if err := ssh.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting to ssh session: %w", err) } @@ -83,9 +80,9 @@ func Connect(ctx context.Context, opts ...Option) (*Session, error) { rpc.connect(ctx) args := joinWorkspaceArgs{ - ID: cli.connection.SessionID, + ID: opts.SessionID, ConnectionMode: "local", - JoiningUserSessionToken: cli.connection.SessionToken, + JoiningUserSessionToken: opts.SessionToken, ClientCapabilities: clientCapabilities{ IsNonInteractive: false, }, diff --git a/client_test.go b/client_test.go index 369c53b28..2b95f738f 100644 --- a/client_test.go +++ b/client_test.go @@ -14,7 +14,7 @@ import ( ) func TestConnect(t *testing.T) { - connection := Connection{ + opts := Options{ SessionID: "session-id", SessionToken: "session-token", RelaySAS: "relay-sas", @@ -24,13 +24,13 @@ func TestConnect(t *testing.T) { if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { return nil, fmt.Errorf("error unmarshaling req: %v", err) } - if joinWorkspaceReq.ID != connection.SessionID { + if joinWorkspaceReq.ID != opts.SessionID { return nil, errors.New("connection session id does not match") } if joinWorkspaceReq.ConnectionMode != "local" { return nil, errors.New("connection mode is not local") } - if joinWorkspaceReq.JoiningUserSessionToken != connection.SessionToken { + if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken { return nil, errors.New("connection user token does not match") } if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false { @@ -40,23 +40,23 @@ func TestConnect(t *testing.T) { } server, err := livesharetest.NewServer( - livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithPassword(opts.SessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), - livesharetest.WithRelaySAS(connection.RelaySAS), + livesharetest.WithRelaySAS(opts.RelaySAS), ) if err != nil { t.Errorf("error creating Live Share server: %v", err) } defer server.Close() - connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") + opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") ctx := context.Background() - tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} done := make(chan error) go func() { - _, err := Connect(ctx, WithConnection(connection), tlsConfig) // ignore session + _, err := Connect(ctx, opts) // ignore session done <- err }() diff --git a/connection.go b/connection.go deleted file mode 100644 index f402e4bb9..000000000 --- a/connection.go +++ /dev/null @@ -1,44 +0,0 @@ -package liveshare - -import ( - "errors" - "net/url" - "strings" -) - -// A Connection represents a set of values necessary to join a liveshare connection. -type Connection struct { - SessionID string - SessionToken string - RelaySAS string - RelayEndpoint string -} - -func (r Connection) validate() error { - if r.SessionID == "" { - return errors.New("connection SessionID is required") - } - - if r.SessionToken == "" { - return errors.New("connection SessionToken is required") - } - - if r.RelaySAS == "" { - return errors.New("connection RelaySAS is required") - } - - if r.RelayEndpoint == "" { - return errors.New("connection RelayEndpoint is required") - } - - return nil -} - -func (r Connection) uri(action string) string { - sas := url.QueryEscape(r.RelaySAS) - uri := r.RelayEndpoint - uri = strings.Replace(uri, "sb:", "wss:", -1) - uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) - uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas - return uri -} diff --git a/connection_test.go b/connection_test.go deleted file mode 100644 index f42ec4189..000000000 --- a/connection_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package liveshare - -import "testing" - -func TestConnectionValid(t *testing.T) { - conn := Connection{"sess-id", "sess-token", "sas", "endpoint"} - if err := conn.validate(); err != nil { - t.Error(err) - } -} - -func TestConnectionInvalid(t *testing.T) { - conn := Connection{"", "sess-token", "sas", "endpoint"} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"sess-id", "", "sas", "endpoint"} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"sess-id", "sess-token", "", "endpoint"} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"sess-id", "sess-token", "sas", ""} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"", "", "", ""} - if err := conn.validate(); err == nil { - t.Error(err) - } -} - -func TestConnectionURI(t *testing.T) { - conn := Connection{"sess-id", "sess-token", "sas", "sb://endpoint/.net/liveshare"} - uri := conn.uri("connect") - if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" { - t.Errorf("uri is not correct, got: '%v'", uri) - } -} diff --git a/options_test.go b/options_test.go new file mode 100644 index 000000000..830c59104 --- /dev/null +++ b/options_test.go @@ -0,0 +1,56 @@ +package liveshare + +import ( + "context" + "testing" +) + +func TestBadOptions(t *testing.T) { + goodOptions := Options{ + SessionID: "sess-id", + SessionToken: "sess-token", + RelaySAS: "sas", + RelayEndpoint: "endpoint", + } + + opts := goodOptions + opts.SessionID = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.SessionToken = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.RelaySAS = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.RelayEndpoint = "" + checkBadOptions(t, opts) + + opts = Options{} + checkBadOptions(t, opts) +} + +func checkBadOptions(t *testing.T, opts Options) { + if _, err := Connect(context.Background(), opts); err == nil { + t.Errorf("Connect(%+v): no error", opts) + } +} + +func TestOptionsURI(t *testing.T) { + opts := Options{ + SessionID: "sess-id", + SessionToken: "sess-token", + RelaySAS: "sas", + RelayEndpoint: "sb://endpoint/.net/liveshare", + } + uri, err := opts.uri("connect") + if err != nil { + t.Fatal(err) + } + if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" { + t.Errorf("uri is not correct, got: '%v'", uri) + } +} diff --git a/session_test.go b/session_test.go index 3be90cb0e..cd0a7b474 100644 --- a/session_test.go +++ b/session_test.go @@ -14,25 +14,23 @@ import ( ) func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { - connection := Connection{ - SessionID: "session-id", - SessionToken: "session-token", - RelaySAS: "relay-sas", - } joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { return joinWorkspaceResult{1}, nil } + const sessionToken = "session-token" opts = append( opts, - livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithPassword(sessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), ) - testServer, err := livesharetest.NewServer( - opts..., - ) - connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https") - tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) - session, err := Connect(context.Background(), WithConnection(connection), tlsConfig) + testServer, err := livesharetest.NewServer(opts...) + session, err := Connect(context.Background(), Options{ + SessionID: "session-id", + SessionToken: sessionToken, + RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), + RelaySAS: "relay-sas", + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + }) if err != nil { return nil, nil, fmt.Errorf("error connecting to Live Share: %v", err) } diff --git a/socket.go b/socket.go index 8744eeb96..f66436f65 100644 --- a/socket.go +++ b/socket.go @@ -19,8 +19,8 @@ type socket struct { reader io.Reader } -func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket { - return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig} +func newSocket(uri string, tlsConfig *tls.Config) *socket { + return &socket{addr: uri, tlsConfig: tlsConfig} } func (s *socket) connect(ctx context.Context) error {