Merge pull request #1 from github/jg/refactor
Refactors most of the library to solidify some of the implementation with tests
This commit is contained in:
commit
39fe550aeb
17 changed files with 1019 additions and 443 deletions
130
api.go
130
api.go
|
|
@ -1,130 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type api struct {
|
||||
client *Client
|
||||
httpClient *http.Client
|
||||
serviceURI string
|
||||
workspaceID string
|
||||
}
|
||||
|
||||
func newAPI(client *Client) *api {
|
||||
serviceURI := client.liveShare.Configuration.LiveShareEndpoint
|
||||
if !strings.HasSuffix(client.liveShare.Configuration.LiveShareEndpoint, "/") {
|
||||
serviceURI = client.liveShare.Configuration.LiveShareEndpoint + "/"
|
||||
}
|
||||
|
||||
if !strings.Contains(serviceURI, "api/v1.2") {
|
||||
serviceURI = serviceURI + "api/v1.2"
|
||||
}
|
||||
|
||||
serviceURI = strings.TrimSuffix(serviceURI, "/")
|
||||
|
||||
return &api{client, &http.Client{}, serviceURI, strings.ToUpper(client.liveShare.Configuration.WorkspaceID)}
|
||||
}
|
||||
|
||||
type workspaceAccessResponse struct {
|
||||
SessionToken string `json:"sessionToken"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
Name string `json:"name"`
|
||||
OwnerID string `json:"ownerId"`
|
||||
JoinLink string `json:"joinLink"`
|
||||
ConnectLinks []string `json:"connectLinks"`
|
||||
RelayLink string `json:"relayLink"`
|
||||
RelaySas string `json:"relaySas"`
|
||||
HostPublicKeys []string `json:"hostPublicKeys"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
AssociatedUserIDs map[string]string `json:"associatedUserIds"`
|
||||
AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"`
|
||||
IsHostConnected bool `json:"isHostConnected"`
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
InvitationLinks []string `json:"invitationLinks"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func (a *api) workspaceAccess() (*workspaceAccessResponse, error) {
|
||||
url := fmt.Sprintf("%s/workspace/%s/user", a.serviceURI, a.workspaceID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPut, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating request: %v", err)
|
||||
}
|
||||
|
||||
a.setDefaultHeaders(req)
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making request: %v", err)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: %v", err)
|
||||
}
|
||||
|
||||
var response workspaceAccessResponse
|
||||
if err := json.Unmarshal(b, &response); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling response into json: %v", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (a *api) setDefaultHeaders(req *http.Request) {
|
||||
req.Header.Set("Authorization", "Bearer "+a.client.liveShare.Configuration.Token)
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
type workspaceInfoResponse struct {
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
Name string `json:"name"`
|
||||
OwnerID string `json:"ownerId"`
|
||||
JoinLink string `json:"joinLink"`
|
||||
ConnectLinks []string `json:"connectLinks"`
|
||||
RelayLink string `json:"relayLink"`
|
||||
RelaySas string `json:"relaySas"`
|
||||
HostPublicKeys []string `json:"hostPublicKeys"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
AssociatedUserIDs map[string]string
|
||||
AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"`
|
||||
IsHostConnected bool `json:"isHostConnected"`
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
InvitationLinks []string `json:"invitationLinks"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func (a *api) workspaceInfo() (*workspaceInfoResponse, error) {
|
||||
url := fmt.Sprintf("%s/workspace/%s", a.serviceURI, a.workspaceID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating request: %v", err)
|
||||
}
|
||||
|
||||
a.setDefaultHeaders(req)
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making request: %v", err)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: %v", err)
|
||||
}
|
||||
|
||||
var response workspaceInfoResponse
|
||||
if err := json.Unmarshal(b, &response); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling response into json: %v", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
74
client.go
74
client.go
|
|
@ -2,42 +2,69 @@ package liveshare
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// A Client capable of joining a liveshare connection
|
||||
type Client struct {
|
||||
liveShare *LiveShare
|
||||
session *session
|
||||
sshSession *sshSession
|
||||
rpc *rpc
|
||||
connection Connection
|
||||
tlsConfig *tls.Config
|
||||
|
||||
ssh *sshSession
|
||||
rpc *rpcClient
|
||||
}
|
||||
|
||||
// NewClient is a function ...
|
||||
func (l *LiveShare) NewClient() *Client {
|
||||
return &Client{liveShare: l}
|
||||
}
|
||||
// A ClientOption is a function that modifies a client
|
||||
type ClientOption func(*Client) error
|
||||
|
||||
func (c *Client) Join(ctx context.Context) (err error) {
|
||||
api := newAPI(c)
|
||||
// NewClient accepts a range of options, applies them and returns a client
|
||||
func NewClient(opts ...ClientOption) (*Client, error) {
|
||||
client := new(Client)
|
||||
|
||||
c.session = newSession(api)
|
||||
if err := c.session.init(ctx); err != nil {
|
||||
return fmt.Errorf("error creating session: %v", err)
|
||||
for _, o := range opts {
|
||||
if err := o(client); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
websocket := newWebsocket(c.session)
|
||||
if err := websocket.connect(ctx); err != nil {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// WithConnection is a ClientOption that accepts a Connection
|
||||
func WithConnection(connection Connection) ClientOption {
|
||||
return func(c *Client) error {
|
||||
if err := connection.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.connection = connection
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithTLSConfig(tlsConfig *tls.Config) ClientOption {
|
||||
return func(c *Client) error {
|
||||
c.tlsConfig = tlsConfig
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Join is a method that joins the client to the liveshare session
|
||||
func (c *Client) Join(ctx context.Context) (err error) {
|
||||
clientSocket := newSocket(c.connection, c.tlsConfig)
|
||||
if err := clientSocket.connect(ctx); err != nil {
|
||||
return fmt.Errorf("error connecting websocket: %v", err)
|
||||
}
|
||||
|
||||
c.sshSession = newSSH(c.session, websocket)
|
||||
if err := c.sshSession.connect(ctx); err != nil {
|
||||
c.ssh = newSshSession(c.connection.SessionToken, clientSocket)
|
||||
if err := c.ssh.connect(ctx); err != nil {
|
||||
return fmt.Errorf("error connecting to ssh session: %v", err)
|
||||
}
|
||||
|
||||
c.rpc = newRPC(c.sshSession)
|
||||
c.rpc = newRpcClient(c.ssh)
|
||||
c.rpc.connect(ctx)
|
||||
|
||||
_, err = c.joinWorkspace(ctx)
|
||||
|
|
@ -49,7 +76,7 @@ func (c *Client) Join(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
func (c *Client) hasJoined() bool {
|
||||
return c.sshSession != nil && c.rpc != nil
|
||||
return c.ssh != nil && c.rpc != nil
|
||||
}
|
||||
|
||||
type clientCapabilities struct {
|
||||
|
|
@ -69,9 +96,9 @@ type joinWorkspaceResult struct {
|
|||
|
||||
func (c *Client) joinWorkspace(ctx context.Context) (*joinWorkspaceResult, error) {
|
||||
args := joinWorkspaceArgs{
|
||||
ID: c.session.workspaceInfo.ID,
|
||||
ID: c.connection.SessionID,
|
||||
ConnectionMode: "local",
|
||||
JoiningUserSessionToken: c.session.workspaceAccess.SessionToken,
|
||||
JoiningUserSessionToken: c.connection.SessionToken,
|
||||
ClientCapabilities: clientCapabilities{
|
||||
IsNonInteractive: false,
|
||||
},
|
||||
|
|
@ -92,15 +119,14 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition
|
|||
return nil, fmt.Errorf("error getting stream id: %v", err)
|
||||
}
|
||||
|
||||
channel, reqs, err := c.sshSession.conn.OpenChannel("session", nil)
|
||||
channel, reqs, err := c.ssh.conn.OpenChannel("session", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening ssh channel for transport: %v", err)
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
requestType := fmt.Sprintf("stream-transport-%s", streamID)
|
||||
_, err = channel.SendRequest(requestType, true, nil)
|
||||
if err != nil {
|
||||
if _, err = channel.SendRequest(requestType, true, nil); err != nil {
|
||||
return nil, fmt.Errorf("error sending channel request: %v", err)
|
||||
}
|
||||
|
||||
|
|
|
|||
109
client_test.go
Normal file
109
client_test.go
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
livesharetest "github.com/github/go-liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
if err != nil {
|
||||
t.Errorf("error creating new client: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Error("client is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientValidConnection(t *testing.T) {
|
||||
connection := Connection{"1", "2", "3", "4"}
|
||||
|
||||
client, err := NewClient(WithConnection(connection))
|
||||
if err != nil {
|
||||
t.Errorf("error creating new client: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Error("client is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClientWithInvalidConnection(t *testing.T) {
|
||||
connection := Connection{}
|
||||
|
||||
if _, err := NewClient(WithConnection(connection)); err == nil {
|
||||
t.Error("err is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientJoin(t *testing.T) {
|
||||
connection := Connection{
|
||||
SessionID: "session-id",
|
||||
SessionToken: "session-token",
|
||||
RelaySAS: "relay-sas",
|
||||
}
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
var joinWorkspaceReq joinWorkspaceArgs
|
||||
if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling req: %v", err)
|
||||
}
|
||||
if joinWorkspaceReq.ID != connection.SessionID {
|
||||
return nil, errors.New("connection session id does not match")
|
||||
}
|
||||
if joinWorkspaceReq.ConnectionMode != "local" {
|
||||
return nil, errors.New("connection mode is not local")
|
||||
}
|
||||
if joinWorkspaceReq.JoiningUserSessionToken != connection.SessionToken {
|
||||
return nil, errors.New("connection user token does not match")
|
||||
}
|
||||
if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false {
|
||||
return nil, errors.New("non interactive is not false")
|
||||
}
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
|
||||
server, err := livesharetest.NewServer(
|
||||
livesharetest.WithPassword(connection.SessionToken),
|
||||
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
|
||||
livesharetest.WithRelaySAS(connection.RelaySAS),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("error creating liveshare server: %v", err)
|
||||
}
|
||||
defer server.Close()
|
||||
connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true})
|
||||
client, err := NewClient(WithConnection(connection), tlsConfig)
|
||||
if err != nil {
|
||||
t.Errorf("error creating new client: %v", err)
|
||||
}
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
if err := client.Join(ctx); err != nil {
|
||||
done <- fmt.Errorf("error joining client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-server.Err():
|
||||
t.Errorf("error from server: %v", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
44
connection.go
Normal file
44
connection.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A Connection represents a set of values necessary to join a liveshare connection
|
||||
type Connection struct {
|
||||
SessionID string
|
||||
SessionToken string
|
||||
RelaySAS string
|
||||
RelayEndpoint string
|
||||
}
|
||||
|
||||
func (r Connection) validate() error {
|
||||
if r.SessionID == "" {
|
||||
return errors.New("connection SessionID is required")
|
||||
}
|
||||
|
||||
if r.SessionToken == "" {
|
||||
return errors.New("connection SessionToken is required")
|
||||
}
|
||||
|
||||
if r.RelaySAS == "" {
|
||||
return errors.New("connection RelaySAS is required")
|
||||
}
|
||||
|
||||
if r.RelayEndpoint == "" {
|
||||
return errors.New("connection RelayEndpoint is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Connection) uri(action string) string {
|
||||
sas := url.QueryEscape(r.RelaySAS)
|
||||
uri := r.RelayEndpoint
|
||||
uri = strings.Replace(uri, "sb:", "wss:", -1)
|
||||
uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1)
|
||||
uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas
|
||||
return uri
|
||||
}
|
||||
33
connection_test.go
Normal file
33
connection_test.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package liveshare
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestConnectionValid(t *testing.T) {
|
||||
conn := Connection{"sess-id", "sess-token", "sas", "endpoint"}
|
||||
if err := conn.validate(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionInvalid(t *testing.T) {
|
||||
conn := Connection{"", "sess-token", "sas", "endpoint"}
|
||||
if err := conn.validate(); err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
conn = Connection{"sess-id", "", "sas", "endpoint"}
|
||||
if err := conn.validate(); err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
conn = Connection{"sess-id", "sess-token", "", "endpoint"}
|
||||
if err := conn.validate(); err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
conn = Connection{"sess-id", "sess-token", "sas", ""}
|
||||
if err := conn.validate(); err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
conn = Connection{"", "", "", ""}
|
||||
if err := conn.validate(); err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
77
liveshare.go
77
liveshare.go
|
|
@ -1,77 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type LiveShare struct {
|
||||
Configuration *Configuration
|
||||
}
|
||||
|
||||
func New(opts ...Option) (*LiveShare, error) {
|
||||
configuration := NewConfiguration()
|
||||
|
||||
for _, o := range opts {
|
||||
if err := o(configuration); err != nil {
|
||||
return nil, fmt.Errorf("error configuring liveshare: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := configuration.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("error validating configuration: %v", err)
|
||||
}
|
||||
|
||||
return &LiveShare{Configuration: configuration}, nil
|
||||
}
|
||||
|
||||
type Option func(configuration *Configuration) error
|
||||
|
||||
func WithWorkspaceID(id string) Option {
|
||||
return func(configuration *Configuration) error {
|
||||
configuration.WorkspaceID = id
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithLiveShareEndpoint(liveShareEndpoint string) Option {
|
||||
return func(configuration *Configuration) error {
|
||||
configuration.LiveShareEndpoint = liveShareEndpoint
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithToken(token string) Option {
|
||||
return func(configuration *Configuration) error {
|
||||
configuration.Token = token
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type Configuration struct {
|
||||
WorkspaceID, LiveShareEndpoint, Token string
|
||||
}
|
||||
|
||||
func NewConfiguration() *Configuration {
|
||||
return &Configuration{
|
||||
LiveShareEndpoint: "https://prod.liveshare.vsengsaas.visualstudio.com",
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Configuration) Validate() error {
|
||||
errs := []string{}
|
||||
if c.WorkspaceID == "" {
|
||||
errs = append(errs, "WorkspaceID is required")
|
||||
}
|
||||
|
||||
if c.Token == "" {
|
||||
errs = append(errs, "Token is required")
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errors.New(strings.Join(errs, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -4,25 +4,30 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type LocalPortForwarder struct {
|
||||
client *Client
|
||||
server *Server
|
||||
port int
|
||||
channels []ssh.Channel
|
||||
// A PortForwader can forward ports from a remote liveshare host to localhost
|
||||
type PortForwarder struct {
|
||||
client *Client
|
||||
server *Server
|
||||
port int
|
||||
errCh chan error
|
||||
}
|
||||
|
||||
func NewLocalPortForwarder(client *Client, server *Server, port int) *LocalPortForwarder {
|
||||
return &LocalPortForwarder{client, server, port, []ssh.Channel{}}
|
||||
// NewPortForwarder creates a new PortForwader with a given client, server and port
|
||||
func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder {
|
||||
return &PortForwarder{
|
||||
client: client,
|
||||
server: server,
|
||||
port: port,
|
||||
errCh: make(chan error),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LocalPortForwarder) Start(ctx context.Context) error {
|
||||
// Start is a method to start forwarding the server to a localhost port
|
||||
func (l *PortForwarder) Start(ctx context.Context) error {
|
||||
ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error listening on tcp port: %v", err)
|
||||
|
|
@ -37,24 +42,24 @@ func (l *LocalPortForwarder) Start(ctx context.Context) error {
|
|||
go l.handleConnection(ctx, conn)
|
||||
}
|
||||
|
||||
// clean up after ourselves
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalPortForwarder) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition)
|
||||
if err != nil {
|
||||
log.Println("errrr handle Connect")
|
||||
log.Println(err) // TODO(josebalius) handle this somehow
|
||||
l.errCh <- fmt.Errorf("error opening streaming channel for new connection: %v", err)
|
||||
return
|
||||
}
|
||||
l.channels = append(l.channels, channel)
|
||||
|
||||
copyConn := func(writer io.Writer, reader io.Reader) {
|
||||
_, err := io.Copy(writer, reader)
|
||||
if err != nil {
|
||||
if _, err := io.Copy(writer, reader); err != nil {
|
||||
fmt.Println(err)
|
||||
channel.Close()
|
||||
conn.Close()
|
||||
if err != io.EOF {
|
||||
l.errCh <- fmt.Errorf("tunnel connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
103
port_forwarder_test.go
Normal file
103
port_forwarder_test.go
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
livesharetest "github.com/github/go-liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func TestNewPortForwarder(t *testing.T) {
|
||||
testServer, client, err := makeMockJoinedClient()
|
||||
if err != nil {
|
||||
t.Errorf("create mock client: %v", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
server, err := NewServer(client)
|
||||
if err != nil {
|
||||
t.Errorf("create new server: %v", err)
|
||||
}
|
||||
pf := NewPortForwarder(client, server, 80)
|
||||
if pf == nil {
|
||||
t.Error("port forwarder is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortForwarderStart(t *testing.T) {
|
||||
streamName, streamCondition := "stream-name", "stream-condition"
|
||||
serverSharing := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
|
||||
}
|
||||
getStream := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return "stream-id", nil
|
||||
}
|
||||
|
||||
stream := bytes.NewBufferString("stream-data")
|
||||
testServer, client, err := makeMockJoinedClient(
|
||||
livesharetest.WithService("serverSharing.startSharing", serverSharing),
|
||||
livesharetest.WithService("streamManager.getStream", getStream),
|
||||
livesharetest.WithStream("stream-id", stream),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("create mock client: %v", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
|
||||
server, err := NewServer(client)
|
||||
if err != nil {
|
||||
t.Errorf("create new server: %v", err)
|
||||
}
|
||||
|
||||
ctx, _ := context.WithCancel(context.Background())
|
||||
pf := NewPortForwarder(client, server, 8000)
|
||||
done := make(chan error)
|
||||
|
||||
go func() {
|
||||
if err := server.StartSharing(ctx, "http", 8000); err != nil {
|
||||
done <- fmt.Errorf("start sharing: %v", err)
|
||||
}
|
||||
if err := pf.Start(ctx); err != nil {
|
||||
done <- err
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
go func() {
|
||||
var conn net.Conn
|
||||
retries := 0
|
||||
for conn == nil && retries < 2 {
|
||||
conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second)
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
if conn == nil {
|
||||
done <- errors.New("failed to connect to forwarded port")
|
||||
}
|
||||
b := make([]byte, len("stream-data"))
|
||||
if _, err := conn.Read(b); err != nil && err != io.EOF {
|
||||
done <- fmt.Errorf("reading stream: %v", err)
|
||||
}
|
||||
if string(b) != "stream-data" {
|
||||
done <- fmt.Errorf("stream data is not expected value, got: %v", string(b))
|
||||
}
|
||||
if _, err := conn.Write([]byte("new-data")); err != nil {
|
||||
done <- fmt.Errorf("writing to stream: %v", err)
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %v", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
17
rpc.go
17
rpc.go
|
|
@ -9,32 +9,27 @@ import (
|
|||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
type rpc struct {
|
||||
type rpcClient struct {
|
||||
*jsonrpc2.Conn
|
||||
conn io.ReadWriteCloser
|
||||
handler *rpcHandler
|
||||
}
|
||||
|
||||
func newRPC(conn io.ReadWriteCloser) *rpc {
|
||||
return &rpc{conn: conn, handler: newRPCHandler()}
|
||||
func newRpcClient(conn io.ReadWriteCloser) *rpcClient {
|
||||
return &rpcClient{conn: conn, handler: newRPCHandler()}
|
||||
}
|
||||
|
||||
func (r *rpc) connect(ctx context.Context) {
|
||||
func (r *rpcClient) connect(ctx context.Context) {
|
||||
stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{})
|
||||
r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler)
|
||||
}
|
||||
|
||||
func (r *rpc) do(ctx context.Context, method string, args interface{}, result interface{}) error {
|
||||
func (r *rpcClient) do(ctx context.Context, method string, args interface{}, result interface{}) error {
|
||||
waiter, err := r.Conn.DispatchCall(ctx, method, args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error on dispatch call: %v", err)
|
||||
}
|
||||
|
||||
// caller doesn't care about result, so lets ignore it
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return waiter.Wait(ctx, result)
|
||||
}
|
||||
|
||||
|
|
@ -78,7 +73,5 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr
|
|||
|
||||
r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{}
|
||||
}()
|
||||
} else {
|
||||
// TODO(josebalius): Handle
|
||||
}
|
||||
}
|
||||
|
|
|
|||
16
server.go
16
server.go
|
|
@ -7,20 +7,23 @@ import (
|
|||
"strconv"
|
||||
)
|
||||
|
||||
// A Server represents the liveshare host and container server
|
||||
type Server struct {
|
||||
client *Client
|
||||
port int
|
||||
streamName, streamCondition string
|
||||
}
|
||||
|
||||
func (c *Client) NewServer() (*Server, error) {
|
||||
if !c.hasJoined() {
|
||||
return nil, errors.New("LiveShareClient must join before creating server")
|
||||
// NewServer creates a new Server with a given Client
|
||||
func NewServer(client *Client) (*Server, error) {
|
||||
if !client.hasJoined() {
|
||||
return nil, errors.New("client must join before creating server")
|
||||
}
|
||||
|
||||
return &Server{client: c}, nil
|
||||
return &Server{client: client}, nil
|
||||
}
|
||||
|
||||
// Port represents an open port on the container
|
||||
type Port struct {
|
||||
SourcePort int `json:"sourcePort"`
|
||||
DestinationPort int `json:"destinationPort"`
|
||||
|
|
@ -33,6 +36,7 @@ type Port struct {
|
|||
HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"`
|
||||
}
|
||||
|
||||
// StartSharing tells the liveshare host to start sharing the port from the container
|
||||
func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error {
|
||||
s.port = port
|
||||
|
||||
|
|
@ -49,8 +53,10 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er
|
|||
return nil
|
||||
}
|
||||
|
||||
// Ports is a slice of Port pointers
|
||||
type Ports []*Port
|
||||
|
||||
// GetSharedServers returns a list of available/open ports from the container
|
||||
func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) {
|
||||
var response Ports
|
||||
if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil {
|
||||
|
|
@ -60,6 +66,8 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) {
|
|||
return response, nil
|
||||
}
|
||||
|
||||
// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly
|
||||
// via the Browse URL
|
||||
func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error {
|
||||
if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil {
|
||||
return err
|
||||
|
|
|
|||
237
server_test.go
Normal file
237
server_test.go
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
livesharetest "github.com/github/go-liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func TestNewServerWithNotJoinedClient(t *testing.T) {
|
||||
client, err := NewClient()
|
||||
if err != nil {
|
||||
t.Errorf("error creating new client: %v", err)
|
||||
}
|
||||
if _, err := NewServer(client); err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) {
|
||||
connection := Connection{
|
||||
SessionID: "session-id",
|
||||
SessionToken: "session-token",
|
||||
RelaySAS: "relay-sas",
|
||||
}
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
opts = append(
|
||||
opts,
|
||||
livesharetest.WithPassword(connection.SessionToken),
|
||||
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
|
||||
)
|
||||
testServer, err := livesharetest.NewServer(
|
||||
opts...,
|
||||
)
|
||||
connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https")
|
||||
tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true})
|
||||
client, err := NewClient(WithConnection(connection), tlsConfig)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating new client: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
if err := client.Join(ctx); err != nil {
|
||||
return nil, nil, fmt.Errorf("error joining client: %v", err)
|
||||
}
|
||||
return testServer, client, nil
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
testServer, client, err := makeMockJoinedClient()
|
||||
defer testServer.Close()
|
||||
if err != nil {
|
||||
t.Errorf("error creating mock joined client: %v", err)
|
||||
}
|
||||
server, err := NewServer(client)
|
||||
if err != nil {
|
||||
t.Errorf("error creating new server: %v", err)
|
||||
}
|
||||
if server == nil {
|
||||
t.Error("server is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerStartSharing(t *testing.T) {
|
||||
serverPort, serverProtocol := 2222, "sshd"
|
||||
startSharing := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
var args []interface{}
|
||||
if err := json.Unmarshal(*req.Params, &args); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling request: %v", err)
|
||||
}
|
||||
if len(args) < 3 {
|
||||
return nil, errors.New("not enough arguments to start sharing")
|
||||
}
|
||||
if port, ok := args[0].(float64); !ok {
|
||||
return nil, errors.New("port argument is not an int")
|
||||
} else if port != float64(serverPort) {
|
||||
return nil, errors.New("port does not match serverPort")
|
||||
}
|
||||
if protocol, ok := args[1].(string); !ok {
|
||||
return nil, errors.New("protocol argument is not a string")
|
||||
} else if protocol != serverProtocol {
|
||||
return nil, errors.New("protocol does not match serverProtocol")
|
||||
}
|
||||
if browseURL, ok := args[2].(string); !ok {
|
||||
return nil, errors.New("browse url is not a string")
|
||||
} else if browseURL != fmt.Sprintf("http://localhost:%v", serverPort) {
|
||||
return nil, errors.New("browseURL does not match expected")
|
||||
}
|
||||
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
|
||||
}
|
||||
testServer, client, err := makeMockJoinedClient(
|
||||
livesharetest.WithService("serverSharing.startSharing", startSharing),
|
||||
)
|
||||
defer testServer.Close()
|
||||
if err != nil {
|
||||
t.Errorf("error creating mock joined client: %v", err)
|
||||
}
|
||||
server, err := NewServer(client)
|
||||
if err != nil {
|
||||
t.Errorf("error creating new server: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
if err := server.StartSharing(ctx, serverProtocol, serverPort); err != nil {
|
||||
done <- fmt.Errorf("error sharing server: %v", err)
|
||||
}
|
||||
if server.streamName == "" || server.streamCondition == "" {
|
||||
done <- errors.New("stream name or condition is blank")
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %v", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerGetSharedServers(t *testing.T) {
|
||||
sharedServer := Port{
|
||||
SourcePort: 2222,
|
||||
StreamName: "stream-name",
|
||||
StreamCondition: "stream-condition",
|
||||
}
|
||||
getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return Ports{&sharedServer}, nil
|
||||
}
|
||||
testServer, client, err := makeMockJoinedClient(
|
||||
livesharetest.WithService("serverSharing.getSharedServers", getSharedServers),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("error creating new mock client: %v", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
server, err := NewServer(client)
|
||||
if err != nil {
|
||||
t.Errorf("error creating new server: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
ports, err := server.GetSharedServers(ctx)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("error getting shared servers: %v", err)
|
||||
}
|
||||
if len(ports) < 1 {
|
||||
done <- errors.New("not enough ports returned")
|
||||
}
|
||||
if ports[0].SourcePort != sharedServer.SourcePort {
|
||||
done <- errors.New("source port does not match")
|
||||
}
|
||||
if ports[0].StreamName != sharedServer.StreamName {
|
||||
done <- errors.New("stream name does not match")
|
||||
}
|
||||
if ports[0].StreamCondition != sharedServer.StreamCondition {
|
||||
done <- errors.New("stream condiion does not match")
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %v", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerUpdateSharedVisibility(t *testing.T) {
|
||||
updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
var req []interface{}
|
||||
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal req: %v", err)
|
||||
}
|
||||
if len(req) < 2 {
|
||||
return nil, errors.New("request arguments is less than 2")
|
||||
}
|
||||
if port, ok := req[0].(float64); ok {
|
||||
if port != 80.0 {
|
||||
return nil, errors.New("port param is not expected value")
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("port param is not a float64")
|
||||
}
|
||||
if public, ok := req[1].(bool); ok {
|
||||
if public != true {
|
||||
return nil, errors.New("pulic param is not expected value")
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("public param is not a bool")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
testServer, client, err := makeMockJoinedClient(
|
||||
livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("creating new mock client: %v", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
server, err := NewServer(client)
|
||||
if err != nil {
|
||||
t.Errorf("creating server: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
if err := server.UpdateSharedVisibility(ctx, 80, true); err != nil {
|
||||
done <- err
|
||||
return
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %v", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
60
session.go
60
session.go
|
|
@ -1,60 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
type session struct {
|
||||
api *api
|
||||
|
||||
workspaceAccess *workspaceAccessResponse
|
||||
workspaceInfo *workspaceInfoResponse
|
||||
}
|
||||
|
||||
func newSession(api *api) *session {
|
||||
return &session{api: api}
|
||||
}
|
||||
|
||||
func (s *session) init(ctx context.Context) error {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
workspaceAccess, err := s.api.workspaceAccess()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting workspace access: %v", err)
|
||||
}
|
||||
s.workspaceAccess = workspaceAccess
|
||||
return nil
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
workspaceInfo, err := s.api.workspaceInfo()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting workspace info: %v", err)
|
||||
}
|
||||
s.workspaceInfo = workspaceInfo
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reference:
|
||||
// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137
|
||||
func (s *session) relayURI(action string) string {
|
||||
relaySas := url.QueryEscape(s.workspaceAccess.RelaySas)
|
||||
relayURI := s.workspaceAccess.RelayLink
|
||||
relayURI = strings.Replace(relayURI, "sb:", "wss:", -1)
|
||||
relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1)
|
||||
relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas
|
||||
return relayURI
|
||||
}
|
||||
100
socket.go
Normal file
100
socket.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type socket struct {
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
|
||||
conn *websocket.Conn
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket {
|
||||
return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig}
|
||||
}
|
||||
|
||||
func (s *socket) connect(ctx context.Context) error {
|
||||
dialer := websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: 45 * time.Second,
|
||||
TLSClientConfig: s.tlsConfig,
|
||||
}
|
||||
ws, _, err := dialer.Dial(s.addr, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.conn = ws
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *socket) Read(b []byte) (int, error) {
|
||||
if s.reader == nil {
|
||||
_, reader, err := s.conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s.reader = reader
|
||||
}
|
||||
|
||||
bytesRead, err := s.reader.Read(b)
|
||||
if err != nil {
|
||||
s.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (s *socket) Write(b []byte) (int, error) {
|
||||
nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
bytesWritten, err := nextWriter.Write(b)
|
||||
nextWriter.Close()
|
||||
|
||||
return bytesWritten, err
|
||||
}
|
||||
|
||||
func (s *socket) Close() error {
|
||||
return s.conn.Close()
|
||||
}
|
||||
|
||||
func (s *socket) LocalAddr() net.Addr {
|
||||
return s.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (s *socket) RemoteAddr() net.Addr {
|
||||
return s.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (s *socket) SetDeadline(t time.Time) error {
|
||||
if err := s.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (s *socket) SetReadDeadline(t time.Time) error {
|
||||
return s.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (s *socket) SetWriteDeadline(t time.Time) error {
|
||||
return s.conn.SetWriteDeadline(t)
|
||||
}
|
||||
16
ssh.go
16
ssh.go
|
|
@ -12,22 +12,22 @@ import (
|
|||
|
||||
type sshSession struct {
|
||||
*ssh.Session
|
||||
session *session
|
||||
socket net.Conn
|
||||
conn ssh.Conn
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
token string
|
||||
socket net.Conn
|
||||
conn ssh.Conn
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
func newSSH(session *session, socket net.Conn) *sshSession {
|
||||
return &sshSession{session: session, socket: socket}
|
||||
func newSshSession(token string, socket net.Conn) *sshSession {
|
||||
return &sshSession{token: token, socket: socket}
|
||||
}
|
||||
|
||||
func (s *sshSession) connect(ctx context.Context) error {
|
||||
clientConfig := ssh.ClientConfig{
|
||||
User: "",
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password(s.session.workspaceAccess.SessionToken),
|
||||
ssh.Password(s.token),
|
||||
},
|
||||
HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ type Terminal struct {
|
|||
client *Client
|
||||
}
|
||||
|
||||
func (c *Client) NewTerminal() (*Terminal, error) {
|
||||
if !c.hasJoined() {
|
||||
return nil, errors.New("LiveShareClient must join before creating terminal")
|
||||
func NewTerminal(client *Client) (*Terminal, error) {
|
||||
if !client.hasJoined() {
|
||||
return nil, errors.New("client must join before creating terminal")
|
||||
}
|
||||
|
||||
return &Terminal{
|
||||
client: c,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
290
test/server.go
Normal file
290
test/server.go
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
package livesharetest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
password string
|
||||
services map[string]RpcHandleFunc
|
||||
relaySAS string
|
||||
streams map[string]io.ReadWriter
|
||||
|
||||
sshConfig *ssh.ServerConfig
|
||||
httptestServer *httptest.Server
|
||||
errCh chan error
|
||||
}
|
||||
|
||||
func NewServer(opts ...ServerOption) (*Server, error) {
|
||||
server := new(Server)
|
||||
|
||||
for _, o := range opts {
|
||||
if err := o(server); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
server.sshConfig = &ssh.ServerConfig{
|
||||
PasswordCallback: sshPasswordCallback(server.password),
|
||||
}
|
||||
b, err := ioutil.ReadFile(filepath.Join("test", "private.key"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading private.key: %v", err)
|
||||
}
|
||||
privateKey, err := ssh.ParsePrivateKey(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing key: %v", err)
|
||||
}
|
||||
server.sshConfig.AddHostKey(privateKey)
|
||||
|
||||
server.errCh = make(chan error)
|
||||
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
|
||||
return server, nil
|
||||
}
|
||||
|
||||
type ServerOption func(*Server) error
|
||||
|
||||
func WithPassword(password string) ServerOption {
|
||||
return func(s *Server) error {
|
||||
s.password = password
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithService(serviceName string, handler RpcHandleFunc) ServerOption {
|
||||
return func(s *Server) error {
|
||||
if s.services == nil {
|
||||
s.services = make(map[string]RpcHandleFunc)
|
||||
}
|
||||
|
||||
s.services[serviceName] = handler
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRelaySAS(sas string) ServerOption {
|
||||
return func(s *Server) error {
|
||||
s.relaySAS = sas
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithStream(name string, stream io.ReadWriter) ServerOption {
|
||||
return func(s *Server) error {
|
||||
if s.streams == nil {
|
||||
s.streams = make(map[string]io.ReadWriter)
|
||||
}
|
||||
s.streams[name] = stream
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
|
||||
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
if string(password) == serverPassword {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.New("password rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Close() {
|
||||
s.httptestServer.Close()
|
||||
}
|
||||
|
||||
func (s *Server) URL() string {
|
||||
return s.httptestServer.URL
|
||||
}
|
||||
|
||||
func (s *Server) Err() <-chan error {
|
||||
return s.errCh
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
|
||||
func makeConnection(server *Server) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
if server.relaySAS != "" {
|
||||
// validate the sas key
|
||||
sasParam := req.URL.Query().Get("sb-hc-token")
|
||||
if sasParam != server.relaySAS {
|
||||
server.errCh <- errors.New("error validating sas")
|
||||
return
|
||||
}
|
||||
}
|
||||
c, err := upgrader.Upgrade(w, req, nil)
|
||||
if err != nil {
|
||||
server.errCh <- fmt.Errorf("error upgrading connection: %v", err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
socketConn := newSocketConn(c)
|
||||
_, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
|
||||
if err != nil {
|
||||
server.errCh <- fmt.Errorf("error creating new ssh conn: %v", err)
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
for newChannel := range chans {
|
||||
ch, reqs, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
server.errCh <- fmt.Errorf("error accepting new channel: %v", err)
|
||||
return
|
||||
}
|
||||
go handleNewRequests(server, ch, reqs)
|
||||
go handleNewChannel(server, ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) {
|
||||
for req := range reqs {
|
||||
if req.WantReply {
|
||||
if err := req.Reply(true, nil); err != nil {
|
||||
server.errCh <- fmt.Errorf("error replying to channel request: %v", err)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(req.Type, "stream-transport") {
|
||||
forwardStream(server, req.Type, channel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func forwardStream(server *Server, streamName string, channel ssh.Channel) {
|
||||
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
|
||||
stream, found := server.streams[simpleStreamName]
|
||||
if !found {
|
||||
server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName)
|
||||
return
|
||||
}
|
||||
|
||||
copy := func(dst io.Writer, src io.Reader) {
|
||||
if _, err := io.Copy(dst, src); err != nil {
|
||||
fmt.Println(err)
|
||||
server.errCh <- fmt.Errorf("io copy: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go copy(stream, channel)
|
||||
go copy(channel, stream)
|
||||
|
||||
for {
|
||||
}
|
||||
}
|
||||
|
||||
func handleNewChannel(server *Server, channel ssh.Channel) {
|
||||
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
|
||||
jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server))
|
||||
}
|
||||
|
||||
type RpcHandleFunc func(req *jsonrpc2.Request) (interface{}, error)
|
||||
|
||||
type rpcHandler struct {
|
||||
server *Server
|
||||
}
|
||||
|
||||
func newRpcHandler(server *Server) *rpcHandler {
|
||||
return &rpcHandler{server}
|
||||
}
|
||||
|
||||
func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
|
||||
handler, found := r.server.services[req.Method]
|
||||
if !found {
|
||||
r.server.errCh <- fmt.Errorf("RPC Method: '%v' not serviced", req.Method)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := handler(req)
|
||||
if err != nil {
|
||||
r.server.errCh <- fmt.Errorf("error handling: '%v': %v", req.Method, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.Reply(ctx, req.ID, result); err != nil {
|
||||
r.server.errCh <- fmt.Errorf("error replying: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type socketConn struct {
|
||||
*websocket.Conn
|
||||
|
||||
reader io.Reader
|
||||
writeMutex sync.Mutex
|
||||
readMutex sync.Mutex
|
||||
}
|
||||
|
||||
func newSocketConn(conn *websocket.Conn) *socketConn {
|
||||
return &socketConn{Conn: conn}
|
||||
}
|
||||
|
||||
func (s *socketConn) Read(b []byte) (int, error) {
|
||||
s.readMutex.Lock()
|
||||
defer s.readMutex.Unlock()
|
||||
|
||||
if s.reader == nil {
|
||||
msgType, r, err := s.Conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next reader: %v", err)
|
||||
}
|
||||
if msgType != websocket.BinaryMessage {
|
||||
return 0, fmt.Errorf("invalid message type")
|
||||
}
|
||||
s.reader = r
|
||||
}
|
||||
|
||||
bytesRead, err := s.reader.Read(b)
|
||||
if err != nil {
|
||||
s.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (s *socketConn) Write(b []byte) (int, error) {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
|
||||
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next writer: %v", err)
|
||||
}
|
||||
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error writing: %v", err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return 0, fmt.Errorf("error closing writer: %v", err)
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *socketConn) SetDeadline(deadline time.Time) error {
|
||||
if err := s.Conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Conn.SetWriteDeadline(deadline)
|
||||
}
|
||||
105
websocket.go
105
websocket.go
|
|
@ -1,105 +0,0 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
gorillawebsocket "github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type websocket struct {
|
||||
session *session
|
||||
conn *gorillawebsocket.Conn
|
||||
readMutex sync.Mutex
|
||||
writeMutex sync.Mutex
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func newWebsocket(session *session) *websocket {
|
||||
return &websocket{session: session}
|
||||
}
|
||||
|
||||
func (w *websocket) connect(ctx context.Context) error {
|
||||
ws, _, err := gorillawebsocket.DefaultDialer.Dial(w.session.relayURI("connect"), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.conn = ws
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *websocket) Read(b []byte) (int, error) {
|
||||
w.readMutex.Lock()
|
||||
defer w.readMutex.Unlock()
|
||||
|
||||
if w.reader == nil {
|
||||
messageType, reader, err := w.conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if messageType != gorillawebsocket.BinaryMessage {
|
||||
return 0, errors.New("unexpected websocket message type")
|
||||
}
|
||||
|
||||
w.reader = reader
|
||||
}
|
||||
|
||||
bytesRead, err := w.reader.Read(b)
|
||||
if err != nil {
|
||||
w.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (w *websocket) Write(b []byte) (int, error) {
|
||||
w.writeMutex.Lock()
|
||||
defer w.writeMutex.Unlock()
|
||||
|
||||
nextWriter, err := w.conn.NextWriter(gorillawebsocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
bytesWritten, err := nextWriter.Write(b)
|
||||
nextWriter.Close()
|
||||
|
||||
return bytesWritten, err
|
||||
}
|
||||
|
||||
func (w *websocket) Close() error {
|
||||
return w.conn.Close()
|
||||
}
|
||||
|
||||
func (w *websocket) LocalAddr() net.Addr {
|
||||
return w.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (w *websocket) RemoteAddr() net.Addr {
|
||||
return w.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (w *websocket) SetDeadline(t time.Time) error {
|
||||
if err := w.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return w.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (w *websocket) SetReadDeadline(t time.Time) error {
|
||||
return w.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (w *websocket) SetWriteDeadline(t time.Time) error {
|
||||
return w.conn.SetWriteDeadline(t)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue