Tests for most of the new behavior
- Made the heartbeat interval configurable for easier testing - Moved span to the top of connect to capture the full execution
This commit is contained in:
parent
9c8351ecd8
commit
8f5d6bb672
6 changed files with 201 additions and 9 deletions
|
|
@ -17,6 +17,7 @@ import (
|
|||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
|
@ -76,6 +77,9 @@ func (opts *Options) uri(action string) (string, error) {
|
|||
// options, and returns a session representing the connection.
|
||||
// The caller must call the session's Close method to end the session.
|
||||
func Connect(ctx context.Context, opts Options) (*Session, error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
|
||||
defer span.Finish()
|
||||
|
||||
uri, err := opts.uri("connect")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -86,9 +90,6 @@ func Connect(ctx context.Context, opts Options) (*Session, error) {
|
|||
sessionLogger = opts.Logger
|
||||
}
|
||||
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
|
||||
defer span.Finish()
|
||||
|
||||
sock := newSocket(uri, opts.TLSConfig)
|
||||
if err := sock.connect(ctx); err != nil {
|
||||
return nil, fmt.Errorf("error connecting websocket: %w", err)
|
||||
|
|
@ -125,7 +126,7 @@ func Connect(ctx context.Context, opts Options) (*Session, error) {
|
|||
keepAliveReason: make(chan string, 1),
|
||||
logger: sessionLogger,
|
||||
}
|
||||
go s.heartbeat(ctx)
|
||||
go s.heartbeat(ctx, 1*time.Minute)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import (
|
|||
|
||||
func TestConnect(t *testing.T) {
|
||||
opts := Options{
|
||||
ClientName: "liveshare-client",
|
||||
SessionID: "session-id",
|
||||
SessionToken: "session-token",
|
||||
RelaySAS: "relay-sas",
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ func checkBadOptions(t *testing.T, opts Options) {
|
|||
|
||||
func TestOptionsURI(t *testing.T) {
|
||||
opts := Options{
|
||||
ClientName: "liveshare-client",
|
||||
SessionID: "sess-id",
|
||||
SessionToken: "sess-token",
|
||||
RelaySAS: "sas",
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) {
|
|||
t.Errorf("create mock client: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
pf := NewPortForwarder(session, "ssh", 80)
|
||||
pf := NewPortForwarder(session, "ssh", 80, false)
|
||||
if pf == nil {
|
||||
t.Error("port forwarder is nil")
|
||||
}
|
||||
|
|
@ -58,7 +58,7 @@ func TestPortForwarderStart(t *testing.T) {
|
|||
done := make(chan error)
|
||||
go func() {
|
||||
const name, remote = "ssh", 8000
|
||||
done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen)
|
||||
done <- NewPortForwarder(session, name, remote, false).ForwardToListener(ctx, listen)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
|
|
@ -93,3 +93,26 @@ func TestPortForwarderStart(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortForwarderTrafficMonitor(t *testing.T) {
|
||||
buf := bytes.NewBufferString("some-input")
|
||||
session := &Session{keepAliveReason: make(chan string, 1)}
|
||||
trafficType := "io"
|
||||
|
||||
tm := newTrafficMonitor(buf, session, trafficType)
|
||||
l := len(buf.Bytes())
|
||||
|
||||
bb := make([]byte, l)
|
||||
n, err := tm.Read(bb)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read from traffic monitor: %w", err)
|
||||
}
|
||||
if n != l {
|
||||
t.Errorf("expected to read %d bytes, got %d", l, n)
|
||||
}
|
||||
|
||||
keepAliveReason := <-session.keepAliveReason
|
||||
if keepAliveReason != trafficType {
|
||||
t.Errorf("expected keep alive reason to be %s, got %s", trafficType, keepAliveReason)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,10 +103,10 @@ func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) {
|
|||
return port, response.User, nil
|
||||
}
|
||||
|
||||
// heartbeat ticks every minute and sends a signal to the Live Share host to keep
|
||||
// heartbeat ticks every interval and sends a signal to the Live Share host to keep
|
||||
// the connection alive if there is a reason to do so.
|
||||
func (s *Session) heartbeat(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
func (s *Session) heartbeat(ctx context.Context, interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
|
|
@ -8,11 +9,14 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
const mockClientName = "liveshare-client"
|
||||
|
||||
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return joinWorkspaceResult{1}, nil
|
||||
|
|
@ -29,6 +33,7 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server,
|
|||
}
|
||||
|
||||
session, err := Connect(context.Background(), Options{
|
||||
ClientName: mockClientName,
|
||||
SessionID: "session-id",
|
||||
SessionToken: sessionToken,
|
||||
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
|
||||
|
|
@ -221,3 +226,164 @@ func TestInvalidHostKey(t *testing.T) {
|
|||
t.Error("expected invalid host key error, got: nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeepAliveNonBlocking(t *testing.T) {
|
||||
session := &Session{keepAliveReason: make(chan string, 1)}
|
||||
var i int
|
||||
for ; i < 2; i++ {
|
||||
session.keepAlive("io")
|
||||
}
|
||||
|
||||
// if keepAlive blocks, we'll never reach this and timeout the test
|
||||
// timing out
|
||||
if i != 2 {
|
||||
t.Errorf("unexpected iteration account, expected: 2, got: %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifyHostOfActivity(t *testing.T) {
|
||||
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
var req []interface{}
|
||||
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal req: %w", err)
|
||||
}
|
||||
if len(req) < 2 {
|
||||
return nil, errors.New("request arguments is less than 2")
|
||||
}
|
||||
|
||||
if clientName, ok := req[0].(string); ok {
|
||||
if clientName != mockClientName {
|
||||
return nil, fmt.Errorf(
|
||||
"unexpected clientName param, expected: %q, got: %q", mockClientName, clientName,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("clientName param is not a string")
|
||||
}
|
||||
|
||||
if acs, ok := req[1].([]interface{}); ok {
|
||||
if fmt.Sprintf("%s", acs) != "[input]" {
|
||||
return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs)
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("activities param is not a slice")
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
svc := livesharetest.WithService(
|
||||
"ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity,
|
||||
)
|
||||
testServer, session, err := makeMockSession(svc)
|
||||
if err != nil {
|
||||
t.Errorf("creating mock session: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
ctx := context.Background()
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
done <- session.notifyHostOfActivity(ctx, "input")
|
||||
}()
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHeartbeat(t *testing.T) {
|
||||
var requests int
|
||||
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
requests++
|
||||
|
||||
var req []interface{}
|
||||
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal req: %w", err)
|
||||
}
|
||||
if len(req) < 2 {
|
||||
return nil, errors.New("request arguments is less than 2")
|
||||
}
|
||||
|
||||
if clientName, ok := req[0].(string); ok {
|
||||
if clientName != mockClientName {
|
||||
return nil, fmt.Errorf(
|
||||
"unexpected clientName param, expected: %q, got: %q", mockClientName, clientName,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("clientName param is not a string")
|
||||
}
|
||||
|
||||
if acs, ok := req[1].([]interface{}); ok {
|
||||
if fmt.Sprintf("%s", acs) != "[input]" {
|
||||
return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs)
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("activities param is not a slice")
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
svc := livesharetest.WithService(
|
||||
"ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity,
|
||||
)
|
||||
testServer, session, err := makeMockSession(svc)
|
||||
if err != nil {
|
||||
t.Errorf("creating mock session: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
logger := newMockLogger()
|
||||
session.logger = logger
|
||||
|
||||
go session.heartbeat(ctx, 50*time.Millisecond)
|
||||
go func() {
|
||||
session.keepAlive("input")
|
||||
<-time.Tick(100 * time.Millisecond)
|
||||
session.keepAlive("input")
|
||||
<-time.Tick(100 * time.Millisecond)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case <-done:
|
||||
activityCount := strings.Count(logger.String(), "input")
|
||||
if activityCount != 2 {
|
||||
t.Errorf("unexpected number of activities, expected: 2, got: %d", activityCount)
|
||||
}
|
||||
if requests != 2 {
|
||||
t.Errorf("unexpected number of requests, expected: 2, got: %d", requests)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type mockLogger struct {
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func newMockLogger() *mockLogger {
|
||||
return &mockLogger{new(bytes.Buffer)}
|
||||
}
|
||||
|
||||
func (m *mockLogger) Printf(format string, v ...interface{}) (int, error) {
|
||||
return m.buf.WriteString(fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) Println(v ...interface{}) (int, error) {
|
||||
return m.buf.WriteString(fmt.Sprintln(v...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) String() string {
|
||||
return m.buf.String()
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue