Merge pull request #4461 from cli/jg/liveshare-keepalive
codespace/liveshare keepalive: keep LS sessions alive with PF traffic
This commit is contained in:
commit
a033b85fa2
15 changed files with 433 additions and 32 deletions
|
|
@ -15,6 +15,13 @@ type logger interface {
|
|||
Println(v ...interface{}) (int, error)
|
||||
}
|
||||
|
||||
// TODO(josebalius): clean this up once we standardrize
|
||||
// logging for codespaces
|
||||
type liveshareLogger interface {
|
||||
Println(v ...interface{})
|
||||
Printf(f string, v ...interface{})
|
||||
}
|
||||
|
||||
func connectionReady(codespace *api.Codespace) bool {
|
||||
return codespace.Connection.SessionID != "" &&
|
||||
codespace.Connection.SessionToken != "" &&
|
||||
|
|
@ -30,7 +37,7 @@ type apiClient interface {
|
|||
|
||||
// ConnectToLiveshare waits for a Codespace to become running,
|
||||
// and connects to it using a Live Share session.
|
||||
func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) {
|
||||
func ConnectToLiveshare(ctx context.Context, log logger, sessionLogger liveshareLogger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) {
|
||||
var startedCodespace bool
|
||||
if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable {
|
||||
startedCodespace = true
|
||||
|
|
@ -67,10 +74,12 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, co
|
|||
log.Println("Connecting to your codespace...")
|
||||
|
||||
return liveshare.Connect(ctx, liveshare.Options{
|
||||
ClientName: "gh",
|
||||
SessionID: codespace.Connection.SessionID,
|
||||
SessionToken: codespace.Connection.SessionToken,
|
||||
RelaySAS: codespace.Connection.RelaySAS,
|
||||
RelayEndpoint: codespace.Connection.RelayEndpoint,
|
||||
HostPublicKeys: codespace.Connection.HostPublicKeys,
|
||||
Logger: sessionLogger,
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
|
@ -36,8 +38,10 @@ type PostCreateState struct {
|
|||
// PollPostCreateStates watches for state changes in a codespace,
|
||||
// and calls the supplied poller for each batch of state changes.
|
||||
// It runs until it encounters an error, including cancellation of the context.
|
||||
func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, codespace *api.Codespace, poller func([]PostCreateState)) (err error) {
|
||||
session, err := ConnectToLiveshare(ctx, log, apiClient, codespace)
|
||||
func PollPostCreateStates(ctx context.Context, logger logger, apiClient apiClient, codespace *api.Codespace, poller func([]PostCreateState)) (err error) {
|
||||
noopLogger := log.New(ioutil.Discard, "", 0)
|
||||
|
||||
session, err := ConnectToLiveshare(ctx, logger, noopLogger, apiClient, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to Live Share: %w", err)
|
||||
}
|
||||
|
|
@ -54,7 +58,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient,
|
|||
}
|
||||
localPort := listen.Addr().(*net.TCPAddr).Port
|
||||
|
||||
log.Println("Fetching SSH Details...")
|
||||
logger.Println("Fetching SSH Details...")
|
||||
remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting ssh server details: %w", err)
|
||||
|
|
@ -62,7 +66,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient,
|
|||
|
||||
tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false)
|
||||
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
|
||||
}()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
|
@ -211,6 +213,10 @@ func noArgsConstraint(cmd *cobra.Command, args []string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func noopLogger() *log.Logger {
|
||||
return log.New(ioutil.Discard, "", 0)
|
||||
}
|
||||
|
||||
type codespace struct {
|
||||
*api.Codespace
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ func TestDelete(t *testing.T) {
|
|||
},
|
||||
},
|
||||
wantDeleted: []string{"hubot-robawt-abc"},
|
||||
wantStdout: "Codespace deleted.\n",
|
||||
},
|
||||
{
|
||||
name: "by repo",
|
||||
|
|
@ -65,6 +66,7 @@ func TestDelete(t *testing.T) {
|
|||
},
|
||||
},
|
||||
wantDeleted: []string{"monalisa-spoonknife-123", "monalisa-spoonknife-c4f3"},
|
||||
wantStdout: "Codespaces deleted.\n",
|
||||
},
|
||||
{
|
||||
name: "unused",
|
||||
|
|
@ -87,6 +89,7 @@ func TestDelete(t *testing.T) {
|
|||
},
|
||||
},
|
||||
wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"},
|
||||
wantStdout: "Codespaces deleted.\n",
|
||||
},
|
||||
{
|
||||
name: "deletion failed",
|
||||
|
|
@ -148,6 +151,7 @@ func TestDelete(t *testing.T) {
|
|||
"Codespace hubot-robawt-abc has unsaved changes. OK to delete?": true,
|
||||
},
|
||||
wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"},
|
||||
wantStdout: "Codespaces deleted.\n",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err
|
|||
return fmt.Errorf("get or choose codespace: %w", err)
|
||||
}
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace)
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connecting to Live Share: %w", err)
|
||||
}
|
||||
|
|
@ -90,7 +90,7 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err
|
|||
|
||||
tunnelClosed := make(chan error, 1)
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false)
|
||||
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
|
||||
}()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,10 +9,14 @@ import (
|
|||
// NewLogger returns a Logger that will write to the given stdout/stderr writers.
|
||||
// Disable the Logger to prevent it from writing to stdout in a TTY environment.
|
||||
func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger {
|
||||
enabled := !disabled
|
||||
if isTTY(stdout) && !enabled {
|
||||
enabled = false
|
||||
}
|
||||
return &Logger{
|
||||
out: stdout,
|
||||
errout: stderr,
|
||||
enabled: !disabled && isTTY(stdout),
|
||||
enabled: enabled,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ func (a *App) ListPorts(ctx context.Context, codespaceName string, asJSON bool)
|
|||
|
||||
devContainerCh := getDevContainer(ctx, a.apiClient, codespace)
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace)
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to Live Share: %w", err)
|
||||
}
|
||||
|
|
@ -194,7 +194,7 @@ func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName, sourcePor
|
|||
return fmt.Errorf("error getting codespace: %w", err)
|
||||
}
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace)
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to Live Share: %w", err)
|
||||
}
|
||||
|
|
@ -253,7 +253,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st
|
|||
return fmt.Errorf("error getting codespace: %w", err)
|
||||
}
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace)
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, noopLogger(), a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to Live Share: %w", err)
|
||||
}
|
||||
|
|
@ -272,7 +272,7 @@ func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []st
|
|||
defer listen.Close()
|
||||
a.logger.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local)
|
||||
name := fmt.Sprintf("share-%d", pair.remote)
|
||||
fwd := liveshare.NewPortForwarder(session, name, pair.remote)
|
||||
fwd := liveshare.NewPortForwarder(session, name, pair.remote, false)
|
||||
return fwd.ForwardToListener(ctx, listen) // error always non-nil
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,34 +3,46 @@ package codespace
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type sshOptions struct {
|
||||
codespace string
|
||||
profile string
|
||||
serverPort int
|
||||
debug bool
|
||||
debugFile string
|
||||
}
|
||||
|
||||
func newSSHCmd(app *App) *cobra.Command {
|
||||
var sshProfile, codespaceName string
|
||||
var sshServerPort int
|
||||
var opts sshOptions
|
||||
|
||||
sshCmd := &cobra.Command{
|
||||
Use: "ssh [flags] [--] [ssh-flags] [command]",
|
||||
Short: "SSH into a codespace",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return app.SSH(cmd.Context(), args, sshProfile, codespaceName, sshServerPort)
|
||||
return app.SSH(cmd.Context(), args, opts)
|
||||
},
|
||||
}
|
||||
|
||||
sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "Name of the SSH profile to use")
|
||||
sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH server port number (0 => pick unused)")
|
||||
sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "Name of the codespace")
|
||||
sshCmd.Flags().StringVarP(&opts.profile, "profile", "", "", "Name of the SSH profile to use")
|
||||
sshCmd.Flags().IntVarP(&opts.serverPort, "server-port", "", 0, "SSH server port number (0 => pick unused)")
|
||||
sshCmd.Flags().StringVarP(&opts.codespace, "codespace", "c", "", "Name of the codespace")
|
||||
sshCmd.Flags().BoolVarP(&opts.debug, "debug", "d", false, "Log debug data to a file")
|
||||
sshCmd.Flags().StringVarP(&opts.debugFile, "debug-file", "", "", "Path of the file log to")
|
||||
|
||||
return sshCmd
|
||||
}
|
||||
|
||||
// SSH opens an ssh session or runs an ssh command in a codespace.
|
||||
func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) {
|
||||
func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err error) {
|
||||
// Ensure all child tasks (e.g. port forwarding) terminate before return.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
|
@ -45,12 +57,22 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa
|
|||
authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login)
|
||||
}()
|
||||
|
||||
codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName)
|
||||
codespace, err := getOrChooseCodespace(ctx, a.apiClient, opts.codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get or choose codespace: %w", err)
|
||||
}
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, codespace)
|
||||
var debugLogger *fileLogger
|
||||
if opts.debug {
|
||||
debugLogger, err = newFileLogger(opts.debugFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating debug logger: %w", err)
|
||||
}
|
||||
defer safeClose(debugLogger, &err)
|
||||
a.logger.Println("Debug file located at: " + debugLogger.Name())
|
||||
}
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, debugLogger, a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to Live Share: %w", err)
|
||||
}
|
||||
|
|
@ -66,6 +88,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa
|
|||
return fmt.Errorf("error getting ssh server details: %w", err)
|
||||
}
|
||||
|
||||
localSSHServerPort := opts.serverPort
|
||||
usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell
|
||||
|
||||
// Ensure local port is listening before client (Shell) connects.
|
||||
|
|
@ -76,7 +99,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa
|
|||
defer listen.Close()
|
||||
localSSHServerPort = listen.Addr().(*net.TCPAddr).Port
|
||||
|
||||
connectDestination := sshProfile
|
||||
connectDestination := opts.profile
|
||||
if connectDestination == "" {
|
||||
connectDestination = fmt.Sprintf("%s@localhost", sshUser)
|
||||
}
|
||||
|
|
@ -84,7 +107,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa
|
|||
a.logger.Println("Ready...")
|
||||
tunnelClosed := make(chan error, 1)
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true)
|
||||
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil
|
||||
}()
|
||||
|
||||
|
|
@ -103,3 +126,43 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceNa
|
|||
return nil // success
|
||||
}
|
||||
}
|
||||
|
||||
// fileLogger is a wrapper around an log.Logger configured to write
|
||||
// to a file. It exports two additional methods to get the log file name
|
||||
// and close the file handle when the operation is finished.
|
||||
type fileLogger struct {
|
||||
*log.Logger
|
||||
|
||||
f *os.File
|
||||
}
|
||||
|
||||
// newFileLogger creates a new fileLogger. It returns an error if the file
|
||||
// cannot be created. The file is created on the specified path, if the path
|
||||
// is empty it is created in the temporary directory.
|
||||
func newFileLogger(file string) (fl *fileLogger, err error) {
|
||||
var f *os.File
|
||||
if file == "" {
|
||||
f, err = ioutil.TempFile("", "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create tmp file: %w", err)
|
||||
}
|
||||
} else {
|
||||
f, err = os.Create(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &fileLogger{
|
||||
Logger: log.New(f, "", log.LstdFlags),
|
||||
f: f,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (fl *fileLogger) Name() string {
|
||||
return fl.f.Name()
|
||||
}
|
||||
|
||||
func (fl *fileLogger) Close() error {
|
||||
return fl.f.Close()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,23 +17,34 @@ import (
|
|||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type logger interface {
|
||||
Println(v ...interface{})
|
||||
Printf(f string, v ...interface{})
|
||||
}
|
||||
|
||||
// An Options specifies Live Share connection parameters.
|
||||
type Options struct {
|
||||
ClientName string // ClientName is the name of the connecting client.
|
||||
SessionID string
|
||||
SessionToken string // token for SSH session
|
||||
RelaySAS string
|
||||
RelayEndpoint string
|
||||
HostPublicKeys []string
|
||||
Logger logger // required
|
||||
TLSConfig *tls.Config // (optional)
|
||||
}
|
||||
|
||||
// uri returns a websocket URL for the specified options.
|
||||
func (opts *Options) uri(action string) (string, error) {
|
||||
if opts.ClientName == "" {
|
||||
return "", errors.New("ClientName is required")
|
||||
}
|
||||
if opts.SessionID == "" {
|
||||
return "", errors.New("SessionID is required")
|
||||
}
|
||||
|
|
@ -56,13 +67,17 @@ func (opts *Options) uri(action string) (string, error) {
|
|||
// options, and returns a session representing the connection.
|
||||
// The caller must call the session's Close method to end the session.
|
||||
func Connect(ctx context.Context, opts Options) (*Session, error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
|
||||
defer span.Finish()
|
||||
|
||||
uri, err := opts.uri("connect")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
|
||||
defer span.Finish()
|
||||
if opts.Logger == nil {
|
||||
return nil, errors.New("Logger is required")
|
||||
}
|
||||
|
||||
sock := newSocket(uri, opts.TLSConfig)
|
||||
if err := sock.connect(ctx); err != nil {
|
||||
|
|
@ -93,7 +108,16 @@ func Connect(ctx context.Context, opts Options) (*Session, error) {
|
|||
return nil, fmt.Errorf("error joining Live Share workspace: %w", err)
|
||||
}
|
||||
|
||||
return &Session{ssh: ssh, rpc: rpc}, nil
|
||||
s := &Session{
|
||||
ssh: ssh,
|
||||
rpc: rpc,
|
||||
clientName: opts.ClientName,
|
||||
keepAliveReason: make(chan string, 1),
|
||||
logger: opts.Logger,
|
||||
}
|
||||
go s.heartbeat(ctx, 1*time.Minute)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
type clientCapabilities struct {
|
||||
|
|
|
|||
|
|
@ -15,10 +15,12 @@ import (
|
|||
|
||||
func TestConnect(t *testing.T) {
|
||||
opts := Options{
|
||||
ClientName: "liveshare-client",
|
||||
SessionID: "session-id",
|
||||
SessionToken: "session-token",
|
||||
RelaySAS: "relay-sas",
|
||||
HostPublicKeys: []string{livesharetest.SSHPublicKey},
|
||||
Logger: newMockLogger(),
|
||||
}
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
var joinWorkspaceReq joinWorkspaceArgs
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ func checkBadOptions(t *testing.T, opts Options) {
|
|||
|
||||
func TestOptionsURI(t *testing.T) {
|
||||
opts := Options{
|
||||
ClientName: "liveshare-client",
|
||||
SessionID: "sess-id",
|
||||
SessionToken: "sess-token",
|
||||
RelaySAS: "sas",
|
||||
|
|
|
|||
|
|
@ -15,16 +15,19 @@ type PortForwarder struct {
|
|||
session *Session
|
||||
name string
|
||||
remotePort int
|
||||
keepAlive bool
|
||||
}
|
||||
|
||||
// NewPortForwarder returns a new PortForwarder for the specified
|
||||
// remote port and Live Share session. The name describes the purpose
|
||||
// of the remote port or service.
|
||||
func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder {
|
||||
// of the remote port or service. The keepAlive flag indicates whether
|
||||
// the session should be kept alive with port forwarding traffic.
|
||||
func NewPortForwarder(session *Session, name string, remotePort int, keepAlive bool) *PortForwarder {
|
||||
return &PortForwarder{
|
||||
session: session,
|
||||
name: name,
|
||||
remotePort: remotePort,
|
||||
keepAlive: keepAlive,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -106,6 +109,27 @@ func awaitError(ctx context.Context, errc <-chan error) error {
|
|||
}
|
||||
}
|
||||
|
||||
// trafficMonitor implements io.Reader. It keeps the session alive by notifying
|
||||
// it of the traffic type during Read operations.
|
||||
type trafficMonitor struct {
|
||||
reader io.Reader
|
||||
|
||||
session *Session
|
||||
trafficType string
|
||||
}
|
||||
|
||||
// newTrafficMonitor returns a new trafficMonitor for the specified
|
||||
// session and traffic type. It wraps the provided io.Reader with its own
|
||||
// Read method.
|
||||
func newTrafficMonitor(reader io.Reader, session *Session, trafficType string) *trafficMonitor {
|
||||
return &trafficMonitor{reader, session, trafficType}
|
||||
}
|
||||
|
||||
func (t *trafficMonitor) Read(p []byte) (n int, err error) {
|
||||
t.session.keepAlive(t.trafficType)
|
||||
return t.reader.Read(p)
|
||||
}
|
||||
|
||||
// handleConnection handles forwarding for a single accepted connection, then closes it.
|
||||
func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
|
||||
|
|
@ -133,8 +157,21 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co
|
|||
_, err := io.Copy(w, r)
|
||||
errs <- err
|
||||
}
|
||||
go copyConn(conn, channel)
|
||||
go copyConn(channel, conn)
|
||||
|
||||
var (
|
||||
channelReader io.Reader = channel
|
||||
connReader io.Reader = conn
|
||||
)
|
||||
|
||||
// If the forwader has been configured to keep the session alive
|
||||
// it will monitor the I/O and notify the session of the traffic.
|
||||
if fwd.keepAlive {
|
||||
channelReader = newTrafficMonitor(channelReader, fwd.session, "output")
|
||||
connReader = newTrafficMonitor(connReader, fwd.session, "input")
|
||||
}
|
||||
|
||||
go copyConn(conn, channelReader)
|
||||
go copyConn(channel, connReader)
|
||||
|
||||
// Wait until context is cancelled or both copies are done.
|
||||
// Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail.
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) {
|
|||
t.Errorf("create mock client: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
pf := NewPortForwarder(session, "ssh", 80)
|
||||
pf := NewPortForwarder(session, "ssh", 80, false)
|
||||
if pf == nil {
|
||||
t.Error("port forwarder is nil")
|
||||
}
|
||||
|
|
@ -58,7 +58,7 @@ func TestPortForwarderStart(t *testing.T) {
|
|||
done := make(chan error)
|
||||
go func() {
|
||||
const name, remote = "ssh", 8000
|
||||
done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen)
|
||||
done <- NewPortForwarder(session, name, remote, false).ForwardToListener(ctx, listen)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
|
|
@ -93,3 +93,26 @@ func TestPortForwarderStart(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortForwarderTrafficMonitor(t *testing.T) {
|
||||
buf := bytes.NewBufferString("some-input")
|
||||
session := &Session{keepAliveReason: make(chan string, 1)}
|
||||
trafficType := "io"
|
||||
|
||||
tm := newTrafficMonitor(buf, session, trafficType)
|
||||
l := len(buf.Bytes())
|
||||
|
||||
bb := make([]byte, l)
|
||||
n, err := tm.Read(bb)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read from traffic monitor: %w", err)
|
||||
}
|
||||
if n != l {
|
||||
t.Errorf("expected to read %d bytes, got %d", l, n)
|
||||
}
|
||||
|
||||
keepAliveReason := <-session.keepAliveReason
|
||||
if keepAliveReason != trafficType {
|
||||
t.Errorf("expected keep alive reason to be %s, got %s", trafficType, keepAliveReason)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,12 +4,17 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A Session represents the session between a connected Live Share client and server.
|
||||
type Session struct {
|
||||
ssh *sshSession
|
||||
rpc *rpcClient
|
||||
|
||||
clientName string
|
||||
keepAliveReason chan string
|
||||
logger logger
|
||||
}
|
||||
|
||||
// Close should be called by users to clean up RPC and SSH resources whenever the session
|
||||
|
|
@ -97,3 +102,42 @@ func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) {
|
|||
|
||||
return port, response.User, nil
|
||||
}
|
||||
|
||||
// heartbeat runs until context cancellation, periodically checking whether there is a
|
||||
// reason to keep the connection alive, and if so, notifying the Live Share host to do so.
|
||||
func (s *Session) heartbeat(ctx context.Context, interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.logger.Println("Heartbeat tick")
|
||||
reason := <-s.keepAliveReason
|
||||
s.logger.Println("Keep alive reason: " + reason)
|
||||
if err := s.notifyHostOfActivity(ctx, reason); err != nil {
|
||||
s.logger.Printf("Failed to notify host of activity: %s\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// notifyHostOfActivity notifies the Live Share host of client activity.
|
||||
func (s *Session) notifyHostOfActivity(ctx context.Context, activity string) error {
|
||||
activities := []string{activity}
|
||||
params := []interface{}{s.clientName, activities}
|
||||
return s.rpc.do(ctx, "ICodespaceHostService.notifyCodespaceOfClientActivity", params, nil)
|
||||
}
|
||||
|
||||
// keepAlive accepts a reason that is retained if there is no active reason
|
||||
// to send to the server.
|
||||
func (s *Session) keepAlive(reason string) {
|
||||
select {
|
||||
case s.keepAliveReason <- reason:
|
||||
default:
|
||||
// there is already an active keep alive reason
|
||||
// so we can ignore this one
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,18 +1,23 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
const mockClientName = "liveshare-client"
|
||||
|
||||
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return joinWorkspaceResult{1}, nil
|
||||
|
|
@ -29,12 +34,14 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server,
|
|||
}
|
||||
|
||||
session, err := Connect(context.Background(), Options{
|
||||
ClientName: mockClientName,
|
||||
SessionID: "session-id",
|
||||
SessionToken: sessionToken,
|
||||
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
|
||||
RelaySAS: "relay-sas",
|
||||
HostPublicKeys: []string{livesharetest.SSHPublicKey},
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
Logger: newMockLogger(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err)
|
||||
|
|
@ -221,3 +228,176 @@ func TestInvalidHostKey(t *testing.T) {
|
|||
t.Error("expected invalid host key error, got: nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeepAliveNonBlocking(t *testing.T) {
|
||||
session := &Session{keepAliveReason: make(chan string, 1)}
|
||||
for i := 0; i < 2; i++ {
|
||||
session.keepAlive("io")
|
||||
}
|
||||
|
||||
// if keepAlive blocks, we'll never reach this and timeout the test
|
||||
// timing out
|
||||
}
|
||||
|
||||
func TestNotifyHostOfActivity(t *testing.T) {
|
||||
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
var req []interface{}
|
||||
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal req: %w", err)
|
||||
}
|
||||
if len(req) < 2 {
|
||||
return nil, errors.New("request arguments is less than 2")
|
||||
}
|
||||
|
||||
if clientName, ok := req[0].(string); ok {
|
||||
if clientName != mockClientName {
|
||||
return nil, fmt.Errorf(
|
||||
"unexpected clientName param, expected: %q, got: %q", mockClientName, clientName,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("clientName param is not a string")
|
||||
}
|
||||
|
||||
if acs, ok := req[1].([]interface{}); ok {
|
||||
if fmt.Sprintf("%s", acs) != "[input]" {
|
||||
return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs)
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("activities param is not a slice")
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
svc := livesharetest.WithService(
|
||||
"ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity,
|
||||
)
|
||||
testServer, session, err := makeMockSession(svc)
|
||||
if err != nil {
|
||||
t.Errorf("creating mock session: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
ctx := context.Background()
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
done <- session.notifyHostOfActivity(ctx, "input")
|
||||
}()
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionHeartbeat(t *testing.T) {
|
||||
var (
|
||||
requestsMu sync.Mutex
|
||||
requests int
|
||||
)
|
||||
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
requestsMu.Lock()
|
||||
requests++
|
||||
requestsMu.Unlock()
|
||||
|
||||
var req []interface{}
|
||||
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal req: %w", err)
|
||||
}
|
||||
if len(req) < 2 {
|
||||
return nil, errors.New("request arguments is less than 2")
|
||||
}
|
||||
|
||||
if clientName, ok := req[0].(string); ok {
|
||||
if clientName != mockClientName {
|
||||
return nil, fmt.Errorf(
|
||||
"unexpected clientName param, expected: %q, got: %q", mockClientName, clientName,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("clientName param is not a string")
|
||||
}
|
||||
|
||||
if acs, ok := req[1].([]interface{}); ok {
|
||||
if fmt.Sprintf("%s", acs) != "[input]" {
|
||||
return nil, fmt.Errorf("unexpected activities param, expected: [input], got: %s", acs)
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("activities param is not a slice")
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
svc := livesharetest.WithService(
|
||||
"ICodespaceHostService.notifyCodespaceOfClientActivity", notifyHostOfActivity,
|
||||
)
|
||||
testServer, session, err := makeMockSession(svc)
|
||||
if err != nil {
|
||||
t.Errorf("creating mock session: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
logger := newMockLogger()
|
||||
session.logger = logger
|
||||
|
||||
go session.heartbeat(ctx, 50*time.Millisecond)
|
||||
go func() {
|
||||
session.keepAlive("input")
|
||||
<-time.Tick(200 * time.Millisecond)
|
||||
session.keepAlive("input")
|
||||
<-time.Tick(100 * time.Millisecond)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case <-done:
|
||||
activityCount := strings.Count(logger.String(), "input")
|
||||
if activityCount != 2 {
|
||||
t.Errorf("unexpected number of activities, expected: 2, got: %d", activityCount)
|
||||
}
|
||||
|
||||
requestsMu.Lock()
|
||||
rc := requests
|
||||
requestsMu.Unlock()
|
||||
if rc != 2 {
|
||||
t.Errorf("unexpected number of requests, expected: 2, got: %d", requests)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type mockLogger struct {
|
||||
sync.Mutex
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func newMockLogger() *mockLogger {
|
||||
return &mockLogger{buf: new(bytes.Buffer)}
|
||||
}
|
||||
|
||||
func (m *mockLogger) Printf(format string, v ...interface{}) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
m.buf.WriteString(fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) Println(v ...interface{}) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
m.buf.WriteString(fmt.Sprintln(v...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) String() string {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.buf.String()
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue