Extract LiveshareSession interface (#5725)

This will make it possible to inject a mock liveshare session for testing
This commit is contained in:
Greggory Rothmeier 2022-06-06 06:52:52 -07:00 committed by GitHub
parent 866eccc202
commit f15a8ca335
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 114 additions and 90 deletions

View file

@ -19,6 +19,7 @@ import (
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
)
@ -60,8 +61,18 @@ func (a *App) StopProgressIndicator() {
a.io.StopProgressIndicator()
}
type liveshareSession interface {
Close() error
GetSharedServers(context.Context) ([]*liveshare.Port, error)
KeepAlive(string)
OpenStreamingChannel(context.Context, liveshare.ChannelID) (ssh.Channel, error)
StartJupyterServer(context.Context) (int, string, error)
StartSharing(context.Context, string, int) (liveshare.ChannelID, error)
StartSSHServer(context.Context) (int, string, error)
}
// Connects to a codespace using Live Share and returns that session
func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App, debug bool, debugFile string) (session *liveshare.Session, err error) {
func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App, debug bool, debugFile string) (session liveshareSession, err error) {
// While connecting, ensure in the background that the user has keys installed.
// That lets us report a more useful error message if they don't.
authkeys := make(chan error, 1)

View file

@ -20,7 +20,6 @@ import (
"time"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/ssh"
)
type logger interface {
@ -136,41 +135,3 @@ type joinWorkspaceArgs struct {
type joinWorkspaceResult struct {
SessionNumber int `json:"sessionNumber"`
}
// A channelID is an identifier for an exposed port on a remote
// container that may be used to open an SSH channel to it.
type channelID struct {
name, condition string
}
func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) {
type getStreamArgs struct {
StreamName string `json:"streamName"`
Condition string `json:"condition"`
}
args := getStreamArgs{
StreamName: id.name,
Condition: id.condition,
}
var streamID string
if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
return nil, fmt.Errorf("error getting stream id: %w", err)
}
span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest")
defer span.Finish()
_ = ctx // ctx is not currently used
channel, reqs, err := s.ssh.conn.OpenChannel("session", nil)
if err != nil {
return nil, fmt.Errorf("error opening ssh channel for transport: %w", err)
}
go ssh.DiscardRequests(reqs)
requestType := fmt.Sprintf("stream-transport-%s", streamID)
if _, err = channel.SendRequest(requestType, true, nil); err != nil {
return nil, fmt.Errorf("error sending channel request: %w", err)
}
return channel, nil
}

View file

@ -7,12 +7,19 @@ import (
"net"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/ssh"
)
type portForwardingSession interface {
StartSharing(context.Context, string, int) (ChannelID, error)
OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error)
KeepAlive(string)
}
// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote
// container to a local destination such as a network port or Go reader/writer.
type PortForwarder struct {
session *Session
session portForwardingSession
name string
remotePort int
keepAlive bool
@ -22,7 +29,7 @@ type PortForwarder struct {
// remote port and Live Share session. The name describes the purpose
// of the remote port or service. The keepAlive flag indicates whether
// the session should be kept alive with port forwarding traffic.
func NewPortForwarder(session *Session, name string, remotePort int, keepAlive bool) *PortForwarder {
func NewPortForwarder(session portForwardingSession, name string, remotePort int, keepAlive bool) *PortForwarder {
return &PortForwarder{
session: session,
name: name,
@ -92,8 +99,8 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser)
return awaitError(ctx, errc)
}
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) {
id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort)
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) {
id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort)
if err != nil {
err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err)
}
@ -110,35 +117,39 @@ func awaitError(ctx context.Context, errc <-chan error) error {
}
}
type trafficMonitorSession interface {
KeepAlive(string)
}
// trafficMonitor implements io.Reader. It keeps the session alive by notifying
// it of the traffic type during Read operations.
type trafficMonitor struct {
reader io.Reader
session *Session
session trafficMonitorSession
trafficType string
}
// newTrafficMonitor returns a new trafficMonitor for the specified
// session and traffic type. It wraps the provided io.Reader with its own
// Read method.
func newTrafficMonitor(reader io.Reader, session *Session, trafficType string) *trafficMonitor {
func newTrafficMonitor(reader io.Reader, session trafficMonitorSession, trafficType string) *trafficMonitor {
return &trafficMonitor{reader, session, trafficType}
}
func (t *trafficMonitor) Read(p []byte) (n int, err error) {
t.session.keepAlive(t.trafficType)
t.session.KeepAlive(t.trafficType)
return t.reader.Read(p)
}
// handleConnection handles forwarding for a single accepted connection, then closes it.
func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) {
func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn io.ReadWriteCloser) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
defer span.Finish()
defer safeClose(conn, &err)
channel, err := fwd.session.openStreamingChannel(ctx, id)
channel, err := fwd.session.OpenStreamingChannel(ctx, id)
if err != nil {
return fmt.Errorf("error opening streaming channel for new connection: %w", err)
}

View file

@ -6,7 +6,6 @@ import (
"fmt"
"github.com/sourcegraph/jsonrpc2"
"golang.org/x/sync/errgroup"
)
// Port describes a port exposed by the container.
@ -30,37 +29,6 @@ const (
PortChangeKindUpdate PortChangeKind = "update"
)
// startSharing tells the Live Share host to start sharing the specified port from the container.
// The sessionName describes the purpose of the remote port or service.
// It returns an identifier that can be used to open an SSH channel to the remote port.
func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) {
args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
startNotification, err := s.WaitForPortNotification(ctx, port, PortChangeKindStart)
if err != nil {
return fmt.Errorf("error while waiting for port notification: %w", err)
}
if !startNotification.Success {
return fmt.Errorf("error while starting port sharing: %s", startNotification.ErrorDetail)
}
return nil // success
})
var response Port
g.Go(func() error {
return s.rpc.do(ctx, "serverSharing.startSharing", args, &response)
})
if err := g.Wait(); err != nil {
return channelID{}, err
}
return channelID{response.StreamName, response.StreamCondition}, nil
}
type PortNotification struct {
Success bool // Helps us disambiguate between the SharingSucceeded/SharingFailed events
// The following are properties included in the SharingSucceeded/SharingFailed events sent by the server sharing service in the Codespace

View file

@ -5,8 +5,18 @@ import (
"fmt"
"strconv"
"time"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)
// A ChannelID is an identifier for an exposed port on a remote
// container that may be used to open an SSH channel to it.
type ChannelID struct {
name, condition string
}
// A Session represents the session between a connected Live Share client and server.
type Session struct {
ssh *sshSession
@ -91,7 +101,7 @@ func (s *Session) StartJupyterServer(ctx context.Context) (int, string, error) {
// heartbeat runs until context cancellation, periodically checking whether there is a
// reason to keep the connection alive, and if so, notifying the Live Share host to do so.
// Heartbeat ensures it does not send more than one request every "interval" to ratelimit
// how many keepAlives we send at a time.
// how many KeepAlives we send at a time.
func (s *Session) heartbeat(ctx context.Context, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
@ -118,9 +128,9 @@ func (s *Session) notifyHostOfActivity(ctx context.Context, activity string) err
return s.rpc.do(ctx, "ICodespaceHostService.notifyCodespaceOfClientActivity", params, nil)
}
// keepAlive accepts a reason that is retained if there is no active reason
// KeepAlive accepts a reason that is retained if there is no active reason
// to send to the server.
func (s *Session) keepAlive(reason string) {
func (s *Session) KeepAlive(reason string) {
select {
case s.keepAliveReason <- reason:
default:
@ -128,3 +138,66 @@ func (s *Session) keepAlive(reason string) {
// so we can ignore this one
}
}
// StartSharing tells the Live Share host to start sharing the specified port from the container.
// The sessionName describes the purpose of the remote port or service.
// It returns an identifier that can be used to open an SSH channel to the remote port.
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (ChannelID, error) {
args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
startNotification, err := s.WaitForPortNotification(ctx, port, PortChangeKindStart)
if err != nil {
return fmt.Errorf("error while waiting for port notification: %w", err)
}
if !startNotification.Success {
return fmt.Errorf("error while starting port sharing: %s", startNotification.ErrorDetail)
}
return nil // success
})
var response Port
g.Go(func() error {
return s.rpc.do(ctx, "serverSharing.startSharing", args, &response)
})
if err := g.Wait(); err != nil {
return ChannelID{}, err
}
return ChannelID{response.StreamName, response.StreamCondition}, nil
}
func (s *Session) OpenStreamingChannel(ctx context.Context, id ChannelID) (ssh.Channel, error) {
type getStreamArgs struct {
StreamName string `json:"streamName"`
Condition string `json:"condition"`
}
args := getStreamArgs{
StreamName: id.name,
Condition: id.condition,
}
var streamID string
if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
return nil, fmt.Errorf("error getting stream id: %w", err)
}
span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest")
defer span.Finish()
_ = ctx // ctx is not currently used
channel, reqs, err := s.ssh.conn.OpenChannel("session", nil)
if err != nil {
return nil, fmt.Errorf("error opening ssh channel for transport: %w", err)
}
go ssh.DiscardRequests(reqs)
requestType := fmt.Sprintf("stream-transport-%s", streamID)
if _, err = channel.SendRequest(requestType, true, nil); err != nil {
return nil, fmt.Errorf("error sending channel request: %w", err)
}
return channel, nil
}

View file

@ -103,7 +103,7 @@ func TestServerStartSharing(t *testing.T) {
done := make(chan error)
go func() {
streamID, err := session.startSharing(ctx, serverProtocol, serverPort)
streamID, err := session.StartSharing(ctx, serverProtocol, serverPort)
if err != nil {
done <- fmt.Errorf("error sharing server: %w", err)
}
@ -247,10 +247,10 @@ func TestInvalidHostKey(t *testing.T) {
func TestKeepAliveNonBlocking(t *testing.T) {
session := &Session{keepAliveReason: make(chan string, 1)}
for i := 0; i < 2; i++ {
session.keepAlive("io")
session.KeepAlive("io")
}
// if keepAlive blocks, we'll never reach this and timeout the test
// if KeepAlive blocks, we'll never reach this and timeout the test
// timing out
}
@ -367,10 +367,10 @@ func TestSessionHeartbeat(t *testing.T) {
go session.heartbeat(ctx, 50*time.Millisecond)
go func() {
session.keepAlive("input")
session.KeepAlive("input")
wg.Wait()
wg.Add(1)
session.keepAlive("input")
session.KeepAlive("input")
wg.Wait()
done <- struct{}{}
}()
@ -380,7 +380,7 @@ func TestSessionHeartbeat(t *testing.T) {
t.Errorf("error from server: %v", err)
case <-done:
activityCount := strings.Count(logger.String(), "input")
// by design keepAlive can drop requests, and therefore there is zero guarantee
// by design KeepAlive can drop requests, and therefore there is zero guarantee
// that we actually get two requests if the network happened to be slow (rarely)
// during testing.
if activityCount != 1 && activityCount != 2 {