refactor Options API

This commit is contained in:
Alan Donovan 2021-09-21 15:23:02 -04:00
parent b3b675d108
commit f8a8713520
7 changed files with 113 additions and 147 deletions

View file

@ -13,68 +13,65 @@ package liveshare
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/url"
"strings"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/ssh"
)
// A client capable of joining a Live Share workspace.
type client struct {
connection Connection
tlsConfig *tls.Config
// An Options specifies Live Share connection parameters.
type Options struct {
SessionID string
SessionToken string // token for SSH session
RelaySAS string
RelayEndpoint string
TLSConfig *tls.Config // (optional)
}
// An Option updates the initial configuration state of a Live Share connection.
type Option func(*client) error
// WithConnection is a Option that accepts a Connection.
//
// TODO(adonovan): WithConnection is not optional, so it should not be
// not an Option. We should make Connection a mandatory parameter of
// Connect, at which point, why not just merge
// client+Option+Connection, rename it to Options, do away with the
// function mechanism, and express TLS config (etc) as public fields
// of Options with sensible zero values, like websocket.Dialer, etc?
func WithConnection(connection Connection) Option {
return func(cli *client) error {
if err := connection.validate(); err != nil {
return err
}
cli.connection = connection
return nil
// uri returns a websocket URL for the specified options.
func (opts *Options) uri(action string) (string, error) {
if opts.SessionID == "" {
return "", errors.New("SessionID is required")
}
}
// WithTLSConfig returns a Connect option that sets the TLS configuration.
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cli *client) error {
cli.tlsConfig = tlsConfig
return nil
if opts.RelaySAS == "" {
return "", errors.New("RelaySAS is required")
}
if opts.RelayEndpoint == "" {
return "", errors.New("RelayEndpoint is required")
}
sas := url.QueryEscape(opts.RelaySAS)
uri := opts.RelayEndpoint
uri = strings.Replace(uri, "sb:", "wss:", -1)
uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1)
uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas
return uri, nil
}
// Connect connects to a Live Share workspace specified by the
// options, and returns a session representing the connection.
// The caller must call the session's Close method to end the session.
func Connect(ctx context.Context, opts ...Option) (*Session, error) {
cli := new(client)
for _, opt := range opts {
if err := opt(cli); err != nil {
return nil, fmt.Errorf("error applying Live Share connect option: %w", err)
}
func Connect(ctx context.Context, opts Options) (*Session, error) {
uri, err := opts.uri("connect")
if err != nil {
return nil, err
}
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
defer span.Finish()
sock := newSocket(cli.connection, cli.tlsConfig)
sock := newSocket(uri, opts.TLSConfig)
if err := sock.connect(ctx); err != nil {
return nil, fmt.Errorf("error connecting websocket: %w", err)
}
ssh := newSSHSession(cli.connection.SessionToken, sock)
if opts.SessionToken == "" {
return nil, errors.New("SessionToken is required")
}
ssh := newSSHSession(opts.SessionToken, sock)
if err := ssh.connect(ctx); err != nil {
return nil, fmt.Errorf("error connecting to ssh session: %w", err)
}
@ -83,9 +80,9 @@ func Connect(ctx context.Context, opts ...Option) (*Session, error) {
rpc.connect(ctx)
args := joinWorkspaceArgs{
ID: cli.connection.SessionID,
ID: opts.SessionID,
ConnectionMode: "local",
JoiningUserSessionToken: cli.connection.SessionToken,
JoiningUserSessionToken: opts.SessionToken,
ClientCapabilities: clientCapabilities{
IsNonInteractive: false,
},

View file

@ -14,7 +14,7 @@ import (
)
func TestConnect(t *testing.T) {
connection := Connection{
opts := Options{
SessionID: "session-id",
SessionToken: "session-token",
RelaySAS: "relay-sas",
@ -24,13 +24,13 @@ func TestConnect(t *testing.T) {
if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil {
return nil, fmt.Errorf("error unmarshaling req: %v", err)
}
if joinWorkspaceReq.ID != connection.SessionID {
if joinWorkspaceReq.ID != opts.SessionID {
return nil, errors.New("connection session id does not match")
}
if joinWorkspaceReq.ConnectionMode != "local" {
return nil, errors.New("connection mode is not local")
}
if joinWorkspaceReq.JoiningUserSessionToken != connection.SessionToken {
if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken {
return nil, errors.New("connection user token does not match")
}
if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false {
@ -40,23 +40,23 @@ func TestConnect(t *testing.T) {
}
server, err := livesharetest.NewServer(
livesharetest.WithPassword(connection.SessionToken),
livesharetest.WithPassword(opts.SessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
livesharetest.WithRelaySAS(connection.RelaySAS),
livesharetest.WithRelaySAS(opts.RelaySAS),
)
if err != nil {
t.Errorf("error creating Live Share server: %v", err)
}
defer server.Close()
connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https")
opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https")
ctx := context.Background()
tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true})
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true}
done := make(chan error)
go func() {
_, err := Connect(ctx, WithConnection(connection), tlsConfig) // ignore session
_, err := Connect(ctx, opts) // ignore session
done <- err
}()

View file

@ -1,44 +0,0 @@
package liveshare
import (
"errors"
"net/url"
"strings"
)
// A Connection represents a set of values necessary to join a liveshare connection.
type Connection struct {
SessionID string
SessionToken string
RelaySAS string
RelayEndpoint string
}
func (r Connection) validate() error {
if r.SessionID == "" {
return errors.New("connection SessionID is required")
}
if r.SessionToken == "" {
return errors.New("connection SessionToken is required")
}
if r.RelaySAS == "" {
return errors.New("connection RelaySAS is required")
}
if r.RelayEndpoint == "" {
return errors.New("connection RelayEndpoint is required")
}
return nil
}
func (r Connection) uri(action string) string {
sas := url.QueryEscape(r.RelaySAS)
uri := r.RelayEndpoint
uri = strings.Replace(uri, "sb:", "wss:", -1)
uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1)
uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas
return uri
}

View file

@ -1,41 +0,0 @@
package liveshare
import "testing"
func TestConnectionValid(t *testing.T) {
conn := Connection{"sess-id", "sess-token", "sas", "endpoint"}
if err := conn.validate(); err != nil {
t.Error(err)
}
}
func TestConnectionInvalid(t *testing.T) {
conn := Connection{"", "sess-token", "sas", "endpoint"}
if err := conn.validate(); err == nil {
t.Error(err)
}
conn = Connection{"sess-id", "", "sas", "endpoint"}
if err := conn.validate(); err == nil {
t.Error(err)
}
conn = Connection{"sess-id", "sess-token", "", "endpoint"}
if err := conn.validate(); err == nil {
t.Error(err)
}
conn = Connection{"sess-id", "sess-token", "sas", ""}
if err := conn.validate(); err == nil {
t.Error(err)
}
conn = Connection{"", "", "", ""}
if err := conn.validate(); err == nil {
t.Error(err)
}
}
func TestConnectionURI(t *testing.T) {
conn := Connection{"sess-id", "sess-token", "sas", "sb://endpoint/.net/liveshare"}
uri := conn.uri("connect")
if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" {
t.Errorf("uri is not correct, got: '%v'", uri)
}
}

56
options_test.go Normal file
View file

@ -0,0 +1,56 @@
package liveshare
import (
"context"
"testing"
)
func TestBadOptions(t *testing.T) {
goodOptions := Options{
SessionID: "sess-id",
SessionToken: "sess-token",
RelaySAS: "sas",
RelayEndpoint: "endpoint",
}
opts := goodOptions
opts.SessionID = ""
checkBadOptions(t, opts)
opts = goodOptions
opts.SessionToken = ""
checkBadOptions(t, opts)
opts = goodOptions
opts.RelaySAS = ""
checkBadOptions(t, opts)
opts = goodOptions
opts.RelayEndpoint = ""
checkBadOptions(t, opts)
opts = Options{}
checkBadOptions(t, opts)
}
func checkBadOptions(t *testing.T, opts Options) {
if _, err := Connect(context.Background(), opts); err == nil {
t.Errorf("Connect(%+v): no error", opts)
}
}
func TestOptionsURI(t *testing.T) {
opts := Options{
SessionID: "sess-id",
SessionToken: "sess-token",
RelaySAS: "sas",
RelayEndpoint: "sb://endpoint/.net/liveshare",
}
uri, err := opts.uri("connect")
if err != nil {
t.Fatal(err)
}
if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" {
t.Errorf("uri is not correct, got: '%v'", uri)
}
}

View file

@ -14,25 +14,23 @@ import (
)
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
connection := Connection{
SessionID: "session-id",
SessionToken: "session-token",
RelaySAS: "relay-sas",
}
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
}
const sessionToken = "session-token"
opts = append(
opts,
livesharetest.WithPassword(connection.SessionToken),
livesharetest.WithPassword(sessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
)
testServer, err := livesharetest.NewServer(
opts...,
)
connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https")
tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true})
session, err := Connect(context.Background(), WithConnection(connection), tlsConfig)
testServer, err := livesharetest.NewServer(opts...)
session, err := Connect(context.Background(), Options{
SessionID: "session-id",
SessionToken: sessionToken,
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
RelaySAS: "relay-sas",
TLSConfig: &tls.Config{InsecureSkipVerify: true},
})
if err != nil {
return nil, nil, fmt.Errorf("error connecting to Live Share: %v", err)
}

View file

@ -19,8 +19,8 @@ type socket struct {
reader io.Reader
}
func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket {
return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig}
func newSocket(uri string, tlsConfig *tls.Config) *socket {
return &socket{addr: uri, tlsConfig: tlsConfig}
}
func (s *socket) connect(ctx context.Context) error {