Tests for most of the new behavior

- Made the heartbeat interval configurable for easier testing
- Moved span to the top of connect to capture the full execution
This commit is contained in:
Jose Garcia 2021-10-07 15:14:42 -04:00
parent 9c8351ecd8
commit 8f5d6bb672
6 changed files with 201 additions and 9 deletions

View file

@ -17,6 +17,7 @@ import (
"fmt"
"net/url"
"strings"
"time"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/ssh"
@ -76,6 +77,9 @@ func (opts *Options) uri(action string) (string, error) {
// options, and returns a session representing the connection.
// The caller must call the session's Close method to end the session.
func Connect(ctx context.Context, opts Options) (*Session, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
defer span.Finish()
uri, err := opts.uri("connect")
if err != nil {
return nil, err
@ -86,9 +90,6 @@ func Connect(ctx context.Context, opts Options) (*Session, error) {
sessionLogger = opts.Logger
}
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
defer span.Finish()
sock := newSocket(uri, opts.TLSConfig)
if err := sock.connect(ctx); err != nil {
return nil, fmt.Errorf("error connecting websocket: %w", err)
@ -125,7 +126,7 @@ func Connect(ctx context.Context, opts Options) (*Session, error) {
keepAliveReason: make(chan string, 1),
logger: sessionLogger,
}
go s.heartbeat(ctx)
go s.heartbeat(ctx, 1*time.Minute)
return s, nil
}

View file

@ -15,6 +15,7 @@ import (
func TestConnect(t *testing.T) {
opts := Options{
ClientName: "liveshare-client",
SessionID: "session-id",
SessionToken: "session-token",
RelaySAS: "relay-sas",

View file

@ -41,6 +41,7 @@ func checkBadOptions(t *testing.T, opts Options) {
func TestOptionsURI(t *testing.T) {
opts := Options{
ClientName: "liveshare-client",
SessionID: "sess-id",
SessionToken: "sess-token",
RelaySAS: "sas",

View file

@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) {
t.Errorf("create mock client: %w", err)
}
defer testServer.Close()
pf := NewPortForwarder(session, "ssh", 80)
pf := NewPortForwarder(session, "ssh", 80, false)
if pf == nil {
t.Error("port forwarder is nil")
}
@ -58,7 +58,7 @@ func TestPortForwarderStart(t *testing.T) {
done := make(chan error)
go func() {
const name, remote = "ssh", 8000
done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen)
done <- NewPortForwarder(session, name, remote, false).ForwardToListener(ctx, listen)
}()
go func() {
@ -93,3 +93,26 @@ func TestPortForwarderStart(t *testing.T) {
}
}
}
func TestPortForwarderTrafficMonitor(t *testing.T) {
buf := bytes.NewBufferString("some-input")
session := &Session{keepAliveReason: make(chan string, 1)}
trafficType := "io"
tm := newTrafficMonitor(buf, session, trafficType)
l := len(buf.Bytes())
bb := make([]byte, l)
n, err := tm.Read(bb)
if err != nil {
t.Errorf("failed to read from traffic monitor: %w", err)
}
if n != l {
t.Errorf("expected to read %d bytes, got %d", l, n)
}
keepAliveReason := <-session.keepAliveReason
if keepAliveReason != trafficType {
t.Errorf("expected keep alive reason to be %s, got %s", trafficType, keepAliveReason)
}
}

View file

@ -103,10 +103,10 @@ func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) {
return port, response.User, nil
}
// heartbeat ticks every minute and sends a signal to the Live Share host to keep
// heartbeat ticks every interval and sends a signal to the Live Share host to keep
// the connection alive if there is a reason to do so.
func (s *Session) heartbeat(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
func (s *Session) heartbeat(ctx context.Context, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {

View file

@ -1,6 +1,7 @@
package liveshare
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
@ -8,11 +9,14 @@ import (
"fmt"
"strings"
"testing"
"time"
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
const mockClientName = "liveshare-client"
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
@ -29,6 +33,7 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server,
}
session, err := Connect(context.Background(), Options{
ClientName: mockClientName,
SessionID: "session-id",
SessionToken: sessionToken,
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
@ -221,3 +226,164 @@ func TestInvalidHostKey(t *testing.T) {
t.Error("expected invalid host key error, got: nil")
}
}
func TestKeepAliveNonBlocking(t *testing.T) {
session := &Session{keepAliveReason: make(chan string, 1)}
var i int
for ; i < 2; i++ {
session.keepAlive("io")
}
// if keepAlive blocks, we'll never reach this and timeout the test
// timing out
if i != 2 {
t.Errorf("unexpected iteration account, expected: 2, got: %d", i)
}
}
func TestNotifyHostOfActivity(t *testing.T) {
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %w", err)
}
if len(req) < 2 {
return nil, errors.New("request arguments is less than 2")
}
if clientName, ok := req[0].(string); ok {
if clientName != mockClientName {
return nil, fmt.Errorf(
"unexpected clientName param, expected: %q, got: %q", mockClientName, clientName,
)
}
} else {
return nil, errors.New("clientName param is not a string")
}
if acs, ok := req[1].([]interface{}); ok {
if fmt.Sprintf("%s", acs) != "[input]" {
return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs)
}
} else {
return nil, errors.New("activities param is not a slice")
}
return nil, nil
}
svc := livesharetest.WithService(
"ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity,
)
testServer, session, err := makeMockSession(svc)
if err != nil {
t.Errorf("creating mock session: %w", err)
}
defer testServer.Close()
ctx := context.Background()
done := make(chan error)
go func() {
done <- session.notifyHostOfActivity(ctx, "input")
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %w", err)
}
}
}
func TestSessionHeartbeat(t *testing.T) {
var requests int
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
requests++
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %w", err)
}
if len(req) < 2 {
return nil, errors.New("request arguments is less than 2")
}
if clientName, ok := req[0].(string); ok {
if clientName != mockClientName {
return nil, fmt.Errorf(
"unexpected clientName param, expected: %q, got: %q", mockClientName, clientName,
)
}
} else {
return nil, errors.New("clientName param is not a string")
}
if acs, ok := req[1].([]interface{}); ok {
if fmt.Sprintf("%s", acs) != "[input]" {
return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs)
}
} else {
return nil, errors.New("activities param is not a slice")
}
return nil, nil
}
svc := livesharetest.WithService(
"ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity,
)
testServer, session, err := makeMockSession(svc)
if err != nil {
t.Errorf("creating mock session: %w", err)
}
defer testServer.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
logger := newMockLogger()
session.logger = logger
go session.heartbeat(ctx, 50*time.Millisecond)
go func() {
session.keepAlive("input")
<-time.Tick(100 * time.Millisecond)
session.keepAlive("input")
<-time.Tick(100 * time.Millisecond)
done <- struct{}{}
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case <-done:
activityCount := strings.Count(logger.String(), "input")
if activityCount != 2 {
t.Errorf("unexpected number of activities, expected: 2, got: %d", activityCount)
}
if requests != 2 {
t.Errorf("unexpected number of requests, expected: 2, got: %d", requests)
}
return
}
}
type mockLogger struct {
buf *bytes.Buffer
}
func newMockLogger() *mockLogger {
return &mockLogger{new(bytes.Buffer)}
}
func (m *mockLogger) Printf(format string, v ...interface{}) (int, error) {
return m.buf.WriteString(fmt.Sprintf(format, v...))
}
func (m *mockLogger) Println(v ...interface{}) (int, error) {
return m.buf.WriteString(fmt.Sprintln(v...))
}
func (m *mockLogger) String() string {
return m.buf.String()
}