Create Invoker object

This commit is contained in:
David Gardiner 2023-01-03 13:36:05 -08:00
parent faabdc247b
commit 731ba682f2
17 changed files with 152 additions and 105 deletions

View file

@ -8,6 +8,7 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/pkg/liveshare"
)
@ -79,3 +80,16 @@ func ConnectToLiveshare(ctx context.Context, progress progressIndicator, session
Logger: sessionLogger,
})
}
// Helper function to connect to the internal RPC server and return an RPC invoker for it
func CreateRPCInvoker(ctx context.Context, session *liveshare.Session, token string) (*rpc.Invoker, error) {
ctx, cancel := context.WithTimeout(ctx, rpc.ConnectionTimeout)
defer cancel()
invoker, err := rpc.Connect(ctx, session, token)
if err != nil {
return nil, fmt.Errorf("error connecting to internal server: %w", err)
}
return invoker, nil
}

View file

@ -1,4 +1,4 @@
package grpc
package rpc
// gRPC client implementation to be able to connect to the gRPC server and perform the following operations:
// - Start a remote JupyterLab server
@ -10,7 +10,7 @@ import (
"strconv"
"time"
"github.com/cli/cli/v2/internal/codespaces/rpc/grpc/jupyter"
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
"github.com/cli/cli/v2/pkg/liveshare"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc"
@ -20,7 +20,7 @@ import (
const (
ConnectionTimeout = 5 * time.Second
RequestTimeout = 30 * time.Second
requestTimeout = 30 * time.Second
)
const (
@ -28,30 +28,37 @@ const (
codespacesInternalSessionName = "CodespacesInternal"
)
type Client struct {
type liveshareSession interface {
Close() error
GetSharedServers(context.Context) ([]*liveshare.Port, error)
KeepAlive(string)
OpenStreamingChannel(context.Context, liveshare.ChannelID) (ssh.Channel, error)
StartSharing(context.Context, string, int) (liveshare.ChannelID, error)
StartSSHServer(context.Context) (int, string, error)
StartSSHServerWithOptions(context.Context, liveshare.StartSSHServerOptions) (int, string, error)
RebuildContainer(context.Context, bool) error
}
type Invoker struct {
conn *grpc.ClientConn
token string
session liveshareSession
listener net.Listener
jupyterClient jupyter.JupyterServerHostClient
cancelPF context.CancelFunc
}
type liveshareSession interface {
KeepAlive(string)
OpenStreamingChannel(context.Context, liveshare.ChannelID) (ssh.Channel, error)
StartSharing(context.Context, string, int) (liveshare.ChannelID, error)
}
// Finds a free port to listen on and creates a new gRPC client that connects to that port
func Connect(ctx context.Context, session liveshareSession, token string) (*Client, error) {
func Connect(ctx context.Context, session liveshareSession, token string) (*Invoker, error) {
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0))
if err != nil {
return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err)
}
localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port)
client := &Client{
invoker := &Invoker{
token: token,
session: session,
listener: listener,
}
@ -70,7 +77,7 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie
// or once the gRPC connection is closed. pfcancel is retained
// to close the PF whenever we close the gRPC connection.
pfctx, pfcancel := context.WithCancel(connectctx)
client.cancelPF = pfcancel
invoker.cancelPF = pfcancel
// Tunnel the remote gRPC server port to the local port
go func() {
@ -99,14 +106,14 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie
}
}
client.conn = conn
client.jupyterClient = jupyter.NewJupyterServerHostClient(conn)
invoker.conn = conn
invoker.jupyterClient = jupyter.NewJupyterServerHostClient(conn)
return client, nil
return invoker, nil
}
// Closes the gRPC connection
func (g *Client) Close() error {
func (g *Invoker) Close() error {
g.cancelPF()
// Closing the local listener effectively closes the gRPC connection
@ -119,13 +126,15 @@ func (g *Client) Close() error {
}
// Appends the authentication token to the gRPC context
func (g *Client) appendMetadata(ctx context.Context) context.Context {
func (g *Invoker) appendMetadata(ctx context.Context) context.Context {
return metadata.AppendToOutgoingContext(ctx, "Authorization", "Bearer "+g.token)
}
// Starts a remote JupyterLab server to allow the user to connect to the codespace via JupyterLab in their browser
func (g *Client) StartJupyterServer(ctx context.Context) (port int, serverUrl string, err error) {
func (g *Invoker) StartJupyterServer(ctx context.Context) (port int, serverUrl string, err error) {
ctx = g.appendMetadata(ctx)
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
defer cancel()
response, err := g.jupyterClient.GetRunningServer(ctx, &jupyter.GetRunningServerRequest{})
if err != nil {
@ -143,3 +152,18 @@ func (g *Client) StartJupyterServer(ctx context.Context) (port int, serverUrl st
return port, response.ServerUrl, err
}
// Rebuilds the container using cached layers by default or from scratch if full is true
func (g *Invoker) RebuildContainer(ctx context.Context, full bool) error {
return g.session.RebuildContainer(ctx, full)
}
// Starts a remote SSH server to allow the user to connect to the codespace via SSH
func (g *Invoker) StartSSHServer(ctx context.Context) (int, string, error) {
return g.session.StartSSHServer(ctx)
}
// Starts a remote SSH server to allow the user to connect to the codespace via SSH
func (g *Invoker) StartSSHServerWithOptions(ctx context.Context, options liveshare.StartSSHServerOptions) (int, string, error) {
return g.session.StartSSHServerWithOptions(ctx, options)
}

View file

@ -1,4 +1,4 @@
package grpc
package rpc
import (
"context"
@ -7,7 +7,7 @@ import (
"os"
"testing"
grpctest "github.com/cli/cli/v2/internal/codespaces/rpc/grpc/test"
rpctest "github.com/cli/cli/v2/internal/codespaces/rpc/test"
)
func startServer(t *testing.T) {
@ -20,7 +20,7 @@ func startServer(t *testing.T) {
// Start the gRPC server in the background
go func() {
err := grpctest.StartServer(ctx)
err := rpctest.StartServer(ctx)
if err != nil && err != context.Canceled {
log.Println(fmt.Errorf("error starting test server: %v", err))
}
@ -32,46 +32,46 @@ func startServer(t *testing.T) {
})
}
func connect(t *testing.T) (client *Client) {
func connect(t *testing.T) (invoker *Invoker) {
t.Helper()
client, err := Connect(context.Background(), &grpctest.Session{}, "token")
invoker, err := Connect(context.Background(), &rpctest.Session{}, "token")
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
t.Cleanup(func() {
client.Close()
invoker.Close()
})
return client
return invoker
}
// Test that the gRPC client returns the correct port and URL when the JupyterLab server starts successfully
// Test that the RPC invoker returns the correct port and URL when the JupyterLab server starts successfully
func TestStartJupyterServerSuccess(t *testing.T) {
startServer(t)
client := connect(t)
invoker := connect(t)
port, url, err := client.StartJupyterServer(context.Background())
port, url, err := invoker.StartJupyterServer(context.Background())
if err != nil {
t.Fatalf("expected %v, got %v", nil, err)
}
if port != grpctest.JupyterPort {
t.Fatalf("expected %d, got %d", grpctest.JupyterPort, port)
if port != rpctest.JupyterPort {
t.Fatalf("expected %d, got %d", rpctest.JupyterPort, port)
}
if url != grpctest.JupyterServerUrl {
t.Fatalf("expected %s, got %s", grpctest.JupyterServerUrl, url)
if url != rpctest.JupyterServerUrl {
t.Fatalf("expected %s, got %s", rpctest.JupyterServerUrl, url)
}
}
// Test that the gRPC client returns an error when the JupyterLab server fails to start
// Test that the RPC invoker returns an error when the JupyterLab server fails to start
func TestStartJupyterServerFailure(t *testing.T) {
startServer(t)
client := connect(t)
grpctest.JupyterMessage = "error message"
grpctest.JupyterResult = false
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", grpctest.JupyterMessage)
port, url, err := client.StartJupyterServer(context.Background())
invoker := connect(t)
rpctest.JupyterMessage = "error message"
rpctest.JupyterResult = false
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", rpctest.JupyterMessage)
port, url, err := invoker.StartJupyterServer(context.Background())
if err.Error() != errorMessage {
t.Fatalf("expected %v, got %v", errorMessage, err)
}

View file

@ -1,51 +0,0 @@
package rpc
import (
"context"
"fmt"
"github.com/cli/cli/v2/internal/codespaces/rpc/grpc"
"github.com/cli/cli/v2/pkg/liveshare"
)
// Helper function to connect to the GRPC server in the codespace
func connectToGRPCServer(ctx context.Context, session *liveshare.Session, token string) (*grpc.Client, error) {
ctx, cancel := context.WithTimeout(ctx, grpc.ConnectionTimeout)
defer cancel()
client, err := grpc.Connect(ctx, session, token)
if err != nil {
return nil, fmt.Errorf("error connecting to internal server: %w", err)
}
return client, nil
}
func RebuildContainer(ctx context.Context, session *liveshare.Session, full bool) error {
return session.RebuildContainer(ctx, full)
}
func StartSSHServer(ctx context.Context, session *liveshare.Session) (int, string, error) {
return session.StartSSHServer(ctx)
}
func StartSSHServerWithOptions(ctx context.Context, session *liveshare.Session, options liveshare.StartSSHServerOptions) (int, string, error) {
return session.StartSSHServerWithOptions(ctx, options)
}
func StartJupyterServer(ctx context.Context, session *liveshare.Session) (int, string, error) {
client, err := connectToGRPCServer(ctx, session, "")
if err != nil {
return 0, "", err
}
ctx, cancel := context.WithTimeout(ctx, grpc.RequestTimeout)
defer cancel()
serverPort, serverUrl, err := client.StartJupyterServer(ctx)
if err != nil {
return 0, "", err
}
return serverPort, serverUrl, nil
}

View file

@ -6,7 +6,7 @@ import (
"net"
"strconv"
"github.com/cli/cli/v2/internal/codespaces/rpc/grpc/jupyter"
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
"google.golang.org/grpc"
)

View file

@ -13,6 +13,26 @@ type Session struct {
channel ssh.Channel
}
func (*Session) Close() error {
panic("unimplemented")
}
func (*Session) GetSharedServers(context.Context) ([]*liveshare.Port, error) {
panic("unimplemented")
}
func (*Session) RebuildContainer(context.Context, bool) error {
panic("unimplemented")
}
func (*Session) StartSSHServer(context.Context) (int, string, error) {
panic("unimplemented")
}
func (*Session) StartSSHServerWithOptions(context.Context, liveshare.StartSSHServerOptions) (int, string, error) {
panic("unimplemented")
}
func (s *Session) KeepAlive(reason string) {
}

View file

@ -11,7 +11,6 @@ import (
"time"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/internal/text"
"github.com/cli/cli/v2/pkg/liveshare"
)
@ -60,11 +59,17 @@ func PollPostCreateStates(ctx context.Context, progress progressIndicator, apiCl
localPort := listen.Addr().(*net.TCPAddr).Port
progress.StartProgressIndicatorWithLabel("Fetching SSH Details")
defer progress.StopProgressIndicator()
remoteSSHServerPort, sshUser, err := rpc.StartSSHServer(ctx, session)
invoker, err := CreateRPCInvoker(ctx, session, "")
if err != nil {
return err
}
defer safeClose(invoker, &err)
remoteSSHServerPort, sshUser, err := invoker.StartSSHServer(ctx)
if err != nil {
return fmt.Errorf("error getting ssh server details: %w", err)
}
progress.StopProgressIndicator()
progress.StartProgressIndicatorWithLabel("Fetching status")
tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
@ -125,3 +130,9 @@ func getPostCreateOutput(ctx context.Context, tunnelPort int, user string) ([]Po
return output.Steps, nil
}
func safeClose(closer io.Closer, err *error) {
if closeErr := closer.Close(); *err == nil {
*err = closeErr
}
}

View file

@ -6,7 +6,7 @@ import (
"net"
"strings"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/spf13/cobra"
)
@ -45,7 +45,13 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) {
defer safeClose(session, &err)
a.StartProgressIndicatorWithLabel("Starting JupyterLab on codespace")
serverPort, serverUrl, err := rpc.StartJupyterServer(ctx, session)
invoker, err := codespaces.CreateRPCInvoker(ctx, session, "")
if err != nil {
return err
}
defer safeClose(invoker, &err)
serverPort, serverUrl, err := invoker.StartJupyterServer(ctx)
if err != nil {
return fmt.Errorf("failed to start JupyterLab server: %w", err)
}

View file

@ -6,7 +6,6 @@ import (
"net"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/spf13/cobra"
)
@ -57,7 +56,13 @@ func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err
localPort := listen.Addr().(*net.TCPAddr).Port
a.StartProgressIndicatorWithLabel("Fetching SSH Details")
remoteSSHServerPort, sshUser, err := rpc.StartSSHServer(ctx, session)
invoker, err := codespaces.CreateRPCInvoker(ctx, session, "")
if err != nil {
return err
}
defer safeClose(invoker, &err)
remoteSSHServerPort, sshUser, err := invoker.StartSSHServer(ctx)
a.StopProgressIndicator()
if err != nil {
return fmt.Errorf("error getting ssh server details: %w", err)

View file

@ -4,8 +4,8 @@ import (
"context"
"fmt"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/spf13/cobra"
)
@ -52,7 +52,13 @@ func (a *App) Rebuild(ctx context.Context, codespaceName string, full bool) (err
}
defer safeClose(session, &err)
err = rpc.RebuildContainer(ctx, session, full)
invoker, err := codespaces.CreateRPCInvoker(ctx, session, "")
if err != nil {
return err
}
defer safeClose(invoker, &err)
err = invoker.RebuildContainer(ctx, full)
if err != nil {
return fmt.Errorf("rebuilding codespace via session: %w", err)
}

View file

@ -20,7 +20,6 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/internal/codespaces/api"
"github.com/cli/cli/v2/internal/codespaces/rpc"
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/liveshare"
@ -174,7 +173,13 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
defer safeClose(session, &err)
a.StartProgressIndicatorWithLabel("Fetching SSH Details")
remoteSSHServerPort, sshUser, err := rpc.StartSSHServerWithOptions(ctx, session, startSSHOptions)
invoker, err := codespaces.CreateRPCInvoker(ctx, session, "")
if err != nil {
return err
}
defer safeClose(invoker, &err)
remoteSSHServerPort, sshUser, err := invoker.StartSSHServerWithOptions(ctx, startSSHOptions)
a.StopProgressIndicator()
if err != nil {
return fmt.Errorf("error getting ssh server details: %w", err)
@ -509,11 +514,18 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro
} else {
defer safeClose(session, &err)
_, result.user, err = rpc.StartSSHServer(ctx, session)
invoker, err := codespaces.CreateRPCInvoker(ctx, session, "")
if err != nil {
result.err = fmt.Errorf("error getting ssh server details: %w", err)
result.err = fmt.Errorf("error connecting to codespace: %w", err)
} else {
result.codespace = cs
defer safeClose(invoker, &err)
_, result.user, err = invoker.StartSSHServer(ctx)
if err != nil {
result.err = fmt.Errorf("error getting ssh server details: %w", err)
} else {
result.codespace = cs
}
}
}