diff --git a/adapter.go b/adapter.go deleted file mode 100644 index fb3424734..000000000 --- a/adapter.go +++ /dev/null @@ -1,100 +0,0 @@ -package liveshare - -import ( - "errors" - "io" - "net" - "sync" - "time" - - "github.com/gorilla/websocket" -) - -type Adapter struct { - conn *websocket.Conn - readMutex sync.Mutex - writeMutex sync.Mutex - reader io.Reader -} - -func NewAdapter(conn *websocket.Conn) *Adapter { - return &Adapter{ - conn: conn, - } -} - -func (a *Adapter) Read(b []byte) (int, error) { - // Read() can be called concurrently, and we mutate some internal state here - a.readMutex.Lock() - defer a.readMutex.Unlock() - - if a.reader == nil { - messageType, reader, err := a.conn.NextReader() - if err != nil { - return 0, err - } - - if messageType != websocket.BinaryMessage { - return 0, errors.New("unexpected websocket message type") - } - - a.reader = reader - } - - bytesRead, err := a.reader.Read(b) - if err != nil { - a.reader = nil - - // EOF for the current Websocket frame, more will probably come so.. - if err == io.EOF { - // .. we must hide this from the caller since our semantics are a - // stream of bytes across many frames - err = nil - } - } - - return bytesRead, err -} - -func (a *Adapter) Write(b []byte) (int, error) { - a.writeMutex.Lock() - defer a.writeMutex.Unlock() - - nextWriter, err := a.conn.NextWriter(websocket.BinaryMessage) - if err != nil { - return 0, err - } - - bytesWritten, err := nextWriter.Write(b) - nextWriter.Close() - - return bytesWritten, err -} - -func (a *Adapter) Close() error { - return a.conn.Close() -} - -func (a *Adapter) LocalAddr() net.Addr { - return a.conn.LocalAddr() -} - -func (a *Adapter) RemoteAddr() net.Addr { - return a.conn.RemoteAddr() -} - -func (a *Adapter) SetDeadline(t time.Time) error { - if err := a.SetReadDeadline(t); err != nil { - return err - } - - return a.SetWriteDeadline(t) -} - -func (a *Adapter) SetReadDeadline(t time.Time) error { - return a.conn.SetReadDeadline(t) -} - -func (a *Adapter) SetWriteDeadline(t time.Time) error { - return a.conn.SetWriteDeadline(t) -} diff --git a/api.go b/api.go index d8a4bebbd..a101823df 100644 --- a/api.go +++ b/api.go @@ -5,21 +5,20 @@ import ( "fmt" "io/ioutil" "net/http" - "net/http/httputil" "strings" ) -type API struct { - Configuration *Configuration - HttpClient *http.Client - ServiceURI string - WorkspaceID string +type api struct { + client *Client + httpClient *http.Client + serviceURI string + workspaceID string } -func NewAPI(configuration *Configuration) *API { - serviceURI := configuration.LiveShareEndpoint - if !strings.HasSuffix(configuration.LiveShareEndpoint, "/") { - serviceURI = configuration.LiveShareEndpoint + "/" +func newAPI(client *Client) *api { + serviceURI := client.liveShare.Configuration.LiveShareEndpoint + if !strings.HasSuffix(client.liveShare.Configuration.LiveShareEndpoint, "/") { + serviceURI = client.liveShare.Configuration.LiveShareEndpoint + "/" } if !strings.Contains(serviceURI, "api/v1.2") { @@ -28,10 +27,10 @@ func NewAPI(configuration *Configuration) *API { serviceURI = strings.TrimSuffix(serviceURI, "/") - return &API{configuration, &http.Client{}, serviceURI, strings.ToUpper(configuration.WorkspaceID)} + return &api{client, &http.Client{}, serviceURI, strings.ToUpper(client.liveShare.Configuration.WorkspaceID)} } -type WorkspaceAccessResponse struct { +type workspaceAccessResponse struct { SessionToken string `json:"sessionToken"` CreatedAt string `json:"createdAt"` UpdatedAt string `json:"updatedAt"` @@ -51,8 +50,8 @@ type WorkspaceAccessResponse struct { ID string `json:"id"` } -func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { - url := fmt.Sprintf("%s/workspace/%s/user", a.ServiceURI, a.WorkspaceID) +func (a *api) workspaceAccess() (*workspaceAccessResponse, error) { + url := fmt.Sprintf("%s/workspace/%s/user", a.serviceURI, a.workspaceID) fmt.Println(url) req, err := http.NewRequest(http.MethodPut, url, nil) @@ -61,7 +60,7 @@ func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { } a.setDefaultHeaders(req) - resp, err := a.HttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -71,23 +70,21 @@ func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { return nil, fmt.Errorf("error reading response body: %v", err) } - d, _ := httputil.DumpResponse(resp, true) - fmt.Println(string(d)) - var workspaceAccessResponse WorkspaceAccessResponse - if err := json.Unmarshal(b, &workspaceAccessResponse); err != nil { + var response workspaceAccessResponse + if err := json.Unmarshal(b, &response); err != nil { return nil, fmt.Errorf("error unmarshaling response into json: %v", err) } - return &workspaceAccessResponse, nil + return &response, nil } -func (a *API) setDefaultHeaders(req *http.Request) { - req.Header.Set("Authorization", "Bearer "+a.Configuration.Token) +func (a *api) setDefaultHeaders(req *http.Request) { + req.Header.Set("Authorization", "Bearer "+a.client.liveShare.Configuration.Token) req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Content-Type", "application/json") } -type WorkspaceInfoResponse struct { +type workspaceInfoResponse struct { CreatedAt string `json:"createdAt"` UpdatedAt string `json:"updatedAt"` Name string `json:"name"` @@ -106,8 +103,8 @@ type WorkspaceInfoResponse struct { ID string `json:"id"` } -func (a *API) WorkspaceInfo() (*WorkspaceInfoResponse, error) { - url := fmt.Sprintf("%s/workspace/%s", a.ServiceURI, a.WorkspaceID) +func (a *api) workspaceInfo() (*workspaceInfoResponse, error) { + url := fmt.Sprintf("%s/workspace/%s", a.serviceURI, a.workspaceID) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { @@ -115,7 +112,7 @@ func (a *API) WorkspaceInfo() (*WorkspaceInfoResponse, error) { } a.setDefaultHeaders(req) - resp, err := a.HttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -125,10 +122,10 @@ func (a *API) WorkspaceInfo() (*WorkspaceInfoResponse, error) { return nil, fmt.Errorf("error reading response body: %v", err) } - var workspaceInfoResponse WorkspaceInfoResponse - if err := json.Unmarshal(b, &workspaceInfoResponse); err != nil { + var response workspaceInfoResponse + if err := json.Unmarshal(b, &response); err != nil { return nil, fmt.Errorf("error unmarshaling response into json: %v", err) } - return &workspaceInfoResponse, nil + return &response, nil } diff --git a/client.go b/client.go index 0a89a125c..88c91ba01 100644 --- a/client.go +++ b/client.go @@ -3,27 +3,122 @@ package liveshare import ( "context" "fmt" + "log" + + "golang.org/x/crypto/ssh" ) type Client struct { - Configuration *Configuration - SSHSession *SSHSession + liveShare *LiveShare + session *session + sshSession *sshSession + rpc *rpc } -func NewClient(configuration *Configuration) *Client { - return &Client{Configuration: configuration} +// NewClient is a function ... +func (l *LiveShare) NewClient() *Client { + return &Client{liveShare: l} } -func (c *Client) Join(ctx context.Context) error { - session, err := GetSession(ctx, c.Configuration) - if err != nil { - return fmt.Errorf("error getting session: %v", err) +func (c *Client) Join(ctx context.Context) (err error) { + api := newAPI(c) + + c.session = newSession(api) + if err := c.session.init(ctx); err != nil { + return fmt.Errorf("error creating session: %v", err) } - c.SSHSession, err = NewSSH(session).NewSession() - if err != nil { + websocket := newWebsocket(c.session) + if err := websocket.connect(ctx); err != nil { + return fmt.Errorf("error connecting websocket: %v", err) + } + + c.sshSession = newSSH(c.session, websocket) + if err := c.sshSession.connect(ctx); err != nil { return fmt.Errorf("error connecting to ssh session: %v", err) } + c.rpc = newRPC(c.sshSession) + c.rpc.connect(ctx) + + _, err = c.joinWorkspace(ctx) + if err != nil { + return fmt.Errorf("error joining liveshare workspace: %v", err) + } + return nil } + +func (c *Client) hasJoined() bool { + return c.sshSession != nil && c.rpc != nil +} + +type clientCapabilities struct { + IsNonInteractive bool `json:"isNonInteractive"` +} + +type joinWorkspaceArgs struct { + ID string `json:"id"` + ConnectionMode string `json:"connectionMode"` + JoiningUserSessionToken string `json:"joiningUserSessionToken"` + ClientCapabilities clientCapabilities `json:"clientCapabilities"` +} + +type joinWorkspaceResult struct { + SessionNumber int `json:"sessionNumber"` +} + +func (c *Client) joinWorkspace(ctx context.Context) (*joinWorkspaceResult, error) { + args := joinWorkspaceArgs{ + ID: c.session.workspaceInfo.ID, + ConnectionMode: "local", + JoiningUserSessionToken: c.session.workspaceAccess.SessionToken, + ClientCapabilities: clientCapabilities{ + IsNonInteractive: false, + }, + } + + var result joinWorkspaceResult + if err := c.rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { + return nil, fmt.Errorf("error making workspace.joinWorkspace call: %v", err) + } + + return &result, nil +} + +func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { + args := getStreamArgs{streamName, condition} + var streamID string + if err := c.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { + return nil, fmt.Errorf("error getting stream id: %v", err) + } + + channel, reqs, err := c.sshSession.conn.OpenChannel("session", nil) + if err != nil { + return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) + } + go c.processChannelRequests(ctx, reqs) + + requestType := fmt.Sprintf("stream-transport-%s", streamID) + acked, err := channel.SendRequest(requestType, true, nil) + if err != nil { + return nil, fmt.Errorf("error sending channel request: %v", err) + } + fmt.Println("ACKED: ", acked) + + return channel, nil +} + +func (c *Client) processChannelRequests(ctx context.Context, reqs <-chan *ssh.Request) { + for { + select { + case req := <-reqs: + if req != nil { + fmt.Printf("REQ: %+v\n\n", req) + log.Println("streaming channel requests are not supported") + } + case <-ctx.Done(): + break + } + } +} diff --git a/example/main.go b/example/main.go index bbf540d7c..b1fda8d32 100644 --- a/example/main.go +++ b/example/main.go @@ -1,36 +1,117 @@ package main import ( + "bufio" "context" + "flag" "fmt" "log" + "os" - "github.com/josebalius/go-liveshare" + "github.com/github/go-liveshare" ) +var workspaceIdFlag = flag.String("w", "", "workspace session id") + +func init() { + flag.Parse() +} + func main() { liveShare, err := liveshare.New( - liveshare.WithWorkspaceID(""), - liveshare.WithToken(""), + liveshare.WithWorkspaceID(*workspaceIdFlag), + liveshare.WithToken(os.Getenv("CODESPACE_TOKEN")), ) if err != nil { log.Fatal(fmt.Errorf("error creating liveshare: %v", err)) } - if err := liveShare.Connect(context.Background()); err != nil { - log.Fatal(fmt.Errorf("error connecting to liveshare: %v", err)) + ctx := context.Background() + liveShareClient := liveShare.NewClient() + if err := liveShareClient.Join(ctx); err != nil { + log.Fatal(fmt.Errorf("error joining liveshare with client: %v", err)) } - terminal := liveShare.NewTerminal() - - cmd := terminal.NewCommand( - "/home/codespace/workspace", - "docker ps -aq --filter label=Type=codespaces --filter status=running", - ) - output, err := cmd.Run(context.Background()) + terminal, err := liveShareClient.NewTerminal() if err != nil { - log.Fatal(fmt.Errorf("error starting ssh server with liveshare: %v", err)) + log.Fatal(fmt.Errorf("error creating liveshare terminal")) } - fmt.Println(string(output)) + containerID, err := getContainerID(ctx, terminal) + if err != nil { + log.Fatal(fmt.Errorf("error getting container id: %v", err)) + } + + if err := setupSSH(ctx, terminal, containerID); err != nil { + log.Fatal(fmt.Errorf("error setting up ssh: %v", err)) + } + + fmt.Println("Starting server...") + + server, err := liveShareClient.NewServer() + if err != nil { + log.Fatal(fmt.Errorf("error creating server: %v", err)) + } + + fmt.Println("Starting sharing...") + if err := server.StartSharing(ctx, "sshd", 2222); err != nil { + log.Fatal(fmt.Errorf("error server sharing: %v", err)) + } + + portForwarder := liveshare.NewLocalPortForwarder(liveShareClient, server, 2222) + + fmt.Println("Listening on port 2222") + if err := portForwarder.Start(ctx); err != nil { + log.Fatal(fmt.Errorf("error forwarding port: %v", err)) + } +} + +func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID string) error { + cmd := terminal.NewCommand( + "/", + fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c \"echo -e \\\"testpwd1\\ntestpwd1\\n\\\" | sudo passwd codespace;/usr/local/share/ssh-init.sh\"", containerID), + ) + stream, err := cmd.Run(ctx) + if err != nil { + return fmt.Errorf("error running command: %v", err) + } + + scanner := bufio.NewScanner(stream) + scanner.Scan() + + fmt.Println("> Debug:", scanner.Text()) + if err := scanner.Err(); err != nil { + return fmt.Errorf("error scanning stream: %v", err) + } + + if err := stream.Close(); err != nil { + return fmt.Errorf("error closing stream: %v", err) + } + + return nil +} + +func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) { + cmd := terminal.NewCommand( + "/", + "/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running", + ) + stream, err := cmd.Run(ctx) + if err != nil { + return "", fmt.Errorf("error running command: %v", err) + } + + scanner := bufio.NewScanner(stream) + scanner.Scan() + + containerID := scanner.Text() + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("error scanning stream: %v", err) + } + + if err := stream.Close(); err != nil { + return "", fmt.Errorf("error closing stream: %v", err) + } + + return containerID, nil } diff --git a/liveshare.go b/liveshare.go index 174eac20f..38222957a 100644 --- a/liveshare.go +++ b/liveshare.go @@ -1,16 +1,11 @@ package liveshare import ( - "context" "fmt" - "net/rpc" ) type LiveShare struct { Configuration *Configuration - - workspaceClient *Client - terminal *Terminal } func New(opts ...Option) (*LiveShare, error) { @@ -28,65 +23,3 @@ func New(opts ...Option) (*LiveShare, error) { return &LiveShare{Configuration: configuration}, nil } - -func (l *LiveShare) Connect(ctx context.Context) error { - l.workspaceClient = NewClient(l.Configuration) - if err := l.workspaceClient.Join(ctx); err != nil { - return fmt.Errorf("error joining with workspace client: %v", err) - } - - return nil -} - -type Terminal struct { - WorkspaceClient *Client - RPCClient *rpc.Client -} - -func (l *LiveShare) NewTerminal() *Terminal { - return &Terminal{ - WorkspaceClient: l.workspaceClient, - RPCClient: rpc.NewClient(l.workspaceClient.SSHSession), - } -} - -type TerminalCommand struct { - Terminal *Terminal - Cwd string - Cmd string -} - -func (t *Terminal) NewCommand(cwd, cmd string) TerminalCommand { - return TerminalCommand{t, cwd, cmd} -} - -type RunArgs struct { - Name string - Rows, Cols int - App string - Cwd string - CommandLine []string - ReadOnlyForGuests bool -} - -func (t TerminalCommand) Run(ctx context.Context) ([]byte, error) { - args := RunArgs{ - Name: "RunCommand", - Rows: 10, - Cols: 80, - App: "/bin/bash", - Cwd: t.Cwd, - CommandLine: []string{"-c", t.Cmd}, - ReadOnlyForGuests: false, - } - - var output []byte - runCall := t.Terminal.RPCClient.Go("terminal.startAsync", &args, &output, nil) - - runReply := <-runCall.Done - if runReply.Error != nil { - return nil, fmt.Errorf("error startAsync operation: %v", runReply.Error) - } - fmt.Printf("%+v\n\n", runReply) - return output, nil -} diff --git a/port_forwarder.go b/port_forwarder.go new file mode 100644 index 000000000..0ae5e1916 --- /dev/null +++ b/port_forwarder.go @@ -0,0 +1,63 @@ +package liveshare + +import ( + "context" + "fmt" + "io" + "log" + "net" + "strconv" + + "golang.org/x/crypto/ssh" +) + +type LocalPortForwarder struct { + client *Client + server *Server + port int + channels []ssh.Channel +} + +func NewLocalPortForwarder(client *Client, server *Server, port int) *LocalPortForwarder { + return &LocalPortForwarder{client, server, port, []ssh.Channel{}} +} + +func (l *LocalPortForwarder) Start(ctx context.Context) error { + ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) + if err != nil { + return fmt.Errorf("error listening on tcp port: %v", err) + } + + for { + conn, err := ln.Accept() + if err != nil { + return fmt.Errorf("error accepting incoming connection: %v", err) + } + + go l.handleConnection(ctx, conn) + } + + // clean up after ourselves + + return nil +} + +func (l *LocalPortForwarder) handleConnection(ctx context.Context, conn net.Conn) { + channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) + if err != nil { + log.Println("errrr handle Connect") + log.Println(err) // TODO(josebalius) handle this somehow + } + l.channels = append(l.channels, channel) + + copyConn := func(writer io.Writer, reader io.Reader) { + _, err := io.Copy(writer, reader) + if err != nil { + log.Println("errrrr copyConn") + log.Println(err) //TODO(josebalius): handle this somehow + } + } + + go copyConn(conn, channel) + go copyConn(channel, conn) +} diff --git a/rpc.go b/rpc.go new file mode 100644 index 000000000..e90f71ba6 --- /dev/null +++ b/rpc.go @@ -0,0 +1,97 @@ +package liveshare + +import ( + "context" + "encoding/json" + "fmt" + "io" + "sync" + + "github.com/sourcegraph/jsonrpc2" +) + +type rpc struct { + *jsonrpc2.Conn + conn io.ReadWriteCloser + handler *rpcHandler +} + +func newRPC(conn io.ReadWriteCloser) *rpc { + return &rpc{conn: conn, handler: newRPCHandler()} +} + +func (r *rpc) connect(ctx context.Context) { + stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) + r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) +} + +func (r *rpc) do(ctx context.Context, method string, args interface{}, result interface{}) error { + b, _ := json.Marshal(args) + fmt.Println("rpc sent: ", method, string(b)) + waiter, err := r.Conn.DispatchCall(ctx, method, args) + if err != nil { + return fmt.Errorf("error on dispatch call: %v", err) + } + + // caller doesn't care about result, so lets ignore it + if result == nil { + return nil + } + + return waiter.Wait(ctx, result) +} + +type rpcHandler struct { + mutex sync.RWMutex + eventHandlers map[string][]chan *jsonrpc2.Request +} + +func newRPCHandler() *rpcHandler { + return &rpcHandler{ + eventHandlers: make(map[string][]chan *jsonrpc2.Request), + } +} + +func (r *rpcHandler) registerEventHandler(eventMethod string) <-chan *jsonrpc2.Request { + r.mutex.Lock() + defer r.mutex.Unlock() + + ch := make(chan *jsonrpc2.Request) + if _, ok := r.eventHandlers[eventMethod]; !ok { + r.eventHandlers[eventMethod] = []chan *jsonrpc2.Request{ch} + } else { + r.eventHandlers[eventMethod] = append(r.eventHandlers[eventMethod], ch) + } + return ch +} + +func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + r.mutex.Lock() + defer r.mutex.Unlock() + + fmt.Println("REQUEST") + fmt.Println("Method:", req.Method) + b, _ := req.MarshalJSON() + fmt.Println(string(b)) + fmt.Println("----") + fmt.Printf("%+v\n\n", r.eventHandlers) + if handlers, ok := r.eventHandlers[req.Method]; ok { + go func() { + for _, handler := range handlers { + select { + case handler <- req: + case <-ctx.Done(): + break + } + } + + r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{} + }() + } else { + fmt.Println("UNHANDLED REQUEST") + fmt.Println("Method:", req.Method) + b, _ := req.MarshalJSON() + fmt.Println(string(b)) + fmt.Println("----") + } +} diff --git a/rpc_test.go b/rpc_test.go new file mode 100644 index 000000000..d16b32a4f --- /dev/null +++ b/rpc_test.go @@ -0,0 +1,27 @@ +package liveshare + +import ( + "context" + "testing" + "time" + + "github.com/sourcegraph/jsonrpc2" +) + +func TestRPCHandlerEvents(t *testing.T) { + rpcHandler := newRPCHandler() + eventCh := rpcHandler.registerEventHandler("somethingHappened") + go func() { + time.Sleep(1 * time.Second) + rpcHandler.Handle(context.Background(), nil, &jsonrpc2.Request{Method: "somethingHappened"}) + }() + ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + select { + case event := <-eventCh: + if event.Method != "somethingHappened" { + t.Error("event.Method is not the expect value") + } + case <-ctx.Done(): + t.Error("Test time out") + } +} diff --git a/server.go b/server.go new file mode 100644 index 000000000..65e03d584 --- /dev/null +++ b/server.go @@ -0,0 +1,52 @@ +package liveshare + +import ( + "context" + "errors" + "fmt" + "strconv" +) + +type Server struct { + client *Client + port int + streamName, streamCondition string +} + +func (c *Client) NewServer() (*Server, error) { + if !c.hasJoined() { + return nil, errors.New("LiveShareClient must join before creating server") + } + + return &Server{client: c}, nil +} + +type serverSharingResponse struct { + SourcePort int `json:"sourcePort"` + DestinationPort int `json:"destinationPort"` + SessionName string `json:"sessionName"` + StreamName string `json:"streamName"` + StreamCondition string `json:"streamCondition"` + BrowseURL string `json:"browseUrl"` + IsPublic bool `json:"isPublic"` + IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"` + HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"` +} + +func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { + s.port = port + + sharingStarted := s.client.rpc.handler.registerEventHandler("serverSharing.sharingStarted") + var response serverSharingResponse + if err := s.client.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ + port, protocol, fmt.Sprintf("http://localhost:%s", strconv.Itoa(port)), + }, &response); err != nil { + return err + } + <-sharingStarted + + s.streamName = response.StreamName + s.streamCondition = response.StreamCondition + + return nil +} diff --git a/session.go b/session.go index 24a284ef2..d0492a10e 100644 --- a/session.go +++ b/session.go @@ -3,42 +3,58 @@ package liveshare import ( "context" "fmt" + "net/url" + "strings" "golang.org/x/sync/errgroup" ) -type Session struct { - WorkspaceAccess *WorkspaceAccessResponse - WorkspaceInfo *WorkspaceInfoResponse +type session struct { + api *api + + workspaceAccess *workspaceAccessResponse + workspaceInfo *workspaceInfoResponse } -func GetSession(ctx context.Context, configuration *Configuration) (*Session, error) { - api := NewAPI(configuration) - session := new(Session) +func newSession(api *api) *session { + return &session{api: api} +} +func (s *session) init(ctx context.Context) error { g, ctx := errgroup.WithContext(ctx) g.Go(func() error { - workspaceAccess, err := api.WorkspaceAccess() + workspaceAccess, err := s.api.workspaceAccess() if err != nil { return fmt.Errorf("error getting workspace access: %v", err) } - session.WorkspaceAccess = workspaceAccess + s.workspaceAccess = workspaceAccess return nil }) g.Go(func() error { - workspaceInfo, err := api.WorkspaceInfo() + workspaceInfo, err := s.api.workspaceInfo() if err != nil { return fmt.Errorf("error getting workspace info: %v", err) } - session.WorkspaceInfo = workspaceInfo + s.workspaceInfo = workspaceInfo return nil }) if err := g.Wait(); err != nil { - return nil, err + return err } - return session, nil + return nil +} + +// Reference: +// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 +func (s *session) relayURI(action string) string { + relaySas := url.QueryEscape(s.workspaceAccess.RelaySas) + relayURI := s.workspaceAccess.RelayLink + relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) + relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) + relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas + return relayURI } diff --git a/ssh.go b/ssh.go index 51906290e..9ae32ed7c 100644 --- a/ssh.go +++ b/ssh.go @@ -1,101 +1,68 @@ package liveshare import ( + "context" "fmt" "io" "net" - "net/url" - "strings" "time" - "github.com/gorilla/websocket" "golang.org/x/crypto/ssh" ) -type SSH struct { - Session *Session -} - -func NewSSH(session *Session) *SSH { - return &SSH{ - Session: session, - } -} - -// Reference: -// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 -func (s *SSH) relayURI(action string) string { - relaySas := url.QueryEscape(s.Session.WorkspaceAccess.RelaySas) - relayURI := s.Session.WorkspaceAccess.RelayLink - relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) - relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) - relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas - return relayURI -} - -func (s *SSH) socketStream() (net.Conn, error) { - uri := s.relayURI("connect") - - ws, _, err := websocket.DefaultDialer.Dial(uri, nil) - if err != nil { - return nil, fmt.Errorf("error dialing websocket connection: %v", err) - } - - return NewAdapter(ws), nil -} - -type SSHSession struct { +type sshSession struct { *ssh.Session - reader io.Reader - writer io.Writer + session *session + socket net.Conn + conn ssh.Conn + reader io.Reader + writer io.Writer } -func (s SSHSession) Read(p []byte) (n int, err error) { - return s.reader.Read(p) +func newSSH(session *session, socket net.Conn) *sshSession { + return &sshSession{session: session, socket: socket} } -func (s SSHSession) Write(p []byte) (n int, err error) { - return s.writer.Write(p) -} - -func (s *SSH) NewSession() (*SSHSession, error) { - socketStream, err := s.socketStream() - if err != nil { - return nil, fmt.Errorf("error creating socket stream: %v", err) - } - +func (s *sshSession) connect(ctx context.Context) error { clientConfig := ssh.ClientConfig{ User: "", Auth: []ssh.AuthMethod{ - ssh.Password(s.Session.WorkspaceAccess.SessionToken), + ssh.Password(s.session.workspaceAccess.SessionToken), }, - HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { - // TODO(josebalius): implement - return nil - }, - Timeout: 10 * time.Second, + HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 10 * time.Second, } - sshClientConn, chans, reqs, err := ssh.NewClientConn(socketStream, "", &clientConfig) + sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig) if err != nil { - return nil, fmt.Errorf("error creating ssh client connection: %v", err) + return fmt.Errorf("error creating ssh client connection: %v", err) } + s.conn = sshClientConn sshClient := ssh.NewClient(sshClientConn, chans, reqs) - sshSession, err := sshClient.NewSession() + s.Session, err = sshClient.NewSession() if err != nil { - return nil, fmt.Errorf("error creating ssh client session: %v", err) + return fmt.Errorf("error creating ssh client session: %v", err) } - reader, err := sshSession.StdoutPipe() + s.reader, err = s.Session.StdoutPipe() if err != nil { - return nil, fmt.Errorf("error creating ssh session reader: %v", err) + return fmt.Errorf("error creating ssh session reader: %v", err) } - writer, err := sshSession.StdinPipe() + s.writer, err = s.Session.StdinPipe() if err != nil { - return nil, fmt.Errorf("error creating ssh session writer: %v", err) + return fmt.Errorf("error creating ssh session writer: %v", err) } - return &SSHSession{Session: sshSession, reader: reader, writer: writer}, nil + return nil +} + +func (s *sshSession) Read(p []byte) (n int, err error) { + return s.reader.Read(p) +} + +func (s *sshSession) Write(p []byte) (n int, err error) { + return s.writer.Write(p) } diff --git a/terminal.go b/terminal.go new file mode 100644 index 000000000..631e75912 --- /dev/null +++ b/terminal.go @@ -0,0 +1,116 @@ +package liveshare + +import ( + "context" + "errors" + "fmt" + "io" + + "golang.org/x/crypto/ssh" +) + +type Terminal struct { + client *Client +} + +func (c *Client) NewTerminal() (*Terminal, error) { + if !c.hasJoined() { + return nil, errors.New("LiveShareClient must join before creating terminal") + } + + return &Terminal{ + client: c, + }, nil +} + +type TerminalCommand struct { + terminal *Terminal + cwd string + cmd string +} + +func (t *Terminal) NewCommand(cwd, cmd string) TerminalCommand { + return TerminalCommand{t, cwd, cmd} +} + +type runArgs struct { + Name string `json:"name"` + Rows int `json:"rows"` + Cols int `json:"cols"` + App string `json:"app"` + Cwd string `json:"cwd"` + CommandLine []string `json:"commandLine"` + ReadOnlyForGuests bool `json:"readOnlyForGuests"` +} + +type startTerminalResult struct { + ID int `json:"id"` + StreamName string `json:"streamName"` + StreamCondition string `json:"streamCondition"` + LocalPipeName string `json:"localPipeName"` + AppProcessID int `json:"appProcessId"` +} + +type getStreamArgs struct { + StreamName string `json:"streamName"` + Condition string `json:"condition"` +} + +type stopTerminalArgs struct { + ID int `json:"id"` +} + +func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { + args := runArgs{ + Name: "RunCommand", + Rows: 10, + Cols: 80, + App: "/bin/bash", + Cwd: t.cwd, + CommandLine: []string{"-c", t.cmd}, + ReadOnlyForGuests: false, + } + + terminalStarted := t.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStarted") + var result startTerminalResult + if err := t.terminal.client.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil { + return nil, fmt.Errorf("error making terminal.startTerminal call: %v", err) + } + <-terminalStarted + + channel, err := t.terminal.client.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) + if err != nil { + return nil, fmt.Errorf("error opening streaming channel: %v", err) + } + + return t.newTerminalReadCloser(result.ID, channel), nil +} + +type terminalReadCloser struct { + terminalCommand TerminalCommand + terminalID int + channel ssh.Channel +} + +func (t TerminalCommand) newTerminalReadCloser(terminalID int, channel ssh.Channel) io.ReadCloser { + return terminalReadCloser{t, terminalID, channel} +} + +func (t terminalReadCloser) Read(b []byte) (int, error) { + return t.channel.Read(b) +} + +func (t terminalReadCloser) Close() error { + terminalStopped := t.terminalCommand.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStopped") + if err := t.terminalCommand.terminal.client.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil { + return fmt.Errorf("error making terminal.stopTerminal call: %v", err) + } + + if err := t.channel.Close(); err != nil { + return fmt.Errorf("error closing channel: %v", err) + } + + <-terminalStopped + + return nil +} diff --git a/websocket.go b/websocket.go new file mode 100644 index 000000000..ae163e6e2 --- /dev/null +++ b/websocket.go @@ -0,0 +1,105 @@ +package liveshare + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + gorillawebsocket "github.com/gorilla/websocket" +) + +type websocket struct { + session *session + conn *gorillawebsocket.Conn + readMutex sync.Mutex + writeMutex sync.Mutex + reader io.Reader +} + +func newWebsocket(session *session) *websocket { + return &websocket{session: session} +} + +func (w *websocket) connect(ctx context.Context) error { + ws, _, err := gorillawebsocket.DefaultDialer.Dial(w.session.relayURI("connect"), nil) + if err != nil { + return err + } + w.conn = ws + return nil +} + +func (w *websocket) Read(b []byte) (int, error) { + w.readMutex.Lock() + defer w.readMutex.Unlock() + + if w.reader == nil { + messageType, reader, err := w.conn.NextReader() + if err != nil { + return 0, err + } + + if messageType != gorillawebsocket.BinaryMessage { + return 0, errors.New("unexpected websocket message type") + } + + w.reader = reader + } + + bytesRead, err := w.reader.Read(b) + if err != nil { + w.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (w *websocket) Write(b []byte) (int, error) { + w.writeMutex.Lock() + defer w.writeMutex.Unlock() + + nextWriter, err := w.conn.NextWriter(gorillawebsocket.BinaryMessage) + if err != nil { + return 0, err + } + + bytesWritten, err := nextWriter.Write(b) + nextWriter.Close() + + return bytesWritten, err +} + +func (w *websocket) Close() error { + return w.conn.Close() +} + +func (w *websocket) LocalAddr() net.Addr { + return w.conn.LocalAddr() +} + +func (w *websocket) RemoteAddr() net.Addr { + return w.conn.RemoteAddr() +} + +func (w *websocket) SetDeadline(t time.Time) error { + if err := w.SetReadDeadline(t); err != nil { + return err + } + + return w.SetWriteDeadline(t) +} + +func (w *websocket) SetReadDeadline(t time.Time) error { + return w.conn.SetReadDeadline(t) +} + +func (w *websocket) SetWriteDeadline(t time.Time) error { + return w.conn.SetWriteDeadline(t) +}