diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 90a7f7bfc..e8e1cb671 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -5,7 +5,6 @@ import ( "context" "errors" "fmt" - "log" "math/rand" "os" "os/exec" @@ -177,40 +176,25 @@ func SSH(sshProfile string) error { } }() + connectDestination := sshProfile + if connectDestination == "" { + connectDestination = fmt.Sprintf("%s@localhost", getSSHUser(codespace)) + } + fmt.Println("Ready...") - if err := connect(ctx, port, sshProfile); err != nil { + if err := connect(ctx, port, connectDestination); err != nil { return fmt.Errorf("error connecting via SSH: %v", err) } return nil } -func connect(ctx context.Context, port int, sshProfile string) error { - var cmd *exec.Cmd - if sshProfile != "" { - cmd = exec.CommandContext(ctx, "ssh", sshProfile, "-p", strconv.Itoa(port), "-C") - } else { - cmd = exec.CommandContext(ctx, "ssh", "codespace@localhost", "-C", "-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes") - } - +func connect(ctx context.Context, port int, destination string) error { + cmd := exec.CommandContext(ctx, "ssh", destination, "-C", "-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes") cmd.Stdout = os.Stdout cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr - - if err := cmd.Start(); err != nil { - return fmt.Errorf("error running ssh: %v", err) - } - - go func() { - if err := cmd.Wait(); err != nil { - log.Println(fmt.Errorf("error waiting for ssh to finish: %v", err)) - } - }() - - done := make(chan bool) - <-done - - return nil + return cmd.Run() } func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) { @@ -263,3 +247,10 @@ func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID, re return nil } + +func getSSHUser(codespace *api.Codespace) string { + if codespace.RepositoryNWO == "github/github" { + return "root" + } + return "codespace" +}