Merge pull request #1 from dmgardiner25/jg/fixes

Fixes races and timeout issue
This commit is contained in:
David Gardiner 2022-10-11 10:27:41 -07:00 committed by GitHub
commit b7d7674f78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 136 additions and 112 deletions

View file

@ -19,9 +19,8 @@ import (
)
const (
serverConnectionTimeout = 5 * time.Second
requestTimeout = 30 * time.Second
portConnectionTimeout = 30 * time.Second
ConnectionTimeout = 5 * time.Second
RequestTimeout = 30 * time.Second
)
const (
@ -34,6 +33,7 @@ type Client struct {
token string
listener net.Listener
jupyterClient jupyter.JupyterServerHostClient
cancelPF context.CancelFunc
}
type liveshareSession interface {
@ -48,47 +48,67 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie
if err != nil {
return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err)
}
// Tunnel the remote gRPC server port to the local port
localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port)
internalTunnelClosed := make(chan error, 1)
go func() {
fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true)
internalTunnelClosed <- fwd.ForwardToListener(ctx, listener)
client := &Client{
token: token,
listener: listener,
}
// Create a cancelable context to be able to cancel background tasks
// if we encounter an error while connecting to the gRPC server
connectctx, cancel := context.WithCancel(context.Background())
defer func() {
if err != nil {
cancel()
}
}()
// Ping the port to ensure that it is fully forwarded before continuing
connctx, cancel := context.WithTimeout(ctx, portConnectionTimeout)
defer cancel()
err = liveshare.WaitForPortConnection(connctx, localAddress)
if err != nil {
return nil, fmt.Errorf("failed to connect to local port: %w", err)
ch := make(chan error, 2) // Buffered channel to ensure we don't block on the goroutine
// Ensure we close the port forwarder if we encounter an error
// or once the gRPC connection is closed. pfcancel is retained
// to close the PF whenever we close the gRPC connection.
pfctx, pfcancel := context.WithCancel(connectctx)
client.cancelPF = pfcancel
// Tunnel the remote gRPC server port to the local port
go func() {
fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true)
ch <- fwd.ForwardToListener(pfctx, listener)
}()
var conn *grpc.ClientConn
go func() {
// Attempt to connect to the port
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
}
conn, err = grpc.DialContext(connectctx, localAddress, opts...)
ch <- err // nil if we successfully connected
}()
// Wait for the connection to be established or for the context to be cancelled
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-ch:
if err != nil {
return nil, err
}
}
// Attempt to connect to the port
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
}
ctx, cancel = context.WithTimeout(ctx, serverConnectionTimeout)
defer cancel()
conn, err := grpc.DialContext(ctx, localAddress, opts...)
if err != nil {
return nil, err
}
client.conn = conn
client.jupyterClient = jupyter.NewJupyterServerHostClient(conn)
g := &Client{
conn: conn,
token: token,
listener: listener,
jupyterClient: jupyter.NewJupyterServerHostClient(conn),
}
return g, nil
return client, nil
}
// Closes the gRPC connection
func (g *Client) Close() error {
g.cancelPF()
// Closing the local listener effectively closes the gRPC connection
if err := g.listener.Close(); err != nil {
g.conn.Close() // If we fail to close the listener, explicitly close the gRPC connection and ignore any error
@ -105,9 +125,7 @@ func (g *Client) appendMetadata(ctx context.Context) context.Context {
// Starts a remote JupyterLab server to allow the user to connect to the codespace via JupyterLab in their browser
func (g *Client) StartJupyterServer(ctx context.Context) (port int, serverUrl string, err error) {
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
ctx = g.appendMetadata(ctx)
defer cancel()
response, err := g.jupyterClient.GetRunningServer(ctx, &jupyter.GetRunningServerRequest{})
if err != nil {

View file

@ -3,29 +3,35 @@ package grpc
import (
"context"
"fmt"
"os"
"log"
"testing"
"github.com/cli/cli/v2/internal/codespaces/grpc/test"
grpctest "github.com/cli/cli/v2/internal/codespaces/grpc/test"
)
func TestMain(m *testing.M) {
func startServer(t *testing.T) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
// Start the gRPC server in the background
go func() {
err := test.StartServer()
if err != nil {
panic(err)
err := grpctest.StartServer(ctx)
if err != nil && err != context.Canceled {
log.Println(fmt.Errorf("error starting test server: %v", err))
}
}()
m.Run()
os.Exit(0)
// Stop the gRPC server when the test is done
t.Cleanup(func() {
cancel()
})
}
func connect(t *testing.T) (ctx context.Context, client *Client) {
func connect(t *testing.T) (client *Client) {
t.Helper()
ctx = context.Background()
client, err := Connect(ctx, &test.Session{}, "token")
client, err := Connect(context.Background(), &grpctest.Session{}, "token")
if err != nil {
t.Fatalf("error connecting to internal server: %v", err)
}
@ -34,31 +40,34 @@ func connect(t *testing.T) (ctx context.Context, client *Client) {
client.Close()
})
return ctx, client
return client
}
// Test that the gRPC client returns the correct port and URL when the JupyterLab server starts successfully
func TestStartJupyterServerSuccess(t *testing.T) {
ctx, client := connect(t)
port, url, err := client.StartJupyterServer(ctx)
startServer(t)
client := connect(t)
port, url, err := client.StartJupyterServer(context.Background())
if err != nil {
t.Fatalf("expected %v, got %v", nil, err)
}
if port != test.JupyterPort {
t.Fatalf("expected %d, got %d", test.JupyterPort, port)
if port != grpctest.JupyterPort {
t.Fatalf("expected %d, got %d", grpctest.JupyterPort, port)
}
if url != test.JupyterServerUrl {
t.Fatalf("expected %s, got %s", test.JupyterServerUrl, url)
if url != grpctest.JupyterServerUrl {
t.Fatalf("expected %s, got %s", grpctest.JupyterServerUrl, url)
}
}
// Test that the gRPC client returns an error when the JupyterLab server fails to start
func TestStartJupyterServerFailure(t *testing.T) {
ctx, client := connect(t)
test.JupyterMessage = "error message"
test.JupyterResult = false
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", test.JupyterMessage)
port, url, err := client.StartJupyterServer(ctx)
startServer(t)
client := connect(t)
grpctest.JupyterMessage = "error message"
grpctest.JupyterResult = false
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", grpctest.JupyterMessage)
port, url, err := client.StartJupyterServer(context.Background())
if err.Error() != errorMessage {
t.Fatalf("expected %v, got %v", errorMessage, err)
}

View file

@ -35,7 +35,7 @@ func (s *server) GetRunningServer(ctx context.Context, in *jupyter.GetRunningSer
}
// Starts the mock gRPC server listening on port 50051
func StartServer() error {
func StartServer(ctx context.Context) error {
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
if err != nil {
return fmt.Errorf("failed to listen: %v", err)
@ -44,9 +44,19 @@ func StartServer() error {
s := grpc.NewServer()
jupyter.RegisterJupyterServerHostServer(s, &server{})
if err := s.Serve(listener); err != nil {
return fmt.Errorf("failed to serve: %v", err)
}
return nil
ch := make(chan error, 1)
go func() {
if err := s.Serve(listener); err != nil {
ch <- fmt.Errorf("failed to serve: %v", err)
}
}()
select {
case <-ctx.Done():
s.Stop()
return ctx.Err()
case err := <-ch:
return err
}
}

View file

@ -3,7 +3,6 @@ package test
import (
"context"
"fmt"
"log"
"net"
"github.com/cli/cli/v2/pkg/liveshare"
@ -11,23 +10,22 @@ import (
)
type Session struct {
channel ssh.Channel
}
func (s *Session) KeepAlive(reason string) {
}
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) {
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
if err != nil {
return liveshare.ChannelID{}, err
}
s.channel = &Channel{conn}
return liveshare.ChannelID{}, nil
}
// Creates mock SSH channel connected to the mock gRPC server
func (s *Session) OpenStreamingChannel(ctx context.Context, id liveshare.ChannelID) (ssh.Channel, error) {
dialer := net.Dialer{}
conn, err := dialer.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
if err != nil {
log.Fatalf("failed to connect to the grpc server: %v", err)
}
return &Channel{
conn: conn,
}, nil
return s.channel, nil
}

View file

@ -45,17 +45,17 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) {
defer safeClose(session, &err)
a.StartProgressIndicatorWithLabel("Starting JupyterLab on codespace")
client, err := grpc.Connect(ctx, session, codespace.Connection.SessionToken)
client, err := connectToGRPCServer(ctx, session, codespace.Connection.SessionToken)
if err != nil {
return fmt.Errorf("error connecting to internal server: %w", err)
return fmt.Errorf("failed to connect to internal server: %w", err)
}
defer safeClose(client, &err)
serverPort, serverUrl, err := client.StartJupyterServer(ctx)
a.StopProgressIndicator()
serverPort, serverUrl, err := startJupyterServer(ctx, client)
if err != nil {
return fmt.Errorf("failed to start JupyterLab server: %w", err)
}
a.StopProgressIndicator()
// Pass 0 to pick a random port
listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0))
@ -87,3 +87,27 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) {
return nil // success
}
}
func connectToGRPCServer(ctx context.Context, session liveshareSession, token string) (*grpc.Client, error) {
ctx, cancel := context.WithTimeout(ctx, grpc.ConnectionTimeout)
defer cancel()
client, err := grpc.Connect(ctx, session, token)
if err != nil {
return nil, fmt.Errorf("error connecting to internal server: %w", err)
}
return client, nil
}
func startJupyterServer(ctx context.Context, client *grpc.Client) (int, string, error) {
ctx, cancel := context.WithTimeout(ctx, grpc.RequestTimeout)
defer cancel()
serverPort, serverUrl, err := client.StartJupyterServer(ctx)
if err != nil {
return 0, "", fmt.Errorf("failed to start JupyterLab server: %w", err)
}
return serverPort, serverUrl, nil
}

View file

@ -99,41 +99,6 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser)
return awaitError(ctx, errc)
}
// Loops until we can connect to the address or the context is canceled.
func WaitForPortConnection(ctx context.Context, address string) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
err := connectToAddr(address)
if err != nil {
continue
}
return nil // success
}
}
}
// Connects to and pings a given address to ensure that the server is shared and the port is forwarded.
func connectToAddr(address string) error {
// Verify that the port can be connected to
conn, err := net.Dial("tcp", address)
if err != nil {
return err
}
defer conn.Close()
// Send a ping and make sure it succeed
_, err = conn.Write([]byte("ping"))
if err != nil {
return err
}
return nil
}
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) {
id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort)
if err != nil {