Merge pull request #1 from dmgardiner25/jg/fixes
Fixes races and timeout issue
This commit is contained in:
commit
b7d7674f78
6 changed files with 136 additions and 112 deletions
|
|
@ -19,9 +19,8 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
serverConnectionTimeout = 5 * time.Second
|
||||
requestTimeout = 30 * time.Second
|
||||
portConnectionTimeout = 30 * time.Second
|
||||
ConnectionTimeout = 5 * time.Second
|
||||
RequestTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -34,6 +33,7 @@ type Client struct {
|
|||
token string
|
||||
listener net.Listener
|
||||
jupyterClient jupyter.JupyterServerHostClient
|
||||
cancelPF context.CancelFunc
|
||||
}
|
||||
|
||||
type liveshareSession interface {
|
||||
|
|
@ -48,47 +48,67 @@ func Connect(ctx context.Context, session liveshareSession, token string) (*Clie
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen to local port over tcp: %w", err)
|
||||
}
|
||||
|
||||
// Tunnel the remote gRPC server port to the local port
|
||||
localAddress := fmt.Sprintf("127.0.0.1:%d", listener.Addr().(*net.TCPAddr).Port)
|
||||
internalTunnelClosed := make(chan error, 1)
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true)
|
||||
internalTunnelClosed <- fwd.ForwardToListener(ctx, listener)
|
||||
|
||||
client := &Client{
|
||||
token: token,
|
||||
listener: listener,
|
||||
}
|
||||
|
||||
// Create a cancelable context to be able to cancel background tasks
|
||||
// if we encounter an error while connecting to the gRPC server
|
||||
connectctx, cancel := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
// Ping the port to ensure that it is fully forwarded before continuing
|
||||
connctx, cancel := context.WithTimeout(ctx, portConnectionTimeout)
|
||||
defer cancel()
|
||||
err = liveshare.WaitForPortConnection(connctx, localAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to local port: %w", err)
|
||||
ch := make(chan error, 2) // Buffered channel to ensure we don't block on the goroutine
|
||||
|
||||
// Ensure we close the port forwarder if we encounter an error
|
||||
// or once the gRPC connection is closed. pfcancel is retained
|
||||
// to close the PF whenever we close the gRPC connection.
|
||||
pfctx, pfcancel := context.WithCancel(connectctx)
|
||||
client.cancelPF = pfcancel
|
||||
|
||||
// Tunnel the remote gRPC server port to the local port
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, codespacesInternalSessionName, codespacesInternalPort, true)
|
||||
ch <- fwd.ForwardToListener(pfctx, listener)
|
||||
}()
|
||||
|
||||
var conn *grpc.ClientConn
|
||||
go func() {
|
||||
// Attempt to connect to the port
|
||||
opts := []grpc.DialOption{
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
}
|
||||
conn, err = grpc.DialContext(connectctx, localAddress, opts...)
|
||||
ch <- err // nil if we successfully connected
|
||||
}()
|
||||
|
||||
// Wait for the connection to be established or for the context to be cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case err := <-ch:
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt to connect to the port
|
||||
opts := []grpc.DialOption{
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
}
|
||||
ctx, cancel = context.WithTimeout(ctx, serverConnectionTimeout)
|
||||
defer cancel()
|
||||
conn, err := grpc.DialContext(ctx, localAddress, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.conn = conn
|
||||
client.jupyterClient = jupyter.NewJupyterServerHostClient(conn)
|
||||
|
||||
g := &Client{
|
||||
conn: conn,
|
||||
token: token,
|
||||
listener: listener,
|
||||
jupyterClient: jupyter.NewJupyterServerHostClient(conn),
|
||||
}
|
||||
|
||||
return g, nil
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Closes the gRPC connection
|
||||
func (g *Client) Close() error {
|
||||
g.cancelPF()
|
||||
|
||||
// Closing the local listener effectively closes the gRPC connection
|
||||
if err := g.listener.Close(); err != nil {
|
||||
g.conn.Close() // If we fail to close the listener, explicitly close the gRPC connection and ignore any error
|
||||
|
|
@ -105,9 +125,7 @@ func (g *Client) appendMetadata(ctx context.Context) context.Context {
|
|||
|
||||
// Starts a remote JupyterLab server to allow the user to connect to the codespace via JupyterLab in their browser
|
||||
func (g *Client) StartJupyterServer(ctx context.Context) (port int, serverUrl string, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
|
||||
ctx = g.appendMetadata(ctx)
|
||||
defer cancel()
|
||||
|
||||
response, err := g.jupyterClient.GetRunningServer(ctx, &jupyter.GetRunningServerRequest{})
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -3,29 +3,35 @@ package grpc
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"log"
|
||||
"testing"
|
||||
|
||||
"github.com/cli/cli/v2/internal/codespaces/grpc/test"
|
||||
grpctest "github.com/cli/cli/v2/internal/codespaces/grpc/test"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
func startServer(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Start the gRPC server in the background
|
||||
go func() {
|
||||
err := test.StartServer()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
err := grpctest.StartServer(ctx)
|
||||
if err != nil && err != context.Canceled {
|
||||
log.Println(fmt.Errorf("error starting test server: %v", err))
|
||||
}
|
||||
}()
|
||||
|
||||
m.Run()
|
||||
os.Exit(0)
|
||||
// Stop the gRPC server when the test is done
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
})
|
||||
}
|
||||
|
||||
func connect(t *testing.T) (ctx context.Context, client *Client) {
|
||||
func connect(t *testing.T) (client *Client) {
|
||||
t.Helper()
|
||||
ctx = context.Background()
|
||||
client, err := Connect(ctx, &test.Session{}, "token")
|
||||
|
||||
client, err := Connect(context.Background(), &grpctest.Session{}, "token")
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting to internal server: %v", err)
|
||||
}
|
||||
|
|
@ -34,31 +40,34 @@ func connect(t *testing.T) (ctx context.Context, client *Client) {
|
|||
client.Close()
|
||||
})
|
||||
|
||||
return ctx, client
|
||||
return client
|
||||
}
|
||||
|
||||
// Test that the gRPC client returns the correct port and URL when the JupyterLab server starts successfully
|
||||
func TestStartJupyterServerSuccess(t *testing.T) {
|
||||
ctx, client := connect(t)
|
||||
port, url, err := client.StartJupyterServer(ctx)
|
||||
startServer(t)
|
||||
client := connect(t)
|
||||
|
||||
port, url, err := client.StartJupyterServer(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected %v, got %v", nil, err)
|
||||
}
|
||||
if port != test.JupyterPort {
|
||||
t.Fatalf("expected %d, got %d", test.JupyterPort, port)
|
||||
if port != grpctest.JupyterPort {
|
||||
t.Fatalf("expected %d, got %d", grpctest.JupyterPort, port)
|
||||
}
|
||||
if url != test.JupyterServerUrl {
|
||||
t.Fatalf("expected %s, got %s", test.JupyterServerUrl, url)
|
||||
if url != grpctest.JupyterServerUrl {
|
||||
t.Fatalf("expected %s, got %s", grpctest.JupyterServerUrl, url)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the gRPC client returns an error when the JupyterLab server fails to start
|
||||
func TestStartJupyterServerFailure(t *testing.T) {
|
||||
ctx, client := connect(t)
|
||||
test.JupyterMessage = "error message"
|
||||
test.JupyterResult = false
|
||||
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", test.JupyterMessage)
|
||||
port, url, err := client.StartJupyterServer(ctx)
|
||||
startServer(t)
|
||||
client := connect(t)
|
||||
grpctest.JupyterMessage = "error message"
|
||||
grpctest.JupyterResult = false
|
||||
errorMessage := fmt.Sprintf("failed to start JupyterLab: %s", grpctest.JupyterMessage)
|
||||
port, url, err := client.StartJupyterServer(context.Background())
|
||||
if err.Error() != errorMessage {
|
||||
t.Fatalf("expected %v, got %v", errorMessage, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ func (s *server) GetRunningServer(ctx context.Context, in *jupyter.GetRunningSer
|
|||
}
|
||||
|
||||
// Starts the mock gRPC server listening on port 50051
|
||||
func StartServer() error {
|
||||
func StartServer(ctx context.Context) error {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen: %v", err)
|
||||
|
|
@ -44,9 +44,19 @@ func StartServer() error {
|
|||
|
||||
s := grpc.NewServer()
|
||||
jupyter.RegisterJupyterServerHostServer(s, &server{})
|
||||
if err := s.Serve(listener); err != nil {
|
||||
return fmt.Errorf("failed to serve: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
ch := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Serve(listener); err != nil {
|
||||
ch <- fmt.Errorf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.Stop()
|
||||
return ctx.Err()
|
||||
case err := <-ch:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package test
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"github.com/cli/cli/v2/pkg/liveshare"
|
||||
|
|
@ -11,23 +10,22 @@ import (
|
|||
)
|
||||
|
||||
type Session struct {
|
||||
channel ssh.Channel
|
||||
}
|
||||
|
||||
func (s *Session) KeepAlive(reason string) {
|
||||
}
|
||||
|
||||
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (liveshare.ChannelID, error) {
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
|
||||
if err != nil {
|
||||
return liveshare.ChannelID{}, err
|
||||
}
|
||||
s.channel = &Channel{conn}
|
||||
return liveshare.ChannelID{}, nil
|
||||
}
|
||||
|
||||
// Creates mock SSH channel connected to the mock gRPC server
|
||||
func (s *Session) OpenStreamingChannel(ctx context.Context, id liveshare.ChannelID) (ssh.Channel, error) {
|
||||
dialer := net.Dialer{}
|
||||
conn, err := dialer.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to the grpc server: %v", err)
|
||||
}
|
||||
return &Channel{
|
||||
conn: conn,
|
||||
}, nil
|
||||
return s.channel, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,17 +45,17 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) {
|
|||
defer safeClose(session, &err)
|
||||
|
||||
a.StartProgressIndicatorWithLabel("Starting JupyterLab on codespace")
|
||||
client, err := grpc.Connect(ctx, session, codespace.Connection.SessionToken)
|
||||
client, err := connectToGRPCServer(ctx, session, codespace.Connection.SessionToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to internal server: %w", err)
|
||||
return fmt.Errorf("failed to connect to internal server: %w", err)
|
||||
}
|
||||
defer safeClose(client, &err)
|
||||
|
||||
serverPort, serverUrl, err := client.StartJupyterServer(ctx)
|
||||
a.StopProgressIndicator()
|
||||
serverPort, serverUrl, err := startJupyterServer(ctx, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start JupyterLab server: %w", err)
|
||||
}
|
||||
a.StopProgressIndicator()
|
||||
|
||||
// Pass 0 to pick a random port
|
||||
listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", 0))
|
||||
|
|
@ -87,3 +87,27 @@ func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) {
|
|||
return nil // success
|
||||
}
|
||||
}
|
||||
|
||||
func connectToGRPCServer(ctx context.Context, session liveshareSession, token string) (*grpc.Client, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, grpc.ConnectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
client, err := grpc.Connect(ctx, session, token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error connecting to internal server: %w", err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func startJupyterServer(ctx context.Context, client *grpc.Client) (int, string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, grpc.RequestTimeout)
|
||||
defer cancel()
|
||||
|
||||
serverPort, serverUrl, err := client.StartJupyterServer(ctx)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("failed to start JupyterLab server: %w", err)
|
||||
}
|
||||
|
||||
return serverPort, serverUrl, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -99,41 +99,6 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser)
|
|||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
// Loops until we can connect to the address or the context is canceled.
|
||||
func WaitForPortConnection(ctx context.Context, address string) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
err := connectToAddr(address)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return nil // success
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Connects to and pings a given address to ensure that the server is shared and the port is forwarded.
|
||||
func connectToAddr(address string) error {
|
||||
// Verify that the port can be connected to
|
||||
conn, err := net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Send a ping and make sure it succeed
|
||||
_, err = conn.Write([]byte("ping"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (ChannelID, error) {
|
||||
id, err := fwd.session.StartSharing(ctx, fwd.name, fwd.remotePort)
|
||||
if err != nil {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue