rename Server to Session and simplify API

This commit is contained in:
Alan Donovan 2021-09-02 11:06:49 -04:00
parent 9964a444b0
commit 4cceda1af0
10 changed files with 82 additions and 162 deletions

View file

@ -8,13 +8,10 @@ import (
"golang.org/x/crypto/ssh"
)
// A Client capable of joining a liveshare connection
// A Client capable of joining a Live Share workspace.
type Client struct {
connection Connection
tlsConfig *tls.Config
ssh *sshSession
rpc *rpcClient
}
// A ClientOption is a function that modifies a client
@ -52,31 +49,26 @@ func WithTLSConfig(tlsConfig *tls.Config) ClientOption {
}
}
// Join is a method that joins the client to the liveshare session
func (c *Client) Join(ctx context.Context) (err error) {
// JoinWorkspace connects the client to the server's Live Share
// workspace and returns a session representing their connection.
func (c *Client) JoinWorkspace(ctx context.Context) (*Session, error) {
clientSocket := newSocket(c.connection, c.tlsConfig)
if err := clientSocket.connect(ctx); err != nil {
return fmt.Errorf("error connecting websocket: %v", err)
return nil, fmt.Errorf("error connecting websocket: %v", err)
}
c.ssh = newSshSession(c.connection.SessionToken, clientSocket)
if err := c.ssh.connect(ctx); err != nil {
return fmt.Errorf("error connecting to ssh session: %v", err)
ssh := newSSHSession(c.connection.SessionToken, clientSocket)
if err := ssh.connect(ctx); err != nil {
return nil, fmt.Errorf("error connecting to ssh session: %v", err)
}
c.rpc = newRpcClient(c.ssh)
c.rpc.connect(ctx)
_, err = c.joinWorkspace(ctx)
if err != nil {
return fmt.Errorf("error joining Live Share workspace: %v", err)
rpc := newRpcClient(ssh)
rpc.connect(ctx)
if _, err := c.joinWorkspace(ctx, rpc); err != nil {
return nil, fmt.Errorf("error joining Live Share workspace: %v", err)
}
return nil
}
func (c *Client) hasJoined() bool {
return c.ssh != nil && c.rpc != nil
return &Session{ssh: ssh, rpc: rpc}, nil
}
type clientCapabilities struct {
@ -94,32 +86,32 @@ type joinWorkspaceResult struct {
SessionNumber int `json:"sessionNumber"`
}
func (c *Client) joinWorkspace(ctx context.Context) (*joinWorkspaceResult, error) {
func (client *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) {
args := joinWorkspaceArgs{
ID: c.connection.SessionID,
ID: client.connection.SessionID,
ConnectionMode: "local",
JoiningUserSessionToken: c.connection.SessionToken,
JoiningUserSessionToken: client.connection.SessionToken,
ClientCapabilities: clientCapabilities{
IsNonInteractive: false,
},
}
var result joinWorkspaceResult
if err := c.rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil {
if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil {
return nil, fmt.Errorf("error making workspace.joinWorkspace call: %v", err)
}
return &result, nil
}
func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) {
func (session *Session) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) {
args := getStreamArgs{streamName, condition}
var streamID string
if err := c.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
if err := session.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
return nil, fmt.Errorf("error getting stream id: %v", err)
}
channel, reqs, err := c.ssh.conn.OpenChannel("session", nil)
channel, reqs, err := session.ssh.conn.OpenChannel("session", nil)
if err != nil {
return nil, fmt.Errorf("error opening ssh channel for transport: %v", err)
}

View file

@ -43,7 +43,7 @@ func TestNewClientWithInvalidConnection(t *testing.T) {
}
}
func TestClientJoin(t *testing.T) {
func TestJoinSession(t *testing.T) {
connection := Connection{
SessionID: "session-id",
SessionToken: "session-token",
@ -90,10 +90,12 @@ func TestClientJoin(t *testing.T) {
done := make(chan error)
go func() {
if err := client.Join(ctx); err != nil {
done <- fmt.Errorf("error joining client: %v", err)
session, err := client.JoinWorkspace(ctx)
if err != nil {
done <- fmt.Errorf("error joining workspace: %v", err)
return
}
_ = session
done <- nil
}()

View file

@ -7,20 +7,17 @@ import (
"net"
)
// A PortForwarder forwards TCP traffic between a port on a remote
// LiveShare host and a local port.
// A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session.
type PortForwarder struct {
client *Client
server *Server
port int
session *Session
port int
}
// NewPortForwarder creates a new PortForwarder that connects a given client, server and port.
func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder {
// NewPortForwarder creates a new PortForwarder for a given Live Share session and local TCP port.
func NewPortForwarder(session *Session, port int) *PortForwarder {
return &PortForwarder{
client: client,
server: server,
port: port,
session: session,
port: port,
}
}
@ -87,7 +84,7 @@ func awaitError(ctx context.Context, errc <-chan error) error {
func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) {
defer safeClose(conn, &err)
channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition)
channel, err := l.session.openStreamingChannel(ctx, l.session.streamName, l.session.streamCondition)
if err != nil {
return fmt.Errorf("error opening streaming channel for new connection: %v", err)
}

View file

@ -15,16 +15,12 @@ import (
)
func TestNewPortForwarder(t *testing.T) {
testServer, client, err := makeMockJoinedClient()
testServer, session, err := makeMockSession()
if err != nil {
t.Errorf("create mock client: %v", err)
}
defer testServer.Close()
server, err := NewServer(client)
if err != nil {
t.Errorf("create new server: %v", err)
}
pf := NewPortForwarder(client, server, 80)
pf := NewPortForwarder(session, 80)
if pf == nil {
t.Error("port forwarder is nil")
}
@ -40,27 +36,22 @@ func TestPortForwarderStart(t *testing.T) {
}
stream := bytes.NewBufferString("stream-data")
testServer, client, err := makeMockJoinedClient(
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.startSharing", serverSharing),
livesharetest.WithService("streamManager.getStream", getStream),
livesharetest.WithStream("stream-id", stream),
)
if err != nil {
t.Errorf("create mock client: %v", err)
t.Errorf("create mock session: %v", err)
}
defer testServer.Close()
server, err := NewServer(client)
if err != nil {
t.Errorf("create new server: %v", err)
}
ctx, _ := context.WithCancel(context.Background())
pf := NewPortForwarder(client, server, 8000)
pf := NewPortForwarder(session, 8000)
done := make(chan error)
go func() {
if err := server.StartSharing(ctx, "http", 8000); err != nil {
if err := session.StartSharing(ctx, "http", 8000); err != nil {
done <- fmt.Errorf("start sharing: %v", err)
}
done <- pf.Forward(ctx)

1
rpc.go
View file

@ -21,6 +21,7 @@ func newRpcClient(conn io.ReadWriteCloser) *rpcClient {
func (r *rpcClient) connect(ctx context.Context) {
stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{})
// TODO(adonovan): fix: ensure r.Conn is eventually Closed!
r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler)
}

View file

@ -2,27 +2,18 @@ package liveshare
import (
"context"
"errors"
"fmt"
"strconv"
)
// A Server represents the liveshare host and container server
type Server struct {
client *Client
// A Session represents the session between a connected Live Share client and server.
type Session struct {
ssh *sshSession
rpc *rpcClient
port int
streamName, streamCondition string
}
// NewServer creates a new Server with a given Client
func NewServer(client *Client) (*Server, error) {
if !client.hasJoined() {
return nil, errors.New("client must join before creating server")
}
return &Server{client: client}, nil
}
// Port represents an open port on the container
type Port struct {
SourcePort int `json:"sourcePort"`
@ -37,11 +28,11 @@ type Port struct {
}
// StartSharing tells the liveshare host to start sharing the port from the container
func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error {
func (s *Session) StartSharing(ctx context.Context, protocol string, port int) error {
s.port = port
var response Port
if err := s.client.rpc.do(ctx, "serverSharing.startSharing", []interface{}{
if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{
port, protocol, fmt.Sprintf("http://localhost:%s", strconv.Itoa(port)),
}, &response); err != nil {
return err
@ -53,13 +44,10 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er
return nil
}
// Ports is a slice of Port pointers
type Ports []*Port
// GetSharedServers returns a list of available/open ports from the container
func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) {
var response Ports
if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil {
func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) {
var response []*Port
if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil {
return nil, err
}
@ -68,8 +56,8 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) {
// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly
// via the Browse URL
func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error {
if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil {
func (s *Session) UpdateSharedVisibility(ctx context.Context, port int, public bool) error {
if err := s.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil {
return err
}

View file

@ -13,17 +13,7 @@ import (
"github.com/sourcegraph/jsonrpc2"
)
func TestNewServerWithNotJoinedClient(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Errorf("error creating new client: %v", err)
}
if _, err := NewServer(client); err == nil {
t.Error("expected error")
}
}
func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) {
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
connection := Connection{
SessionID: "session-id",
SessionToken: "session-token",
@ -47,25 +37,11 @@ func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Se
return nil, nil, fmt.Errorf("error creating new client: %v", err)
}
ctx := context.Background()
if err := client.Join(ctx); err != nil {
return nil, nil, fmt.Errorf("error joining client: %v", err)
}
return testServer, client, nil
}
func TestNewServer(t *testing.T) {
testServer, client, err := makeMockJoinedClient()
defer testServer.Close()
session, err := client.JoinWorkspace(ctx)
if err != nil {
t.Errorf("error creating mock joined client: %v", err)
}
server, err := NewServer(client)
if err != nil {
t.Errorf("error creating new server: %v", err)
}
if server == nil {
t.Error("server is nil")
return nil, nil, fmt.Errorf("error joining workspace: %v", err)
}
return testServer, session, nil
}
func TestServerStartSharing(t *testing.T) {
@ -95,25 +71,21 @@ func TestServerStartSharing(t *testing.T) {
}
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
}
testServer, client, err := makeMockJoinedClient(
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.startSharing", startSharing),
)
defer testServer.Close()
if err != nil {
t.Errorf("error creating mock joined client: %v", err)
}
server, err := NewServer(client)
if err != nil {
t.Errorf("error creating new server: %v", err)
t.Errorf("error creating mock session: %v", err)
}
ctx := context.Background()
done := make(chan error)
go func() {
if err := server.StartSharing(ctx, serverProtocol, serverPort); err != nil {
if err := session.StartSharing(ctx, serverProtocol, serverPort); err != nil {
done <- fmt.Errorf("error sharing server: %v", err)
}
if server.streamName == "" || server.streamCondition == "" {
if session.streamName == "" || session.streamCondition == "" {
done <- errors.New("stream name or condition is blank")
}
done <- nil
@ -136,23 +108,19 @@ func TestServerGetSharedServers(t *testing.T) {
StreamCondition: "stream-condition",
}
getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) {
return Ports{&sharedServer}, nil
return []*Port{&sharedServer}, nil
}
testServer, client, err := makeMockJoinedClient(
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.getSharedServers", getSharedServers),
)
if err != nil {
t.Errorf("error creating new mock client: %v", err)
t.Errorf("error creating mock session: %v", err)
}
defer testServer.Close()
server, err := NewServer(client)
if err != nil {
t.Errorf("error creating new server: %v", err)
}
ctx := context.Background()
done := make(chan error)
go func() {
ports, err := server.GetSharedServers(ctx)
ports, err := session.GetSharedServers(ctx)
if err != nil {
done <- fmt.Errorf("error getting shared servers: %v", err)
}
@ -206,25 +174,17 @@ func TestServerUpdateSharedVisibility(t *testing.T) {
}
return nil, nil
}
testServer, client, err := makeMockJoinedClient(
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility),
)
if err != nil {
t.Errorf("creating new mock client: %v", err)
t.Errorf("creating mock session: %v", err)
}
defer testServer.Close()
server, err := NewServer(client)
if err != nil {
t.Errorf("creating server: %v", err)
}
ctx := context.Background()
done := make(chan error)
go func() {
if err := server.UpdateSharedVisibility(ctx, 80, true); err != nil {
done <- err
return
}
done <- nil
done <- session.UpdateSharedVisibility(ctx, 80, true)
}()
select {
case err := <-testServer.Err():

2
ssh.go
View file

@ -19,7 +19,7 @@ type sshSession struct {
writer io.Writer
}
func newSshSession(token string, socket net.Conn) *sshSession {
func newSSHSession(token string, socket net.Conn) *sshSession {
return &sshSession{token: token, socket: socket}
}

View file

@ -2,18 +2,14 @@ package liveshare
import (
"context"
"errors"
)
type SSHServer struct {
client *Client
session *Session
}
func NewSSHServer(client *Client) (*SSHServer, error) {
if !client.hasJoined() {
return nil, errors.New("client must join before creating server")
}
return &SSHServer{client: client}, nil
func (session *Session) SSHServer() *SSHServer {
return &SSHServer{session: session}
}
type SSHServerStartResult struct {
@ -23,12 +19,12 @@ type SSHServerStartResult struct {
Message string `json:"message"`
}
func (s *SSHServer) StartRemoteServer(ctx context.Context) (SSHServerStartResult, error) {
func (s *SSHServer) StartRemoteServer(ctx context.Context) (*SSHServerStartResult, error) {
var response SSHServerStartResult
if err := s.client.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil {
return response, err
if err := s.session.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil {
return nil, err
}
return response, nil
return &response, nil
}

View file

@ -2,7 +2,6 @@ package liveshare
import (
"context"
"errors"
"fmt"
"io"
@ -10,17 +9,11 @@ import (
)
type Terminal struct {
client *Client
session *Session
}
func NewTerminal(client *Client) (*Terminal, error) {
if !client.hasJoined() {
return nil, errors.New("client must join before creating terminal")
}
return &Terminal{
client: client,
}, nil
func NewTerminal(session *Session) *Terminal {
return &Terminal{session: session}
}
type TerminalCommand struct {
@ -71,14 +64,14 @@ func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) {
ReadOnlyForGuests: false,
}
terminalStarted := t.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStarted")
terminalStarted := t.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStarted")
var result startTerminalResult
if err := t.terminal.client.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil {
if err := t.terminal.session.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil {
return nil, fmt.Errorf("error making terminal.startTerminal call: %v", err)
}
<-terminalStarted
channel, err := t.terminal.client.openStreamingChannel(ctx, result.StreamName, result.StreamCondition)
channel, err := t.terminal.session.openStreamingChannel(ctx, result.StreamName, result.StreamCondition)
if err != nil {
return nil, fmt.Errorf("error opening streaming channel: %v", err)
}
@ -101,8 +94,8 @@ func (t terminalReadCloser) Read(b []byte) (int, error) {
}
func (t terminalReadCloser) Close() error {
terminalStopped := t.terminalCommand.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStopped")
if err := t.terminalCommand.terminal.client.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil {
terminalStopped := t.terminalCommand.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStopped")
if err := t.terminalCommand.terminal.session.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil {
return fmt.Errorf("error making terminal.stopTerminal call: %v", err)
}