This commit is contained in:
Alan Donovan 2021-09-02 17:04:07 -04:00
parent 1162c8adff
commit 981b2545bc
5 changed files with 33 additions and 37 deletions

View file

@ -3,6 +3,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"os" "os"
"github.com/github/ghcs/api" "github.com/github/ghcs/api"
@ -60,10 +61,13 @@ func logs(ctx context.Context, tail bool, codespaceName string) error {
return fmt.Errorf("connecting to Live Share: %v", err) return fmt.Errorf("connecting to Live Share: %v", err)
} }
localSSHPort, err := codespaces.UnusedPort() // Ensure local port is listening before client (getPostCreateOutput) connects.
listen, err := liveshare.Listen(0) // zero => arbitrary
if err != nil { if err != nil {
return err return err
} }
defer listen.Close()
localPort := listen.Addr().(*net.TCPAddr).Port
remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log) remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log)
if err != nil { if err != nil {
@ -77,7 +81,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error {
dst := fmt.Sprintf("%s@localhost", sshUser) dst := fmt.Sprintf("%s@localhost", sshUser)
cmd := codespaces.NewRemoteCommand( cmd := codespaces.NewRemoteCommand(
ctx, localSSHPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
) )
// Error channels are buffered so that neither sending goroutine gets stuck. // Error channels are buffered so that neither sending goroutine gets stuck.
@ -85,7 +89,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error {
tunnelClosed := make(chan error, 1) tunnelClosed := make(chan error, 1)
go func() { go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil tunnelClosed <- fwd.ForwardToLocalPort(ctx, listen) // error is non-nil
}() }()
cmdDone := make(chan error, 1) cmdDone := make(chan error, 1)

View file

@ -277,11 +277,18 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro
defer cancel() defer cancel()
for _, pair := range portPairs { for _, pair := range portPairs {
pair := pair pair := pair
log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local)
name := fmt.Sprintf("share-%d", pair.remote)
go func() { go func() {
listen, err := liveshare.Listen(pair.local)
if err != nil {
errc <- err
return
}
defer listen.Close()
log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local)
name := fmt.Sprintf("share-%d", pair.remote)
fwd := liveshare.NewPortForwarder(session, name, pair.remote) fwd := liveshare.NewPortForwarder(session, name, pair.remote)
errc <- fwd.ForwardToLocalPort(ctx, pair.local) // error always non-nil errc <- fwd.ForwardToLocalPort(ctx, listen) // error always non-nil
}() }()
} }

View file

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"context" "context"
"fmt" "fmt"
"net"
"os" "os"
"strings" "strings"
@ -81,14 +82,15 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo
} }
log.Print("\n") log.Print("\n")
usingCustomPort := true usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell
if localSSHServerPort == 0 {
usingCustomPort = false // suppress log of command line in Shell // Ensure local port is listening before client (Shell) connects.
localSSHServerPort, err = codespaces.UnusedPort() listen, err := liveshare.Listen(localSSHServerPort)
if err != nil { if err != nil {
return err return err
}
} }
defer listen.Close()
localSSHServerPort = listen.Addr().(*net.TCPAddr).Port
connectDestination := sshProfile connectDestination := sshProfile
if connectDestination == "" { if connectDestination == "" {
@ -98,7 +100,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo
tunnelClosed := make(chan error) tunnelClosed := make(chan error)
go func() { go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHServerPort) // error is always non-nil tunnelClosed <- fwd.ForwardToLocalPort(ctx, listen) // error is always non-nil
}() }()
shellClosed := make(chan error) shellClosed := make(chan error)

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"os" "os"
"os/exec" "os/exec"
"strconv" "strconv"
@ -13,25 +12,6 @@ import (
"github.com/github/go-liveshare" "github.com/github/go-liveshare"
) )
// UnusedPort returns the number of a local TCP port that is currently
// unbound, or an error if none was available.
//
// Use of this function carries an inherent risk of a time-of-check to
// time-of-use race against other processes.
func UnusedPort() (int, error) {
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
return 0, fmt.Errorf("internal error while choosing port: %v", err)
}
l, err := net.ListenTCP("tcp", addr)
if err != nil {
return 0, fmt.Errorf("choosing available port: %v", err)
}
defer l.Close()
return l.Addr().(*net.TCPAddr).Port, nil
}
// StartSSHServer installs (if necessary) and starts the SSH in the codespace. // StartSSHServer installs (if necessary) and starts the SSH in the codespace.
// It returns the remote port where it is running, the user to log in with, or an error if something failed. // It returns the remote port where it is running, the user to log in with, or an error if something failed.
func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) (serverPort int, user string, err error) { func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) (serverPort int, user string, err error) {

View file

@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net"
"strings" "strings"
"time" "time"
@ -46,10 +47,12 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u
return fmt.Errorf("connect to Live Share: %v", err) return fmt.Errorf("connect to Live Share: %v", err)
} }
localSSHPort, err := UnusedPort() // Ensure local port is listening before client (getPostCreateOutput) connects.
listen, err := liveshare.Listen(0)
if err != nil { if err != nil {
return err return err
} }
localPort := listen.Addr().(*net.TCPAddr).Port
remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, session, log) remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, session, log)
if err != nil { if err != nil {
@ -59,7 +62,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u
tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
go func() { go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil tunnelClosed <- fwd.ForwardToLocalPort(ctx, listen) // error is non-nil
}() }()
t := time.NewTicker(1 * time.Second) t := time.NewTicker(1 * time.Second)
@ -74,7 +77,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u
return fmt.Errorf("connection failed: %v", err) return fmt.Errorf("connection failed: %v", err)
case <-t.C: case <-t.C:
states, err := getPostCreateOutput(ctx, localSSHPort, codespace, sshUser) states, err := getPostCreateOutput(ctx, localPort, codespace, sshUser)
if err != nil { if err != nil {
return fmt.Errorf("get post create output: %v", err) return fmt.Errorf("get post create output: %v", err)
} }