diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 625b1f596..2dc81ba64 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -8,7 +8,6 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/cli/cli/v2/internal/codespaces/api" - "github.com/cli/cli/v2/internal/codespaces/rpc" "github.com/cli/cli/v2/pkg/liveshare" ) @@ -80,16 +79,3 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session Logger: sessionLogger, }) } - -// Helper function to connect to the internal RPC server and return an RPC invoker for it -func CreateRPCInvoker(ctx context.Context, session *liveshare.Session, token string) (*rpc.Invoker, error) { - ctx, cancel := context.WithTimeout(ctx, rpc.ConnectionTimeout) - defer cancel() - - invoker, err := rpc.Connect(ctx, session, token) - if err != nil { - return nil, fmt.Errorf("error connecting to internal server: %w", err) - } - - return invoker, nil -} diff --git a/internal/codespaces/rpc/invoker.go b/internal/codespaces/rpc/invoker.go index 2aa2629da..6cb94f2c3 100644 --- a/internal/codespaces/rpc/invoker.go +++ b/internal/codespaces/rpc/invoker.go @@ -12,7 +12,6 @@ import ( "github.com/cli/cli/v2/internal/codespaces/rpc/jupyter" "github.com/cli/cli/v2/pkg/liveshare" - "golang.org/x/crypto/ssh" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" @@ -28,35 +27,45 @@ const ( codespacesInternalSessionName = "CodespacesInternal" ) -type liveshareSession interface { +type Invoker interface { Close() error - GetSharedServers(context.Context) ([]*liveshare.Port, error) - KeepAlive(string) - OpenStreamingChannel(context.Context, liveshare.ChannelID) (ssh.Channel, error) - StartSharing(context.Context, string, int) (liveshare.ChannelID, error) - StartSSHServer(context.Context) (int, string, error) - StartSSHServerWithOptions(context.Context, liveshare.StartSSHServerOptions) (int, string, error) - RebuildContainer(context.Context, bool) error + StartJupyterServer(ctx context.Context) (int, string, error) + RebuildContainer(ctx context.Context, full bool) error + StartSSHServer(ctx context.Context) (int, string, error) + StartSSHServerWithOptions(ctx context.Context, options liveshare.StartSSHServerOptions) (int, string, error) } -type Invoker struct { +type invoker struct { conn *grpc.ClientConn token string - session liveshareSession + session liveshare.LiveshareSession listener net.Listener jupyterClient jupyter.JupyterServerHostClient cancelPF context.CancelFunc } -// Finds a free port to listen on and creates a new gRPC client that connects to that port -func Connect(ctx context.Context, session liveshareSession, token string) (*Invoker, error) { +// Connects to the internal RPC server and returns a new invoker for it +func CreateInvoker(ctx context.Context, session liveshare.LiveshareSession, token string) (Invoker, error) { + ctx, cancel := context.WithTimeout(ctx, ConnectionTimeout) + defer cancel() + + invoker, err := connect(ctx, session, token) + if err != nil { + return nil, fmt.Errorf("error connecting to internal server: %w", err) + } + + return invoker, nil +} + +// Finds a free port to listen on and creates a new RPC invoker that connects to that port +func connect(ctx context.Context, session liveshare.LiveshareSession, token string) (Invoker, error) { listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0)) if err != nil { return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err) } localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port) - invoker := &Invoker{ + invoker := &invoker{ token: token, session: session, listener: listener, @@ -113,12 +122,12 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Invo } // Closes the gRPC connection -func (g *Invoker) Close() error { - g.cancelPF() +func (i *invoker) Close() error { + i.cancelPF() // Closing the local listener effectively closes the gRPC connection - if err := g.listener.Close(); err != nil { - g.conn.Close() // If we fail to close the listener, explicitly close the gRPC connection and ignore any error + if err := i.listener.Close(); err != nil { + i.conn.Close() // If we fail to close the listener, explicitly close the gRPC connection and ignore any error return fmt.Errorf("failed to close local tcp port listener: %w", err) } @@ -126,17 +135,17 @@ func (g *Invoker) Close() error { } // Appends the authentication token to the gRPC context -func (g *Invoker) appendMetadata(ctx context.Context) context.Context { - return metadata.AppendToOutgoingContext(ctx, "Authorization", "Bearer "+g.token) +func (i *invoker) appendMetadata(ctx context.Context) context.Context { + return metadata.AppendToOutgoingContext(ctx, "Authorization", "Bearer "+i.token) } // Starts a remote JupyterLab server to allow the user to connect to the codespace via JupyterLab in their browser -func (g *Invoker) StartJupyterServer(ctx context.Context) (port int, serverUrl string, err error) { - ctx = g.appendMetadata(ctx) +func (i *invoker) StartJupyterServer(ctx context.Context) (port int, serverUrl string, err error) { + ctx = i.appendMetadata(ctx) ctx, cancel := context.WithTimeout(ctx, requestTimeout) defer cancel() - response, err := g.jupyterClient.GetRunningServer(ctx, &jupyter.GetRunningServerRequest{}) + response, err := i.jupyterClient.GetRunningServer(ctx, &jupyter.GetRunningServerRequest{}) if err != nil { return 0, "", fmt.Errorf("failed to invoke JupyterLab RPC: %w", err) } @@ -154,16 +163,16 @@ func (g *Invoker) StartJupyterServer(ctx context.Context) (port int, serverUrl s } // Rebuilds the container using cached layers by default or from scratch if full is true -func (g *Invoker) RebuildContainer(ctx context.Context, full bool) error { - return g.session.RebuildContainer(ctx, full) +func (i *invoker) RebuildContainer(ctx context.Context, full bool) error { + return i.session.RebuildContainer(ctx, full) } // Starts a remote SSH server to allow the user to connect to the codespace via SSH -func (g *Invoker) StartSSHServer(ctx context.Context) (int, string, error) { - return g.session.StartSSHServer(ctx) +func (i *invoker) StartSSHServer(ctx context.Context) (int, string, error) { + return i.session.StartSSHServer(ctx) } // Starts a remote SSH server to allow the user to connect to the codespace via SSH -func (g *Invoker) StartSSHServerWithOptions(ctx context.Context, options liveshare.StartSSHServerOptions) (int, string, error) { - return g.session.StartSSHServerWithOptions(ctx, options) +func (i *invoker) StartSSHServerWithOptions(ctx context.Context, options liveshare.StartSSHServerOptions) (int, string, error) { + return i.session.StartSSHServerWithOptions(ctx, options) } diff --git a/internal/codespaces/rpc/invoker_test.go b/internal/codespaces/rpc/invoker_test.go index 8114f20aa..f8f60d1ff 100644 --- a/internal/codespaces/rpc/invoker_test.go +++ b/internal/codespaces/rpc/invoker_test.go @@ -32,10 +32,10 @@ func startServer(t *testing.T) { }) } -func connect(t *testing.T) (invoker *Invoker) { +func createTestInvoker(t *testing.T) Invoker { t.Helper() - invoker, err := Connect(context.Background(), &rpctest.Session{}, "token") + invoker, err := CreateInvoker(context.Background(), &rpctest.Session{}, "token") //connect(context.Background(), &rpctest.Session{}, "token") if err != nil { t.Fatalf("error connecting to internal server: %v", err) } @@ -50,8 +50,7 @@ func connect(t *testing.T) (invoker *Invoker) { // Test that the RPC invoker returns the correct port and URL when the JupyterLab server starts successfully func TestStartJupyterServerSuccess(t *testing.T) { startServer(t) - invoker := connect(t) - + invoker := createTestInvoker(t) port, url, err := invoker.StartJupyterServer(context.Background()) if err != nil { t.Fatalf("expected %v, got %v", nil, err) @@ -67,7 +66,7 @@ func TestStartJupyterServerSuccess(t *testing.T) { // Test that the RPC invoker returns an error when the JupyterLab server fails to start func TestStartJupyterServerFailure(t *testing.T) { startServer(t) - invoker := connect(t) + invoker := createTestInvoker(t) rpctest.JupyterMessage = "error message" rpctest.JupyterResult = false errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", rpctest.JupyterMessage) diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 652f8c610..fc2508f74 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -11,6 +11,7 @@ import ( "time" "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/cli/cli/v2/internal/codespaces/rpc" "github.com/cli/cli/v2/internal/text" "github.com/cli/cli/v2/pkg/liveshare" ) @@ -59,7 +60,7 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl localPort := listen.Addr().(*net.TCPAddr).Port progress.StartProgressIndicatorWithLabel("Fetching SSH Details") - invoker, err := CreateRPCInvoker(ctx, session, "") + invoker, err := rpc.CreateInvoker(ctx, session, "") if err != nil { return err } diff --git a/pkg/cmd/codespace/jupyter.go b/pkg/cmd/codespace/jupyter.go index 0dab62569..285d34529 100644 --- a/pkg/cmd/codespace/jupyter.go +++ b/pkg/cmd/codespace/jupyter.go @@ -6,7 +6,7 @@ import ( "net" "strings" - "github.com/cli/cli/v2/internal/codespaces" + "github.com/cli/cli/v2/internal/codespaces/rpc" "github.com/cli/cli/v2/pkg/liveshare" "github.com/spf13/cobra" ) @@ -45,7 +45,7 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) { defer safeClose(session, &err) a.StartProgressIndicatorWithLabel("Starting JupyterLab on codespace") - invoker, err := codespaces.CreateRPCInvoker(ctx, session, "") + invoker, err := rpc.CreateInvoker(ctx, session, "") if err != nil { return err } diff --git a/pkg/cmd/codespace/logs.go b/pkg/cmd/codespace/logs.go index f9e9ebd63..04a458251 100644 --- a/pkg/cmd/codespace/logs.go +++ b/pkg/cmd/codespace/logs.go @@ -6,6 +6,7 @@ import ( "net" "github.com/cli/cli/v2/internal/codespaces" + "github.com/cli/cli/v2/internal/codespaces/rpc" "github.com/cli/cli/v2/pkg/liveshare" "github.com/spf13/cobra" ) @@ -56,7 +57,7 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err localPort := listen.Addr().(*net.TCPAddr).Port a.StartProgressIndicatorWithLabel("Fetching SSH Details") - invoker, err := codespaces.CreateRPCInvoker(ctx, session, "") + invoker, err := rpc.CreateInvoker(ctx, session, "") if err != nil { return err } diff --git a/pkg/cmd/codespace/rebuild.go b/pkg/cmd/codespace/rebuild.go index 98abd825e..9c7fd831f 100644 --- a/pkg/cmd/codespace/rebuild.go +++ b/pkg/cmd/codespace/rebuild.go @@ -4,8 +4,8 @@ import ( "context" "fmt" - "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/cli/cli/v2/internal/codespaces/rpc" "github.com/spf13/cobra" ) @@ -52,7 +52,7 @@ func (a *App) Rebuild(ctx context.Context, codespaceName string, full bool) (err } defer safeClose(session, &err) - invoker, err := codespaces.CreateRPCInvoker(ctx, session, "") + invoker, err := rpc.CreateInvoker(ctx, session, "") if err != nil { return err } diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index 10320192d..45c8d16a9 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -20,6 +20,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/internal/codespaces" "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/cli/cli/v2/internal/codespaces/rpc" "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/liveshare" @@ -173,7 +174,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e defer safeClose(session, &err) a.StartProgressIndicatorWithLabel("Fetching SSH Details") - invoker, err := codespaces.CreateRPCInvoker(ctx, session, "") + invoker, err := rpc.CreateInvoker(ctx, session, "") if err != nil { return err } @@ -514,7 +515,7 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro } else { defer safeClose(session, &err) - invoker, err := codespaces.CreateRPCInvoker(ctx, session, "") + invoker, err := rpc.CreateInvoker(ctx, session, "") if err != nil { result.err = fmt.Errorf("error connecting to codespace: %w", err) } else { diff --git a/pkg/liveshare/session.go b/pkg/liveshare/session.go index 5e4c42f8a..a854d3874 100644 --- a/pkg/liveshare/session.go +++ b/pkg/liveshare/session.go @@ -19,6 +19,18 @@ type ChannelID struct { name, condition string } +// Interface to allow the mocking of the liveshare session +type LiveshareSession interface { + Close() error + GetSharedServers(context.Context) ([]*Port, error) + KeepAlive(string) + OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error) + StartSharing(context.Context, string, int) (ChannelID, error) + StartSSHServer(context.Context) (int, string, error) + StartSSHServerWithOptions(context.Context, StartSSHServerOptions) (int, string, error) + RebuildContainer(context.Context, bool) error +} + // A Session represents the session between a connected Live Share client and server. type Session struct { ssh *sshSession