Migrate all Codespaces operations from Live Share to Dev Tunnels (#8149)
* Migrate all Codespaces operations from Live Share to Dev Tunnels * Remove Live Share references * Fix linting errors * Update comments, remove deps, add uint16 bound checks * Fix tests and move keep-alive logic to forwarder * Address comments * Updated mock port forwarder * Fix CodeQL error * Update comment * Update func name * Add missing connection close * Fix linting error * https -> http * Update defer * Fix tests
This commit is contained in:
parent
7d6fba0d7d
commit
64f4660ec7
31 changed files with 491 additions and 2211 deletions
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
|
|
@ -1,5 +1,4 @@
|
|||
* @cli/code-reviewers
|
||||
|
||||
pkg/cmd/codespace/ @cli/codespaces
|
||||
pkg/liveshare/ @cli/codespaces
|
||||
internal/codespaces/ @cli/codespaces
|
||||
|
|
|
|||
3
go.mod
3
go.mod
|
|
@ -27,12 +27,11 @@ require (
|
|||
github.com/mattn/go-colorable v0.1.13
|
||||
github.com/mattn/go-isatty v0.0.19
|
||||
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
|
||||
github.com/microsoft/dev-tunnels v0.0.21
|
||||
github.com/microsoft/dev-tunnels v0.0.25
|
||||
github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38
|
||||
github.com/opentracing/opentracing-go v1.1.0
|
||||
github.com/rivo/tview v0.0.0-20221029100920-c4a7e501810d
|
||||
github.com/shurcooL/githubv4 v0.0.0-20230704064427-599ae7bbf278
|
||||
github.com/sourcegraph/jsonrpc2 v0.1.0
|
||||
github.com/spf13/cobra v1.6.1
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/stretchr/testify v1.8.4
|
||||
|
|
|
|||
7
go.sum
7
go.sum
|
|
@ -67,7 +67,6 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaU
|
|||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
|
||||
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
|
||||
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
|
||||
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
||||
|
|
@ -118,8 +117,8 @@ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyex
|
|||
github.com/microcosm-cc/bluemonday v1.0.21/go.mod h1:ytNkv4RrDrLJ2pqlsSI46O6IVXmZOBBD4SaJyDwwTkM=
|
||||
github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58=
|
||||
github.com/microcosm-cc/bluemonday v1.0.26/go.mod h1:JyzOCs9gkyQyjs+6h10UEVSe02CGwkhd72Xdqh78TWs=
|
||||
github.com/microsoft/dev-tunnels v0.0.21 h1:p4QP7C5ZOyP9bGbmanRjPxUMckfi9Z41Gl+KY4C11w0=
|
||||
github.com/microsoft/dev-tunnels v0.0.21/go.mod h1:frU++12T/oqxckXkDpTuYa427ncguEOodSPZcGCCrzQ=
|
||||
github.com/microsoft/dev-tunnels v0.0.25 h1:UlMKUI+2O8cSu4RlB52ioSyn1LthYSVkJA+CSTsdKoA=
|
||||
github.com/microsoft/dev-tunnels v0.0.25/go.mod h1:frU++12T/oqxckXkDpTuYa427ncguEOodSPZcGCCrzQ=
|
||||
github.com/muesli/reflow v0.2.1-0.20210115123740-9e1d0d53df68/go.mod h1:Xk+z4oIWdQqJzsxyjgl3P22oYZnHdZ8FFTHAQQt5BMQ=
|
||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
|
|
@ -150,8 +149,6 @@ github.com/shurcooL/githubv4 v0.0.0-20230704064427-599ae7bbf278 h1:kdEGVAV4sO46D
|
|||
github.com/shurcooL/githubv4 v0.0.0-20230704064427-599ae7bbf278/go.mod h1:zqMwyHmnN/eDOZOdiTohqIUKUrTFX62PNlu7IJdu0q8=
|
||||
github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 h1:17JxqqJY66GmZVHkmAsGEkcIu0oCe3AM420QDgGwZx0=
|
||||
github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466/go.mod h1:9dIRpgIY7hVhoqfe0/FcYp0bpInZaT7dc3BYOprrIUE=
|
||||
github.com/sourcegraph/jsonrpc2 v0.1.0 h1:ohJHjZ+PcaLxDUjqk2NC3tIGsVa5bXThe1ZheSXOjuk=
|
||||
github.com/sourcegraph/jsonrpc2 v0.1.0/go.mod h1:ZafdZgk/axhT1cvZAPOhw+95nz2I/Ra5qMlU4gTRwIo=
|
||||
github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA=
|
||||
github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
|
|
|
|||
|
|
@ -247,11 +247,6 @@ const (
|
|||
)
|
||||
|
||||
type CodespaceConnection struct {
|
||||
SessionID string `json:"sessionId"`
|
||||
SessionToken string `json:"sessionToken"`
|
||||
RelayEndpoint string `json:"relayEndpoint"`
|
||||
RelaySAS string `json:"relaySas"`
|
||||
HostPublicKeys []string `json:"hostPublicKeys"`
|
||||
TunnelProperties TunnelProperties `json:"tunnelProperties"`
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,30 +11,20 @@ import (
|
|||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces/connection"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
)
|
||||
|
||||
func connectionReady(codespace *api.Codespace, usingDevTunnels bool) bool {
|
||||
func connectionReady(codespace *api.Codespace) bool {
|
||||
// If the codespace is not available, it is not ready
|
||||
if codespace.State != api.CodespaceStateAvailable {
|
||||
return false
|
||||
}
|
||||
|
||||
// If using Dev Tunnels, we need to check that we have all of the required tunnel properties
|
||||
if usingDevTunnels {
|
||||
return codespace.Connection.TunnelProperties.ConnectAccessToken != "" &&
|
||||
codespace.Connection.TunnelProperties.ManagePortsAccessToken != "" &&
|
||||
codespace.Connection.TunnelProperties.ServiceUri != "" &&
|
||||
codespace.Connection.TunnelProperties.TunnelId != "" &&
|
||||
codespace.Connection.TunnelProperties.ClusterId != "" &&
|
||||
codespace.Connection.TunnelProperties.Domain != ""
|
||||
}
|
||||
|
||||
// If not using Dev Tunnels, we need to check that we have all of the required Live Share properties
|
||||
return codespace.Connection.SessionID != "" &&
|
||||
codespace.Connection.SessionToken != "" &&
|
||||
codespace.Connection.RelayEndpoint != "" &&
|
||||
codespace.Connection.RelaySAS != ""
|
||||
return codespace.Connection.TunnelProperties.ConnectAccessToken != "" &&
|
||||
codespace.Connection.TunnelProperties.ManagePortsAccessToken != "" &&
|
||||
codespace.Connection.TunnelProperties.ServiceUri != "" &&
|
||||
codespace.Connection.TunnelProperties.TunnelId != "" &&
|
||||
codespace.Connection.TunnelProperties.ClusterId != "" &&
|
||||
codespace.Connection.TunnelProperties.Domain != ""
|
||||
}
|
||||
|
||||
type apiClient interface {
|
||||
|
|
@ -48,11 +38,6 @@ type progressIndicator interface {
|
|||
StopProgressIndicator()
|
||||
}
|
||||
|
||||
type logger interface {
|
||||
Println(v ...interface{})
|
||||
Printf(f string, v ...interface{})
|
||||
}
|
||||
|
||||
type TimeoutError struct {
|
||||
message string
|
||||
}
|
||||
|
|
@ -64,7 +49,7 @@ func (e *TimeoutError) Error() string {
|
|||
// GetCodespaceConnection waits until a codespace is able
|
||||
// to be connected to and initializes a connection to it.
|
||||
func GetCodespaceConnection(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace) (*connection.CodespaceConnection, error) {
|
||||
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, true)
|
||||
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -80,29 +65,8 @@ func GetCodespaceConnection(ctx context.Context, progress progressIndicator, api
|
|||
return connection.NewCodespaceConnection(ctx, codespace, httpClient)
|
||||
}
|
||||
|
||||
// ConnectToLiveshare waits until a codespace is able to be
|
||||
// connected to and connects to it using a Live Share session.
|
||||
func ConnectToLiveshare(ctx context.Context, progress progressIndicator, sessionLogger logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) {
|
||||
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
|
||||
defer progress.StopProgressIndicator()
|
||||
|
||||
return liveshare.Connect(ctx, liveshare.Options{
|
||||
SessionID: codespace.Connection.SessionID,
|
||||
SessionToken: codespace.Connection.SessionToken,
|
||||
RelaySAS: codespace.Connection.RelaySAS,
|
||||
RelayEndpoint: codespace.Connection.RelayEndpoint,
|
||||
HostPublicKeys: codespace.Connection.HostPublicKeys,
|
||||
Logger: sessionLogger,
|
||||
})
|
||||
}
|
||||
|
||||
// waitUntilCodespaceConnectionReady waits for a Codespace to be running and is able to be connected to.
|
||||
func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace, usingDevTunnels bool) (*api.Codespace, error) {
|
||||
func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace) (*api.Codespace, error) {
|
||||
if codespace.State != api.CodespaceStateAvailable {
|
||||
progress.StartProgressIndicatorWithLabel("Starting codespace")
|
||||
defer progress.StopProgressIndicator()
|
||||
|
|
@ -111,7 +75,7 @@ func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressInd
|
|||
}
|
||||
}
|
||||
|
||||
if !connectionReady(codespace, usingDevTunnels) {
|
||||
if !connectionReady(codespace) {
|
||||
expBackoff := backoff.NewExponentialBackOff()
|
||||
expBackoff.Multiplier = 1.1
|
||||
expBackoff.MaxInterval = 10 * time.Second
|
||||
|
|
@ -124,7 +88,7 @@ func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressInd
|
|||
return backoff.Permanent(fmt.Errorf("error getting codespace: %w", err))
|
||||
}
|
||||
|
||||
if connectionReady(codespace, usingDevTunnels) {
|
||||
if connectionReady(codespace) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package portforwarder
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
|
|
@ -22,101 +23,75 @@ const (
|
|||
PublicPortVisibility = "public"
|
||||
)
|
||||
|
||||
type PortForwarder struct {
|
||||
connection connection.CodespaceConnection
|
||||
const (
|
||||
trafficTypeInput = "input"
|
||||
trafficTypeOutput = "output"
|
||||
)
|
||||
|
||||
type ForwardPortOpts struct {
|
||||
Port int
|
||||
Internal bool
|
||||
KeepAlive bool
|
||||
Visibility string
|
||||
}
|
||||
|
||||
type CodespacesPortForwarder struct {
|
||||
connection connection.CodespaceConnection
|
||||
keepAliveReason chan string
|
||||
}
|
||||
|
||||
type PortForwarder interface {
|
||||
ForwardPortToListener(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) error
|
||||
ForwardPort(ctx context.Context, opts ForwardPortOpts) error
|
||||
ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, opts ForwardPortOpts) error
|
||||
ListPorts(ctx context.Context) ([]*tunnels.TunnelPort, error)
|
||||
UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error
|
||||
KeepAlive(reason string)
|
||||
GetKeepAliveReason() string
|
||||
CloseSSHConnection()
|
||||
}
|
||||
|
||||
// NewPortForwarder returns a new PortForwarder for the specified codespace.
|
||||
func NewPortForwarder(ctx context.Context, codespaceConnection *connection.CodespaceConnection) (fwd *PortForwarder, err error) {
|
||||
return &PortForwarder{
|
||||
connection: *codespaceConnection,
|
||||
func NewPortForwarder(ctx context.Context, codespaceConnection *connection.CodespaceConnection) (fwd PortForwarder, err error) {
|
||||
return &CodespacesPortForwarder{
|
||||
connection: *codespaceConnection,
|
||||
keepAliveReason: make(chan string, 1),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ForwardAndConnectToPort forwards a port and connects to it via a local TCP port.
|
||||
func (fwd *PortForwarder) ForwardAndConnectToPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, internal bool) error {
|
||||
return fwd.ForwardPort(ctx, remotePort, listen, keepAlive, true, internal, "")
|
||||
}
|
||||
|
||||
// ForwardPort forwards a port and optionally connects to it via a local TCP port.
|
||||
func (fwd *PortForwarder) ForwardPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, connect bool, internal bool, visibility string) error {
|
||||
tunnelPort := tunnels.NewTunnelPort(remotePort, "", "", tunnels.TunnelProtocolHttp)
|
||||
|
||||
// If no visibility is provided, Dev Tunnels will use the default (private)
|
||||
if visibility != "" {
|
||||
// Check if the requested visibility is allowed
|
||||
allowed := false
|
||||
for _, allowedVisibility := range fwd.connection.AllowedPortPrivacySettings {
|
||||
if allowedVisibility == visibility {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If the requested visibility is not allowed, return an error
|
||||
if !allowed {
|
||||
return fmt.Errorf("visibility %s is not allowed", visibility)
|
||||
}
|
||||
|
||||
accessControlEntries := visibilityToAccessControlEntries(visibility)
|
||||
if len(accessControlEntries) > 0 {
|
||||
tunnelPort.AccessControl = &tunnels.TunnelAccessControl{
|
||||
Entries: accessControlEntries,
|
||||
}
|
||||
}
|
||||
// ForwardPortToListener forwards the specified port to the given TCP listener.
|
||||
func (fwd *CodespacesPortForwarder) ForwardPortToListener(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) error {
|
||||
err := fwd.ForwardPort(ctx, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error forwarding port: %w", err)
|
||||
}
|
||||
|
||||
// Tag the port as internal or user forwarded so we know if it needs to be shown in the UI
|
||||
if internal {
|
||||
tunnelPort.Tags = []string{InternalPortTag}
|
||||
} else {
|
||||
tunnelPort.Tags = []string{UserForwardedPortTag}
|
||||
}
|
||||
|
||||
// Create the tunnel port
|
||||
_, err := fwd.connection.TunnelManager.CreateTunnelPort(ctx, fwd.connection.Tunnel, tunnelPort, fwd.connection.Options)
|
||||
if err != nil && !strings.Contains(err.Error(), "409") {
|
||||
return fmt.Errorf("create tunnel port failed: %v", err)
|
||||
}
|
||||
// Close the SSH connection when we're done
|
||||
defer fwd.CloseSSHConnection()
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
// Connect to the tunnel
|
||||
err = fwd.connection.TunnelClient.Connect(ctx, "")
|
||||
// Convert the port number to a uint16
|
||||
port, err := convertIntToUint16(opts.Port)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("connect failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Inform the host that we've forwarded the port locally
|
||||
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("refresh ports failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// If we don't want to connect to the port, exit early
|
||||
if !connect {
|
||||
done <- nil
|
||||
done <- fmt.Errorf("error converting port: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure the port is forwarded before connecting
|
||||
err = fwd.connection.TunnelClient.WaitForForwardedPort(ctx, remotePort)
|
||||
err = fwd.connection.TunnelClient.WaitForForwardedPort(ctx, port)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("wait for forwarded port failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Connect to the forwarded port via a local TCP port
|
||||
err = fwd.connection.TunnelClient.ConnectToForwardedPort(ctx, listen, remotePort)
|
||||
// Connect to the forwarded port
|
||||
err = fwd.connectListenerToForwardedPort(ctx, opts, listener)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("connect to forwarded port failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
|
|
@ -128,8 +103,131 @@ func (fwd *PortForwarder) ForwardPort(ctx context.Context, remotePort uint16, li
|
|||
}
|
||||
}
|
||||
|
||||
// ForwardPort informs the host that we would like to forward the given port.
|
||||
func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts ForwardPortOpts) error {
|
||||
// Convert the port number to a uint16
|
||||
port, err := convertIntToUint16(opts.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error converting port: %w", err)
|
||||
}
|
||||
|
||||
tunnelPort := tunnels.NewTunnelPort(port, "", "", tunnels.TunnelProtocolHttp)
|
||||
|
||||
// If no visibility is provided, Dev Tunnels will use the default (private)
|
||||
if opts.Visibility != "" {
|
||||
// Check if the requested visibility is allowed
|
||||
allowed := false
|
||||
for _, allowedVisibility := range fwd.connection.AllowedPortPrivacySettings {
|
||||
if allowedVisibility == opts.Visibility {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If the requested visibility is not allowed, return an error
|
||||
if !allowed {
|
||||
return fmt.Errorf("visibility %s is not allowed", opts.Visibility)
|
||||
}
|
||||
|
||||
accessControlEntries := visibilityToAccessControlEntries(opts.Visibility)
|
||||
if len(accessControlEntries) > 0 {
|
||||
tunnelPort.AccessControl = &tunnels.TunnelAccessControl{
|
||||
Entries: accessControlEntries,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tag the port as internal or user forwarded so we know if it needs to be shown in the UI
|
||||
if opts.Internal {
|
||||
tunnelPort.Tags = []string{InternalPortTag}
|
||||
} else {
|
||||
tunnelPort.Tags = []string{UserForwardedPortTag}
|
||||
}
|
||||
|
||||
// Create the tunnel port
|
||||
_, err = fwd.connection.TunnelManager.CreateTunnelPort(ctx, fwd.connection.Tunnel, tunnelPort, fwd.connection.Options)
|
||||
if err != nil && !strings.Contains(err.Error(), "409") {
|
||||
return fmt.Errorf("create tunnel port failed: %v", err)
|
||||
}
|
||||
|
||||
// Connect to the tunnel
|
||||
err = fwd.connection.TunnelClient.Connect(ctx, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect failed: %v", err)
|
||||
}
|
||||
|
||||
// Inform the host that we've forwarded the port locally
|
||||
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
|
||||
if err != nil {
|
||||
fwd.CloseSSHConnection()
|
||||
return fmt.Errorf("refresh ports failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectListenerToForwardedPort connects to the forwarded port via a local TCP port.
|
||||
func (fwd *CodespacesPortForwarder) connectListenerToForwardedPort(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) (err error) {
|
||||
errc := make(chan error, 1)
|
||||
sendError := func(err error) {
|
||||
// Use non-blocking send, to avoid goroutines getting
|
||||
// stuck in case of concurrent or sequential errors.
|
||||
select {
|
||||
case errc <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.AcceptTCP()
|
||||
if err != nil {
|
||||
sendError(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Connect to the forwarded port in a goroutine so we can accept new connections
|
||||
go func() {
|
||||
if err := fwd.ConnectToForwardedPort(ctx, conn, opts); err != nil {
|
||||
sendError(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for an error or for the context to be cancelled
|
||||
select {
|
||||
case err := <-errc:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err() // canceled
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectToForwardedPort connects to the forwarded port via a given ReadWriteCloser.
|
||||
// Optionally, it detects traffic over the connection and sends activity signals to the server to keep the codespace from shutting down.
|
||||
func (fwd *CodespacesPortForwarder) ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, opts ForwardPortOpts) error {
|
||||
// Create a traffic monitor to keep the session alive
|
||||
if opts.KeepAlive {
|
||||
conn = newTrafficMonitor(conn, fwd)
|
||||
}
|
||||
|
||||
// Convert the port number to a uint16
|
||||
port, err := convertIntToUint16(opts.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error converting port: %w", err)
|
||||
}
|
||||
|
||||
// Connect to the forwarded port
|
||||
err = fwd.connection.TunnelClient.ConnectToForwardedPort(ctx, conn, port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to forwarded port: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPorts fetches the list of ports that are currently forwarded.
|
||||
func (fwd *PortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.TunnelPort, err error) {
|
||||
func (fwd *CodespacesPortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.TunnelPort, err error) {
|
||||
ports, err = fwd.connection.TunnelManager.ListTunnelPorts(ctx, fwd.connection.Tunnel, fwd.connection.Options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error listing ports: %w", err)
|
||||
|
|
@ -139,7 +237,7 @@ func (fwd *PortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.Tunne
|
|||
}
|
||||
|
||||
// UpdatePortVisibility changes the visibility (private, org, public) of the specified port.
|
||||
func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error {
|
||||
func (fwd *CodespacesPortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error {
|
||||
tunnelPort, err := fwd.connection.TunnelManager.GetTunnelPort(ctx, fwd.connection.Tunnel, remotePort, fwd.connection.Options)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting tunnel port: %w", err)
|
||||
|
|
@ -165,6 +263,9 @@ func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort i
|
|||
return
|
||||
}
|
||||
|
||||
// Close the SSH connection when we're done
|
||||
defer fwd.CloseSSHConnection()
|
||||
|
||||
// Inform the host that we've deleted the port
|
||||
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
|
||||
if err != nil {
|
||||
|
|
@ -172,6 +273,13 @@ func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort i
|
|||
return
|
||||
}
|
||||
|
||||
// Re-forward the port with the updated visibility
|
||||
err = fwd.ForwardPort(ctx, ForwardPortOpts{Port: remotePort, Visibility: visibility})
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("error forwarding port: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
|
|
@ -179,13 +287,10 @@ func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort i
|
|||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to tunnel: %w", err)
|
||||
}
|
||||
// If we fail to re-forward the port, we need to forward again with the original visibility so the port is still accessible
|
||||
_ = fwd.ForwardPort(ctx, ForwardPortOpts{Port: remotePort, Visibility: AccessControlEntriesToVisibility(tunnelPort.AccessControl.Entries)})
|
||||
|
||||
// Re-forward the port with the updated visibility
|
||||
err = fwd.ForwardPort(ctx, uint16(remotePort), nil, false, false, false, visibility)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error forwarding port: %w", err)
|
||||
return fmt.Errorf("error connecting to tunnel: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -194,6 +299,27 @@ func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort i
|
|||
}
|
||||
}
|
||||
|
||||
// KeepAlive accepts a reason that is retained if there is no active reason
|
||||
// to send to the server.
|
||||
func (fwd *CodespacesPortForwarder) KeepAlive(reason string) {
|
||||
select {
|
||||
case fwd.keepAliveReason <- reason:
|
||||
default:
|
||||
// there is already an active keep alive reason
|
||||
// so we can ignore this one
|
||||
}
|
||||
}
|
||||
|
||||
// GetKeepAliveReason fetches the keep alive reason from the channel and returns it.
|
||||
func (fwd *CodespacesPortForwarder) GetKeepAliveReason() string {
|
||||
return <-fwd.keepAliveReason
|
||||
}
|
||||
|
||||
// Close closes the port forwarder's tunnel client connection.
|
||||
func (fwd *CodespacesPortForwarder) CloseSSHConnection() {
|
||||
_ = fwd.connection.TunnelClient.Close()
|
||||
}
|
||||
|
||||
// AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value.
|
||||
func AccessControlEntriesToVisibility(accessControlEntries []tunnels.TunnelAccessControlEntry) string {
|
||||
for _, entry := range accessControlEntries {
|
||||
|
|
@ -251,3 +377,45 @@ func IsInternalPort(port *tunnels.TunnelPort) bool {
|
|||
|
||||
return false
|
||||
}
|
||||
|
||||
// convertIntToUint16 converts the given int to a uint16.
|
||||
func convertIntToUint16(port int) (uint16, error) {
|
||||
var updatedPort uint16
|
||||
if port >= 0 && port <= 65535 {
|
||||
updatedPort = uint16(port)
|
||||
} else {
|
||||
return 0, fmt.Errorf("invalid port number: %d", port)
|
||||
}
|
||||
|
||||
return updatedPort, nil
|
||||
}
|
||||
|
||||
// trafficMonitor implements io.Reader. It keeps the session alive by notifying
|
||||
// it of the traffic type during Read operations.
|
||||
type trafficMonitor struct {
|
||||
rwc io.ReadWriteCloser
|
||||
fwd PortForwarder
|
||||
}
|
||||
|
||||
// newTrafficMonitor returns a trafficMonitor for the specified codespace connection.
|
||||
// It wraps the provided io.ReaderWriteCloser with its own Read/Write/Close methods.
|
||||
func newTrafficMonitor(rwc io.ReadWriteCloser, fwd PortForwarder) *trafficMonitor {
|
||||
return &trafficMonitor{rwc, fwd}
|
||||
}
|
||||
|
||||
// Read wraps the underlying ReadWriteCloser's Read method and keeps the session alive with the "input" traffic type.
|
||||
func (t *trafficMonitor) Read(p []byte) (n int, err error) {
|
||||
t.fwd.KeepAlive(trafficTypeInput)
|
||||
return t.rwc.Read(p)
|
||||
}
|
||||
|
||||
// Write wraps the underlying ReadWriteCloser's Write method and keeps the session alive with the "output" traffic type.
|
||||
func (t *trafficMonitor) Write(p []byte) (n int, err error) {
|
||||
t.fwd.KeepAlive(trafficTypeOutput)
|
||||
return t.rwc.Write(p)
|
||||
}
|
||||
|
||||
// Close closes the underlying ReadWriteCloser.
|
||||
func (t *trafficMonitor) Close() error {
|
||||
return t.rwc.Close()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/codespace"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/ssh"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
|
@ -47,7 +47,7 @@ type Invoker interface {
|
|||
|
||||
type invoker struct {
|
||||
conn *grpc.ClientConn
|
||||
session liveshare.LiveshareSession
|
||||
fwd portforwarder.PortForwarder
|
||||
listener net.Listener
|
||||
jupyterClient jupyter.JupyterServerHostClient
|
||||
codespaceClient codespace.CodespaceHostClient
|
||||
|
|
@ -56,11 +56,11 @@ type invoker struct {
|
|||
}
|
||||
|
||||
// Connects to the internal RPC server and returns a new invoker for it
|
||||
func CreateInvoker(ctx context.Context, session liveshare.LiveshareSession) (Invoker, error) {
|
||||
func CreateInvoker(ctx context.Context, fwd portforwarder.PortForwarder) (Invoker, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, ConnectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
invoker, err := connect(ctx, session)
|
||||
invoker, err := connect(ctx, fwd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error connecting to internal server: %w", err)
|
||||
}
|
||||
|
|
@ -69,7 +69,7 @@ func CreateInvoker(ctx context.Context, session liveshare.LiveshareSession) (Inv
|
|||
}
|
||||
|
||||
// Finds a free port to listen on and creates a new RPC invoker that connects to that port
|
||||
func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, error) {
|
||||
func connect(ctx context.Context, fwd portforwarder.PortForwarder) (Invoker, error) {
|
||||
listener, err := listenTCP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -77,7 +77,7 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker,
|
|||
localAddress := listener.Addr().String()
|
||||
|
||||
invoker := &invoker{
|
||||
session: session,
|
||||
fwd: fwd,
|
||||
listener: listener,
|
||||
}
|
||||
|
||||
|
|
@ -100,8 +100,12 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker,
|
|||
|
||||
// Tunnel the remote gRPC server port to the local port
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true)
|
||||
ch <- fwd.ForwardToListener(pfctx, listener)
|
||||
// Start forwarding the port locally
|
||||
opts := portforwarder.ForwardPortOpts{
|
||||
Port: codespacesInternalPort,
|
||||
Internal: true,
|
||||
}
|
||||
ch <- fwd.ForwardPortToListener(pfctx, opts, listener)
|
||||
}()
|
||||
|
||||
var conn *grpc.ClientConn
|
||||
|
|
@ -262,7 +266,7 @@ func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) {
|
|||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
reason := i.session.GetKeepAliveReason()
|
||||
reason := i.fwd.GetKeepAliveReason()
|
||||
_ = i.notifyCodespaceOfClientActivity(ctx, reason)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -72,7 +72,8 @@ func createTestInvoker(t *testing.T, server *mockServer) (Invoker, func(), error
|
|||
listener.Close()
|
||||
}
|
||||
|
||||
invoker, err := CreateInvoker(context.Background(), &rpctest.Session{})
|
||||
// Create a new invoker with a mock port forwarder
|
||||
invoker, err := CreateInvoker(context.Background(), rpctest.PortForwarder{})
|
||||
if err != nil {
|
||||
close()
|
||||
return nil, nil, fmt.Errorf("error connecting to internal server: %w", err)
|
||||
|
|
|
|||
|
|
@ -1,34 +0,0 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (c *Channel) Read(data []byte) (int, error) {
|
||||
return c.conn.Read(data)
|
||||
}
|
||||
|
||||
func (c *Channel) Write(data []byte) (int, error) {
|
||||
return c.conn.Write(data)
|
||||
}
|
||||
|
||||
func (c *Channel) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *Channel) CloseWrite() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Channel) Stderr() io.ReadWriter {
|
||||
return nil
|
||||
}
|
||||
78
internal/codespaces/rpc/test/port_forwarder.go
Normal file
78
internal/codespaces/rpc/test/port_forwarder.go
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
|
||||
"github.com/microsoft/dev-tunnels/go/tunnels"
|
||||
)
|
||||
|
||||
type PortForwarder struct{}
|
||||
|
||||
// Close implements portforwarder.PortForwarder.
|
||||
func (PortForwarder) CloseSSHConnection() {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// ConnectToForwardedPort implements portforwarder.PortForwarder.
|
||||
func (PortForwarder) ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, opts portforwarder.ForwardPortOpts) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// ForwardPort implements portforwarder.PortForwarder.
|
||||
func (PortForwarder) ForwardPort(ctx context.Context, opts portforwarder.ForwardPortOpts) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// GetKeepAliveReason implements portforwarder.PortForwarder.
|
||||
func (PortForwarder) GetKeepAliveReason() string {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// KeepAlive implements portforwarder.PortForwarder.
|
||||
func (PortForwarder) KeepAlive(reason string) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// ForwardPortToListener implements portforwarder.PortForwarder.
|
||||
func (PortForwarder) ForwardPortToListener(ctx context.Context, opts portforwarder.ForwardPortOpts, listener *net.TCPListener) error {
|
||||
// Start forwarding the port locally
|
||||
hostConn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", opts.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Accept the connection from the listener
|
||||
listenerConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy data between the two connections
|
||||
go func() {
|
||||
_, _ = io.Copy(hostConn, listenerConn)
|
||||
hostConn.Close()
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(listenerConn, hostConn)
|
||||
listenerConn.Close()
|
||||
}()
|
||||
|
||||
// ForwardPortToListener typically blocks until the context is cancelled so we need to do the same
|
||||
<-ctx.Done()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPorts implements portforwarder.PortForwarder.
|
||||
func (PortForwarder) ListPorts(ctx context.Context) ([]*tunnels.TunnelPort, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// UpdatePortVisibility implements portforwarder.PortForwarder.
|
||||
func (PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
channel ssh.Channel
|
||||
}
|
||||
|
||||
func (*Session) Close() error {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (*Session) GetSharedServers(context.Context) ([]*liveshare.Port, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (s *Session) KeepAlive(reason string) {
|
||||
}
|
||||
|
||||
func (s *Session) GetKeepAliveReason() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) {
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
if err != nil {
|
||||
return liveshare.ChannelID{}, err
|
||||
}
|
||||
s.channel = &Channel{conn}
|
||||
return liveshare.ChannelID{}, nil
|
||||
}
|
||||
|
||||
// Creates mock SSH channel connected to the mock gRPC server
|
||||
func (s *Session) OpenStreamingChannel(ctx context.Context, id liveshare.ChannelID) (ssh.Channel, error) {
|
||||
return s.channel, nil
|
||||
}
|
||||
|
|
@ -6,13 +6,12 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc"
|
||||
"github.com/cli/cli/v2/internal/text"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
)
|
||||
|
||||
// PostCreateStateStatus is a string value representing the different statuses a state can have.
|
||||
|
|
@ -39,17 +38,15 @@ type PostCreateState struct {
|
|||
// and calls the supplied poller for each batch of state changes.
|
||||
// It runs until it encounters an error, including cancellation of the context.
|
||||
func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace, poller func([]PostCreateState)) (err error) {
|
||||
noopLogger := log.New(io.Discard, "", 0)
|
||||
|
||||
session, err := ConnectToLiveshare(ctx, progress, noopLogger, apiClient, codespace)
|
||||
codespaceConnection, err := GetCodespaceConnection(ctx, progress, apiClient, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to codespace: %w", err)
|
||||
return fmt.Errorf("error connecting to codespace: %w", err)
|
||||
}
|
||||
|
||||
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := session.Close(); err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
// Ensure local port is listening before client (getPostCreateOutput) connects.
|
||||
listen, localPort, err := ListenTCP(0, false)
|
||||
|
|
@ -58,7 +55,7 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl
|
|||
}
|
||||
|
||||
progress.StartProgressIndicatorWithLabel("Fetching SSH Details")
|
||||
invoker, err := rpc.CreateInvoker(ctx, session)
|
||||
invoker, err := rpc.CreateInvoker(ctx, fwd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -73,8 +70,11 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl
|
|||
progress.StartProgressIndicatorWithLabel("Fetching status")
|
||||
tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false)
|
||||
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
|
||||
opts := portforwarder.ForwardPortOpts{
|
||||
Port: remoteSSHServerPort,
|
||||
Internal: true,
|
||||
}
|
||||
tunnelClosed <- fwd.ForwardPortToListener(ctx, opts, listen)
|
||||
}()
|
||||
|
||||
t := time.NewTicker(1 * time.Second)
|
||||
|
|
|
|||
|
|
@ -17,10 +17,8 @@ import (
|
|||
"github.com/AlecAivazis/survey/v2/terminal"
|
||||
clicontext "github.com/cli/cli/v2/context"
|
||||
"github.com/cli/cli/v2/internal/browser"
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/cli/cli/v2/pkg/iostreams"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
|
@ -65,28 +63,6 @@ func (a *App) RunWithProgress(label string, run func() error) error {
|
|||
return a.io.RunWithProgress(label, run)
|
||||
}
|
||||
|
||||
// Connects to a codespace using Live Share and returns that session
|
||||
func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App, debug bool, debugFile string) (session *liveshare.Session, err error) {
|
||||
liveshareLogger := noopLogger()
|
||||
if debug {
|
||||
debugLogger, err := newFileLogger(debugFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't create file logger: %w", err)
|
||||
}
|
||||
defer safeClose(debugLogger, &err)
|
||||
|
||||
liveshareLogger = debugLogger.Logger
|
||||
a.errLogger.Printf("Debug file located at: %s", debugLogger.Name())
|
||||
}
|
||||
|
||||
session, err = codespaces.ConnectToLiveshare(ctx, a, liveshareLogger, a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Live Share: %w", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient
|
||||
type apiClient interface {
|
||||
ServerURL() string
|
||||
|
|
@ -201,10 +177,6 @@ func noArgsConstraint(cmd *cobra.Command, args []string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func noopLogger() *log.Logger {
|
||||
return log.New(io.Discard, "", 0)
|
||||
}
|
||||
|
||||
type codespace struct {
|
||||
*api.Codespace
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
|
|
@ -39,11 +39,15 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err
|
|||
return err
|
||||
}
|
||||
|
||||
session, err := startLiveShareSession(ctx, codespace, a, false, "")
|
||||
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error connecting to codespace: %w", err)
|
||||
}
|
||||
|
||||
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
var (
|
||||
invoker rpc.Invoker
|
||||
|
|
@ -51,7 +55,7 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err
|
|||
serverUrl string
|
||||
)
|
||||
err = a.RunWithProgress("Starting JupyterLab on codespace", func() (err error) {
|
||||
invoker, err = rpc.CreateInvoker(ctx, session)
|
||||
invoker, err = rpc.CreateInvoker(ctx, fwd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -76,8 +80,10 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err
|
|||
|
||||
tunnelClosed := make(chan error, 1)
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, "jupyter", serverPort, true)
|
||||
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil
|
||||
opts := portforwarder.ForwardPortOpts{
|
||||
Port: serverPort,
|
||||
}
|
||||
tunnelClosed <- fwd.ForwardPortToListener(ctx, opts, listen)
|
||||
}()
|
||||
|
||||
// Server URL contains an authentication token that must be preserved
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import (
|
|||
"fmt"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
|
|
@ -42,11 +42,15 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool
|
|||
return err
|
||||
}
|
||||
|
||||
session, err := startLiveShareSession(ctx, codespace, a, false, "")
|
||||
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error connecting to codespace: %w", err)
|
||||
}
|
||||
|
||||
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
// Ensure local port is listening before client (getPostCreateOutput) connects.
|
||||
listen, localPort, err := codespaces.ListenTCP(0, false)
|
||||
|
|
@ -57,7 +61,7 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool
|
|||
|
||||
remoteSSHServerPort, sshUser := 0, ""
|
||||
err = a.RunWithProgress("Fetching SSH Details", func() (err error) {
|
||||
invoker, err := rpc.CreateInvoker(ctx, session)
|
||||
invoker, err := rpc.CreateInvoker(ctx, fwd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -85,8 +89,11 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool
|
|||
|
||||
tunnelClosed := make(chan error, 1)
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false)
|
||||
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
|
||||
opts := portforwarder.ForwardPortOpts{
|
||||
Port: remoteSSHServerPort,
|
||||
Internal: true,
|
||||
}
|
||||
tunnelClosed <- fwd.ForwardPortToListener(ctx, opts, listen)
|
||||
}()
|
||||
|
||||
cmdDone := make(chan error, 1)
|
||||
|
|
|
|||
|
|
@ -345,7 +345,11 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
return fwd.ForwardAndConnectToPort(ctx, uint16(pair.remote), listen, false, false)
|
||||
|
||||
opts := portforwarder.ForwardPortOpts{
|
||||
Port: pair.remote,
|
||||
}
|
||||
return fwd.ForwardPortToListener(ctx, opts, listen)
|
||||
})
|
||||
}
|
||||
return group.Wait() // first error
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
|
@ -49,13 +51,17 @@ func (a *App) Rebuild(ctx context.Context, selector *CodespaceSelector, full boo
|
|||
return nil
|
||||
}
|
||||
|
||||
session, err := startLiveShareSession(ctx, codespace, a, false, "")
|
||||
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting Live Share session: %w", err)
|
||||
return fmt.Errorf("error connecting to codespace: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
invoker, err := rpc.CreateInvoker(ctx, session)
|
||||
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
|
||||
invoker, err := rpc.CreateInvoker(ctx, fwd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
|
|
@ -18,10 +18,10 @@ import (
|
|||
"github.com/MakeNowJust/heredoc"
|
||||
"github.com/cli/cli/v2/internal/codespaces"
|
||||
"github.com/cli/cli/v2/internal/codespaces/api"
|
||||
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc"
|
||||
"github.com/cli/cli/v2/internal/config"
|
||||
"github.com/cli/cli/v2/pkg/cmdutil"
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
"github.com/cli/cli/v2/pkg/ssh"
|
||||
"github.com/cli/safeexec"
|
||||
"github.com/spf13/cobra"
|
||||
|
|
@ -144,6 +144,24 @@ func newSSHCmd(app *App) *cobra.Command {
|
|||
return sshCmd
|
||||
}
|
||||
|
||||
type combinedReadWriteHalfCloser struct {
|
||||
io.ReadCloser
|
||||
io.WriteCloser
|
||||
}
|
||||
|
||||
func (crwc *combinedReadWriteHalfCloser) Close() error {
|
||||
werr := crwc.WriteCloser.Close()
|
||||
rerr := crwc.ReadCloser.Close()
|
||||
if werr != nil {
|
||||
return werr
|
||||
}
|
||||
return rerr
|
||||
}
|
||||
|
||||
func (crwc *combinedReadWriteHalfCloser) CloseWrite() error {
|
||||
return crwc.WriteCloser.Close()
|
||||
}
|
||||
|
||||
// SSH opens an ssh session or runs an ssh command in a codespace.
|
||||
func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err error) {
|
||||
// Ensure all child tasks (e.g. port forwarding) terminate before return.
|
||||
|
|
@ -175,11 +193,15 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
|
|||
return err
|
||||
}
|
||||
|
||||
session, err := startLiveShareSession(ctx, codespace, a, opts.debug, opts.debugFile)
|
||||
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error connecting to codespace: %w", err)
|
||||
}
|
||||
|
||||
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
}
|
||||
defer safeClose(session, &err)
|
||||
|
||||
var (
|
||||
invoker rpc.Invoker
|
||||
|
|
@ -187,7 +209,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
|
|||
sshUser string
|
||||
)
|
||||
err = a.RunWithProgress("Fetching SSH Details", func() (err error) {
|
||||
invoker, err = rpc.CreateInvoker(ctx, session)
|
||||
invoker, err = rpc.CreateInvoker(ctx, fwd)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -203,9 +225,28 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
|
|||
}
|
||||
|
||||
if opts.stdio {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true)
|
||||
stdio := liveshare.NewReadWriteHalfCloser(os.Stdin, os.Stdout)
|
||||
err := fwd.Forward(ctx, stdio) // always non-nil
|
||||
stdio := &combinedReadWriteHalfCloser{os.Stdin, os.Stdout}
|
||||
opts := portforwarder.ForwardPortOpts{
|
||||
Port: remoteSSHServerPort,
|
||||
Internal: true,
|
||||
KeepAlive: true,
|
||||
}
|
||||
|
||||
// Forward the port
|
||||
err = fwd.ForwardPort(ctx, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to forward port: %w", err)
|
||||
}
|
||||
|
||||
// Close the SSH connection when we're done
|
||||
defer fwd.CloseSSHConnection()
|
||||
|
||||
// Connect to the forwarded port
|
||||
err = fwd.ConnectToForwardedPort(ctx, stdio, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to forwarded port: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("tunnel closed: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -227,8 +268,12 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
|
|||
|
||||
tunnelClosed := make(chan error, 1)
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true)
|
||||
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil
|
||||
opts := portforwarder.ForwardPortOpts{
|
||||
Port: remoteSSHServerPort,
|
||||
Internal: true,
|
||||
KeepAlive: true,
|
||||
}
|
||||
tunnelClosed <- fwd.ForwardPortToListener(ctx, opts, listen)
|
||||
}()
|
||||
|
||||
shellClosed := make(chan error, 1)
|
||||
|
|
@ -526,27 +571,36 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro
|
|||
result := sshResult{}
|
||||
defer wg.Done()
|
||||
|
||||
session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, cs)
|
||||
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, cs)
|
||||
if err != nil {
|
||||
result.err = fmt.Errorf("error connecting to codespace: %w", err)
|
||||
} else {
|
||||
defer safeClose(session, &err)
|
||||
|
||||
invoker, err := rpc.CreateInvoker(ctx, session)
|
||||
if err != nil {
|
||||
result.err = fmt.Errorf("error connecting to codespace: %w", err)
|
||||
} else {
|
||||
defer safeClose(invoker, &err)
|
||||
|
||||
_, result.user, err = invoker.StartSSHServer(ctx)
|
||||
if err != nil {
|
||||
result.err = fmt.Errorf("error getting ssh server details: %w", err)
|
||||
} else {
|
||||
result.codespace = cs
|
||||
}
|
||||
}
|
||||
sshUsers <- result
|
||||
return
|
||||
}
|
||||
|
||||
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
|
||||
if err != nil {
|
||||
result.err = fmt.Errorf("failed to create port forwarder: %w", err)
|
||||
sshUsers <- result
|
||||
return
|
||||
}
|
||||
|
||||
invoker, err := rpc.CreateInvoker(ctx, fwd)
|
||||
if err != nil {
|
||||
result.err = fmt.Errorf("error connecting to codespace: %w", err)
|
||||
sshUsers <- result
|
||||
return
|
||||
}
|
||||
defer safeClose(invoker, &err)
|
||||
|
||||
_, result.user, err = invoker.StartSSHServer(ctx)
|
||||
if err != nil {
|
||||
result.err = fmt.Errorf("error getting ssh server details: %w", err)
|
||||
sshUsers <- result
|
||||
return
|
||||
}
|
||||
|
||||
result.codespace = cs
|
||||
sshUsers <- result
|
||||
}()
|
||||
}
|
||||
|
|
@ -722,43 +776,3 @@ func (a *App) Copy(ctx context.Context, args []string, opts cpOptions) error {
|
|||
}
|
||||
return a.SSH(ctx, nil, opts.sshOptions)
|
||||
}
|
||||
|
||||
// fileLogger is a wrapper around an log.Logger configured to write
|
||||
// to a file. It exports two additional methods to get the log file name
|
||||
// and close the file handle when the operation is finished.
|
||||
type fileLogger struct {
|
||||
*log.Logger
|
||||
|
||||
f *os.File
|
||||
}
|
||||
|
||||
// newFileLogger creates a new fileLogger. It returns an error if the file
|
||||
// cannot be created. The file is created on the specified path, if the path
|
||||
// is empty it is created in the temporary directory.
|
||||
func newFileLogger(file string) (fl *fileLogger, err error) {
|
||||
var f *os.File
|
||||
if file == "" {
|
||||
f, err = os.CreateTemp("", "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create tmp file: %w", err)
|
||||
}
|
||||
} else {
|
||||
f, err = os.Create(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &fileLogger{
|
||||
Logger: log.New(f, "", log.LstdFlags),
|
||||
f: f,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (fl *fileLogger) Name() string {
|
||||
return fl.f.Name()
|
||||
}
|
||||
|
||||
func (fl *fileLogger) Close() error {
|
||||
return fl.f.Close()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,130 +0,0 @@
|
|||
// Package liveshare is a Go client library for the Visual Studio Live Share
|
||||
// service, which provides collaborative, distributed editing and debugging.
|
||||
// See https://docs.microsoft.com/en-us/visualstudio/liveshare for an overview.
|
||||
//
|
||||
// It provides the ability for a Go program to connect to a Live Share
|
||||
// workspace (Connect), to expose a TCP port on a remote host
|
||||
// (UpdateSharedVisibility), to start an SSH server listening on an
|
||||
// exposed port (StartSSHServer), and to forward connections between
|
||||
// the remote port and a local listening TCP port (ForwardToListener)
|
||||
// or a local Go reader/writer (Forward).
|
||||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
type logger interface {
|
||||
Println(v ...interface{})
|
||||
Printf(f string, v ...interface{})
|
||||
}
|
||||
|
||||
// An Options specifies Live Share connection parameters.
|
||||
type Options struct {
|
||||
SessionID string
|
||||
SessionToken string // token for SSH session
|
||||
RelaySAS string
|
||||
RelayEndpoint string
|
||||
HostPublicKeys []string
|
||||
Logger logger // required
|
||||
TLSConfig *tls.Config // (optional)
|
||||
}
|
||||
|
||||
// uri returns a websocket URL for the specified options.
|
||||
func (opts *Options) uri(action string) (string, error) {
|
||||
if opts.SessionID == "" {
|
||||
return "", errors.New("SessionID is required")
|
||||
}
|
||||
if opts.RelaySAS == "" {
|
||||
return "", errors.New("RelaySAS is required")
|
||||
}
|
||||
if opts.RelayEndpoint == "" {
|
||||
return "", errors.New("RelayEndpoint is required")
|
||||
}
|
||||
|
||||
sas := url.QueryEscape(opts.RelaySAS)
|
||||
uri := opts.RelayEndpoint
|
||||
|
||||
if strings.HasPrefix(uri, "http:") {
|
||||
uri = strings.Replace(uri, "http:", "ws:", 1)
|
||||
} else {
|
||||
uri = strings.Replace(uri, "sb:", "wss:", -1)
|
||||
}
|
||||
|
||||
uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1)
|
||||
uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas
|
||||
return uri, nil
|
||||
}
|
||||
|
||||
// Connect connects to a Live Share workspace specified by the
|
||||
// options, and returns a session representing the connection.
|
||||
// The caller must call the session's Close method to end the session.
|
||||
func Connect(ctx context.Context, opts Options) (*Session, error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
|
||||
defer span.Finish()
|
||||
|
||||
uri, err := opts.uri("connect")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sock := newSocket(uri, opts.TLSConfig)
|
||||
if err := sock.connect(ctx); err != nil {
|
||||
return nil, fmt.Errorf("error connecting websocket: %w", err)
|
||||
}
|
||||
|
||||
if opts.SessionToken == "" {
|
||||
return nil, errors.New("SessionToken is required")
|
||||
}
|
||||
ssh := newSSHSession(opts.SessionToken, opts.HostPublicKeys, sock)
|
||||
if err := ssh.connect(ctx); err != nil {
|
||||
return nil, fmt.Errorf("error connecting to ssh session: %w", err)
|
||||
}
|
||||
|
||||
rpc := newRPCClient(ssh)
|
||||
rpc.connect(ctx)
|
||||
|
||||
args := joinWorkspaceArgs{
|
||||
ID: opts.SessionID,
|
||||
ConnectionMode: "local",
|
||||
JoiningUserSessionToken: opts.SessionToken,
|
||||
ClientCapabilities: clientCapabilities{
|
||||
IsNonInteractive: false,
|
||||
},
|
||||
}
|
||||
var result joinWorkspaceResult
|
||||
if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil {
|
||||
return nil, fmt.Errorf("error joining Live Share workspace: %w", err)
|
||||
}
|
||||
|
||||
s := &Session{
|
||||
ssh: ssh,
|
||||
rpc: rpc,
|
||||
keepAliveReason: make(chan string, 1),
|
||||
logger: opts.Logger,
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
type clientCapabilities struct {
|
||||
IsNonInteractive bool `json:"isNonInteractive"`
|
||||
}
|
||||
|
||||
type joinWorkspaceArgs struct {
|
||||
ID string `json:"id"`
|
||||
ConnectionMode string `json:"connectionMode"`
|
||||
JoiningUserSessionToken string `json:"joiningUserSessionToken"`
|
||||
ClientCapabilities clientCapabilities `json:"clientCapabilities"`
|
||||
}
|
||||
|
||||
type joinWorkspaceResult struct {
|
||||
SessionNumber int `json:"sessionNumber"`
|
||||
}
|
||||
|
|
@ -1,73 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
opts := Options{
|
||||
SessionID: "session-id",
|
||||
SessionToken: "session-token",
|
||||
RelaySAS: "relay-sas",
|
||||
HostPublicKeys: []string{livesharetest.SSHPublicKey},
|
||||
Logger: newMockLogger(),
|
||||
}
|
||||
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
var joinWorkspaceReq joinWorkspaceArgs
|
||||
if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling req: %w", err)
|
||||
}
|
||||
if joinWorkspaceReq.ID != opts.SessionID {
|
||||
return nil, errors.New("connection session id does not match")
|
||||
}
|
||||
if joinWorkspaceReq.ConnectionMode != "local" {
|
||||
return nil, errors.New("connection mode is not local")
|
||||
}
|
||||
if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken {
|
||||
return nil, errors.New("connection user token does not match")
|
||||
}
|
||||
if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false {
|
||||
return nil, errors.New("non interactive is not false")
|
||||
}
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
|
||||
server, err := livesharetest.NewServer(
|
||||
livesharetest.WithPassword(opts.SessionToken),
|
||||
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
|
||||
livesharetest.WithRelaySAS(opts.RelaySAS),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("error creating Live Share server: %v", err)
|
||||
}
|
||||
defer server.Close()
|
||||
opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
_, err := Connect(ctx, opts) // ignore session
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-server.Err():
|
||||
t.Errorf("error from server: %v", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBadOptions(t *testing.T) {
|
||||
goodOptions := Options{
|
||||
SessionID: "sess-id",
|
||||
SessionToken: "sess-token",
|
||||
RelaySAS: "sas",
|
||||
RelayEndpoint: "endpoint",
|
||||
}
|
||||
|
||||
opts := goodOptions
|
||||
opts.SessionID = ""
|
||||
checkBadOptions(t, opts)
|
||||
|
||||
opts = goodOptions
|
||||
opts.SessionToken = ""
|
||||
checkBadOptions(t, opts)
|
||||
|
||||
opts = goodOptions
|
||||
opts.RelaySAS = ""
|
||||
checkBadOptions(t, opts)
|
||||
|
||||
opts = goodOptions
|
||||
opts.RelayEndpoint = ""
|
||||
checkBadOptions(t, opts)
|
||||
|
||||
opts = Options{}
|
||||
checkBadOptions(t, opts)
|
||||
}
|
||||
|
||||
func checkBadOptions(t *testing.T, opts Options) {
|
||||
if _, err := Connect(context.Background(), opts); err == nil {
|
||||
t.Errorf("Connect(%+v): no error", opts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionsURI(t *testing.T) {
|
||||
opts := Options{
|
||||
SessionID: "sess-id",
|
||||
SessionToken: "sess-token",
|
||||
RelaySAS: "sas",
|
||||
RelayEndpoint: "sb://endpoint/.net/liveshare",
|
||||
}
|
||||
uri, err := opts.uri("connect")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" {
|
||||
t.Errorf("uri is not correct, got: '%v'", uri)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,241 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type portForwardingSession interface {
|
||||
StartSharing(context.Context, string, int) (ChannelID, error)
|
||||
OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error)
|
||||
KeepAlive(string)
|
||||
}
|
||||
|
||||
type ReadWriteHalfCloser interface {
|
||||
io.ReadWriteCloser
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
type combinedReadWriteHalfCloser struct {
|
||||
io.ReadCloser
|
||||
io.WriteCloser
|
||||
}
|
||||
|
||||
func NewReadWriteHalfCloser(reader io.ReadCloser, writer io.WriteCloser) ReadWriteHalfCloser {
|
||||
return &combinedReadWriteHalfCloser{reader, writer}
|
||||
}
|
||||
|
||||
func (crwc *combinedReadWriteHalfCloser) Close() error {
|
||||
werr := crwc.WriteCloser.Close()
|
||||
rerr := crwc.ReadCloser.Close()
|
||||
if werr != nil {
|
||||
return werr
|
||||
}
|
||||
return rerr
|
||||
}
|
||||
|
||||
func (crwc *combinedReadWriteHalfCloser) CloseWrite() error {
|
||||
return crwc.WriteCloser.Close()
|
||||
}
|
||||
|
||||
// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote
|
||||
// container to a local destination such as a network port or Go reader/writer.
|
||||
type PortForwarder struct {
|
||||
session portForwardingSession
|
||||
name string
|
||||
remotePort int
|
||||
keepAlive bool
|
||||
}
|
||||
|
||||
// NewPortForwarder returns a new PortForwarder for the specified
|
||||
// remote port and Live Share session. The name describes the purpose
|
||||
// of the remote port or service. The keepAlive flag indicates whether
|
||||
// the session should be kept alive with port forwarding traffic.
|
||||
func NewPortForwarder(session portForwardingSession, name string, remotePort int, keepAlive bool) *PortForwarder {
|
||||
return &PortForwarder{
|
||||
session: session,
|
||||
name: name,
|
||||
remotePort: remotePort,
|
||||
keepAlive: keepAlive,
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardToListener forwards traffic between the container's remote
|
||||
// port and a local port, which must already be listening for
|
||||
// connections. (Accepting a listener rather than a port number avoids
|
||||
// races against other processes opening ports, and against a client
|
||||
// connecting to the socket prematurely.)
|
||||
//
|
||||
// ForwardToListener accepts and handles connections on the local port
|
||||
// until it encounters the first error, which may include context
|
||||
// cancellation. Its error result is always non-nil. The caller is
|
||||
// responsible for closing the listening port.
|
||||
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen *net.TCPListener) (err error) {
|
||||
id, err := fwd.shareRemotePort(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
errc := make(chan error, 1)
|
||||
sendError := func(err error) {
|
||||
// Use non-blocking send, to avoid goroutines getting
|
||||
// stuck in case of concurrent or sequential errors.
|
||||
select {
|
||||
case errc <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listen.AcceptTCP()
|
||||
if err != nil {
|
||||
sendError(err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := fwd.handleConnection(ctx, id, conn); err != nil {
|
||||
sendError(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
// Forward forwards traffic between the container's remote port and
|
||||
// the specified read/write stream. On return, the stream is closed.
|
||||
func (fwd *PortForwarder) Forward(ctx context.Context, conn ReadWriteHalfCloser) error {
|
||||
id, err := fwd.shareRemotePort(ctx)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Create buffered channel so that send doesn't get stuck after context cancellation.
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
errc <- fwd.handleConnection(ctx, id, conn)
|
||||
}()
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) {
|
||||
id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err)
|
||||
}
|
||||
|
||||
return id, err
|
||||
}
|
||||
|
||||
func awaitError(ctx context.Context, errc <-chan error) error {
|
||||
select {
|
||||
case err := <-errc:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err() // canceled
|
||||
}
|
||||
}
|
||||
|
||||
type trafficMonitorSession interface {
|
||||
KeepAlive(string)
|
||||
}
|
||||
|
||||
// trafficMonitor implements io.Reader. It keeps the session alive by notifying
|
||||
// it of the traffic type during Read operations.
|
||||
type trafficMonitor struct {
|
||||
reader io.Reader
|
||||
|
||||
session trafficMonitorSession
|
||||
trafficType string
|
||||
}
|
||||
|
||||
// newTrafficMonitor returns a new trafficMonitor for the specified
|
||||
// session and traffic type. It wraps the provided io.Reader with its own
|
||||
// Read method.
|
||||
func newTrafficMonitor(reader io.Reader, session trafficMonitorSession, trafficType string) *trafficMonitor {
|
||||
return &trafficMonitor{reader, session, trafficType}
|
||||
}
|
||||
|
||||
func (t *trafficMonitor) Read(p []byte) (n int, err error) {
|
||||
t.session.KeepAlive(t.trafficType)
|
||||
return t.reader.Read(p)
|
||||
}
|
||||
|
||||
// handleConnection handles forwarding for a single accepted connection, then closes it.
|
||||
func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn ReadWriteHalfCloser) (err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
|
||||
defer span.Finish()
|
||||
|
||||
defer safeClose(conn, &err)
|
||||
|
||||
channel, err := fwd.session.OpenStreamingChannel(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening streaming channel for new connection: %w", err)
|
||||
}
|
||||
// Ideally we would call safeClose again, but (*ssh.channel).Close
|
||||
// appears to have a bug that causes it return io.EOF spuriously
|
||||
// if its peer closed first; see github.com/golang/go/issues/38115.
|
||||
defer func() {
|
||||
closeErr := channel.Close()
|
||||
if err == nil && closeErr != io.EOF {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
// bi-directional copy of data.
|
||||
errs := make(chan error, 2)
|
||||
copyConn := func(w ReadWriteHalfCloser, r io.Reader) {
|
||||
_, err := io.Copy(w, r)
|
||||
errs <- err
|
||||
|
||||
// Ignore errors here, we call the full Close() later and catch that error
|
||||
_ = w.CloseWrite()
|
||||
}
|
||||
|
||||
var (
|
||||
channelReader io.Reader = channel
|
||||
connReader io.Reader = conn
|
||||
)
|
||||
|
||||
// If the forwader has been configured to keep the session alive
|
||||
// it will monitor the I/O and notify the session of the traffic.
|
||||
if fwd.keepAlive {
|
||||
channelReader = newTrafficMonitor(channelReader, fwd.session, "output")
|
||||
connReader = newTrafficMonitor(connReader, fwd.session, "input")
|
||||
}
|
||||
|
||||
go copyConn(conn, channelReader)
|
||||
go copyConn(channel, connReader)
|
||||
|
||||
// Wait until context is cancelled or both copies are done.
|
||||
// Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail.
|
||||
// TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file?
|
||||
for i := 0; ; {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-errs:
|
||||
i++
|
||||
if i == 2 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// safeClose reports the error (to *err) from closing the stream only
|
||||
// if no other error was previously reported.
|
||||
func safeClose(closer io.Closer, err *error) {
|
||||
closeErr := closer.Close()
|
||||
if *err == nil {
|
||||
*err = closeErr
|
||||
}
|
||||
}
|
||||
|
|
@ -1,153 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func TestNewPortForwarder(t *testing.T) {
|
||||
testServer, session, err := makeMockSession()
|
||||
if err != nil {
|
||||
t.Errorf("create mock client: %v", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
pf := NewPortForwarder(session, "ssh", 80, false)
|
||||
if pf == nil {
|
||||
t.Error("port forwarder is nil")
|
||||
}
|
||||
}
|
||||
|
||||
type portUpdateNotification struct {
|
||||
PortNotification
|
||||
conn *jsonrpc2.Conn
|
||||
}
|
||||
|
||||
func TestPortForwarderStart(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" {
|
||||
t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5338")
|
||||
}
|
||||
|
||||
streamName, streamCondition := "stream-name", "stream-condition"
|
||||
const port = 8000
|
||||
sendNotification := make(chan portUpdateNotification)
|
||||
serverSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
// Send the PortNotification that will be awaited on in session.StartSharing
|
||||
sendNotification <- portUpdateNotification{
|
||||
PortNotification: PortNotification{
|
||||
Port: port,
|
||||
ChangeKind: PortChangeKindStart,
|
||||
},
|
||||
conn: conn,
|
||||
}
|
||||
return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
|
||||
}
|
||||
getStream := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
return "stream-id", nil
|
||||
}
|
||||
|
||||
stream := bytes.NewBufferString("stream-data")
|
||||
testServer, session, err := makeMockSession(
|
||||
livesharetest.WithService("serverSharing.startSharing", serverSharing),
|
||||
livesharetest.WithService("streamManager.getStream", getStream),
|
||||
livesharetest.WithStream("stream-id", stream),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("create mock session: %v", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
|
||||
listen, err := net.Listen("tcp", "127.0.0.1:8000")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer listen.Close()
|
||||
tcpListener, ok := listen.(*net.TCPListener)
|
||||
if !ok {
|
||||
t.Fatal("net.Listen did not return a TCPListener")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
notif := <-sendNotification
|
||||
_, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif)
|
||||
}()
|
||||
|
||||
done := make(chan error, 2)
|
||||
go func() {
|
||||
done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, tcpListener)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
var conn net.Conn
|
||||
|
||||
// We retry DialTimeout in a loop to deal with a race in PortForwarder startup.
|
||||
for tries := 0; conn == nil && tries < 2; tries++ {
|
||||
conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second)
|
||||
if conn == nil {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
if conn == nil {
|
||||
done <- errors.New("failed to connect to forwarded port")
|
||||
return
|
||||
}
|
||||
b := make([]byte, len("stream-data"))
|
||||
if _, err := conn.Read(b); err != nil && err != io.EOF {
|
||||
done <- fmt.Errorf("reading stream: %w", err)
|
||||
return
|
||||
}
|
||||
if string(b) != "stream-data" {
|
||||
done <- fmt.Errorf("stream data is not expected value, got: %s", string(b))
|
||||
return
|
||||
}
|
||||
if _, err := conn.Write([]byte("new-data")); err != nil {
|
||||
done <- fmt.Errorf("writing to stream: %w", err)
|
||||
return
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %v", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortForwarderTrafficMonitor(t *testing.T) {
|
||||
buf := bytes.NewBufferString("some-input")
|
||||
session := &Session{keepAliveReason: make(chan string, 1)}
|
||||
trafficType := "io"
|
||||
|
||||
tm := newTrafficMonitor(buf, session, trafficType)
|
||||
l := len(buf.Bytes())
|
||||
|
||||
bb := make([]byte, l)
|
||||
n, err := tm.Read(bb)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read from traffic monitor: %v", err)
|
||||
}
|
||||
if n != l {
|
||||
t.Errorf("expected to read %d bytes, got %d", l, n)
|
||||
}
|
||||
|
||||
keepAliveReason := <-session.keepAliveReason
|
||||
if keepAliveReason != trafficType {
|
||||
t.Errorf("expected keep alive reason to be %s, got %s", trafficType, keepAliveReason)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,101 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
// Port describes a port exposed by the container.
|
||||
type Port struct {
|
||||
SourcePort int `json:"sourcePort"`
|
||||
DestinationPort int `json:"destinationPort"`
|
||||
SessionName string `json:"sessionName"`
|
||||
StreamName string `json:"streamName"`
|
||||
StreamCondition string `json:"streamCondition"`
|
||||
BrowseURL string `json:"browseUrl"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"`
|
||||
HasTLSHandshakePassed bool `json:"hasTLSHandshakePassed"`
|
||||
Privacy string `json:"privacy"`
|
||||
}
|
||||
|
||||
type PortChangeKind string
|
||||
|
||||
const (
|
||||
PortChangeKindStart PortChangeKind = "start"
|
||||
PortChangeKindUpdate PortChangeKind = "update"
|
||||
)
|
||||
|
||||
type PortNotification struct {
|
||||
Success bool // Helps us disambiguate between the SharingSucceeded/SharingFailed events
|
||||
// The following are properties included in the SharingSucceeded/SharingFailed events sent by the server sharing service in the Codespace
|
||||
Port int `json:"port"`
|
||||
ChangeKind PortChangeKind `json:"changeKind"`
|
||||
ErrorDetail string `json:"errorDetail"`
|
||||
StatusCode int `json:"statusCode"`
|
||||
}
|
||||
|
||||
// WaitForPortNotification waits for a port notification to be received. It returns the notification
|
||||
// or an error if the notification is not received before the context is cancelled or it fails
|
||||
// to parse the notification.
|
||||
func (s *Session) WaitForPortNotification(ctx context.Context, port int, notifType PortChangeKind) (*PortNotification, error) {
|
||||
// We use 1-buffered channels and non-blocking sends so that
|
||||
// no goroutine gets stuck.
|
||||
notificationCh := make(chan *PortNotification, 1)
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
h := func(success bool) func(*jsonrpc2.Conn, *jsonrpc2.Request) {
|
||||
return func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
|
||||
notification := new(PortNotification)
|
||||
if err := json.Unmarshal(*req.Params, ¬ification); err != nil {
|
||||
select {
|
||||
case errCh <- fmt.Errorf("error unmarshalling notification: %w", err):
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
notification.Success = success
|
||||
if notification.Port == port && notification.ChangeKind == notifType {
|
||||
select {
|
||||
case notificationCh <- notification:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
deregisterSuccess := s.registerRequestHandler("serverSharing.sharingSucceeded", h(true))
|
||||
deregisterFailure := s.registerRequestHandler("serverSharing.sharingFailed", h(false))
|
||||
defer deregisterSuccess()
|
||||
defer deregisterFailure()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
case notification := <-notificationCh:
|
||||
return notification, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetSharedServers returns a description of each container port
|
||||
// shared by a prior call to StartSharing by some client.
|
||||
func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) {
|
||||
var response []*Port
|
||||
if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// UpdateSharedServerPrivacy controls port permissions and visibility scopes for who can access its URLs
|
||||
// in the browser.
|
||||
func (s *Session) UpdateSharedServerPrivacy(ctx context.Context, port int, visibility string) error {
|
||||
return s.rpc.do(ctx, "serverSharing.updateSharedServerPrivacy", []interface{}{port, visibility}, nil)
|
||||
}
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
type rpcClient struct {
|
||||
*jsonrpc2.Conn
|
||||
conn io.ReadWriteCloser
|
||||
handlersMu sync.Mutex
|
||||
handlers map[string][]*handlerWrapper
|
||||
}
|
||||
|
||||
func newRPCClient(conn io.ReadWriteCloser) *rpcClient {
|
||||
return &rpcClient{conn: conn, handlers: make(map[string][]*handlerWrapper)}
|
||||
}
|
||||
|
||||
func (r *rpcClient) connect(ctx context.Context) {
|
||||
stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{})
|
||||
r.Conn = jsonrpc2.NewConn(ctx, stream, r)
|
||||
}
|
||||
|
||||
func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, method)
|
||||
defer span.Finish()
|
||||
|
||||
waiter, err := r.Conn.DispatchCall(ctx, method, args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error dispatching %q call: %w", method, err)
|
||||
}
|
||||
|
||||
// timeout for waiter in case a connection cannot be made
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
return waiter.Wait(waitCtx, result)
|
||||
}
|
||||
|
||||
type handler func(conn *jsonrpc2.Conn, req *jsonrpc2.Request)
|
||||
|
||||
type handlerWrapper struct {
|
||||
fn handler
|
||||
}
|
||||
|
||||
func (r *rpcClient) register(requestType string, fn handler) func() {
|
||||
r.handlersMu.Lock()
|
||||
defer r.handlersMu.Unlock()
|
||||
|
||||
h := &handlerWrapper{fn: fn}
|
||||
r.handlers[requestType] = append(r.handlers[requestType], h)
|
||||
|
||||
return func() {
|
||||
r.deregister(requestType, h)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rpcClient) deregister(requestType string, handler *handlerWrapper) {
|
||||
r.handlersMu.Lock()
|
||||
defer r.handlersMu.Unlock()
|
||||
|
||||
handlers := r.handlers[requestType]
|
||||
for i, h := range handlers {
|
||||
if h == handler {
|
||||
// Swap h with last element and pop.
|
||||
last := len(handlers) - 1
|
||||
handlers[i], handlers[last] = handlers[last], nil
|
||||
r.handlers[requestType] = handlers[:last]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rpcClient) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
|
||||
r.handlersMu.Lock()
|
||||
defer r.handlersMu.Unlock()
|
||||
|
||||
for _, handler := range r.handlers[req.Method] {
|
||||
go handler.fn(conn, req)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,133 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// A ChannelID is an identifier for an exposed port on a remote
|
||||
// container that may be used to open an SSH channel to it.
|
||||
type ChannelID struct {
|
||||
name, condition string
|
||||
}
|
||||
|
||||
// Interface to allow the mocking of the liveshare session
|
||||
type LiveshareSession interface {
|
||||
Close() error
|
||||
GetSharedServers(context.Context) ([]*Port, error)
|
||||
KeepAlive(string)
|
||||
OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error)
|
||||
StartSharing(context.Context, string, int) (ChannelID, error)
|
||||
GetKeepAliveReason() string
|
||||
}
|
||||
|
||||
// A Session represents the session between a connected Live Share client and server.
|
||||
type Session struct {
|
||||
ssh *sshSession
|
||||
rpc *rpcClient
|
||||
|
||||
keepAliveReason chan string
|
||||
logger logger
|
||||
}
|
||||
|
||||
// Close should be called by users to clean up RPC and SSH resources whenever the session
|
||||
// is no longer active.
|
||||
func (s *Session) Close() error {
|
||||
// Closing the RPC conn closes the underlying stream (SSH)
|
||||
// So we only need to close once
|
||||
if err := s.rpc.Close(); err != nil {
|
||||
s.ssh.Close() // close SSH and ignore error
|
||||
return fmt.Errorf("error while closing Live Share session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fetches the keep alive reason from the channel and returns it.
|
||||
func (s *Session) GetKeepAliveReason() string {
|
||||
return <-s.keepAliveReason
|
||||
}
|
||||
|
||||
// registerRequestHandler registers a handler for the given request type with the RPC
|
||||
// server and returns a callback function to deregister the handler
|
||||
func (s *Session) registerRequestHandler(requestType string, h handler) func() {
|
||||
return s.rpc.register(requestType, h)
|
||||
}
|
||||
|
||||
// KeepAlive accepts a reason that is retained if there is no active reason
|
||||
// to send to the server.
|
||||
func (s *Session) KeepAlive(reason string) {
|
||||
select {
|
||||
case s.keepAliveReason <- reason:
|
||||
default:
|
||||
// there is already an active keep alive reason
|
||||
// so we can ignore this one
|
||||
}
|
||||
}
|
||||
|
||||
// StartSharing tells the Live Share host to start sharing the specified port from the container.
|
||||
// The sessionName describes the purpose of the remote port or service.
|
||||
// It returns an identifier that can be used to open an SSH channel to the remote port.
|
||||
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (ChannelID, error) {
|
||||
args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
startNotification, err := s.WaitForPortNotification(ctx, port, PortChangeKindStart)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while waiting for port notification: %w", err)
|
||||
|
||||
}
|
||||
if !startNotification.Success {
|
||||
return fmt.Errorf("error while starting port sharing: %s", startNotification.ErrorDetail)
|
||||
}
|
||||
return nil // success
|
||||
})
|
||||
|
||||
var response Port
|
||||
g.Go(func() error {
|
||||
return s.rpc.do(ctx, "serverSharing.startSharing", args, &response)
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return ChannelID{}, err
|
||||
}
|
||||
|
||||
return ChannelID{response.StreamName, response.StreamCondition}, nil
|
||||
}
|
||||
|
||||
func (s *Session) OpenStreamingChannel(ctx context.Context, id ChannelID) (ssh.Channel, error) {
|
||||
type getStreamArgs struct {
|
||||
StreamName string `json:"streamName"`
|
||||
Condition string `json:"condition"`
|
||||
}
|
||||
args := getStreamArgs{
|
||||
StreamName: id.name,
|
||||
Condition: id.condition,
|
||||
}
|
||||
var streamID string
|
||||
if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
|
||||
return nil, fmt.Errorf("error getting stream id: %w", err)
|
||||
}
|
||||
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest")
|
||||
defer span.Finish()
|
||||
_ = ctx // ctx is not currently used
|
||||
|
||||
channel, reqs, err := s.ssh.conn.OpenChannel("session", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening ssh channel for transport: %w", err)
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
requestType := fmt.Sprintf("stream-transport-%s", streamID)
|
||||
if _, err = channel.SendRequest(requestType, true, nil); err != nil {
|
||||
return nil, fmt.Errorf("error sending channel request: %w", err)
|
||||
}
|
||||
|
||||
return channel, nil
|
||||
}
|
||||
|
|
@ -1,278 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
|
||||
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
const sessionToken = "session-token"
|
||||
opts = append(
|
||||
opts,
|
||||
livesharetest.WithPassword(sessionToken),
|
||||
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
|
||||
)
|
||||
testServer, err := livesharetest.NewServer(opts...)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating server: %w", err)
|
||||
}
|
||||
|
||||
session, err := Connect(context.Background(), Options{
|
||||
SessionID: "session-id",
|
||||
SessionToken: sessionToken,
|
||||
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
|
||||
RelaySAS: "relay-sas",
|
||||
HostPublicKeys: []string{livesharetest.SSHPublicKey},
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
Logger: newMockLogger(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err)
|
||||
}
|
||||
return testServer, session, nil
|
||||
}
|
||||
|
||||
func TestServerStartSharing(t *testing.T) {
|
||||
serverPort, serverProtocol := 2222, "sshd"
|
||||
sendNotification := make(chan portUpdateNotification)
|
||||
startSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
var args []interface{}
|
||||
if err := json.Unmarshal(*req.Params, &args); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling request: %w", err)
|
||||
}
|
||||
if len(args) < 3 {
|
||||
return nil, errors.New("not enough arguments to start sharing")
|
||||
}
|
||||
port, ok := args[0].(float64)
|
||||
if !ok {
|
||||
return nil, errors.New("port argument is not an int")
|
||||
}
|
||||
if port != float64(serverPort) {
|
||||
return nil, errors.New("port does not match serverPort")
|
||||
}
|
||||
if protocol, ok := args[1].(string); !ok {
|
||||
return nil, errors.New("protocol argument is not a string")
|
||||
} else if protocol != serverProtocol {
|
||||
return nil, errors.New("protocol does not match serverProtocol")
|
||||
}
|
||||
if browseURL, ok := args[2].(string); !ok {
|
||||
return nil, errors.New("browse url is not a string")
|
||||
} else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) {
|
||||
return nil, errors.New("browseURL does not match expected")
|
||||
}
|
||||
sendNotification <- portUpdateNotification{
|
||||
PortNotification: PortNotification{
|
||||
Port: int(port),
|
||||
ChangeKind: PortChangeKindStart,
|
||||
},
|
||||
conn: conn,
|
||||
}
|
||||
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
|
||||
}
|
||||
testServer, session, err := makeMockSession(
|
||||
livesharetest.WithService("serverSharing.startSharing", startSharing),
|
||||
)
|
||||
defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close()
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error creating mock session: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
go func() {
|
||||
notif := <-sendNotification
|
||||
_, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif)
|
||||
}()
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
streamID, err := session.StartSharing(ctx, serverProtocol, serverPort)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("error sharing server: %w", err)
|
||||
}
|
||||
if streamID.name == "" || streamID.condition == "" {
|
||||
done <- errors.New("stream name or condition is blank")
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %v", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerGetSharedServers(t *testing.T) {
|
||||
sharedServer := Port{
|
||||
SourcePort: 2222,
|
||||
StreamName: "stream-name",
|
||||
StreamCondition: "stream-condition",
|
||||
}
|
||||
getSharedServers := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
return []*Port{&sharedServer}, nil
|
||||
}
|
||||
testServer, session, err := makeMockSession(
|
||||
livesharetest.WithService("serverSharing.getSharedServers", getSharedServers),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("error creating mock session: %v", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
ctx := context.Background()
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
ports, err := session.GetSharedServers(ctx)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("error getting shared servers: %w", err)
|
||||
}
|
||||
if len(ports) < 1 {
|
||||
done <- errors.New("not enough ports returned")
|
||||
}
|
||||
if ports[0].SourcePort != sharedServer.SourcePort {
|
||||
done <- errors.New("source port does not match")
|
||||
}
|
||||
if ports[0].StreamName != sharedServer.StreamName {
|
||||
done <- errors.New("stream name does not match")
|
||||
}
|
||||
if ports[0].StreamCondition != sharedServer.StreamCondition {
|
||||
done <- errors.New("stream condiion does not match")
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %v", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerUpdateSharedServerPrivacy(t *testing.T) {
|
||||
updateSharedVisibility := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
var req []interface{}
|
||||
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal req: %w", err)
|
||||
}
|
||||
if len(req) < 2 {
|
||||
return nil, errors.New("request arguments is less than 2")
|
||||
}
|
||||
if port, ok := req[0].(float64); ok {
|
||||
if port != 80.0 {
|
||||
return nil, errors.New("port param is not expected value")
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("port param is not a float64")
|
||||
}
|
||||
if privacy, ok := req[1].(string); ok {
|
||||
if privacy != "public" {
|
||||
return nil, fmt.Errorf("expected privacy param to be public but got %q", privacy)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("expected privacy param to be a bool but go %T", req[1])
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
testServer, session, err := makeMockSession(
|
||||
livesharetest.WithService("serverSharing.updateSharedServerPrivacy", updateSharedVisibility),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("creating mock session: %v", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
ctx := context.Background()
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
done <- session.UpdateSharedServerPrivacy(ctx, 80, "public")
|
||||
}()
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %v", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidHostKey(t *testing.T) {
|
||||
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
const sessionToken = "session-token"
|
||||
opts := []livesharetest.ServerOption{
|
||||
livesharetest.WithPassword(sessionToken),
|
||||
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
|
||||
}
|
||||
testServer, err := livesharetest.NewServer(opts...)
|
||||
if err != nil {
|
||||
t.Errorf("error creating server: %v", err)
|
||||
}
|
||||
_, err = Connect(context.Background(), Options{
|
||||
SessionID: "session-id",
|
||||
SessionToken: sessionToken,
|
||||
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
|
||||
RelaySAS: "relay-sas",
|
||||
HostPublicKeys: []string{},
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected invalid host key error, got: nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeepAliveNonBlocking(t *testing.T) {
|
||||
session := &Session{keepAliveReason: make(chan string, 1)}
|
||||
for i := 0; i < 2; i++ {
|
||||
session.KeepAlive("io")
|
||||
}
|
||||
|
||||
// if KeepAlive blocks, we'll never reach this and timeout the test
|
||||
// timing out
|
||||
}
|
||||
|
||||
type mockLogger struct {
|
||||
sync.Mutex
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func newMockLogger() *mockLogger {
|
||||
return &mockLogger{buf: new(bytes.Buffer)}
|
||||
}
|
||||
|
||||
func (m *mockLogger) Printf(format string, v ...interface{}) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
m.buf.WriteString(fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) Println(v ...interface{}) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
m.buf.WriteString(fmt.Sprintln(v...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) String() string {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
return m.buf.String()
|
||||
}
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type socket struct {
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
|
||||
conn *websocket.Conn
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func newSocket(uri string, tlsConfig *tls.Config) *socket {
|
||||
return &socket{addr: uri, tlsConfig: tlsConfig}
|
||||
}
|
||||
|
||||
func (s *socket) connect(ctx context.Context) error {
|
||||
dialer := websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: 45 * time.Second,
|
||||
TLSClientConfig: s.tlsConfig,
|
||||
}
|
||||
ws, _, err := dialer.Dial(s.addr, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.conn = ws
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *socket) Read(b []byte) (int, error) {
|
||||
if s.reader == nil {
|
||||
_, reader, err := s.conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s.reader = reader
|
||||
}
|
||||
|
||||
bytesRead, err := s.reader.Read(b)
|
||||
if err != nil {
|
||||
s.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (s *socket) Write(b []byte) (int, error) {
|
||||
nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
bytesWritten, err := nextWriter.Write(b)
|
||||
nextWriter.Close()
|
||||
|
||||
return bytesWritten, err
|
||||
}
|
||||
|
||||
func (s *socket) Close() error {
|
||||
return s.conn.Close()
|
||||
}
|
||||
|
||||
func (s *socket) LocalAddr() net.Addr {
|
||||
return s.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (s *socket) RemoteAddr() net.Addr {
|
||||
return s.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (s *socket) SetDeadline(t time.Time) error {
|
||||
if err := s.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (s *socket) SetReadDeadline(t time.Time) error {
|
||||
return s.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (s *socket) SetWriteDeadline(t time.Time) error {
|
||||
return s.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
|
@ -1,79 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type sshSession struct {
|
||||
*ssh.Session
|
||||
token string
|
||||
hostPublicKeys []string
|
||||
socket net.Conn
|
||||
conn ssh.Conn
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
func newSSHSession(token string, hostPublicKeys []string, socket net.Conn) *sshSession {
|
||||
return &sshSession{token: token, hostPublicKeys: hostPublicKeys, socket: socket}
|
||||
}
|
||||
|
||||
func (s *sshSession) connect(ctx context.Context) error {
|
||||
clientConfig := ssh.ClientConfig{
|
||||
User: "",
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password(s.token),
|
||||
},
|
||||
HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"},
|
||||
HostKeyCallback: func(hostname string, addr net.Addr, key ssh.PublicKey) error {
|
||||
encodedKey := base64.StdEncoding.EncodeToString(key.Marshal())
|
||||
for _, hpk := range s.hostPublicKeys {
|
||||
if encodedKey == hpk {
|
||||
return nil // we found a match for expected public key, safely return
|
||||
}
|
||||
}
|
||||
return errors.New("invalid host public key")
|
||||
},
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating ssh client connection: %w", err)
|
||||
}
|
||||
s.conn = sshClientConn
|
||||
|
||||
sshClient := ssh.NewClient(sshClientConn, chans, reqs)
|
||||
s.Session, err = sshClient.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating ssh client session: %w", err)
|
||||
}
|
||||
|
||||
s.reader, err = s.Session.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating ssh session reader: %w", err)
|
||||
}
|
||||
|
||||
s.writer, err = s.Session.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating ssh session writer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sshSession) Read(p []byte) (n int, err error) {
|
||||
return s.reader.Read(p)
|
||||
}
|
||||
|
||||
func (s *sshSession) Write(p []byte) (n int, err error) {
|
||||
return s.writer.Write(p)
|
||||
}
|
||||
|
|
@ -1,349 +0,0 @@
|
|||
package livesharetest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXgIBAAKBgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yV
|
||||
rCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhY
|
||||
lR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3wIDAQAB
|
||||
AoGBAI8UemkYoSM06gBCh5D1RHQt8eKNltzL7g9QSNfoXeZOC7+q+/TiZPcbqLp0
|
||||
5lyOalu8b8Ym7J0rSE377Ypj13LyHMXS63e4wMiXv3qOl3GDhMLpypnJ8PwqR2b8
|
||||
IijL2jrpQfLu6IYqlteA+7e9aEexJa1RRwxYIyq6pG1IYpbhAkEA9nKgtj3Z6ZDC
|
||||
46IdqYzuUM9ZQdcw4AFr407+lub7tbWe5pYmaq3cT725IwLw081OAmnWJYFDMa/n
|
||||
IPl9YcZSPQJBAMGOMbPs/YPkQAsgNdIUlFtK3o41OrrwJuTRTvv0DsbqDV0LKOiC
|
||||
t8oAQQvjisH6Ew5OOhFyIFXtvZfzQMJppksCQQDWFd+cUICTUEise/Duj9maY3Uz
|
||||
J99ySGnTbZTlu8PfJuXhg3/d3ihrMPG6A1z3cPqaSBxaOj8H07mhQHn1zNU1AkEA
|
||||
hkl+SGPrO793g4CUdq2ahIA8SpO5rIsDoQtq7jlUq0MlhGFCv5Y5pydn+bSjx5MV
|
||||
933kocf5kUSBntPBIWElYwJAZTm5ghu0JtSE6t3km0iuj7NGAQSdb6mD8+O7C3CP
|
||||
FU3vi+4HlBysaT6IZ/HG+/dBsr4gYp4LGuS7DbaLuYw/uw==
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
|
||||
const SSHPublicKey = `AAAAB3NzaC1yc2EAAAADAQABAAAAgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yVrCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhYlR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3w==`
|
||||
|
||||
// Server represents a LiveShare relay host server.
|
||||
type Server struct {
|
||||
password string
|
||||
services map[string]RPCHandleFunc
|
||||
relaySAS string
|
||||
streams map[string]io.ReadWriter
|
||||
sshConfig *ssh.ServerConfig
|
||||
httptestServer *httptest.Server
|
||||
errCh chan error
|
||||
nonSecure bool
|
||||
}
|
||||
|
||||
// NewServer creates a new Server. ServerOptions can be passed to configure
|
||||
// the SSH password, backing service, secrets and more.
|
||||
func NewServer(opts ...ServerOption) (*Server, error) {
|
||||
server := new(Server)
|
||||
|
||||
for _, o := range opts {
|
||||
if err := o(server); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
server.sshConfig = &ssh.ServerConfig{
|
||||
PasswordCallback: sshPasswordCallback(server.password),
|
||||
}
|
||||
privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing key: %w", err)
|
||||
}
|
||||
server.sshConfig.AddHostKey(privateKey)
|
||||
|
||||
server.errCh = make(chan error, 1)
|
||||
|
||||
if server.nonSecure {
|
||||
server.httptestServer = httptest.NewServer(http.HandlerFunc(makeConnection(server)))
|
||||
} else {
|
||||
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
|
||||
}
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// ServerOption is used to configure the Server.
|
||||
type ServerOption func(*Server) error
|
||||
|
||||
// WithPassword configures the Server password for SSH.
|
||||
func WithPassword(password string) ServerOption {
|
||||
return func(s *Server) error {
|
||||
s.password = password
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithNonSecure configures the Server as non-secure.
|
||||
func WithNonSecure() ServerOption {
|
||||
return func(s *Server) error {
|
||||
s.nonSecure = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithService accepts a mock RPC service for the Server to invoke.
|
||||
func WithService(serviceName string, handler RPCHandleFunc) ServerOption {
|
||||
return func(s *Server) error {
|
||||
if s.services == nil {
|
||||
s.services = make(map[string]RPCHandleFunc)
|
||||
}
|
||||
|
||||
s.services[serviceName] = handler
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithRelaySAS configures the relay SAS configuration key.
|
||||
func WithRelaySAS(sas string) ServerOption {
|
||||
return func(s *Server) error {
|
||||
s.relaySAS = sas
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithStream allows you to specify a mock data stream for the server.
|
||||
func WithStream(name string, stream io.ReadWriter) ServerOption {
|
||||
return func(s *Server) error {
|
||||
if s.streams == nil {
|
||||
s.streams = make(map[string]io.ReadWriter)
|
||||
}
|
||||
s.streams[name] = stream
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
|
||||
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
if string(password) == serverPassword {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.New("password rejected")
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the underlying httptest Server.
|
||||
func (s *Server) Close() {
|
||||
s.httptestServer.Close()
|
||||
}
|
||||
|
||||
// URL returns the httptest Server url.
|
||||
func (s *Server) URL() string {
|
||||
return s.httptestServer.URL
|
||||
}
|
||||
|
||||
func (s *Server) Err() <-chan error {
|
||||
return s.errCh
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
|
||||
func makeConnection(server *Server) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if server.relaySAS != "" {
|
||||
// validate the sas key
|
||||
sasParam := req.URL.Query().Get("sb-hc-token")
|
||||
if sasParam != server.relaySAS {
|
||||
sendError(server.errCh, errors.New("error validating sas"))
|
||||
return
|
||||
}
|
||||
}
|
||||
c, err := upgrader.Upgrade(w, req, nil)
|
||||
if err != nil {
|
||||
sendError(server.errCh, fmt.Errorf("error upgrading connection: %w", err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
sendError(server.errCh, err)
|
||||
}
|
||||
}()
|
||||
|
||||
socketConn := newSocketConn(c)
|
||||
_, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
|
||||
if err != nil {
|
||||
sendError(server.errCh, fmt.Errorf("error creating new ssh conn: %w", err))
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
if err := handleChannels(ctx, server, chans); err != nil {
|
||||
sendError(server.errCh, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendError does a non-blocking send of the error to the err channel.
|
||||
func sendError(errc chan<- error, err error) {
|
||||
select {
|
||||
case errc <- err:
|
||||
default:
|
||||
// channel is blocked with a previous error, so we ignore
|
||||
// this current error
|
||||
}
|
||||
}
|
||||
|
||||
// awaitError waits for the context to finish and returns its error (if any).
|
||||
// It also waits for an err to come through the err channel.
|
||||
func awaitError(ctx context.Context, errc <-chan error) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-errc:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// handleChannels services the sshChannels channel. For each SSH channel received
|
||||
// it creates a go routine to service the channel's requests. It returns on the first
|
||||
// error encountered.
|
||||
func handleChannels(ctx context.Context, server *Server, sshChannels <-chan ssh.NewChannel) error {
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
for sshCh := range sshChannels {
|
||||
ch, reqs, err := sshCh.Accept()
|
||||
if err != nil {
|
||||
sendError(errc, fmt.Errorf("failed to accept channel: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := handleRequests(ctx, server, ch, reqs); err != nil {
|
||||
sendError(errc, fmt.Errorf("failed to handle requests: %w", err))
|
||||
}
|
||||
}()
|
||||
|
||||
handleChannel(server, ch)
|
||||
}
|
||||
}()
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
// handleRequests services the SSH channel requests channel. It replies to requests and
|
||||
// when stream transport requests are encountered, creates a go routine to create a
|
||||
// bi-directional data stream between the channel and server stream. It returns on the first error
|
||||
// encountered.
|
||||
func handleRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) error {
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
for req := range reqs {
|
||||
r := req
|
||||
if r.WantReply {
|
||||
if err := r.Reply(true, nil); err != nil {
|
||||
sendError(errc, fmt.Errorf("error replying to channel request: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(r.Type, "stream-transport") {
|
||||
go func() {
|
||||
if err := forwardStream(ctx, server, r.Type, channel); err != nil {
|
||||
sendError(errc, fmt.Errorf("failed to forward stream: %w", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
// concurrentStream is a concurrency safe io.ReadWriter.
|
||||
type concurrentStream struct {
|
||||
sync.RWMutex
|
||||
stream io.ReadWriter
|
||||
}
|
||||
|
||||
func newConcurrentStream(rw io.ReadWriter) *concurrentStream {
|
||||
return &concurrentStream{stream: rw}
|
||||
}
|
||||
|
||||
func (cs *concurrentStream) Read(b []byte) (int, error) {
|
||||
cs.RLock()
|
||||
defer cs.RUnlock()
|
||||
return cs.stream.Read(b)
|
||||
}
|
||||
|
||||
func (cs *concurrentStream) Write(b []byte) (int, error) {
|
||||
cs.Lock()
|
||||
defer cs.Unlock()
|
||||
return cs.stream.Write(b)
|
||||
}
|
||||
|
||||
// forwardStream does a bi-directional copy of the stream <-> with the SSH channel. The io.Copy
|
||||
// runs until an error is encountered.
|
||||
func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) (err error) {
|
||||
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
|
||||
stream, found := server.streams[simpleStreamName]
|
||||
if !found {
|
||||
return fmt.Errorf("stream '%s' not found", simpleStreamName)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := channel.Close(); err == nil && closeErr != io.EOF {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
errc := make(chan error, 2)
|
||||
copy := func(dst io.Writer, src io.Reader) {
|
||||
if _, err := io.Copy(dst, src); err != nil {
|
||||
errc <- err
|
||||
}
|
||||
}
|
||||
|
||||
csStream := newConcurrentStream(stream)
|
||||
go copy(csStream, channel)
|
||||
go copy(channel, csStream)
|
||||
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
func handleChannel(server *Server, channel ssh.Channel) {
|
||||
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
|
||||
jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server))
|
||||
}
|
||||
|
||||
type RPCHandleFunc func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error)
|
||||
|
||||
type rpcHandler struct {
|
||||
server *Server
|
||||
}
|
||||
|
||||
func newRPCHandler(server *Server) *rpcHandler {
|
||||
return &rpcHandler{server}
|
||||
}
|
||||
|
||||
// Handle satisfies the jsonrpc2 pkg handler interface. It tries to find a mocked
|
||||
// RPC service method and if found, it invokes the handler and replies to the request.
|
||||
func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
|
||||
handler, found := r.server.services[req.Method]
|
||||
if !found {
|
||||
sendError(r.server.errCh, fmt.Errorf("RPC Method: '%s' not serviced", req.Method))
|
||||
return
|
||||
}
|
||||
|
||||
result, err := handler(conn, req)
|
||||
if err != nil {
|
||||
sendError(r.server.errCh, fmt.Errorf("error handling: '%s': %w", req.Method, err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.Reply(ctx, req.ID, result); err != nil {
|
||||
sendError(r.server.errCh, fmt.Errorf("error replying: %w", err))
|
||||
}
|
||||
}
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
package livesharetest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type socketConn struct {
|
||||
*websocket.Conn
|
||||
|
||||
reader io.Reader
|
||||
writeMutex sync.Mutex
|
||||
readMutex sync.Mutex
|
||||
}
|
||||
|
||||
func newSocketConn(conn *websocket.Conn) *socketConn {
|
||||
return &socketConn{Conn: conn}
|
||||
}
|
||||
|
||||
func (s *socketConn) Read(b []byte) (int, error) {
|
||||
s.readMutex.Lock()
|
||||
defer s.readMutex.Unlock()
|
||||
|
||||
if s.reader == nil {
|
||||
msgType, r, err := s.Conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next reader: %w", err)
|
||||
}
|
||||
if msgType != websocket.BinaryMessage {
|
||||
return 0, fmt.Errorf("invalid message type")
|
||||
}
|
||||
s.reader = r
|
||||
}
|
||||
|
||||
bytesRead, err := s.reader.Read(b)
|
||||
if err != nil {
|
||||
s.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (s *socketConn) Write(b []byte) (int, error) {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
|
||||
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next writer: %w", err)
|
||||
}
|
||||
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error writing: %w", err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return 0, fmt.Errorf("error closing writer: %w", err)
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *socketConn) SetDeadline(deadline time.Time) error {
|
||||
if err := s.Conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Conn.SetWriteDeadline(deadline)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue