diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 514c36966..4bc83001b 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -9,7 +9,7 @@ import ( "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" "github.com/spf13/cobra" ) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 8a4f855fa..24ec7a6e8 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -14,7 +14,7 @@ import ( "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" "github.com/muhammadmuzzammil1998/jsonc" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 3a49e6ebc..bb771107a 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -9,7 +9,7 @@ import ( "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" "github.com/spf13/cobra" ) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 43809bab9..1cd605abc 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -7,7 +7,7 @@ import ( "time" "github.com/github/ghcs/internal/api" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" ) type logger interface { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 31105d576..c7d61b41e 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -10,7 +10,7 @@ import ( "time" "github.com/github/ghcs/internal/api" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" ) // PostCreateStateStatus is a string value representing the different statuses a state can have. diff --git a/internal/liveshare/client.go b/internal/liveshare/client.go new file mode 100644 index 000000000..2b1f97831 --- /dev/null +++ b/internal/liveshare/client.go @@ -0,0 +1,149 @@ +// Package liveshare is a Go client library for the Visual Studio Live Share +// service, which provides collaborative, distibuted editing and debugging. +// See https://docs.microsoft.com/en-us/visualstudio/liveshare for an overview. +// +// It provides the ability for a Go program to connect to a Live Share +// workspace (Connect), to expose a TCP port on a remote host +// (UpdateSharedVisibility), to start an SSH server listening on an +// exposed port (StartSSHServer), and to forward connections between +// the remote port and a local listening TCP port (ForwardToListener) +// or a local Go reader/writer (Forward). +package liveshare + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net/url" + "strings" + + "github.com/opentracing/opentracing-go" + "golang.org/x/crypto/ssh" +) + +// An Options specifies Live Share connection parameters. +type Options struct { + SessionID string + SessionToken string // token for SSH session + RelaySAS string + RelayEndpoint string + TLSConfig *tls.Config // (optional) +} + +// uri returns a websocket URL for the specified options. +func (opts *Options) uri(action string) (string, error) { + if opts.SessionID == "" { + return "", errors.New("SessionID is required") + } + if opts.RelaySAS == "" { + return "", errors.New("RelaySAS is required") + } + if opts.RelayEndpoint == "" { + return "", errors.New("RelayEndpoint is required") + } + + sas := url.QueryEscape(opts.RelaySAS) + uri := opts.RelayEndpoint + uri = strings.Replace(uri, "sb:", "wss:", -1) + uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) + uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas + return uri, nil +} + +// Connect connects to a Live Share workspace specified by the +// options, and returns a session representing the connection. +// The caller must call the session's Close method to end the session. +func Connect(ctx context.Context, opts Options) (*Session, error) { + uri, err := opts.uri("connect") + if err != nil { + return nil, err + } + + span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") + defer span.Finish() + + sock := newSocket(uri, opts.TLSConfig) + if err := sock.connect(ctx); err != nil { + return nil, fmt.Errorf("error connecting websocket: %w", err) + } + + if opts.SessionToken == "" { + return nil, errors.New("SessionToken is required") + } + ssh := newSSHSession(opts.SessionToken, sock) + if err := ssh.connect(ctx); err != nil { + return nil, fmt.Errorf("error connecting to ssh session: %w", err) + } + + rpc := newRPCClient(ssh) + rpc.connect(ctx) + + args := joinWorkspaceArgs{ + ID: opts.SessionID, + ConnectionMode: "local", + JoiningUserSessionToken: opts.SessionToken, + ClientCapabilities: clientCapabilities{ + IsNonInteractive: false, + }, + } + var result joinWorkspaceResult + if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { + return nil, fmt.Errorf("error joining Live Share workspace: %w", err) + } + + return &Session{ssh: ssh, rpc: rpc}, nil +} + +type clientCapabilities struct { + IsNonInteractive bool `json:"isNonInteractive"` +} + +type joinWorkspaceArgs struct { + ID string `json:"id"` + ConnectionMode string `json:"connectionMode"` + JoiningUserSessionToken string `json:"joiningUserSessionToken"` + ClientCapabilities clientCapabilities `json:"clientCapabilities"` +} + +type joinWorkspaceResult struct { + SessionNumber int `json:"sessionNumber"` +} + +// A channelID is an identifier for an exposed port on a remote +// container that may be used to open an SSH channel to it. +type channelID struct { + name, condition string +} + +func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) { + type getStreamArgs struct { + StreamName string `json:"streamName"` + Condition string `json:"condition"` + } + args := getStreamArgs{ + StreamName: id.name, + Condition: id.condition, + } + var streamID string + if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { + return nil, fmt.Errorf("error getting stream id: %w", err) + } + + span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") + defer span.Finish() + _ = ctx // ctx is not currently used + + channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) + if err != nil { + return nil, fmt.Errorf("error opening ssh channel for transport: %w", err) + } + go ssh.DiscardRequests(reqs) + + requestType := fmt.Sprintf("stream-transport-%s", streamID) + if _, err = channel.SendRequest(requestType, true, nil); err != nil { + return nil, fmt.Errorf("error sending channel request: %w", err) + } + + return channel, nil +} diff --git a/internal/liveshare/client_test.go b/internal/liveshare/client_test.go new file mode 100644 index 000000000..12ea903b6 --- /dev/null +++ b/internal/liveshare/client_test.go @@ -0,0 +1,71 @@ +package liveshare + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + + livesharetest "github.com/github/ghcs/internal/liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestConnect(t *testing.T) { + opts := Options{ + SessionID: "session-id", + SessionToken: "session-token", + RelaySAS: "relay-sas", + } + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + var joinWorkspaceReq joinWorkspaceArgs + if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { + return nil, fmt.Errorf("error unmarshaling req: %w", err) + } + if joinWorkspaceReq.ID != opts.SessionID { + return nil, errors.New("connection session id does not match") + } + if joinWorkspaceReq.ConnectionMode != "local" { + return nil, errors.New("connection mode is not local") + } + if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken { + return nil, errors.New("connection user token does not match") + } + if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false { + return nil, errors.New("non interactive is not false") + } + return joinWorkspaceResult{1}, nil + } + + server, err := livesharetest.NewServer( + livesharetest.WithPassword(opts.SessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + livesharetest.WithRelaySAS(opts.RelaySAS), + ) + if err != nil { + t.Errorf("error creating Live Share server: %w", err) + } + defer server.Close() + opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") + + ctx := context.Background() + + opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} + + done := make(chan error) + go func() { + _, err := Connect(ctx, opts) // ignore session + done <- err + }() + + select { + case err := <-server.Err(): + t.Errorf("error from server: %w", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %w", err) + } + } +} diff --git a/internal/liveshare/options_test.go b/internal/liveshare/options_test.go new file mode 100644 index 000000000..830c59104 --- /dev/null +++ b/internal/liveshare/options_test.go @@ -0,0 +1,56 @@ +package liveshare + +import ( + "context" + "testing" +) + +func TestBadOptions(t *testing.T) { + goodOptions := Options{ + SessionID: "sess-id", + SessionToken: "sess-token", + RelaySAS: "sas", + RelayEndpoint: "endpoint", + } + + opts := goodOptions + opts.SessionID = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.SessionToken = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.RelaySAS = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.RelayEndpoint = "" + checkBadOptions(t, opts) + + opts = Options{} + checkBadOptions(t, opts) +} + +func checkBadOptions(t *testing.T, opts Options) { + if _, err := Connect(context.Background(), opts); err == nil { + t.Errorf("Connect(%+v): no error", opts) + } +} + +func TestOptionsURI(t *testing.T) { + opts := Options{ + SessionID: "sess-id", + SessionToken: "sess-token", + RelaySAS: "sas", + RelayEndpoint: "sb://endpoint/.net/liveshare", + } + uri, err := opts.uri("connect") + if err != nil { + t.Fatal(err) + } + if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" { + t.Errorf("uri is not correct, got: '%v'", uri) + } +} diff --git a/internal/liveshare/port_forwarder.go b/internal/liveshare/port_forwarder.go new file mode 100644 index 000000000..fcc7ba767 --- /dev/null +++ b/internal/liveshare/port_forwarder.go @@ -0,0 +1,162 @@ +package liveshare + +import ( + "context" + "fmt" + "io" + "net" + + "github.com/opentracing/opentracing-go" +) + +// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote +// container to a local destination such as a network port or Go reader/writer. +type PortForwarder struct { + session *Session + name string + remotePort int +} + +// NewPortForwarder returns a new PortForwarder for the specified +// remote port and Live Share session. The name describes the purpose +// of the remote port or service. +func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder { + return &PortForwarder{ + session: session, + name: name, + remotePort: remotePort, + } +} + +// ForwardToListener forwards traffic between the container's remote +// port and a local port, which must already be listening for +// connections. (Accepting a listener rather than a port number avoids +// races against other processes opening ports, and against a client +// connecting to the socket prematurely.) +// +// ForwardToListener accepts and handles connections on the local port +// until it encounters the first error, which may include context +// cancellation. Its error result is always non-nil. The caller is +// responsible for closing the listening port. +func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) { + id, err := fwd.shareRemotePort(ctx) + if err != nil { + return err + } + + errc := make(chan error, 1) + sendError := func(err error) { + // Use non-blocking send, to avoid goroutines getting + // stuck in case of concurrent or sequential errors. + select { + case errc <- err: + default: + } + } + go func() { + for { + conn, err := listen.Accept() + if err != nil { + sendError(err) + return + } + + go func() { + if err := fwd.handleConnection(ctx, id, conn); err != nil { + sendError(err) + } + }() + } + }() + + return awaitError(ctx, errc) +} + +// Forward forwards traffic between the container's remote port and +// the specified read/write stream. On return, the stream is closed. +func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error { + id, err := fwd.shareRemotePort(ctx) + if err != nil { + conn.Close() + return err + } + + // Create buffered channel so that send doesn't get stuck after context cancellation. + errc := make(chan error, 1) + go func() { + errc <- fwd.handleConnection(ctx, id, conn) + }() + return awaitError(ctx, errc) +} + +func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) { + id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort) + if err != nil { + err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err) + } + return id, err +} + +func awaitError(ctx context.Context, errc <-chan error) error { + select { + case err := <-errc: + return err + case <-ctx.Done(): + return ctx.Err() // canceled + } +} + +// handleConnection handles forwarding for a single accepted connection, then closes it. +func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection") + defer span.Finish() + + defer safeClose(conn, &err) + + channel, err := fwd.session.openStreamingChannel(ctx, id) + if err != nil { + return fmt.Errorf("error opening streaming channel for new connection: %w", err) + } + // Ideally we would call safeClose again, but (*ssh.channel).Close + // appears to have a bug that causes it return io.EOF spuriously + // if its peer closed first; see github.com/golang/go/issues/38115. + defer func() { + closeErr := channel.Close() + if err == nil && closeErr != io.EOF { + err = closeErr + } + }() + + // bi-directional copy of data. + errs := make(chan error, 2) + copyConn := func(w io.Writer, r io.Reader) { + _, err := io.Copy(w, r) + errs <- err + } + go copyConn(conn, channel) + go copyConn(channel, conn) + + // Wait until context is cancelled or both copies are done. + // Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail. + // TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file? + for i := 0; ; { + select { + case <-ctx.Done(): + return ctx.Err() + case <-errs: + i++ + if i == 2 { + return nil + } + } + } +} + +// safeClose reports the error (to *err) from closing the stream only +// if no other error was previously reported. +func safeClose(closer io.Closer, err *error) { + closeErr := closer.Close() + if *err == nil { + *err = closeErr + } +} diff --git a/internal/liveshare/port_forwarder_test.go b/internal/liveshare/port_forwarder_test.go new file mode 100644 index 000000000..64dfb5c88 --- /dev/null +++ b/internal/liveshare/port_forwarder_test.go @@ -0,0 +1,95 @@ +package liveshare + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "testing" + "time" + + livesharetest "github.com/github/ghcs/internal/liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestNewPortForwarder(t *testing.T) { + testServer, session, err := makeMockSession() + if err != nil { + t.Errorf("create mock client: %w", err) + } + defer testServer.Close() + pf := NewPortForwarder(session, "ssh", 80) + if pf == nil { + t.Error("port forwarder is nil") + } +} + +func TestPortForwarderStart(t *testing.T) { + streamName, streamCondition := "stream-name", "stream-condition" + serverSharing := func(req *jsonrpc2.Request) (interface{}, error) { + return Port{StreamName: streamName, StreamCondition: streamCondition}, nil + } + getStream := func(req *jsonrpc2.Request) (interface{}, error) { + return "stream-id", nil + } + + stream := bytes.NewBufferString("stream-data") + testServer, session, err := makeMockSession( + livesharetest.WithService("serverSharing.startSharing", serverSharing), + livesharetest.WithService("streamManager.getStream", getStream), + livesharetest.WithStream("stream-id", stream), + ) + if err != nil { + t.Errorf("create mock session: %w", err) + } + defer testServer.Close() + + listen, err := net.Listen("tcp", ":8000") + if err != nil { + t.Fatal(err) + } + defer listen.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error) + go func() { + const name, remote = "ssh", 8000 + done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen) + }() + + go func() { + var conn net.Conn + retries := 0 + for conn == nil && retries < 2 { + conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second) + time.Sleep(1 * time.Second) + } + if conn == nil { + done <- errors.New("failed to connect to forwarded port") + } + b := make([]byte, len("stream-data")) + if _, err := conn.Read(b); err != nil && err != io.EOF { + done <- fmt.Errorf("reading stream: %w", err) + } + if string(b) != "stream-data" { + done <- fmt.Errorf("stream data is not expected value, got: %s", string(b)) + } + if _, err := conn.Write([]byte("new-data")); err != nil { + done <- fmt.Errorf("writing to stream: %w", err) + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %w", err) + } + } +} diff --git a/internal/liveshare/rpc.go b/internal/liveshare/rpc.go new file mode 100644 index 000000000..bfd214c89 --- /dev/null +++ b/internal/liveshare/rpc.go @@ -0,0 +1,41 @@ +package liveshare + +import ( + "context" + "fmt" + "io" + + "github.com/opentracing/opentracing-go" + "github.com/sourcegraph/jsonrpc2" +) + +type rpcClient struct { + *jsonrpc2.Conn + conn io.ReadWriteCloser +} + +func newRPCClient(conn io.ReadWriteCloser) *rpcClient { + return &rpcClient{conn: conn} +} + +func (r *rpcClient) connect(ctx context.Context) { + stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) + r.Conn = jsonrpc2.NewConn(ctx, stream, nullHandler{}) +} + +func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error { + span, ctx := opentracing.StartSpanFromContext(ctx, method) + defer span.Finish() + + waiter, err := r.Conn.DispatchCall(ctx, method, args) + if err != nil { + return fmt.Errorf("error dispatching %q call: %w", method, err) + } + + return waiter.Wait(ctx, result) +} + +type nullHandler struct{} + +func (nullHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { +} diff --git a/internal/liveshare/session.go b/internal/liveshare/session.go new file mode 100644 index 000000000..929e8605b --- /dev/null +++ b/internal/liveshare/session.go @@ -0,0 +1,99 @@ +package liveshare + +import ( + "context" + "fmt" + "strconv" +) + +// A Session represents the session between a connected Live Share client and server. +type Session struct { + ssh *sshSession + rpc *rpcClient +} + +// Close should be called by users to clean up RPC and SSH resources whenever the session +// is no longer active. +func (s *Session) Close() error { + // Closing the RPC conn closes the underlying stream (SSH) + // So we only need to close once + if err := s.rpc.Close(); err != nil { + s.ssh.Close() // close SSH and ignore error + return fmt.Errorf("error while closing Live Share session: %w", err) + } + + return nil +} + +// Port describes a port exposed by the container. +type Port struct { + SourcePort int `json:"sourcePort"` + DestinationPort int `json:"destinationPort"` + SessionName string `json:"sessionName"` + StreamName string `json:"streamName"` + StreamCondition string `json:"streamCondition"` + BrowseURL string `json:"browseUrl"` + IsPublic bool `json:"isPublic"` + IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"` + HasTLSHandshakePassed bool `json:"hasTLSHandshakePassed"` +} + +// startSharing tells the Live Share host to start sharing the specified port from the container. +// The sessionName describes the purpose of the remote port or service. +// It returns an identifier that can be used to open an SSH channel to the remote port. +func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) { + args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)} + var response Port + if err := s.rpc.do(ctx, "serverSharing.startSharing", args, &response); err != nil { + return channelID{}, err + } + + return channelID{response.StreamName, response.StreamCondition}, nil +} + +// GetSharedServers returns a description of each container port +// shared by a prior call to StartSharing by some client. +func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) { + var response []*Port + if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { + return nil, err + } + + return response, nil +} + +// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly +// via the Browse URL +func (s *Session) UpdateSharedVisibility(ctx context.Context, port int, public bool) error { + if err := s.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil { + return err + } + + return nil +} + +// StartsSSHServer starts an SSH server in the container, installing sshd if necessary, +// and returns the port on which it listens and the user name clients should provide. +func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) { + var response struct { + Result bool `json:"result"` + ServerPort string `json:"serverPort"` + User string `json:"user"` + Message string `json:"message"` + } + + if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { + return 0, "", err + } + + if !response.Result { + return 0, "", fmt.Errorf("failed to start server: %s", response.Message) + } + + port, err := strconv.Atoi(response.ServerPort) + if err != nil { + return 0, "", fmt.Errorf("failed to parse port: %w", err) + } + + return port, response.User, nil +} diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go new file mode 100644 index 000000000..af41dd117 --- /dev/null +++ b/internal/liveshare/session_test.go @@ -0,0 +1,196 @@ +package liveshare + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + + livesharetest "github.com/github/ghcs/internal/liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + return joinWorkspaceResult{1}, nil + } + const sessionToken = "session-token" + opts = append( + opts, + livesharetest.WithPassword(sessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + ) + testServer, err := livesharetest.NewServer(opts...) + if err != nil { + return nil, nil, fmt.Errorf("error creating server: %w", err) + } + + session, err := Connect(context.Background(), Options{ + SessionID: "session-id", + SessionToken: sessionToken, + RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), + RelaySAS: "relay-sas", + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + }) + if err != nil { + return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err) + } + return testServer, session, nil +} + +func TestServerStartSharing(t *testing.T) { + serverPort, serverProtocol := 2222, "sshd" + startSharing := func(req *jsonrpc2.Request) (interface{}, error) { + var args []interface{} + if err := json.Unmarshal(*req.Params, &args); err != nil { + return nil, fmt.Errorf("error unmarshaling request: %w", err) + } + if len(args) < 3 { + return nil, errors.New("not enough arguments to start sharing") + } + if port, ok := args[0].(float64); !ok { + return nil, errors.New("port argument is not an int") + } else if port != float64(serverPort) { + return nil, errors.New("port does not match serverPort") + } + if protocol, ok := args[1].(string); !ok { + return nil, errors.New("protocol argument is not a string") + } else if protocol != serverProtocol { + return nil, errors.New("protocol does not match serverProtocol") + } + if browseURL, ok := args[2].(string); !ok { + return nil, errors.New("browse url is not a string") + } else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) { + return nil, errors.New("browseURL does not match expected") + } + return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil + } + testServer, session, err := makeMockSession( + livesharetest.WithService("serverSharing.startSharing", startSharing), + ) + defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close() + + if err != nil { + t.Errorf("error creating mock session: %w", err) + } + ctx := context.Background() + + done := make(chan error) + go func() { + streamID, err := session.startSharing(ctx, serverProtocol, serverPort) + if err != nil { + done <- fmt.Errorf("error sharing server: %w", err) + } + if streamID.name == "" || streamID.condition == "" { + done <- errors.New("stream name or condition is blank") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %w", err) + } + } +} + +func TestServerGetSharedServers(t *testing.T) { + sharedServer := Port{ + SourcePort: 2222, + StreamName: "stream-name", + StreamCondition: "stream-condition", + } + getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) { + return []*Port{&sharedServer}, nil + } + testServer, session, err := makeMockSession( + livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), + ) + if err != nil { + t.Errorf("error creating mock session: %w", err) + } + defer testServer.Close() + ctx := context.Background() + done := make(chan error) + go func() { + ports, err := session.GetSharedServers(ctx) + if err != nil { + done <- fmt.Errorf("error getting shared servers: %w", err) + } + if len(ports) < 1 { + done <- errors.New("not enough ports returned") + } + if ports[0].SourcePort != sharedServer.SourcePort { + done <- errors.New("source port does not match") + } + if ports[0].StreamName != sharedServer.StreamName { + done <- errors.New("stream name does not match") + } + if ports[0].StreamCondition != sharedServer.StreamCondition { + done <- errors.New("stream condiion does not match") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %w", err) + } + } +} + +func TestServerUpdateSharedVisibility(t *testing.T) { + updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) { + var req []interface{} + if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { + return nil, fmt.Errorf("unmarshal req: %w", err) + } + if len(req) < 2 { + return nil, errors.New("request arguments is less than 2") + } + if port, ok := req[0].(float64); ok { + if port != 80.0 { + return nil, errors.New("port param is not expected value") + } + } else { + return nil, errors.New("port param is not a float64") + } + if public, ok := req[1].(bool); ok { + if public != true { + return nil, errors.New("pulic param is not expected value") + } + } else { + return nil, errors.New("public param is not a bool") + } + return nil, nil + } + testServer, session, err := makeMockSession( + livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), + ) + if err != nil { + t.Errorf("creating mock session: %w", err) + } + defer testServer.Close() + ctx := context.Background() + done := make(chan error) + go func() { + done <- session.UpdateSharedVisibility(ctx, 80, true) + }() + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %w", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %w", err) + } + } +} diff --git a/internal/liveshare/socket.go b/internal/liveshare/socket.go new file mode 100644 index 000000000..f66436f65 --- /dev/null +++ b/internal/liveshare/socket.go @@ -0,0 +1,100 @@ +package liveshare + +import ( + "context" + "crypto/tls" + "io" + "net" + "net/http" + "time" + + "github.com/gorilla/websocket" +) + +type socket struct { + addr string + tlsConfig *tls.Config + + conn *websocket.Conn + reader io.Reader +} + +func newSocket(uri string, tlsConfig *tls.Config) *socket { + return &socket{addr: uri, tlsConfig: tlsConfig} +} + +func (s *socket) connect(ctx context.Context) error { + dialer := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + TLSClientConfig: s.tlsConfig, + } + ws, _, err := dialer.Dial(s.addr, nil) + if err != nil { + return err + } + s.conn = ws + return nil +} + +func (s *socket) Read(b []byte) (int, error) { + if s.reader == nil { + _, reader, err := s.conn.NextReader() + if err != nil { + return 0, err + } + + s.reader = reader + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socket) Write(b []byte) (int, error) { + nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, err + } + + bytesWritten, err := nextWriter.Write(b) + nextWriter.Close() + + return bytesWritten, err +} + +func (s *socket) Close() error { + return s.conn.Close() +} + +func (s *socket) LocalAddr() net.Addr { + return s.conn.LocalAddr() +} + +func (s *socket) RemoteAddr() net.Addr { + return s.conn.RemoteAddr() +} + +func (s *socket) SetDeadline(t time.Time) error { + if err := s.SetReadDeadline(t); err != nil { + return err + } + + return s.SetWriteDeadline(t) +} + +func (s *socket) SetReadDeadline(t time.Time) error { + return s.conn.SetReadDeadline(t) +} + +func (s *socket) SetWriteDeadline(t time.Time) error { + return s.conn.SetWriteDeadline(t) +} diff --git a/internal/liveshare/ssh.go b/internal/liveshare/ssh.go new file mode 100644 index 000000000..15f67d2a4 --- /dev/null +++ b/internal/liveshare/ssh.go @@ -0,0 +1,68 @@ +package liveshare + +import ( + "context" + "fmt" + "io" + "net" + "time" + + "golang.org/x/crypto/ssh" +) + +type sshSession struct { + *ssh.Session + token string + socket net.Conn + conn ssh.Conn + reader io.Reader + writer io.Writer +} + +func newSSHSession(token string, socket net.Conn) *sshSession { + return &sshSession{token: token, socket: socket} +} + +func (s *sshSession) connect(ctx context.Context) error { + clientConfig := ssh.ClientConfig{ + User: "", + Auth: []ssh.AuthMethod{ + ssh.Password(s.token), + }, + HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 10 * time.Second, + } + + sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig) + if err != nil { + return fmt.Errorf("error creating ssh client connection: %w", err) + } + s.conn = sshClientConn + + sshClient := ssh.NewClient(sshClientConn, chans, reqs) + s.Session, err = sshClient.NewSession() + if err != nil { + return fmt.Errorf("error creating ssh client session: %w", err) + } + + s.reader, err = s.Session.StdoutPipe() + if err != nil { + return fmt.Errorf("error creating ssh session reader: %w", err) + } + + s.writer, err = s.Session.StdinPipe() + if err != nil { + return fmt.Errorf("error creating ssh session writer: %w", err) + } + + return nil +} + +func (s *sshSession) Read(p []byte) (n int, err error) { + return s.reader.Read(p) +} + +func (s *sshSession) Write(p []byte) (n int, err error) { + return s.writer.Write(p) +} diff --git a/internal/liveshare/test/server.go b/internal/liveshare/test/server.go new file mode 100644 index 000000000..9b898dafb --- /dev/null +++ b/internal/liveshare/test/server.go @@ -0,0 +1,245 @@ +package livesharetest + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + + "github.com/gorilla/websocket" + "github.com/sourcegraph/jsonrpc2" + "golang.org/x/crypto/ssh" +) + +const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAp/Jmzy/HaPNx5Bug09FX5Q/KGY4G9c4DfplhWrn31OQCqNiT +ZSLd46rdXC75liHzE7e5Ic0RJN61cYN9SNArjvEXx2vvs7szhwO7LonwPOvpYpUf +daayrgbr6S46plpx+hEZ1kO/6BqMgFuvnkIVThrEyx5b48ll8zgDABsYrKF8/p1V +SjGfb+bLwjn1NtnZF2prBG5P4ZtMR06HaPglLqBJhmc0ZMG5IZGUE7ew/VrPDqdC +f1v4XvvGiU4BLoKYy4QOhyrCGh9Uk/9u0Ea56M2bh4RqwhbpR8m7TYJZ0DVMLbGW +8C+4lCWp+xRyBNxAQh8qeQVCxYl02hPE4bXLGQIDAQABAoIBAEoVPk6UZ+UexhV2 +LnphNOFhFqgxI1bYWmhE5lHsCKuLLLUoW9RYDgL4gw6/1e7o6N3AxFRpre9Soj0B +YIl28k/qf6/DKAhjQnaDKdV8mVF2Swvmdesi7lyfxv6kGtD4wqApXPlMB2IuG94f +E5e+1MEQQ9DJgoU3eNZR1dj9GuRC3PyzPcNNJ2R/MMGFw3sOOVcLOgAukotoicuL +0SiL51rHPQu8a5/darH9EltN1GFeceJSDDhgqMP5T8Tp7g/c3//H6szon4H9W+uN +Z3UrImJ+teJjFOaVDqN93+J2eQSUk0lCPGQCd4U9I4AGDGyU6ucdcLQ58Aha9gmU +uQwkfKUCgYEA0UkuPOSDE9dbXe+yhsbOwMb1kKzJYgFDKjRTSP7D9BOMZu4YyASo +J95R4DWjePlDopafG2tNJoWX+CwUl7Uld1R3Ex6xHBa2B7hwZj860GZtr7D4mdWc +DTVjczAjp4P0K1MIFYQui1mVJterkjKuePiI6q/27L1c2jIa/39BWBcCgYEAzW8R +MFZamVw3eA2JYSpBuqhQgE5gX5IWrmVJZSUhpAQTNG/A4nxf7WGtjy9p99tm0RMb +ld05+sOmNLrzw8Pq8SBpFOd+MAca7lPLS1A2CoaAHbOqRqrzVcZ4EZ2jB3WjoLoq +yctwslGb9KmrhBCdcwT48aPAYUIJCZdqEen2xE8CgYBoMowvywGrvjwCH9X9njvP +5P7cAfrdrY04FQcmP5lmCtmLYZ267/6couaWv33dPBU9fMpIh3rI5BiOebvi8FBw +AgCq50v8lR4Z5+0mKvLoUSbpIy4SwTRJqzwRXHVT8LF/ZH6Q39egj4Bf716/kjYl +im/4kJVatsjk5a9lZ4EsDwKBgERkJ3rKJNtNggHrr8KzSLKVekdc0GTAw+BHRAny +NKLf4Gzij3pXIbBrhlZW2JZ1amNMUzCvN7AuFlUTsDeKL9saiSE2eCIRG3wgVVu7 +VmJmqJw6xgNEwkHaEvr6Wd4P4euOTtRjcB9NX/gxzDHpPiGelCoN8+vtCgkxaVSR +aV+tAoGAO4HtLOfBAVDNbVXa27aJAjQSUq8qfkwUNJNz+rwgpVQahfiVkyqAPCQM +IfRJxKWb0Wbt9ojw3AowK/k0d3LZA7FS41JSiiGKIllSGb+i7JKqKW7RHLA3VJ/E +Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U= +-----END RSA PRIVATE KEY-----` + +type Server struct { + password string + services map[string]RPCHandleFunc + relaySAS string + streams map[string]io.ReadWriter + + sshConfig *ssh.ServerConfig + httptestServer *httptest.Server + errCh chan error +} + +func NewServer(opts ...ServerOption) (*Server, error) { + server := new(Server) + + for _, o := range opts { + if err := o(server); err != nil { + return nil, err + } + } + + server.sshConfig = &ssh.ServerConfig{ + PasswordCallback: sshPasswordCallback(server.password), + } + privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey)) + if err != nil { + return nil, fmt.Errorf("error parsing key: %w", err) + } + server.sshConfig.AddHostKey(privateKey) + + server.errCh = make(chan error) + server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server))) + return server, nil +} + +type ServerOption func(*Server) error + +func WithPassword(password string) ServerOption { + return func(s *Server) error { + s.password = password + return nil + } +} + +func WithService(serviceName string, handler RPCHandleFunc) ServerOption { + return func(s *Server) error { + if s.services == nil { + s.services = make(map[string]RPCHandleFunc) + } + + s.services[serviceName] = handler + return nil + } +} + +func WithRelaySAS(sas string) ServerOption { + return func(s *Server) error { + s.relaySAS = sas + return nil + } +} + +func WithStream(name string, stream io.ReadWriter) ServerOption { + return func(s *Server) error { + if s.streams == nil { + s.streams = make(map[string]io.ReadWriter) + } + s.streams[name] = stream + return nil + } +} + +func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { + return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + if string(password) == serverPassword { + return nil, nil + } + return nil, errors.New("password rejected") + } +} + +func (s *Server) Close() { + s.httptestServer.Close() +} + +func (s *Server) URL() string { + return s.httptestServer.URL +} + +func (s *Server) Err() <-chan error { + return s.errCh +} + +var upgrader = websocket.Upgrader{} + +func makeConnection(server *Server) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if server.relaySAS != "" { + // validate the sas key + sasParam := req.URL.Query().Get("sb-hc-token") + if sasParam != server.relaySAS { + server.errCh <- errors.New("error validating sas") + return + } + } + c, err := upgrader.Upgrade(w, req, nil) + if err != nil { + server.errCh <- fmt.Errorf("error upgrading connection: %w", err) + return + } + defer c.Close() + + socketConn := newSocketConn(c) + _, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig) + if err != nil { + server.errCh <- fmt.Errorf("error creating new ssh conn: %w", err) + return + } + go ssh.DiscardRequests(reqs) + + for newChannel := range chans { + ch, reqs, err := newChannel.Accept() + if err != nil { + server.errCh <- fmt.Errorf("error accepting new channel: %w", err) + return + } + go handleNewRequests(ctx, server, ch, reqs) + go handleNewChannel(server, ch) + } + } +} + +func handleNewRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { + for req := range reqs { + if req.WantReply { + if err := req.Reply(true, nil); err != nil { + server.errCh <- fmt.Errorf("error replying to channel request: %w", err) + } + } + if strings.HasPrefix(req.Type, "stream-transport") { + forwardStream(ctx, server, req.Type, channel) + } + } +} + +func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) { + simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") + stream, found := server.streams[simpleStreamName] + if !found { + server.errCh <- fmt.Errorf("stream '%s' not found", simpleStreamName) + return + } + + copy := func(dst io.Writer, src io.Reader) { + if _, err := io.Copy(dst, src); err != nil { + fmt.Println(err) + server.errCh <- fmt.Errorf("io copy: %w", err) + return + } + } + + go copy(stream, channel) + go copy(channel, stream) + + <-ctx.Done() // TODO(josebalius): improve this +} + +func handleNewChannel(server *Server, channel ssh.Channel) { + stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) + jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server)) +} + +type RPCHandleFunc func(req *jsonrpc2.Request) (interface{}, error) + +type rpcHandler struct { + server *Server +} + +func newRPCHandler(server *Server) *rpcHandler { + return &rpcHandler{server} +} + +func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + handler, found := r.server.services[req.Method] + if !found { + r.server.errCh <- fmt.Errorf("RPC Method: '%s' not serviced", req.Method) + return + } + + result, err := handler(req) + if err != nil { + r.server.errCh <- fmt.Errorf("error handling: '%s': %w", req.Method, err) + return + } + + if err := conn.Reply(ctx, req.ID, result); err != nil { + r.server.errCh <- fmt.Errorf("error replying: %w", err) + } +} diff --git a/internal/liveshare/test/socket.go b/internal/liveshare/test/socket.go new file mode 100644 index 000000000..00cd64a1b --- /dev/null +++ b/internal/liveshare/test/socket.go @@ -0,0 +1,77 @@ +package livesharetest + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type socketConn struct { + *websocket.Conn + + reader io.Reader + writeMutex sync.Mutex + readMutex sync.Mutex +} + +func newSocketConn(conn *websocket.Conn) *socketConn { + return &socketConn{Conn: conn} +} + +func (s *socketConn) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + msgType, r, err := s.Conn.NextReader() + if err != nil { + return 0, fmt.Errorf("error getting next reader: %w", err) + } + if msgType != websocket.BinaryMessage { + return 0, fmt.Errorf("invalid message type") + } + s.reader = r + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socketConn) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + w, err := s.Conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, fmt.Errorf("error getting next writer: %w", err) + } + + n, err := w.Write(b) + if err != nil { + return 0, fmt.Errorf("error writing: %w", err) + } + + if err := w.Close(); err != nil { + return 0, fmt.Errorf("error closing writer: %w", err) + } + + return n, nil +} + +func (s *socketConn) SetDeadline(deadline time.Time) error { + if err := s.Conn.SetReadDeadline(deadline); err != nil { + return err + } + return s.Conn.SetWriteDeadline(deadline) +}