diff --git a/client.go b/client.go index 19b0aff50..377ec2512 100644 --- a/client.go +++ b/client.go @@ -62,7 +62,7 @@ func (c *Client) JoinWorkspace(ctx context.Context) (*Session, error) { return nil, fmt.Errorf("error connecting to ssh session: %v", err) } - rpc := newRpcClient(ssh) + rpc := newRPCClient(ssh) rpc.connect(ctx) if _, err := c.joinWorkspace(ctx, rpc); err != nil { return nil, fmt.Errorf("error joining Live Share workspace: %v", err) diff --git a/rpc.go b/rpc.go index c58ab419d..237606fe0 100644 --- a/rpc.go +++ b/rpc.go @@ -15,7 +15,7 @@ type rpcClient struct { handler *rpcHandler } -func newRpcClient(conn io.ReadWriteCloser) *rpcClient { +func newRPCClient(conn io.ReadWriteCloser) *rpcClient { return &rpcClient{conn: conn, handler: newRPCHandler()} } @@ -25,54 +25,45 @@ func (r *rpcClient) connect(ctx context.Context) { r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) } -func (r *rpcClient) do(ctx context.Context, method string, args interface{}, result interface{}) error { +func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error { waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { - return fmt.Errorf("error on dispatch call: %v", err) + return fmt.Errorf("error dispatching %q call: %v", method, err) } return waiter.Wait(ctx, result) } +type rpcHandlerFunc = func(*jsonrpc2.Request) + type rpcHandler struct { - mutex sync.RWMutex - eventHandlers map[string][]chan *jsonrpc2.Request + handlersMu sync.Mutex + handlers map[string][]rpcHandlerFunc } func newRPCHandler() *rpcHandler { return &rpcHandler{ - eventHandlers: make(map[string][]chan *jsonrpc2.Request), + handlers: make(map[string][]rpcHandlerFunc), } } -func (r *rpcHandler) registerEventHandler(eventMethod string) <-chan *jsonrpc2.Request { - r.mutex.Lock() - defer r.mutex.Unlock() - - ch := make(chan *jsonrpc2.Request) - if _, ok := r.eventHandlers[eventMethod]; !ok { - r.eventHandlers[eventMethod] = []chan *jsonrpc2.Request{ch} - } else { - r.eventHandlers[eventMethod] = append(r.eventHandlers[eventMethod], ch) - } - return ch +// registerEventHandler registers a handler for the specified event. +// After the next occurrence of the event, the handler will be called, +// once, in its own goroutine. +func (r *rpcHandler) registerEventHandler(eventMethod string, h rpcHandlerFunc) { + r.handlersMu.Lock() + r.handlers[eventMethod] = append(r.handlers[eventMethod], h) + r.handlersMu.Unlock() } +// Handle calls all registered handlers for the request, concurrently, each in its own goroutine. func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { - r.mutex.Lock() - defer r.mutex.Unlock() + r.handlersMu.Lock() + handlers := r.handlers[req.Method] + r.handlers[req.Method] = nil + r.handlersMu.Unlock() - if handlers, ok := r.eventHandlers[req.Method]; ok { - go func() { - for _, handler := range handlers { - select { - case handler <- req: - case <-ctx.Done(): - break - } - } - - r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{} - }() + for _, h := range handlers { + go h(req) } } diff --git a/rpc_test.go b/rpc_test.go index 7543152d1..cf9c4cf81 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -10,7 +10,10 @@ import ( func TestRPCHandlerEvents(t *testing.T) { rpcHandler := newRPCHandler() - eventCh := rpcHandler.registerEventHandler("somethingHappened") + eventCh := make(chan *jsonrpc2.Request) + rpcHandler.registerEventHandler("somethingHappened", func(req *jsonrpc2.Request) { + eventCh <- req + }) go func() { time.Sleep(1 * time.Second) rpcHandler.Handle(context.Background(), nil, &jsonrpc2.Request{Method: "somethingHappened"}) diff --git a/session.go b/session.go index d13bba9f1..ed87c6c2c 100644 --- a/session.go +++ b/session.go @@ -28,13 +28,14 @@ type Port struct { // TODO(adonovan): fix possible typo in field name, and audit others. } -// StartSharing tells the liveshare host to start sharing the port from the container -func (s *Session) StartSharing(ctx context.Context, protocol string, port int) error { +// StartSharing tells the Live Share host to start sharing the specified port from the container. +// The sessionName describes the purpose of the port or service. +func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) error { s.port = port var response Port if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ - port, protocol, fmt.Sprintf("http://localhost:%d", port), + port, sessionName, fmt.Sprintf("http://localhost:%d", port), }, &response); err != nil { return err } diff --git a/terminal.go b/terminal.go index 07532f426..96938ed89 100644 --- a/terminal.go +++ b/terminal.go @@ -5,6 +5,7 @@ import ( "fmt" "io" + "github.com/sourcegraph/jsonrpc2" "golang.org/x/crypto/ssh" ) @@ -64,12 +65,15 @@ func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { ReadOnlyForGuests: false, } - terminalStarted := t.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStarted") + started := make(chan struct{}) + t.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStarted", func(*jsonrpc2.Request) { + close(started) + }) var result startTerminalResult if err := t.terminal.session.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil { return nil, fmt.Errorf("error making terminal.startTerminal call: %v", err) } - <-terminalStarted + <-started channel, err := t.terminal.session.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) if err != nil { @@ -94,7 +98,10 @@ func (t terminalReadCloser) Read(b []byte) (int, error) { } func (t terminalReadCloser) Close() error { - terminalStopped := t.terminalCommand.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStopped") + stopped := make(chan struct{}) + t.terminalCommand.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStopped", func(*jsonrpc2.Request) { + close(stopped) + }) if err := t.terminalCommand.terminal.session.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil { return fmt.Errorf("error making terminal.stopTerminal call: %v", err) } @@ -103,7 +110,7 @@ func (t terminalReadCloser) Close() error { return fmt.Errorf("error closing channel: %v", err) } - <-terminalStopped + <-stopped return nil }