From 8f5d6bb672e889ed8723a7f5bbc22da1c0a9ef12 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 7 Oct 2021 15:14:42 -0400 Subject: [PATCH] Tests for most of the new behavior - Made the heartbeat interval configurable for easier testing - Moved span to the top of connect to capture the full execution --- pkg/liveshare/client.go | 9 +- pkg/liveshare/client_test.go | 1 + pkg/liveshare/options_test.go | 1 + pkg/liveshare/port_forwarder_test.go | 27 ++++- pkg/liveshare/session.go | 6 +- pkg/liveshare/session_test.go | 166 +++++++++++++++++++++++++++ 6 files changed, 201 insertions(+), 9 deletions(-) diff --git a/pkg/liveshare/client.go b/pkg/liveshare/client.go index ccf57b08a..c3e92004d 100644 --- a/pkg/liveshare/client.go +++ b/pkg/liveshare/client.go @@ -17,6 +17,7 @@ import ( "fmt" "net/url" "strings" + "time" "github.com/opentracing/opentracing-go" "golang.org/x/crypto/ssh" @@ -76,6 +77,9 @@ func (opts *Options) uri(action string) (string, error) { // options, and returns a session representing the connection. // The caller must call the session's Close method to end the session. func Connect(ctx context.Context, opts Options) (*Session, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") + defer span.Finish() + uri, err := opts.uri("connect") if err != nil { return nil, err @@ -86,9 +90,6 @@ func Connect(ctx context.Context, opts Options) (*Session, error) { sessionLogger = opts.Logger } - span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") - defer span.Finish() - sock := newSocket(uri, opts.TLSConfig) if err := sock.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting websocket: %w", err) @@ -125,7 +126,7 @@ func Connect(ctx context.Context, opts Options) (*Session, error) { keepAliveReason: make(chan string, 1), logger: sessionLogger, } - go s.heartbeat(ctx) + go s.heartbeat(ctx, 1*time.Minute) return s, nil } diff --git a/pkg/liveshare/client_test.go b/pkg/liveshare/client_test.go index 46807a22e..a775ba4af 100644 --- a/pkg/liveshare/client_test.go +++ b/pkg/liveshare/client_test.go @@ -15,6 +15,7 @@ import ( func TestConnect(t *testing.T) { opts := Options{ + ClientName: "liveshare-client", SessionID: "session-id", SessionToken: "session-token", RelaySAS: "relay-sas", diff --git a/pkg/liveshare/options_test.go b/pkg/liveshare/options_test.go index 830c59104..d244193b4 100644 --- a/pkg/liveshare/options_test.go +++ b/pkg/liveshare/options_test.go @@ -41,6 +41,7 @@ func checkBadOptions(t *testing.T, opts Options) { func TestOptionsURI(t *testing.T) { opts := Options{ + ClientName: "liveshare-client", SessionID: "sess-id", SessionToken: "sess-token", RelaySAS: "sas", diff --git a/pkg/liveshare/port_forwarder_test.go b/pkg/liveshare/port_forwarder_test.go index 624428dda..c5b61d430 100644 --- a/pkg/liveshare/port_forwarder_test.go +++ b/pkg/liveshare/port_forwarder_test.go @@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) { t.Errorf("create mock client: %w", err) } defer testServer.Close() - pf := NewPortForwarder(session, "ssh", 80) + pf := NewPortForwarder(session, "ssh", 80, false) if pf == nil { t.Error("port forwarder is nil") } @@ -58,7 +58,7 @@ func TestPortForwarderStart(t *testing.T) { done := make(chan error) go func() { const name, remote = "ssh", 8000 - done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen) + done <- NewPortForwarder(session, name, remote, false).ForwardToListener(ctx, listen) }() go func() { @@ -93,3 +93,26 @@ func TestPortForwarderStart(t *testing.T) { } } } + +func TestPortForwarderTrafficMonitor(t *testing.T) { + buf := bytes.NewBufferString("some-input") + session := &Session{keepAliveReason: make(chan string, 1)} + trafficType := "io" + + tm := newTrafficMonitor(buf, session, trafficType) + l := len(buf.Bytes()) + + bb := make([]byte, l) + n, err := tm.Read(bb) + if err != nil { + t.Errorf("failed to read from traffic monitor: %w", err) + } + if n != l { + t.Errorf("expected to read %d bytes, got %d", l, n) + } + + keepAliveReason := <-session.keepAliveReason + if keepAliveReason != trafficType { + t.Errorf("expected keep alive reason to be %s, got %s", trafficType, keepAliveReason) + } +} diff --git a/pkg/liveshare/session.go b/pkg/liveshare/session.go index 329ea1a2e..4815ae77a 100644 --- a/pkg/liveshare/session.go +++ b/pkg/liveshare/session.go @@ -103,10 +103,10 @@ func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) { return port, response.User, nil } -// heartbeat ticks every minute and sends a signal to the Live Share host to keep +// heartbeat ticks every interval and sends a signal to the Live Share host to keep // the connection alive if there is a reason to do so. -func (s *Session) heartbeat(ctx context.Context) { - ticker := time.NewTicker(1 * time.Minute) +func (s *Session) heartbeat(ctx context.Context, interval time.Duration) { + ticker := time.NewTicker(interval) defer ticker.Stop() for { diff --git a/pkg/liveshare/session_test.go b/pkg/liveshare/session_test.go index 7f0b573b5..fdbcab2b5 100644 --- a/pkg/liveshare/session_test.go +++ b/pkg/liveshare/session_test.go @@ -1,6 +1,7 @@ package liveshare import ( + "bytes" "context" "crypto/tls" "encoding/json" @@ -8,11 +9,14 @@ import ( "fmt" "strings" "testing" + "time" livesharetest "github.com/cli/cli/v2/pkg/liveshare/test" "github.com/sourcegraph/jsonrpc2" ) +const mockClientName = "liveshare-client" + func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { return joinWorkspaceResult{1}, nil @@ -29,6 +33,7 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, } session, err := Connect(context.Background(), Options{ + ClientName: mockClientName, SessionID: "session-id", SessionToken: sessionToken, RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), @@ -221,3 +226,164 @@ func TestInvalidHostKey(t *testing.T) { t.Error("expected invalid host key error, got: nil") } } + +func TestKeepAliveNonBlocking(t *testing.T) { + session := &Session{keepAliveReason: make(chan string, 1)} + var i int + for ; i < 2; i++ { + session.keepAlive("io") + } + + // if keepAlive blocks, we'll never reach this and timeout the test + // timing out + if i != 2 { + t.Errorf("unexpected iteration account, expected: 2, got: %d", i) + } +} + +func TestNotifyHostOfActivity(t *testing.T) { + notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) { + var req []interface{} + if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { + return nil, fmt.Errorf("unmarshal req: %w", err) + } + if len(req) < 2 { + return nil, errors.New("request arguments is less than 2") + } + + if clientName, ok := req[0].(string); ok { + if clientName != mockClientName { + return nil, fmt.Errorf( + "unexpected clientName param, expected: %q, got: %q", mockClientName, clientName, + ) + } + } else { + return nil, errors.New("clientName param is not a string") + } + + if acs, ok := req[1].([]interface{}); ok { + if fmt.Sprintf("%s", acs) != "[input]" { + return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs) + } + } else { + return nil, errors.New("activities param is not a slice") + } + + return nil, nil + } + svc := livesharetest.WithService( + "ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity, + ) + testServer, session, err := makeMockSession(svc) + if err != nil { + t.Errorf("creating mock session: %w", err) + } + defer testServer.Close() + ctx := context.Background() + done := make(chan error) + go func() { + done <- session.notifyHostOfActivity(ctx, "input") + }() + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %w", err) + } + } +} + +func TestSessionHeartbeat(t *testing.T) { + var requests int + notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) { + requests++ + + var req []interface{} + if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { + return nil, fmt.Errorf("unmarshal req: %w", err) + } + if len(req) < 2 { + return nil, errors.New("request arguments is less than 2") + } + + if clientName, ok := req[0].(string); ok { + if clientName != mockClientName { + return nil, fmt.Errorf( + "unexpected clientName param, expected: %q, got: %q", mockClientName, clientName, + ) + } + } else { + return nil, errors.New("clientName param is not a string") + } + + if acs, ok := req[1].([]interface{}); ok { + if fmt.Sprintf("%s", acs) != "[input]" { + return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs) + } + } else { + return nil, errors.New("activities param is not a slice") + } + + return nil, nil + } + svc := livesharetest.WithService( + "ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity, + ) + testServer, session, err := makeMockSession(svc) + if err != nil { + t.Errorf("creating mock session: %w", err) + } + defer testServer.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan struct{}) + + logger := newMockLogger() + session.logger = logger + + go session.heartbeat(ctx, 50*time.Millisecond) + go func() { + session.keepAlive("input") + <-time.Tick(100 * time.Millisecond) + session.keepAlive("input") + <-time.Tick(100 * time.Millisecond) + done <- struct{}{} + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case <-done: + activityCount := strings.Count(logger.String(), "input") + if activityCount != 2 { + t.Errorf("unexpected number of activities, expected: 2, got: %d", activityCount) + } + if requests != 2 { + t.Errorf("unexpected number of requests, expected: 2, got: %d", requests) + } + return + } +} + +type mockLogger struct { + buf *bytes.Buffer +} + +func newMockLogger() *mockLogger { + return &mockLogger{new(bytes.Buffer)} +} + +func (m *mockLogger) Printf(format string, v ...interface{}) (int, error) { + return m.buf.WriteString(fmt.Sprintf(format, v...)) +} + +func (m *mockLogger) Println(v ...interface{}) (int, error) { + return m.buf.WriteString(fmt.Sprintln(v...)) +} + +func (m *mockLogger) String() string { + return m.buf.String() +}