diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 34685e1e8..829ba9a31 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -88,7 +88,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { tunnelClosed := make(chan error, 1) go func() { - tunnelClosed <- tunnel.Start(ctx) // error is non-nil + tunnelClosed <- tunnel.Forward(ctx) // error is non-nil }() cmdDone := make(chan error, 1) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index f83757ff8..522ef61cb 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -76,15 +76,15 @@ func ports(opts *portsOptions) error { devContainerCh := getDevContainer(ctx, apiClient, codespace) - liveShareClient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %v", err) } log.Println("Loading ports...") - ports, err := getPorts(ctx, liveShareClient) + ports, err := session.GetSharedServers(ctx) if err != nil { - return fmt.Errorf("error getting ports: %v", err) + return fmt.Errorf("error getting ports of shared servers: %v", err) } devContainerResult := <-devContainerCh @@ -116,20 +116,6 @@ func ports(opts *portsOptions) error { return nil } -func getPorts(ctx context.Context, lsclient *liveshare.Client) (liveshare.Ports, error) { - server, err := liveshare.NewServer(lsclient) - if err != nil { - return nil, fmt.Errorf("error creating server: %v", err) - } - - ports, err := server.GetSharedServers(ctx) - if err != nil { - return nil, fmt.Errorf("error getting shared servers: %v", err) - } - - return ports, nil -} - type devContainerResult struct { devContainer *devContainer err error @@ -219,22 +205,17 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, return fmt.Errorf("error getting Codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %v", err) } - server, err := liveshare.NewServer(lsclient) - if err != nil { - return fmt.Errorf("error creating server: %v", err) - } - port, err := strconv.Atoi(sourcePort) if err != nil { return fmt.Errorf("error reading port number: %v", err) } - if err := server.UpdateSharedVisibility(ctx, port, public); err != nil { + if err := session.UpdateSharedVisibility(ctx, port, public); err != nil { return fmt.Errorf("error update port to public: %v", err) } @@ -285,29 +266,26 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro return fmt.Errorf("error getting Codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %v", err) } - server, err := liveshare.NewServer(lsclient) - if err != nil { - return fmt.Errorf("error creating server: %v", err) - } - g, gctx := errgroup.WithContext(ctx) for _, portPair := range portPairs { pp := portPair + // TODO(adonovan): fix data race on Session between + // StartSharing and NewPortForwarder. srcstr := strconv.Itoa(portPair.src) - if err := server.StartSharing(gctx, "share-"+srcstr, pp.src); err != nil { + if err := session.StartSharing(gctx, "share-"+srcstr, pp.src); err != nil { return fmt.Errorf("start sharing port: %v", err) } g.Go(func() error { log.Println("Forwarding port: " + srcstr + " ==> " + strconv.Itoa(pp.dst)) - portForwarder := liveshare.NewPortForwarder(lsclient, server, pp.dst) - if err := portForwarder.Start(gctx); err != nil { + portForwarder := liveshare.NewPortForwarder(session, pp.dst) + if err := portForwarder.Forward(gctx); err != nil { return fmt.Errorf("error forwarding port: %v", err) } @@ -315,6 +293,9 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro }) } + // TODO(adonovan): fix: the waits for _all_ goroutines to terminate. + // If there are multiple ports, one long-lived successful connection + // will hide errors from any that fail. if err := g.Wait(); err != nil { return err } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index aa85c33a3..91329b28c 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -56,20 +56,17 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo return fmt.Errorf("get or choose Codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %v", err) } - remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, lsclient, log) + remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log) if err != nil { return fmt.Errorf("error getting ssh server details: %v", err) } - terminal, err := liveshare.NewTerminal(lsclient) - if err != nil { - return fmt.Errorf("error creating Live Share terminal: %v", err) - } + terminal := liveshare.NewTerminal(session) log.Print("Preparing SSH...") if sshProfile == "" { @@ -93,7 +90,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo } } - tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", localSSHServerPort, remoteSSHServerPort) + tunnel, err := codespaces.NewPortForwarder(ctx, session, "sshd", localSSHServerPort, remoteSSHServerPort) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) } @@ -105,7 +102,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo tunnelClosed := make(chan error) go func() { - tunnelClosed <- tunnel.Start(ctx) // error is always non-nil + tunnelClosed <- tunnel.Forward(ctx) // error is always non-nil }() shellClosed := make(chan error) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 90f676d28..86b703d92 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -73,7 +73,7 @@ func connectionReady(codespace *api.Codespace) bool { codespace.Environment.State == api.CodespaceEnvironmentStateAvailable } -func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (client *liveshare.Client, err error) { +func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { startedCodespace = true @@ -96,6 +96,7 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use return nil, errors.New("timed out while waiting for the Codespace to start") } + var err error codespace, err = apiClient.GetCodespace(ctx, token, userLogin, codespace.Name) if err != nil { return nil, fmt.Errorf("error getting Codespace: %v", err) @@ -117,14 +118,10 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use }), ) if err != nil { - return nil, fmt.Errorf("error creating Live Share: %v", err) + return nil, fmt.Errorf("error creating Live Share client: %v", err) } - if err := lsclient.Join(ctx); err != nil { - return nil, fmt.Errorf("error joining Live Share client: %v", err) - } - - return lsclient, nil + return lsclient.JoinWorkspace(ctx) } func GetOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 7a82e6af7..a8f1834d4 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -32,37 +32,33 @@ func UnusedPort() (int, error) { return l.Addr().(*net.TCPAddr).Port, nil } -// NewPortForwarder returns a new port forwarder for traffic between -// the Live Share client and the specified local and remote ports. +// NewPortForwarder returns a new port forwarder that forwards traffic between +// the specified local and remote ports over the provided Live Share session. // // The session name is used (along with the port) to generate // names for streams, and may appear in error messages. -func NewPortForwarder(ctx context.Context, client *liveshare.Client, sessionName string, localSSHPort, remoteSSHPort int) (*liveshare.PortForwarder, error) { +func NewPortForwarder(ctx context.Context, session *liveshare.Session, sessionName string, localSSHPort, remoteSSHPort int) (*liveshare.PortForwarder, error) { if localSSHPort == 0 { return nil, fmt.Errorf("a local port must be provided") } - server, err := liveshare.NewServer(client) - if err != nil { - return nil, fmt.Errorf("new liveshare server: %v", err) - } + // TODO(adonovan): fix data race on Session between + // StartSharing and NewPortForwarder. Perhaps combine the + // operations in go-liveshare? - if err := server.StartSharing(ctx, "sshd", remoteSSHPort); err != nil { + if err := session.StartSharing(ctx, "sshd", remoteSSHPort); err != nil { return nil, fmt.Errorf("sharing sshd port: %v", err) } - return liveshare.NewPortForwarder(client, server, localSSHPort), nil + return liveshare.NewPortForwarder(session, localSSHPort), nil } // StartSSHServer installs (if necessary) and starts the SSH in the codespace. // It returns the remote port where it is running, the user to log in with, or an error if something failed. -func StartSSHServer(ctx context.Context, client *liveshare.Client, log logger) (serverPort int, user string, err error) { +func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) (serverPort int, user string, err error) { log.Println("Fetching SSH details...") - sshServer, err := liveshare.NewSSHServer(client) - if err != nil { - return 0, "", fmt.Errorf("error creating live share: %v", err) - } + sshServer := session.SSHServer() sshServerStartResult, err := sshServer.StartRemoteServer(ctx) if err != nil { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index a58e2b235..f0052e72c 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -40,7 +40,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("getting Codespace token: %v", err) } - lsclient, err := ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("connect to Live Share: %v", err) } @@ -50,19 +50,19 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return err } - remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, lsclient, log) + remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, session, log) if err != nil { return fmt.Errorf("error getting ssh server details: %v", err) } - fwd, err := NewPortForwarder(ctx, lsclient, "sshd", localSSHPort, remoteSSHServerPort) + fwd, err := NewPortForwarder(ctx, session, "sshd", localSSHPort, remoteSSHServerPort) if err != nil { return fmt.Errorf("creating port forwarder: %v", err) } tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness go func() { - tunnelClosed <- fwd.Start(ctx) // error is non-nil + tunnelClosed <- fwd.Forward(ctx) // error is non-nil }() t := time.NewTicker(1 * time.Second)