diff --git a/internal/codespaces/grpc/client.go b/internal/codespaces/grpc/client.go index b4c5a0394..76119d659 100644 --- a/internal/codespaces/grpc/client.go +++ b/internal/codespaces/grpc/client.go @@ -19,9 +19,8 @@ import ( ) const ( - serverConnectionTimeout = 5 * time.Second - requestTimeout = 30 * time.Second - portConnectionTimeout = 30 * time.Second + ConnectionTimeout = 5 * time.Second + RequestTimeout = 30 * time.Second ) const ( @@ -34,6 +33,7 @@ type Client struct { token string listener net.Listener jupyterClient jupyter.JupyterServerHostClient + cancelPF context.CancelFunc } type liveshareSession interface { @@ -48,47 +48,67 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie if err != nil { return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err) } - - // Tunnel the remote gRPC server port to the local port localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port) - internalTunnelClosed := make(chan error, 1) - go func() { - fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true) - internalTunnelClosed <- fwd.ForwardToListener(ctx, listener) + + client := &Client{ + token: token, + listener: listener, + } + + // Create a cancelable context to be able to cancel background tasks + // if we encounter an error while connecting to the gRPC server + connectctx, cancel := context.WithCancel(context.Background()) + defer func() { + if err != nil { + cancel() + } }() - // Ping the port to ensure that it is fully forwarded before continuing - connctx, cancel := context.WithTimeout(ctx, portConnectionTimeout) - defer cancel() - err = liveshare.WaitForPortConnection(connctx, localAddress) - if err != nil { - return nil, fmt.Errorf("failed to connect to local port: %w", err) + ch := make(chan error, 2) // Buffered channel to ensure we don't block on the goroutine + + // Ensure we close the port forwarder if we encounter an error + // or once the gRPC connection is closed. pfcancel is retained + // to close the PF whenever we close the gRPC connection. + pfctx, pfcancel := context.WithCancel(connectctx) + client.cancelPF = pfcancel + + // Tunnel the remote gRPC server port to the local port + go func() { + fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true) + ch <- fwd.ForwardToListener(pfctx, listener) + }() + + var conn *grpc.ClientConn + go func() { + // Attempt to connect to the port + opts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + } + conn, err = grpc.DialContext(connectctx, localAddress, opts...) + ch <- err // nil if we successfully connected + }() + + // Wait for the connection to be established or for the context to be cancelled + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-ch: + if err != nil { + return nil, err + } } - // Attempt to connect to the port - opts := []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithBlock(), - } - ctx, cancel = context.WithTimeout(ctx, serverConnectionTimeout) - defer cancel() - conn, err := grpc.DialContext(ctx, localAddress, opts...) - if err != nil { - return nil, err - } + client.conn = conn + client.jupyterClient = jupyter.NewJupyterServerHostClient(conn) - g := &Client{ - conn: conn, - token: token, - listener: listener, - jupyterClient: jupyter.NewJupyterServerHostClient(conn), - } - - return g, nil + return client, nil } // Closes the gRPC connection func (g *Client) Close() error { + g.cancelPF() + // Closing the local listener effectively closes the gRPC connection if err := g.listener.Close(); err != nil { g.conn.Close() // If we fail to close the listener, explicitly close the gRPC connection and ignore any error @@ -105,9 +125,7 @@ func (g *Client) appendMetadata(ctx context.Context) context.Context { // Starts a remote JupyterLab server to allow the user to connect to the codespace via JupyterLab in their browser func (g *Client) StartJupyterServer(ctx context.Context) (port int, serverUrl string, err error) { - ctx, cancel := context.WithTimeout(ctx, requestTimeout) ctx = g.appendMetadata(ctx) - defer cancel() response, err := g.jupyterClient.GetRunningServer(ctx, &jupyter.GetRunningServerRequest{}) if err != nil { diff --git a/internal/codespaces/grpc/client_test.go b/internal/codespaces/grpc/client_test.go index 62a7b1723..d905c0e29 100644 --- a/internal/codespaces/grpc/client_test.go +++ b/internal/codespaces/grpc/client_test.go @@ -3,29 +3,35 @@ package grpc import ( "context" "fmt" - "os" + "log" "testing" - "github.com/cli/cli/v2/internal/codespaces/grpc/test" + grpctest "github.com/cli/cli/v2/internal/codespaces/grpc/test" ) -func TestMain(m *testing.M) { +func startServer(t *testing.T) { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + // Start the gRPC server in the background go func() { - err := test.StartServer() - if err != nil { - panic(err) + err := grpctest.StartServer(ctx) + if err != nil && err != context.Canceled { + log.Println(fmt.Errorf("error starting test server: %v", err)) } }() - m.Run() - os.Exit(0) + // Stop the gRPC server when the test is done + t.Cleanup(func() { + cancel() + }) } -func connect(t *testing.T) (ctx context.Context, client *Client) { +func connect(t *testing.T) (client *Client) { t.Helper() - ctx = context.Background() - client, err := Connect(ctx, &test.Session{}, "token") + + client, err := Connect(context.Background(), &grpctest.Session{}, "token") if err != nil { t.Fatalf("error connecting to internal server: %v", err) } @@ -34,31 +40,34 @@ func connect(t *testing.T) (ctx context.Context, client *Client) { client.Close() }) - return ctx, client + return client } // Test that the gRPC client returns the correct port and URL when the JupyterLab server starts successfully func TestStartJupyterServerSuccess(t *testing.T) { - ctx, client := connect(t) - port, url, err := client.StartJupyterServer(ctx) + startServer(t) + client := connect(t) + + port, url, err := client.StartJupyterServer(context.Background()) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } - if port != test.JupyterPort { - t.Fatalf("expected %d, got %d", test.JupyterPort, port) + if port != grpctest.JupyterPort { + t.Fatalf("expected %d, got %d", grpctest.JupyterPort, port) } - if url != test.JupyterServerUrl { - t.Fatalf("expected %s, got %s", test.JupyterServerUrl, url) + if url != grpctest.JupyterServerUrl { + t.Fatalf("expected %s, got %s", grpctest.JupyterServerUrl, url) } } // Test that the gRPC client returns an error when the JupyterLab server fails to start func TestStartJupyterServerFailure(t *testing.T) { - ctx, client := connect(t) - test.JupyterMessage = "error message" - test.JupyterResult = false - errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", test.JupyterMessage) - port, url, err := client.StartJupyterServer(ctx) + startServer(t) + client := connect(t) + grpctest.JupyterMessage = "error message" + grpctest.JupyterResult = false + errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", grpctest.JupyterMessage) + port, url, err := client.StartJupyterServer(context.Background()) if err.Error() != errorMessage { t.Fatalf("expected %v, got %v", errorMessage, err) } diff --git a/internal/codespaces/grpc/test/server.go b/internal/codespaces/grpc/test/server.go index 50608a9fa..8af5efc29 100644 --- a/internal/codespaces/grpc/test/server.go +++ b/internal/codespaces/grpc/test/server.go @@ -35,7 +35,7 @@ func (s *server) GetRunningServer(ctx context.Context, in *jupyter.GetRunningSer } // Starts the mock gRPC server listening on port 50051 -func StartServer() error { +func StartServer(ctx context.Context) error { listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) if err != nil { return fmt.Errorf("failed to listen: %v", err) @@ -44,9 +44,19 @@ func StartServer() error { s := grpc.NewServer() jupyter.RegisterJupyterServerHostServer(s, &server{}) - if err := s.Serve(listener); err != nil { - return fmt.Errorf("failed to serve: %v", err) - } - return nil + ch := make(chan error, 1) + go func() { + if err := s.Serve(listener); err != nil { + ch <- fmt.Errorf("failed to serve: %v", err) + } + }() + + select { + case <-ctx.Done(): + s.Stop() + return ctx.Err() + case err := <-ch: + return err + } } diff --git a/internal/codespaces/grpc/test/session.go b/internal/codespaces/grpc/test/session.go index 70d81d41e..aba0f17ee 100644 --- a/internal/codespaces/grpc/test/session.go +++ b/internal/codespaces/grpc/test/session.go @@ -3,7 +3,6 @@ package test import ( "context" "fmt" - "log" "net" "github.com/cli/cli/v2/pkg/liveshare" @@ -11,23 +10,22 @@ import ( ) type Session struct { + channel ssh.Channel } func (s *Session) KeepAlive(reason string) { } func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) { + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) + if err != nil { + return liveshare.ChannelID{}, err + } + s.channel = &Channel{conn} return liveshare.ChannelID{}, nil } // Creates mock SSH channel connected to the mock gRPC server func (s *Session) OpenStreamingChannel(ctx context.Context, id liveshare.ChannelID) (ssh.Channel, error) { - dialer := net.Dialer{} - conn, err := dialer.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) - if err != nil { - log.Fatalf("failed to connect to the grpc server: %v", err) - } - return &Channel{ - conn: conn, - }, nil + return s.channel, nil } diff --git a/pkg/cmd/codespace/jupyter.go b/pkg/cmd/codespace/jupyter.go index c37a97ea7..5c19c2ab1 100644 --- a/pkg/cmd/codespace/jupyter.go +++ b/pkg/cmd/codespace/jupyter.go @@ -45,17 +45,17 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) { defer safeClose(session, &err) a.StartProgressIndicatorWithLabel("Starting JupyterLab on codespace") - client, err := grpc.Connect(ctx, session, codespace.Connection.SessionToken) + client, err := connectToGRPCServer(ctx, session, codespace.Connection.SessionToken) if err != nil { - return fmt.Errorf("error connecting to internal server: %w", err) + return fmt.Errorf("failed to connect to internal server: %w", err) } defer safeClose(client, &err) - serverPort, serverUrl, err := client.StartJupyterServer(ctx) - a.StopProgressIndicator() + serverPort, serverUrl, err := startJupyterServer(ctx, client) if err != nil { return fmt.Errorf("failed to start JupyterLab server: %w", err) } + a.StopProgressIndicator() // Pass 0 to pick a random port listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0)) @@ -87,3 +87,27 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) { return nil // success } } + +func connectToGRPCServer(ctx context.Context, session liveshareSession, token string) (*grpc.Client, error) { + ctx, cancel := context.WithTimeout(ctx, grpc.ConnectionTimeout) + defer cancel() + + client, err := grpc.Connect(ctx, session, token) + if err != nil { + return nil, fmt.Errorf("error connecting to internal server: %w", err) + } + + return client, nil +} + +func startJupyterServer(ctx context.Context, client *grpc.Client) (int, string, error) { + ctx, cancel := context.WithTimeout(ctx, grpc.RequestTimeout) + defer cancel() + + serverPort, serverUrl, err := client.StartJupyterServer(ctx) + if err != nil { + return 0, "", fmt.Errorf("failed to start JupyterLab server: %w", err) + } + + return serverPort, serverUrl, nil +} diff --git a/pkg/liveshare/port_forwarder.go b/pkg/liveshare/port_forwarder.go index 923415c01..f042eeaea 100644 --- a/pkg/liveshare/port_forwarder.go +++ b/pkg/liveshare/port_forwarder.go @@ -99,41 +99,6 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) return awaitError(ctx, errc) } -// Loops until we can connect to the address or the context is canceled. -func WaitForPortConnection(ctx context.Context, address string) error { - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - err := connectToAddr(address) - if err != nil { - continue - } - - return nil // success - } - } -} - -// Connects to and pings a given address to ensure that the server is shared and the port is forwarded. -func connectToAddr(address string) error { - // Verify that the port can be connected to - conn, err := net.Dial("tcp", address) - if err != nil { - return err - } - defer conn.Close() - - // Send a ping and make sure it succeed - _, err = conn.Write([]byte("ping")) - if err != nil { - return err - } - - return nil -} - func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) { id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort) if err != nil {