Merge pull request #99 from github/runcommand
preparatory cleanups to ssh tunnel/port forwarding code
This commit is contained in:
commit
c0fbb7e9fb
4 changed files with 147 additions and 150 deletions
|
|
@ -1,7 +1,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
|
@ -24,7 +23,7 @@ func newLogsCmd() *cobra.Command {
|
|||
if len(args) > 0 {
|
||||
codespaceName = args[0]
|
||||
}
|
||||
return logs(tail, codespaceName)
|
||||
return logs(context.Background(), tail, codespaceName)
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -37,9 +36,12 @@ func init() {
|
|||
rootCmd.AddCommand(newLogsCmd())
|
||||
}
|
||||
|
||||
func logs(tail bool, codespaceName string) error {
|
||||
func logs(ctx context.Context, tail bool, codespaceName string) error {
|
||||
// Ensure all child tasks (port forwarding, remote exec) terminate before return.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
apiClient := api.New(os.Getenv("GITHUB_TOKEN"))
|
||||
ctx := context.Background()
|
||||
log := output.NewLogger(os.Stdout, os.Stderr, false)
|
||||
|
||||
user, err := apiClient.GetUser(ctx)
|
||||
|
|
@ -57,12 +59,17 @@ func logs(tail bool, codespaceName string) error {
|
|||
return fmt.Errorf("connecting to Live Share: %v", err)
|
||||
}
|
||||
|
||||
localSSHPort, err := codespaces.UnusedPort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, lsclient, log)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting ssh server details: %v", err)
|
||||
}
|
||||
|
||||
tunnelPort, connClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, 0, remoteSSHServerPort)
|
||||
tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", localSSHPort, remoteSSHServerPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("make ssh tunnel: %v", err)
|
||||
}
|
||||
|
|
@ -73,42 +80,30 @@ func logs(tail bool, codespaceName string) error {
|
|||
}
|
||||
|
||||
dst := fmt.Sprintf("%s@localhost", sshUser)
|
||||
stdout, err := codespaces.RunCommand(
|
||||
ctx, tunnelPort, dst, fmt.Sprintf("%v /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
|
||||
cmd := codespaces.NewRemoteCommand(
|
||||
ctx, localSSHPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("run command: %v", err)
|
||||
}
|
||||
|
||||
done := make(chan error)
|
||||
// Error channels are buffered so that neither sending goroutine gets stuck.
|
||||
|
||||
tunnelClosed := make(chan error, 1)
|
||||
go func() {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
for scanner.Scan() {
|
||||
fmt.Println(scanner.Text())
|
||||
}
|
||||
tunnelClosed <- tunnel.Start(ctx) // error is non-nil
|
||||
}()
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
done <- fmt.Errorf("error scanning: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := stdout.Close(); err != nil {
|
||||
done <- fmt.Errorf("close stdout: %v", err)
|
||||
return
|
||||
}
|
||||
done <- nil
|
||||
cmdDone := make(chan error, 1)
|
||||
go func() {
|
||||
cmdDone <- cmd.Run()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-connClosed:
|
||||
if err != nil {
|
||||
return fmt.Errorf("connection closed: %v", err)
|
||||
}
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case err := <-tunnelClosed:
|
||||
return fmt.Errorf("connection closed: %v", err)
|
||||
|
||||
return nil
|
||||
case err := <-cmdDone:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error retrieving logs: %v", err)
|
||||
}
|
||||
return nil // success
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ func newSSHCmd() *cobra.Command {
|
|||
Short: "SSH into a Codespace",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return ssh(sshProfile, codespaceName, sshServerPort)
|
||||
return ssh(context.Background(), sshProfile, codespaceName, sshServerPort)
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -38,9 +38,12 @@ func init() {
|
|||
rootCmd.AddCommand(newSSHCmd())
|
||||
}
|
||||
|
||||
func ssh(sshProfile, codespaceName string, sshServerPort int) error {
|
||||
func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPort int) error {
|
||||
// Ensure all child tasks (e.g. port forwarding) terminate before return.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
apiClient := api.New(os.Getenv("GITHUB_TOKEN"))
|
||||
ctx := context.Background()
|
||||
log := output.NewLogger(os.Stdout, os.Stderr, false)
|
||||
|
||||
user, err := apiClient.GetUser(ctx)
|
||||
|
|
@ -81,7 +84,16 @@ func ssh(sshProfile, codespaceName string, sshServerPort int) error {
|
|||
}
|
||||
log.Print("\n")
|
||||
|
||||
tunnelPort, tunnelClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, sshServerPort, remoteSSHServerPort)
|
||||
usingCustomPort := true
|
||||
if localSSHServerPort == 0 {
|
||||
usingCustomPort = false // suppress log of command line in Shell
|
||||
localSSHServerPort, err = codespaces.UnusedPort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", localSSHServerPort, remoteSSHServerPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("make ssh tunnel: %v", err)
|
||||
}
|
||||
|
|
@ -91,22 +103,27 @@ func ssh(sshProfile, codespaceName string, sshServerPort int) error {
|
|||
connectDestination = fmt.Sprintf("%s@localhost", sshUser)
|
||||
}
|
||||
|
||||
usingCustomPort := tunnelPort == sshServerPort
|
||||
connClosed := codespaces.ConnectToTunnel(ctx, log, tunnelPort, connectDestination, usingCustomPort)
|
||||
tunnelClosed := make(chan error)
|
||||
go func() {
|
||||
tunnelClosed <- tunnel.Start(ctx) // error is always non-nil
|
||||
}()
|
||||
|
||||
shellClosed := make(chan error)
|
||||
go func() {
|
||||
shellClosed <- codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort)
|
||||
}()
|
||||
|
||||
log.Println("Ready...")
|
||||
select {
|
||||
case err := <-tunnelClosed:
|
||||
if err != nil {
|
||||
return fmt.Errorf("tunnel closed: %v", err)
|
||||
}
|
||||
case err := <-connClosed:
|
||||
if err != nil {
|
||||
return fmt.Errorf("connection closed: %v", err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("tunnel closed: %v", err)
|
||||
|
||||
return nil
|
||||
case err := <-shellClosed:
|
||||
if err != nil {
|
||||
return fmt.Errorf("shell closed: %v", err)
|
||||
}
|
||||
return nil // success
|
||||
}
|
||||
}
|
||||
|
||||
func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) {
|
||||
|
|
|
|||
|
|
@ -4,45 +4,54 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/github/go-liveshare"
|
||||
)
|
||||
|
||||
func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, localSSHPort int, remoteSSHPort int) (int, <-chan error, error) {
|
||||
tunnelClosed := make(chan error)
|
||||
|
||||
server, err := liveshare.NewServer(lsclient)
|
||||
// UnusedPort returns the number of a local TCP port that is currently
|
||||
// unbound, or an error if none was available.
|
||||
//
|
||||
// Use of this function carries an inherent risk of a time-of-check to
|
||||
// time-of-use race against other processes.
|
||||
func UnusedPort() (int, error) {
|
||||
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("new Live Share server: %v", err)
|
||||
return 0, fmt.Errorf("internal error while choosing port: %v", err)
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().Unix())
|
||||
port := rand.Intn(9999-2000) + 2000 // improve this obviously
|
||||
if localSSHPort != 0 {
|
||||
port = localSSHPort
|
||||
l, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("choosing available port: %v", err)
|
||||
}
|
||||
defer l.Close()
|
||||
return l.Addr().(*net.TCPAddr).Port, nil
|
||||
}
|
||||
|
||||
// NewPortForwarder returns a new port forwarder for traffic between
|
||||
// the Live Share client and the specified local and remote ports.
|
||||
//
|
||||
// The session name is used (along with the port) to generate
|
||||
// names for streams, and may appear in error messages.
|
||||
func NewPortForwarder(ctx context.Context, client *liveshare.Client, sessionName string, localSSHPort, remoteSSHPort int) (*liveshare.PortForwarder, error) {
|
||||
if localSSHPort == 0 {
|
||||
return nil, fmt.Errorf("a local port must be provided")
|
||||
}
|
||||
|
||||
server, err := liveshare.NewServer(client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new liveshare server: %v", err)
|
||||
}
|
||||
|
||||
if err := server.StartSharing(ctx, "sshd", remoteSSHPort); err != nil {
|
||||
return 0, nil, fmt.Errorf("sharing sshd port: %v", err)
|
||||
return nil, fmt.Errorf("sharing sshd port: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
portForwarder := liveshare.NewPortForwarder(lsclient, server, port)
|
||||
if err := portForwarder.Start(ctx); err != nil {
|
||||
tunnelClosed <- fmt.Errorf("forwarding port: %v", err)
|
||||
return
|
||||
}
|
||||
tunnelClosed <- nil
|
||||
}()
|
||||
|
||||
return port, tunnelClosed, nil
|
||||
return liveshare.NewPortForwarder(client, server, localSSHPort), nil
|
||||
}
|
||||
|
||||
// StartSSHServer installs (if necessary) and starts the SSH in the codespace.
|
||||
|
|
@ -72,72 +81,41 @@ func StartSSHServer(ctx context.Context, client *liveshare.Client, log logger) (
|
|||
return portInt, sshServerStartResult.User, nil
|
||||
}
|
||||
|
||||
func makeSSHArgs(port int, dst, cmd string) ([]string, []string) {
|
||||
connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"}
|
||||
cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression
|
||||
|
||||
if cmd != "" {
|
||||
cmdArgs = append(cmdArgs, cmd)
|
||||
}
|
||||
|
||||
return cmdArgs, connArgs
|
||||
}
|
||||
|
||||
func ConnectToTunnel(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) <-chan error {
|
||||
connClosed := make(chan error)
|
||||
args, connArgs := makeSSHArgs(port, destination, "")
|
||||
// Shell runs an interactive secure shell over an existing
|
||||
// port-forwarding session. It runs until the shell is terminated
|
||||
// (including by cancellation of the context).
|
||||
func Shell(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) error {
|
||||
cmd, connArgs := newSSHCommand(ctx, port, destination, "")
|
||||
|
||||
if usingCustomPort {
|
||||
log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " "))
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
// NewRemoteCommand returns an exec.Cmd that will securely run a shell
|
||||
// command on the remote machine.
|
||||
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd {
|
||||
cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// newSSHCommand populates an exec.Cmd to run a command (or if blank,
|
||||
// an interactive shell) over ssh.
|
||||
func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cmd, []string) {
|
||||
connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"}
|
||||
// TODO(adonovan): eliminate X11 and X11Trust flags where unneeded.
|
||||
cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression
|
||||
|
||||
if command != "" {
|
||||
cmdArgs = append(cmdArgs, command)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "ssh", cmdArgs...)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
go func() {
|
||||
connClosed <- cmd.Run()
|
||||
}()
|
||||
|
||||
return connClosed
|
||||
}
|
||||
|
||||
type command struct {
|
||||
Cmd *exec.Cmd
|
||||
StdoutPipe io.ReadCloser
|
||||
}
|
||||
|
||||
func newCommand(cmd *exec.Cmd) (*command, error) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stdout pipe: %v", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("cmd start: %v", err)
|
||||
}
|
||||
|
||||
return &command{
|
||||
Cmd: cmd,
|
||||
StdoutPipe: stdoutPipe,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *command) Read(p []byte) (int, error) {
|
||||
return c.StdoutPipe.Read(p)
|
||||
}
|
||||
|
||||
func (c *command) Close() error {
|
||||
if err := c.StdoutPipe.Close(); err != nil {
|
||||
return fmt.Errorf("close stdout: %v", err)
|
||||
}
|
||||
|
||||
return c.Cmd.Wait()
|
||||
}
|
||||
|
||||
func RunCommand(ctx context.Context, tunnelPort int, destination, cmdString string) (io.ReadCloser, error) {
|
||||
args, _ := makeSSHArgs(tunnelPort, destination, cmdString)
|
||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||
return newCommand(cmd)
|
||||
return cmd, connArgs
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
package codespaces
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
|
@ -33,7 +33,7 @@ type PostCreateState struct {
|
|||
|
||||
// PollPostCreateStates watches for state changes in a codespace,
|
||||
// and calls the supplied poller for each batch of state changes.
|
||||
// It runs until the context is cancelled or SSH tunnel is closed.
|
||||
// It runs until it encounters an error, including cancellation of the context.
|
||||
func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error {
|
||||
token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name)
|
||||
if err != nil {
|
||||
|
|
@ -45,27 +45,39 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u
|
|||
return fmt.Errorf("connect to Live Share: %v", err)
|
||||
}
|
||||
|
||||
localSSHPort, err := UnusedPort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, lsclient, log)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting ssh server details: %v", err)
|
||||
}
|
||||
|
||||
tunnelPort, connClosed, err := MakeSSHTunnel(ctx, lsclient, 0, remoteSSHServerPort)
|
||||
fwd, err := NewPortForwarder(ctx, lsclient, "sshd", localSSHPort, remoteSSHServerPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("make ssh tunnel: %v", err)
|
||||
return fmt.Errorf("creating port forwarder: %v", err)
|
||||
}
|
||||
|
||||
tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
|
||||
go func() {
|
||||
tunnelClosed <- fwd.Start(ctx) // error is non-nil
|
||||
}()
|
||||
|
||||
t := time.NewTicker(1 * time.Second)
|
||||
defer t.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case err := <-connClosed:
|
||||
return fmt.Errorf("connection closed: %v", err)
|
||||
return ctx.Err()
|
||||
|
||||
case err := <-tunnelClosed:
|
||||
return fmt.Errorf("connection failed: %v", err)
|
||||
|
||||
case <-t.C:
|
||||
states, err := getPostCreateOutput(ctx, tunnelPort, codespace, sshUser)
|
||||
states, err := getPostCreateOutput(ctx, localSSHPort, codespace, sshUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get post create output: %v", err)
|
||||
}
|
||||
|
|
@ -76,24 +88,19 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u
|
|||
}
|
||||
|
||||
func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace, user string) ([]PostCreateState, error) {
|
||||
stdout, err := RunCommand(
|
||||
cmd := NewRemoteCommand(
|
||||
ctx, tunnelPort, fmt.Sprintf("%s@localhost", user),
|
||||
"cat /workspaces/.codespaces/shared/postCreateOutput.json",
|
||||
)
|
||||
if err != nil {
|
||||
stdout := new(bytes.Buffer)
|
||||
cmd.Stdout = stdout
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("run command: %v", err)
|
||||
}
|
||||
defer stdout.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(stdout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read output: %v", err)
|
||||
}
|
||||
|
||||
var output struct {
|
||||
Steps []PostCreateState `json:"steps"`
|
||||
}
|
||||
if err := json.Unmarshal(b, &output); err != nil {
|
||||
if err := json.Unmarshal(stdout.Bytes(), &output); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal output: %v", err)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue