diff --git a/pkg/cmd/codespace/ports.go b/pkg/cmd/codespace/ports.go index 737e824af..3b698494a 100644 --- a/pkg/cmd/codespace/ports.go +++ b/pkg/cmd/codespace/ports.go @@ -231,6 +231,28 @@ func newPortsVisibilityCmd(app *App) *cobra.Command { } } +type ErrUpdatingPortVisibility struct { + port int + visibility string + err error +} + +func newErrUpdatingPortVisibility(port int, visibility string, err error) *ErrUpdatingPortVisibility { + return &ErrUpdatingPortVisibility{ + port: port, + visibility: visibility, + err: err, + } +} + +func (e *ErrUpdatingPortVisibility) Error() string { + return fmt.Sprintf("error waiting for port %d to update to %s: %s", e.port, e.visibility, e.err) +} + +func (e *ErrUpdatingPortVisibility) Unwrap() error { + return e.err +} + func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName string, args []string) (err error) { ports, err := a.parsePortVisibilities(args) if err != nil { @@ -251,6 +273,9 @@ func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName string, ar } defer safeClose(session, &err) + success := session.RegisterEvent("sharingSucceeded") + failure := session.RegisterEvent("sharingFailed") + // TODO: check if port visibility can be updated in parallel instead of sequentially for _, port := range ports { a.StartProgressIndicatorWithLabel(fmt.Sprintf("Updating port %d visibility to: %s", port.number, port.visibility)) @@ -264,8 +289,8 @@ func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName string, ar ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - if err := a.waitForPortUpdate(ctx, session, port.number); err != nil { - return fmt.Errorf("error waiting for port %d to update to %s: %w", port.number, port.visibility, err) + if err := a.waitForPortUpdate(ctx, success, failure, session, port.number); err != nil { + return newErrUpdatingPortVisibility(port.number, port.visibility, err) } a.StopProgressIndicator() @@ -287,10 +312,9 @@ type portData struct { StatusCode int `json:"statusCode"` } -func (a *App) waitForPortUpdate(ctx context.Context, session *liveshare.Session, port int) error { - success := session.WaitForEvent("sharingSucceeded") - failure := session.WaitForEvent("sharingFailed") +var errUpdatePortVisibilityForbidden = errors.New("organization admin has forbidden this privacy setting") +func (a *App) waitForPortUpdate(ctx context.Context, success, failure chan []byte, session *liveshare.Session, port int) error { for { var pd portData select { @@ -309,7 +333,7 @@ func (a *App) waitForPortUpdate(ctx context.Context, session *liveshare.Session, } if pd.Port == port && pd.ChangeKind == portChangeKindUpdate { if pd.StatusCode == http.StatusForbidden { - return errors.New("organization admin has forbidden this privacy setting") + return errUpdatePortVisibilityForbidden } return errors.New(pd.ErrorDetail) } diff --git a/pkg/cmd/codespace/ports_test.go b/pkg/cmd/codespace/ports_test.go index a936223d5..9f4a3f839 100644 --- a/pkg/cmd/codespace/ports_test.go +++ b/pkg/cmd/codespace/ports_test.go @@ -3,6 +3,7 @@ package codespace import ( "context" "encoding/json" + "errors" "fmt" "testing" @@ -40,7 +41,7 @@ func TestPortsUpdateVisibilitySuccess(t *testing.T) { }, } - err := RunUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData) + err := runUpdateVisibilityTest(portVisibilities, eventResponses, portsData) if err != nil { t.Errorf("unexpected error: %v", err) @@ -77,14 +78,13 @@ func TestPortsUpdateVisibilityFailure403(t *testing.T) { }, } - err := RunUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData) + err := runUpdateVisibilityTest(portVisibilities, eventResponses, portsData) if err == nil { t.Errorf("unexpected error: %v", err) } - expectedErr := "error waiting for port 9999 to update to public: organization admin has forbidden this privacy setting" - if err.Error() != expectedErr { - t.Errorf("expected: %v, got: %v", expectedErr, err) + if errors.Unwrap(err) != errUpdatePortVisibilityForbidden { + t.Errorf("expected: %v, got: %v", errUpdatePortVisibilityForbidden, errors.Unwrap(err)) } } @@ -117,13 +117,13 @@ func TestPortsUpdateVisibilityFailure(t *testing.T) { }, } - err := RunUpdateVisibilityTest(t, portVisibilities, eventResponses, portsData) + err := runUpdateVisibilityTest(portVisibilities, eventResponses, portsData) if err == nil { t.Errorf("unexpected error: %v", err) } - expectedErr := "error waiting for port 9999 to update to public: test error" - if err.Error() != expectedErr { + var expectedErr *ErrUpdatingPortVisibility + if !errors.As(err, &expectedErr) { t.Errorf("expected: %v, got: %v", expectedErr, err) } } @@ -132,7 +132,7 @@ type joinWorkspaceResult struct { SessionNumber int `json:"sessionNumber"` } -func RunUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, eventResponses []string, portsData []portData) error { +func runUpdateVisibilityTest(portVisibilities []portVisibility, eventResponses []string, portsData []portData) error { joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { return joinWorkspaceResult{1}, nil } @@ -158,7 +158,7 @@ func RunUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, ev livesharetest.WithService("serverSharing.updateSharedServerPrivacy", updateSharedVisibility), ) if err != nil { - t.Fatal(err) + return fmt.Errorf("unable to create test server: %w", err) } type rpcMessage struct { @@ -166,21 +166,25 @@ func RunUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, ev Params portData } - for index, pd := range portsData { - go func(index int, pd portData) { - for { - select { - case <-ctx.Done(): - return - case <-ch: - testServer.WriteToObjectStream(rpcMessage{ - Method: eventResponses[index], - Params: pd, - }) + go func() { + var i int + for ; ; i++ { + select { + case <-ctx.Done(): + return + case <-ch: + pd := portsData[i] + // TODO: handle error + err := testServer.WriteToObjectStream(rpcMessage{ + Method: eventResponses[i], + Params: pd, + }) + if err != nil { + panic(err) } } - }(index, pd) - } + } + }() mockApi := &apiClientMock{ GetCodespaceFunc: func(ctx context.Context, codespaceName string, includeConnection bool) (*api.Codespace, error) { diff --git a/pkg/liveshare/rpc.go b/pkg/liveshare/rpc.go index 5187dc8de..e50e2576b 100644 --- a/pkg/liveshare/rpc.go +++ b/pkg/liveshare/rpc.go @@ -13,19 +13,17 @@ import ( type rpcClient struct { *jsonrpc2.Conn - conn io.ReadWriteCloser - - eventHandlersMu sync.RWMutex - eventHandlers map[string]chan []byte + conn io.ReadWriteCloser + requestHandler *requestHandler } func newRPCClient(conn io.ReadWriteCloser) *rpcClient { - return &rpcClient{conn: conn, eventHandlers: make(map[string]chan []byte)} + return &rpcClient{conn: conn, requestHandler: newRequestHandler()} } func (r *rpcClient) connect(ctx context.Context) { stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) - r.Conn = jsonrpc2.NewConn(ctx, stream, newRequestHandler(r)) + r.Conn = jsonrpc2.NewConn(ctx, stream, r.requestHandler) } func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error { @@ -44,7 +42,16 @@ func (r *rpcClient) do(ctx context.Context, method string, args, result interfac return waiter.Wait(waitCtx, result) } -func (r *rpcClient) registerEventHandler(eventName string) chan []byte { +type requestHandler struct { + eventHandlersMu sync.RWMutex + eventHandlers map[string]chan []byte +} + +func newRequestHandler() *requestHandler { + return &requestHandler{eventHandlers: make(map[string]chan []byte)} +} + +func (r *requestHandler) registerEvent(eventName string) chan []byte { r.eventHandlersMu.Lock() defer r.eventHandlersMu.Unlock() @@ -57,23 +64,19 @@ func (r *rpcClient) registerEventHandler(eventName string) chan []byte { return ch } -func (r *rpcClient) eventHandler(eventName string) chan []byte { +func (r *requestHandler) eventHandler(eventName string) chan []byte { r.eventHandlersMu.RLock() defer r.eventHandlersMu.RUnlock() return r.eventHandlers[eventName] } -type requestHandler struct { - rpcClient *rpcClient -} - -func newRequestHandler(rpcClient *rpcClient) *requestHandler { - return &requestHandler{rpcClient: rpcClient} -} - -func (e *requestHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { - handler := e.rpcClient.eventHandler(req.Method) +func (r *requestHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + fmt.Println(req.Method) + if req.Params != nil { + fmt.Println(string(*req.Params)) + } + handler := r.eventHandler(req.Method) if handler == nil { return // noop } diff --git a/pkg/liveshare/session.go b/pkg/liveshare/session.go index 715d24fc6..25e2143ab 100644 --- a/pkg/liveshare/session.go +++ b/pkg/liveshare/session.go @@ -78,8 +78,8 @@ func (s *Session) UpdateSharedServerPrivacy(ctx context.Context, port int, visib return nil } -func (s *Session) WaitForEvent(eventName string) chan []byte { - return s.rpc.registerEventHandler(eventName) +func (s *Session) RegisterEvent(eventName string) chan []byte { + return s.rpc.requestHandler.registerEvent(eventName) } // StartsSSHServer starts an SSH server in the container, installing sshd if necessary, diff --git a/pkg/liveshare/ssh.go b/pkg/liveshare/ssh.go index e7de9055a..ec32671be 100644 --- a/pkg/liveshare/ssh.go +++ b/pkg/liveshare/ssh.go @@ -50,6 +50,7 @@ func (s *sshSession) connect(ctx context.Context) error { return fmt.Errorf("error creating ssh client connection: %w", err) } s.conn = sshClientConn + go s.handleGlobalRequests(reqs) sshClient := ssh.NewClient(sshClientConn, chans, reqs) s.Session, err = sshClient.NewSession() @@ -70,6 +71,15 @@ func (s *sshSession) connect(ctx context.Context) error { return nil } +func (s *sshSession) handleGlobalRequests(incoming <-chan *ssh.Request) { + for r := range incoming { + fmt.Println(r.Type) + // This handles keepalive messages and matches + // the behaviour of OpenSSH. + r.Reply(false, nil) + } +} + func (s *sshSession) Read(p []byte) (n int, err error) { return s.reader.Read(p) }