Merge pull request #99 from github/runcommand

preparatory cleanups to ssh tunnel/port forwarding code
This commit is contained in:
Alan Donovan 2021-08-31 17:30:00 -04:00 committed by GitHub
commit c0fbb7e9fb
4 changed files with 147 additions and 150 deletions

View file

@ -1,7 +1,6 @@
package main
import (
"bufio"
"context"
"fmt"
"os"
@ -24,7 +23,7 @@ func newLogsCmd() *cobra.Command {
if len(args) > 0 {
codespaceName = args[0]
}
return logs(tail, codespaceName)
return logs(context.Background(), tail, codespaceName)
},
}
@ -37,9 +36,12 @@ func init() {
rootCmd.AddCommand(newLogsCmd())
}
func logs(tail bool, codespaceName string) error {
func logs(ctx context.Context, tail bool, codespaceName string) error {
// Ensure all child tasks (port forwarding, remote exec) terminate before return.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
apiClient := api.New(os.Getenv("GITHUB_TOKEN"))
ctx := context.Background()
log := output.NewLogger(os.Stdout, os.Stderr, false)
user, err := apiClient.GetUser(ctx)
@ -57,12 +59,17 @@ func logs(tail bool, codespaceName string) error {
return fmt.Errorf("connecting to Live Share: %v", err)
}
localSSHPort, err := codespaces.UnusedPort()
if err != nil {
return err
}
remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, lsclient, log)
if err != nil {
return fmt.Errorf("error getting ssh server details: %v", err)
}
tunnelPort, connClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, 0, remoteSSHServerPort)
tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", localSSHPort, remoteSSHServerPort)
if err != nil {
return fmt.Errorf("make ssh tunnel: %v", err)
}
@ -73,42 +80,30 @@ func logs(tail bool, codespaceName string) error {
}
dst := fmt.Sprintf("%s@localhost", sshUser)
stdout, err := codespaces.RunCommand(
ctx, tunnelPort, dst, fmt.Sprintf("%v /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
cmd := codespaces.NewRemoteCommand(
ctx, localSSHPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
)
if err != nil {
return fmt.Errorf("run command: %v", err)
}
done := make(chan error)
// Error channels are buffered so that neither sending goroutine gets stuck.
tunnelClosed := make(chan error, 1)
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
fmt.Println(scanner.Text())
}
tunnelClosed <- tunnel.Start(ctx) // error is non-nil
}()
if err := scanner.Err(); err != nil {
done <- fmt.Errorf("error scanning: %v", err)
return
}
if err := stdout.Close(); err != nil {
done <- fmt.Errorf("close stdout: %v", err)
return
}
done <- nil
cmdDone := make(chan error, 1)
go func() {
cmdDone <- cmd.Run()
}()
select {
case err := <-connClosed:
if err != nil {
return fmt.Errorf("connection closed: %v", err)
}
case err := <-done:
if err != nil {
return err
}
}
case err := <-tunnelClosed:
return fmt.Errorf("connection closed: %v", err)
return nil
case err := <-cmdDone:
if err != nil {
return fmt.Errorf("error retrieving logs: %v", err)
}
return nil // success
}
}

View file

@ -23,7 +23,7 @@ func newSSHCmd() *cobra.Command {
Short: "SSH into a Codespace",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
return ssh(sshProfile, codespaceName, sshServerPort)
return ssh(context.Background(), sshProfile, codespaceName, sshServerPort)
},
}
@ -38,9 +38,12 @@ func init() {
rootCmd.AddCommand(newSSHCmd())
}
func ssh(sshProfile, codespaceName string, sshServerPort int) error {
func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPort int) error {
// Ensure all child tasks (e.g. port forwarding) terminate before return.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
apiClient := api.New(os.Getenv("GITHUB_TOKEN"))
ctx := context.Background()
log := output.NewLogger(os.Stdout, os.Stderr, false)
user, err := apiClient.GetUser(ctx)
@ -81,7 +84,16 @@ func ssh(sshProfile, codespaceName string, sshServerPort int) error {
}
log.Print("\n")
tunnelPort, tunnelClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, sshServerPort, remoteSSHServerPort)
usingCustomPort := true
if localSSHServerPort == 0 {
usingCustomPort = false // suppress log of command line in Shell
localSSHServerPort, err = codespaces.UnusedPort()
if err != nil {
return err
}
}
tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", localSSHServerPort, remoteSSHServerPort)
if err != nil {
return fmt.Errorf("make ssh tunnel: %v", err)
}
@ -91,22 +103,27 @@ func ssh(sshProfile, codespaceName string, sshServerPort int) error {
connectDestination = fmt.Sprintf("%s@localhost", sshUser)
}
usingCustomPort := tunnelPort == sshServerPort
connClosed := codespaces.ConnectToTunnel(ctx, log, tunnelPort, connectDestination, usingCustomPort)
tunnelClosed := make(chan error)
go func() {
tunnelClosed <- tunnel.Start(ctx) // error is always non-nil
}()
shellClosed := make(chan error)
go func() {
shellClosed <- codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort)
}()
log.Println("Ready...")
select {
case err := <-tunnelClosed:
if err != nil {
return fmt.Errorf("tunnel closed: %v", err)
}
case err := <-connClosed:
if err != nil {
return fmt.Errorf("connection closed: %v", err)
}
}
return fmt.Errorf("tunnel closed: %v", err)
return nil
case err := <-shellClosed:
if err != nil {
return fmt.Errorf("shell closed: %v", err)
}
return nil // success
}
}
func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) {

View file

@ -4,45 +4,54 @@ import (
"context"
"errors"
"fmt"
"io"
"math/rand"
"net"
"os"
"os/exec"
"strconv"
"strings"
"time"
"github.com/github/go-liveshare"
)
func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, localSSHPort int, remoteSSHPort int) (int, <-chan error, error) {
tunnelClosed := make(chan error)
server, err := liveshare.NewServer(lsclient)
// UnusedPort returns the number of a local TCP port that is currently
// unbound, or an error if none was available.
//
// Use of this function carries an inherent risk of a time-of-check to
// time-of-use race against other processes.
func UnusedPort() (int, error) {
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
return 0, nil, fmt.Errorf("new Live Share server: %v", err)
return 0, fmt.Errorf("internal error while choosing port: %v", err)
}
rand.Seed(time.Now().Unix())
port := rand.Intn(9999-2000) + 2000 // improve this obviously
if localSSHPort != 0 {
port = localSSHPort
l, err := net.ListenTCP("tcp", addr)
if err != nil {
return 0, fmt.Errorf("choosing available port: %v", err)
}
defer l.Close()
return l.Addr().(*net.TCPAddr).Port, nil
}
// NewPortForwarder returns a new port forwarder for traffic between
// the Live Share client and the specified local and remote ports.
//
// The session name is used (along with the port) to generate
// names for streams, and may appear in error messages.
func NewPortForwarder(ctx context.Context, client *liveshare.Client, sessionName string, localSSHPort, remoteSSHPort int) (*liveshare.PortForwarder, error) {
if localSSHPort == 0 {
return nil, fmt.Errorf("a local port must be provided")
}
server, err := liveshare.NewServer(client)
if err != nil {
return nil, fmt.Errorf("new liveshare server: %v", err)
}
if err := server.StartSharing(ctx, "sshd", remoteSSHPort); err != nil {
return 0, nil, fmt.Errorf("sharing sshd port: %v", err)
return nil, fmt.Errorf("sharing sshd port: %v", err)
}
go func() {
portForwarder := liveshare.NewPortForwarder(lsclient, server, port)
if err := portForwarder.Start(ctx); err != nil {
tunnelClosed <- fmt.Errorf("forwarding port: %v", err)
return
}
tunnelClosed <- nil
}()
return port, tunnelClosed, nil
return liveshare.NewPortForwarder(client, server, localSSHPort), nil
}
// StartSSHServer installs (if necessary) and starts the SSH in the codespace.
@ -72,72 +81,41 @@ func StartSSHServer(ctx context.Context, client *liveshare.Client, log logger) (
return portInt, sshServerStartResult.User, nil
}
func makeSSHArgs(port int, dst, cmd string) ([]string, []string) {
connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"}
cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression
if cmd != "" {
cmdArgs = append(cmdArgs, cmd)
}
return cmdArgs, connArgs
}
func ConnectToTunnel(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) <-chan error {
connClosed := make(chan error)
args, connArgs := makeSSHArgs(port, destination, "")
// Shell runs an interactive secure shell over an existing
// port-forwarding session. It runs until the shell is terminated
// (including by cancellation of the context).
func Shell(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) error {
cmd, connArgs := newSSHCommand(ctx, port, destination, "")
if usingCustomPort {
log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " "))
}
cmd := exec.CommandContext(ctx, "ssh", args...)
return cmd.Run()
}
// NewRemoteCommand returns an exec.Cmd that will securely run a shell
// command on the remote machine.
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd {
cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command)
return cmd
}
// newSSHCommand populates an exec.Cmd to run a command (or if blank,
// an interactive shell) over ssh.
func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cmd, []string) {
connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"}
// TODO(adonovan): eliminate X11 and X11Trust flags where unneeded.
cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression
if command != "" {
cmdArgs = append(cmdArgs, command)
}
cmd := exec.CommandContext(ctx, "ssh", cmdArgs...)
cmd.Stdout = os.Stdout
cmd.Stdin = os.Stdin
cmd.Stderr = os.Stderr
go func() {
connClosed <- cmd.Run()
}()
return connClosed
}
type command struct {
Cmd *exec.Cmd
StdoutPipe io.ReadCloser
}
func newCommand(cmd *exec.Cmd) (*command, error) {
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("create stdout pipe: %v", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("cmd start: %v", err)
}
return &command{
Cmd: cmd,
StdoutPipe: stdoutPipe,
}, nil
}
func (c *command) Read(p []byte) (int, error) {
return c.StdoutPipe.Read(p)
}
func (c *command) Close() error {
if err := c.StdoutPipe.Close(); err != nil {
return fmt.Errorf("close stdout: %v", err)
}
return c.Cmd.Wait()
}
func RunCommand(ctx context.Context, tunnelPort int, destination, cmdString string) (io.ReadCloser, error) {
args, _ := makeSSHArgs(tunnelPort, destination, cmdString)
cmd := exec.CommandContext(ctx, "ssh", args...)
return newCommand(cmd)
return cmd, connArgs
}

View file

@ -1,10 +1,10 @@
package codespaces
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"strings"
"time"
@ -33,7 +33,7 @@ type PostCreateState struct {
// PollPostCreateStates watches for state changes in a codespace,
// and calls the supplied poller for each batch of state changes.
// It runs until the context is cancelled or SSH tunnel is closed.
// It runs until it encounters an error, including cancellation of the context.
func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error {
token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name)
if err != nil {
@ -45,27 +45,39 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u
return fmt.Errorf("connect to Live Share: %v", err)
}
localSSHPort, err := UnusedPort()
if err != nil {
return err
}
remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, lsclient, log)
if err != nil {
return fmt.Errorf("error getting ssh server details: %v", err)
}
tunnelPort, connClosed, err := MakeSSHTunnel(ctx, lsclient, 0, remoteSSHServerPort)
fwd, err := NewPortForwarder(ctx, lsclient, "sshd", localSSHPort, remoteSSHServerPort)
if err != nil {
return fmt.Errorf("make ssh tunnel: %v", err)
return fmt.Errorf("creating port forwarder: %v", err)
}
tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
go func() {
tunnelClosed <- fwd.Start(ctx) // error is non-nil
}()
t := time.NewTicker(1 * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
return nil
case err := <-connClosed:
return fmt.Errorf("connection closed: %v", err)
return ctx.Err()
case err := <-tunnelClosed:
return fmt.Errorf("connection failed: %v", err)
case <-t.C:
states, err := getPostCreateOutput(ctx, tunnelPort, codespace, sshUser)
states, err := getPostCreateOutput(ctx, localSSHPort, codespace, sshUser)
if err != nil {
return fmt.Errorf("get post create output: %v", err)
}
@ -76,24 +88,19 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u
}
func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace, user string) ([]PostCreateState, error) {
stdout, err := RunCommand(
cmd := NewRemoteCommand(
ctx, tunnelPort, fmt.Sprintf("%s@localhost", user),
"cat /workspaces/.codespaces/shared/postCreateOutput.json",
)
if err != nil {
stdout := new(bytes.Buffer)
cmd.Stdout = stdout
if err := cmd.Run(); err != nil {
return nil, fmt.Errorf("run command: %v", err)
}
defer stdout.Close()
b, err := ioutil.ReadAll(stdout)
if err != nil {
return nil, fmt.Errorf("read output: %v", err)
}
var output struct {
Steps []PostCreateState `json:"steps"`
}
if err := json.Unmarshal(b, &output); err != nil {
if err := json.Unmarshal(stdout.Bytes(), &output); err != nil {
return nil, fmt.Errorf("unmarshal output: %v", err)
}