Extract LiveshareSession interface (#5725)
This will make it possible to inject a mock liveshare session for testing
This commit is contained in:
parent
866eccc202
commit
f15a8ca335
6 changed files with 114 additions and 90 deletions
|
|
@ -19,6 +19,7 @@ import (
|
|||
"github.com/cli/cli/v2/pkg/iostreams"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
|
|
@ -60,8 +61,18 @@ func (a *App) StopProgressIndicator() {
|
|||
a.io.StopProgressIndicator()
|
||||
}
|
||||
|
||||
type liveshareSession interface {
|
||||
Close() error
|
||||
GetSharedServers(context.Context) ([]*liveshare.Port, error)
|
||||
KeepAlive(string)
|
||||
OpenStreamingChannel(context.Context, liveshare.ChannelID) (ssh.Channel, error)
|
||||
StartJupyterServer(context.Context) (int, string, error)
|
||||
StartSharing(context.Context, string, int) (liveshare.ChannelID, error)
|
||||
StartSSHServer(context.Context) (int, string, error)
|
||||
}
|
||||
|
||||
// Connects to a codespace using Live Share and returns that session
|
||||
func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App, debug bool, debugFile string) (session *liveshare.Session, err error) {
|
||||
func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App, debug bool, debugFile string) (session liveshareSession, err error) {
|
||||
// While connecting, ensure in the background that the user has keys installed.
|
||||
// That lets us report a more useful error message if they don't.
|
||||
authkeys := make(chan error, 1)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type logger interface {
|
||||
|
|
@ -136,41 +135,3 @@ type joinWorkspaceArgs struct {
|
|||
type joinWorkspaceResult struct {
|
||||
SessionNumber int `json:"sessionNumber"`
|
||||
}
|
||||
|
||||
// A channelID is an identifier for an exposed port on a remote
|
||||
// container that may be used to open an SSH channel to it.
|
||||
type channelID struct {
|
||||
name, condition string
|
||||
}
|
||||
|
||||
func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) {
|
||||
type getStreamArgs struct {
|
||||
StreamName string `json:"streamName"`
|
||||
Condition string `json:"condition"`
|
||||
}
|
||||
args := getStreamArgs{
|
||||
StreamName: id.name,
|
||||
Condition: id.condition,
|
||||
}
|
||||
var streamID string
|
||||
if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
|
||||
return nil, fmt.Errorf("error getting stream id: %w", err)
|
||||
}
|
||||
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest")
|
||||
defer span.Finish()
|
||||
_ = ctx // ctx is not currently used
|
||||
|
||||
channel, reqs, err := s.ssh.conn.OpenChannel("session", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening ssh channel for transport: %w", err)
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
requestType := fmt.Sprintf("stream-transport-%s", streamID)
|
||||
if _, err = channel.SendRequest(requestType, true, nil); err != nil {
|
||||
return nil, fmt.Errorf("error sending channel request: %w", err)
|
||||
}
|
||||
|
||||
return channel, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,12 +7,19 @@ import (
|
|||
"net"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type portForwardingSession interface {
|
||||
StartSharing(context.Context, string, int) (ChannelID, error)
|
||||
OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error)
|
||||
KeepAlive(string)
|
||||
}
|
||||
|
||||
// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote
|
||||
// container to a local destination such as a network port or Go reader/writer.
|
||||
type PortForwarder struct {
|
||||
session *Session
|
||||
session portForwardingSession
|
||||
name string
|
||||
remotePort int
|
||||
keepAlive bool
|
||||
|
|
@ -22,7 +29,7 @@ type PortForwarder struct {
|
|||
// remote port and Live Share session. The name describes the purpose
|
||||
// of the remote port or service. The keepAlive flag indicates whether
|
||||
// the session should be kept alive with port forwarding traffic.
|
||||
func NewPortForwarder(session *Session, name string, remotePort int, keepAlive bool) *PortForwarder {
|
||||
func NewPortForwarder(session portForwardingSession, name string, remotePort int, keepAlive bool) *PortForwarder {
|
||||
return &PortForwarder{
|
||||
session: session,
|
||||
name: name,
|
||||
|
|
@ -92,8 +99,8 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser)
|
|||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) {
|
||||
id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort)
|
||||
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) {
|
||||
id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err)
|
||||
}
|
||||
|
|
@ -110,35 +117,39 @@ func awaitError(ctx context.Context, errc <-chan error) error {
|
|||
}
|
||||
}
|
||||
|
||||
type trafficMonitorSession interface {
|
||||
KeepAlive(string)
|
||||
}
|
||||
|
||||
// trafficMonitor implements io.Reader. It keeps the session alive by notifying
|
||||
// it of the traffic type during Read operations.
|
||||
type trafficMonitor struct {
|
||||
reader io.Reader
|
||||
|
||||
session *Session
|
||||
session trafficMonitorSession
|
||||
trafficType string
|
||||
}
|
||||
|
||||
// newTrafficMonitor returns a new trafficMonitor for the specified
|
||||
// session and traffic type. It wraps the provided io.Reader with its own
|
||||
// Read method.
|
||||
func newTrafficMonitor(reader io.Reader, session *Session, trafficType string) *trafficMonitor {
|
||||
func newTrafficMonitor(reader io.Reader, session trafficMonitorSession, trafficType string) *trafficMonitor {
|
||||
return &trafficMonitor{reader, session, trafficType}
|
||||
}
|
||||
|
||||
func (t *trafficMonitor) Read(p []byte) (n int, err error) {
|
||||
t.session.keepAlive(t.trafficType)
|
||||
t.session.KeepAlive(t.trafficType)
|
||||
return t.reader.Read(p)
|
||||
}
|
||||
|
||||
// handleConnection handles forwarding for a single accepted connection, then closes it.
|
||||
func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) {
|
||||
func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn io.ReadWriteCloser) (err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
|
||||
defer span.Finish()
|
||||
|
||||
defer safeClose(conn, &err)
|
||||
|
||||
channel, err := fwd.session.openStreamingChannel(ctx, id)
|
||||
channel, err := fwd.session.OpenStreamingChannel(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening streaming channel for new connection: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// Port describes a port exposed by the container.
|
||||
|
|
@ -30,37 +29,6 @@ const (
|
|||
PortChangeKindUpdate PortChangeKind = "update"
|
||||
)
|
||||
|
||||
// startSharing tells the Live Share host to start sharing the specified port from the container.
|
||||
// The sessionName describes the purpose of the remote port or service.
|
||||
// It returns an identifier that can be used to open an SSH channel to the remote port.
|
||||
func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) {
|
||||
args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
startNotification, err := s.WaitForPortNotification(ctx, port, PortChangeKindStart)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while waiting for port notification: %w", err)
|
||||
|
||||
}
|
||||
if !startNotification.Success {
|
||||
return fmt.Errorf("error while starting port sharing: %s", startNotification.ErrorDetail)
|
||||
}
|
||||
return nil // success
|
||||
})
|
||||
|
||||
var response Port
|
||||
g.Go(func() error {
|
||||
return s.rpc.do(ctx, "serverSharing.startSharing", args, &response)
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return channelID{}, err
|
||||
}
|
||||
|
||||
return channelID{response.StreamName, response.StreamCondition}, nil
|
||||
}
|
||||
|
||||
type PortNotification struct {
|
||||
Success bool // Helps us disambiguate between the SharingSucceeded/SharingFailed events
|
||||
// The following are properties included in the SharingSucceeded/SharingFailed events sent by the server sharing service in the Codespace
|
||||
|
|
|
|||
|
|
@ -5,8 +5,18 @@ import (
|
|||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// A ChannelID is an identifier for an exposed port on a remote
|
||||
// container that may be used to open an SSH channel to it.
|
||||
type ChannelID struct {
|
||||
name, condition string
|
||||
}
|
||||
|
||||
// A Session represents the session between a connected Live Share client and server.
|
||||
type Session struct {
|
||||
ssh *sshSession
|
||||
|
|
@ -91,7 +101,7 @@ func (s *Session) StartJupyterServer(ctx context.Context) (int, string, error) {
|
|||
// heartbeat runs until context cancellation, periodically checking whether there is a
|
||||
// reason to keep the connection alive, and if so, notifying the Live Share host to do so.
|
||||
// Heartbeat ensures it does not send more than one request every "interval" to ratelimit
|
||||
// how many keepAlives we send at a time.
|
||||
// how many KeepAlives we send at a time.
|
||||
func (s *Session) heartbeat(ctx context.Context, interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
|
@ -118,9 +128,9 @@ func (s *Session) notifyHostOfActivity(ctx context.Context, activity string) err
|
|||
return s.rpc.do(ctx, "ICodespaceHostService.notifyCodespaceOfClientActivity", params, nil)
|
||||
}
|
||||
|
||||
// keepAlive accepts a reason that is retained if there is no active reason
|
||||
// KeepAlive accepts a reason that is retained if there is no active reason
|
||||
// to send to the server.
|
||||
func (s *Session) keepAlive(reason string) {
|
||||
func (s *Session) KeepAlive(reason string) {
|
||||
select {
|
||||
case s.keepAliveReason <- reason:
|
||||
default:
|
||||
|
|
@ -128,3 +138,66 @@ func (s *Session) keepAlive(reason string) {
|
|||
// so we can ignore this one
|
||||
}
|
||||
}
|
||||
|
||||
// StartSharing tells the Live Share host to start sharing the specified port from the container.
|
||||
// The sessionName describes the purpose of the remote port or service.
|
||||
// It returns an identifier that can be used to open an SSH channel to the remote port.
|
||||
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (ChannelID, error) {
|
||||
args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
startNotification, err := s.WaitForPortNotification(ctx, port, PortChangeKindStart)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while waiting for port notification: %w", err)
|
||||
|
||||
}
|
||||
if !startNotification.Success {
|
||||
return fmt.Errorf("error while starting port sharing: %s", startNotification.ErrorDetail)
|
||||
}
|
||||
return nil // success
|
||||
})
|
||||
|
||||
var response Port
|
||||
g.Go(func() error {
|
||||
return s.rpc.do(ctx, "serverSharing.startSharing", args, &response)
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return ChannelID{}, err
|
||||
}
|
||||
|
||||
return ChannelID{response.StreamName, response.StreamCondition}, nil
|
||||
}
|
||||
|
||||
func (s *Session) OpenStreamingChannel(ctx context.Context, id ChannelID) (ssh.Channel, error) {
|
||||
type getStreamArgs struct {
|
||||
StreamName string `json:"streamName"`
|
||||
Condition string `json:"condition"`
|
||||
}
|
||||
args := getStreamArgs{
|
||||
StreamName: id.name,
|
||||
Condition: id.condition,
|
||||
}
|
||||
var streamID string
|
||||
if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
|
||||
return nil, fmt.Errorf("error getting stream id: %w", err)
|
||||
}
|
||||
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest")
|
||||
defer span.Finish()
|
||||
_ = ctx // ctx is not currently used
|
||||
|
||||
channel, reqs, err := s.ssh.conn.OpenChannel("session", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening ssh channel for transport: %w", err)
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
requestType := fmt.Sprintf("stream-transport-%s", streamID)
|
||||
if _, err = channel.SendRequest(requestType, true, nil); err != nil {
|
||||
return nil, fmt.Errorf("error sending channel request: %w", err)
|
||||
}
|
||||
|
||||
return channel, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ func TestServerStartSharing(t *testing.T) {
|
|||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
streamID, err := session.startSharing(ctx, serverProtocol, serverPort)
|
||||
streamID, err := session.StartSharing(ctx, serverProtocol, serverPort)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("error sharing server: %w", err)
|
||||
}
|
||||
|
|
@ -247,10 +247,10 @@ func TestInvalidHostKey(t *testing.T) {
|
|||
func TestKeepAliveNonBlocking(t *testing.T) {
|
||||
session := &Session{keepAliveReason: make(chan string, 1)}
|
||||
for i := 0; i < 2; i++ {
|
||||
session.keepAlive("io")
|
||||
session.KeepAlive("io")
|
||||
}
|
||||
|
||||
// if keepAlive blocks, we'll never reach this and timeout the test
|
||||
// if KeepAlive blocks, we'll never reach this and timeout the test
|
||||
// timing out
|
||||
}
|
||||
|
||||
|
|
@ -367,10 +367,10 @@ func TestSessionHeartbeat(t *testing.T) {
|
|||
|
||||
go session.heartbeat(ctx, 50*time.Millisecond)
|
||||
go func() {
|
||||
session.keepAlive("input")
|
||||
session.KeepAlive("input")
|
||||
wg.Wait()
|
||||
wg.Add(1)
|
||||
session.keepAlive("input")
|
||||
session.KeepAlive("input")
|
||||
wg.Wait()
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
|
@ -380,7 +380,7 @@ func TestSessionHeartbeat(t *testing.T) {
|
|||
t.Errorf("error from server: %v", err)
|
||||
case <-done:
|
||||
activityCount := strings.Count(logger.String(), "input")
|
||||
// by design keepAlive can drop requests, and therefore there is zero guarantee
|
||||
// by design KeepAlive can drop requests, and therefore there is zero guarantee
|
||||
// that we actually get two requests if the network happened to be slow (rarely)
|
||||
// during testing.
|
||||
if activityCount != 1 && activityCount != 2 {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue