From f669a10cf9869f4171e94582fc4326cd77ff4b07 Mon Sep 17 00:00:00 2001 From: Caleb Brose <5447118+cmbrose@users.noreply.github.com> Date: Wed, 25 Jan 2023 14:57:21 -0600 Subject: [PATCH] Fix race conditions in invoker_test (#6905) --- .../codespace/codespace_host_service.v1.pb.go | 2 +- .../codespace_host_service.v1.proto.mock.go | 168 +++++++++++ internal/codespaces/rpc/generate.md | 3 +- internal/codespaces/rpc/generate.sh | 15 +- internal/codespaces/rpc/invoker.go | 6 +- internal/codespaces/rpc/invoker_test.go | 281 ++++++++++++++---- .../jupyter_server_host_service.v1.pb.go | 2 +- ...pyter_server_host_service.v1.proto.mock.go | 118 ++++++++ .../rpc/ssh/ssh_server_host_service.v1.pb.go | 2 +- .../ssh_server_host_service.v1.proto.mock.go | 118 ++++++++ internal/codespaces/rpc/test/server.go | 117 -------- internal/codespaces/rpc/test/session.go | 2 +- 12 files changed, 640 insertions(+), 194 deletions(-) create mode 100644 internal/codespaces/rpc/codespace/codespace_host_service.v1.proto.mock.go create mode 100644 internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.proto.mock.go create mode 100644 internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto.mock.go delete mode 100644 internal/codespaces/rpc/test/server.go diff --git a/internal/codespaces/rpc/codespace/codespace_host_service.v1.pb.go b/internal/codespaces/rpc/codespace/codespace_host_service.v1.pb.go index 21b908838..6da7f9e39 100644 --- a/internal/codespaces/rpc/codespace/codespace_host_service.v1.pb.go +++ b/internal/codespaces/rpc/codespace/codespace_host_service.v1.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.0 +// protoc-gen-go v1.28.1 // protoc v3.21.12 // source: codespace/codespace_host_service.v1.proto diff --git a/internal/codespaces/rpc/codespace/codespace_host_service.v1.proto.mock.go b/internal/codespaces/rpc/codespace/codespace_host_service.v1.proto.mock.go new file mode 100644 index 000000000..246849fe0 --- /dev/null +++ b/internal/codespaces/rpc/codespace/codespace_host_service.v1.proto.mock.go @@ -0,0 +1,168 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package codespace + +import ( + context "context" + sync "sync" +) + +// Ensure, that CodespaceHostServerMock does implement CodespaceHostServer. +// If this is not the case, regenerate this file with moq. +var _ CodespaceHostServer = &CodespaceHostServerMock{} + +// CodespaceHostServerMock is a mock implementation of CodespaceHostServer. +// +// func TestSomethingThatUsesCodespaceHostServer(t *testing.T) { +// +// // make and configure a mocked CodespaceHostServer +// mockedCodespaceHostServer := &CodespaceHostServerMock{ +// NotifyCodespaceOfClientActivityFunc: func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) { +// panic("mock out the NotifyCodespaceOfClientActivity method") +// }, +// RebuildContainerAsyncFunc: func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) { +// panic("mock out the RebuildContainerAsync method") +// }, +// mustEmbedUnimplementedCodespaceHostServerFunc: func() { +// panic("mock out the mustEmbedUnimplementedCodespaceHostServer method") +// }, +// } +// +// // use mockedCodespaceHostServer in code that requires CodespaceHostServer +// // and then make assertions. +// +// } +type CodespaceHostServerMock struct { + // NotifyCodespaceOfClientActivityFunc mocks the NotifyCodespaceOfClientActivity method. + NotifyCodespaceOfClientActivityFunc func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) + + // RebuildContainerAsyncFunc mocks the RebuildContainerAsync method. + RebuildContainerAsyncFunc func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) + + // mustEmbedUnimplementedCodespaceHostServerFunc mocks the mustEmbedUnimplementedCodespaceHostServer method. + mustEmbedUnimplementedCodespaceHostServerFunc func() + + // calls tracks calls to the methods. + calls struct { + // NotifyCodespaceOfClientActivity holds details about calls to the NotifyCodespaceOfClientActivity method. + NotifyCodespaceOfClientActivity []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // NotifyCodespaceOfClientActivityRequest is the notifyCodespaceOfClientActivityRequest argument value. + NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest + } + // RebuildContainerAsync holds details about calls to the RebuildContainerAsync method. + RebuildContainerAsync []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // RebuildContainerRequest is the rebuildContainerRequest argument value. + RebuildContainerRequest *RebuildContainerRequest + } + // mustEmbedUnimplementedCodespaceHostServer holds details about calls to the mustEmbedUnimplementedCodespaceHostServer method. + mustEmbedUnimplementedCodespaceHostServer []struct { + } + } + lockNotifyCodespaceOfClientActivity sync.RWMutex + lockRebuildContainerAsync sync.RWMutex + lockmustEmbedUnimplementedCodespaceHostServer sync.RWMutex +} + +// NotifyCodespaceOfClientActivity calls NotifyCodespaceOfClientActivityFunc. +func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivity(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) { + if mock.NotifyCodespaceOfClientActivityFunc == nil { + panic("CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc: method is nil but CodespaceHostServer.NotifyCodespaceOfClientActivity was just called") + } + callInfo := struct { + ContextMoqParam context.Context + NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest + }{ + ContextMoqParam: contextMoqParam, + NotifyCodespaceOfClientActivityRequest: notifyCodespaceOfClientActivityRequest, + } + mock.lockNotifyCodespaceOfClientActivity.Lock() + mock.calls.NotifyCodespaceOfClientActivity = append(mock.calls.NotifyCodespaceOfClientActivity, callInfo) + mock.lockNotifyCodespaceOfClientActivity.Unlock() + return mock.NotifyCodespaceOfClientActivityFunc(contextMoqParam, notifyCodespaceOfClientActivityRequest) +} + +// NotifyCodespaceOfClientActivityCalls gets all the calls that were made to NotifyCodespaceOfClientActivity. +// Check the length with: +// +// len(mockedCodespaceHostServer.NotifyCodespaceOfClientActivityCalls()) +func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivityCalls() []struct { + ContextMoqParam context.Context + NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest +} { + var calls []struct { + ContextMoqParam context.Context + NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest + } + mock.lockNotifyCodespaceOfClientActivity.RLock() + calls = mock.calls.NotifyCodespaceOfClientActivity + mock.lockNotifyCodespaceOfClientActivity.RUnlock() + return calls +} + +// RebuildContainerAsync calls RebuildContainerAsyncFunc. +func (mock *CodespaceHostServerMock) RebuildContainerAsync(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) { + if mock.RebuildContainerAsyncFunc == nil { + panic("CodespaceHostServerMock.RebuildContainerAsyncFunc: method is nil but CodespaceHostServer.RebuildContainerAsync was just called") + } + callInfo := struct { + ContextMoqParam context.Context + RebuildContainerRequest *RebuildContainerRequest + }{ + ContextMoqParam: contextMoqParam, + RebuildContainerRequest: rebuildContainerRequest, + } + mock.lockRebuildContainerAsync.Lock() + mock.calls.RebuildContainerAsync = append(mock.calls.RebuildContainerAsync, callInfo) + mock.lockRebuildContainerAsync.Unlock() + return mock.RebuildContainerAsyncFunc(contextMoqParam, rebuildContainerRequest) +} + +// RebuildContainerAsyncCalls gets all the calls that were made to RebuildContainerAsync. +// Check the length with: +// +// len(mockedCodespaceHostServer.RebuildContainerAsyncCalls()) +func (mock *CodespaceHostServerMock) RebuildContainerAsyncCalls() []struct { + ContextMoqParam context.Context + RebuildContainerRequest *RebuildContainerRequest +} { + var calls []struct { + ContextMoqParam context.Context + RebuildContainerRequest *RebuildContainerRequest + } + mock.lockRebuildContainerAsync.RLock() + calls = mock.calls.RebuildContainerAsync + mock.lockRebuildContainerAsync.RUnlock() + return calls +} + +// mustEmbedUnimplementedCodespaceHostServer calls mustEmbedUnimplementedCodespaceHostServerFunc. +func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServer() { + if mock.mustEmbedUnimplementedCodespaceHostServerFunc == nil { + panic("CodespaceHostServerMock.mustEmbedUnimplementedCodespaceHostServerFunc: method is nil but CodespaceHostServer.mustEmbedUnimplementedCodespaceHostServer was just called") + } + callInfo := struct { + }{} + mock.lockmustEmbedUnimplementedCodespaceHostServer.Lock() + mock.calls.mustEmbedUnimplementedCodespaceHostServer = append(mock.calls.mustEmbedUnimplementedCodespaceHostServer, callInfo) + mock.lockmustEmbedUnimplementedCodespaceHostServer.Unlock() + mock.mustEmbedUnimplementedCodespaceHostServerFunc() +} + +// mustEmbedUnimplementedCodespaceHostServerCalls gets all the calls that were made to mustEmbedUnimplementedCodespaceHostServer. +// Check the length with: +// +// len(mockedCodespaceHostServer.mustEmbedUnimplementedCodespaceHostServerCalls()) +func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServerCalls() []struct { +} { + var calls []struct { + } + mock.lockmustEmbedUnimplementedCodespaceHostServer.RLock() + calls = mock.calls.mustEmbedUnimplementedCodespaceHostServer + mock.lockmustEmbedUnimplementedCodespaceHostServer.RUnlock() + return calls +} diff --git a/internal/codespaces/rpc/generate.md b/internal/codespaces/rpc/generate.md index 7ae1dcc1a..d0d6bbc9d 100644 --- a/internal/codespaces/rpc/generate.md +++ b/internal/codespaces/rpc/generate.md @@ -6,7 +6,8 @@ Instructions for generating and adding gRPC protocol buffers. 1. [Download `protoc`](https://grpc.io/docs/protoc-installation/) 2. [Download protocol compiler plugins for Go](https://grpc.io/docs/languages/go/quickstart/) -3. Run `./generate.sh` from the `internal/codespaces/grpc` directory +3. Install moq: `go install github.com/matryer/moq@latest` +4. Run `./generate.sh` from the `internal/codespaces/rpc` directory ## Add New Protocol Buffers diff --git a/internal/codespaces/rpc/generate.sh b/internal/codespaces/rpc/generate.sh index 159803bbe..4ba2f898a 100755 --- a/internal/codespaces/rpc/generate.sh +++ b/internal/codespaces/rpc/generate.sh @@ -15,14 +15,21 @@ if ! protoc-gen-go-grpc --version; then fi function generate { - local contract="$1" + local dir="$1" + local proto="$2" + + local contract="$dir/$proto" protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative "$contract" echo "Generated protocol buffers for $contract" + + services=$(cat "$contract" | grep -Eo "service .+ {" | awk '{print $2 "Server"}') + moq -out $contract.mock.go $dir $services + echo "Generated mock protocols for $contract" } -generate jupyter/jupyter_server_host_service.v1.proto -generate codespace/codespace_host_service.v1.proto -generate ssh/ssh_server_host_service.v1.proto +generate jupyter jupyter_server_host_service.v1.proto +generate codespace codespace_host_service.v1.proto +generate ssh ssh_server_host_service.v1.proto echo 'Done!' diff --git a/internal/codespaces/rpc/invoker.go b/internal/codespaces/rpc/invoker.go index fa3c2897a..bb2e25a55 100644 --- a/internal/codespaces/rpc/invoker.go +++ b/internal/codespaces/rpc/invoker.go @@ -130,6 +130,9 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker, invoker.codespaceClient = codespace.NewCodespaceHostClient(conn) invoker.sshClient = ssh.NewSshServerHostClient(conn) + // Send initial connection heartbeat (no need to throw if we fail to get a response from the server) + _ = invoker.notifyCodespaceOfClientActivity(ctx, connectedEventName) + // Start the activity heatbeats go invoker.heartbeat(pfctx, 1*time.Minute) @@ -253,9 +256,6 @@ func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() - // Send initial connection heartbeat (no need to throw if we fail to get a response from the server) - _ = i.notifyCodespaceOfClientActivity(ctx, connectedEventName) - for { select { case <-ctx.Done(): diff --git a/internal/codespaces/rpc/invoker_test.go b/internal/codespaces/rpc/invoker_test.go index c01148cf1..ba3e13ac3 100644 --- a/internal/codespaces/rpc/invoker_test.go +++ b/internal/codespaces/rpc/invoker_test.go @@ -3,84 +3,159 @@ package rpc import ( "context" "fmt" - "log" - "os" + "net" + "strconv" "testing" + "github.com/cli/cli/v2/internal/codespaces/rpc/codespace" + "github.com/cli/cli/v2/internal/codespaces/rpc/jupyter" + "github.com/cli/cli/v2/internal/codespaces/rpc/ssh" rpctest "github.com/cli/cli/v2/internal/codespaces/rpc/test" + "google.golang.org/grpc" ) -func startServer(t *testing.T) { - t.Helper() - if os.Getenv("GITHUB_ACTIONS") == "true" { - t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5663") +type mockServer struct { + jupyter.JupyterServerHostServerMock + codespace.CodespaceHostServerMock + ssh.SshServerHostServerMock +} + +func newMockServer() *mockServer { + server := &mockServer{} + + server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc = func(context.Context, *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) { + return &codespace.NotifyCodespaceOfClientActivityResponse{ + Message: "", + Result: true, + }, nil + } + + return server +} + +// runTestGrpcServer serves grpc requests over the provided Listener using the mockServer for mocked callbacks. +// It does not return until the Context is cancelled and the server fully shuts down. +func runTestGrpcServer(ctx context.Context, listener net.Listener, server *mockServer) error { + s := grpc.NewServer() + jupyter.RegisterJupyterServerHostServer(s, server) + codespace.RegisterCodespaceHostServer(s, server) + ssh.RegisterSshServerHostServer(s, server) + + ch := make(chan error, 1) + go func() { ch <- s.Serve(listener) }() + + select { + case <-ctx.Done(): + s.Stop() + <-ch + return nil + case err := <-ch: + return err + } +} + +// createTestInvoker is the main test setup function. It returns an Invoker using the provided mockServer, as well as a shutdown function. +// The Invoker does not need to be closed directly, that will be handled by the shutdown function. +func createTestInvoker(t *testing.T, server *mockServer) (Invoker, func(), error) { + listener, err := net.Listen("tcp", "127.0.0.1:16634") + if err != nil { + return nil, nil, fmt.Errorf("failed to listen: %w", err) } ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan error) + go func() { ch <- runTestGrpcServer(ctx, listener, server) }() - // Start the gRPC server in the background - go func() { - err := rpctest.StartServer(ctx) - if err != nil && err != context.Canceled { - log.Println(fmt.Errorf("error starting test server: %v", err)) - } - }() - - // Stop the gRPC server when the test is done - t.Cleanup(func() { + close := func() { cancel() - }) -} - -func createTestInvoker(t *testing.T) Invoker { - t.Helper() - - // Clear the stored client activity - rpctest.NotifyReceivedActivity = "" + <-ch + listener.Close() + } invoker, err := CreateInvoker(context.Background(), &rpctest.Session{}) if err != nil { - t.Fatalf("error connecting to internal server: %v", err) + close() + return nil, nil, fmt.Errorf("error connecting to internal server: %w", err) } - t.Cleanup(func() { - testNotifyCodespaceOfClientActivity(t) + return invoker, func() { invoker.Close() - }) - - return invoker + close() + }, nil } // Test that the RPC invoker notifies the codespace of client activity on connection -func testNotifyCodespaceOfClientActivity(t *testing.T) { - if rpctest.NotifyReceivedActivity != connectedEventName { - t.Fatalf("expected %s, got %s", connectedEventName, rpctest.NotifyMessage) +func verifyNotifyCodespaceOfClientActivity(t *testing.T, server *mockServer) { + calls := server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityCalls() + if len(calls) == 0 { + t.Fatalf("no client activity calls") } + + for _, call := range calls { + activities := call.NotifyCodespaceOfClientActivityRequest.GetClientActivities() + if activities[0] == connectedEventName { + return + } + } + + t.Fatalf("no activity named %s", connectedEventName) } // Test that the RPC invoker returns the correct port and URL when the JupyterLab server starts successfully func TestStartJupyterServerSuccess(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) + resp := jupyter.GetRunningServerResponse{ + Port: strconv.Itoa(1234), + ServerUrl: "http://localhost:1234?token=1234", + Message: "", + Result: true, + } + + server := newMockServer() + server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + port, url, err := invoker.StartJupyterServer(context.Background()) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } - if port != rpctest.JupyterPort { - t.Fatalf("expected %d, got %d", rpctest.JupyterPort, port) + if strconv.Itoa(port) != resp.Port { + t.Fatalf("expected %s, got %d", resp.Port, port) } - if url != rpctest.JupyterServerUrl { - t.Fatalf("expected %s, got %s", rpctest.JupyterServerUrl, url) + if url != resp.ServerUrl { + t.Fatalf("expected %s, got %s", resp.ServerUrl, url) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker returns an error when the JupyterLab server fails to start func TestStartJupyterServerFailure(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - rpctest.JupyterMessage = "error message" - rpctest.JupyterResult = false - errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", rpctest.JupyterMessage) + resp := jupyter.GetRunningServerResponse{ + Port: strconv.Itoa(1234), + ServerUrl: "http://localhost:1234?token=1234", + Message: "error message", + Result: false, + } + + server := newMockServer() + server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + + errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", resp.Message) port, url, err := invoker.StartJupyterServer(context.Background()) if err.Error() != errorMessage { t.Fatalf("expected %v, got %v", errorMessage, err) @@ -91,35 +166,79 @@ func TestStartJupyterServerFailure(t *testing.T) { if url != "" { t.Fatalf("expected %s, got %s", "", url) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker doesn't throw an error when requesting an incremental rebuild func TestRebuildContainerIncremental(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - err := invoker.RebuildContainer(context.Background(), false) + resp := codespace.RebuildContainerResponse{ + RebuildContainer: true, + } + + server := newMockServer() + server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + + err = invoker.RebuildContainer(context.Background(), false) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker doesn't throw an error when requesting a full rebuild func TestRebuildContainerFull(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - err := invoker.RebuildContainer(context.Background(), true) + resp := codespace.RebuildContainerResponse{ + RebuildContainer: true, + } + + server := newMockServer() + server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + + err = invoker.RebuildContainer(context.Background(), true) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker throws an error when the rebuild fails func TestRebuildContainerFailure(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - rpctest.RebuildContainer = false + resp := codespace.RebuildContainerResponse{ + RebuildContainer: false, + } + + server := newMockServer() + server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + errorMessage := "couldn't rebuild codespace" - err := invoker.RebuildContainer(context.Background(), true) + err = invoker.RebuildContainer(context.Background(), true) if err.Error() != errorMessage { t.Fatalf("expected %v, got %v", errorMessage, err) } @@ -127,27 +246,59 @@ func TestRebuildContainerFailure(t *testing.T) { // Test that the RPC invoker returns the correct port and user when the SSH server starts successfully func TestStartSSHServerSuccess(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) + resp := ssh.StartRemoteServerResponse{ + ServerPort: strconv.Itoa(1234), + User: "test", + Message: "", + Result: true, + } + + server := newMockServer() + server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + port, user, err := invoker.StartSSHServer(context.Background()) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } - if port != rpctest.SshServerPort { - t.Fatalf("expected %d, got %d", rpctest.SshServerPort, port) + if strconv.Itoa(port) != resp.ServerPort { + t.Fatalf("expected %s, got %d", resp.ServerPort, port) } - if user != rpctest.SshUser { - t.Fatalf("expected %s, got %s", rpctest.SshUser, user) + if user != resp.User { + t.Fatalf("expected %s, got %s", resp.User, user) } + + verifyNotifyCodespaceOfClientActivity(t, server) } // Test that the RPC invoker returns an error when the SSH server fails to start func TestStartSSHServerFailure(t *testing.T) { - startServer(t) - invoker := createTestInvoker(t) - rpctest.SshMessage = "error message" - rpctest.SshResult = false - errorMessage := fmt.Sprintf("failed to start SSH server: %s", rpctest.SshMessage) + resp := ssh.StartRemoteServerResponse{ + ServerPort: strconv.Itoa(1234), + User: "test", + Message: "error message", + Result: false, + } + + server := newMockServer() + server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) { + return &resp, nil + } + + invoker, stop, err := createTestInvoker(t, server) + if err != nil { + t.Fatalf("error connecting to internal server: %v", err) + } + defer stop() + + errorMessage := fmt.Sprintf("failed to start SSH server: %s", resp.Message) port, user, err := invoker.StartSSHServer(context.Background()) if err.Error() != errorMessage { t.Fatalf("expected %v, got %v", errorMessage, err) diff --git a/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.pb.go b/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.pb.go index 8e11c6a32..b8f400d3c 100644 --- a/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.pb.go +++ b/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.0 +// protoc-gen-go v1.28.1 // protoc v3.21.12 // source: jupyter/jupyter_server_host_service.v1.proto diff --git a/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.proto.mock.go b/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.proto.mock.go new file mode 100644 index 000000000..12ea0bb5b --- /dev/null +++ b/internal/codespaces/rpc/jupyter/jupyter_server_host_service.v1.proto.mock.go @@ -0,0 +1,118 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package jupyter + +import ( + context "context" + sync "sync" +) + +// Ensure, that JupyterServerHostServerMock does implement JupyterServerHostServer. +// If this is not the case, regenerate this file with moq. +var _ JupyterServerHostServer = &JupyterServerHostServerMock{} + +// JupyterServerHostServerMock is a mock implementation of JupyterServerHostServer. +// +// func TestSomethingThatUsesJupyterServerHostServer(t *testing.T) { +// +// // make and configure a mocked JupyterServerHostServer +// mockedJupyterServerHostServer := &JupyterServerHostServerMock{ +// GetRunningServerFunc: func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) { +// panic("mock out the GetRunningServer method") +// }, +// mustEmbedUnimplementedJupyterServerHostServerFunc: func() { +// panic("mock out the mustEmbedUnimplementedJupyterServerHostServer method") +// }, +// } +// +// // use mockedJupyterServerHostServer in code that requires JupyterServerHostServer +// // and then make assertions. +// +// } +type JupyterServerHostServerMock struct { + // GetRunningServerFunc mocks the GetRunningServer method. + GetRunningServerFunc func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) + + // mustEmbedUnimplementedJupyterServerHostServerFunc mocks the mustEmbedUnimplementedJupyterServerHostServer method. + mustEmbedUnimplementedJupyterServerHostServerFunc func() + + // calls tracks calls to the methods. + calls struct { + // GetRunningServer holds details about calls to the GetRunningServer method. + GetRunningServer []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // GetRunningServerRequest is the getRunningServerRequest argument value. + GetRunningServerRequest *GetRunningServerRequest + } + // mustEmbedUnimplementedJupyterServerHostServer holds details about calls to the mustEmbedUnimplementedJupyterServerHostServer method. + mustEmbedUnimplementedJupyterServerHostServer []struct { + } + } + lockGetRunningServer sync.RWMutex + lockmustEmbedUnimplementedJupyterServerHostServer sync.RWMutex +} + +// GetRunningServer calls GetRunningServerFunc. +func (mock *JupyterServerHostServerMock) GetRunningServer(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) { + if mock.GetRunningServerFunc == nil { + panic("JupyterServerHostServerMock.GetRunningServerFunc: method is nil but JupyterServerHostServer.GetRunningServer was just called") + } + callInfo := struct { + ContextMoqParam context.Context + GetRunningServerRequest *GetRunningServerRequest + }{ + ContextMoqParam: contextMoqParam, + GetRunningServerRequest: getRunningServerRequest, + } + mock.lockGetRunningServer.Lock() + mock.calls.GetRunningServer = append(mock.calls.GetRunningServer, callInfo) + mock.lockGetRunningServer.Unlock() + return mock.GetRunningServerFunc(contextMoqParam, getRunningServerRequest) +} + +// GetRunningServerCalls gets all the calls that were made to GetRunningServer. +// Check the length with: +// +// len(mockedJupyterServerHostServer.GetRunningServerCalls()) +func (mock *JupyterServerHostServerMock) GetRunningServerCalls() []struct { + ContextMoqParam context.Context + GetRunningServerRequest *GetRunningServerRequest +} { + var calls []struct { + ContextMoqParam context.Context + GetRunningServerRequest *GetRunningServerRequest + } + mock.lockGetRunningServer.RLock() + calls = mock.calls.GetRunningServer + mock.lockGetRunningServer.RUnlock() + return calls +} + +// mustEmbedUnimplementedJupyterServerHostServer calls mustEmbedUnimplementedJupyterServerHostServerFunc. +func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServer() { + if mock.mustEmbedUnimplementedJupyterServerHostServerFunc == nil { + panic("JupyterServerHostServerMock.mustEmbedUnimplementedJupyterServerHostServerFunc: method is nil but JupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServer was just called") + } + callInfo := struct { + }{} + mock.lockmustEmbedUnimplementedJupyterServerHostServer.Lock() + mock.calls.mustEmbedUnimplementedJupyterServerHostServer = append(mock.calls.mustEmbedUnimplementedJupyterServerHostServer, callInfo) + mock.lockmustEmbedUnimplementedJupyterServerHostServer.Unlock() + mock.mustEmbedUnimplementedJupyterServerHostServerFunc() +} + +// mustEmbedUnimplementedJupyterServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedJupyterServerHostServer. +// Check the length with: +// +// len(mockedJupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServerCalls()) +func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServerCalls() []struct { +} { + var calls []struct { + } + mock.lockmustEmbedUnimplementedJupyterServerHostServer.RLock() + calls = mock.calls.mustEmbedUnimplementedJupyterServerHostServer + mock.lockmustEmbedUnimplementedJupyterServerHostServer.RUnlock() + return calls +} diff --git a/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go index c495eb781..3dd22f583 100644 --- a/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go +++ b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.0 +// protoc-gen-go v1.28.1 // protoc v3.21.12 // source: ssh/ssh_server_host_service.v1.proto diff --git a/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto.mock.go b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto.mock.go new file mode 100644 index 000000000..d11e99461 --- /dev/null +++ b/internal/codespaces/rpc/ssh/ssh_server_host_service.v1.proto.mock.go @@ -0,0 +1,118 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package ssh + +import ( + context "context" + sync "sync" +) + +// Ensure, that SshServerHostServerMock does implement SshServerHostServer. +// If this is not the case, regenerate this file with moq. +var _ SshServerHostServer = &SshServerHostServerMock{} + +// SshServerHostServerMock is a mock implementation of SshServerHostServer. +// +// func TestSomethingThatUsesSshServerHostServer(t *testing.T) { +// +// // make and configure a mocked SshServerHostServer +// mockedSshServerHostServer := &SshServerHostServerMock{ +// StartRemoteServerAsyncFunc: func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) { +// panic("mock out the StartRemoteServerAsync method") +// }, +// mustEmbedUnimplementedSshServerHostServerFunc: func() { +// panic("mock out the mustEmbedUnimplementedSshServerHostServer method") +// }, +// } +// +// // use mockedSshServerHostServer in code that requires SshServerHostServer +// // and then make assertions. +// +// } +type SshServerHostServerMock struct { + // StartRemoteServerAsyncFunc mocks the StartRemoteServerAsync method. + StartRemoteServerAsyncFunc func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) + + // mustEmbedUnimplementedSshServerHostServerFunc mocks the mustEmbedUnimplementedSshServerHostServer method. + mustEmbedUnimplementedSshServerHostServerFunc func() + + // calls tracks calls to the methods. + calls struct { + // StartRemoteServerAsync holds details about calls to the StartRemoteServerAsync method. + StartRemoteServerAsync []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // StartRemoteServerRequest is the startRemoteServerRequest argument value. + StartRemoteServerRequest *StartRemoteServerRequest + } + // mustEmbedUnimplementedSshServerHostServer holds details about calls to the mustEmbedUnimplementedSshServerHostServer method. + mustEmbedUnimplementedSshServerHostServer []struct { + } + } + lockStartRemoteServerAsync sync.RWMutex + lockmustEmbedUnimplementedSshServerHostServer sync.RWMutex +} + +// StartRemoteServerAsync calls StartRemoteServerAsyncFunc. +func (mock *SshServerHostServerMock) StartRemoteServerAsync(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) { + if mock.StartRemoteServerAsyncFunc == nil { + panic("SshServerHostServerMock.StartRemoteServerAsyncFunc: method is nil but SshServerHostServer.StartRemoteServerAsync was just called") + } + callInfo := struct { + ContextMoqParam context.Context + StartRemoteServerRequest *StartRemoteServerRequest + }{ + ContextMoqParam: contextMoqParam, + StartRemoteServerRequest: startRemoteServerRequest, + } + mock.lockStartRemoteServerAsync.Lock() + mock.calls.StartRemoteServerAsync = append(mock.calls.StartRemoteServerAsync, callInfo) + mock.lockStartRemoteServerAsync.Unlock() + return mock.StartRemoteServerAsyncFunc(contextMoqParam, startRemoteServerRequest) +} + +// StartRemoteServerAsyncCalls gets all the calls that were made to StartRemoteServerAsync. +// Check the length with: +// +// len(mockedSshServerHostServer.StartRemoteServerAsyncCalls()) +func (mock *SshServerHostServerMock) StartRemoteServerAsyncCalls() []struct { + ContextMoqParam context.Context + StartRemoteServerRequest *StartRemoteServerRequest +} { + var calls []struct { + ContextMoqParam context.Context + StartRemoteServerRequest *StartRemoteServerRequest + } + mock.lockStartRemoteServerAsync.RLock() + calls = mock.calls.StartRemoteServerAsync + mock.lockStartRemoteServerAsync.RUnlock() + return calls +} + +// mustEmbedUnimplementedSshServerHostServer calls mustEmbedUnimplementedSshServerHostServerFunc. +func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServer() { + if mock.mustEmbedUnimplementedSshServerHostServerFunc == nil { + panic("SshServerHostServerMock.mustEmbedUnimplementedSshServerHostServerFunc: method is nil but SshServerHostServer.mustEmbedUnimplementedSshServerHostServer was just called") + } + callInfo := struct { + }{} + mock.lockmustEmbedUnimplementedSshServerHostServer.Lock() + mock.calls.mustEmbedUnimplementedSshServerHostServer = append(mock.calls.mustEmbedUnimplementedSshServerHostServer, callInfo) + mock.lockmustEmbedUnimplementedSshServerHostServer.Unlock() + mock.mustEmbedUnimplementedSshServerHostServerFunc() +} + +// mustEmbedUnimplementedSshServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedSshServerHostServer. +// Check the length with: +// +// len(mockedSshServerHostServer.mustEmbedUnimplementedSshServerHostServerCalls()) +func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServerCalls() []struct { +} { + var calls []struct { + } + mock.lockmustEmbedUnimplementedSshServerHostServer.RLock() + calls = mock.calls.mustEmbedUnimplementedSshServerHostServer + mock.lockmustEmbedUnimplementedSshServerHostServer.RUnlock() + return calls +} diff --git a/internal/codespaces/rpc/test/server.go b/internal/codespaces/rpc/test/server.go deleted file mode 100644 index 07b80bb83..000000000 --- a/internal/codespaces/rpc/test/server.go +++ /dev/null @@ -1,117 +0,0 @@ -package test - -import ( - "context" - "fmt" - "net" - "strconv" - - "github.com/cli/cli/v2/internal/codespaces/rpc/codespace" - "github.com/cli/cli/v2/internal/codespaces/rpc/jupyter" - "github.com/cli/cli/v2/internal/codespaces/rpc/ssh" - "google.golang.org/grpc" -) - -const ( - ServerPort = 50051 -) - -// Mock responses for the `GetRunningServer` RPC method -var ( - JupyterPort = 1234 - JupyterServerUrl = "http://localhost:1234?token=1234" - JupyterMessage = "" - JupyterResult = true -) - -// Mock responses for the `RebuildContainerAsync` RPC method -var ( - RebuildContainer = true -) - -// Mock responses for the `NotifyCodespaceOfClientActivity` RPC method -// NotifyMessage is used to store the activity that was sent to the server -var ( - NotifyMessage = "" - NotifyResult = true - NotifyReceivedActivity = "" -) - -// Mock responses for the `StartRemoteServerAsync` RPC method -var ( - SshServerPort = 1234 - SshUser = "test" - SshMessage = "" - SshResult = true -) - -type server struct { - jupyter.UnimplementedJupyterServerHostServer - codespace.CodespaceHostServer - ssh.SshServerHostServer -} - -func (s *server) GetRunningServer(ctx context.Context, in *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) { - return &jupyter.GetRunningServerResponse{ - Port: strconv.Itoa(JupyterPort), - ServerUrl: JupyterServerUrl, - Message: JupyterMessage, - Result: JupyterResult, - }, nil -} - -func (s *server) RebuildContainerAsync(ctx context.Context, in *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) { - return &codespace.RebuildContainerResponse{ - RebuildContainer: RebuildContainer, - }, nil -} - -func (s *server) NotifyCodespaceOfClientActivity(ctx context.Context, in *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) { - // If there is at least one client activity, set NotifyReceivedActivity to the first one (should be "connected") - if len(in.GetClientActivities()) > 0 { - NotifyReceivedActivity = in.GetClientActivities()[0] - } - - return &codespace.NotifyCodespaceOfClientActivityResponse{ - Message: NotifyMessage, - Result: NotifyResult, - }, nil -} - -func (s *server) StartRemoteServerAsync(ctx context.Context, in *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) { - return &ssh.StartRemoteServerResponse{ - ServerPort: strconv.Itoa(SshServerPort), - User: SshUser, - Message: SshMessage, - Result: SshResult, - }, nil -} - -// Starts the mock gRPC server listening on port 50051 -func StartServer(ctx context.Context) error { - listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) - if err != nil { - return fmt.Errorf("failed to listen: %v", err) - } - defer listener.Close() - - s := grpc.NewServer() - jupyter.RegisterJupyterServerHostServer(s, &server{}) - codespace.RegisterCodespaceHostServer(s, &server{}) - ssh.RegisterSshServerHostServer(s, &server{}) - - ch := make(chan error, 1) - go func() { - if err := s.Serve(listener); err != nil { - ch <- fmt.Errorf("failed to serve: %v", err) - } - }() - - select { - case <-ctx.Done(): - s.Stop() - return ctx.Err() - case err := <-ch: - return err - } -} diff --git a/internal/codespaces/rpc/test/session.go b/internal/codespaces/rpc/test/session.go index 607451392..531d4c33f 100644 --- a/internal/codespaces/rpc/test/session.go +++ b/internal/codespaces/rpc/test/session.go @@ -29,7 +29,7 @@ func (s *Session) GetKeepAliveReason() string { } func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) { - conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) if err != nil { return liveshare.ChannelID{}, err }