Fix race conditions in invoker_test (#6905)

This commit is contained in:
Caleb Brose 2023-01-25 14:57:21 -06:00 committed by GitHub
parent fef4195004
commit f669a10cf9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 640 additions and 194 deletions

View file

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.0
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: codespace/codespace_host_service.v1.proto

View file

@ -0,0 +1,168 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package codespace
import (
context "context"
sync "sync"
)
// Ensure, that CodespaceHostServerMock does implement CodespaceHostServer.
// If this is not the case, regenerate this file with moq.
var _ CodespaceHostServer = &CodespaceHostServerMock{}
// CodespaceHostServerMock is a mock implementation of CodespaceHostServer.
//
// func TestSomethingThatUsesCodespaceHostServer(t *testing.T) {
//
// // make and configure a mocked CodespaceHostServer
// mockedCodespaceHostServer := &CodespaceHostServerMock{
// NotifyCodespaceOfClientActivityFunc: func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) {
// panic("mock out the NotifyCodespaceOfClientActivity method")
// },
// RebuildContainerAsyncFunc: func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) {
// panic("mock out the RebuildContainerAsync method")
// },
// mustEmbedUnimplementedCodespaceHostServerFunc: func() {
// panic("mock out the mustEmbedUnimplementedCodespaceHostServer method")
// },
// }
//
// // use mockedCodespaceHostServer in code that requires CodespaceHostServer
// // and then make assertions.
//
// }
type CodespaceHostServerMock struct {
// NotifyCodespaceOfClientActivityFunc mocks the NotifyCodespaceOfClientActivity method.
NotifyCodespaceOfClientActivityFunc func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error)
// RebuildContainerAsyncFunc mocks the RebuildContainerAsync method.
RebuildContainerAsyncFunc func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error)
// mustEmbedUnimplementedCodespaceHostServerFunc mocks the mustEmbedUnimplementedCodespaceHostServer method.
mustEmbedUnimplementedCodespaceHostServerFunc func()
// calls tracks calls to the methods.
calls struct {
// NotifyCodespaceOfClientActivity holds details about calls to the NotifyCodespaceOfClientActivity method.
NotifyCodespaceOfClientActivity []struct {
// ContextMoqParam is the contextMoqParam argument value.
ContextMoqParam context.Context
// NotifyCodespaceOfClientActivityRequest is the notifyCodespaceOfClientActivityRequest argument value.
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
}
// RebuildContainerAsync holds details about calls to the RebuildContainerAsync method.
RebuildContainerAsync []struct {
// ContextMoqParam is the contextMoqParam argument value.
ContextMoqParam context.Context
// RebuildContainerRequest is the rebuildContainerRequest argument value.
RebuildContainerRequest *RebuildContainerRequest
}
// mustEmbedUnimplementedCodespaceHostServer holds details about calls to the mustEmbedUnimplementedCodespaceHostServer method.
mustEmbedUnimplementedCodespaceHostServer []struct {
}
}
lockNotifyCodespaceOfClientActivity sync.RWMutex
lockRebuildContainerAsync sync.RWMutex
lockmustEmbedUnimplementedCodespaceHostServer sync.RWMutex
}
// NotifyCodespaceOfClientActivity calls NotifyCodespaceOfClientActivityFunc.
func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivity(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) {
if mock.NotifyCodespaceOfClientActivityFunc == nil {
panic("CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc: method is nil but CodespaceHostServer.NotifyCodespaceOfClientActivity was just called")
}
callInfo := struct {
ContextMoqParam context.Context
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
}{
ContextMoqParam: contextMoqParam,
NotifyCodespaceOfClientActivityRequest: notifyCodespaceOfClientActivityRequest,
}
mock.lockNotifyCodespaceOfClientActivity.Lock()
mock.calls.NotifyCodespaceOfClientActivity = append(mock.calls.NotifyCodespaceOfClientActivity, callInfo)
mock.lockNotifyCodespaceOfClientActivity.Unlock()
return mock.NotifyCodespaceOfClientActivityFunc(contextMoqParam, notifyCodespaceOfClientActivityRequest)
}
// NotifyCodespaceOfClientActivityCalls gets all the calls that were made to NotifyCodespaceOfClientActivity.
// Check the length with:
//
// len(mockedCodespaceHostServer.NotifyCodespaceOfClientActivityCalls())
func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivityCalls() []struct {
ContextMoqParam context.Context
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
} {
var calls []struct {
ContextMoqParam context.Context
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
}
mock.lockNotifyCodespaceOfClientActivity.RLock()
calls = mock.calls.NotifyCodespaceOfClientActivity
mock.lockNotifyCodespaceOfClientActivity.RUnlock()
return calls
}
// RebuildContainerAsync calls RebuildContainerAsyncFunc.
func (mock *CodespaceHostServerMock) RebuildContainerAsync(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) {
if mock.RebuildContainerAsyncFunc == nil {
panic("CodespaceHostServerMock.RebuildContainerAsyncFunc: method is nil but CodespaceHostServer.RebuildContainerAsync was just called")
}
callInfo := struct {
ContextMoqParam context.Context
RebuildContainerRequest *RebuildContainerRequest
}{
ContextMoqParam: contextMoqParam,
RebuildContainerRequest: rebuildContainerRequest,
}
mock.lockRebuildContainerAsync.Lock()
mock.calls.RebuildContainerAsync = append(mock.calls.RebuildContainerAsync, callInfo)
mock.lockRebuildContainerAsync.Unlock()
return mock.RebuildContainerAsyncFunc(contextMoqParam, rebuildContainerRequest)
}
// RebuildContainerAsyncCalls gets all the calls that were made to RebuildContainerAsync.
// Check the length with:
//
// len(mockedCodespaceHostServer.RebuildContainerAsyncCalls())
func (mock *CodespaceHostServerMock) RebuildContainerAsyncCalls() []struct {
ContextMoqParam context.Context
RebuildContainerRequest *RebuildContainerRequest
} {
var calls []struct {
ContextMoqParam context.Context
RebuildContainerRequest *RebuildContainerRequest
}
mock.lockRebuildContainerAsync.RLock()
calls = mock.calls.RebuildContainerAsync
mock.lockRebuildContainerAsync.RUnlock()
return calls
}
// mustEmbedUnimplementedCodespaceHostServer calls mustEmbedUnimplementedCodespaceHostServerFunc.
func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServer() {
if mock.mustEmbedUnimplementedCodespaceHostServerFunc == nil {
panic("CodespaceHostServerMock.mustEmbedUnimplementedCodespaceHostServerFunc: method is nil but CodespaceHostServer.mustEmbedUnimplementedCodespaceHostServer was just called")
}
callInfo := struct {
}{}
mock.lockmustEmbedUnimplementedCodespaceHostServer.Lock()
mock.calls.mustEmbedUnimplementedCodespaceHostServer = append(mock.calls.mustEmbedUnimplementedCodespaceHostServer, callInfo)
mock.lockmustEmbedUnimplementedCodespaceHostServer.Unlock()
mock.mustEmbedUnimplementedCodespaceHostServerFunc()
}
// mustEmbedUnimplementedCodespaceHostServerCalls gets all the calls that were made to mustEmbedUnimplementedCodespaceHostServer.
// Check the length with:
//
// len(mockedCodespaceHostServer.mustEmbedUnimplementedCodespaceHostServerCalls())
func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServerCalls() []struct {
} {
var calls []struct {
}
mock.lockmustEmbedUnimplementedCodespaceHostServer.RLock()
calls = mock.calls.mustEmbedUnimplementedCodespaceHostServer
mock.lockmustEmbedUnimplementedCodespaceHostServer.RUnlock()
return calls
}

View file

@ -6,7 +6,8 @@ Instructions for generating and adding gRPC protocol buffers.
1. [Download `protoc`](https://grpc.io/docs/protoc-installation/)
2. [Download protocol compiler plugins for Go](https://grpc.io/docs/languages/go/quickstart/)
3. Run `./generate.sh` from the `internal/codespaces/grpc` directory
3. Install moq: `go install github.com/matryer/moq@latest`
4. Run `./generate.sh` from the `internal/codespaces/rpc` directory
## Add New Protocol Buffers

View file

@ -15,14 +15,21 @@ if ! protoc-gen-go-grpc --version; then
fi
function generate {
local contract="$1"
local dir="$1"
local proto="$2"
local contract="$dir/$proto"
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative "$contract"
echo "Generated protocol buffers for $contract"
services=$(cat "$contract" | grep -Eo "service .+ {" | awk '{print $2 "Server"}')
moq -out $contract.mock.go $dir $services
echo "Generated mock protocols for $contract"
}
generate jupyter/jupyter_server_host_service.v1.proto
generate codespace/codespace_host_service.v1.proto
generate ssh/ssh_server_host_service.v1.proto
generate jupyter jupyter_server_host_service.v1.proto
generate codespace codespace_host_service.v1.proto
generate ssh ssh_server_host_service.v1.proto
echo 'Done!'

View file

@ -130,6 +130,9 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker,
invoker.codespaceClient = codespace.NewCodespaceHostClient(conn)
invoker.sshClient = ssh.NewSshServerHostClient(conn)
// Send initial connection heartbeat (no need to throw if we fail to get a response from the server)
_ = invoker.notifyCodespaceOfClientActivity(ctx, connectedEventName)
// Start the activity heatbeats
go invoker.heartbeat(pfctx, 1*time.Minute)
@ -253,9 +256,6 @@ func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
// Send initial connection heartbeat (no need to throw if we fail to get a response from the server)
_ = i.notifyCodespaceOfClientActivity(ctx, connectedEventName)
for {
select {
case <-ctx.Done():

View file

@ -3,84 +3,159 @@ package rpc
import (
"context"
"fmt"
"log"
"os"
"net"
"strconv"
"testing"
"github.com/cli/cli/v2/internal/codespaces/rpc/codespace"
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
"github.com/cli/cli/v2/internal/codespaces/rpc/ssh"
rpctest "github.com/cli/cli/v2/internal/codespaces/rpc/test"
"google.golang.org/grpc"
)
func startServer(t *testing.T) {
t.Helper()
if os.Getenv("GITHUB_ACTIONS") == "true" {
t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5663")
type mockServer struct {
jupyter.JupyterServerHostServerMock
codespace.CodespaceHostServerMock
ssh.SshServerHostServerMock
}
func newMockServer() *mockServer {
server := &mockServer{}
server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc = func(context.Context, *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) {
return &codespace.NotifyCodespaceOfClientActivityResponse{
Message: "",
Result: true,
}, nil
}
return server
}
// runTestGrpcServer serves grpc requests over the provided Listener using the mockServer for mocked callbacks.
// It does not return until the Context is cancelled and the server fully shuts down.
func runTestGrpcServer(ctx context.Context, listener net.Listener, server *mockServer) error {
s := grpc.NewServer()
jupyter.RegisterJupyterServerHostServer(s, server)
codespace.RegisterCodespaceHostServer(s, server)
ssh.RegisterSshServerHostServer(s, server)
ch := make(chan error, 1)
go func() { ch <- s.Serve(listener) }()
select {
case <-ctx.Done():
s.Stop()
<-ch
return nil
case err := <-ch:
return err
}
}
// createTestInvoker is the main test setup function. It returns an Invoker using the provided mockServer, as well as a shutdown function.
// The Invoker does not need to be closed directly, that will be handled by the shutdown function.
func createTestInvoker(t *testing.T, server *mockServer) (Invoker, func(), error) {
listener, err := net.Listen("tcp", "127.0.0.1:16634")
if err != nil {
return nil, nil, fmt.Errorf("failed to listen: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())
ch := make(chan error)
go func() { ch <- runTestGrpcServer(ctx, listener, server) }()
// Start the gRPC server in the background
go func() {
err := rpctest.StartServer(ctx)
if err != nil && err != context.Canceled {
log.Println(fmt.Errorf("error starting test server: %v", err))
}
}()
// Stop the gRPC server when the test is done
t.Cleanup(func() {
close := func() {
cancel()
})
}
func createTestInvoker(t *testing.T) Invoker {
t.Helper()
// Clear the stored client activity
rpctest.NotifyReceivedActivity = ""
<-ch
listener.Close()
}
invoker, err := CreateInvoker(context.Background(), &rpctest.Session{})
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
close()
return nil, nil, fmt.Errorf("error connecting to internal server: %w", err)
}
t.Cleanup(func() {
testNotifyCodespaceOfClientActivity(t)
return invoker, func() {
invoker.Close()
})
return invoker
close()
}, nil
}
// Test that the RPC invoker notifies the codespace of client activity on connection
func testNotifyCodespaceOfClientActivity(t *testing.T) {
if rpctest.NotifyReceivedActivity != connectedEventName {
t.Fatalf("expected %s, got %s", connectedEventName, rpctest.NotifyMessage)
func verifyNotifyCodespaceOfClientActivity(t *testing.T, server *mockServer) {
calls := server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityCalls()
if len(calls) == 0 {
t.Fatalf("no client activity calls")
}
for _, call := range calls {
activities := call.NotifyCodespaceOfClientActivityRequest.GetClientActivities()
if activities[0] == connectedEventName {
return
}
}
t.Fatalf("no activity named %s", connectedEventName)
}
// Test that the RPC invoker returns the correct port and URL when the JupyterLab server starts successfully
func TestStartJupyterServerSuccess(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
resp := jupyter.GetRunningServerResponse{
Port: strconv.Itoa(1234),
ServerUrl: "http://localhost:1234?token=1234",
Message: "",
Result: true,
}
server := newMockServer()
server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
port, url, err := invoker.StartJupyterServer(context.Background())
if err != nil {
t.Fatalf("expected %v, got %v", nil, err)
}
if port != rpctest.JupyterPort {
t.Fatalf("expected %d, got %d", rpctest.JupyterPort, port)
if strconv.Itoa(port) != resp.Port {
t.Fatalf("expected %s, got %d", resp.Port, port)
}
if url != rpctest.JupyterServerUrl {
t.Fatalf("expected %s, got %s", rpctest.JupyterServerUrl, url)
if url != resp.ServerUrl {
t.Fatalf("expected %s, got %s", resp.ServerUrl, url)
}
verifyNotifyCodespaceOfClientActivity(t, server)
}
// Test that the RPC invoker returns an error when the JupyterLab server fails to start
func TestStartJupyterServerFailure(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
rpctest.JupyterMessage = "error message"
rpctest.JupyterResult = false
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", rpctest.JupyterMessage)
resp := jupyter.GetRunningServerResponse{
Port: strconv.Itoa(1234),
ServerUrl: "http://localhost:1234?token=1234",
Message: "error message",
Result: false,
}
server := newMockServer()
server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", resp.Message)
port, url, err := invoker.StartJupyterServer(context.Background())
if err.Error() != errorMessage {
t.Fatalf("expected %v, got %v", errorMessage, err)
@ -91,35 +166,79 @@ func TestStartJupyterServerFailure(t *testing.T) {
if url != "" {
t.Fatalf("expected %s, got %s", "", url)
}
verifyNotifyCodespaceOfClientActivity(t, server)
}
// Test that the RPC invoker doesn't throw an error when requesting an incremental rebuild
func TestRebuildContainerIncremental(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
err := invoker.RebuildContainer(context.Background(), false)
resp := codespace.RebuildContainerResponse{
RebuildContainer: true,
}
server := newMockServer()
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
err = invoker.RebuildContainer(context.Background(), false)
if err != nil {
t.Fatalf("expected %v, got %v", nil, err)
}
verifyNotifyCodespaceOfClientActivity(t, server)
}
// Test that the RPC invoker doesn't throw an error when requesting a full rebuild
func TestRebuildContainerFull(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
err := invoker.RebuildContainer(context.Background(), true)
resp := codespace.RebuildContainerResponse{
RebuildContainer: true,
}
server := newMockServer()
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
err = invoker.RebuildContainer(context.Background(), true)
if err != nil {
t.Fatalf("expected %v, got %v", nil, err)
}
verifyNotifyCodespaceOfClientActivity(t, server)
}
// Test that the RPC invoker throws an error when the rebuild fails
func TestRebuildContainerFailure(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
rpctest.RebuildContainer = false
resp := codespace.RebuildContainerResponse{
RebuildContainer: false,
}
server := newMockServer()
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
errorMessage := "couldn't rebuild codespace"
err := invoker.RebuildContainer(context.Background(), true)
err = invoker.RebuildContainer(context.Background(), true)
if err.Error() != errorMessage {
t.Fatalf("expected %v, got %v", errorMessage, err)
}
@ -127,27 +246,59 @@ func TestRebuildContainerFailure(t *testing.T) {
// Test that the RPC invoker returns the correct port and user when the SSH server starts successfully
func TestStartSSHServerSuccess(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
resp := ssh.StartRemoteServerResponse{
ServerPort: strconv.Itoa(1234),
User: "test",
Message: "",
Result: true,
}
server := newMockServer()
server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
port, user, err := invoker.StartSSHServer(context.Background())
if err != nil {
t.Fatalf("expected %v, got %v", nil, err)
}
if port != rpctest.SshServerPort {
t.Fatalf("expected %d, got %d", rpctest.SshServerPort, port)
if strconv.Itoa(port) != resp.ServerPort {
t.Fatalf("expected %s, got %d", resp.ServerPort, port)
}
if user != rpctest.SshUser {
t.Fatalf("expected %s, got %s", rpctest.SshUser, user)
if user != resp.User {
t.Fatalf("expected %s, got %s", resp.User, user)
}
verifyNotifyCodespaceOfClientActivity(t, server)
}
// Test that the RPC invoker returns an error when the SSH server fails to start
func TestStartSSHServerFailure(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
rpctest.SshMessage = "error message"
rpctest.SshResult = false
errorMessage := fmt.Sprintf("failed to start SSH server: %s", rpctest.SshMessage)
resp := ssh.StartRemoteServerResponse{
ServerPort: strconv.Itoa(1234),
User: "test",
Message: "error message",
Result: false,
}
server := newMockServer()
server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
return &resp, nil
}
invoker, stop, err := createTestInvoker(t, server)
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
defer stop()
errorMessage := fmt.Sprintf("failed to start SSH server: %s", resp.Message)
port, user, err := invoker.StartSSHServer(context.Background())
if err.Error() != errorMessage {
t.Fatalf("expected %v, got %v", errorMessage, err)

View file

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.0
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: jupyter/jupyter_server_host_service.v1.proto

View file

@ -0,0 +1,118 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package jupyter
import (
context "context"
sync "sync"
)
// Ensure, that JupyterServerHostServerMock does implement JupyterServerHostServer.
// If this is not the case, regenerate this file with moq.
var _ JupyterServerHostServer = &JupyterServerHostServerMock{}
// JupyterServerHostServerMock is a mock implementation of JupyterServerHostServer.
//
// func TestSomethingThatUsesJupyterServerHostServer(t *testing.T) {
//
// // make and configure a mocked JupyterServerHostServer
// mockedJupyterServerHostServer := &JupyterServerHostServerMock{
// GetRunningServerFunc: func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) {
// panic("mock out the GetRunningServer method")
// },
// mustEmbedUnimplementedJupyterServerHostServerFunc: func() {
// panic("mock out the mustEmbedUnimplementedJupyterServerHostServer method")
// },
// }
//
// // use mockedJupyterServerHostServer in code that requires JupyterServerHostServer
// // and then make assertions.
//
// }
type JupyterServerHostServerMock struct {
// GetRunningServerFunc mocks the GetRunningServer method.
GetRunningServerFunc func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error)
// mustEmbedUnimplementedJupyterServerHostServerFunc mocks the mustEmbedUnimplementedJupyterServerHostServer method.
mustEmbedUnimplementedJupyterServerHostServerFunc func()
// calls tracks calls to the methods.
calls struct {
// GetRunningServer holds details about calls to the GetRunningServer method.
GetRunningServer []struct {
// ContextMoqParam is the contextMoqParam argument value.
ContextMoqParam context.Context
// GetRunningServerRequest is the getRunningServerRequest argument value.
GetRunningServerRequest *GetRunningServerRequest
}
// mustEmbedUnimplementedJupyterServerHostServer holds details about calls to the mustEmbedUnimplementedJupyterServerHostServer method.
mustEmbedUnimplementedJupyterServerHostServer []struct {
}
}
lockGetRunningServer sync.RWMutex
lockmustEmbedUnimplementedJupyterServerHostServer sync.RWMutex
}
// GetRunningServer calls GetRunningServerFunc.
func (mock *JupyterServerHostServerMock) GetRunningServer(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) {
if mock.GetRunningServerFunc == nil {
panic("JupyterServerHostServerMock.GetRunningServerFunc: method is nil but JupyterServerHostServer.GetRunningServer was just called")
}
callInfo := struct {
ContextMoqParam context.Context
GetRunningServerRequest *GetRunningServerRequest
}{
ContextMoqParam: contextMoqParam,
GetRunningServerRequest: getRunningServerRequest,
}
mock.lockGetRunningServer.Lock()
mock.calls.GetRunningServer = append(mock.calls.GetRunningServer, callInfo)
mock.lockGetRunningServer.Unlock()
return mock.GetRunningServerFunc(contextMoqParam, getRunningServerRequest)
}
// GetRunningServerCalls gets all the calls that were made to GetRunningServer.
// Check the length with:
//
// len(mockedJupyterServerHostServer.GetRunningServerCalls())
func (mock *JupyterServerHostServerMock) GetRunningServerCalls() []struct {
ContextMoqParam context.Context
GetRunningServerRequest *GetRunningServerRequest
} {
var calls []struct {
ContextMoqParam context.Context
GetRunningServerRequest *GetRunningServerRequest
}
mock.lockGetRunningServer.RLock()
calls = mock.calls.GetRunningServer
mock.lockGetRunningServer.RUnlock()
return calls
}
// mustEmbedUnimplementedJupyterServerHostServer calls mustEmbedUnimplementedJupyterServerHostServerFunc.
func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServer() {
if mock.mustEmbedUnimplementedJupyterServerHostServerFunc == nil {
panic("JupyterServerHostServerMock.mustEmbedUnimplementedJupyterServerHostServerFunc: method is nil but JupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServer was just called")
}
callInfo := struct {
}{}
mock.lockmustEmbedUnimplementedJupyterServerHostServer.Lock()
mock.calls.mustEmbedUnimplementedJupyterServerHostServer = append(mock.calls.mustEmbedUnimplementedJupyterServerHostServer, callInfo)
mock.lockmustEmbedUnimplementedJupyterServerHostServer.Unlock()
mock.mustEmbedUnimplementedJupyterServerHostServerFunc()
}
// mustEmbedUnimplementedJupyterServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedJupyterServerHostServer.
// Check the length with:
//
// len(mockedJupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServerCalls())
func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServerCalls() []struct {
} {
var calls []struct {
}
mock.lockmustEmbedUnimplementedJupyterServerHostServer.RLock()
calls = mock.calls.mustEmbedUnimplementedJupyterServerHostServer
mock.lockmustEmbedUnimplementedJupyterServerHostServer.RUnlock()
return calls
}

View file

@ -1,6 +1,6 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.0
// protoc-gen-go v1.28.1
// protoc v3.21.12
// source: ssh/ssh_server_host_service.v1.proto

View file

@ -0,0 +1,118 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package ssh
import (
context "context"
sync "sync"
)
// Ensure, that SshServerHostServerMock does implement SshServerHostServer.
// If this is not the case, regenerate this file with moq.
var _ SshServerHostServer = &SshServerHostServerMock{}
// SshServerHostServerMock is a mock implementation of SshServerHostServer.
//
// func TestSomethingThatUsesSshServerHostServer(t *testing.T) {
//
// // make and configure a mocked SshServerHostServer
// mockedSshServerHostServer := &SshServerHostServerMock{
// StartRemoteServerAsyncFunc: func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) {
// panic("mock out the StartRemoteServerAsync method")
// },
// mustEmbedUnimplementedSshServerHostServerFunc: func() {
// panic("mock out the mustEmbedUnimplementedSshServerHostServer method")
// },
// }
//
// // use mockedSshServerHostServer in code that requires SshServerHostServer
// // and then make assertions.
//
// }
type SshServerHostServerMock struct {
// StartRemoteServerAsyncFunc mocks the StartRemoteServerAsync method.
StartRemoteServerAsyncFunc func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error)
// mustEmbedUnimplementedSshServerHostServerFunc mocks the mustEmbedUnimplementedSshServerHostServer method.
mustEmbedUnimplementedSshServerHostServerFunc func()
// calls tracks calls to the methods.
calls struct {
// StartRemoteServerAsync holds details about calls to the StartRemoteServerAsync method.
StartRemoteServerAsync []struct {
// ContextMoqParam is the contextMoqParam argument value.
ContextMoqParam context.Context
// StartRemoteServerRequest is the startRemoteServerRequest argument value.
StartRemoteServerRequest *StartRemoteServerRequest
}
// mustEmbedUnimplementedSshServerHostServer holds details about calls to the mustEmbedUnimplementedSshServerHostServer method.
mustEmbedUnimplementedSshServerHostServer []struct {
}
}
lockStartRemoteServerAsync sync.RWMutex
lockmustEmbedUnimplementedSshServerHostServer sync.RWMutex
}
// StartRemoteServerAsync calls StartRemoteServerAsyncFunc.
func (mock *SshServerHostServerMock) StartRemoteServerAsync(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) {
if mock.StartRemoteServerAsyncFunc == nil {
panic("SshServerHostServerMock.StartRemoteServerAsyncFunc: method is nil but SshServerHostServer.StartRemoteServerAsync was just called")
}
callInfo := struct {
ContextMoqParam context.Context
StartRemoteServerRequest *StartRemoteServerRequest
}{
ContextMoqParam: contextMoqParam,
StartRemoteServerRequest: startRemoteServerRequest,
}
mock.lockStartRemoteServerAsync.Lock()
mock.calls.StartRemoteServerAsync = append(mock.calls.StartRemoteServerAsync, callInfo)
mock.lockStartRemoteServerAsync.Unlock()
return mock.StartRemoteServerAsyncFunc(contextMoqParam, startRemoteServerRequest)
}
// StartRemoteServerAsyncCalls gets all the calls that were made to StartRemoteServerAsync.
// Check the length with:
//
// len(mockedSshServerHostServer.StartRemoteServerAsyncCalls())
func (mock *SshServerHostServerMock) StartRemoteServerAsyncCalls() []struct {
ContextMoqParam context.Context
StartRemoteServerRequest *StartRemoteServerRequest
} {
var calls []struct {
ContextMoqParam context.Context
StartRemoteServerRequest *StartRemoteServerRequest
}
mock.lockStartRemoteServerAsync.RLock()
calls = mock.calls.StartRemoteServerAsync
mock.lockStartRemoteServerAsync.RUnlock()
return calls
}
// mustEmbedUnimplementedSshServerHostServer calls mustEmbedUnimplementedSshServerHostServerFunc.
func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServer() {
if mock.mustEmbedUnimplementedSshServerHostServerFunc == nil {
panic("SshServerHostServerMock.mustEmbedUnimplementedSshServerHostServerFunc: method is nil but SshServerHostServer.mustEmbedUnimplementedSshServerHostServer was just called")
}
callInfo := struct {
}{}
mock.lockmustEmbedUnimplementedSshServerHostServer.Lock()
mock.calls.mustEmbedUnimplementedSshServerHostServer = append(mock.calls.mustEmbedUnimplementedSshServerHostServer, callInfo)
mock.lockmustEmbedUnimplementedSshServerHostServer.Unlock()
mock.mustEmbedUnimplementedSshServerHostServerFunc()
}
// mustEmbedUnimplementedSshServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedSshServerHostServer.
// Check the length with:
//
// len(mockedSshServerHostServer.mustEmbedUnimplementedSshServerHostServerCalls())
func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServerCalls() []struct {
} {
var calls []struct {
}
mock.lockmustEmbedUnimplementedSshServerHostServer.RLock()
calls = mock.calls.mustEmbedUnimplementedSshServerHostServer
mock.lockmustEmbedUnimplementedSshServerHostServer.RUnlock()
return calls
}

View file

@ -1,117 +0,0 @@
package test
import (
"context"
"fmt"
"net"
"strconv"
"github.com/cli/cli/v2/internal/codespaces/rpc/codespace"
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
"github.com/cli/cli/v2/internal/codespaces/rpc/ssh"
"google.golang.org/grpc"
)
const (
ServerPort = 50051
)
// Mock responses for the `GetRunningServer` RPC method
var (
JupyterPort = 1234
JupyterServerUrl = "http://localhost:1234?token=1234"
JupyterMessage = ""
JupyterResult = true
)
// Mock responses for the `RebuildContainerAsync` RPC method
var (
RebuildContainer = true
)
// Mock responses for the `NotifyCodespaceOfClientActivity` RPC method
// NotifyMessage is used to store the activity that was sent to the server
var (
NotifyMessage = ""
NotifyResult = true
NotifyReceivedActivity = ""
)
// Mock responses for the `StartRemoteServerAsync` RPC method
var (
SshServerPort = 1234
SshUser = "test"
SshMessage = ""
SshResult = true
)
type server struct {
jupyter.UnimplementedJupyterServerHostServer
codespace.CodespaceHostServer
ssh.SshServerHostServer
}
func (s *server) GetRunningServer(ctx context.Context, in *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
return &jupyter.GetRunningServerResponse{
Port: strconv.Itoa(JupyterPort),
ServerUrl: JupyterServerUrl,
Message: JupyterMessage,
Result: JupyterResult,
}, nil
}
func (s *server) RebuildContainerAsync(ctx context.Context, in *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
return &codespace.RebuildContainerResponse{
RebuildContainer: RebuildContainer,
}, nil
}
func (s *server) NotifyCodespaceOfClientActivity(ctx context.Context, in *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) {
// If there is at least one client activity, set NotifyReceivedActivity to the first one (should be "connected")
if len(in.GetClientActivities()) > 0 {
NotifyReceivedActivity = in.GetClientActivities()[0]
}
return &codespace.NotifyCodespaceOfClientActivityResponse{
Message: NotifyMessage,
Result: NotifyResult,
}, nil
}
func (s *server) StartRemoteServerAsync(ctx context.Context, in *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
return &ssh.StartRemoteServerResponse{
ServerPort: strconv.Itoa(SshServerPort),
User: SshUser,
Message: SshMessage,
Result: SshResult,
}, nil
}
// Starts the mock gRPC server listening on port 50051
func StartServer(ctx context.Context) error {
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
if err != nil {
return fmt.Errorf("failed to listen: %v", err)
}
defer listener.Close()
s := grpc.NewServer()
jupyter.RegisterJupyterServerHostServer(s, &server{})
codespace.RegisterCodespaceHostServer(s, &server{})
ssh.RegisterSshServerHostServer(s, &server{})
ch := make(chan error, 1)
go func() {
if err := s.Serve(listener); err != nil {
ch <- fmt.Errorf("failed to serve: %v", err)
}
}()
select {
case <-ctx.Done():
s.Stop()
return ctx.Err()
case err := <-ch:
return err
}
}

View file

@ -29,7 +29,7 @@ func (s *Session) GetKeepAliveReason() string {
}
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) {
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return liveshare.ChannelID{}, err
}