* Migrate all Codespaces operations from Live Share to Dev Tunnels * Remove Live Share references * Fix linting errors * Update comments, remove deps, add uint16 bound checks * Fix tests and move keep-alive logic to forwarder * Address comments * Updated mock port forwarder * Fix CodeQL error * Update comment * Update func name * Add missing connection close * Fix linting error * https -> http * Update defer * Fix tests
313 lines
9 KiB
Go
313 lines
9 KiB
Go
package rpc
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"strconv"
|
|
"testing"
|
|
|
|
"github.com/cli/cli/v2/internal/codespaces/rpc/codespace"
|
|
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
|
|
"github.com/cli/cli/v2/internal/codespaces/rpc/ssh"
|
|
rpctest "github.com/cli/cli/v2/internal/codespaces/rpc/test"
|
|
"google.golang.org/grpc"
|
|
)
|
|
|
|
type mockServer struct {
|
|
jupyter.JupyterServerHostServerMock
|
|
codespace.CodespaceHostServerMock
|
|
ssh.SshServerHostServerMock
|
|
}
|
|
|
|
func newMockServer() *mockServer {
|
|
server := &mockServer{}
|
|
|
|
server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityFunc = func(context.Context, *codespace.NotifyCodespaceOfClientActivityRequest) (*codespace.NotifyCodespaceOfClientActivityResponse, error) {
|
|
return &codespace.NotifyCodespaceOfClientActivityResponse{
|
|
Message: "",
|
|
Result: true,
|
|
}, nil
|
|
}
|
|
|
|
return server
|
|
}
|
|
|
|
// runTestGrpcServer serves grpc requests over the provided Listener using the mockServer for mocked callbacks.
|
|
// It does not return until the Context is cancelled and the server fully shuts down.
|
|
func runTestGrpcServer(ctx context.Context, listener net.Listener, server *mockServer) error {
|
|
s := grpc.NewServer()
|
|
jupyter.RegisterJupyterServerHostServer(s, server)
|
|
codespace.RegisterCodespaceHostServer(s, server)
|
|
ssh.RegisterSshServerHostServer(s, server)
|
|
|
|
ch := make(chan error, 1)
|
|
go func() { ch <- s.Serve(listener) }()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
s.Stop()
|
|
<-ch
|
|
return nil
|
|
case err := <-ch:
|
|
return err
|
|
}
|
|
}
|
|
|
|
// createTestInvoker is the main test setup function. It returns an Invoker using the provided mockServer, as well as a shutdown function.
|
|
// The Invoker does not need to be closed directly, that will be handled by the shutdown function.
|
|
func createTestInvoker(t *testing.T, server *mockServer) (Invoker, func(), error) {
|
|
listener, err := net.Listen("tcp", "127.0.0.1:16634")
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to listen: %w", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
ch := make(chan error)
|
|
go func() { ch <- runTestGrpcServer(ctx, listener, server) }()
|
|
|
|
close := func() {
|
|
cancel()
|
|
<-ch
|
|
listener.Close()
|
|
}
|
|
|
|
// Create a new invoker with a mock port forwarder
|
|
invoker, err := CreateInvoker(context.Background(), rpctest.PortForwarder{})
|
|
if err != nil {
|
|
close()
|
|
return nil, nil, fmt.Errorf("error connecting to internal server: %w", err)
|
|
}
|
|
|
|
return invoker, func() {
|
|
invoker.Close()
|
|
close()
|
|
}, nil
|
|
}
|
|
|
|
// Test that the RPC invoker notifies the codespace of client activity on connection
|
|
func verifyNotifyCodespaceOfClientActivity(t *testing.T, server *mockServer) {
|
|
calls := server.CodespaceHostServerMock.NotifyCodespaceOfClientActivityCalls()
|
|
if len(calls) == 0 {
|
|
t.Fatalf("no client activity calls")
|
|
}
|
|
|
|
for _, call := range calls {
|
|
activities := call.NotifyCodespaceOfClientActivityRequest.GetClientActivities()
|
|
if activities[0] == connectedEventName {
|
|
return
|
|
}
|
|
}
|
|
|
|
t.Fatalf("no activity named %s", connectedEventName)
|
|
}
|
|
|
|
// Test that the RPC invoker returns the correct port and URL when the JupyterLab server starts successfully
|
|
func TestStartJupyterServerSuccess(t *testing.T) {
|
|
resp := jupyter.GetRunningServerResponse{
|
|
Port: strconv.Itoa(1234),
|
|
ServerUrl: "http://localhost:1234?token=1234",
|
|
Message: "",
|
|
Result: true,
|
|
}
|
|
|
|
server := newMockServer()
|
|
server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
|
|
return &resp, nil
|
|
}
|
|
|
|
invoker, stop, err := createTestInvoker(t, server)
|
|
if err != nil {
|
|
t.Fatalf("error connecting to internal server: %v", err)
|
|
}
|
|
defer stop()
|
|
|
|
port, url, err := invoker.StartJupyterServer(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("expected %v, got %v", nil, err)
|
|
}
|
|
if strconv.Itoa(port) != resp.Port {
|
|
t.Fatalf("expected %s, got %d", resp.Port, port)
|
|
}
|
|
if url != resp.ServerUrl {
|
|
t.Fatalf("expected %s, got %s", resp.ServerUrl, url)
|
|
}
|
|
|
|
verifyNotifyCodespaceOfClientActivity(t, server)
|
|
}
|
|
|
|
// Test that the RPC invoker returns an error when the JupyterLab server fails to start
|
|
func TestStartJupyterServerFailure(t *testing.T) {
|
|
resp := jupyter.GetRunningServerResponse{
|
|
Port: strconv.Itoa(1234),
|
|
ServerUrl: "http://localhost:1234?token=1234",
|
|
Message: "error message",
|
|
Result: false,
|
|
}
|
|
|
|
server := newMockServer()
|
|
server.JupyterServerHostServerMock.GetRunningServerFunc = func(context.Context, *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
|
|
return &resp, nil
|
|
}
|
|
|
|
invoker, stop, err := createTestInvoker(t, server)
|
|
if err != nil {
|
|
t.Fatalf("error connecting to internal server: %v", err)
|
|
}
|
|
defer stop()
|
|
|
|
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", resp.Message)
|
|
port, url, err := invoker.StartJupyterServer(context.Background())
|
|
if err.Error() != errorMessage {
|
|
t.Fatalf("expected %v, got %v", errorMessage, err)
|
|
}
|
|
if port != 0 {
|
|
t.Fatalf("expected %d, got %d", 0, port)
|
|
}
|
|
if url != "" {
|
|
t.Fatalf("expected %s, got %s", "", url)
|
|
}
|
|
|
|
verifyNotifyCodespaceOfClientActivity(t, server)
|
|
}
|
|
|
|
// Test that the RPC invoker doesn't throw an error when requesting an incremental rebuild
|
|
func TestRebuildContainerIncremental(t *testing.T) {
|
|
resp := codespace.RebuildContainerResponse{
|
|
RebuildContainer: true,
|
|
}
|
|
|
|
server := newMockServer()
|
|
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
|
|
return &resp, nil
|
|
}
|
|
|
|
invoker, stop, err := createTestInvoker(t, server)
|
|
if err != nil {
|
|
t.Fatalf("error connecting to internal server: %v", err)
|
|
}
|
|
defer stop()
|
|
|
|
err = invoker.RebuildContainer(context.Background(), false)
|
|
if err != nil {
|
|
t.Fatalf("expected %v, got %v", nil, err)
|
|
}
|
|
|
|
verifyNotifyCodespaceOfClientActivity(t, server)
|
|
}
|
|
|
|
// Test that the RPC invoker doesn't throw an error when requesting a full rebuild
|
|
func TestRebuildContainerFull(t *testing.T) {
|
|
resp := codespace.RebuildContainerResponse{
|
|
RebuildContainer: true,
|
|
}
|
|
|
|
server := newMockServer()
|
|
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
|
|
return &resp, nil
|
|
}
|
|
|
|
invoker, stop, err := createTestInvoker(t, server)
|
|
if err != nil {
|
|
t.Fatalf("error connecting to internal server: %v", err)
|
|
}
|
|
defer stop()
|
|
|
|
err = invoker.RebuildContainer(context.Background(), true)
|
|
if err != nil {
|
|
t.Fatalf("expected %v, got %v", nil, err)
|
|
}
|
|
|
|
verifyNotifyCodespaceOfClientActivity(t, server)
|
|
}
|
|
|
|
// Test that the RPC invoker throws an error when the rebuild fails
|
|
func TestRebuildContainerFailure(t *testing.T) {
|
|
resp := codespace.RebuildContainerResponse{
|
|
RebuildContainer: false,
|
|
}
|
|
|
|
server := newMockServer()
|
|
server.RebuildContainerAsyncFunc = func(context.Context, *codespace.RebuildContainerRequest) (*codespace.RebuildContainerResponse, error) {
|
|
return &resp, nil
|
|
}
|
|
|
|
invoker, stop, err := createTestInvoker(t, server)
|
|
if err != nil {
|
|
t.Fatalf("error connecting to internal server: %v", err)
|
|
}
|
|
defer stop()
|
|
|
|
errorMessage := "couldn't rebuild codespace"
|
|
err = invoker.RebuildContainer(context.Background(), true)
|
|
if err.Error() != errorMessage {
|
|
t.Fatalf("expected %v, got %v", errorMessage, err)
|
|
}
|
|
}
|
|
|
|
// Test that the RPC invoker returns the correct port and user when the SSH server starts successfully
|
|
func TestStartSSHServerSuccess(t *testing.T) {
|
|
resp := ssh.StartRemoteServerResponse{
|
|
ServerPort: strconv.Itoa(1234),
|
|
User: "test",
|
|
Message: "",
|
|
Result: true,
|
|
}
|
|
|
|
server := newMockServer()
|
|
server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
|
|
return &resp, nil
|
|
}
|
|
|
|
invoker, stop, err := createTestInvoker(t, server)
|
|
if err != nil {
|
|
t.Fatalf("error connecting to internal server: %v", err)
|
|
}
|
|
defer stop()
|
|
|
|
port, user, err := invoker.StartSSHServer(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("expected %v, got %v", nil, err)
|
|
}
|
|
if strconv.Itoa(port) != resp.ServerPort {
|
|
t.Fatalf("expected %s, got %d", resp.ServerPort, port)
|
|
}
|
|
if user != resp.User {
|
|
t.Fatalf("expected %s, got %s", resp.User, user)
|
|
}
|
|
|
|
verifyNotifyCodespaceOfClientActivity(t, server)
|
|
}
|
|
|
|
// Test that the RPC invoker returns an error when the SSH server fails to start
|
|
func TestStartSSHServerFailure(t *testing.T) {
|
|
resp := ssh.StartRemoteServerResponse{
|
|
ServerPort: strconv.Itoa(1234),
|
|
User: "test",
|
|
Message: "error message",
|
|
Result: false,
|
|
}
|
|
|
|
server := newMockServer()
|
|
server.StartRemoteServerAsyncFunc = func(context.Context, *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
|
|
return &resp, nil
|
|
}
|
|
|
|
invoker, stop, err := createTestInvoker(t, server)
|
|
if err != nil {
|
|
t.Fatalf("error connecting to internal server: %v", err)
|
|
}
|
|
defer stop()
|
|
|
|
errorMessage := fmt.Sprintf("failed to start SSH server: %s", resp.Message)
|
|
port, user, err := invoker.StartSSHServer(context.Background())
|
|
if err.Error() != errorMessage {
|
|
t.Fatalf("expected %v, got %v", errorMessage, err)
|
|
}
|
|
if port != 0 {
|
|
t.Fatalf("expected %d, got %d", 0, port)
|
|
}
|
|
if user != "" {
|
|
t.Fatalf("expected %s, got %s", "", user)
|
|
}
|
|
}
|