merge upstream
This commit is contained in:
commit
efc6fd369c
7 changed files with 69 additions and 83 deletions
|
|
@ -3,6 +3,7 @@ package main
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/github/ghcs/api"
|
||||
|
|
@ -10,6 +11,7 @@ import (
|
|||
"github.com/github/ghcs/internal/codespaces"
|
||||
"github.com/github/go-liveshare"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func newLogsCmd() *cobra.Command {
|
||||
|
|
@ -71,10 +73,13 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow
|
|||
return fmt.Errorf("connecting to Live Share: %v", err)
|
||||
}
|
||||
|
||||
localSSHPort, err := codespaces.UnusedPort()
|
||||
// Ensure local port is listening before client (getPostCreateOutput) connects.
|
||||
listen, err := net.Listen("tcp", ":0") // arbitrary port
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listen.Close()
|
||||
localPort := listen.Addr().(*net.TCPAddr).Port
|
||||
|
||||
remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log)
|
||||
if err != nil {
|
||||
|
|
@ -88,30 +93,15 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow
|
|||
|
||||
dst := fmt.Sprintf("%s@localhost", sshUser)
|
||||
cmd := codespaces.NewRemoteCommand(
|
||||
ctx, localSSHPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
|
||||
ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
|
||||
)
|
||||
|
||||
// Error channels are buffered so that neither sending goroutine gets stuck.
|
||||
|
||||
tunnelClosed := make(chan error, 1)
|
||||
go func() {
|
||||
group, ctx := errgroup.WithContext(ctx)
|
||||
group.Go(func() error {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
|
||||
tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil
|
||||
}()
|
||||
|
||||
cmdDone := make(chan error, 1)
|
||||
go func() {
|
||||
cmdDone <- cmd.Run()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-tunnelClosed:
|
||||
err := fwd.ForwardToListener(ctx, listen) // error is non-nil
|
||||
return fmt.Errorf("connection closed: %v", err)
|
||||
|
||||
case err := <-cmdDone:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error retrieving logs: %v", err)
|
||||
}
|
||||
return nil // success
|
||||
}
|
||||
})
|
||||
group.Go(cmd.Run)
|
||||
return group.Wait()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,8 +19,9 @@ func main() {
|
|||
var version = "DEV"
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "ghcs",
|
||||
SilenceUsage: true, // don't print usage message after each error (see #80)
|
||||
Use: "ghcs",
|
||||
SilenceUsage: true, // don't print usage message after each error (see #80)
|
||||
SilenceErrors: false, // print errors automatically so that main need not
|
||||
Long: `Unofficial CLI tool to manage GitHub Codespaces.
|
||||
|
||||
Running commands requires the GITHUB_TOKEN environment variable to be set to a
|
||||
|
|
@ -43,5 +44,4 @@ func explainError(w io.Writer, err error) {
|
|||
fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.")
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%v\n", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import (
|
|||
"io"
|
||||
)
|
||||
|
||||
// NewLogger returns a Logger that will write to the given stdout/stderr writers.
|
||||
// Disable the Logger to prevent it from writing to stdout in a TTY environment.
|
||||
func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger {
|
||||
return &Logger{
|
||||
out: stdout,
|
||||
|
|
@ -13,12 +15,16 @@ func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger {
|
|||
}
|
||||
}
|
||||
|
||||
// Logger writes to the given stdout/stderr writers.
|
||||
// If not enabled, Print functions will noop but Error functions will continue
|
||||
// to write to the stderr writer.
|
||||
type Logger struct {
|
||||
out io.Writer
|
||||
errout io.Writer
|
||||
enabled bool
|
||||
}
|
||||
|
||||
// Print writes the arguments to the stdout writer.
|
||||
func (l *Logger) Print(v ...interface{}) (int, error) {
|
||||
if !l.enabled {
|
||||
return 0, nil
|
||||
|
|
@ -26,6 +32,7 @@ func (l *Logger) Print(v ...interface{}) (int, error) {
|
|||
return fmt.Fprint(l.out, v...)
|
||||
}
|
||||
|
||||
// Println writes the arguments to the stdout writer with a newline at the end.
|
||||
func (l *Logger) Println(v ...interface{}) (int, error) {
|
||||
if !l.enabled {
|
||||
return 0, nil
|
||||
|
|
@ -33,6 +40,7 @@ func (l *Logger) Println(v ...interface{}) (int, error) {
|
|||
return fmt.Fprintln(l.out, v...)
|
||||
}
|
||||
|
||||
// Printf writes the formatted arguments to the stdout writer.
|
||||
func (l *Logger) Printf(f string, v ...interface{}) (int, error) {
|
||||
if !l.enabled {
|
||||
return 0, nil
|
||||
|
|
@ -40,6 +48,12 @@ func (l *Logger) Printf(f string, v ...interface{}) (int, error) {
|
|||
return fmt.Fprintf(l.out, f, v...)
|
||||
}
|
||||
|
||||
// Errorf writes the formatted arguments to the stderr writer.
|
||||
func (l *Logger) Errorf(f string, v ...interface{}) (int, error) {
|
||||
return fmt.Fprintf(l.errout, f, v...)
|
||||
}
|
||||
|
||||
// Errorln writes the arguments to the stderr writer with a newline at the end.
|
||||
func (l *Logger) Errorln(v ...interface{}) (int, error) {
|
||||
return fmt.Fprintln(l.errout, v...)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
|
@ -16,6 +17,7 @@ import (
|
|||
"github.com/github/go-liveshare"
|
||||
"github.com/muhammadmuzzammil1998/jsonc"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// portOptions represents the options accepted by the ports command.
|
||||
|
|
@ -272,20 +274,22 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro
|
|||
|
||||
// Run forwarding of all ports concurrently, aborting all of
|
||||
// them at the first failure, including cancellation of the context.
|
||||
errc := make(chan error, len(portPairs))
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
group, ctx := errgroup.WithContext(ctx)
|
||||
for _, pair := range portPairs {
|
||||
pair := pair
|
||||
log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local)
|
||||
name := fmt.Sprintf("share-%d", pair.remote)
|
||||
go func() {
|
||||
group.Go(func() error {
|
||||
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listen.Close()
|
||||
log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local)
|
||||
name := fmt.Sprintf("share-%d", pair.remote)
|
||||
fwd := liveshare.NewPortForwarder(session, name, pair.remote)
|
||||
errc <- fwd.ForwardToLocalPort(ctx, pair.local) // error always non-nil
|
||||
}()
|
||||
return fwd.ForwardToListener(ctx, listen) // error always non-nil
|
||||
})
|
||||
}
|
||||
|
||||
return <-errc // first error
|
||||
return group.Wait() // first error
|
||||
}
|
||||
|
||||
type portPair struct {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
|
|
@ -12,6 +13,7 @@ import (
|
|||
"github.com/github/ghcs/internal/codespaces"
|
||||
"github.com/github/go-liveshare"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
func newSSHCmd() *cobra.Command {
|
||||
|
|
@ -81,42 +83,35 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo
|
|||
}
|
||||
log.Print("\n")
|
||||
|
||||
usingCustomPort := true
|
||||
if localSSHServerPort == 0 {
|
||||
usingCustomPort = false // suppress log of command line in Shell
|
||||
localSSHServerPort, err = codespaces.UnusedPort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell
|
||||
|
||||
// Ensure local port is listening before client (Shell) connects.
|
||||
listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localSSHServerPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listen.Close()
|
||||
localSSHServerPort = listen.Addr().(*net.TCPAddr).Port
|
||||
|
||||
connectDestination := sshProfile
|
||||
if connectDestination == "" {
|
||||
connectDestination = fmt.Sprintf("%s@localhost", sshUser)
|
||||
}
|
||||
|
||||
tunnelClosed := make(chan error)
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
|
||||
tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHServerPort) // error is always non-nil
|
||||
}()
|
||||
|
||||
shellClosed := make(chan error)
|
||||
go func() {
|
||||
shellClosed <- codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort)
|
||||
}()
|
||||
|
||||
log.Println("Ready...")
|
||||
select {
|
||||
case err := <-tunnelClosed:
|
||||
group, ctx := errgroup.WithContext(ctx)
|
||||
group.Go(func() error {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
|
||||
err := fwd.ForwardToListener(ctx, listen) // always non-nil
|
||||
return fmt.Errorf("tunnel closed: %v", err)
|
||||
|
||||
case err := <-shellClosed:
|
||||
if err != nil {
|
||||
})
|
||||
group.Go(func() error {
|
||||
if err := codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort); err != nil {
|
||||
return fmt.Errorf("shell closed: %v", err)
|
||||
}
|
||||
return nil // success
|
||||
}
|
||||
})
|
||||
return group.Wait()
|
||||
}
|
||||
|
||||
func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
|
|
@ -13,25 +12,6 @@ import (
|
|||
"github.com/github/go-liveshare"
|
||||
)
|
||||
|
||||
// UnusedPort returns the number of a local TCP port that is currently
|
||||
// unbound, or an error if none was available.
|
||||
//
|
||||
// Use of this function carries an inherent risk of a time-of-check to
|
||||
// time-of-use race against other processes.
|
||||
func UnusedPort() (int, error) {
|
||||
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("internal error while choosing port: %v", err)
|
||||
}
|
||||
|
||||
l, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("choosing available port: %v", err)
|
||||
}
|
||||
defer l.Close()
|
||||
return l.Addr().(*net.TCPAddr).Port, nil
|
||||
}
|
||||
|
||||
// StartSSHServer installs (if necessary) and starts the SSH in the codespace.
|
||||
// It returns the remote port where it is running, the user to log in with, or an error if something failed.
|
||||
func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) (serverPort int, user string, err error) {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
|
@ -46,10 +47,12 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u
|
|||
return fmt.Errorf("connect to Live Share: %v", err)
|
||||
}
|
||||
|
||||
localSSHPort, err := UnusedPort()
|
||||
// Ensure local port is listening before client (getPostCreateOutput) connects.
|
||||
listen, err := net.Listen("tcp", ":0") // arbitrary port
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
localPort := listen.Addr().(*net.TCPAddr).Port
|
||||
|
||||
remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, session, log)
|
||||
if err != nil {
|
||||
|
|
@ -59,7 +62,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u
|
|||
tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
|
||||
go func() {
|
||||
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort)
|
||||
tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil
|
||||
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil
|
||||
}()
|
||||
|
||||
t := time.NewTicker(1 * time.Second)
|
||||
|
|
@ -74,7 +77,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u
|
|||
return fmt.Errorf("connection failed: %v", err)
|
||||
|
||||
case <-t.C:
|
||||
states, err := getPostCreateOutput(ctx, localSSHPort, codespace, sshUser)
|
||||
states, err := getPostCreateOutput(ctx, localPort, codespace, sshUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get post create output: %v", err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue