Migrate all Codespaces operations from Live Share to Dev Tunnels (#8149)

* Migrate all Codespaces operations from Live Share to Dev Tunnels

* Remove Live Share references

* Fix linting errors

* Update comments, remove deps, add uint16 bound checks

* Fix tests and move keep-alive logic to forwarder

* Address comments

* Updated mock port forwarder

* Fix CodeQL error

* Update comment

* Update func name

* Add missing connection close

* Fix linting error

* https -> http

* Update defer

* Fix tests
This commit is contained in:
David Gardiner 2023-10-12 15:16:36 -07:00 committed by GitHub
parent 7d6fba0d7d
commit 64f4660ec7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 491 additions and 2211 deletions

1
.github/CODEOWNERS vendored
View file

@ -1,5 +1,4 @@
* @cli/code-reviewers
pkg/cmd/codespace/ @cli/codespaces
pkg/liveshare/ @cli/codespaces
internal/codespaces/ @cli/codespaces

3
go.mod
View file

@ -27,12 +27,11 @@ require (
github.com/mattn/go-colorable v0.1.13
github.com/mattn/go-isatty v0.0.19
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d
github.com/microsoft/dev-tunnels v0.0.21
github.com/microsoft/dev-tunnels v0.0.25
github.com/muhammadmuzzammil1998/jsonc v0.0.0-20201229145248-615b0916ca38
github.com/opentracing/opentracing-go v1.1.0
github.com/rivo/tview v0.0.0-20221029100920-c4a7e501810d
github.com/shurcooL/githubv4 v0.0.0-20230704064427-599ae7bbf278
github.com/sourcegraph/jsonrpc2 v0.1.0
github.com/spf13/cobra v1.6.1
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.8.4

7
go.sum
View file

@ -67,7 +67,6 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaU
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
@ -118,8 +117,8 @@ github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyex
github.com/microcosm-cc/bluemonday v1.0.21/go.mod h1:ytNkv4RrDrLJ2pqlsSI46O6IVXmZOBBD4SaJyDwwTkM=
github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58=
github.com/microcosm-cc/bluemonday v1.0.26/go.mod h1:JyzOCs9gkyQyjs+6h10UEVSe02CGwkhd72Xdqh78TWs=
github.com/microsoft/dev-tunnels v0.0.21 h1:p4QP7C5ZOyP9bGbmanRjPxUMckfi9Z41Gl+KY4C11w0=
github.com/microsoft/dev-tunnels v0.0.21/go.mod h1:frU++12T/oqxckXkDpTuYa427ncguEOodSPZcGCCrzQ=
github.com/microsoft/dev-tunnels v0.0.25 h1:UlMKUI+2O8cSu4RlB52ioSyn1LthYSVkJA+CSTsdKoA=
github.com/microsoft/dev-tunnels v0.0.25/go.mod h1:frU++12T/oqxckXkDpTuYa427ncguEOodSPZcGCCrzQ=
github.com/muesli/reflow v0.2.1-0.20210115123740-9e1d0d53df68/go.mod h1:Xk+z4oIWdQqJzsxyjgl3P22oYZnHdZ8FFTHAQQt5BMQ=
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
@ -150,8 +149,6 @@ github.com/shurcooL/githubv4 v0.0.0-20230704064427-599ae7bbf278 h1:kdEGVAV4sO46D
github.com/shurcooL/githubv4 v0.0.0-20230704064427-599ae7bbf278/go.mod h1:zqMwyHmnN/eDOZOdiTohqIUKUrTFX62PNlu7IJdu0q8=
github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 h1:17JxqqJY66GmZVHkmAsGEkcIu0oCe3AM420QDgGwZx0=
github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466/go.mod h1:9dIRpgIY7hVhoqfe0/FcYp0bpInZaT7dc3BYOprrIUE=
github.com/sourcegraph/jsonrpc2 v0.1.0 h1:ohJHjZ+PcaLxDUjqk2NC3tIGsVa5bXThe1ZheSXOjuk=
github.com/sourcegraph/jsonrpc2 v0.1.0/go.mod h1:ZafdZgk/axhT1cvZAPOhw+95nz2I/Ra5qMlU4gTRwIo=
github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA=
github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=

View file

@ -247,11 +247,6 @@ const (
)
type CodespaceConnection struct {
SessionID string `json:"sessionId"`
SessionToken string `json:"sessionToken"`
RelayEndpoint string `json:"relayEndpoint"`
RelaySAS string `json:"relaySas"`
HostPublicKeys []string `json:"hostPublicKeys"`
TunnelProperties TunnelProperties `json:"tunnelProperties"`
}

View file

@ -11,30 +11,20 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/connection"
"github.com/cli/cli/v2/pkg/liveshare"
)
func connectionReady(codespace *api.Codespace, usingDevTunnels bool) bool {
func connectionReady(codespace *api.Codespace) bool {
// If the codespace is not available, it is not ready
if codespace.State != api.CodespaceStateAvailable {
return false
}
// If using Dev Tunnels, we need to check that we have all of the required tunnel properties
if usingDevTunnels {
return codespace.Connection.TunnelProperties.ConnectAccessToken != "" &&
codespace.Connection.TunnelProperties.ManagePortsAccessToken != "" &&
codespace.Connection.TunnelProperties.ServiceUri != "" &&
codespace.Connection.TunnelProperties.TunnelId != "" &&
codespace.Connection.TunnelProperties.ClusterId != "" &&
codespace.Connection.TunnelProperties.Domain != ""
}
// If not using Dev Tunnels, we need to check that we have all of the required Live Share properties
return codespace.Connection.SessionID != "" &&
codespace.Connection.SessionToken != "" &&
codespace.Connection.RelayEndpoint != "" &&
codespace.Connection.RelaySAS != ""
return codespace.Connection.TunnelProperties.ConnectAccessToken != "" &&
codespace.Connection.TunnelProperties.ManagePortsAccessToken != "" &&
codespace.Connection.TunnelProperties.ServiceUri != "" &&
codespace.Connection.TunnelProperties.TunnelId != "" &&
codespace.Connection.TunnelProperties.ClusterId != "" &&
codespace.Connection.TunnelProperties.Domain != ""
}
type apiClient interface {
@ -48,11 +38,6 @@ type progressIndicator interface {
StopProgressIndicator()
}
type logger interface {
Println(v ...interface{})
Printf(f string, v ...interface{})
}
type TimeoutError struct {
message string
}
@ -64,7 +49,7 @@ func (e *TimeoutError) Error() string {
// GetCodespaceConnection waits until a codespace is able
// to be connected to and initializes a connection to it.
func GetCodespaceConnection(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace) (*connection.CodespaceConnection, error) {
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, true)
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace)
if err != nil {
return nil, err
}
@ -80,29 +65,8 @@ func GetCodespaceConnection(ctx context.Context, progress progressIndicator, api
return connection.NewCodespaceConnection(ctx, codespace, httpClient)
}
// ConnectToLiveshare waits until a codespace is able to be
// connected to and connects to it using a Live Share session.
func ConnectToLiveshare(ctx context.Context, progress progressIndicator, sessionLogger logger, apiClient apiClient, codespace *api.Codespace) (*liveshare.Session, error) {
codespace, err := waitUntilCodespaceConnectionReady(ctx, progress, apiClient, codespace, false)
if err != nil {
return nil, err
}
progress.StartProgressIndicatorWithLabel("Connecting to codespace")
defer progress.StopProgressIndicator()
return liveshare.Connect(ctx, liveshare.Options{
SessionID: codespace.Connection.SessionID,
SessionToken: codespace.Connection.SessionToken,
RelaySAS: codespace.Connection.RelaySAS,
RelayEndpoint: codespace.Connection.RelayEndpoint,
HostPublicKeys: codespace.Connection.HostPublicKeys,
Logger: sessionLogger,
})
}
// waitUntilCodespaceConnectionReady waits for a Codespace to be running and is able to be connected to.
func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace, usingDevTunnels bool) (*api.Codespace, error) {
func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace) (*api.Codespace, error) {
if codespace.State != api.CodespaceStateAvailable {
progress.StartProgressIndicatorWithLabel("Starting codespace")
defer progress.StopProgressIndicator()
@ -111,7 +75,7 @@ func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressInd
}
}
if !connectionReady(codespace, usingDevTunnels) {
if !connectionReady(codespace) {
expBackoff := backoff.NewExponentialBackOff()
expBackoff.Multiplier = 1.1
expBackoff.MaxInterval = 10 * time.Second
@ -124,7 +88,7 @@ func waitUntilCodespaceConnectionReady(ctx context.Context, progress progressInd
return backoff.Permanent(fmt.Errorf("error getting codespace: %w", err))
}
if connectionReady(codespace, usingDevTunnels) {
if connectionReady(codespace) {
return nil
}

View file

@ -3,6 +3,7 @@ package portforwarder
import (
"context"
"fmt"
"io"
"net"
"strings"
@ -22,101 +23,75 @@ const (
PublicPortVisibility = "public"
)
type PortForwarder struct {
connection connection.CodespaceConnection
const (
trafficTypeInput = "input"
trafficTypeOutput = "output"
)
type ForwardPortOpts struct {
Port int
Internal bool
KeepAlive bool
Visibility string
}
type CodespacesPortForwarder struct {
connection connection.CodespaceConnection
keepAliveReason chan string
}
type PortForwarder interface {
ForwardPortToListener(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) error
ForwardPort(ctx context.Context, opts ForwardPortOpts) error
ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, opts ForwardPortOpts) error
ListPorts(ctx context.Context) ([]*tunnels.TunnelPort, error)
UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error
KeepAlive(reason string)
GetKeepAliveReason() string
CloseSSHConnection()
}
// NewPortForwarder returns a new PortForwarder for the specified codespace.
func NewPortForwarder(ctx context.Context, codespaceConnection *connection.CodespaceConnection) (fwd *PortForwarder, err error) {
return &PortForwarder{
connection: *codespaceConnection,
func NewPortForwarder(ctx context.Context, codespaceConnection *connection.CodespaceConnection) (fwd PortForwarder, err error) {
return &CodespacesPortForwarder{
connection: *codespaceConnection,
keepAliveReason: make(chan string, 1),
}, nil
}
// ForwardAndConnectToPort forwards a port and connects to it via a local TCP port.
func (fwd *PortForwarder) ForwardAndConnectToPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, internal bool) error {
return fwd.ForwardPort(ctx, remotePort, listen, keepAlive, true, internal, "")
}
// ForwardPort forwards a port and optionally connects to it via a local TCP port.
func (fwd *PortForwarder) ForwardPort(ctx context.Context, remotePort uint16, listen *net.TCPListener, keepAlive bool, connect bool, internal bool, visibility string) error {
tunnelPort := tunnels.NewTunnelPort(remotePort, "", "", tunnels.TunnelProtocolHttp)
// If no visibility is provided, Dev Tunnels will use the default (private)
if visibility != "" {
// Check if the requested visibility is allowed
allowed := false
for _, allowedVisibility := range fwd.connection.AllowedPortPrivacySettings {
if allowedVisibility == visibility {
allowed = true
break
}
}
// If the requested visibility is not allowed, return an error
if !allowed {
return fmt.Errorf("visibility %s is not allowed", visibility)
}
accessControlEntries := visibilityToAccessControlEntries(visibility)
if len(accessControlEntries) > 0 {
tunnelPort.AccessControl = &tunnels.TunnelAccessControl{
Entries: accessControlEntries,
}
}
// ForwardPortToListener forwards the specified port to the given TCP listener.
func (fwd *CodespacesPortForwarder) ForwardPortToListener(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) error {
err := fwd.ForwardPort(ctx, opts)
if err != nil {
return fmt.Errorf("error forwarding port: %w", err)
}
// Tag the port as internal or user forwarded so we know if it needs to be shown in the UI
if internal {
tunnelPort.Tags = []string{InternalPortTag}
} else {
tunnelPort.Tags = []string{UserForwardedPortTag}
}
// Create the tunnel port
_, err := fwd.connection.TunnelManager.CreateTunnelPort(ctx, fwd.connection.Tunnel, tunnelPort, fwd.connection.Options)
if err != nil && !strings.Contains(err.Error(), "409") {
return fmt.Errorf("create tunnel port failed: %v", err)
}
// Close the SSH connection when we're done
defer fwd.CloseSSHConnection()
done := make(chan error)
go func() {
// Connect to the tunnel
err = fwd.connection.TunnelClient.Connect(ctx, "")
// Convert the port number to a uint16
port, err := convertIntToUint16(opts.Port)
if err != nil {
done <- fmt.Errorf("connect failed: %v", err)
return
}
// Inform the host that we've forwarded the port locally
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
if err != nil {
done <- fmt.Errorf("refresh ports failed: %v", err)
return
}
// If we don't want to connect to the port, exit early
if !connect {
done <- nil
done <- fmt.Errorf("error converting port: %w", err)
return
}
// Ensure the port is forwarded before connecting
err = fwd.connection.TunnelClient.WaitForForwardedPort(ctx, remotePort)
err = fwd.connection.TunnelClient.WaitForForwardedPort(ctx, port)
if err != nil {
done <- fmt.Errorf("wait for forwarded port failed: %v", err)
return
}
// Connect to the forwarded port via a local TCP port
err = fwd.connection.TunnelClient.ConnectToForwardedPort(ctx, listen, remotePort)
// Connect to the forwarded port
err = fwd.connectListenerToForwardedPort(ctx, opts, listener)
if err != nil {
done <- fmt.Errorf("connect to forwarded port failed: %v", err)
return
}
done <- nil
}()
select {
case err := <-done:
if err != nil {
@ -128,8 +103,131 @@ func (fwd *PortForwarder) ForwardPort(ctx context.Context, remotePort uint16, li
}
}
// ForwardPort informs the host that we would like to forward the given port.
func (fwd *CodespacesPortForwarder) ForwardPort(ctx context.Context, opts ForwardPortOpts) error {
// Convert the port number to a uint16
port, err := convertIntToUint16(opts.Port)
if err != nil {
return fmt.Errorf("error converting port: %w", err)
}
tunnelPort := tunnels.NewTunnelPort(port, "", "", tunnels.TunnelProtocolHttp)
// If no visibility is provided, Dev Tunnels will use the default (private)
if opts.Visibility != "" {
// Check if the requested visibility is allowed
allowed := false
for _, allowedVisibility := range fwd.connection.AllowedPortPrivacySettings {
if allowedVisibility == opts.Visibility {
allowed = true
break
}
}
// If the requested visibility is not allowed, return an error
if !allowed {
return fmt.Errorf("visibility %s is not allowed", opts.Visibility)
}
accessControlEntries := visibilityToAccessControlEntries(opts.Visibility)
if len(accessControlEntries) > 0 {
tunnelPort.AccessControl = &tunnels.TunnelAccessControl{
Entries: accessControlEntries,
}
}
}
// Tag the port as internal or user forwarded so we know if it needs to be shown in the UI
if opts.Internal {
tunnelPort.Tags = []string{InternalPortTag}
} else {
tunnelPort.Tags = []string{UserForwardedPortTag}
}
// Create the tunnel port
_, err = fwd.connection.TunnelManager.CreateTunnelPort(ctx, fwd.connection.Tunnel, tunnelPort, fwd.connection.Options)
if err != nil && !strings.Contains(err.Error(), "409") {
return fmt.Errorf("create tunnel port failed: %v", err)
}
// Connect to the tunnel
err = fwd.connection.TunnelClient.Connect(ctx, "")
if err != nil {
return fmt.Errorf("connect failed: %v", err)
}
// Inform the host that we've forwarded the port locally
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
if err != nil {
fwd.CloseSSHConnection()
return fmt.Errorf("refresh ports failed: %v", err)
}
return nil
}
// connectListenerToForwardedPort connects to the forwarded port via a local TCP port.
func (fwd *CodespacesPortForwarder) connectListenerToForwardedPort(ctx context.Context, opts ForwardPortOpts, listener *net.TCPListener) (err error) {
errc := make(chan error, 1)
sendError := func(err error) {
// Use non-blocking send, to avoid goroutines getting
// stuck in case of concurrent or sequential errors.
select {
case errc <- err:
default:
}
}
go func() {
for {
conn, err := listener.AcceptTCP()
if err != nil {
sendError(err)
return
}
// Connect to the forwarded port in a goroutine so we can accept new connections
go func() {
if err := fwd.ConnectToForwardedPort(ctx, conn, opts); err != nil {
sendError(err)
}
}()
}
}()
// Wait for an error or for the context to be cancelled
select {
case err := <-errc:
return err
case <-ctx.Done():
return ctx.Err() // canceled
}
}
// ConnectToForwardedPort connects to the forwarded port via a given ReadWriteCloser.
// Optionally, it detects traffic over the connection and sends activity signals to the server to keep the codespace from shutting down.
func (fwd *CodespacesPortForwarder) ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, opts ForwardPortOpts) error {
// Create a traffic monitor to keep the session alive
if opts.KeepAlive {
conn = newTrafficMonitor(conn, fwd)
}
// Convert the port number to a uint16
port, err := convertIntToUint16(opts.Port)
if err != nil {
return fmt.Errorf("error converting port: %w", err)
}
// Connect to the forwarded port
err = fwd.connection.TunnelClient.ConnectToForwardedPort(ctx, conn, port)
if err != nil {
return fmt.Errorf("error connecting to forwarded port: %w", err)
}
return nil
}
// ListPorts fetches the list of ports that are currently forwarded.
func (fwd *PortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.TunnelPort, err error) {
func (fwd *CodespacesPortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.TunnelPort, err error) {
ports, err = fwd.connection.TunnelManager.ListTunnelPorts(ctx, fwd.connection.Tunnel, fwd.connection.Options)
if err != nil {
return nil, fmt.Errorf("error listing ports: %w", err)
@ -139,7 +237,7 @@ func (fwd *PortForwarder) ListPorts(ctx context.Context) (ports []*tunnels.Tunne
}
// UpdatePortVisibility changes the visibility (private, org, public) of the specified port.
func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error {
func (fwd *CodespacesPortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error {
tunnelPort, err := fwd.connection.TunnelManager.GetTunnelPort(ctx, fwd.connection.Tunnel, remotePort, fwd.connection.Options)
if err != nil {
return fmt.Errorf("error getting tunnel port: %w", err)
@ -165,6 +263,9 @@ func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort i
return
}
// Close the SSH connection when we're done
defer fwd.CloseSSHConnection()
// Inform the host that we've deleted the port
err = fwd.connection.TunnelClient.RefreshPorts(ctx)
if err != nil {
@ -172,6 +273,13 @@ func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort i
return
}
// Re-forward the port with the updated visibility
err = fwd.ForwardPort(ctx, ForwardPortOpts{Port: remotePort, Visibility: visibility})
if err != nil {
done <- fmt.Errorf("error forwarding port: %w", err)
return
}
done <- nil
}()
@ -179,13 +287,10 @@ func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort i
select {
case err := <-done:
if err != nil {
return fmt.Errorf("error connecting to tunnel: %w", err)
}
// If we fail to re-forward the port, we need to forward again with the original visibility so the port is still accessible
_ = fwd.ForwardPort(ctx, ForwardPortOpts{Port: remotePort, Visibility: AccessControlEntriesToVisibility(tunnelPort.AccessControl.Entries)})
// Re-forward the port with the updated visibility
err = fwd.ForwardPort(ctx, uint16(remotePort), nil, false, false, false, visibility)
if err != nil {
return fmt.Errorf("error forwarding port: %w", err)
return fmt.Errorf("error connecting to tunnel: %w", err)
}
return nil
@ -194,6 +299,27 @@ func (fwd *PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort i
}
}
// KeepAlive accepts a reason that is retained if there is no active reason
// to send to the server.
func (fwd *CodespacesPortForwarder) KeepAlive(reason string) {
select {
case fwd.keepAliveReason <- reason:
default:
// there is already an active keep alive reason
// so we can ignore this one
}
}
// GetKeepAliveReason fetches the keep alive reason from the channel and returns it.
func (fwd *CodespacesPortForwarder) GetKeepAliveReason() string {
return <-fwd.keepAliveReason
}
// Close closes the port forwarder's tunnel client connection.
func (fwd *CodespacesPortForwarder) CloseSSHConnection() {
_ = fwd.connection.TunnelClient.Close()
}
// AccessControlEntriesToVisibility converts the access control entries used by Dev Tunnels to a friendly visibility value.
func AccessControlEntriesToVisibility(accessControlEntries []tunnels.TunnelAccessControlEntry) string {
for _, entry := range accessControlEntries {
@ -251,3 +377,45 @@ func IsInternalPort(port *tunnels.TunnelPort) bool {
return false
}
// convertIntToUint16 converts the given int to a uint16.
func convertIntToUint16(port int) (uint16, error) {
var updatedPort uint16
if port >= 0 && port <= 65535 {
updatedPort = uint16(port)
} else {
return 0, fmt.Errorf("invalid port number: %d", port)
}
return updatedPort, nil
}
// trafficMonitor implements io.Reader. It keeps the session alive by notifying
// it of the traffic type during Read operations.
type trafficMonitor struct {
rwc io.ReadWriteCloser
fwd PortForwarder
}
// newTrafficMonitor returns a trafficMonitor for the specified codespace connection.
// It wraps the provided io.ReaderWriteCloser with its own Read/Write/Close methods.
func newTrafficMonitor(rwc io.ReadWriteCloser, fwd PortForwarder) *trafficMonitor {
return &trafficMonitor{rwc, fwd}
}
// Read wraps the underlying ReadWriteCloser's Read method and keeps the session alive with the "input" traffic type.
func (t *trafficMonitor) Read(p []byte) (n int, err error) {
t.fwd.KeepAlive(trafficTypeInput)
return t.rwc.Read(p)
}
// Write wraps the underlying ReadWriteCloser's Write method and keeps the session alive with the "output" traffic type.
func (t *trafficMonitor) Write(p []byte) (n int, err error) {
t.fwd.KeepAlive(trafficTypeOutput)
return t.rwc.Write(p)
}
// Close closes the underlying ReadWriteCloser.
func (t *trafficMonitor) Close() error {
return t.rwc.Close()
}

View file

@ -12,10 +12,10 @@ import (
"strings"
"time"
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
"github.com/cli/cli/v2/internal/codespaces/rpc/codespace"
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
"github.com/cli/cli/v2/internal/codespaces/rpc/ssh"
"github.com/cli/cli/v2/pkg/liveshare"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
@ -47,7 +47,7 @@ type Invoker interface {
type invoker struct {
conn *grpc.ClientConn
session liveshare.LiveshareSession
fwd portforwarder.PortForwarder
listener net.Listener
jupyterClient jupyter.JupyterServerHostClient
codespaceClient codespace.CodespaceHostClient
@ -56,11 +56,11 @@ type invoker struct {
}
// Connects to the internal RPC server and returns a new invoker for it
func CreateInvoker(ctx context.Context, session liveshare.LiveshareSession) (Invoker, error) {
func CreateInvoker(ctx context.Context, fwd portforwarder.PortForwarder) (Invoker, error) {
ctx, cancel := context.WithTimeout(ctx, ConnectionTimeout)
defer cancel()
invoker, err := connect(ctx, session)
invoker, err := connect(ctx, fwd)
if err != nil {
return nil, fmt.Errorf("error connecting to internal server: %w", err)
}
@ -69,7 +69,7 @@ func CreateInvoker(ctx context.Context, session liveshare.LiveshareSession) (Inv
}
// Finds a free port to listen on and creates a new RPC invoker that connects to that port
func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, error) {
func connect(ctx context.Context, fwd portforwarder.PortForwarder) (Invoker, error) {
listener, err := listenTCP()
if err != nil {
return nil, err
@ -77,7 +77,7 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker,
localAddress := listener.Addr().String()
invoker := &invoker{
session: session,
fwd: fwd,
listener: listener,
}
@ -100,8 +100,12 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker,
// Tunnel the remote gRPC server port to the local port
go func() {
fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true)
ch <- fwd.ForwardToListener(pfctx, listener)
// Start forwarding the port locally
opts := portforwarder.ForwardPortOpts{
Port: codespacesInternalPort,
Internal: true,
}
ch <- fwd.ForwardPortToListener(pfctx, opts, listener)
}()
var conn *grpc.ClientConn
@ -262,7 +266,7 @@ func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) {
case <-ctx.Done():
return
case <-ticker.C:
reason := i.session.GetKeepAliveReason()
reason := i.fwd.GetKeepAliveReason()
_ = i.notifyCodespaceOfClientActivity(ctx, reason)
}
}

View file

@ -72,7 +72,8 @@ func createTestInvoker(t *testing.T, server *mockServer) (Invoker, func(), error
listener.Close()
}
invoker, err := CreateInvoker(context.Background(), &rpctest.Session{})
// Create a new invoker with a mock port forwarder
invoker, err := CreateInvoker(context.Background(), rpctest.PortForwarder{})
if err != nil {
close()
return nil, nil, fmt.Errorf("error connecting to internal server: %w", err)

View file

@ -1,34 +0,0 @@
package test
import (
"io"
"net"
)
type Channel struct {
conn net.Conn
}
func (c *Channel) Read(data []byte) (int, error) {
return c.conn.Read(data)
}
func (c *Channel) Write(data []byte) (int, error) {
return c.conn.Write(data)
}
func (c *Channel) Close() error {
return c.conn.Close()
}
func (c *Channel) CloseWrite() error {
return nil
}
func (c *Channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
return false, nil
}
func (c *Channel) Stderr() io.ReadWriter {
return nil
}

View file

@ -0,0 +1,78 @@
package test
import (
"context"
"fmt"
"io"
"net"
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
"github.com/microsoft/dev-tunnels/go/tunnels"
)
type PortForwarder struct{}
// Close implements portforwarder.PortForwarder.
func (PortForwarder) CloseSSHConnection() {
panic("unimplemented")
}
// ConnectToForwardedPort implements portforwarder.PortForwarder.
func (PortForwarder) ConnectToForwardedPort(ctx context.Context, conn io.ReadWriteCloser, opts portforwarder.ForwardPortOpts) error {
panic("unimplemented")
}
// ForwardPort implements portforwarder.PortForwarder.
func (PortForwarder) ForwardPort(ctx context.Context, opts portforwarder.ForwardPortOpts) error {
panic("unimplemented")
}
// GetKeepAliveReason implements portforwarder.PortForwarder.
func (PortForwarder) GetKeepAliveReason() string {
panic("unimplemented")
}
// KeepAlive implements portforwarder.PortForwarder.
func (PortForwarder) KeepAlive(reason string) {
panic("unimplemented")
}
// ForwardPortToListener implements portforwarder.PortForwarder.
func (PortForwarder) ForwardPortToListener(ctx context.Context, opts portforwarder.ForwardPortOpts, listener *net.TCPListener) error {
// Start forwarding the port locally
hostConn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", opts.Port))
if err != nil {
return err
}
// Accept the connection from the listener
listenerConn, err := listener.Accept()
if err != nil {
return err
}
// Copy data between the two connections
go func() {
_, _ = io.Copy(hostConn, listenerConn)
hostConn.Close()
}()
go func() {
_, _ = io.Copy(listenerConn, hostConn)
listenerConn.Close()
}()
// ForwardPortToListener typically blocks until the context is cancelled so we need to do the same
<-ctx.Done()
return nil
}
// ListPorts implements portforwarder.PortForwarder.
func (PortForwarder) ListPorts(ctx context.Context) ([]*tunnels.TunnelPort, error) {
panic("unimplemented")
}
// UpdatePortVisibility implements portforwarder.PortForwarder.
func (PortForwarder) UpdatePortVisibility(ctx context.Context, remotePort int, visibility string) error {
panic("unimplemented")
}

View file

@ -1,43 +0,0 @@
package test
import (
"context"
"fmt"
"net"
"github.com/cli/cli/v2/pkg/liveshare"
"golang.org/x/crypto/ssh"
)
type Session struct {
channel ssh.Channel
}
func (*Session) Close() error {
panic("unimplemented")
}
func (*Session) GetSharedServers(context.Context) ([]*liveshare.Port, error) {
panic("unimplemented")
}
func (s *Session) KeepAlive(reason string) {
}
func (s *Session) GetKeepAliveReason() string {
return ""
}
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) {
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return liveshare.ChannelID{}, err
}
s.channel = &Channel{conn}
return liveshare.ChannelID{}, nil
}
// Creates mock SSH channel connected to the mock gRPC server
func (s *Session) OpenStreamingChannel(ctx context.Context, id liveshare.ChannelID) (ssh.Channel, error) {
return s.channel, nil
}

View file

@ -6,13 +6,12 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"time"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/internal/text"
"github.com/cli/cli/v2/pkg/liveshare"
)
// PostCreateStateStatus is a string value representing the different statuses a state can have.
@ -39,17 +38,15 @@ type PostCreateState struct {
// and calls the supplied poller for each batch of state changes.
// It runs until it encounters an error, including cancellation of the context.
func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiClient apiClient, codespace *api.Codespace, poller func([]PostCreateState)) (err error) {
noopLogger := log.New(io.Discard, "", 0)
session, err := ConnectToLiveshare(ctx, progress, noopLogger, apiClient, codespace)
codespaceConnection, err := GetCodespaceConnection(ctx, progress, apiClient, codespace)
if err != nil {
return fmt.Errorf("connect to codespace: %w", err)
return fmt.Errorf("error connecting to codespace: %w", err)
}
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
defer func() {
if closeErr := session.Close(); err == nil {
err = closeErr
}
}()
// Ensure local port is listening before client (getPostCreateOutput) connects.
listen, localPort, err := ListenTCP(0, false)
@ -58,7 +55,7 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl
}
progress.StartProgressIndicatorWithLabel("Fetching SSH Details")
invoker, err := rpc.CreateInvoker(ctx, session)
invoker, err := rpc.CreateInvoker(ctx, fwd)
if err != nil {
return err
}
@ -73,8 +70,11 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl
progress.StartProgressIndicatorWithLabel("Fetching status")
tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false)
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
opts := portforwarder.ForwardPortOpts{
Port: remoteSSHServerPort,
Internal: true,
}
tunnelClosed <- fwd.ForwardPortToListener(ctx, opts, listen)
}()
t := time.NewTicker(1 * time.Second)

View file

@ -17,10 +17,8 @@ import (
"github.com/AlecAivazis/survey/v2/terminal"
clicontext "github.com/cli/cli/v2/context"
"github.com/cli/cli/v2/internal/browser"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/spf13/cobra"
"golang.org/x/term"
)
@ -65,28 +63,6 @@ func (a *App) RunWithProgress(label string, run func() error) error {
return a.io.RunWithProgress(label, run)
}
// Connects to a codespace using Live Share and returns that session
func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App, debug bool, debugFile string) (session *liveshare.Session, err error) {
liveshareLogger := noopLogger()
if debug {
debugLogger, err := newFileLogger(debugFile)
if err != nil {
return nil, fmt.Errorf("couldn't create file logger: %w", err)
}
defer safeClose(debugLogger, &err)
liveshareLogger = debugLogger.Logger
a.errLogger.Printf("Debug file located at: %s", debugLogger.Name())
}
session, err = codespaces.ConnectToLiveshare(ctx, a, liveshareLogger, a.apiClient, codespace)
if err != nil {
return nil, fmt.Errorf("failed to connect to Live Share: %w", err)
}
return session, nil
}
//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient
type apiClient interface {
ServerURL() string
@ -201,10 +177,6 @@ func noArgsConstraint(cmd *cobra.Command, args []string) error {
return nil
}
func noopLogger() *log.Logger {
return log.New(io.Discard, "", 0)
}
type codespace struct {
*api.Codespace
}

View file

@ -7,8 +7,8 @@ import (
"strings"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/spf13/cobra"
)
@ -39,11 +39,15 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err
return err
}
session, err := startLiveShareSession(ctx, codespace, a, false, "")
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
if err != nil {
return err
return fmt.Errorf("error connecting to codespace: %w", err)
}
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
defer safeClose(session, &err)
var (
invoker rpc.Invoker
@ -51,7 +55,7 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err
serverUrl string
)
err = a.RunWithProgress("Starting JupyterLab on codespace", func() (err error) {
invoker, err = rpc.CreateInvoker(ctx, session)
invoker, err = rpc.CreateInvoker(ctx, fwd)
if err != nil {
return
}
@ -76,8 +80,10 @@ func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err err
tunnelClosed := make(chan error, 1)
go func() {
fwd := liveshare.NewPortForwarder(session, "jupyter", serverPort, true)
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil
opts := portforwarder.ForwardPortOpts{
Port: serverPort,
}
tunnelClosed <- fwd.ForwardPortToListener(ctx, opts, listen)
}()
// Server URL contains an authentication token that must be preserved

View file

@ -5,8 +5,8 @@ import (
"fmt"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/spf13/cobra"
)
@ -42,11 +42,15 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool
return err
}
session, err := startLiveShareSession(ctx, codespace, a, false, "")
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
if err != nil {
return err
return fmt.Errorf("error connecting to codespace: %w", err)
}
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
defer safeClose(session, &err)
// Ensure local port is listening before client (getPostCreateOutput) connects.
listen, localPort, err := codespaces.ListenTCP(0, false)
@ -57,7 +61,7 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool
remoteSSHServerPort, sshUser := 0, ""
err = a.RunWithProgress("Fetching SSH Details", func() (err error) {
invoker, err := rpc.CreateInvoker(ctx, session)
invoker, err := rpc.CreateInvoker(ctx, fwd)
if err != nil {
return
}
@ -85,8 +89,11 @@ func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool
tunnelClosed := make(chan error, 1)
go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, false)
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
opts := portforwarder.ForwardPortOpts{
Port: remoteSSHServerPort,
Internal: true,
}
tunnelClosed <- fwd.ForwardPortToListener(ctx, opts, listen)
}()
cmdDone := make(chan error, 1)

View file

@ -345,7 +345,11 @@ func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, por
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
return fwd.ForwardAndConnectToPort(ctx, uint16(pair.remote), listen, false, false)
opts := portforwarder.ForwardPortOpts{
Port: pair.remote,
}
return fwd.ForwardPortToListener(ctx, opts, listen)
})
}
return group.Wait() // first error

View file

@ -4,7 +4,9 @@ import (
"context"
"fmt"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/spf13/cobra"
)
@ -49,13 +51,17 @@ func (a *App) Rebuild(ctx context.Context, selector *CodespaceSelector, full boo
return nil
}
session, err := startLiveShareSession(ctx, codespace, a, false, "")
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
if err != nil {
return fmt.Errorf("starting Live Share session: %w", err)
return fmt.Errorf("error connecting to codespace: %w", err)
}
defer safeClose(session, &err)
invoker, err := rpc.CreateInvoker(ctx, session)
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
invoker, err := rpc.CreateInvoker(ctx, fwd)
if err != nil {
return err
}

View file

@ -6,7 +6,7 @@ import (
"context"
"errors"
"fmt"
"log"
"io"
"os"
"os/exec"
"path"
@ -18,10 +18,10 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/portforwarder"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/cli/cli/v2/pkg/ssh"
"github.com/cli/safeexec"
"github.com/spf13/cobra"
@ -144,6 +144,24 @@ func newSSHCmd(app *App) *cobra.Command {
return sshCmd
}
type combinedReadWriteHalfCloser struct {
io.ReadCloser
io.WriteCloser
}
func (crwc *combinedReadWriteHalfCloser) Close() error {
werr := crwc.WriteCloser.Close()
rerr := crwc.ReadCloser.Close()
if werr != nil {
return werr
}
return rerr
}
func (crwc *combinedReadWriteHalfCloser) CloseWrite() error {
return crwc.WriteCloser.Close()
}
// SSH opens an ssh session or runs an ssh command in a codespace.
func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err error) {
// Ensure all child tasks (e.g. port forwarding) terminate before return.
@ -175,11 +193,15 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
return err
}
session, err := startLiveShareSession(ctx, codespace, a, opts.debug, opts.debugFile)
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, codespace)
if err != nil {
return err
return fmt.Errorf("error connecting to codespace: %w", err)
}
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
defer safeClose(session, &err)
var (
invoker rpc.Invoker
@ -187,7 +209,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
sshUser string
)
err = a.RunWithProgress("Fetching SSH Details", func() (err error) {
invoker, err = rpc.CreateInvoker(ctx, session)
invoker, err = rpc.CreateInvoker(ctx, fwd)
if err != nil {
return
}
@ -203,9 +225,28 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
}
if opts.stdio {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true)
stdio := liveshare.NewReadWriteHalfCloser(os.Stdin, os.Stdout)
err := fwd.Forward(ctx, stdio) // always non-nil
stdio := &combinedReadWriteHalfCloser{os.Stdin, os.Stdout}
opts := portforwarder.ForwardPortOpts{
Port: remoteSSHServerPort,
Internal: true,
KeepAlive: true,
}
// Forward the port
err = fwd.ForwardPort(ctx, opts)
if err != nil {
return fmt.Errorf("failed to forward port: %w", err)
}
// Close the SSH connection when we're done
defer fwd.CloseSSHConnection()
// Connect to the forwarded port
err = fwd.ConnectToForwardedPort(ctx, stdio, opts)
if err != nil {
return fmt.Errorf("failed to connect to forwarded port: %w", err)
}
return fmt.Errorf("tunnel closed: %w", err)
}
@ -227,8 +268,12 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
tunnelClosed := make(chan error, 1)
go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true)
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil
opts := portforwarder.ForwardPortOpts{
Port: remoteSSHServerPort,
Internal: true,
KeepAlive: true,
}
tunnelClosed <- fwd.ForwardPortToListener(ctx, opts, listen)
}()
shellClosed := make(chan error, 1)
@ -526,27 +571,36 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro
result := sshResult{}
defer wg.Done()
session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, cs)
codespaceConnection, err := codespaces.GetCodespaceConnection(ctx, a, a.apiClient, cs)
if err != nil {
result.err = fmt.Errorf("error connecting to codespace: %w", err)
} else {
defer safeClose(session, &err)
invoker, err := rpc.CreateInvoker(ctx, session)
if err != nil {
result.err = fmt.Errorf("error connecting to codespace: %w", err)
} else {
defer safeClose(invoker, &err)
_, result.user, err = invoker.StartSSHServer(ctx)
if err != nil {
result.err = fmt.Errorf("error getting ssh server details: %w", err)
} else {
result.codespace = cs
}
}
sshUsers <- result
return
}
fwd, err := portforwarder.NewPortForwarder(ctx, codespaceConnection)
if err != nil {
result.err = fmt.Errorf("failed to create port forwarder: %w", err)
sshUsers <- result
return
}
invoker, err := rpc.CreateInvoker(ctx, fwd)
if err != nil {
result.err = fmt.Errorf("error connecting to codespace: %w", err)
sshUsers <- result
return
}
defer safeClose(invoker, &err)
_, result.user, err = invoker.StartSSHServer(ctx)
if err != nil {
result.err = fmt.Errorf("error getting ssh server details: %w", err)
sshUsers <- result
return
}
result.codespace = cs
sshUsers <- result
}()
}
@ -722,43 +776,3 @@ func (a *App) Copy(ctx context.Context, args []string, opts cpOptions) error {
}
return a.SSH(ctx, nil, opts.sshOptions)
}
// fileLogger is a wrapper around an log.Logger configured to write
// to a file. It exports two additional methods to get the log file name
// and close the file handle when the operation is finished.
type fileLogger struct {
*log.Logger
f *os.File
}
// newFileLogger creates a new fileLogger. It returns an error if the file
// cannot be created. The file is created on the specified path, if the path
// is empty it is created in the temporary directory.
func newFileLogger(file string) (fl *fileLogger, err error) {
var f *os.File
if file == "" {
f, err = os.CreateTemp("", "")
if err != nil {
return nil, fmt.Errorf("failed to create tmp file: %w", err)
}
} else {
f, err = os.Create(file)
if err != nil {
return nil, err
}
}
return &fileLogger{
Logger: log.New(f, "", log.LstdFlags),
f: f,
}, nil
}
func (fl *fileLogger) Name() string {
return fl.f.Name()
}
func (fl *fileLogger) Close() error {
return fl.f.Close()
}

View file

@ -1,130 +0,0 @@
// Package liveshare is a Go client library for the Visual Studio Live Share
// service, which provides collaborative, distributed editing and debugging.
// See https://docs.microsoft.com/en-us/visualstudio/liveshare for an overview.
//
// It provides the ability for a Go program to connect to a Live Share
// workspace (Connect), to expose a TCP port on a remote host
// (UpdateSharedVisibility), to start an SSH server listening on an
// exposed port (StartSSHServer), and to forward connections between
// the remote port and a local listening TCP port (ForwardToListener)
// or a local Go reader/writer (Forward).
package liveshare
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/url"
"strings"
"github.com/opentracing/opentracing-go"
)
type logger interface {
Println(v ...interface{})
Printf(f string, v ...interface{})
}
// An Options specifies Live Share connection parameters.
type Options struct {
SessionID string
SessionToken string // token for SSH session
RelaySAS string
RelayEndpoint string
HostPublicKeys []string
Logger logger // required
TLSConfig *tls.Config // (optional)
}
// uri returns a websocket URL for the specified options.
func (opts *Options) uri(action string) (string, error) {
if opts.SessionID == "" {
return "", errors.New("SessionID is required")
}
if opts.RelaySAS == "" {
return "", errors.New("RelaySAS is required")
}
if opts.RelayEndpoint == "" {
return "", errors.New("RelayEndpoint is required")
}
sas := url.QueryEscape(opts.RelaySAS)
uri := opts.RelayEndpoint
if strings.HasPrefix(uri, "http:") {
uri = strings.Replace(uri, "http:", "ws:", 1)
} else {
uri = strings.Replace(uri, "sb:", "wss:", -1)
}
uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1)
uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas
return uri, nil
}
// Connect connects to a Live Share workspace specified by the
// options, and returns a session representing the connection.
// The caller must call the session's Close method to end the session.
func Connect(ctx context.Context, opts Options) (*Session, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
defer span.Finish()
uri, err := opts.uri("connect")
if err != nil {
return nil, err
}
sock := newSocket(uri, opts.TLSConfig)
if err := sock.connect(ctx); err != nil {
return nil, fmt.Errorf("error connecting websocket: %w", err)
}
if opts.SessionToken == "" {
return nil, errors.New("SessionToken is required")
}
ssh := newSSHSession(opts.SessionToken, opts.HostPublicKeys, sock)
if err := ssh.connect(ctx); err != nil {
return nil, fmt.Errorf("error connecting to ssh session: %w", err)
}
rpc := newRPCClient(ssh)
rpc.connect(ctx)
args := joinWorkspaceArgs{
ID: opts.SessionID,
ConnectionMode: "local",
JoiningUserSessionToken: opts.SessionToken,
ClientCapabilities: clientCapabilities{
IsNonInteractive: false,
},
}
var result joinWorkspaceResult
if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil {
return nil, fmt.Errorf("error joining Live Share workspace: %w", err)
}
s := &Session{
ssh: ssh,
rpc: rpc,
keepAliveReason: make(chan string, 1),
logger: opts.Logger,
}
return s, nil
}
type clientCapabilities struct {
IsNonInteractive bool `json:"isNonInteractive"`
}
type joinWorkspaceArgs struct {
ID string `json:"id"`
ConnectionMode string `json:"connectionMode"`
JoiningUserSessionToken string `json:"joiningUserSessionToken"`
ClientCapabilities clientCapabilities `json:"clientCapabilities"`
}
type joinWorkspaceResult struct {
SessionNumber int `json:"sessionNumber"`
}

View file

@ -1,73 +0,0 @@
package liveshare
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func TestConnect(t *testing.T) {
opts := Options{
SessionID: "session-id",
SessionToken: "session-token",
RelaySAS: "relay-sas",
HostPublicKeys: []string{livesharetest.SSHPublicKey},
Logger: newMockLogger(),
}
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
var joinWorkspaceReq joinWorkspaceArgs
if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil {
return nil, fmt.Errorf("error unmarshalling req: %w", err)
}
if joinWorkspaceReq.ID != opts.SessionID {
return nil, errors.New("connection session id does not match")
}
if joinWorkspaceReq.ConnectionMode != "local" {
return nil, errors.New("connection mode is not local")
}
if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken {
return nil, errors.New("connection user token does not match")
}
if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false {
return nil, errors.New("non interactive is not false")
}
return joinWorkspaceResult{1}, nil
}
server, err := livesharetest.NewServer(
livesharetest.WithPassword(opts.SessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
livesharetest.WithRelaySAS(opts.RelaySAS),
)
if err != nil {
t.Errorf("error creating Live Share server: %v", err)
}
defer server.Close()
opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https")
ctx := context.Background()
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true}
done := make(chan error)
go func() {
_, err := Connect(ctx, opts) // ignore session
done <- err
}()
select {
case err := <-server.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}

View file

@ -1,56 +0,0 @@
package liveshare
import (
"context"
"testing"
)
func TestBadOptions(t *testing.T) {
goodOptions := Options{
SessionID: "sess-id",
SessionToken: "sess-token",
RelaySAS: "sas",
RelayEndpoint: "endpoint",
}
opts := goodOptions
opts.SessionID = ""
checkBadOptions(t, opts)
opts = goodOptions
opts.SessionToken = ""
checkBadOptions(t, opts)
opts = goodOptions
opts.RelaySAS = ""
checkBadOptions(t, opts)
opts = goodOptions
opts.RelayEndpoint = ""
checkBadOptions(t, opts)
opts = Options{}
checkBadOptions(t, opts)
}
func checkBadOptions(t *testing.T, opts Options) {
if _, err := Connect(context.Background(), opts); err == nil {
t.Errorf("Connect(%+v): no error", opts)
}
}
func TestOptionsURI(t *testing.T) {
opts := Options{
SessionID: "sess-id",
SessionToken: "sess-token",
RelaySAS: "sas",
RelayEndpoint: "sb://endpoint/.net/liveshare",
}
uri, err := opts.uri("connect")
if err != nil {
t.Fatal(err)
}
if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" {
t.Errorf("uri is not correct, got: '%v'", uri)
}
}

View file

@ -1,241 +0,0 @@
package liveshare
import (
"context"
"fmt"
"io"
"net"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/ssh"
)
type portForwardingSession interface {
StartSharing(context.Context, string, int) (ChannelID, error)
OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error)
KeepAlive(string)
}
type ReadWriteHalfCloser interface {
io.ReadWriteCloser
CloseWrite() error
}
type combinedReadWriteHalfCloser struct {
io.ReadCloser
io.WriteCloser
}
func NewReadWriteHalfCloser(reader io.ReadCloser, writer io.WriteCloser) ReadWriteHalfCloser {
return &combinedReadWriteHalfCloser{reader, writer}
}
func (crwc *combinedReadWriteHalfCloser) Close() error {
werr := crwc.WriteCloser.Close()
rerr := crwc.ReadCloser.Close()
if werr != nil {
return werr
}
return rerr
}
func (crwc *combinedReadWriteHalfCloser) CloseWrite() error {
return crwc.WriteCloser.Close()
}
// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote
// container to a local destination such as a network port or Go reader/writer.
type PortForwarder struct {
session portForwardingSession
name string
remotePort int
keepAlive bool
}
// NewPortForwarder returns a new PortForwarder for the specified
// remote port and Live Share session. The name describes the purpose
// of the remote port or service. The keepAlive flag indicates whether
// the session should be kept alive with port forwarding traffic.
func NewPortForwarder(session portForwardingSession, name string, remotePort int, keepAlive bool) *PortForwarder {
return &PortForwarder{
session: session,
name: name,
remotePort: remotePort,
keepAlive: keepAlive,
}
}
// ForwardToListener forwards traffic between the container's remote
// port and a local port, which must already be listening for
// connections. (Accepting a listener rather than a port number avoids
// races against other processes opening ports, and against a client
// connecting to the socket prematurely.)
//
// ForwardToListener accepts and handles connections on the local port
// until it encounters the first error, which may include context
// cancellation. Its error result is always non-nil. The caller is
// responsible for closing the listening port.
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen *net.TCPListener) (err error) {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
return err
}
errc := make(chan error, 1)
sendError := func(err error) {
// Use non-blocking send, to avoid goroutines getting
// stuck in case of concurrent or sequential errors.
select {
case errc <- err:
default:
}
}
go func() {
for {
conn, err := listen.AcceptTCP()
if err != nil {
sendError(err)
return
}
go func() {
if err := fwd.handleConnection(ctx, id, conn); err != nil {
sendError(err)
}
}()
}
}()
return awaitError(ctx, errc)
}
// Forward forwards traffic between the container's remote port and
// the specified read/write stream. On return, the stream is closed.
func (fwd *PortForwarder) Forward(ctx context.Context, conn ReadWriteHalfCloser) error {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
conn.Close()
return err
}
// Create buffered channel so that send doesn't get stuck after context cancellation.
errc := make(chan error, 1)
go func() {
errc <- fwd.handleConnection(ctx, id, conn)
}()
return awaitError(ctx, errc)
}
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) {
id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort)
if err != nil {
err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err)
}
return id, err
}
func awaitError(ctx context.Context, errc <-chan error) error {
select {
case err := <-errc:
return err
case <-ctx.Done():
return ctx.Err() // canceled
}
}
type trafficMonitorSession interface {
KeepAlive(string)
}
// trafficMonitor implements io.Reader. It keeps the session alive by notifying
// it of the traffic type during Read operations.
type trafficMonitor struct {
reader io.Reader
session trafficMonitorSession
trafficType string
}
// newTrafficMonitor returns a new trafficMonitor for the specified
// session and traffic type. It wraps the provided io.Reader with its own
// Read method.
func newTrafficMonitor(reader io.Reader, session trafficMonitorSession, trafficType string) *trafficMonitor {
return &trafficMonitor{reader, session, trafficType}
}
func (t *trafficMonitor) Read(p []byte) (n int, err error) {
t.session.KeepAlive(t.trafficType)
return t.reader.Read(p)
}
// handleConnection handles forwarding for a single accepted connection, then closes it.
func (fwd *PortForwarder) handleConnection(ctx context.Context, id ChannelID, conn ReadWriteHalfCloser) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
defer span.Finish()
defer safeClose(conn, &err)
channel, err := fwd.session.OpenStreamingChannel(ctx, id)
if err != nil {
return fmt.Errorf("error opening streaming channel for new connection: %w", err)
}
// Ideally we would call safeClose again, but (*ssh.channel).Close
// appears to have a bug that causes it return io.EOF spuriously
// if its peer closed first; see github.com/golang/go/issues/38115.
defer func() {
closeErr := channel.Close()
if err == nil && closeErr != io.EOF {
err = closeErr
}
}()
// bi-directional copy of data.
errs := make(chan error, 2)
copyConn := func(w ReadWriteHalfCloser, r io.Reader) {
_, err := io.Copy(w, r)
errs <- err
// Ignore errors here, we call the full Close() later and catch that error
_ = w.CloseWrite()
}
var (
channelReader io.Reader = channel
connReader io.Reader = conn
)
// If the forwader has been configured to keep the session alive
// it will monitor the I/O and notify the session of the traffic.
if fwd.keepAlive {
channelReader = newTrafficMonitor(channelReader, fwd.session, "output")
connReader = newTrafficMonitor(connReader, fwd.session, "input")
}
go copyConn(conn, channelReader)
go copyConn(channel, connReader)
// Wait until context is cancelled or both copies are done.
// Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail.
// TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file?
for i := 0; ; {
select {
case <-ctx.Done():
return ctx.Err()
case <-errs:
i++
if i == 2 {
return nil
}
}
}
}
// safeClose reports the error (to *err) from closing the stream only
// if no other error was previously reported.
func safeClose(closer io.Closer, err *error) {
closeErr := closer.Close()
if *err == nil {
*err = closeErr
}
}

View file

@ -1,153 +0,0 @@
package liveshare
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"os"
"testing"
"time"
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func TestNewPortForwarder(t *testing.T) {
testServer, session, err := makeMockSession()
if err != nil {
t.Errorf("create mock client: %v", err)
}
defer testServer.Close()
pf := NewPortForwarder(session, "ssh", 80, false)
if pf == nil {
t.Error("port forwarder is nil")
}
}
type portUpdateNotification struct {
PortNotification
conn *jsonrpc2.Conn
}
func TestPortForwarderStart(t *testing.T) {
if os.Getenv("GITHUB_ACTIONS") == "true" {
t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5338")
}
streamName, streamCondition := "stream-name", "stream-condition"
const port = 8000
sendNotification := make(chan portUpdateNotification)
serverSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
// Send the PortNotification that will be awaited on in session.StartSharing
sendNotification <- portUpdateNotification{
PortNotification: PortNotification{
Port: port,
ChangeKind: PortChangeKindStart,
},
conn: conn,
}
return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
}
getStream := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
return "stream-id", nil
}
stream := bytes.NewBufferString("stream-data")
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.startSharing", serverSharing),
livesharetest.WithService("streamManager.getStream", getStream),
livesharetest.WithStream("stream-id", stream),
)
if err != nil {
t.Errorf("create mock session: %v", err)
}
defer testServer.Close()
listen, err := net.Listen("tcp", "127.0.0.1:8000")
if err != nil {
t.Fatal(err)
}
defer listen.Close()
tcpListener, ok := listen.(*net.TCPListener)
if !ok {
t.Fatal("net.Listen did not return a TCPListener")
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
notif := <-sendNotification
_, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif)
}()
done := make(chan error, 2)
go func() {
done <- NewPortForwarder(session, "ssh", port, false).ForwardToListener(ctx, tcpListener)
}()
go func() {
var conn net.Conn
// We retry DialTimeout in a loop to deal with a race in PortForwarder startup.
for tries := 0; conn == nil && tries < 2; tries++ {
conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second)
if conn == nil {
time.Sleep(1 * time.Second)
}
}
if conn == nil {
done <- errors.New("failed to connect to forwarded port")
return
}
b := make([]byte, len("stream-data"))
if _, err := conn.Read(b); err != nil && err != io.EOF {
done <- fmt.Errorf("reading stream: %w", err)
return
}
if string(b) != "stream-data" {
done <- fmt.Errorf("stream data is not expected value, got: %s", string(b))
return
}
if _, err := conn.Write([]byte("new-data")); err != nil {
done <- fmt.Errorf("writing to stream: %w", err)
return
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}
func TestPortForwarderTrafficMonitor(t *testing.T) {
buf := bytes.NewBufferString("some-input")
session := &Session{keepAliveReason: make(chan string, 1)}
trafficType := "io"
tm := newTrafficMonitor(buf, session, trafficType)
l := len(buf.Bytes())
bb := make([]byte, l)
n, err := tm.Read(bb)
if err != nil {
t.Errorf("failed to read from traffic monitor: %v", err)
}
if n != l {
t.Errorf("expected to read %d bytes, got %d", l, n)
}
keepAliveReason := <-session.keepAliveReason
if keepAliveReason != trafficType {
t.Errorf("expected keep alive reason to be %s, got %s", trafficType, keepAliveReason)
}
}

View file

@ -1,101 +0,0 @@
package liveshare
import (
"context"
"encoding/json"
"fmt"
"github.com/sourcegraph/jsonrpc2"
)
// Port describes a port exposed by the container.
type Port struct {
SourcePort int `json:"sourcePort"`
DestinationPort int `json:"destinationPort"`
SessionName string `json:"sessionName"`
StreamName string `json:"streamName"`
StreamCondition string `json:"streamCondition"`
BrowseURL string `json:"browseUrl"`
IsPublic bool `json:"isPublic"`
IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"`
HasTLSHandshakePassed bool `json:"hasTLSHandshakePassed"`
Privacy string `json:"privacy"`
}
type PortChangeKind string
const (
PortChangeKindStart PortChangeKind = "start"
PortChangeKindUpdate PortChangeKind = "update"
)
type PortNotification struct {
Success bool // Helps us disambiguate between the SharingSucceeded/SharingFailed events
// The following are properties included in the SharingSucceeded/SharingFailed events sent by the server sharing service in the Codespace
Port int `json:"port"`
ChangeKind PortChangeKind `json:"changeKind"`
ErrorDetail string `json:"errorDetail"`
StatusCode int `json:"statusCode"`
}
// WaitForPortNotification waits for a port notification to be received. It returns the notification
// or an error if the notification is not received before the context is cancelled or it fails
// to parse the notification.
func (s *Session) WaitForPortNotification(ctx context.Context, port int, notifType PortChangeKind) (*PortNotification, error) {
// We use 1-buffered channels and non-blocking sends so that
// no goroutine gets stuck.
notificationCh := make(chan *PortNotification, 1)
errCh := make(chan error, 1)
h := func(success bool) func(*jsonrpc2.Conn, *jsonrpc2.Request) {
return func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
notification := new(PortNotification)
if err := json.Unmarshal(*req.Params, &notification); err != nil {
select {
case errCh <- fmt.Errorf("error unmarshalling notification: %w", err):
default:
}
return
}
notification.Success = success
if notification.Port == port && notification.ChangeKind == notifType {
select {
case notificationCh <- notification:
default:
}
}
}
}
deregisterSuccess := s.registerRequestHandler("serverSharing.sharingSucceeded", h(true))
deregisterFailure := s.registerRequestHandler("serverSharing.sharingFailed", h(false))
defer deregisterSuccess()
defer deregisterFailure()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errCh:
return nil, err
case notification := <-notificationCh:
return notification, nil
}
}
}
// GetSharedServers returns a description of each container port
// shared by a prior call to StartSharing by some client.
func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) {
var response []*Port
if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil {
return nil, err
}
return response, nil
}
// UpdateSharedServerPrivacy controls port permissions and visibility scopes for who can access its URLs
// in the browser.
func (s *Session) UpdateSharedServerPrivacy(ctx context.Context, port int, visibility string) error {
return s.rpc.do(ctx, "serverSharing.updateSharedServerPrivacy", []interface{}{port, visibility}, nil)
}

View file

@ -1,87 +0,0 @@
package liveshare
import (
"context"
"fmt"
"io"
"sync"
"time"
"github.com/opentracing/opentracing-go"
"github.com/sourcegraph/jsonrpc2"
)
type rpcClient struct {
*jsonrpc2.Conn
conn io.ReadWriteCloser
handlersMu sync.Mutex
handlers map[string][]*handlerWrapper
}
func newRPCClient(conn io.ReadWriteCloser) *rpcClient {
return &rpcClient{conn: conn, handlers: make(map[string][]*handlerWrapper)}
}
func (r *rpcClient) connect(ctx context.Context) {
stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{})
r.Conn = jsonrpc2.NewConn(ctx, stream, r)
}
func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, method)
defer span.Finish()
waiter, err := r.Conn.DispatchCall(ctx, method, args)
if err != nil {
return fmt.Errorf("error dispatching %q call: %w", method, err)
}
// timeout for waiter in case a connection cannot be made
waitCtx, cancel := context.WithTimeout(ctx, 2*time.Minute)
defer cancel()
return waiter.Wait(waitCtx, result)
}
type handler func(conn *jsonrpc2.Conn, req *jsonrpc2.Request)
type handlerWrapper struct {
fn handler
}
func (r *rpcClient) register(requestType string, fn handler) func() {
r.handlersMu.Lock()
defer r.handlersMu.Unlock()
h := &handlerWrapper{fn: fn}
r.handlers[requestType] = append(r.handlers[requestType], h)
return func() {
r.deregister(requestType, h)
}
}
func (r *rpcClient) deregister(requestType string, handler *handlerWrapper) {
r.handlersMu.Lock()
defer r.handlersMu.Unlock()
handlers := r.handlers[requestType]
for i, h := range handlers {
if h == handler {
// Swap h with last element and pop.
last := len(handlers) - 1
handlers[i], handlers[last] = handlers[last], nil
r.handlers[requestType] = handlers[:last]
break
}
}
}
func (r *rpcClient) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
r.handlersMu.Lock()
defer r.handlersMu.Unlock()
for _, handler := range r.handlers[req.Method] {
go handler.fn(conn, req)
}
}

View file

@ -1,133 +0,0 @@
package liveshare
import (
"context"
"fmt"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)
// A ChannelID is an identifier for an exposed port on a remote
// container that may be used to open an SSH channel to it.
type ChannelID struct {
name, condition string
}
// Interface to allow the mocking of the liveshare session
type LiveshareSession interface {
Close() error
GetSharedServers(context.Context) ([]*Port, error)
KeepAlive(string)
OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error)
StartSharing(context.Context, string, int) (ChannelID, error)
GetKeepAliveReason() string
}
// A Session represents the session between a connected Live Share client and server.
type Session struct {
ssh *sshSession
rpc *rpcClient
keepAliveReason chan string
logger logger
}
// Close should be called by users to clean up RPC and SSH resources whenever the session
// is no longer active.
func (s *Session) Close() error {
// Closing the RPC conn closes the underlying stream (SSH)
// So we only need to close once
if err := s.rpc.Close(); err != nil {
s.ssh.Close() // close SSH and ignore error
return fmt.Errorf("error while closing Live Share session: %w", err)
}
return nil
}
// Fetches the keep alive reason from the channel and returns it.
func (s *Session) GetKeepAliveReason() string {
return <-s.keepAliveReason
}
// registerRequestHandler registers a handler for the given request type with the RPC
// server and returns a callback function to deregister the handler
func (s *Session) registerRequestHandler(requestType string, h handler) func() {
return s.rpc.register(requestType, h)
}
// KeepAlive accepts a reason that is retained if there is no active reason
// to send to the server.
func (s *Session) KeepAlive(reason string) {
select {
case s.keepAliveReason <- reason:
default:
// there is already an active keep alive reason
// so we can ignore this one
}
}
// StartSharing tells the Live Share host to start sharing the specified port from the container.
// The sessionName describes the purpose of the remote port or service.
// It returns an identifier that can be used to open an SSH channel to the remote port.
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (ChannelID, error) {
args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
startNotification, err := s.WaitForPortNotification(ctx, port, PortChangeKindStart)
if err != nil {
return fmt.Errorf("error while waiting for port notification: %w", err)
}
if !startNotification.Success {
return fmt.Errorf("error while starting port sharing: %s", startNotification.ErrorDetail)
}
return nil // success
})
var response Port
g.Go(func() error {
return s.rpc.do(ctx, "serverSharing.startSharing", args, &response)
})
if err := g.Wait(); err != nil {
return ChannelID{}, err
}
return ChannelID{response.StreamName, response.StreamCondition}, nil
}
func (s *Session) OpenStreamingChannel(ctx context.Context, id ChannelID) (ssh.Channel, error) {
type getStreamArgs struct {
StreamName string `json:"streamName"`
Condition string `json:"condition"`
}
args := getStreamArgs{
StreamName: id.name,
Condition: id.condition,
}
var streamID string
if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
return nil, fmt.Errorf("error getting stream id: %w", err)
}
span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest")
defer span.Finish()
_ = ctx // ctx is not currently used
channel, reqs, err := s.ssh.conn.OpenChannel("session", nil)
if err != nil {
return nil, fmt.Errorf("error opening ssh channel for transport: %w", err)
}
go ssh.DiscardRequests(reqs)
requestType := fmt.Sprintf("stream-transport-%s", streamID)
if _, err = channel.SendRequest(requestType, true, nil); err != nil {
return nil, fmt.Errorf("error sending channel request: %w", err)
}
return channel, nil
}

View file

@ -1,278 +0,0 @@
package liveshare
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"testing"
livesharetest "github.com/cli/cli/v2/pkg/liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
}
const sessionToken = "session-token"
opts = append(
opts,
livesharetest.WithPassword(sessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
)
testServer, err := livesharetest.NewServer(opts...)
if err != nil {
return nil, nil, fmt.Errorf("error creating server: %w", err)
}
session, err := Connect(context.Background(), Options{
SessionID: "session-id",
SessionToken: sessionToken,
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
RelaySAS: "relay-sas",
HostPublicKeys: []string{livesharetest.SSHPublicKey},
TLSConfig: &tls.Config{InsecureSkipVerify: true},
Logger: newMockLogger(),
})
if err != nil {
return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err)
}
return testServer, session, nil
}
func TestServerStartSharing(t *testing.T) {
serverPort, serverProtocol := 2222, "sshd"
sendNotification := make(chan portUpdateNotification)
startSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
var args []interface{}
if err := json.Unmarshal(*req.Params, &args); err != nil {
return nil, fmt.Errorf("error unmarshalling request: %w", err)
}
if len(args) < 3 {
return nil, errors.New("not enough arguments to start sharing")
}
port, ok := args[0].(float64)
if !ok {
return nil, errors.New("port argument is not an int")
}
if port != float64(serverPort) {
return nil, errors.New("port does not match serverPort")
}
if protocol, ok := args[1].(string); !ok {
return nil, errors.New("protocol argument is not a string")
} else if protocol != serverProtocol {
return nil, errors.New("protocol does not match serverProtocol")
}
if browseURL, ok := args[2].(string); !ok {
return nil, errors.New("browse url is not a string")
} else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) {
return nil, errors.New("browseURL does not match expected")
}
sendNotification <- portUpdateNotification{
PortNotification: PortNotification{
Port: int(port),
ChangeKind: PortChangeKindStart,
},
conn: conn,
}
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
}
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.startSharing", startSharing),
)
defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close()
if err != nil {
t.Errorf("error creating mock session: %v", err)
}
ctx := context.Background()
go func() {
notif := <-sendNotification
_, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif)
}()
done := make(chan error)
go func() {
streamID, err := session.StartSharing(ctx, serverProtocol, serverPort)
if err != nil {
done <- fmt.Errorf("error sharing server: %w", err)
}
if streamID.name == "" || streamID.condition == "" {
done <- errors.New("stream name or condition is blank")
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}
func TestServerGetSharedServers(t *testing.T) {
sharedServer := Port{
SourcePort: 2222,
StreamName: "stream-name",
StreamCondition: "stream-condition",
}
getSharedServers := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
return []*Port{&sharedServer}, nil
}
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.getSharedServers", getSharedServers),
)
if err != nil {
t.Errorf("error creating mock session: %v", err)
}
defer testServer.Close()
ctx := context.Background()
done := make(chan error)
go func() {
ports, err := session.GetSharedServers(ctx)
if err != nil {
done <- fmt.Errorf("error getting shared servers: %w", err)
}
if len(ports) < 1 {
done <- errors.New("not enough ports returned")
}
if ports[0].SourcePort != sharedServer.SourcePort {
done <- errors.New("source port does not match")
}
if ports[0].StreamName != sharedServer.StreamName {
done <- errors.New("stream name does not match")
}
if ports[0].StreamCondition != sharedServer.StreamCondition {
done <- errors.New("stream condiion does not match")
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}
func TestServerUpdateSharedServerPrivacy(t *testing.T) {
updateSharedVisibility := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %w", err)
}
if len(req) < 2 {
return nil, errors.New("request arguments is less than 2")
}
if port, ok := req[0].(float64); ok {
if port != 80.0 {
return nil, errors.New("port param is not expected value")
}
} else {
return nil, errors.New("port param is not a float64")
}
if privacy, ok := req[1].(string); ok {
if privacy != "public" {
return nil, fmt.Errorf("expected privacy param to be public but got %q", privacy)
}
} else {
return nil, fmt.Errorf("expected privacy param to be a bool but go %T", req[1])
}
return nil, nil
}
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.updateSharedServerPrivacy", updateSharedVisibility),
)
if err != nil {
t.Errorf("creating mock session: %v", err)
}
defer testServer.Close()
ctx := context.Background()
done := make(chan error)
go func() {
done <- session.UpdateSharedServerPrivacy(ctx, 80, "public")
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}
func TestInvalidHostKey(t *testing.T) {
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
}
const sessionToken = "session-token"
opts := []livesharetest.ServerOption{
livesharetest.WithPassword(sessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
}
testServer, err := livesharetest.NewServer(opts...)
if err != nil {
t.Errorf("error creating server: %v", err)
}
_, err = Connect(context.Background(), Options{
SessionID: "session-id",
SessionToken: sessionToken,
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
RelaySAS: "relay-sas",
HostPublicKeys: []string{},
TLSConfig: &tls.Config{InsecureSkipVerify: true},
})
if err == nil {
t.Error("expected invalid host key error, got: nil")
}
}
func TestKeepAliveNonBlocking(t *testing.T) {
session := &Session{keepAliveReason: make(chan string, 1)}
for i := 0; i < 2; i++ {
session.KeepAlive("io")
}
// if KeepAlive blocks, we'll never reach this and timeout the test
// timing out
}
type mockLogger struct {
sync.Mutex
buf *bytes.Buffer
}
func newMockLogger() *mockLogger {
return &mockLogger{buf: new(bytes.Buffer)}
}
func (m *mockLogger) Printf(format string, v ...interface{}) {
m.Lock()
defer m.Unlock()
m.buf.WriteString(fmt.Sprintf(format, v...))
}
func (m *mockLogger) Println(v ...interface{}) {
m.Lock()
defer m.Unlock()
m.buf.WriteString(fmt.Sprintln(v...))
}
func (m *mockLogger) String() string {
m.Lock()
defer m.Unlock()
return m.buf.String()
}

View file

@ -1,100 +0,0 @@
package liveshare
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"time"
"github.com/gorilla/websocket"
)
type socket struct {
addr string
tlsConfig *tls.Config
conn *websocket.Conn
reader io.Reader
}
func newSocket(uri string, tlsConfig *tls.Config) *socket {
return &socket{addr: uri, tlsConfig: tlsConfig}
}
func (s *socket) connect(ctx context.Context) error {
dialer := websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
TLSClientConfig: s.tlsConfig,
}
ws, _, err := dialer.Dial(s.addr, nil)
if err != nil {
return err
}
s.conn = ws
return nil
}
func (s *socket) Read(b []byte) (int, error) {
if s.reader == nil {
_, reader, err := s.conn.NextReader()
if err != nil {
return 0, err
}
s.reader = reader
}
bytesRead, err := s.reader.Read(b)
if err != nil {
s.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (s *socket) Write(b []byte) (int, error) {
nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, err
}
bytesWritten, err := nextWriter.Write(b)
nextWriter.Close()
return bytesWritten, err
}
func (s *socket) Close() error {
return s.conn.Close()
}
func (s *socket) LocalAddr() net.Addr {
return s.conn.LocalAddr()
}
func (s *socket) RemoteAddr() net.Addr {
return s.conn.RemoteAddr()
}
func (s *socket) SetDeadline(t time.Time) error {
if err := s.SetReadDeadline(t); err != nil {
return err
}
return s.SetWriteDeadline(t)
}
func (s *socket) SetReadDeadline(t time.Time) error {
return s.conn.SetReadDeadline(t)
}
func (s *socket) SetWriteDeadline(t time.Time) error {
return s.conn.SetWriteDeadline(t)
}

View file

@ -1,79 +0,0 @@
package liveshare
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"time"
"golang.org/x/crypto/ssh"
)
type sshSession struct {
*ssh.Session
token string
hostPublicKeys []string
socket net.Conn
conn ssh.Conn
reader io.Reader
writer io.Writer
}
func newSSHSession(token string, hostPublicKeys []string, socket net.Conn) *sshSession {
return &sshSession{token: token, hostPublicKeys: hostPublicKeys, socket: socket}
}
func (s *sshSession) connect(ctx context.Context) error {
clientConfig := ssh.ClientConfig{
User: "",
Auth: []ssh.AuthMethod{
ssh.Password(s.token),
},
HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"},
HostKeyCallback: func(hostname string, addr net.Addr, key ssh.PublicKey) error {
encodedKey := base64.StdEncoding.EncodeToString(key.Marshal())
for _, hpk := range s.hostPublicKeys {
if encodedKey == hpk {
return nil // we found a match for expected public key, safely return
}
}
return errors.New("invalid host public key")
},
Timeout: 10 * time.Second,
}
sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig)
if err != nil {
return fmt.Errorf("error creating ssh client connection: %w", err)
}
s.conn = sshClientConn
sshClient := ssh.NewClient(sshClientConn, chans, reqs)
s.Session, err = sshClient.NewSession()
if err != nil {
return fmt.Errorf("error creating ssh client session: %w", err)
}
s.reader, err = s.Session.StdoutPipe()
if err != nil {
return fmt.Errorf("error creating ssh session reader: %w", err)
}
s.writer, err = s.Session.StdinPipe()
if err != nil {
return fmt.Errorf("error creating ssh session writer: %w", err)
}
return nil
}
func (s *sshSession) Read(p []byte) (n int, err error) {
return s.reader.Read(p)
}
func (s *sshSession) Write(p []byte) (n int, err error) {
return s.writer.Write(p)
}

View file

@ -1,349 +0,0 @@
package livesharetest
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"github.com/gorilla/websocket"
"github.com/sourcegraph/jsonrpc2"
"golang.org/x/crypto/ssh"
)
const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
MIICXgIBAAKBgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yV
rCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhY
lR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3wIDAQAB
AoGBAI8UemkYoSM06gBCh5D1RHQt8eKNltzL7g9QSNfoXeZOC7+q+/TiZPcbqLp0
5lyOalu8b8Ym7J0rSE377Ypj13LyHMXS63e4wMiXv3qOl3GDhMLpypnJ8PwqR2b8
IijL2jrpQfLu6IYqlteA+7e9aEexJa1RRwxYIyq6pG1IYpbhAkEA9nKgtj3Z6ZDC
46IdqYzuUM9ZQdcw4AFr407+lub7tbWe5pYmaq3cT725IwLw081OAmnWJYFDMa/n
IPl9YcZSPQJBAMGOMbPs/YPkQAsgNdIUlFtK3o41OrrwJuTRTvv0DsbqDV0LKOiC
t8oAQQvjisH6Ew5OOhFyIFXtvZfzQMJppksCQQDWFd+cUICTUEise/Duj9maY3Uz
J99ySGnTbZTlu8PfJuXhg3/d3ihrMPG6A1z3cPqaSBxaOj8H07mhQHn1zNU1AkEA
hkl+SGPrO793g4CUdq2ahIA8SpO5rIsDoQtq7jlUq0MlhGFCv5Y5pydn+bSjx5MV
933kocf5kUSBntPBIWElYwJAZTm5ghu0JtSE6t3km0iuj7NGAQSdb6mD8+O7C3CP
FU3vi+4HlBysaT6IZ/HG+/dBsr4gYp4LGuS7DbaLuYw/uw==
-----END RSA PRIVATE KEY-----`
const SSHPublicKey = `AAAAB3NzaC1yc2EAAAADAQABAAAAgQC6VU6XsMaTot9ogsGcJ+juvJOmDvvCZmgJRTRwKkW0u2BLz4yVrCzQcxaY4kaIuR80Y+1f0BLnZgh4pTREDR0T+p8hUsDSHim1ttKI8rK0hRtJ2qhYlR4qt7P51rPA4KFA9z9gDjTwQLbDq21QMC4+n4d8CL3xRVGtlUAMM3Kl3w==`
// Server represents a LiveShare relay host server.
type Server struct {
password string
services map[string]RPCHandleFunc
relaySAS string
streams map[string]io.ReadWriter
sshConfig *ssh.ServerConfig
httptestServer *httptest.Server
errCh chan error
nonSecure bool
}
// NewServer creates a new Server. ServerOptions can be passed to configure
// the SSH password, backing service, secrets and more.
func NewServer(opts ...ServerOption) (*Server, error) {
server := new(Server)
for _, o := range opts {
if err := o(server); err != nil {
return nil, err
}
}
server.sshConfig = &ssh.ServerConfig{
PasswordCallback: sshPasswordCallback(server.password),
}
privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey))
if err != nil {
return nil, fmt.Errorf("error parsing key: %w", err)
}
server.sshConfig.AddHostKey(privateKey)
server.errCh = make(chan error, 1)
if server.nonSecure {
server.httptestServer = httptest.NewServer(http.HandlerFunc(makeConnection(server)))
} else {
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
}
return server, nil
}
// ServerOption is used to configure the Server.
type ServerOption func(*Server) error
// WithPassword configures the Server password for SSH.
func WithPassword(password string) ServerOption {
return func(s *Server) error {
s.password = password
return nil
}
}
// WithNonSecure configures the Server as non-secure.
func WithNonSecure() ServerOption {
return func(s *Server) error {
s.nonSecure = true
return nil
}
}
// WithService accepts a mock RPC service for the Server to invoke.
func WithService(serviceName string, handler RPCHandleFunc) ServerOption {
return func(s *Server) error {
if s.services == nil {
s.services = make(map[string]RPCHandleFunc)
}
s.services[serviceName] = handler
return nil
}
}
// WithRelaySAS configures the relay SAS configuration key.
func WithRelaySAS(sas string) ServerOption {
return func(s *Server) error {
s.relaySAS = sas
return nil
}
}
// WithStream allows you to specify a mock data stream for the server.
func WithStream(name string, stream io.ReadWriter) ServerOption {
return func(s *Server) error {
if s.streams == nil {
s.streams = make(map[string]io.ReadWriter)
}
s.streams[name] = stream
return nil
}
}
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if string(password) == serverPassword {
return nil, nil
}
return nil, errors.New("password rejected")
}
}
// Close closes the underlying httptest Server.
func (s *Server) Close() {
s.httptestServer.Close()
}
// URL returns the httptest Server url.
func (s *Server) URL() string {
return s.httptestServer.URL
}
func (s *Server) Err() <-chan error {
return s.errCh
}
var upgrader = websocket.Upgrader{}
func makeConnection(server *Server) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if server.relaySAS != "" {
// validate the sas key
sasParam := req.URL.Query().Get("sb-hc-token")
if sasParam != server.relaySAS {
sendError(server.errCh, errors.New("error validating sas"))
return
}
}
c, err := upgrader.Upgrade(w, req, nil)
if err != nil {
sendError(server.errCh, fmt.Errorf("error upgrading connection: %w", err))
return
}
defer func() {
if err := c.Close(); err != nil {
sendError(server.errCh, err)
}
}()
socketConn := newSocketConn(c)
_, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
if err != nil {
sendError(server.errCh, fmt.Errorf("error creating new ssh conn: %w", err))
return
}
go ssh.DiscardRequests(reqs)
if err := handleChannels(ctx, server, chans); err != nil {
sendError(server.errCh, err)
}
}
}
// sendError does a non-blocking send of the error to the err channel.
func sendError(errc chan<- error, err error) {
select {
case errc <- err:
default:
// channel is blocked with a previous error, so we ignore
// this current error
}
}
// awaitError waits for the context to finish and returns its error (if any).
// It also waits for an err to come through the err channel.
func awaitError(ctx context.Context, errc <-chan error) error {
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errc:
return err
}
}
// handleChannels services the sshChannels channel. For each SSH channel received
// it creates a go routine to service the channel's requests. It returns on the first
// error encountered.
func handleChannels(ctx context.Context, server *Server, sshChannels <-chan ssh.NewChannel) error {
errc := make(chan error, 1)
go func() {
for sshCh := range sshChannels {
ch, reqs, err := sshCh.Accept()
if err != nil {
sendError(errc, fmt.Errorf("failed to accept channel: %w", err))
return
}
go func() {
if err := handleRequests(ctx, server, ch, reqs); err != nil {
sendError(errc, fmt.Errorf("failed to handle requests: %w", err))
}
}()
handleChannel(server, ch)
}
}()
return awaitError(ctx, errc)
}
// handleRequests services the SSH channel requests channel. It replies to requests and
// when stream transport requests are encountered, creates a go routine to create a
// bi-directional data stream between the channel and server stream. It returns on the first error
// encountered.
func handleRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) error {
errc := make(chan error, 1)
go func() {
for req := range reqs {
r := req
if r.WantReply {
if err := r.Reply(true, nil); err != nil {
sendError(errc, fmt.Errorf("error replying to channel request: %w", err))
return
}
}
if strings.HasPrefix(r.Type, "stream-transport") {
go func() {
if err := forwardStream(ctx, server, r.Type, channel); err != nil {
sendError(errc, fmt.Errorf("failed to forward stream: %w", err))
}
}()
}
}
}()
return awaitError(ctx, errc)
}
// concurrentStream is a concurrency safe io.ReadWriter.
type concurrentStream struct {
sync.RWMutex
stream io.ReadWriter
}
func newConcurrentStream(rw io.ReadWriter) *concurrentStream {
return &concurrentStream{stream: rw}
}
func (cs *concurrentStream) Read(b []byte) (int, error) {
cs.RLock()
defer cs.RUnlock()
return cs.stream.Read(b)
}
func (cs *concurrentStream) Write(b []byte) (int, error) {
cs.Lock()
defer cs.Unlock()
return cs.stream.Write(b)
}
// forwardStream does a bi-directional copy of the stream <-> with the SSH channel. The io.Copy
// runs until an error is encountered.
func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) (err error) {
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
stream, found := server.streams[simpleStreamName]
if !found {
return fmt.Errorf("stream '%s' not found", simpleStreamName)
}
defer func() {
if closeErr := channel.Close(); err == nil && closeErr != io.EOF {
err = closeErr
}
}()
errc := make(chan error, 2)
copy := func(dst io.Writer, src io.Reader) {
if _, err := io.Copy(dst, src); err != nil {
errc <- err
}
}
csStream := newConcurrentStream(stream)
go copy(csStream, channel)
go copy(channel, csStream)
return awaitError(ctx, errc)
}
func handleChannel(server *Server, channel ssh.Channel) {
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server))
}
type RPCHandleFunc func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error)
type rpcHandler struct {
server *Server
}
func newRPCHandler(server *Server) *rpcHandler {
return &rpcHandler{server}
}
// Handle satisfies the jsonrpc2 pkg handler interface. It tries to find a mocked
// RPC service method and if found, it invokes the handler and replies to the request.
func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
handler, found := r.server.services[req.Method]
if !found {
sendError(r.server.errCh, fmt.Errorf("RPC Method: '%s' not serviced", req.Method))
return
}
result, err := handler(conn, req)
if err != nil {
sendError(r.server.errCh, fmt.Errorf("error handling: '%s': %w", req.Method, err))
return
}
if err := conn.Reply(ctx, req.ID, result); err != nil {
sendError(r.server.errCh, fmt.Errorf("error replying: %w", err))
}
}

View file

@ -1,77 +0,0 @@
package livesharetest
import (
"fmt"
"io"
"sync"
"time"
"github.com/gorilla/websocket"
)
type socketConn struct {
*websocket.Conn
reader io.Reader
writeMutex sync.Mutex
readMutex sync.Mutex
}
func newSocketConn(conn *websocket.Conn) *socketConn {
return &socketConn{Conn: conn}
}
func (s *socketConn) Read(b []byte) (int, error) {
s.readMutex.Lock()
defer s.readMutex.Unlock()
if s.reader == nil {
msgType, r, err := s.Conn.NextReader()
if err != nil {
return 0, fmt.Errorf("error getting next reader: %w", err)
}
if msgType != websocket.BinaryMessage {
return 0, fmt.Errorf("invalid message type")
}
s.reader = r
}
bytesRead, err := s.reader.Read(b)
if err != nil {
s.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (s *socketConn) Write(b []byte) (int, error) {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, fmt.Errorf("error getting next writer: %w", err)
}
n, err := w.Write(b)
if err != nil {
return 0, fmt.Errorf("error writing: %w", err)
}
if err := w.Close(); err != nil {
return 0, fmt.Errorf("error closing writer: %w", err)
}
return n, nil
}
func (s *socketConn) SetDeadline(deadline time.Time) error {
if err := s.Conn.SetReadDeadline(deadline); err != nil {
return err
}
return s.Conn.SetWriteDeadline(deadline)
}