Merge pull request #4461 from cli/jg/liveshare-keepalive

codespace/liveshare keepalive: keep LS sessions alive with PF traffic
This commit is contained in:
Jose Garcia 2021-10-13 11:21:43 -04:00 committed by GitHub
commit a033b85fa2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 433 additions and 32 deletions

View file

@ -15,6 +15,13 @@ type logger interface {
Println(v ...interface{}) (int, error)
}
// TODO(josebalius): clean this up once we standardrize
// logging for codespaces
type liveshareLogger interface {
Println(v ...interface{})
Printf(f string, v ...interface{})
}
func connectionReady(codespace *api.Codespace) bool {
return codespace.Connection.SessionID != "" &&
codespace.Connection.SessionToken != "" &&
@ -30,7 +37,7 @@ type apiClient interface {
// ConnectToLiveshare waits for a Codespace to become running,
// and connects to it using a Live Share session.
func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) {
func ConnectToLiveshare(ctx context.Context, log logger, sessionLogger liveshareLogger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) {
var startedCodespace bool
if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable {
startedCodespace = true
@ -67,10 +74,12 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, co
log.Println("Connecting to your codespace...")
return liveshare.Connect(ctx, liveshare.Options{
ClientName: "gh",
SessionID: codespace.Connection.SessionID,
SessionToken: codespace.Connection.SessionToken,
RelaySAS: codespace.Connection.RelaySAS,
RelayEndpoint: codespace.Connection.RelayEndpoint,
HostPublicKeys: codespace.Connection.HostPublicKeys,
Logger: sessionLogger,
})
}

View file

@ -5,6 +5,8 @@ import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net"
"strings"
"time"
@ -36,8 +38,10 @@ type PostCreateState struct {
// PollPostCreateStates watches for state changes in a codespace,
// and calls the supplied poller for each batch of state changes.
// It runs until it encounters an error, including cancellation of the context.
func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, codespace *api.Codespace, poller func([]PostCreateState)) (err error) {
session, err := ConnectToLiveshare(ctx, log, apiClient, codespace)
func PollPostCreateStates(ctx context.Context, logger logger, apiClient apiClient, codespace *api.Codespace, poller func([]PostCreateState)) (err error) {
noopLogger := log.New(ioutil.Discard, "", 0)
session, err := ConnectToLiveshare(ctx, logger, noopLogger, apiClient, codespace)
if err != nil {
return fmt.Errorf("connect to Live Share: %w", err)
}
@ -54,7 +58,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient,
}
localPort := listen.Addr().(*net.TCPAddr).Port
log.Println("Fetching SSH Details...")
logger.Println("Fetching SSH Details...")
remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
if err != nil {
return fmt.Errorf("error getting ssh server details: %w", err)
@ -62,7 +66,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient,
tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false)
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
}()

View file

@ -7,6 +7,8 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"sort"
"strings"
@ -211,6 +213,10 @@ func noArgsConstraint(cmd *cobra.Command, args []string) error {
return nil
}
func noopLogger() *log.Logger {
return log.New(ioutil.Discard, "", 0)
}
type codespace struct {
*api.Codespace
}

View file

@ -44,6 +44,7 @@ func TestDelete(t *testing.T) {
},
},
wantDeleted: []string{"hubot-robawt-abc"},
wantStdout: "Codespace deleted.\n",
},
{
name: "by repo",
@ -65,6 +66,7 @@ func TestDelete(t *testing.T) {
},
},
wantDeleted: []string{"monalisa-spoonknife-123", "monalisa-spoonknife-c4f3"},
wantStdout: "Codespaces deleted.\n",
},
{
name: "unused",
@ -87,6 +89,7 @@ func TestDelete(t *testing.T) {
},
},
wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"},
wantStdout: "Codespaces deleted.\n",
},
{
name: "deletion failed",
@ -148,6 +151,7 @@ func TestDelete(t *testing.T) {
"Codespace hubot-robawt-abc has unsaved changes. OK to delete?": true,
},
wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"},
wantStdout: "Codespaces deleted.\n",
},
}
for _, tt := range tests {

View file

@ -51,7 +51,7 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err
return fmt.Errorf("get or choose codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace)
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace)
if err != nil {
return fmt.Errorf("connecting to Live Share: %w", err)
}
@ -90,7 +90,7 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err
tunnelClosed := make(chan error, 1)
go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false)
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
}()

View file

@ -9,10 +9,14 @@ import (
// NewLogger returns a Logger that will write to the given stdout/stderr writers.
// Disable the Logger to prevent it from writing to stdout in a TTY environment.
func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger {
enabled := !disabled
if isTTY(stdout) && !enabled {
enabled = false
}
return &Logger{
out: stdout,
errout: stderr,
enabled: !disabled && isTTY(stdout),
enabled: enabled,
}
}

View file

@ -60,7 +60,7 @@ func (a *App) ListPorts(ctx context.Context, codespaceName string, asJSON bool)
devContainerCh := getDevContainer(ctx, a.apiClient, codespace)
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace)
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
@ -194,7 +194,7 @@ func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName, sourcePor
return fmt.Errorf("error getting codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace)
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
@ -253,7 +253,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st
return fmt.Errorf("error getting codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace)
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
@ -272,7 +272,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st
defer listen.Close()
a.logger.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local)
name := fmt.Sprintf("share-%d", pair.remote)
fwd := liveshare.NewPortForwarder(session, name, pair.remote)
fwd := liveshare.NewPortForwarder(session, name, pair.remote, false)
return fwd.ForwardToListener(ctx, listen) // error always non-nil
})
}

View file

@ -3,34 +3,46 @@ package codespace
import (
"context"
"fmt"
"io/ioutil"
"log"
"net"
"os"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/spf13/cobra"
)
type sshOptions struct {
codespace string
profile string
serverPort int
debug bool
debugFile string
}
func newSSHCmd(app *App) *cobra.Command {
var sshProfile, codespaceName string
var sshServerPort int
var opts sshOptions
sshCmd := &cobra.Command{
Use: "ssh [flags] [--] [ssh-flags] [command]",
Short: "SSH into a codespace",
RunE: func(cmd *cobra.Command, args []string) error {
return app.SSH(cmd.Context(), args, sshProfile, codespaceName, sshServerPort)
return app.SSH(cmd.Context(), args, opts)
},
}
sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "Name of the SSH profile to use")
sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH server port number (0 => pick unused)")
sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "Name of the codespace")
sshCmd.Flags().StringVarP(&opts.profile, "profile", "", "", "Name of the SSH profile to use")
sshCmd.Flags().IntVarP(&opts.serverPort, "server-port", "", 0, "SSH server port number (0 => pick unused)")
sshCmd.Flags().StringVarP(&opts.codespace, "codespace", "c", "", "Name of the codespace")
sshCmd.Flags().BoolVarP(&opts.debug, "debug", "d", false, "Log debug data to a file")
sshCmd.Flags().StringVarP(&opts.debugFile, "debug-file", "", "", "Path of the file log to")
return sshCmd
}
// SSH opens an ssh session or runs an ssh command in a codespace.
func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) {
func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err error) {
// Ensure all child tasks (e.g. port forwarding) terminate before return.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
@ -45,12 +57,22 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa
authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login)
}()
codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName)
codespace, err := getOrChooseCodespace(ctx, a.apiClient, opts.codespace)
if err != nil {
return fmt.Errorf("get or choose codespace: %w", err)
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace)
var debugLogger *fileLogger
if opts.debug {
debugLogger, err = newFileLogger(opts.debugFile)
if err != nil {
return fmt.Errorf("error creating debug logger: %w", err)
}
defer safeClose(debugLogger, &err)
a.logger.Println("Debug file located at: " + debugLogger.Name())
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, debugLogger, a.apiClient, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
@ -66,6 +88,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa
return fmt.Errorf("error getting ssh server details: %w", err)
}
localSSHServerPort := opts.serverPort
usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell
// Ensure local port is listening before client (Shell) connects.
@ -76,7 +99,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa
defer listen.Close()
localSSHServerPort = listen.Addr().(*net.TCPAddr).Port
connectDestination := sshProfile
connectDestination := opts.profile
if connectDestination == "" {
connectDestination = fmt.Sprintf("%s@localhost", sshUser)
}
@ -84,7 +107,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa
a.logger.Println("Ready...")
tunnelClosed := make(chan error, 1)
go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true)
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil
}()
@ -103,3 +126,43 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa
return nil // success
}
}
// fileLogger is a wrapper around an log.Logger configured to write
// to a file. It exports two additional methods to get the log file name
// and close the file handle when the operation is finished.
type fileLogger struct {
*log.Logger
f *os.File
}
// newFileLogger creates a new fileLogger. It returns an error if the file
// cannot be created. The file is created on the specified path, if the path
// is empty it is created in the temporary directory.
func newFileLogger(file string) (fl *fileLogger, err error) {
var f *os.File
if file == "" {
f, err = ioutil.TempFile("", "")
if err != nil {
return nil, fmt.Errorf("failed to create tmp file: %w", err)
}
} else {
f, err = os.Create(file)
if err != nil {
return nil, err
}
}
return &fileLogger{
Logger: log.New(f, "", log.LstdFlags),
f: f,
}, nil
}
func (fl *fileLogger) Name() string {
return fl.f.Name()
}
func (fl *fileLogger) Close() error {
return fl.f.Close()
}

View file

@ -17,23 +17,34 @@ import (
"fmt"
"net/url"
"strings"
"time"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/ssh"
)
type logger interface {
Println(v ...interface{})
Printf(f string, v ...interface{})
}
// An Options specifies Live Share connection parameters.
type Options struct {
ClientName string // ClientName is the name of the connecting client.
SessionID string
SessionToken string // token for SSH session
RelaySAS string
RelayEndpoint string
HostPublicKeys []string
Logger logger // required
TLSConfig *tls.Config // (optional)
}
// uri returns a websocket URL for the specified options.
func (opts *Options) uri(action string) (string, error) {
if opts.ClientName == "" {
return "", errors.New("ClientName is required")
}
if opts.SessionID == "" {
return "", errors.New("SessionID is required")
}
@ -56,13 +67,17 @@ func (opts *Options) uri(action string) (string, error) {
// options, and returns a session representing the connection.
// The caller must call the session's Close method to end the session.
func Connect(ctx context.Context, opts Options) (*Session, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
defer span.Finish()
uri, err := opts.uri("connect")
if err != nil {
return nil, err
}
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
defer span.Finish()
if opts.Logger == nil {
return nil, errors.New("Logger is required")
}
sock := newSocket(uri, opts.TLSConfig)
if err := sock.connect(ctx); err != nil {
@ -93,7 +108,16 @@ func Connect(ctx context.Context, opts Options) (*Session, error) {
return nil, fmt.Errorf("error joining Live Share workspace: %w", err)
}
return &Session{ssh: ssh, rpc: rpc}, nil
s := &Session{
ssh: ssh,
rpc: rpc,
clientName: opts.ClientName,
keepAliveReason: make(chan string, 1),
logger: opts.Logger,
}
go s.heartbeat(ctx, 1*time.Minute)
return s, nil
}
type clientCapabilities struct {

View file

@ -15,10 +15,12 @@ import (
func TestConnect(t *testing.T) {
opts := Options{
ClientName: "liveshare-client",
SessionID: "session-id",
SessionToken: "session-token",
RelaySAS: "relay-sas",
HostPublicKeys: []string{livesharetest.SSHPublicKey},
Logger: newMockLogger(),
}
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
var joinWorkspaceReq joinWorkspaceArgs

View file

@ -41,6 +41,7 @@ func checkBadOptions(t *testing.T, opts Options) {
func TestOptionsURI(t *testing.T) {
opts := Options{
ClientName: "liveshare-client",
SessionID: "sess-id",
SessionToken: "sess-token",
RelaySAS: "sas",

View file

@ -15,16 +15,19 @@ type PortForwarder struct {
session *Session
name string
remotePort int
keepAlive bool
}
// NewPortForwarder returns a new PortForwarder for the specified
// remote port and Live Share session. The name describes the purpose
// of the remote port or service.
func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder {
// of the remote port or service. The keepAlive flag indicates whether
// the session should be kept alive with port forwarding traffic.
func NewPortForwarder(session *Session, name string, remotePort int, keepAlive bool) *PortForwarder {
return &PortForwarder{
session: session,
name: name,
remotePort: remotePort,
keepAlive: keepAlive,
}
}
@ -106,6 +109,27 @@ func awaitError(ctx context.Context, errc <-chan error) error {
}
}
// trafficMonitor implements io.Reader. It keeps the session alive by notifying
// it of the traffic type during Read operations.
type trafficMonitor struct {
reader io.Reader
session *Session
trafficType string
}
// newTrafficMonitor returns a new trafficMonitor for the specified
// session and traffic type. It wraps the provided io.Reader with its own
// Read method.
func newTrafficMonitor(reader io.Reader, session *Session, trafficType string) *trafficMonitor {
return &trafficMonitor{reader, session, trafficType}
}
func (t *trafficMonitor) Read(p []byte) (n int, err error) {
t.session.keepAlive(t.trafficType)
return t.reader.Read(p)
}
// handleConnection handles forwarding for a single accepted connection, then closes it.
func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
@ -133,8 +157,21 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co
_, err := io.Copy(w, r)
errs <- err
}
go copyConn(conn, channel)
go copyConn(channel, conn)
var (
channelReader io.Reader = channel
connReader io.Reader = conn
)
// If the forwader has been configured to keep the session alive
// it will monitor the I/O and notify the session of the traffic.
if fwd.keepAlive {
channelReader = newTrafficMonitor(channelReader, fwd.session, "output")
connReader = newTrafficMonitor(connReader, fwd.session, "input")
}
go copyConn(conn, channelReader)
go copyConn(channel, connReader)
// Wait until context is cancelled or both copies are done.
// Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail.

View file

@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) {
t.Errorf("create mock client: %w", err)
}
defer testServer.Close()
pf := NewPortForwarder(session, "ssh", 80)
pf := NewPortForwarder(session, "ssh", 80, false)
if pf == nil {
t.Error("port forwarder is nil")
}
@ -58,7 +58,7 @@ func TestPortForwarderStart(t *testing.T) {
done := make(chan error)
go func() {
const name, remote = "ssh", 8000
done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen)
done <- NewPortForwarder(session, name, remote, false).ForwardToListener(ctx, listen)
}()
go func() {
@ -93,3 +93,26 @@ func TestPortForwarderStart(t *testing.T) {
}
}
}
func TestPortForwarderTrafficMonitor(t *testing.T) {
buf := bytes.NewBufferString("some-input")
session := &Session{keepAliveReason: make(chan string, 1)}
trafficType := "io"
tm := newTrafficMonitor(buf, session, trafficType)
l := len(buf.Bytes())
bb := make([]byte, l)
n, err := tm.Read(bb)
if err != nil {
t.Errorf("failed to read from traffic monitor: %w", err)
}
if n != l {
t.Errorf("expected to read %d bytes, got %d", l, n)
}
keepAliveReason := <-session.keepAliveReason
if keepAliveReason != trafficType {
t.Errorf("expected keep alive reason to be %s, got %s", trafficType, keepAliveReason)
}
}

View file

@ -4,12 +4,17 @@ import (
"context"
"fmt"
"strconv"
"time"
)
// A Session represents the session between a connected Live Share client and server.
type Session struct {
ssh *sshSession
rpc *rpcClient
clientName string
keepAliveReason chan string
logger logger
}
// Close should be called by users to clean up RPC and SSH resources whenever the session
@ -97,3 +102,42 @@ func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) {
return port, response.User, nil
}
// heartbeat runs until context cancellation, periodically checking whether there is a
// reason to keep the connection alive, and if so, notifying the Live Share host to do so.
func (s *Session) heartbeat(ctx context.Context, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.logger.Println("Heartbeat tick")
reason := <-s.keepAliveReason
s.logger.Println("Keep alive reason: " + reason)
if err := s.notifyHostOfActivity(ctx, reason); err != nil {
s.logger.Printf("Failed to notify host of activity: %s\n", err)
}
}
}
}
// notifyHostOfActivity notifies the Live Share host of client activity.
func (s *Session) notifyHostOfActivity(ctx context.Context, activity string) error {
activities := []string{activity}
params := []interface{}{s.clientName, activities}
return s.rpc.do(ctx, "ICodespaceHostService.notifyCodespaceOfClientActivity", params, nil)
}
// keepAlive accepts a reason that is retained if there is no active reason
// to send to the server.
func (s *Session) keepAlive(reason string) {
select {
case s.keepAliveReason <- reason:
default:
// there is already an active keep alive reason
// so we can ignore this one
}
}

View file

@ -1,18 +1,23 @@
package liveshare
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"testing"
"time"
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
const mockClientName = "liveshare-client"
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
@ -29,12 +34,14 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server,
}
session, err := Connect(context.Background(), Options{
ClientName: mockClientName,
SessionID: "session-id",
SessionToken: sessionToken,
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
RelaySAS: "relay-sas",
HostPublicKeys: []string{livesharetest.SSHPublicKey},
TLSConfig: &tls.Config{InsecureSkipVerify: true},
Logger: newMockLogger(),
})
if err != nil {
return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err)
@ -221,3 +228,176 @@ func TestInvalidHostKey(t *testing.T) {
t.Error("expected invalid host key error, got: nil")
}
}
func TestKeepAliveNonBlocking(t *testing.T) {
session := &Session{keepAliveReason: make(chan string, 1)}
for i := 0; i < 2; i++ {
session.keepAlive("io")
}
// if keepAlive blocks, we'll never reach this and timeout the test
// timing out
}
func TestNotifyHostOfActivity(t *testing.T) {
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %w", err)
}
if len(req) < 2 {
return nil, errors.New("request arguments is less than 2")
}
if clientName, ok := req[0].(string); ok {
if clientName != mockClientName {
return nil, fmt.Errorf(
"unexpected clientName param, expected: %q, got: %q", mockClientName, clientName,
)
}
} else {
return nil, errors.New("clientName param is not a string")
}
if acs, ok := req[1].([]interface{}); ok {
if fmt.Sprintf("%s", acs) != "[input]" {
return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs)
}
} else {
return nil, errors.New("activities param is not a slice")
}
return nil, nil
}
svc := livesharetest.WithService(
"ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity,
)
testServer, session, err := makeMockSession(svc)
if err != nil {
t.Errorf("creating mock session: %w", err)
}
defer testServer.Close()
ctx := context.Background()
done := make(chan error)
go func() {
done <- session.notifyHostOfActivity(ctx, "input")
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %w", err)
}
}
}
func TestSessionHeartbeat(t *testing.T) {
var (
requestsMu sync.Mutex
requests int
)
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
requestsMu.Lock()
requests++
requestsMu.Unlock()
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %w", err)
}
if len(req) < 2 {
return nil, errors.New("request arguments is less than 2")
}
if clientName, ok := req[0].(string); ok {
if clientName != mockClientName {
return nil, fmt.Errorf(
"unexpected clientName param, expected: %q, got: %q", mockClientName, clientName,
)
}
} else {
return nil, errors.New("clientName param is not a string")
}
if acs, ok := req[1].([]interface{}); ok {
if fmt.Sprintf("%s", acs) != "[input]" {
return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs)
}
} else {
return nil, errors.New("activities param is not a slice")
}
return nil, nil
}
svc := livesharetest.WithService(
"ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity,
)
testServer, session, err := makeMockSession(svc)
if err != nil {
t.Errorf("creating mock session: %w", err)
}
defer testServer.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
logger := newMockLogger()
session.logger = logger
go session.heartbeat(ctx, 50*time.Millisecond)
go func() {
session.keepAlive("input")
<-time.Tick(200 * time.Millisecond)
session.keepAlive("input")
<-time.Tick(100 * time.Millisecond)
done <- struct{}{}
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case <-done:
activityCount := strings.Count(logger.String(), "input")
if activityCount != 2 {
t.Errorf("unexpected number of activities, expected: 2, got: %d", activityCount)
}
requestsMu.Lock()
rc := requests
requestsMu.Unlock()
if rc != 2 {
t.Errorf("unexpected number of requests, expected: 2, got: %d", requests)
}
return
}
}
type mockLogger struct {
sync.Mutex
buf *bytes.Buffer
}
func newMockLogger() *mockLogger {
return &mockLogger{buf: new(bytes.Buffer)}
}
func (m *mockLogger) Printf(format string, v ...interface{}) {
m.Lock()
defer m.Unlock()
m.buf.WriteString(fmt.Sprintf(format, v...))
}
func (m *mockLogger) Println(v ...interface{}) {
m.Lock()
defer m.Unlock()
m.buf.WriteString(fmt.Sprintln(v...))
}
func (m *mockLogger) String() string {
m.Lock()
defer m.Unlock()
return m.buf.String()
}