Merge pull request #10 from github/rpc-fixes

fix two data races in rpcHandler
This commit is contained in:
Alan Donovan 2021-09-02 11:38:38 -04:00 committed by GitHub
commit 34c20cf105
5 changed files with 42 additions and 39 deletions

View file

@ -64,7 +64,7 @@ func (c *Client) Join(ctx context.Context) (err error) {
return fmt.Errorf("error connecting to ssh session: %v", err)
}
c.rpc = newRpcClient(c.ssh)
c.rpc = newRPCClient(c.ssh)
c.rpc.connect(ctx)
_, err = c.joinWorkspace(ctx)

View file

@ -55,7 +55,8 @@ func TestPortForwarderStart(t *testing.T) {
t.Errorf("create new server: %v", err)
}
ctx, _ := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pf := NewPortForwarder(client, server, 8000)
done := make(chan error)

53
rpc.go
View file

@ -15,7 +15,7 @@ type rpcClient struct {
handler *rpcHandler
}
func newRpcClient(conn io.ReadWriteCloser) *rpcClient {
func newRPCClient(conn io.ReadWriteCloser) *rpcClient {
return &rpcClient{conn: conn, handler: newRPCHandler()}
}
@ -24,54 +24,45 @@ func (r *rpcClient) connect(ctx context.Context) {
r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler)
}
func (r *rpcClient) do(ctx context.Context, method string, args interface{}, result interface{}) error {
func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error {
waiter, err := r.Conn.DispatchCall(ctx, method, args)
if err != nil {
return fmt.Errorf("error on dispatch call: %v", err)
return fmt.Errorf("error dispatching %q call: %v", method, err)
}
return waiter.Wait(ctx, result)
}
type rpcHandlerFunc = func(*jsonrpc2.Request)
type rpcHandler struct {
mutex sync.RWMutex
eventHandlers map[string][]chan *jsonrpc2.Request
handlersMu sync.Mutex
handlers map[string][]rpcHandlerFunc
}
func newRPCHandler() *rpcHandler {
return &rpcHandler{
eventHandlers: make(map[string][]chan *jsonrpc2.Request),
handlers: make(map[string][]rpcHandlerFunc),
}
}
func (r *rpcHandler) registerEventHandler(eventMethod string) <-chan *jsonrpc2.Request {
r.mutex.Lock()
defer r.mutex.Unlock()
ch := make(chan *jsonrpc2.Request)
if _, ok := r.eventHandlers[eventMethod]; !ok {
r.eventHandlers[eventMethod] = []chan *jsonrpc2.Request{ch}
} else {
r.eventHandlers[eventMethod] = append(r.eventHandlers[eventMethod], ch)
}
return ch
// registerEventHandler registers a handler for the specified event.
// After the next occurrence of the event, the handler will be called,
// once, in its own goroutine.
func (r *rpcHandler) registerEventHandler(eventMethod string, h rpcHandlerFunc) {
r.handlersMu.Lock()
r.handlers[eventMethod] = append(r.handlers[eventMethod], h)
r.handlersMu.Unlock()
}
// Handle calls all registered handlers for the request, concurrently, each in its own goroutine.
func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.handlersMu.Lock()
handlers := r.handlers[req.Method]
r.handlers[req.Method] = nil
r.handlersMu.Unlock()
if handlers, ok := r.eventHandlers[req.Method]; ok {
go func() {
for _, handler := range handlers {
select {
case handler <- req:
case <-ctx.Done():
break
}
}
r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{}
}()
for _, h := range handlers {
go h(req)
}
}

View file

@ -10,12 +10,16 @@ import (
func TestRPCHandlerEvents(t *testing.T) {
rpcHandler := newRPCHandler()
eventCh := rpcHandler.registerEventHandler("somethingHappened")
eventCh := make(chan *jsonrpc2.Request)
rpcHandler.registerEventHandler("somethingHappened", func(req *jsonrpc2.Request) {
eventCh <- req
})
go func() {
time.Sleep(1 * time.Second)
rpcHandler.Handle(context.Background(), nil, &jsonrpc2.Request{Method: "somethingHappened"})
}()
ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
defer cancel()
select {
case event := <-eventCh:
if event.Method != "somethingHappened" {

View file

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"github.com/sourcegraph/jsonrpc2"
"golang.org/x/crypto/ssh"
)
@ -71,12 +72,15 @@ func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) {
ReadOnlyForGuests: false,
}
terminalStarted := t.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStarted")
started := make(chan struct{})
t.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStarted", func(*jsonrpc2.Request) {
close(started)
})
var result startTerminalResult
if err := t.terminal.client.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil {
return nil, fmt.Errorf("error making terminal.startTerminal call: %v", err)
}
<-terminalStarted
<-started
channel, err := t.terminal.client.openStreamingChannel(ctx, result.StreamName, result.StreamCondition)
if err != nil {
@ -101,7 +105,10 @@ func (t terminalReadCloser) Read(b []byte) (int, error) {
}
func (t terminalReadCloser) Close() error {
terminalStopped := t.terminalCommand.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStopped")
stopped := make(chan struct{})
t.terminalCommand.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStopped", func(*jsonrpc2.Request) {
close(stopped)
})
if err := t.terminalCommand.terminal.client.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil {
return fmt.Errorf("error making terminal.stopTerminal call: %v", err)
}
@ -110,7 +117,7 @@ func (t terminalReadCloser) Close() error {
return fmt.Errorf("error closing channel: %v", err)
}
<-terminalStopped
<-stopped
return nil
}