From e11b43f8f6d01c768784dc594fad5a22b12c6c89 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 11 Oct 2022 08:55:14 -0400 Subject: [PATCH 1/3] Fixes for handling the grpc client lifecyle --- internal/codespaces/grpc/client.go | 65 +++++++++++++++--------- internal/codespaces/grpc/client_test.go | 55 +++++++++++--------- internal/codespaces/grpc/test/server.go | 20 ++++++-- internal/codespaces/grpc/test/session.go | 16 +++--- pkg/cmd/codespace/jupyter.go | 31 +++++++++-- pkg/liveshare/port_forwarder.go | 35 ------------- 6 files changed, 123 insertions(+), 99 deletions(-) diff --git a/internal/codespaces/grpc/client.go b/internal/codespaces/grpc/client.go index b4c5a0394..5ced382f8 100644 --- a/internal/codespaces/grpc/client.go +++ b/internal/codespaces/grpc/client.go @@ -19,9 +19,8 @@ import ( ) const ( - serverConnectionTimeout = 5 * time.Second - requestTimeout = 30 * time.Second - portConnectionTimeout = 30 * time.Second + ConnectionTimeout = 5 * time.Second + RequestTimeout = 30 * time.Second ) const ( @@ -34,6 +33,7 @@ type Client struct { token string listener net.Listener jupyterClient jupyter.JupyterServerHostClient + cancelPF context.CancelFunc } type liveshareSession interface { @@ -49,32 +49,50 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err) } - // Tunnel the remote gRPC server port to the local port - localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port) - internalTunnelClosed := make(chan error, 1) - go func() { - fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true) - internalTunnelClosed <- fwd.ForwardToListener(ctx, listener) + // Create a cancelable context to be able to cancel background tasks + // if we encounter an error while connecting to the gRPC server + connectctx, cancel := context.WithCancel(context.Background()) + defer func() { + if err != nil { + cancel() + } }() - // Ping the port to ensure that it is fully forwarded before continuing - connctx, cancel := context.WithTimeout(ctx, portConnectionTimeout) - defer cancel() - err = liveshare.WaitForPortConnection(connctx, localAddress) - if err != nil { - return nil, fmt.Errorf("failed to connect to local port: %w", err) - } + // Ensure we close the port forwarder if we encounter an error + // or once the gRPC connection is closed. pfcancel is retained + // to close the PF whenever we close the gRPC connection. + pfctx, pfcancel := context.WithCancel(connectctx) + + ch := make(chan error, 2) // Buffered channel to ensure we don't block on the goroutine + + // Tunnel the remote gRPC server port to the local port + localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port) + go func() { + fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true) + ch <- fwd.ForwardToListener(pfctx, listener) + }() // Attempt to connect to the port opts := []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), } - ctx, cancel = context.WithTimeout(ctx, serverConnectionTimeout) - defer cancel() - conn, err := grpc.DialContext(ctx, localAddress, opts...) - if err != nil { - return nil, err + + var conn *grpc.ClientConn + + go func() { + conn, err = grpc.DialContext(connectctx, localAddress, opts...) + ch <- err // nil if we successfully connected + }() + + // Wait for the connection to be established or for the context to be cancelled + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-ch: + if err != nil { + return nil, err + } } g := &Client{ @@ -82,6 +100,7 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie token: token, listener: listener, jupyterClient: jupyter.NewJupyterServerHostClient(conn), + cancelPF: pfcancel, } return g, nil @@ -89,6 +108,8 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie // Closes the gRPC connection func (g *Client) Close() error { + g.cancelPF() + // Closing the local listener effectively closes the gRPC connection if err := g.listener.Close(); err != nil { g.conn.Close() // If we fail to close the listener, explicitly close the gRPC connection and ignore any error @@ -105,9 +126,7 @@ func (g *Client) appendMetadata(ctx context.Context) context.Context { // Starts a remote JupyterLab server to allow the user to connect to the codespace via JupyterLab in their browser func (g *Client) StartJupyterServer(ctx context.Context) (port int, serverUrl string, err error) { - ctx, cancel := context.WithTimeout(ctx, requestTimeout) ctx = g.appendMetadata(ctx) - defer cancel() response, err := g.jupyterClient.GetRunningServer(ctx, &jupyter.GetRunningServerRequest{}) if err != nil { diff --git a/internal/codespaces/grpc/client_test.go b/internal/codespaces/grpc/client_test.go index 62a7b1723..d905c0e29 100644 --- a/internal/codespaces/grpc/client_test.go +++ b/internal/codespaces/grpc/client_test.go @@ -3,29 +3,35 @@ package grpc import ( "context" "fmt" - "os" + "log" "testing" - "github.com/cli/cli/v2/internal/codespaces/grpc/test" + grpctest "github.com/cli/cli/v2/internal/codespaces/grpc/test" ) -func TestMain(m *testing.M) { +func startServer(t *testing.T) { + t.Helper() + + ctx, cancel := context.WithCancel(context.Background()) + // Start the gRPC server in the background go func() { - err := test.StartServer() - if err != nil { - panic(err) + err := grpctest.StartServer(ctx) + if err != nil && err != context.Canceled { + log.Println(fmt.Errorf("error starting test server: %v", err)) } }() - m.Run() - os.Exit(0) + // Stop the gRPC server when the test is done + t.Cleanup(func() { + cancel() + }) } -func connect(t *testing.T) (ctx context.Context, client *Client) { +func connect(t *testing.T) (client *Client) { t.Helper() - ctx = context.Background() - client, err := Connect(ctx, &test.Session{}, "token") + + client, err := Connect(context.Background(), &grpctest.Session{}, "token") if err != nil { t.Fatalf("error connecting to internal server: %v", err) } @@ -34,31 +40,34 @@ func connect(t *testing.T) (ctx context.Context, client *Client) { client.Close() }) - return ctx, client + return client } // Test that the gRPC client returns the correct port and URL when the JupyterLab server starts successfully func TestStartJupyterServerSuccess(t *testing.T) { - ctx, client := connect(t) - port, url, err := client.StartJupyterServer(ctx) + startServer(t) + client := connect(t) + + port, url, err := client.StartJupyterServer(context.Background()) if err != nil { t.Fatalf("expected %v, got %v", nil, err) } - if port != test.JupyterPort { - t.Fatalf("expected %d, got %d", test.JupyterPort, port) + if port != grpctest.JupyterPort { + t.Fatalf("expected %d, got %d", grpctest.JupyterPort, port) } - if url != test.JupyterServerUrl { - t.Fatalf("expected %s, got %s", test.JupyterServerUrl, url) + if url != grpctest.JupyterServerUrl { + t.Fatalf("expected %s, got %s", grpctest.JupyterServerUrl, url) } } // Test that the gRPC client returns an error when the JupyterLab server fails to start func TestStartJupyterServerFailure(t *testing.T) { - ctx, client := connect(t) - test.JupyterMessage = "error message" - test.JupyterResult = false - errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", test.JupyterMessage) - port, url, err := client.StartJupyterServer(ctx) + startServer(t) + client := connect(t) + grpctest.JupyterMessage = "error message" + grpctest.JupyterResult = false + errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", grpctest.JupyterMessage) + port, url, err := client.StartJupyterServer(context.Background()) if err.Error() != errorMessage { t.Fatalf("expected %v, got %v", errorMessage, err) } diff --git a/internal/codespaces/grpc/test/server.go b/internal/codespaces/grpc/test/server.go index 50608a9fa..8af5efc29 100644 --- a/internal/codespaces/grpc/test/server.go +++ b/internal/codespaces/grpc/test/server.go @@ -35,7 +35,7 @@ func (s *server) GetRunningServer(ctx context.Context, in *jupyter.GetRunningSer } // Starts the mock gRPC server listening on port 50051 -func StartServer() error { +func StartServer(ctx context.Context) error { listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) if err != nil { return fmt.Errorf("failed to listen: %v", err) @@ -44,9 +44,19 @@ func StartServer() error { s := grpc.NewServer() jupyter.RegisterJupyterServerHostServer(s, &server{}) - if err := s.Serve(listener); err != nil { - return fmt.Errorf("failed to serve: %v", err) - } - return nil + ch := make(chan error, 1) + go func() { + if err := s.Serve(listener); err != nil { + ch <- fmt.Errorf("failed to serve: %v", err) + } + }() + + select { + case <-ctx.Done(): + s.Stop() + return ctx.Err() + case err := <-ch: + return err + } } diff --git a/internal/codespaces/grpc/test/session.go b/internal/codespaces/grpc/test/session.go index 70d81d41e..aba0f17ee 100644 --- a/internal/codespaces/grpc/test/session.go +++ b/internal/codespaces/grpc/test/session.go @@ -3,7 +3,6 @@ package test import ( "context" "fmt" - "log" "net" "github.com/cli/cli/v2/pkg/liveshare" @@ -11,23 +10,22 @@ import ( ) type Session struct { + channel ssh.Channel } func (s *Session) KeepAlive(reason string) { } func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) { + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) + if err != nil { + return liveshare.ChannelID{}, err + } + s.channel = &Channel{conn} return liveshare.ChannelID{}, nil } // Creates mock SSH channel connected to the mock gRPC server func (s *Session) OpenStreamingChannel(ctx context.Context, id liveshare.ChannelID) (ssh.Channel, error) { - dialer := net.Dialer{} - conn, err := dialer.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort)) - if err != nil { - log.Fatalf("failed to connect to the grpc server: %v", err) - } - return &Channel{ - conn: conn, - }, nil + return s.channel, nil } diff --git a/pkg/cmd/codespace/jupyter.go b/pkg/cmd/codespace/jupyter.go index c37a97ea7..928dc8871 100644 --- a/pkg/cmd/codespace/jupyter.go +++ b/pkg/cmd/codespace/jupyter.go @@ -45,17 +45,17 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) { defer safeClose(session, &err) a.StartProgressIndicatorWithLabel("Starting JupyterLab on codespace") - client, err := grpc.Connect(ctx, session, codespace.Connection.SessionToken) + client, err := connectToGRPCServer(ctx, session, codespace.Connection.SessionToken) if err != nil { - return fmt.Errorf("error connecting to internal server: %w", err) + return fmt.Errorf("failed to connect to internal server: %w", err) } defer safeClose(client, &err) - serverPort, serverUrl, err := client.StartJupyterServer(ctx) - a.StopProgressIndicator() + serverPort, serverUrl, err := startJupyterServer(ctx, client) if err != nil { return fmt.Errorf("failed to start JupyterLab server: %w", err) } + a.StopProgressIndicator() // Pass 0 to pick a random port listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0)) @@ -87,3 +87,26 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) { return nil // success } } + +func connectToGRPCServer(ctx context.Context, session liveshareSession, token string) (*grpc.Client, error) { + ctx, _ = context.WithTimeout(ctx, grpc.ConnectionTimeout) + + client, err := grpc.Connect(ctx, session, token) + if err != nil { + return nil, fmt.Errorf("error connecting to internal server: %w", err) + } + + return client, nil +} + +func startJupyterServer(ctx context.Context, client *grpc.Client) (int, string, error) { + ctx, cancel := context.WithTimeout(ctx, grpc.RequestTimeout) + defer cancel() + + serverPort, serverUrl, err := client.StartJupyterServer(ctx) + if err != nil { + return 0, "", fmt.Errorf("failed to start JupyterLab server: %w", err) + } + + return serverPort, serverUrl, nil +} diff --git a/pkg/liveshare/port_forwarder.go b/pkg/liveshare/port_forwarder.go index 923415c01..f042eeaea 100644 --- a/pkg/liveshare/port_forwarder.go +++ b/pkg/liveshare/port_forwarder.go @@ -99,41 +99,6 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) return awaitError(ctx, errc) } -// Loops until we can connect to the address or the context is canceled. -func WaitForPortConnection(ctx context.Context, address string) error { - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - err := connectToAddr(address) - if err != nil { - continue - } - - return nil // success - } - } -} - -// Connects to and pings a given address to ensure that the server is shared and the port is forwarded. -func connectToAddr(address string) error { - // Verify that the port can be connected to - conn, err := net.Dial("tcp", address) - if err != nil { - return err - } - defer conn.Close() - - // Send a ping and make sure it succeed - _, err = conn.Write([]byte("ping")) - if err != nil { - return err - } - - return nil -} - func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) { id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort) if err != nil { From 9e13f6ba6b17b1aec11d7e0971e5748e7541925f Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 11 Oct 2022 09:14:23 -0400 Subject: [PATCH 2/3] cleanup connect --- internal/codespaces/grpc/client.go | 35 +++++++++++++++--------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/internal/codespaces/grpc/client.go b/internal/codespaces/grpc/client.go index 5ced382f8..76119d659 100644 --- a/internal/codespaces/grpc/client.go +++ b/internal/codespaces/grpc/client.go @@ -48,6 +48,12 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie if err != nil { return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err) } + localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port) + + client := &Client{ + token: token, + listener: listener, + } // Create a cancelable context to be able to cancel background tasks // if we encounter an error while connecting to the gRPC server @@ -58,29 +64,27 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie } }() + ch := make(chan error, 2) // Buffered channel to ensure we don't block on the goroutine + // Ensure we close the port forwarder if we encounter an error // or once the gRPC connection is closed. pfcancel is retained // to close the PF whenever we close the gRPC connection. pfctx, pfcancel := context.WithCancel(connectctx) - - ch := make(chan error, 2) // Buffered channel to ensure we don't block on the goroutine + client.cancelPF = pfcancel // Tunnel the remote gRPC server port to the local port - localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port) go func() { fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true) ch <- fwd.ForwardToListener(pfctx, listener) }() - // Attempt to connect to the port - opts := []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithBlock(), - } - var conn *grpc.ClientConn - go func() { + // Attempt to connect to the port + opts := []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + } conn, err = grpc.DialContext(connectctx, localAddress, opts...) ch <- err // nil if we successfully connected }() @@ -95,15 +99,10 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie } } - g := &Client{ - conn: conn, - token: token, - listener: listener, - jupyterClient: jupyter.NewJupyterServerHostClient(conn), - cancelPF: pfcancel, - } + client.conn = conn + client.jupyterClient = jupyter.NewJupyterServerHostClient(conn) - return g, nil + return client, nil } // Closes the gRPC connection From a356a1bef0bfb32546aa775f3cf2cd79b9bf7e6c Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 11 Oct 2022 13:15:34 -0400 Subject: [PATCH 3/3] no need to ignore cancel --- pkg/cmd/codespace/jupyter.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pkg/cmd/codespace/jupyter.go b/pkg/cmd/codespace/jupyter.go index 928dc8871..5c19c2ab1 100644 --- a/pkg/cmd/codespace/jupyter.go +++ b/pkg/cmd/codespace/jupyter.go @@ -89,7 +89,8 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) { } func connectToGRPCServer(ctx context.Context, session liveshareSession, token string) (*grpc.Client, error) { - ctx, _ = context.WithTimeout(ctx, grpc.ConnectionTimeout) + ctx, cancel := context.WithTimeout(ctx, grpc.ConnectionTimeout) + defer cancel() client, err := grpc.Connect(ctx, session, token) if err != nil {