From 91114d35c3d04245a58f78ebf2feb6bb5edde4e2 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Sat, 24 Jul 2021 03:44:20 +0000 Subject: [PATCH] More tests --- client_test.go | 19 +++++ port_forwarder_test.go | 1 + server_test.go | 186 +++++++++++++++++++++++++++++++++++++++++ test/server.go | 16 ++++ 4 files changed, 222 insertions(+) create mode 100644 port_forwarder_test.go create mode 100644 server_test.go diff --git a/client_test.go b/client_test.go index fdf566fc0..110c7e3b9 100644 --- a/client_test.go +++ b/client_test.go @@ -3,6 +3,8 @@ package liveshare import ( "context" "crypto/tls" + "encoding/json" + "errors" "fmt" "strings" "testing" @@ -48,12 +50,29 @@ func TestClientJoin(t *testing.T) { RelaySAS: "relay-sas", } joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + var joinWorkspaceReq joinWorkspaceArgs + if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { + return nil, fmt.Errorf("error unmarshaling req: %v", err) + } + if joinWorkspaceReq.ID != connection.SessionID { + return nil, errors.New("connection session id does not match") + } + if joinWorkspaceReq.ConnectionMode != "local" { + return nil, errors.New("connection mode is not local") + } + if joinWorkspaceReq.JoiningUserSessionToken != connection.SessionToken { + return nil, errors.New("connection user token does not match") + } + if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false { + return nil, errors.New("non interactive is not false") + } return joinWorkspaceResult{1}, nil } server, err := livesharetest.NewServer( livesharetest.WithPassword(connection.SessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + livesharetest.WithRelaySAS(connection.RelaySAS), ) if err != nil { t.Errorf("error creating liveshare server: %v", err) diff --git a/port_forwarder_test.go b/port_forwarder_test.go new file mode 100644 index 000000000..e3e219705 --- /dev/null +++ b/port_forwarder_test.go @@ -0,0 +1 @@ +package liveshare diff --git a/server_test.go b/server_test.go new file mode 100644 index 000000000..cc2b9adbd --- /dev/null +++ b/server_test.go @@ -0,0 +1,186 @@ +package liveshare + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestNewServerWithNotJoinedClient(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Errorf("error creating new client: %v", err) + } + if _, err := NewServer(client); err == nil { + t.Error("expected error") + } +} + +func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { + connection := Connection{ + SessionID: "session-id", + SessionToken: "session-token", + RelaySAS: "relay-sas", + } + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + return joinWorkspaceResult{1}, nil + } + opts = append( + opts, + livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + ) + testServer, err := livesharetest.NewServer( + opts..., + ) + connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https") + tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + client, err := NewClient(WithConnection(connection), tlsConfig) + if err != nil { + return nil, nil, fmt.Errorf("error creating new client: %v", err) + } + ctx := context.Background() + if err := client.Join(ctx); err != nil { + return nil, nil, fmt.Errorf("error joining client: %v", err) + } + return testServer, client, nil +} + +func TestNewServer(t *testing.T) { + testServer, client, err := newMockJoinedClient() + defer testServer.Close() + if err != nil { + t.Errorf("error creating mock joined client: %v", err) + } + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + if server == nil { + t.Error("server is nil") + } +} + +func TestServerStartSharing(t *testing.T) { + serverPort, serverProtocol := 2222, "sshd" + startSharing := func(req *jsonrpc2.Request) (interface{}, error) { + var args []interface{} + if err := json.Unmarshal(*req.Params, &args); err != nil { + return nil, fmt.Errorf("error unmarshaling request: %v", err) + } + if len(args) < 3 { + return nil, errors.New("not enough arguments to start sharing") + } + if port, ok := args[0].(float64); !ok { + return nil, errors.New("port argument is not an int") + } else if port != float64(serverPort) { + return nil, errors.New("port does not match serverPort") + } + if protocol, ok := args[1].(string); !ok { + return nil, errors.New("protocol argument is not a string") + } else if protocol != serverProtocol { + return nil, errors.New("protocol does not match serverProtocol") + } + if browseURL, ok := args[2].(string); !ok { + return nil, errors.New("browse url is not a string") + } else if browseURL != fmt.Sprintf("http://localhost:%v", serverPort) { + return nil, errors.New("browseURL does not match expected") + } + return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil + } + testServer, client, err := newMockJoinedClient( + livesharetest.WithService("serverSharing.startSharing", startSharing), + ) + defer testServer.Close() + if err != nil { + t.Errorf("error creating mock joined client: %v", err) + } + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + ctx := context.Background() + + done := make(chan error) + go func() { + if err := server.StartSharing(ctx, serverProtocol, serverPort); err != nil { + done <- fmt.Errorf("error sharing server: %v", err) + } + if server.streamName == "" || server.streamCondition == "" { + done <- errors.New("stream name or condition is blank") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} + +func TestServerGetSharedServers(t *testing.T) { + sharedServer := Port{ + SourcePort: 2222, + StreamName: "stream-name", + StreamCondition: "stream-condition", + } + getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) { + return Ports{&sharedServer}, nil + } + testServer, client, err := newMockJoinedClient( + livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), + ) + if err != nil { + t.Errorf("error creating new mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + ctx := context.Background() + done := make(chan error) + go func() { + ports, err := server.GetSharedServers(ctx) + if err != nil { + done <- fmt.Errorf("error getting shared servers: %v", err) + } + if len(ports) < 1 { + done <- errors.New("not enough ports returned") + } + if ports[0].SourcePort != sharedServer.SourcePort { + done <- errors.New("source port does not match") + } + if ports[0].StreamName != sharedServer.StreamName { + done <- errors.New("stream name does not match") + } + if ports[0].StreamCondition != sharedServer.StreamCondition { + done <- errors.New("stream condiion does not match") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} + +func TestServerUpdateSharedVisibility(t *testing.T) { + +} diff --git a/test/server.go b/test/server.go index ed8666cce..abb7ac96a 100644 --- a/test/server.go +++ b/test/server.go @@ -20,6 +20,7 @@ import ( type Server struct { password string services map[string]RpcHandleFunc + relaySAS string sshConfig *ssh.ServerConfig httptestServer *httptest.Server @@ -73,6 +74,13 @@ func WithService(serviceName string, handler RpcHandleFunc) ServerOption { } } +func WithRelaySAS(sas string) ServerOption { + return func(s *Server) error { + s.relaySAS = sas + return nil + } +} + func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { if string(password) == serverPassword { @@ -98,6 +106,14 @@ var upgrader = websocket.Upgrader{} func newConnection(server *Server) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { + if server.relaySAS != "" { + // validate the sas key + sasParam := req.URL.Query().Get("sb-hc-token") + if sasParam != server.relaySAS { + server.errCh <- errors.New("error validating sas") + return + } + } c, err := upgrader.Upgrade(w, req, nil) if err != nil { server.errCh <- fmt.Errorf("error upgrading connection: %v", err)