Fix race conditions in invoker_test (#6905)
This commit is contained in:
parent
fef4195004
commit
f669a10cf9
12 changed files with 640 additions and 194 deletions
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.28.0
|
||||
// protoc-gen-go v1.28.1
|
||||
// protoc v3.21.12
|
||||
// source: codespace/codespace_host_service.v1.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,168 @@
|
|||
// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package codespace
|
||||
|
||||
import (
|
||||
context "context"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
// Ensure, that CodespaceHostServerMock does implement CodespaceHostServer.
|
||||
// If this is not the case, regenerate this file with moq.
|
||||
var _ CodespaceHostServer = &CodespaceHostServerMock{}
|
||||
|
||||
// CodespaceHostServerMock is a mock implementation of CodespaceHostServer.
|
||||
//
|
||||
// func TestSomethingThatUsesCodespaceHostServer(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked CodespaceHostServer
|
||||
// mockedCodespaceHostServer := &CodespaceHostServerMock{
|
||||
// NotifyCodespaceOfClientActivityFunc: func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) {
|
||||
// panic("mock out the NotifyCodespaceOfClientActivity method")
|
||||
// },
|
||||
// RebuildContainerAsyncFunc: func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) {
|
||||
// panic("mock out the RebuildContainerAsync method")
|
||||
// },
|
||||
// mustEmbedUnimplementedCodespaceHostServerFunc: func() {
|
||||
// panic("mock out the mustEmbedUnimplementedCodespaceHostServer method")
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // use mockedCodespaceHostServer in code that requires CodespaceHostServer
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type CodespaceHostServerMock struct {
|
||||
// NotifyCodespaceOfClientActivityFunc mocks the NotifyCodespaceOfClientActivity method.
|
||||
NotifyCodespaceOfClientActivityFunc func(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error)
|
||||
|
||||
// RebuildContainerAsyncFunc mocks the RebuildContainerAsync method.
|
||||
RebuildContainerAsyncFunc func(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error)
|
||||
|
||||
// mustEmbedUnimplementedCodespaceHostServerFunc mocks the mustEmbedUnimplementedCodespaceHostServer method.
|
||||
mustEmbedUnimplementedCodespaceHostServerFunc func()
|
||||
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
// NotifyCodespaceOfClientActivity holds details about calls to the NotifyCodespaceOfClientActivity method.
|
||||
NotifyCodespaceOfClientActivity []struct {
|
||||
// ContextMoqParam is the contextMoqParam argument value.
|
||||
ContextMoqParam context.Context
|
||||
// NotifyCodespaceOfClientActivityRequest is the notifyCodespaceOfClientActivityRequest argument value.
|
||||
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
|
||||
}
|
||||
// RebuildContainerAsync holds details about calls to the RebuildContainerAsync method.
|
||||
RebuildContainerAsync []struct {
|
||||
// ContextMoqParam is the contextMoqParam argument value.
|
||||
ContextMoqParam context.Context
|
||||
// RebuildContainerRequest is the rebuildContainerRequest argument value.
|
||||
RebuildContainerRequest *RebuildContainerRequest
|
||||
}
|
||||
// mustEmbedUnimplementedCodespaceHostServer holds details about calls to the mustEmbedUnimplementedCodespaceHostServer method.
|
||||
mustEmbedUnimplementedCodespaceHostServer []struct {
|
||||
}
|
||||
}
|
||||
lockNotifyCodespaceOfClientActivity sync.RWMutex
|
||||
lockRebuildContainerAsync sync.RWMutex
|
||||
lockmustEmbedUnimplementedCodespaceHostServer sync.RWMutex
|
||||
}
|
||||
|
||||
// NotifyCodespaceOfClientActivity calls NotifyCodespaceOfClientActivityFunc.
|
||||
func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivity(contextMoqParam context.Context, notifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest) (*NotifyCodespaceOfClientActivityResponse, error) {
|
||||
if mock.NotifyCodespaceOfClientActivityFunc == nil {
|
||||
panic("CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc: method is nil but CodespaceHostServer.NotifyCodespaceOfClientActivity was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
ContextMoqParam context.Context
|
||||
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
|
||||
}{
|
||||
ContextMoqParam: contextMoqParam,
|
||||
NotifyCodespaceOfClientActivityRequest: notifyCodespaceOfClientActivityRequest,
|
||||
}
|
||||
mock.lockNotifyCodespaceOfClientActivity.Lock()
|
||||
mock.calls.NotifyCodespaceOfClientActivity = append(mock.calls.NotifyCodespaceOfClientActivity, callInfo)
|
||||
mock.lockNotifyCodespaceOfClientActivity.Unlock()
|
||||
return mock.NotifyCodespaceOfClientActivityFunc(contextMoqParam, notifyCodespaceOfClientActivityRequest)
|
||||
}
|
||||
|
||||
// NotifyCodespaceOfClientActivityCalls gets all the calls that were made to NotifyCodespaceOfClientActivity.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedCodespaceHostServer.NotifyCodespaceOfClientActivityCalls())
|
||||
func (mock *CodespaceHostServerMock) NotifyCodespaceOfClientActivityCalls() []struct {
|
||||
ContextMoqParam context.Context
|
||||
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
|
||||
} {
|
||||
var calls []struct {
|
||||
ContextMoqParam context.Context
|
||||
NotifyCodespaceOfClientActivityRequest *NotifyCodespaceOfClientActivityRequest
|
||||
}
|
||||
mock.lockNotifyCodespaceOfClientActivity.RLock()
|
||||
calls = mock.calls.NotifyCodespaceOfClientActivity
|
||||
mock.lockNotifyCodespaceOfClientActivity.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// RebuildContainerAsync calls RebuildContainerAsyncFunc.
|
||||
func (mock *CodespaceHostServerMock) RebuildContainerAsync(contextMoqParam context.Context, rebuildContainerRequest *RebuildContainerRequest) (*RebuildContainerResponse, error) {
|
||||
if mock.RebuildContainerAsyncFunc == nil {
|
||||
panic("CodespaceHostServerMock.RebuildContainerAsyncFunc: method is nil but CodespaceHostServer.RebuildContainerAsync was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
ContextMoqParam context.Context
|
||||
RebuildContainerRequest *RebuildContainerRequest
|
||||
}{
|
||||
ContextMoqParam: contextMoqParam,
|
||||
RebuildContainerRequest: rebuildContainerRequest,
|
||||
}
|
||||
mock.lockRebuildContainerAsync.Lock()
|
||||
mock.calls.RebuildContainerAsync = append(mock.calls.RebuildContainerAsync, callInfo)
|
||||
mock.lockRebuildContainerAsync.Unlock()
|
||||
return mock.RebuildContainerAsyncFunc(contextMoqParam, rebuildContainerRequest)
|
||||
}
|
||||
|
||||
// RebuildContainerAsyncCalls gets all the calls that were made to RebuildContainerAsync.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedCodespaceHostServer.RebuildContainerAsyncCalls())
|
||||
func (mock *CodespaceHostServerMock) RebuildContainerAsyncCalls() []struct {
|
||||
ContextMoqParam context.Context
|
||||
RebuildContainerRequest *RebuildContainerRequest
|
||||
} {
|
||||
var calls []struct {
|
||||
ContextMoqParam context.Context
|
||||
RebuildContainerRequest *RebuildContainerRequest
|
||||
}
|
||||
mock.lockRebuildContainerAsync.RLock()
|
||||
calls = mock.calls.RebuildContainerAsync
|
||||
mock.lockRebuildContainerAsync.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// mustEmbedUnimplementedCodespaceHostServer calls mustEmbedUnimplementedCodespaceHostServerFunc.
|
||||
func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServer() {
|
||||
if mock.mustEmbedUnimplementedCodespaceHostServerFunc == nil {
|
||||
panic("CodespaceHostServerMock.mustEmbedUnimplementedCodespaceHostServerFunc: method is nil but CodespaceHostServer.mustEmbedUnimplementedCodespaceHostServer was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockmustEmbedUnimplementedCodespaceHostServer.Lock()
|
||||
mock.calls.mustEmbedUnimplementedCodespaceHostServer = append(mock.calls.mustEmbedUnimplementedCodespaceHostServer, callInfo)
|
||||
mock.lockmustEmbedUnimplementedCodespaceHostServer.Unlock()
|
||||
mock.mustEmbedUnimplementedCodespaceHostServerFunc()
|
||||
}
|
||||
|
||||
// mustEmbedUnimplementedCodespaceHostServerCalls gets all the calls that were made to mustEmbedUnimplementedCodespaceHostServer.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedCodespaceHostServer.mustEmbedUnimplementedCodespaceHostServerCalls())
|
||||
func (mock *CodespaceHostServerMock) mustEmbedUnimplementedCodespaceHostServerCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockmustEmbedUnimplementedCodespaceHostServer.RLock()
|
||||
calls = mock.calls.mustEmbedUnimplementedCodespaceHostServer
|
||||
mock.lockmustEmbedUnimplementedCodespaceHostServer.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
|
@ -6,7 +6,8 @@ Instructions for generating and adding gRPC protocol buffers.
|
|||
|
||||
1. [Download `protoc`](https://grpc.io/docs/protoc-installation/)
|
||||
2. [Download protocol compiler plugins for Go](https://grpc.io/docs/languages/go/quickstart/)
|
||||
3. Run `./generate.sh` from the `internal/codespaces/grpc` directory
|
||||
3. Install moq: `go install github.com/matryer/moq@latest`
|
||||
4. Run `./generate.sh` from the `internal/codespaces/rpc` directory
|
||||
|
||||
## Add New Protocol Buffers
|
||||
|
||||
|
|
|
|||
|
|
@ -15,14 +15,21 @@ if ! protoc-gen-go-grpc --version; then
|
|||
fi
|
||||
|
||||
function generate {
|
||||
local contract="$1"
|
||||
local dir="$1"
|
||||
local proto="$2"
|
||||
|
||||
local contract="$dir/$proto"
|
||||
|
||||
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative "$contract"
|
||||
echo "Generated protocol buffers for $contract"
|
||||
|
||||
services=$(cat "$contract" | grep -Eo "service .+ {" | awk '{print $2 "Server"}')
|
||||
moq -out $contract.mock.go $dir $services
|
||||
echo "Generated mock protocols for $contract"
|
||||
}
|
||||
|
||||
generate jupyter/jupyter_server_host_service.v1.proto
|
||||
generate codespace/codespace_host_service.v1.proto
|
||||
generate ssh/ssh_server_host_service.v1.proto
|
||||
generate jupyter jupyter_server_host_service.v1.proto
|
||||
generate codespace codespace_host_service.v1.proto
|
||||
generate ssh ssh_server_host_service.v1.proto
|
||||
|
||||
echo 'Done!'
|
||||
|
|
|
|||
|
|
@ -130,6 +130,9 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker,
|
|||
invoker.codespaceClient = codespace.NewCodespaceHostClient(conn)
|
||||
invoker.sshClient = ssh.NewSshServerHostClient(conn)
|
||||
|
||||
// Send initial connection heartbeat (no need to throw if we fail to get a response from the server)
|
||||
_ = invoker.notifyCodespaceOfClientActivity(ctx, connectedEventName)
|
||||
|
||||
// Start the activity heatbeats
|
||||
go invoker.heartbeat(pfctx, 1*time.Minute)
|
||||
|
||||
|
|
@ -253,9 +256,6 @@ func (i *invoker) heartbeat(ctx context.Context, interval time.Duration) {
|
|||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Send initial connection heartbeat (no need to throw if we fail to get a response from the server)
|
||||
_ = i.notifyCodespaceOfClientActivity(ctx, connectedEventName)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
|
|
|||
|
|
@ -3,84 +3,159 @@ package rpc
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"net"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/codespace"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/ssh"
|
||||
rpctest "github.com/cli/cli/v2/internal/codespaces/rpc/test"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func startServer(t *testing.T) {
|
||||
t.Helper()
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" {
|
||||
t.Skip("fails intermittently in CI: https://github.com/cli/cli/issues/5663")
|
||||
type mockServer struct {
|
||||
jupyter.JupyterServerHostServerMock
|
||||
codespace.CodespaceHostServerMock
|
||||
ssh.SshServerHostServerMock
|
||||
}
|
||||
|
||||
func newMockServer() *mockServer {
|
||||
server := &mockServer{}
|
||||
|
||||
server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc = func(context.Context, *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) {
|
||||
return &codespace.NotifyCodespaceOfClientActivityResponse{
|
||||
Message: "",
|
||||
Result: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
// runTestGrpcServer serves grpc requests over the provided Listener using the mockServer for mocked callbacks.
|
||||
// It does not return until the Context is cancelled and the server fully shuts down.
|
||||
func runTestGrpcServer(ctx context.Context, listener net.Listener, server *mockServer) error {
|
||||
s := grpc.NewServer()
|
||||
jupyter.RegisterJupyterServerHostServer(s, server)
|
||||
codespace.RegisterCodespaceHostServer(s, server)
|
||||
ssh.RegisterSshServerHostServer(s, server)
|
||||
|
||||
ch := make(chan error, 1)
|
||||
go func() { ch <- s.Serve(listener) }()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.Stop()
|
||||
<-ch
|
||||
return nil
|
||||
case err := <-ch:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// createTestInvoker is the main test setup function. It returns an Invoker using the provided mockServer, as well as a shutdown function.
|
||||
// The Invoker does not need to be closed directly, that will be handled by the shutdown function.
|
||||
func createTestInvoker(t *testing.T, server *mockServer) (Invoker, func(), error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:16634")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to listen: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ch := make(chan error)
|
||||
go func() { ch <- runTestGrpcServer(ctx, listener, server) }()
|
||||
|
||||
// Start the gRPC server in the background
|
||||
go func() {
|
||||
err := rpctest.StartServer(ctx)
|
||||
if err != nil && err != context.Canceled {
|
||||
log.Println(fmt.Errorf("error starting test server: %v", err))
|
||||
}
|
||||
}()
|
||||
|
||||
// Stop the gRPC server when the test is done
|
||||
t.Cleanup(func() {
|
||||
close := func() {
|
||||
cancel()
|
||||
})
|
||||
}
|
||||
|
||||
func createTestInvoker(t *testing.T) Invoker {
|
||||
t.Helper()
|
||||
|
||||
// Clear the stored client activity
|
||||
rpctest.NotifyReceivedActivity = ""
|
||||
<-ch
|
||||
listener.Close()
|
||||
}
|
||||
|
||||
invoker, err := CreateInvoker(context.Background(), &rpctest.Session{})
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
close()
|
||||
return nil, nil, fmt.Errorf("error connecting to internal server: %w", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
testNotifyCodespaceOfClientActivity(t)
|
||||
return invoker, func() {
|
||||
invoker.Close()
|
||||
})
|
||||
|
||||
return invoker
|
||||
close()
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Test that the RPC invoker notifies the codespace of client activity on connection
|
||||
func testNotifyCodespaceOfClientActivity(t *testing.T) {
|
||||
if rpctest.NotifyReceivedActivity != connectedEventName {
|
||||
t.Fatalf("expected %s, got %s", connectedEventName, rpctest.NotifyMessage)
|
||||
func verifyNotifyCodespaceOfClientActivity(t *testing.T, server *mockServer) {
|
||||
calls := server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityCalls()
|
||||
if len(calls) == 0 {
|
||||
t.Fatalf("no client activity calls")
|
||||
}
|
||||
|
||||
for _, call := range calls {
|
||||
activities := call.NotifyCodespaceOfClientActivityRequest.GetClientActivities()
|
||||
if activities[0] == connectedEventName {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Fatalf("no activity named %s", connectedEventName)
|
||||
}
|
||||
|
||||
// Test that the RPC invoker returns the correct port and URL when the JupyterLab server starts successfully
|
||||
func TestStartJupyterServerSuccess(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
resp := jupyter.GetRunningServerResponse{
|
||||
Port: strconv.Itoa(1234),
|
||||
ServerUrl: "http://localhost:1234?token=1234",
|
||||
Message: "",
|
||||
Result: true,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
port, url, err := invoker.StartJupyterServer(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected %v, got %v", nil, err)
|
||||
}
|
||||
if port != rpctest.JupyterPort {
|
||||
t.Fatalf("expected %d, got %d", rpctest.JupyterPort, port)
|
||||
if strconv.Itoa(port) != resp.Port {
|
||||
t.Fatalf("expected %s, got %d", resp.Port, port)
|
||||
}
|
||||
if url != rpctest.JupyterServerUrl {
|
||||
t.Fatalf("expected %s, got %s", rpctest.JupyterServerUrl, url)
|
||||
if url != resp.ServerUrl {
|
||||
t.Fatalf("expected %s, got %s", resp.ServerUrl, url)
|
||||
}
|
||||
|
||||
verifyNotifyCodespaceOfClientActivity(t, server)
|
||||
}
|
||||
|
||||
// Test that the RPC invoker returns an error when the JupyterLab server fails to start
|
||||
func TestStartJupyterServerFailure(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
rpctest.JupyterMessage = "error message"
|
||||
rpctest.JupyterResult = false
|
||||
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", rpctest.JupyterMessage)
|
||||
resp := jupyter.GetRunningServerResponse{
|
||||
Port: strconv.Itoa(1234),
|
||||
ServerUrl: "http://localhost:1234?token=1234",
|
||||
Message: "error message",
|
||||
Result: false,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", resp.Message)
|
||||
port, url, err := invoker.StartJupyterServer(context.Background())
|
||||
if err.Error() != errorMessage {
|
||||
t.Fatalf("expected %v, got %v", errorMessage, err)
|
||||
|
|
@ -91,35 +166,79 @@ func TestStartJupyterServerFailure(t *testing.T) {
|
|||
if url != "" {
|
||||
t.Fatalf("expected %s, got %s", "", url)
|
||||
}
|
||||
|
||||
verifyNotifyCodespaceOfClientActivity(t, server)
|
||||
}
|
||||
|
||||
// Test that the RPC invoker doesn't throw an error when requesting an incremental rebuild
|
||||
func TestRebuildContainerIncremental(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
err := invoker.RebuildContainer(context.Background(), false)
|
||||
resp := codespace.RebuildContainerResponse{
|
||||
RebuildContainer: true,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
err = invoker.RebuildContainer(context.Background(), false)
|
||||
if err != nil {
|
||||
t.Fatalf("expected %v, got %v", nil, err)
|
||||
}
|
||||
|
||||
verifyNotifyCodespaceOfClientActivity(t, server)
|
||||
}
|
||||
|
||||
// Test that the RPC invoker doesn't throw an error when requesting a full rebuild
|
||||
func TestRebuildContainerFull(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
err := invoker.RebuildContainer(context.Background(), true)
|
||||
resp := codespace.RebuildContainerResponse{
|
||||
RebuildContainer: true,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
err = invoker.RebuildContainer(context.Background(), true)
|
||||
if err != nil {
|
||||
t.Fatalf("expected %v, got %v", nil, err)
|
||||
}
|
||||
|
||||
verifyNotifyCodespaceOfClientActivity(t, server)
|
||||
}
|
||||
|
||||
// Test that the RPC invoker throws an error when the rebuild fails
|
||||
func TestRebuildContainerFailure(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
rpctest.RebuildContainer = false
|
||||
resp := codespace.RebuildContainerResponse{
|
||||
RebuildContainer: false,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
errorMessage := "couldn't rebuild codespace"
|
||||
err := invoker.RebuildContainer(context.Background(), true)
|
||||
err = invoker.RebuildContainer(context.Background(), true)
|
||||
if err.Error() != errorMessage {
|
||||
t.Fatalf("expected %v, got %v", errorMessage, err)
|
||||
}
|
||||
|
|
@ -127,27 +246,59 @@ func TestRebuildContainerFailure(t *testing.T) {
|
|||
|
||||
// Test that the RPC invoker returns the correct port and user when the SSH server starts successfully
|
||||
func TestStartSSHServerSuccess(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
resp := ssh.StartRemoteServerResponse{
|
||||
ServerPort: strconv.Itoa(1234),
|
||||
User: "test",
|
||||
Message: "",
|
||||
Result: true,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
port, user, err := invoker.StartSSHServer(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected %v, got %v", nil, err)
|
||||
}
|
||||
if port != rpctest.SshServerPort {
|
||||
t.Fatalf("expected %d, got %d", rpctest.SshServerPort, port)
|
||||
if strconv.Itoa(port) != resp.ServerPort {
|
||||
t.Fatalf("expected %s, got %d", resp.ServerPort, port)
|
||||
}
|
||||
if user != rpctest.SshUser {
|
||||
t.Fatalf("expected %s, got %s", rpctest.SshUser, user)
|
||||
if user != resp.User {
|
||||
t.Fatalf("expected %s, got %s", resp.User, user)
|
||||
}
|
||||
|
||||
verifyNotifyCodespaceOfClientActivity(t, server)
|
||||
}
|
||||
|
||||
// Test that the RPC invoker returns an error when the SSH server fails to start
|
||||
func TestStartSSHServerFailure(t *testing.T) {
|
||||
startServer(t)
|
||||
invoker := createTestInvoker(t)
|
||||
rpctest.SshMessage = "error message"
|
||||
rpctest.SshResult = false
|
||||
errorMessage := fmt.Sprintf("failed to start SSH server: %s", rpctest.SshMessage)
|
||||
resp := ssh.StartRemoteServerResponse{
|
||||
ServerPort: strconv.Itoa(1234),
|
||||
User: "test",
|
||||
Message: "error message",
|
||||
Result: false,
|
||||
}
|
||||
|
||||
server := newMockServer()
|
||||
server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
invoker, stop, err := createTestInvoker(t, server)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
defer stop()
|
||||
|
||||
errorMessage := fmt.Sprintf("failed to start SSH server: %s", resp.Message)
|
||||
port, user, err := invoker.StartSSHServer(context.Background())
|
||||
if err.Error() != errorMessage {
|
||||
t.Fatalf("expected %v, got %v", errorMessage, err)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.28.0
|
||||
// protoc-gen-go v1.28.1
|
||||
// protoc v3.21.12
|
||||
// source: jupyter/jupyter_server_host_service.v1.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,118 @@
|
|||
// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package jupyter
|
||||
|
||||
import (
|
||||
context "context"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
// Ensure, that JupyterServerHostServerMock does implement JupyterServerHostServer.
|
||||
// If this is not the case, regenerate this file with moq.
|
||||
var _ JupyterServerHostServer = &JupyterServerHostServerMock{}
|
||||
|
||||
// JupyterServerHostServerMock is a mock implementation of JupyterServerHostServer.
|
||||
//
|
||||
// func TestSomethingThatUsesJupyterServerHostServer(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked JupyterServerHostServer
|
||||
// mockedJupyterServerHostServer := &JupyterServerHostServerMock{
|
||||
// GetRunningServerFunc: func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) {
|
||||
// panic("mock out the GetRunningServer method")
|
||||
// },
|
||||
// mustEmbedUnimplementedJupyterServerHostServerFunc: func() {
|
||||
// panic("mock out the mustEmbedUnimplementedJupyterServerHostServer method")
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // use mockedJupyterServerHostServer in code that requires JupyterServerHostServer
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type JupyterServerHostServerMock struct {
|
||||
// GetRunningServerFunc mocks the GetRunningServer method.
|
||||
GetRunningServerFunc func(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error)
|
||||
|
||||
// mustEmbedUnimplementedJupyterServerHostServerFunc mocks the mustEmbedUnimplementedJupyterServerHostServer method.
|
||||
mustEmbedUnimplementedJupyterServerHostServerFunc func()
|
||||
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
// GetRunningServer holds details about calls to the GetRunningServer method.
|
||||
GetRunningServer []struct {
|
||||
// ContextMoqParam is the contextMoqParam argument value.
|
||||
ContextMoqParam context.Context
|
||||
// GetRunningServerRequest is the getRunningServerRequest argument value.
|
||||
GetRunningServerRequest *GetRunningServerRequest
|
||||
}
|
||||
// mustEmbedUnimplementedJupyterServerHostServer holds details about calls to the mustEmbedUnimplementedJupyterServerHostServer method.
|
||||
mustEmbedUnimplementedJupyterServerHostServer []struct {
|
||||
}
|
||||
}
|
||||
lockGetRunningServer sync.RWMutex
|
||||
lockmustEmbedUnimplementedJupyterServerHostServer sync.RWMutex
|
||||
}
|
||||
|
||||
// GetRunningServer calls GetRunningServerFunc.
|
||||
func (mock *JupyterServerHostServerMock) GetRunningServer(contextMoqParam context.Context, getRunningServerRequest *GetRunningServerRequest) (*GetRunningServerResponse, error) {
|
||||
if mock.GetRunningServerFunc == nil {
|
||||
panic("JupyterServerHostServerMock.GetRunningServerFunc: method is nil but JupyterServerHostServer.GetRunningServer was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
ContextMoqParam context.Context
|
||||
GetRunningServerRequest *GetRunningServerRequest
|
||||
}{
|
||||
ContextMoqParam: contextMoqParam,
|
||||
GetRunningServerRequest: getRunningServerRequest,
|
||||
}
|
||||
mock.lockGetRunningServer.Lock()
|
||||
mock.calls.GetRunningServer = append(mock.calls.GetRunningServer, callInfo)
|
||||
mock.lockGetRunningServer.Unlock()
|
||||
return mock.GetRunningServerFunc(contextMoqParam, getRunningServerRequest)
|
||||
}
|
||||
|
||||
// GetRunningServerCalls gets all the calls that were made to GetRunningServer.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedJupyterServerHostServer.GetRunningServerCalls())
|
||||
func (mock *JupyterServerHostServerMock) GetRunningServerCalls() []struct {
|
||||
ContextMoqParam context.Context
|
||||
GetRunningServerRequest *GetRunningServerRequest
|
||||
} {
|
||||
var calls []struct {
|
||||
ContextMoqParam context.Context
|
||||
GetRunningServerRequest *GetRunningServerRequest
|
||||
}
|
||||
mock.lockGetRunningServer.RLock()
|
||||
calls = mock.calls.GetRunningServer
|
||||
mock.lockGetRunningServer.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// mustEmbedUnimplementedJupyterServerHostServer calls mustEmbedUnimplementedJupyterServerHostServerFunc.
|
||||
func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServer() {
|
||||
if mock.mustEmbedUnimplementedJupyterServerHostServerFunc == nil {
|
||||
panic("JupyterServerHostServerMock.mustEmbedUnimplementedJupyterServerHostServerFunc: method is nil but JupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServer was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockmustEmbedUnimplementedJupyterServerHostServer.Lock()
|
||||
mock.calls.mustEmbedUnimplementedJupyterServerHostServer = append(mock.calls.mustEmbedUnimplementedJupyterServerHostServer, callInfo)
|
||||
mock.lockmustEmbedUnimplementedJupyterServerHostServer.Unlock()
|
||||
mock.mustEmbedUnimplementedJupyterServerHostServerFunc()
|
||||
}
|
||||
|
||||
// mustEmbedUnimplementedJupyterServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedJupyterServerHostServer.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedJupyterServerHostServer.mustEmbedUnimplementedJupyterServerHostServerCalls())
|
||||
func (mock *JupyterServerHostServerMock) mustEmbedUnimplementedJupyterServerHostServerCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockmustEmbedUnimplementedJupyterServerHostServer.RLock()
|
||||
calls = mock.calls.mustEmbedUnimplementedJupyterServerHostServer
|
||||
mock.lockmustEmbedUnimplementedJupyterServerHostServer.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.28.0
|
||||
// protoc-gen-go v1.28.1
|
||||
// protoc v3.21.12
|
||||
// source: ssh/ssh_server_host_service.v1.proto
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,118 @@
|
|||
// Code generated by moq; DO NOT EDIT.
|
||||
// github.com/matryer/moq
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
context "context"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
// Ensure, that SshServerHostServerMock does implement SshServerHostServer.
|
||||
// If this is not the case, regenerate this file with moq.
|
||||
var _ SshServerHostServer = &SshServerHostServerMock{}
|
||||
|
||||
// SshServerHostServerMock is a mock implementation of SshServerHostServer.
|
||||
//
|
||||
// func TestSomethingThatUsesSshServerHostServer(t *testing.T) {
|
||||
//
|
||||
// // make and configure a mocked SshServerHostServer
|
||||
// mockedSshServerHostServer := &SshServerHostServerMock{
|
||||
// StartRemoteServerAsyncFunc: func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) {
|
||||
// panic("mock out the StartRemoteServerAsync method")
|
||||
// },
|
||||
// mustEmbedUnimplementedSshServerHostServerFunc: func() {
|
||||
// panic("mock out the mustEmbedUnimplementedSshServerHostServer method")
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // use mockedSshServerHostServer in code that requires SshServerHostServer
|
||||
// // and then make assertions.
|
||||
//
|
||||
// }
|
||||
type SshServerHostServerMock struct {
|
||||
// StartRemoteServerAsyncFunc mocks the StartRemoteServerAsync method.
|
||||
StartRemoteServerAsyncFunc func(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error)
|
||||
|
||||
// mustEmbedUnimplementedSshServerHostServerFunc mocks the mustEmbedUnimplementedSshServerHostServer method.
|
||||
mustEmbedUnimplementedSshServerHostServerFunc func()
|
||||
|
||||
// calls tracks calls to the methods.
|
||||
calls struct {
|
||||
// StartRemoteServerAsync holds details about calls to the StartRemoteServerAsync method.
|
||||
StartRemoteServerAsync []struct {
|
||||
// ContextMoqParam is the contextMoqParam argument value.
|
||||
ContextMoqParam context.Context
|
||||
// StartRemoteServerRequest is the startRemoteServerRequest argument value.
|
||||
StartRemoteServerRequest *StartRemoteServerRequest
|
||||
}
|
||||
// mustEmbedUnimplementedSshServerHostServer holds details about calls to the mustEmbedUnimplementedSshServerHostServer method.
|
||||
mustEmbedUnimplementedSshServerHostServer []struct {
|
||||
}
|
||||
}
|
||||
lockStartRemoteServerAsync sync.RWMutex
|
||||
lockmustEmbedUnimplementedSshServerHostServer sync.RWMutex
|
||||
}
|
||||
|
||||
// StartRemoteServerAsync calls StartRemoteServerAsyncFunc.
|
||||
func (mock *SshServerHostServerMock) StartRemoteServerAsync(contextMoqParam context.Context, startRemoteServerRequest *StartRemoteServerRequest) (*StartRemoteServerResponse, error) {
|
||||
if mock.StartRemoteServerAsyncFunc == nil {
|
||||
panic("SshServerHostServerMock.StartRemoteServerAsyncFunc: method is nil but SshServerHostServer.StartRemoteServerAsync was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
ContextMoqParam context.Context
|
||||
StartRemoteServerRequest *StartRemoteServerRequest
|
||||
}{
|
||||
ContextMoqParam: contextMoqParam,
|
||||
StartRemoteServerRequest: startRemoteServerRequest,
|
||||
}
|
||||
mock.lockStartRemoteServerAsync.Lock()
|
||||
mock.calls.StartRemoteServerAsync = append(mock.calls.StartRemoteServerAsync, callInfo)
|
||||
mock.lockStartRemoteServerAsync.Unlock()
|
||||
return mock.StartRemoteServerAsyncFunc(contextMoqParam, startRemoteServerRequest)
|
||||
}
|
||||
|
||||
// StartRemoteServerAsyncCalls gets all the calls that were made to StartRemoteServerAsync.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedSshServerHostServer.StartRemoteServerAsyncCalls())
|
||||
func (mock *SshServerHostServerMock) StartRemoteServerAsyncCalls() []struct {
|
||||
ContextMoqParam context.Context
|
||||
StartRemoteServerRequest *StartRemoteServerRequest
|
||||
} {
|
||||
var calls []struct {
|
||||
ContextMoqParam context.Context
|
||||
StartRemoteServerRequest *StartRemoteServerRequest
|
||||
}
|
||||
mock.lockStartRemoteServerAsync.RLock()
|
||||
calls = mock.calls.StartRemoteServerAsync
|
||||
mock.lockStartRemoteServerAsync.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
||||
// mustEmbedUnimplementedSshServerHostServer calls mustEmbedUnimplementedSshServerHostServerFunc.
|
||||
func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServer() {
|
||||
if mock.mustEmbedUnimplementedSshServerHostServerFunc == nil {
|
||||
panic("SshServerHostServerMock.mustEmbedUnimplementedSshServerHostServerFunc: method is nil but SshServerHostServer.mustEmbedUnimplementedSshServerHostServer was just called")
|
||||
}
|
||||
callInfo := struct {
|
||||
}{}
|
||||
mock.lockmustEmbedUnimplementedSshServerHostServer.Lock()
|
||||
mock.calls.mustEmbedUnimplementedSshServerHostServer = append(mock.calls.mustEmbedUnimplementedSshServerHostServer, callInfo)
|
||||
mock.lockmustEmbedUnimplementedSshServerHostServer.Unlock()
|
||||
mock.mustEmbedUnimplementedSshServerHostServerFunc()
|
||||
}
|
||||
|
||||
// mustEmbedUnimplementedSshServerHostServerCalls gets all the calls that were made to mustEmbedUnimplementedSshServerHostServer.
|
||||
// Check the length with:
|
||||
//
|
||||
// len(mockedSshServerHostServer.mustEmbedUnimplementedSshServerHostServerCalls())
|
||||
func (mock *SshServerHostServerMock) mustEmbedUnimplementedSshServerHostServerCalls() []struct {
|
||||
} {
|
||||
var calls []struct {
|
||||
}
|
||||
mock.lockmustEmbedUnimplementedSshServerHostServer.RLock()
|
||||
calls = mock.calls.mustEmbedUnimplementedSshServerHostServer
|
||||
mock.lockmustEmbedUnimplementedSshServerHostServer.RUnlock()
|
||||
return calls
|
||||
}
|
||||
|
|
@ -1,117 +0,0 @@
|
|||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/codespace"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
|
||||
"github.com/cli/cli/v2/internal/codespaces/rpc/ssh"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const (
|
||||
ServerPort = 50051
|
||||
)
|
||||
|
||||
// Mock responses for the `GetRunningServer` RPC method
|
||||
var (
|
||||
JupyterPort = 1234
|
||||
JupyterServerUrl = "http://localhost:1234?token=1234"
|
||||
JupyterMessage = ""
|
||||
JupyterResult = true
|
||||
)
|
||||
|
||||
// Mock responses for the `RebuildContainerAsync` RPC method
|
||||
var (
|
||||
RebuildContainer = true
|
||||
)
|
||||
|
||||
// Mock responses for the `NotifyCodespaceOfClientActivity` RPC method
|
||||
// NotifyMessage is used to store the activity that was sent to the server
|
||||
var (
|
||||
NotifyMessage = ""
|
||||
NotifyResult = true
|
||||
NotifyReceivedActivity = ""
|
||||
)
|
||||
|
||||
// Mock responses for the `StartRemoteServerAsync` RPC method
|
||||
var (
|
||||
SshServerPort = 1234
|
||||
SshUser = "test"
|
||||
SshMessage = ""
|
||||
SshResult = true
|
||||
)
|
||||
|
||||
type server struct {
|
||||
jupyter.UnimplementedJupyterServerHostServer
|
||||
codespace.CodespaceHostServer
|
||||
ssh.SshServerHostServer
|
||||
}
|
||||
|
||||
func (s *server) GetRunningServer(ctx context.Context, in *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
|
||||
return &jupyter.GetRunningServerResponse{
|
||||
Port: strconv.Itoa(JupyterPort),
|
||||
ServerUrl: JupyterServerUrl,
|
||||
Message: JupyterMessage,
|
||||
Result: JupyterResult,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *server) RebuildContainerAsync(ctx context.Context, in *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
|
||||
return &codespace.RebuildContainerResponse{
|
||||
RebuildContainer: RebuildContainer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *server) NotifyCodespaceOfClientActivity(ctx context.Context, in *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) {
|
||||
// If there is at least one client activity, set NotifyReceivedActivity to the first one (should be "connected")
|
||||
if len(in.GetClientActivities()) > 0 {
|
||||
NotifyReceivedActivity = in.GetClientActivities()[0]
|
||||
}
|
||||
|
||||
return &codespace.NotifyCodespaceOfClientActivityResponse{
|
||||
Message: NotifyMessage,
|
||||
Result: NotifyResult,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *server) StartRemoteServerAsync(ctx context.Context, in *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
|
||||
return &ssh.StartRemoteServerResponse{
|
||||
ServerPort: strconv.Itoa(SshServerPort),
|
||||
User: SshUser,
|
||||
Message: SshMessage,
|
||||
Result: SshResult,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Starts the mock gRPC server listening on port 50051
|
||||
func StartServer(ctx context.Context) error {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
s := grpc.NewServer()
|
||||
jupyter.RegisterJupyterServerHostServer(s, &server{})
|
||||
codespace.RegisterCodespaceHostServer(s, &server{})
|
||||
ssh.RegisterSshServerHostServer(s, &server{})
|
||||
|
||||
ch := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Serve(listener); err != nil {
|
||||
ch <- fmt.Errorf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.Stop()
|
||||
return ctx.Err()
|
||||
case err := <-ch:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -29,7 +29,7 @@ func (s *Session) GetKeepAliveReason() string {
|
|||
}
|
||||
|
||||
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) {
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
if err != nil {
|
||||
return liveshare.ChannelID{}, err
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue