Merge pull request #1 from github/jg/refactor

Refactors most of the library to solidify some of the implementation with tests
This commit is contained in:
Jose Garcia 2021-07-27 19:23:37 -04:00 committed by GitHub
commit 39fe550aeb
17 changed files with 1019 additions and 443 deletions

130
api.go
View file

@ -1,130 +0,0 @@
package liveshare
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
)
type api struct {
client *Client
httpClient *http.Client
serviceURI string
workspaceID string
}
func newAPI(client *Client) *api {
serviceURI := client.liveShare.Configuration.LiveShareEndpoint
if !strings.HasSuffix(client.liveShare.Configuration.LiveShareEndpoint, "/") {
serviceURI = client.liveShare.Configuration.LiveShareEndpoint + "/"
}
if !strings.Contains(serviceURI, "api/v1.2") {
serviceURI = serviceURI + "api/v1.2"
}
serviceURI = strings.TrimSuffix(serviceURI, "/")
return &api{client, &http.Client{}, serviceURI, strings.ToUpper(client.liveShare.Configuration.WorkspaceID)}
}
type workspaceAccessResponse struct {
SessionToken string `json:"sessionToken"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
Name string `json:"name"`
OwnerID string `json:"ownerId"`
JoinLink string `json:"joinLink"`
ConnectLinks []string `json:"connectLinks"`
RelayLink string `json:"relayLink"`
RelaySas string `json:"relaySas"`
HostPublicKeys []string `json:"hostPublicKeys"`
ConversationID string `json:"conversationId"`
AssociatedUserIDs map[string]string `json:"associatedUserIds"`
AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"`
IsHostConnected bool `json:"isHostConnected"`
ExpiresAt string `json:"expiresAt"`
InvitationLinks []string `json:"invitationLinks"`
ID string `json:"id"`
}
func (a *api) workspaceAccess() (*workspaceAccessResponse, error) {
url := fmt.Sprintf("%s/workspace/%s/user", a.serviceURI, a.workspaceID)
req, err := http.NewRequest(http.MethodPut, url, nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
a.setDefaultHeaders(req)
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %v", err)
}
var response workspaceAccessResponse
if err := json.Unmarshal(b, &response); err != nil {
return nil, fmt.Errorf("error unmarshaling response into json: %v", err)
}
return &response, nil
}
func (a *api) setDefaultHeaders(req *http.Request) {
req.Header.Set("Authorization", "Bearer "+a.client.liveShare.Configuration.Token)
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Content-Type", "application/json")
}
type workspaceInfoResponse struct {
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
Name string `json:"name"`
OwnerID string `json:"ownerId"`
JoinLink string `json:"joinLink"`
ConnectLinks []string `json:"connectLinks"`
RelayLink string `json:"relayLink"`
RelaySas string `json:"relaySas"`
HostPublicKeys []string `json:"hostPublicKeys"`
ConversationID string `json:"conversationId"`
AssociatedUserIDs map[string]string
AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"`
IsHostConnected bool `json:"isHostConnected"`
ExpiresAt string `json:"expiresAt"`
InvitationLinks []string `json:"invitationLinks"`
ID string `json:"id"`
}
func (a *api) workspaceInfo() (*workspaceInfoResponse, error) {
url := fmt.Sprintf("%s/workspace/%s", a.serviceURI, a.workspaceID)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
a.setDefaultHeaders(req)
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %v", err)
}
var response workspaceInfoResponse
if err := json.Unmarshal(b, &response); err != nil {
return nil, fmt.Errorf("error unmarshaling response into json: %v", err)
}
return &response, nil
}

View file

@ -2,42 +2,69 @@ package liveshare
import (
"context"
"crypto/tls"
"fmt"
"golang.org/x/crypto/ssh"
)
// A Client capable of joining a liveshare connection
type Client struct {
liveShare *LiveShare
session *session
sshSession *sshSession
rpc *rpc
connection Connection
tlsConfig *tls.Config
ssh *sshSession
rpc *rpcClient
}
// NewClient is a function ...
func (l *LiveShare) NewClient() *Client {
return &Client{liveShare: l}
}
// A ClientOption is a function that modifies a client
type ClientOption func(*Client) error
func (c *Client) Join(ctx context.Context) (err error) {
api := newAPI(c)
// NewClient accepts a range of options, applies them and returns a client
func NewClient(opts ...ClientOption) (*Client, error) {
client := new(Client)
c.session = newSession(api)
if err := c.session.init(ctx); err != nil {
return fmt.Errorf("error creating session: %v", err)
for _, o := range opts {
if err := o(client); err != nil {
return nil, err
}
}
websocket := newWebsocket(c.session)
if err := websocket.connect(ctx); err != nil {
return client, nil
}
// WithConnection is a ClientOption that accepts a Connection
func WithConnection(connection Connection) ClientOption {
return func(c *Client) error {
if err := connection.validate(); err != nil {
return err
}
c.connection = connection
return nil
}
}
func WithTLSConfig(tlsConfig *tls.Config) ClientOption {
return func(c *Client) error {
c.tlsConfig = tlsConfig
return nil
}
}
// Join is a method that joins the client to the liveshare session
func (c *Client) Join(ctx context.Context) (err error) {
clientSocket := newSocket(c.connection, c.tlsConfig)
if err := clientSocket.connect(ctx); err != nil {
return fmt.Errorf("error connecting websocket: %v", err)
}
c.sshSession = newSSH(c.session, websocket)
if err := c.sshSession.connect(ctx); err != nil {
c.ssh = newSshSession(c.connection.SessionToken, clientSocket)
if err := c.ssh.connect(ctx); err != nil {
return fmt.Errorf("error connecting to ssh session: %v", err)
}
c.rpc = newRPC(c.sshSession)
c.rpc = newRpcClient(c.ssh)
c.rpc.connect(ctx)
_, err = c.joinWorkspace(ctx)
@ -49,7 +76,7 @@ func (c *Client) Join(ctx context.Context) (err error) {
}
func (c *Client) hasJoined() bool {
return c.sshSession != nil && c.rpc != nil
return c.ssh != nil && c.rpc != nil
}
type clientCapabilities struct {
@ -69,9 +96,9 @@ type joinWorkspaceResult struct {
func (c *Client) joinWorkspace(ctx context.Context) (*joinWorkspaceResult, error) {
args := joinWorkspaceArgs{
ID: c.session.workspaceInfo.ID,
ID: c.connection.SessionID,
ConnectionMode: "local",
JoiningUserSessionToken: c.session.workspaceAccess.SessionToken,
JoiningUserSessionToken: c.connection.SessionToken,
ClientCapabilities: clientCapabilities{
IsNonInteractive: false,
},
@ -92,15 +119,14 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition
return nil, fmt.Errorf("error getting stream id: %v", err)
}
channel, reqs, err := c.sshSession.conn.OpenChannel("session", nil)
channel, reqs, err := c.ssh.conn.OpenChannel("session", nil)
if err != nil {
return nil, fmt.Errorf("error opening ssh channel for transport: %v", err)
}
go ssh.DiscardRequests(reqs)
requestType := fmt.Sprintf("stream-transport-%s", streamID)
_, err = channel.SendRequest(requestType, true, nil)
if err != nil {
if _, err = channel.SendRequest(requestType, true, nil); err != nil {
return nil, fmt.Errorf("error sending channel request: %v", err)
}

109
client_test.go Normal file
View file

@ -0,0 +1,109 @@
package liveshare
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
livesharetest "github.com/github/go-liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func TestNewClient(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Errorf("error creating new client: %v", err)
}
if client == nil {
t.Error("client is nil")
}
}
func TestNewClientValidConnection(t *testing.T) {
connection := Connection{"1", "2", "3", "4"}
client, err := NewClient(WithConnection(connection))
if err != nil {
t.Errorf("error creating new client: %v", err)
}
if client == nil {
t.Error("client is nil")
}
}
func TestNewClientWithInvalidConnection(t *testing.T) {
connection := Connection{}
if _, err := NewClient(WithConnection(connection)); err == nil {
t.Error("err is nil")
}
}
func TestClientJoin(t *testing.T) {
connection := Connection{
SessionID: "session-id",
SessionToken: "session-token",
RelaySAS: "relay-sas",
}
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
var joinWorkspaceReq joinWorkspaceArgs
if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil {
return nil, fmt.Errorf("error unmarshaling req: %v", err)
}
if joinWorkspaceReq.ID != connection.SessionID {
return nil, errors.New("connection session id does not match")
}
if joinWorkspaceReq.ConnectionMode != "local" {
return nil, errors.New("connection mode is not local")
}
if joinWorkspaceReq.JoiningUserSessionToken != connection.SessionToken {
return nil, errors.New("connection user token does not match")
}
if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false {
return nil, errors.New("non interactive is not false")
}
return joinWorkspaceResult{1}, nil
}
server, err := livesharetest.NewServer(
livesharetest.WithPassword(connection.SessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
livesharetest.WithRelaySAS(connection.RelaySAS),
)
if err != nil {
t.Errorf("error creating liveshare server: %v", err)
}
defer server.Close()
connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https")
ctx := context.Background()
tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true})
client, err := NewClient(WithConnection(connection), tlsConfig)
if err != nil {
t.Errorf("error creating new client: %v", err)
}
done := make(chan error)
go func() {
if err := client.Join(ctx); err != nil {
done <- fmt.Errorf("error joining client: %v", err)
return
}
done <- nil
}()
select {
case err := <-server.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}

44
connection.go Normal file
View file

@ -0,0 +1,44 @@
package liveshare
import (
"errors"
"net/url"
"strings"
)
// A Connection represents a set of values necessary to join a liveshare connection
type Connection struct {
SessionID string
SessionToken string
RelaySAS string
RelayEndpoint string
}
func (r Connection) validate() error {
if r.SessionID == "" {
return errors.New("connection SessionID is required")
}
if r.SessionToken == "" {
return errors.New("connection SessionToken is required")
}
if r.RelaySAS == "" {
return errors.New("connection RelaySAS is required")
}
if r.RelayEndpoint == "" {
return errors.New("connection RelayEndpoint is required")
}
return nil
}
func (r Connection) uri(action string) string {
sas := url.QueryEscape(r.RelaySAS)
uri := r.RelayEndpoint
uri = strings.Replace(uri, "sb:", "wss:", -1)
uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1)
uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas
return uri
}

33
connection_test.go Normal file
View file

@ -0,0 +1,33 @@
package liveshare
import "testing"
func TestConnectionValid(t *testing.T) {
conn := Connection{"sess-id", "sess-token", "sas", "endpoint"}
if err := conn.validate(); err != nil {
t.Error(err)
}
}
func TestConnectionInvalid(t *testing.T) {
conn := Connection{"", "sess-token", "sas", "endpoint"}
if err := conn.validate(); err == nil {
t.Error(err)
}
conn = Connection{"sess-id", "", "sas", "endpoint"}
if err := conn.validate(); err == nil {
t.Error(err)
}
conn = Connection{"sess-id", "sess-token", "", "endpoint"}
if err := conn.validate(); err == nil {
t.Error(err)
}
conn = Connection{"sess-id", "sess-token", "sas", ""}
if err := conn.validate(); err == nil {
t.Error(err)
}
conn = Connection{"", "", "", ""}
if err := conn.validate(); err == nil {
t.Error(err)
}
}

View file

@ -1,77 +0,0 @@
package liveshare
import (
"errors"
"fmt"
"strings"
)
type LiveShare struct {
Configuration *Configuration
}
func New(opts ...Option) (*LiveShare, error) {
configuration := NewConfiguration()
for _, o := range opts {
if err := o(configuration); err != nil {
return nil, fmt.Errorf("error configuring liveshare: %v", err)
}
}
if err := configuration.Validate(); err != nil {
return nil, fmt.Errorf("error validating configuration: %v", err)
}
return &LiveShare{Configuration: configuration}, nil
}
type Option func(configuration *Configuration) error
func WithWorkspaceID(id string) Option {
return func(configuration *Configuration) error {
configuration.WorkspaceID = id
return nil
}
}
func WithLiveShareEndpoint(liveShareEndpoint string) Option {
return func(configuration *Configuration) error {
configuration.LiveShareEndpoint = liveShareEndpoint
return nil
}
}
func WithToken(token string) Option {
return func(configuration *Configuration) error {
configuration.Token = token
return nil
}
}
type Configuration struct {
WorkspaceID, LiveShareEndpoint, Token string
}
func NewConfiguration() *Configuration {
return &Configuration{
LiveShareEndpoint: "https://prod.liveshare.vsengsaas.visualstudio.com",
}
}
func (c *Configuration) Validate() error {
errs := []string{}
if c.WorkspaceID == "" {
errs = append(errs, "WorkspaceID is required")
}
if c.Token == "" {
errs = append(errs, "Token is required")
}
if len(errs) > 0 {
return errors.New(strings.Join(errs, ", "))
}
return nil
}

View file

@ -4,25 +4,30 @@ import (
"context"
"fmt"
"io"
"log"
"net"
"strconv"
"golang.org/x/crypto/ssh"
)
type LocalPortForwarder struct {
client *Client
server *Server
port int
channels []ssh.Channel
// A PortForwader can forward ports from a remote liveshare host to localhost
type PortForwarder struct {
client *Client
server *Server
port int
errCh chan error
}
func NewLocalPortForwarder(client *Client, server *Server, port int) *LocalPortForwarder {
return &LocalPortForwarder{client, server, port, []ssh.Channel{}}
// NewPortForwarder creates a new PortForwader with a given client, server and port
func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder {
return &PortForwarder{
client: client,
server: server,
port: port,
errCh: make(chan error),
}
}
func (l *LocalPortForwarder) Start(ctx context.Context) error {
// Start is a method to start forwarding the server to a localhost port
func (l *PortForwarder) Start(ctx context.Context) error {
ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port))
if err != nil {
return fmt.Errorf("error listening on tcp port: %v", err)
@ -37,24 +42,24 @@ func (l *LocalPortForwarder) Start(ctx context.Context) error {
go l.handleConnection(ctx, conn)
}
// clean up after ourselves
return nil
}
func (l *LocalPortForwarder) handleConnection(ctx context.Context, conn net.Conn) {
func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) {
channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition)
if err != nil {
log.Println("errrr handle Connect")
log.Println(err) // TODO(josebalius) handle this somehow
l.errCh <- fmt.Errorf("error opening streaming channel for new connection: %v", err)
return
}
l.channels = append(l.channels, channel)
copyConn := func(writer io.Writer, reader io.Reader) {
_, err := io.Copy(writer, reader)
if err != nil {
if _, err := io.Copy(writer, reader); err != nil {
fmt.Println(err)
channel.Close()
conn.Close()
if err != io.EOF {
l.errCh <- fmt.Errorf("tunnel connection: %v", err)
}
}
}

103
port_forwarder_test.go Normal file
View file

@ -0,0 +1,103 @@
package liveshare
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"testing"
"time"
livesharetest "github.com/github/go-liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func TestNewPortForwarder(t *testing.T) {
testServer, client, err := makeMockJoinedClient()
if err != nil {
t.Errorf("create mock client: %v", err)
}
defer testServer.Close()
server, err := NewServer(client)
if err != nil {
t.Errorf("create new server: %v", err)
}
pf := NewPortForwarder(client, server, 80)
if pf == nil {
t.Error("port forwarder is nil")
}
}
func TestPortForwarderStart(t *testing.T) {
streamName, streamCondition := "stream-name", "stream-condition"
serverSharing := func(req *jsonrpc2.Request) (interface{}, error) {
return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
}
getStream := func(req *jsonrpc2.Request) (interface{}, error) {
return "stream-id", nil
}
stream := bytes.NewBufferString("stream-data")
testServer, client, err := makeMockJoinedClient(
livesharetest.WithService("serverSharing.startSharing", serverSharing),
livesharetest.WithService("streamManager.getStream", getStream),
livesharetest.WithStream("stream-id", stream),
)
if err != nil {
t.Errorf("create mock client: %v", err)
}
defer testServer.Close()
server, err := NewServer(client)
if err != nil {
t.Errorf("create new server: %v", err)
}
ctx, _ := context.WithCancel(context.Background())
pf := NewPortForwarder(client, server, 8000)
done := make(chan error)
go func() {
if err := server.StartSharing(ctx, "http", 8000); err != nil {
done <- fmt.Errorf("start sharing: %v", err)
}
if err := pf.Start(ctx); err != nil {
done <- err
}
done <- nil
}()
go func() {
var conn net.Conn
retries := 0
for conn == nil && retries < 2 {
conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second)
time.Sleep(1 * time.Second)
}
if conn == nil {
done <- errors.New("failed to connect to forwarded port")
}
b := make([]byte, len("stream-data"))
if _, err := conn.Read(b); err != nil && err != io.EOF {
done <- fmt.Errorf("reading stream: %v", err)
}
if string(b) != "stream-data" {
done <- fmt.Errorf("stream data is not expected value, got: %v", string(b))
}
if _, err := conn.Write([]byte("new-data")); err != nil {
done <- fmt.Errorf("writing to stream: %v", err)
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}

17
rpc.go
View file

@ -9,32 +9,27 @@ import (
"github.com/sourcegraph/jsonrpc2"
)
type rpc struct {
type rpcClient struct {
*jsonrpc2.Conn
conn io.ReadWriteCloser
handler *rpcHandler
}
func newRPC(conn io.ReadWriteCloser) *rpc {
return &rpc{conn: conn, handler: newRPCHandler()}
func newRpcClient(conn io.ReadWriteCloser) *rpcClient {
return &rpcClient{conn: conn, handler: newRPCHandler()}
}
func (r *rpc) connect(ctx context.Context) {
func (r *rpcClient) connect(ctx context.Context) {
stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{})
r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler)
}
func (r *rpc) do(ctx context.Context, method string, args interface{}, result interface{}) error {
func (r *rpcClient) do(ctx context.Context, method string, args interface{}, result interface{}) error {
waiter, err := r.Conn.DispatchCall(ctx, method, args)
if err != nil {
return fmt.Errorf("error on dispatch call: %v", err)
}
// caller doesn't care about result, so lets ignore it
if result == nil {
return nil
}
return waiter.Wait(ctx, result)
}
@ -78,7 +73,5 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr
r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{}
}()
} else {
// TODO(josebalius): Handle
}
}

View file

@ -7,20 +7,23 @@ import (
"strconv"
)
// A Server represents the liveshare host and container server
type Server struct {
client *Client
port int
streamName, streamCondition string
}
func (c *Client) NewServer() (*Server, error) {
if !c.hasJoined() {
return nil, errors.New("LiveShareClient must join before creating server")
// NewServer creates a new Server with a given Client
func NewServer(client *Client) (*Server, error) {
if !client.hasJoined() {
return nil, errors.New("client must join before creating server")
}
return &Server{client: c}, nil
return &Server{client: client}, nil
}
// Port represents an open port on the container
type Port struct {
SourcePort int `json:"sourcePort"`
DestinationPort int `json:"destinationPort"`
@ -33,6 +36,7 @@ type Port struct {
HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"`
}
// StartSharing tells the liveshare host to start sharing the port from the container
func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error {
s.port = port
@ -49,8 +53,10 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er
return nil
}
// Ports is a slice of Port pointers
type Ports []*Port
// GetSharedServers returns a list of available/open ports from the container
func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) {
var response Ports
if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil {
@ -60,6 +66,8 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) {
return response, nil
}
// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly
// via the Browse URL
func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error {
if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil {
return err

237
server_test.go Normal file
View file

@ -0,0 +1,237 @@
package liveshare
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
livesharetest "github.com/github/go-liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func TestNewServerWithNotJoinedClient(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Errorf("error creating new client: %v", err)
}
if _, err := NewServer(client); err == nil {
t.Error("expected error")
}
}
func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) {
connection := Connection{
SessionID: "session-id",
SessionToken: "session-token",
RelaySAS: "relay-sas",
}
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
}
opts = append(
opts,
livesharetest.WithPassword(connection.SessionToken),
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
)
testServer, err := livesharetest.NewServer(
opts...,
)
connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https")
tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true})
client, err := NewClient(WithConnection(connection), tlsConfig)
if err != nil {
return nil, nil, fmt.Errorf("error creating new client: %v", err)
}
ctx := context.Background()
if err := client.Join(ctx); err != nil {
return nil, nil, fmt.Errorf("error joining client: %v", err)
}
return testServer, client, nil
}
func TestNewServer(t *testing.T) {
testServer, client, err := makeMockJoinedClient()
defer testServer.Close()
if err != nil {
t.Errorf("error creating mock joined client: %v", err)
}
server, err := NewServer(client)
if err != nil {
t.Errorf("error creating new server: %v", err)
}
if server == nil {
t.Error("server is nil")
}
}
func TestServerStartSharing(t *testing.T) {
serverPort, serverProtocol := 2222, "sshd"
startSharing := func(req *jsonrpc2.Request) (interface{}, error) {
var args []interface{}
if err := json.Unmarshal(*req.Params, &args); err != nil {
return nil, fmt.Errorf("error unmarshaling request: %v", err)
}
if len(args) < 3 {
return nil, errors.New("not enough arguments to start sharing")
}
if port, ok := args[0].(float64); !ok {
return nil, errors.New("port argument is not an int")
} else if port != float64(serverPort) {
return nil, errors.New("port does not match serverPort")
}
if protocol, ok := args[1].(string); !ok {
return nil, errors.New("protocol argument is not a string")
} else if protocol != serverProtocol {
return nil, errors.New("protocol does not match serverProtocol")
}
if browseURL, ok := args[2].(string); !ok {
return nil, errors.New("browse url is not a string")
} else if browseURL != fmt.Sprintf("http://localhost:%v", serverPort) {
return nil, errors.New("browseURL does not match expected")
}
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
}
testServer, client, err := makeMockJoinedClient(
livesharetest.WithService("serverSharing.startSharing", startSharing),
)
defer testServer.Close()
if err != nil {
t.Errorf("error creating mock joined client: %v", err)
}
server, err := NewServer(client)
if err != nil {
t.Errorf("error creating new server: %v", err)
}
ctx := context.Background()
done := make(chan error)
go func() {
if err := server.StartSharing(ctx, serverProtocol, serverPort); err != nil {
done <- fmt.Errorf("error sharing server: %v", err)
}
if server.streamName == "" || server.streamCondition == "" {
done <- errors.New("stream name or condition is blank")
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}
func TestServerGetSharedServers(t *testing.T) {
sharedServer := Port{
SourcePort: 2222,
StreamName: "stream-name",
StreamCondition: "stream-condition",
}
getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) {
return Ports{&sharedServer}, nil
}
testServer, client, err := makeMockJoinedClient(
livesharetest.WithService("serverSharing.getSharedServers", getSharedServers),
)
if err != nil {
t.Errorf("error creating new mock client: %v", err)
}
defer testServer.Close()
server, err := NewServer(client)
if err != nil {
t.Errorf("error creating new server: %v", err)
}
ctx := context.Background()
done := make(chan error)
go func() {
ports, err := server.GetSharedServers(ctx)
if err != nil {
done <- fmt.Errorf("error getting shared servers: %v", err)
}
if len(ports) < 1 {
done <- errors.New("not enough ports returned")
}
if ports[0].SourcePort != sharedServer.SourcePort {
done <- errors.New("source port does not match")
}
if ports[0].StreamName != sharedServer.StreamName {
done <- errors.New("stream name does not match")
}
if ports[0].StreamCondition != sharedServer.StreamCondition {
done <- errors.New("stream condiion does not match")
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}
func TestServerUpdateSharedVisibility(t *testing.T) {
updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %v", err)
}
if len(req) < 2 {
return nil, errors.New("request arguments is less than 2")
}
if port, ok := req[0].(float64); ok {
if port != 80.0 {
return nil, errors.New("port param is not expected value")
}
} else {
return nil, errors.New("port param is not a float64")
}
if public, ok := req[1].(bool); ok {
if public != true {
return nil, errors.New("pulic param is not expected value")
}
} else {
return nil, errors.New("public param is not a bool")
}
return nil, nil
}
testServer, client, err := makeMockJoinedClient(
livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility),
)
if err != nil {
t.Errorf("creating new mock client: %v", err)
}
defer testServer.Close()
server, err := NewServer(client)
if err != nil {
t.Errorf("creating server: %v", err)
}
ctx := context.Background()
done := make(chan error)
go func() {
if err := server.UpdateSharedVisibility(ctx, 80, true); err != nil {
done <- err
return
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}

View file

@ -1,60 +0,0 @@
package liveshare
import (
"context"
"fmt"
"net/url"
"strings"
"golang.org/x/sync/errgroup"
)
type session struct {
api *api
workspaceAccess *workspaceAccessResponse
workspaceInfo *workspaceInfoResponse
}
func newSession(api *api) *session {
return &session{api: api}
}
func (s *session) init(ctx context.Context) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
workspaceAccess, err := s.api.workspaceAccess()
if err != nil {
return fmt.Errorf("error getting workspace access: %v", err)
}
s.workspaceAccess = workspaceAccess
return nil
})
g.Go(func() error {
workspaceInfo, err := s.api.workspaceInfo()
if err != nil {
return fmt.Errorf("error getting workspace info: %v", err)
}
s.workspaceInfo = workspaceInfo
return nil
})
if err := g.Wait(); err != nil {
return err
}
return nil
}
// Reference:
// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137
func (s *session) relayURI(action string) string {
relaySas := url.QueryEscape(s.workspaceAccess.RelaySas)
relayURI := s.workspaceAccess.RelayLink
relayURI = strings.Replace(relayURI, "sb:", "wss:", -1)
relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1)
relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas
return relayURI
}

100
socket.go Normal file
View file

@ -0,0 +1,100 @@
package liveshare
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"time"
"github.com/gorilla/websocket"
)
type socket struct {
addr string
tlsConfig *tls.Config
conn *websocket.Conn
reader io.Reader
}
func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket {
return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig}
}
func (s *socket) connect(ctx context.Context) error {
dialer := websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
TLSClientConfig: s.tlsConfig,
}
ws, _, err := dialer.Dial(s.addr, nil)
if err != nil {
return err
}
s.conn = ws
return nil
}
func (s *socket) Read(b []byte) (int, error) {
if s.reader == nil {
_, reader, err := s.conn.NextReader()
if err != nil {
return 0, err
}
s.reader = reader
}
bytesRead, err := s.reader.Read(b)
if err != nil {
s.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (s *socket) Write(b []byte) (int, error) {
nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, err
}
bytesWritten, err := nextWriter.Write(b)
nextWriter.Close()
return bytesWritten, err
}
func (s *socket) Close() error {
return s.conn.Close()
}
func (s *socket) LocalAddr() net.Addr {
return s.conn.LocalAddr()
}
func (s *socket) RemoteAddr() net.Addr {
return s.conn.RemoteAddr()
}
func (s *socket) SetDeadline(t time.Time) error {
if err := s.SetReadDeadline(t); err != nil {
return err
}
return s.SetWriteDeadline(t)
}
func (s *socket) SetReadDeadline(t time.Time) error {
return s.conn.SetReadDeadline(t)
}
func (s *socket) SetWriteDeadline(t time.Time) error {
return s.conn.SetWriteDeadline(t)
}

16
ssh.go
View file

@ -12,22 +12,22 @@ import (
type sshSession struct {
*ssh.Session
session *session
socket net.Conn
conn ssh.Conn
reader io.Reader
writer io.Writer
token string
socket net.Conn
conn ssh.Conn
reader io.Reader
writer io.Writer
}
func newSSH(session *session, socket net.Conn) *sshSession {
return &sshSession{session: session, socket: socket}
func newSshSession(token string, socket net.Conn) *sshSession {
return &sshSession{token: token, socket: socket}
}
func (s *sshSession) connect(ctx context.Context) error {
clientConfig := ssh.ClientConfig{
User: "",
Auth: []ssh.AuthMethod{
ssh.Password(s.session.workspaceAccess.SessionToken),
ssh.Password(s.token),
},
HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),

View file

@ -13,13 +13,13 @@ type Terminal struct {
client *Client
}
func (c *Client) NewTerminal() (*Terminal, error) {
if !c.hasJoined() {
return nil, errors.New("LiveShareClient must join before creating terminal")
func NewTerminal(client *Client) (*Terminal, error) {
if !client.hasJoined() {
return nil, errors.New("client must join before creating terminal")
}
return &Terminal{
client: c,
client: client,
}, nil
}

290
test/server.go Normal file
View file

@ -0,0 +1,290 @@
package livesharetest
import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/sourcegraph/jsonrpc2"
"golang.org/x/crypto/ssh"
)
type Server struct {
password string
services map[string]RpcHandleFunc
relaySAS string
streams map[string]io.ReadWriter
sshConfig *ssh.ServerConfig
httptestServer *httptest.Server
errCh chan error
}
func NewServer(opts ...ServerOption) (*Server, error) {
server := new(Server)
for _, o := range opts {
if err := o(server); err != nil {
return nil, err
}
}
server.sshConfig = &ssh.ServerConfig{
PasswordCallback: sshPasswordCallback(server.password),
}
b, err := ioutil.ReadFile(filepath.Join("test", "private.key"))
if err != nil {
return nil, fmt.Errorf("error reading private.key: %v", err)
}
privateKey, err := ssh.ParsePrivateKey(b)
if err != nil {
return nil, fmt.Errorf("error parsing key: %v", err)
}
server.sshConfig.AddHostKey(privateKey)
server.errCh = make(chan error)
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
return server, nil
}
type ServerOption func(*Server) error
func WithPassword(password string) ServerOption {
return func(s *Server) error {
s.password = password
return nil
}
}
func WithService(serviceName string, handler RpcHandleFunc) ServerOption {
return func(s *Server) error {
if s.services == nil {
s.services = make(map[string]RpcHandleFunc)
}
s.services[serviceName] = handler
return nil
}
}
func WithRelaySAS(sas string) ServerOption {
return func(s *Server) error {
s.relaySAS = sas
return nil
}
}
func WithStream(name string, stream io.ReadWriter) ServerOption {
return func(s *Server) error {
if s.streams == nil {
s.streams = make(map[string]io.ReadWriter)
}
s.streams[name] = stream
return nil
}
}
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if string(password) == serverPassword {
return nil, nil
}
return nil, errors.New("password rejected")
}
}
func (s *Server) Close() {
s.httptestServer.Close()
}
func (s *Server) URL() string {
return s.httptestServer.URL
}
func (s *Server) Err() <-chan error {
return s.errCh
}
var upgrader = websocket.Upgrader{}
func makeConnection(server *Server) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
if server.relaySAS != "" {
// validate the sas key
sasParam := req.URL.Query().Get("sb-hc-token")
if sasParam != server.relaySAS {
server.errCh <- errors.New("error validating sas")
return
}
}
c, err := upgrader.Upgrade(w, req, nil)
if err != nil {
server.errCh <- fmt.Errorf("error upgrading connection: %v", err)
return
}
defer c.Close()
socketConn := newSocketConn(c)
_, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
if err != nil {
server.errCh <- fmt.Errorf("error creating new ssh conn: %v", err)
return
}
go ssh.DiscardRequests(reqs)
for newChannel := range chans {
ch, reqs, err := newChannel.Accept()
if err != nil {
server.errCh <- fmt.Errorf("error accepting new channel: %v", err)
return
}
go handleNewRequests(server, ch, reqs)
go handleNewChannel(server, ch)
}
}
}
func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) {
for req := range reqs {
if req.WantReply {
if err := req.Reply(true, nil); err != nil {
server.errCh <- fmt.Errorf("error replying to channel request: %v", err)
}
}
if strings.HasPrefix(req.Type, "stream-transport") {
forwardStream(server, req.Type, channel)
}
}
}
func forwardStream(server *Server, streamName string, channel ssh.Channel) {
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
stream, found := server.streams[simpleStreamName]
if !found {
server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName)
return
}
copy := func(dst io.Writer, src io.Reader) {
if _, err := io.Copy(dst, src); err != nil {
fmt.Println(err)
server.errCh <- fmt.Errorf("io copy: %v", err)
return
}
}
go copy(stream, channel)
go copy(channel, stream)
for {
}
}
func handleNewChannel(server *Server, channel ssh.Channel) {
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server))
}
type RpcHandleFunc func(req *jsonrpc2.Request) (interface{}, error)
type rpcHandler struct {
server *Server
}
func newRpcHandler(server *Server) *rpcHandler {
return &rpcHandler{server}
}
func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
handler, found := r.server.services[req.Method]
if !found {
r.server.errCh <- fmt.Errorf("RPC Method: '%v' not serviced", req.Method)
return
}
result, err := handler(req)
if err != nil {
r.server.errCh <- fmt.Errorf("error handling: '%v': %v", req.Method, err)
return
}
if err := conn.Reply(ctx, req.ID, result); err != nil {
r.server.errCh <- fmt.Errorf("error replying: %v", err)
}
}
type socketConn struct {
*websocket.Conn
reader io.Reader
writeMutex sync.Mutex
readMutex sync.Mutex
}
func newSocketConn(conn *websocket.Conn) *socketConn {
return &socketConn{Conn: conn}
}
func (s *socketConn) Read(b []byte) (int, error) {
s.readMutex.Lock()
defer s.readMutex.Unlock()
if s.reader == nil {
msgType, r, err := s.Conn.NextReader()
if err != nil {
return 0, fmt.Errorf("error getting next reader: %v", err)
}
if msgType != websocket.BinaryMessage {
return 0, fmt.Errorf("invalid message type")
}
s.reader = r
}
bytesRead, err := s.reader.Read(b)
if err != nil {
s.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (s *socketConn) Write(b []byte) (int, error) {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, fmt.Errorf("error getting next writer: %v", err)
}
n, err := w.Write(b)
if err != nil {
return 0, fmt.Errorf("error writing: %v", err)
}
if err := w.Close(); err != nil {
return 0, fmt.Errorf("error closing writer: %v", err)
}
return n, nil
}
func (s *socketConn) SetDeadline(deadline time.Time) error {
if err := s.Conn.SetReadDeadline(deadline); err != nil {
return err
}
return s.Conn.SetWriteDeadline(deadline)
}

View file

@ -1,105 +0,0 @@
package liveshare
import (
"context"
"errors"
"io"
"net"
"sync"
"time"
gorillawebsocket "github.com/gorilla/websocket"
)
type websocket struct {
session *session
conn *gorillawebsocket.Conn
readMutex sync.Mutex
writeMutex sync.Mutex
reader io.Reader
}
func newWebsocket(session *session) *websocket {
return &websocket{session: session}
}
func (w *websocket) connect(ctx context.Context) error {
ws, _, err := gorillawebsocket.DefaultDialer.Dial(w.session.relayURI("connect"), nil)
if err != nil {
return err
}
w.conn = ws
return nil
}
func (w *websocket) Read(b []byte) (int, error) {
w.readMutex.Lock()
defer w.readMutex.Unlock()
if w.reader == nil {
messageType, reader, err := w.conn.NextReader()
if err != nil {
return 0, err
}
if messageType != gorillawebsocket.BinaryMessage {
return 0, errors.New("unexpected websocket message type")
}
w.reader = reader
}
bytesRead, err := w.reader.Read(b)
if err != nil {
w.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (w *websocket) Write(b []byte) (int, error) {
w.writeMutex.Lock()
defer w.writeMutex.Unlock()
nextWriter, err := w.conn.NextWriter(gorillawebsocket.BinaryMessage)
if err != nil {
return 0, err
}
bytesWritten, err := nextWriter.Write(b)
nextWriter.Close()
return bytesWritten, err
}
func (w *websocket) Close() error {
return w.conn.Close()
}
func (w *websocket) LocalAddr() net.Addr {
return w.conn.LocalAddr()
}
func (w *websocket) RemoteAddr() net.Addr {
return w.conn.RemoteAddr()
}
func (w *websocket) SetDeadline(t time.Time) error {
if err := w.SetReadDeadline(t); err != nil {
return err
}
return w.SetWriteDeadline(t)
}
func (w *websocket) SetReadDeadline(t time.Time) error {
return w.conn.SetReadDeadline(t)
}
func (w *websocket) SetWriteDeadline(t time.Time) error {
return w.conn.SetWriteDeadline(t)
}