Merge pull request #189 from github/jg/inline-go-liveshare

Inline go-liveshare v0.20.0
This commit is contained in:
Jose Garcia 2021-09-23 13:45:11 -04:00 committed by GitHub
commit 4eb15134a4
17 changed files with 1364 additions and 5 deletions

View file

@ -9,7 +9,7 @@ import (
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
"github.com/github/ghcs/internal/codespaces"
"github.com/github/go-liveshare"
"github.com/github/ghcs/internal/liveshare"
"github.com/spf13/cobra"
)

View file

@ -14,7 +14,7 @@ import (
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
"github.com/github/ghcs/internal/codespaces"
"github.com/github/go-liveshare"
"github.com/github/ghcs/internal/liveshare"
"github.com/muhammadmuzzammil1998/jsonc"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"

View file

@ -9,7 +9,7 @@ import (
"github.com/github/ghcs/cmd/ghcs/output"
"github.com/github/ghcs/internal/api"
"github.com/github/ghcs/internal/codespaces"
"github.com/github/go-liveshare"
"github.com/github/ghcs/internal/liveshare"
"github.com/spf13/cobra"
)

View file

@ -7,7 +7,7 @@ import (
"time"
"github.com/github/ghcs/internal/api"
"github.com/github/go-liveshare"
"github.com/github/ghcs/internal/liveshare"
)
type logger interface {

View file

@ -10,7 +10,7 @@ import (
"time"
"github.com/github/ghcs/internal/api"
"github.com/github/go-liveshare"
"github.com/github/ghcs/internal/liveshare"
)
// PostCreateStateStatus is a string value representing the different statuses a state can have.

View file

@ -0,0 +1,149 @@
// Package liveshare is a Go client library for the Visual Studio Live Share
// service, which provides collaborative, distibuted editing and debugging.
// See https://docs.microsoft.com/en-us/visualstudio/liveshare for an overview.
//
// It provides the ability for a Go program to connect to a Live Share
// workspace (Connect), to expose a TCP port on a remote host
// (UpdateSharedVisibility), to start an SSH server listening on an
// exposed port (StartSSHServer), and to forward connections between
// the remote port and a local listening TCP port (ForwardToListener)
// or a local Go reader/writer (Forward).
package liveshare
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/url"
"strings"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/ssh"
)
// An Options specifies Live Share connection parameters.
type Options struct {
SessionID string
SessionToken string // token for SSH session
RelaySAS string
RelayEndpoint string
TLSConfig *tls.Config // (optional)
}
// uri returns a websocket URL for the specified options.
func (opts *Options) uri(action string) (string, error) {
if opts.SessionID == "" {
return "", errors.New("SessionID is required")
}
if opts.RelaySAS == "" {
return "", errors.New("RelaySAS is required")
}
if opts.RelayEndpoint == "" {
return "", errors.New("RelayEndpoint is required")
}
sas := url.QueryEscape(opts.RelaySAS)
uri := opts.RelayEndpoint
uri = strings.Replace(uri, "sb:", "wss:", -1)
uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1)
uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas
return uri, nil
}
// Connect connects to a Live Share workspace specified by the
// options, and returns a session representing the connection.
// The caller must call the session's Close method to end the session.
func Connect(ctx context.Context, opts Options) (*Session, error) {
uri, err := opts.uri("connect")
if err != nil {
return nil, err
}
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
defer span.Finish()
sock := newSocket(uri, opts.TLSConfig)
if err := sock.connect(ctx); err != nil {
return nil, fmt.Errorf("error connecting websocket: %w", err)
}
if opts.SessionToken == "" {
return nil, errors.New("SessionToken is required")
}
ssh := newSSHSession(opts.SessionToken, sock)
if err := ssh.connect(ctx); err != nil {
return nil, fmt.Errorf("error connecting to ssh session: %w", err)
}
rpc := newRPCClient(ssh)
rpc.connect(ctx)
args := joinWorkspaceArgs{
ID: opts.SessionID,
ConnectionMode: "local",
JoiningUserSessionToken: opts.SessionToken,
ClientCapabilities: clientCapabilities{
IsNonInteractive: false,
},
}
var result joinWorkspaceResult
if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil {
return nil, fmt.Errorf("error joining Live Share workspace: %w", err)
}
return &Session{ssh: ssh, rpc: rpc}, nil
}
type clientCapabilities struct {
IsNonInteractive bool `json:"isNonInteractive"`
}
type joinWorkspaceArgs struct {
ID string `json:"id"`
ConnectionMode string `json:"connectionMode"`
JoiningUserSessionToken string `json:"joiningUserSessionToken"`
ClientCapabilities clientCapabilities `json:"clientCapabilities"`
}
type joinWorkspaceResult struct {
SessionNumber int `json:"sessionNumber"`
}
// A channelID is an identifier for an exposed port on a remote
// container that may be used to open an SSH channel to it.
type channelID struct {
name, condition string
}
func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) {
type getStreamArgs struct {
StreamName string `json:"streamName"`
Condition string `json:"condition"`
}
args := getStreamArgs{
StreamName: id.name,
Condition: id.condition,
}
var streamID string
if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
return nil, fmt.Errorf("error getting stream id: %w", err)
}
span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest")
defer span.Finish()
_ = ctx // ctx is not currently used
channel, reqs, err := s.ssh.conn.OpenChannel("session", nil)
if err != nil {
return nil, fmt.Errorf("error opening ssh channel for transport: %w", err)
}
go ssh.DiscardRequests(reqs)
requestType := fmt.Sprintf("stream-transport-%s", streamID)
if _, err = channel.SendRequest(requestType, true, nil); err != nil {
return nil, fmt.Errorf("error sending channel request: %w", err)
}
return channel, nil
}

View file

@ -0,0 +1,71 @@
package liveshare
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
livesharetest "github.com/github/ghcs/internal/liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func TestConnect(t *testing.T) {
opts := Options{
SessionID: "session-id",
SessionToken: "session-token",
RelaySAS: "relay-sas",
}
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
var joinWorkspaceReq joinWorkspaceArgs
if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil {
return nil, fmt.Errorf("error unmarshaling req: %w", err)
}
if joinWorkspaceReq.ID != opts.SessionID {
return nil, errors.New("connection session id does not match")
}
if joinWorkspaceReq.ConnectionMode != "local" {
return nil, errors.New("connection mode is not local")
}
if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken {
return nil, errors.New("connection user token does not match")
}
if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false {
return nil, errors.New("non interactive is not false")
}
return joinWorkspaceResult{1}, nil
}
server, err := livesharetest.NewServer(
livesharetest.WithPassword(opts.SessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
livesharetest.WithRelaySAS(opts.RelaySAS),
)
if err != nil {
t.Errorf("error creating Live Share server: %w", err)
}
defer server.Close()
opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https")
ctx := context.Background()
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true}
done := make(chan error)
go func() {
_, err := Connect(ctx, opts) // ignore session
done <- err
}()
select {
case err := <-server.Err():
t.Errorf("error from server: %w", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %w", err)
}
}
}

View file

@ -0,0 +1,56 @@
package liveshare
import (
"context"
"testing"
)
func TestBadOptions(t *testing.T) {
goodOptions := Options{
SessionID: "sess-id",
SessionToken: "sess-token",
RelaySAS: "sas",
RelayEndpoint: "endpoint",
}
opts := goodOptions
opts.SessionID = ""
checkBadOptions(t, opts)
opts = goodOptions
opts.SessionToken = ""
checkBadOptions(t, opts)
opts = goodOptions
opts.RelaySAS = ""
checkBadOptions(t, opts)
opts = goodOptions
opts.RelayEndpoint = ""
checkBadOptions(t, opts)
opts = Options{}
checkBadOptions(t, opts)
}
func checkBadOptions(t *testing.T, opts Options) {
if _, err := Connect(context.Background(), opts); err == nil {
t.Errorf("Connect(%+v): no error", opts)
}
}
func TestOptionsURI(t *testing.T) {
opts := Options{
SessionID: "sess-id",
SessionToken: "sess-token",
RelaySAS: "sas",
RelayEndpoint: "sb://endpoint/.net/liveshare",
}
uri, err := opts.uri("connect")
if err != nil {
t.Fatal(err)
}
if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" {
t.Errorf("uri is not correct, got: '%v'", uri)
}
}

View file

@ -0,0 +1,162 @@
package liveshare
import (
"context"
"fmt"
"io"
"net"
"github.com/opentracing/opentracing-go"
)
// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote
// container to a local destination such as a network port or Go reader/writer.
type PortForwarder struct {
session *Session
name string
remotePort int
}
// NewPortForwarder returns a new PortForwarder for the specified
// remote port and Live Share session. The name describes the purpose
// of the remote port or service.
func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder {
return &PortForwarder{
session: session,
name: name,
remotePort: remotePort,
}
}
// ForwardToListener forwards traffic between the container's remote
// port and a local port, which must already be listening for
// connections. (Accepting a listener rather than a port number avoids
// races against other processes opening ports, and against a client
// connecting to the socket prematurely.)
//
// ForwardToListener accepts and handles connections on the local port
// until it encounters the first error, which may include context
// cancellation. Its error result is always non-nil. The caller is
// responsible for closing the listening port.
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
return err
}
errc := make(chan error, 1)
sendError := func(err error) {
// Use non-blocking send, to avoid goroutines getting
// stuck in case of concurrent or sequential errors.
select {
case errc <- err:
default:
}
}
go func() {
for {
conn, err := listen.Accept()
if err != nil {
sendError(err)
return
}
go func() {
if err := fwd.handleConnection(ctx, id, conn); err != nil {
sendError(err)
}
}()
}
}()
return awaitError(ctx, errc)
}
// Forward forwards traffic between the container's remote port and
// the specified read/write stream. On return, the stream is closed.
func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error {
id, err := fwd.shareRemotePort(ctx)
if err != nil {
conn.Close()
return err
}
// Create buffered channel so that send doesn't get stuck after context cancellation.
errc := make(chan error, 1)
go func() {
errc <- fwd.handleConnection(ctx, id, conn)
}()
return awaitError(ctx, errc)
}
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) {
id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort)
if err != nil {
err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err)
}
return id, err
}
func awaitError(ctx context.Context, errc <-chan error) error {
select {
case err := <-errc:
return err
case <-ctx.Done():
return ctx.Err() // canceled
}
}
// handleConnection handles forwarding for a single accepted connection, then closes it.
func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
defer span.Finish()
defer safeClose(conn, &err)
channel, err := fwd.session.openStreamingChannel(ctx, id)
if err != nil {
return fmt.Errorf("error opening streaming channel for new connection: %w", err)
}
// Ideally we would call safeClose again, but (*ssh.channel).Close
// appears to have a bug that causes it return io.EOF spuriously
// if its peer closed first; see github.com/golang/go/issues/38115.
defer func() {
closeErr := channel.Close()
if err == nil && closeErr != io.EOF {
err = closeErr
}
}()
// bi-directional copy of data.
errs := make(chan error, 2)
copyConn := func(w io.Writer, r io.Reader) {
_, err := io.Copy(w, r)
errs <- err
}
go copyConn(conn, channel)
go copyConn(channel, conn)
// Wait until context is cancelled or both copies are done.
// Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail.
// TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file?
for i := 0; ; {
select {
case <-ctx.Done():
return ctx.Err()
case <-errs:
i++
if i == 2 {
return nil
}
}
}
}
// safeClose reports the error (to *err) from closing the stream only
// if no other error was previously reported.
func safeClose(closer io.Closer, err *error) {
closeErr := closer.Close()
if *err == nil {
*err = closeErr
}
}

View file

@ -0,0 +1,95 @@
package liveshare
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"testing"
"time"
livesharetest "github.com/github/ghcs/internal/liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func TestNewPortForwarder(t *testing.T) {
testServer, session, err := makeMockSession()
if err != nil {
t.Errorf("create mock client: %w", err)
}
defer testServer.Close()
pf := NewPortForwarder(session, "ssh", 80)
if pf == nil {
t.Error("port forwarder is nil")
}
}
func TestPortForwarderStart(t *testing.T) {
streamName, streamCondition := "stream-name", "stream-condition"
serverSharing := func(req *jsonrpc2.Request) (interface{}, error) {
return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
}
getStream := func(req *jsonrpc2.Request) (interface{}, error) {
return "stream-id", nil
}
stream := bytes.NewBufferString("stream-data")
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.startSharing", serverSharing),
livesharetest.WithService("streamManager.getStream", getStream),
livesharetest.WithStream("stream-id", stream),
)
if err != nil {
t.Errorf("create mock session: %w", err)
}
defer testServer.Close()
listen, err := net.Listen("tcp", ":8000")
if err != nil {
t.Fatal(err)
}
defer listen.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan error)
go func() {
const name, remote = "ssh", 8000
done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen)
}()
go func() {
var conn net.Conn
retries := 0
for conn == nil && retries < 2 {
conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second)
time.Sleep(1 * time.Second)
}
if conn == nil {
done <- errors.New("failed to connect to forwarded port")
}
b := make([]byte, len("stream-data"))
if _, err := conn.Read(b); err != nil && err != io.EOF {
done <- fmt.Errorf("reading stream: %w", err)
}
if string(b) != "stream-data" {
done <- fmt.Errorf("stream data is not expected value, got: %s", string(b))
}
if _, err := conn.Write([]byte("new-data")); err != nil {
done <- fmt.Errorf("writing to stream: %w", err)
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %w", err)
}
}
}

41
internal/liveshare/rpc.go Normal file
View file

@ -0,0 +1,41 @@
package liveshare
import (
"context"
"fmt"
"io"
"github.com/opentracing/opentracing-go"
"github.com/sourcegraph/jsonrpc2"
)
type rpcClient struct {
*jsonrpc2.Conn
conn io.ReadWriteCloser
}
func newRPCClient(conn io.ReadWriteCloser) *rpcClient {
return &rpcClient{conn: conn}
}
func (r *rpcClient) connect(ctx context.Context) {
stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{})
r.Conn = jsonrpc2.NewConn(ctx, stream, nullHandler{})
}
func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, method)
defer span.Finish()
waiter, err := r.Conn.DispatchCall(ctx, method, args)
if err != nil {
return fmt.Errorf("error dispatching %q call: %w", method, err)
}
return waiter.Wait(ctx, result)
}
type nullHandler struct{}
func (nullHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
}

View file

@ -0,0 +1,99 @@
package liveshare
import (
"context"
"fmt"
"strconv"
)
// A Session represents the session between a connected Live Share client and server.
type Session struct {
ssh *sshSession
rpc *rpcClient
}
// Close should be called by users to clean up RPC and SSH resources whenever the session
// is no longer active.
func (s *Session) Close() error {
// Closing the RPC conn closes the underlying stream (SSH)
// So we only need to close once
if err := s.rpc.Close(); err != nil {
s.ssh.Close() // close SSH and ignore error
return fmt.Errorf("error while closing Live Share session: %w", err)
}
return nil
}
// Port describes a port exposed by the container.
type Port struct {
SourcePort int `json:"sourcePort"`
DestinationPort int `json:"destinationPort"`
SessionName string `json:"sessionName"`
StreamName string `json:"streamName"`
StreamCondition string `json:"streamCondition"`
BrowseURL string `json:"browseUrl"`
IsPublic bool `json:"isPublic"`
IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"`
HasTLSHandshakePassed bool `json:"hasTLSHandshakePassed"`
}
// startSharing tells the Live Share host to start sharing the specified port from the container.
// The sessionName describes the purpose of the remote port or service.
// It returns an identifier that can be used to open an SSH channel to the remote port.
func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) {
args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
var response Port
if err := s.rpc.do(ctx, "serverSharing.startSharing", args, &response); err != nil {
return channelID{}, err
}
return channelID{response.StreamName, response.StreamCondition}, nil
}
// GetSharedServers returns a description of each container port
// shared by a prior call to StartSharing by some client.
func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) {
var response []*Port
if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil {
return nil, err
}
return response, nil
}
// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly
// via the Browse URL
func (s *Session) UpdateSharedVisibility(ctx context.Context, port int, public bool) error {
if err := s.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil {
return err
}
return nil
}
// StartsSSHServer starts an SSH server in the container, installing sshd if necessary,
// and returns the port on which it listens and the user name clients should provide.
func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) {
var response struct {
Result bool `json:"result"`
ServerPort string `json:"serverPort"`
User string `json:"user"`
Message string `json:"message"`
}
if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil {
return 0, "", err
}
if !response.Result {
return 0, "", fmt.Errorf("failed to start server: %s", response.Message)
}
port, err := strconv.Atoi(response.ServerPort)
if err != nil {
return 0, "", fmt.Errorf("failed to parse port: %w", err)
}
return port, response.User, nil
}

View file

@ -0,0 +1,196 @@
package liveshare
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
livesharetest "github.com/github/ghcs/internal/liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
}
const sessionToken = "session-token"
opts = append(
opts,
livesharetest.WithPassword(sessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
)
testServer, err := livesharetest.NewServer(opts...)
if err != nil {
return nil, nil, fmt.Errorf("error creating server: %w", err)
}
session, err := Connect(context.Background(), Options{
SessionID: "session-id",
SessionToken: sessionToken,
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
RelaySAS: "relay-sas",
TLSConfig: &tls.Config{InsecureSkipVerify: true},
})
if err != nil {
return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err)
}
return testServer, session, nil
}
func TestServerStartSharing(t *testing.T) {
serverPort, serverProtocol := 2222, "sshd"
startSharing := func(req *jsonrpc2.Request) (interface{}, error) {
var args []interface{}
if err := json.Unmarshal(*req.Params, &args); err != nil {
return nil, fmt.Errorf("error unmarshaling request: %w", err)
}
if len(args) < 3 {
return nil, errors.New("not enough arguments to start sharing")
}
if port, ok := args[0].(float64); !ok {
return nil, errors.New("port argument is not an int")
} else if port != float64(serverPort) {
return nil, errors.New("port does not match serverPort")
}
if protocol, ok := args[1].(string); !ok {
return nil, errors.New("protocol argument is not a string")
} else if protocol != serverProtocol {
return nil, errors.New("protocol does not match serverProtocol")
}
if browseURL, ok := args[2].(string); !ok {
return nil, errors.New("browse url is not a string")
} else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) {
return nil, errors.New("browseURL does not match expected")
}
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
}
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.startSharing", startSharing),
)
defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close()
if err != nil {
t.Errorf("error creating mock session: %w", err)
}
ctx := context.Background()
done := make(chan error)
go func() {
streamID, err := session.startSharing(ctx, serverProtocol, serverPort)
if err != nil {
done <- fmt.Errorf("error sharing server: %w", err)
}
if streamID.name == "" || streamID.condition == "" {
done <- errors.New("stream name or condition is blank")
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %w", err)
}
}
}
func TestServerGetSharedServers(t *testing.T) {
sharedServer := Port{
SourcePort: 2222,
StreamName: "stream-name",
StreamCondition: "stream-condition",
}
getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) {
return []*Port{&sharedServer}, nil
}
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.getSharedServers", getSharedServers),
)
if err != nil {
t.Errorf("error creating mock session: %w", err)
}
defer testServer.Close()
ctx := context.Background()
done := make(chan error)
go func() {
ports, err := session.GetSharedServers(ctx)
if err != nil {
done <- fmt.Errorf("error getting shared servers: %w", err)
}
if len(ports) < 1 {
done <- errors.New("not enough ports returned")
}
if ports[0].SourcePort != sharedServer.SourcePort {
done <- errors.New("source port does not match")
}
if ports[0].StreamName != sharedServer.StreamName {
done <- errors.New("stream name does not match")
}
if ports[0].StreamCondition != sharedServer.StreamCondition {
done <- errors.New("stream condiion does not match")
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %w", err)
}
}
}
func TestServerUpdateSharedVisibility(t *testing.T) {
updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %w", err)
}
if len(req) < 2 {
return nil, errors.New("request arguments is less than 2")
}
if port, ok := req[0].(float64); ok {
if port != 80.0 {
return nil, errors.New("port param is not expected value")
}
} else {
return nil, errors.New("port param is not a float64")
}
if public, ok := req[1].(bool); ok {
if public != true {
return nil, errors.New("pulic param is not expected value")
}
} else {
return nil, errors.New("public param is not a bool")
}
return nil, nil
}
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility),
)
if err != nil {
t.Errorf("creating mock session: %w", err)
}
defer testServer.Close()
ctx := context.Background()
done := make(chan error)
go func() {
done <- session.UpdateSharedVisibility(ctx, 80, true)
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %w", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %w", err)
}
}
}

View file

@ -0,0 +1,100 @@
package liveshare
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"time"
"github.com/gorilla/websocket"
)
type socket struct {
addr string
tlsConfig *tls.Config
conn *websocket.Conn
reader io.Reader
}
func newSocket(uri string, tlsConfig *tls.Config) *socket {
return &socket{addr: uri, tlsConfig: tlsConfig}
}
func (s *socket) connect(ctx context.Context) error {
dialer := websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
TLSClientConfig: s.tlsConfig,
}
ws, _, err := dialer.Dial(s.addr, nil)
if err != nil {
return err
}
s.conn = ws
return nil
}
func (s *socket) Read(b []byte) (int, error) {
if s.reader == nil {
_, reader, err := s.conn.NextReader()
if err != nil {
return 0, err
}
s.reader = reader
}
bytesRead, err := s.reader.Read(b)
if err != nil {
s.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (s *socket) Write(b []byte) (int, error) {
nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, err
}
bytesWritten, err := nextWriter.Write(b)
nextWriter.Close()
return bytesWritten, err
}
func (s *socket) Close() error {
return s.conn.Close()
}
func (s *socket) LocalAddr() net.Addr {
return s.conn.LocalAddr()
}
func (s *socket) RemoteAddr() net.Addr {
return s.conn.RemoteAddr()
}
func (s *socket) SetDeadline(t time.Time) error {
if err := s.SetReadDeadline(t); err != nil {
return err
}
return s.SetWriteDeadline(t)
}
func (s *socket) SetReadDeadline(t time.Time) error {
return s.conn.SetReadDeadline(t)
}
func (s *socket) SetWriteDeadline(t time.Time) error {
return s.conn.SetWriteDeadline(t)
}

68
internal/liveshare/ssh.go Normal file
View file

@ -0,0 +1,68 @@
package liveshare
import (
"context"
"fmt"
"io"
"net"
"time"
"golang.org/x/crypto/ssh"
)
type sshSession struct {
*ssh.Session
token string
socket net.Conn
conn ssh.Conn
reader io.Reader
writer io.Writer
}
func newSSHSession(token string, socket net.Conn) *sshSession {
return &sshSession{token: token, socket: socket}
}
func (s *sshSession) connect(ctx context.Context) error {
clientConfig := ssh.ClientConfig{
User: "",
Auth: []ssh.AuthMethod{
ssh.Password(s.token),
},
HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 10 * time.Second,
}
sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig)
if err != nil {
return fmt.Errorf("error creating ssh client connection: %w", err)
}
s.conn = sshClientConn
sshClient := ssh.NewClient(sshClientConn, chans, reqs)
s.Session, err = sshClient.NewSession()
if err != nil {
return fmt.Errorf("error creating ssh client session: %w", err)
}
s.reader, err = s.Session.StdoutPipe()
if err != nil {
return fmt.Errorf("error creating ssh session reader: %w", err)
}
s.writer, err = s.Session.StdinPipe()
if err != nil {
return fmt.Errorf("error creating ssh session writer: %w", err)
}
return nil
}
func (s *sshSession) Read(p []byte) (n int, err error) {
return s.reader.Read(p)
}
func (s *sshSession) Write(p []byte) (n int, err error) {
return s.writer.Write(p)
}

View file

@ -0,0 +1,245 @@
package livesharetest
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"github.com/gorilla/websocket"
"github.com/sourcegraph/jsonrpc2"
"golang.org/x/crypto/ssh"
)
const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEAp/Jmzy/HaPNx5Bug09FX5Q/KGY4G9c4DfplhWrn31OQCqNiT
ZSLd46rdXC75liHzE7e5Ic0RJN61cYN9SNArjvEXx2vvs7szhwO7LonwPOvpYpUf
daayrgbr6S46plpx+hEZ1kO/6BqMgFuvnkIVThrEyx5b48ll8zgDABsYrKF8/p1V
SjGfb+bLwjn1NtnZF2prBG5P4ZtMR06HaPglLqBJhmc0ZMG5IZGUE7ew/VrPDqdC
f1v4XvvGiU4BLoKYy4QOhyrCGh9Uk/9u0Ea56M2bh4RqwhbpR8m7TYJZ0DVMLbGW
8C+4lCWp+xRyBNxAQh8qeQVCxYl02hPE4bXLGQIDAQABAoIBAEoVPk6UZ+UexhV2
LnphNOFhFqgxI1bYWmhE5lHsCKuLLLUoW9RYDgL4gw6/1e7o6N3AxFRpre9Soj0B
YIl28k/qf6/DKAhjQnaDKdV8mVF2Swvmdesi7lyfxv6kGtD4wqApXPlMB2IuG94f
E5e+1MEQQ9DJgoU3eNZR1dj9GuRC3PyzPcNNJ2R/MMGFw3sOOVcLOgAukotoicuL
0SiL51rHPQu8a5/darH9EltN1GFeceJSDDhgqMP5T8Tp7g/c3//H6szon4H9W+uN
Z3UrImJ+teJjFOaVDqN93+J2eQSUk0lCPGQCd4U9I4AGDGyU6ucdcLQ58Aha9gmU
uQwkfKUCgYEA0UkuPOSDE9dbXe+yhsbOwMb1kKzJYgFDKjRTSP7D9BOMZu4YyASo
J95R4DWjePlDopafG2tNJoWX+CwUl7Uld1R3Ex6xHBa2B7hwZj860GZtr7D4mdWc
DTVjczAjp4P0K1MIFYQui1mVJterkjKuePiI6q/27L1c2jIa/39BWBcCgYEAzW8R
MFZamVw3eA2JYSpBuqhQgE5gX5IWrmVJZSUhpAQTNG/A4nxf7WGtjy9p99tm0RMb
ld05+sOmNLrzw8Pq8SBpFOd+MAca7lPLS1A2CoaAHbOqRqrzVcZ4EZ2jB3WjoLoq
yctwslGb9KmrhBCdcwT48aPAYUIJCZdqEen2xE8CgYBoMowvywGrvjwCH9X9njvP
5P7cAfrdrY04FQcmP5lmCtmLYZ267/6couaWv33dPBU9fMpIh3rI5BiOebvi8FBw
AgCq50v8lR4Z5+0mKvLoUSbpIy4SwTRJqzwRXHVT8LF/ZH6Q39egj4Bf716/kjYl
im/4kJVatsjk5a9lZ4EsDwKBgERkJ3rKJNtNggHrr8KzSLKVekdc0GTAw+BHRAny
NKLf4Gzij3pXIbBrhlZW2JZ1amNMUzCvN7AuFlUTsDeKL9saiSE2eCIRG3wgVVu7
VmJmqJw6xgNEwkHaEvr6Wd4P4euOTtRjcB9NX/gxzDHpPiGelCoN8+vtCgkxaVSR
aV+tAoGAO4HtLOfBAVDNbVXa27aJAjQSUq8qfkwUNJNz+rwgpVQahfiVkyqAPCQM
IfRJxKWb0Wbt9ojw3AowK/k0d3LZA7FS41JSiiGKIllSGb+i7JKqKW7RHLA3VJ/E
Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U=
-----END RSA PRIVATE KEY-----`
type Server struct {
password string
services map[string]RPCHandleFunc
relaySAS string
streams map[string]io.ReadWriter
sshConfig *ssh.ServerConfig
httptestServer *httptest.Server
errCh chan error
}
func NewServer(opts ...ServerOption) (*Server, error) {
server := new(Server)
for _, o := range opts {
if err := o(server); err != nil {
return nil, err
}
}
server.sshConfig = &ssh.ServerConfig{
PasswordCallback: sshPasswordCallback(server.password),
}
privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey))
if err != nil {
return nil, fmt.Errorf("error parsing key: %w", err)
}
server.sshConfig.AddHostKey(privateKey)
server.errCh = make(chan error)
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
return server, nil
}
type ServerOption func(*Server) error
func WithPassword(password string) ServerOption {
return func(s *Server) error {
s.password = password
return nil
}
}
func WithService(serviceName string, handler RPCHandleFunc) ServerOption {
return func(s *Server) error {
if s.services == nil {
s.services = make(map[string]RPCHandleFunc)
}
s.services[serviceName] = handler
return nil
}
}
func WithRelaySAS(sas string) ServerOption {
return func(s *Server) error {
s.relaySAS = sas
return nil
}
}
func WithStream(name string, stream io.ReadWriter) ServerOption {
return func(s *Server) error {
if s.streams == nil {
s.streams = make(map[string]io.ReadWriter)
}
s.streams[name] = stream
return nil
}
}
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if string(password) == serverPassword {
return nil, nil
}
return nil, errors.New("password rejected")
}
}
func (s *Server) Close() {
s.httptestServer.Close()
}
func (s *Server) URL() string {
return s.httptestServer.URL
}
func (s *Server) Err() <-chan error {
return s.errCh
}
var upgrader = websocket.Upgrader{}
func makeConnection(server *Server) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if server.relaySAS != "" {
// validate the sas key
sasParam := req.URL.Query().Get("sb-hc-token")
if sasParam != server.relaySAS {
server.errCh <- errors.New("error validating sas")
return
}
}
c, err := upgrader.Upgrade(w, req, nil)
if err != nil {
server.errCh <- fmt.Errorf("error upgrading connection: %w", err)
return
}
defer c.Close()
socketConn := newSocketConn(c)
_, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
if err != nil {
server.errCh <- fmt.Errorf("error creating new ssh conn: %w", err)
return
}
go ssh.DiscardRequests(reqs)
for newChannel := range chans {
ch, reqs, err := newChannel.Accept()
if err != nil {
server.errCh <- fmt.Errorf("error accepting new channel: %w", err)
return
}
go handleNewRequests(ctx, server, ch, reqs)
go handleNewChannel(server, ch)
}
}
}
func handleNewRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) {
for req := range reqs {
if req.WantReply {
if err := req.Reply(true, nil); err != nil {
server.errCh <- fmt.Errorf("error replying to channel request: %w", err)
}
}
if strings.HasPrefix(req.Type, "stream-transport") {
forwardStream(ctx, server, req.Type, channel)
}
}
}
func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) {
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
stream, found := server.streams[simpleStreamName]
if !found {
server.errCh <- fmt.Errorf("stream '%s' not found", simpleStreamName)
return
}
copy := func(dst io.Writer, src io.Reader) {
if _, err := io.Copy(dst, src); err != nil {
fmt.Println(err)
server.errCh <- fmt.Errorf("io copy: %w", err)
return
}
}
go copy(stream, channel)
go copy(channel, stream)
<-ctx.Done() // TODO(josebalius): improve this
}
func handleNewChannel(server *Server, channel ssh.Channel) {
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server))
}
type RPCHandleFunc func(req *jsonrpc2.Request) (interface{}, error)
type rpcHandler struct {
server *Server
}
func newRPCHandler(server *Server) *rpcHandler {
return &rpcHandler{server}
}
func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
handler, found := r.server.services[req.Method]
if !found {
r.server.errCh <- fmt.Errorf("RPC Method: '%s' not serviced", req.Method)
return
}
result, err := handler(req)
if err != nil {
r.server.errCh <- fmt.Errorf("error handling: '%s': %w", req.Method, err)
return
}
if err := conn.Reply(ctx, req.ID, result); err != nil {
r.server.errCh <- fmt.Errorf("error replying: %w", err)
}
}

View file

@ -0,0 +1,77 @@
package livesharetest
import (
"fmt"
"io"
"sync"
"time"
"github.com/gorilla/websocket"
)
type socketConn struct {
*websocket.Conn
reader io.Reader
writeMutex sync.Mutex
readMutex sync.Mutex
}
func newSocketConn(conn *websocket.Conn) *socketConn {
return &socketConn{Conn: conn}
}
func (s *socketConn) Read(b []byte) (int, error) {
s.readMutex.Lock()
defer s.readMutex.Unlock()
if s.reader == nil {
msgType, r, err := s.Conn.NextReader()
if err != nil {
return 0, fmt.Errorf("error getting next reader: %w", err)
}
if msgType != websocket.BinaryMessage {
return 0, fmt.Errorf("invalid message type")
}
s.reader = r
}
bytesRead, err := s.reader.Read(b)
if err != nil {
s.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (s *socketConn) Write(b []byte) (int, error) {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, fmt.Errorf("error getting next writer: %w", err)
}
n, err := w.Write(b)
if err != nil {
return 0, fmt.Errorf("error writing: %w", err)
}
if err := w.Close(); err != nil {
return 0, fmt.Errorf("error closing writer: %w", err)
}
return n, nil
}
func (s *socketConn) SetDeadline(deadline time.Time) error {
if err := s.Conn.SetReadDeadline(deadline); err != nil {
return err
}
return s.Conn.SetWriteDeadline(deadline)
}