From 8eba57a9ed6ccf69c4944939cd587ff3e1403e70 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 23 Jun 2021 20:00:24 -0400 Subject: [PATCH 001/290] initial commit --- api.go | 130 ++++++++++++++++++++++++++++++++++++++++++++++++ client.go | 28 +++++++++++ config.go | 56 +++++++++++++++++++++ example/main.go | 23 +++++++++ liveshare.go | 35 +++++++++++++ session.go | 44 ++++++++++++++++ ssh.go | 70 ++++++++++++++++++++++++++ 7 files changed, 386 insertions(+) create mode 100644 api.go create mode 100644 client.go create mode 100644 config.go create mode 100644 example/main.go create mode 100644 liveshare.go create mode 100644 session.go create mode 100644 ssh.go diff --git a/api.go b/api.go new file mode 100644 index 000000000..c7efb9830 --- /dev/null +++ b/api.go @@ -0,0 +1,130 @@ +package liveshare + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "strings" +) + +type API struct { + Configuration *Configuration + HttpClient *http.Client + ServiceURI string + WorkspaceID string +} + +func NewAPI(configuration *Configuration) *API { + serviceURI := configuration.LiveShareEndpoint + if !strings.HasSuffix(configuration.LiveShareEndpoint, "/") { + serviceURI = configuration.LiveShareEndpoint + "/" + } + + if !strings.Contains(serviceURI, "api/v1.2") { + serviceURI = serviceURI + "api/v1.2" + } + + serviceURI = strings.TrimSuffix(serviceURI, "/") + + return &API{configuration, &http.Client{}, serviceURI, strings.ToUpper(configuration.WorkspaceID)} +} + +type WorkspaceAccessResponse struct { + SessionToken string `json:"sessionToken"` + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + Name string `json:"name"` + OwnerID string `json:"ownerId"` + JoinLink string `json:"joinLink"` + ConnectLinks []string `json:"connectLinks"` + RelayLink string `json:"relayLink"` + RelaySas string `json:"relaySas"` + HostPublicKeys []string `json:"hostPublicKeys"` + ConversationID string `json:"conversationId"` + AssociatedUserIDs []string `json:"associatedUserIds"` + AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` + IsHostConnected bool `json:"isHostConnected"` + ExpiresAt string `json:"expiresAt"` + InvitationLinks []string `json:"invitationLinks"` + ID string `json:"id"` +} + +func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { + url := fmt.Sprintf("%s/workspace/%s/user", a.ServiceURI, a.WorkspaceID) + + req, err := http.NewRequest(http.MethodPut, url, nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + a.setDefaultHeaders(req) + resp, err := a.HttpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + var workspaceAccessResponse WorkspaceAccessResponse + if err := json.Unmarshal(b, &workspaceAccessResponse); err != nil { + return nil, fmt.Errorf("error unmarshaling response into json: %v", err) + } + + return &workspaceAccessResponse, nil +} + +func (a *API) setDefaultHeaders(req *http.Request) { + req.Header.Set("Authorization", "Bearer "+a.Configuration.Token) + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Content-Type", "application/json") +} + +type WorkspaceInfoResponse struct { + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + Name string `json:"name"` + OwnerID string `json:"ownerId"` + JoinLink string `json:"joinLink"` + ConnectLinks []string `json:"connectLinks"` + RelayLink string `json:"relayLink"` + RelaySas string `json:"relaySas"` + HostPublicKeys []string `json:"hostPublicKeys"` + ConversationID string `json:"conversationId"` + AssociatedUserIDs []string `json:"associatedUserIds"` + AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` + IsHostConnected bool `json:"isHostConnected"` + ExpiresAt string `json:"expiresAt"` + InvitationLinks []string `json:"invitationLinks"` + ID string `json:"id"` +} + +func (a *API) WorkspaceInfo() (*WorkspaceInfoResponse, error) { + url := fmt.Sprintf("%s/workspace/%s", a.ServiceURI, a.WorkspaceID) + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + a.setDefaultHeaders(req) + resp, err := a.HttpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + var workspaceInfoResponse WorkspaceInfoResponse + if err := json.Unmarshal(b, &workspaceInfoResponse); err != nil { + return nil, fmt.Errorf("error unmarshaling response into json: %v", err) + } + + return &workspaceInfoResponse, nil +} diff --git a/client.go b/client.go new file mode 100644 index 000000000..a58a34c9b --- /dev/null +++ b/client.go @@ -0,0 +1,28 @@ +package liveshare + +import ( + "context" + "fmt" +) + +type Client struct { + Configuration *Configuration +} + +func NewClient(configuration *Configuration) *Client { + return &Client{configuration} +} + +func (c *Client) Join(ctx context.Context) error { + session, err := GetSession(ctx, c.Configuration) + if err != nil { + return fmt.Errorf("error getting session: %v", err) + } + + sshSession := NewSSHSession(session) + if err := sshSession.Connect(); err != nil { + return fmt.Errorf("error authenticating ssh session: %v", err) + } + + return nil +} diff --git a/config.go b/config.go new file mode 100644 index 000000000..74eb5b178 --- /dev/null +++ b/config.go @@ -0,0 +1,56 @@ +package liveshare + +import ( + "errors" + "strings" +) + +type Option func(configuration *Configuration) error + +func WithWorkspaceID(id string) Option { + return func(configuration *Configuration) error { + configuration.WorkspaceID = id + return nil + } +} + +func WithLiveShareEndpoint(liveShareEndpoint string) Option { + return func(configuration *Configuration) error { + configuration.LiveShareEndpoint = liveShareEndpoint + return nil + } +} + +func WithToken(token string) Option { + return func(configuration *Configuration) error { + configuration.Token = token + return nil + } +} + +type Configuration struct { + WorkspaceID, LiveShareEndpoint, Token string +} + +func NewConfiguration() *Configuration { + return &Configuration{ + LiveShareEndpoint: "https://prod.liveshare.vsengsaas.visualstudio.com", + } +} + +func (c *Configuration) Validate() error { + errs := []string{} + if c.WorkspaceID == "" { + errs = append(errs, "WorkspaceID is required") + } + + if c.Token == "" { + errs = append(errs, "Token is required") + } + + if len(errs) > 0 { + return errors.New(strings.Join(errs, ", ")) + } + + return nil +} diff --git a/example/main.go b/example/main.go new file mode 100644 index 000000000..79ab53377 --- /dev/null +++ b/example/main.go @@ -0,0 +1,23 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/josebalius/go-liveshare" +) + +func main() { + liveShare, err := liveshare.New( + liveshare.WithWorkspaceID("..."), + liveshare.WithToken("..."), + ) + if err != nil { + log.Fatal(fmt.Errorf("error creating liveshare: %v", err)) + } + + if err := liveShare.Connect(context.Background()); err != nil { + log.Fatal(fmt.Errorf("error connecting to liveshare: %v", err)) + } +} diff --git a/liveshare.go b/liveshare.go new file mode 100644 index 000000000..a8c8b69d6 --- /dev/null +++ b/liveshare.go @@ -0,0 +1,35 @@ +package liveshare + +import ( + "context" + "fmt" +) + +type LiveShare struct { + Configuration *Configuration +} + +func New(opts ...Option) (*LiveShare, error) { + configuration := NewConfiguration() + + for _, o := range opts { + if err := o(configuration); err != nil { + return nil, fmt.Errorf("error configuring liveshare: %v", err) + } + } + + if err := configuration.Validate(); err != nil { + return nil, fmt.Errorf("error validating configuration: %v", err) + } + + return &LiveShare{configuration}, nil +} + +func (l *LiveShare) Connect(ctx context.Context) error { + workspaceClient := NewClient(l.Configuration) + if err := workspaceClient.Join(ctx); err != nil { + return fmt.Errorf("error joining with workspace client: %v", err) + } + + return nil +} diff --git a/session.go b/session.go new file mode 100644 index 000000000..24a284ef2 --- /dev/null +++ b/session.go @@ -0,0 +1,44 @@ +package liveshare + +import ( + "context" + "fmt" + + "golang.org/x/sync/errgroup" +) + +type Session struct { + WorkspaceAccess *WorkspaceAccessResponse + WorkspaceInfo *WorkspaceInfoResponse +} + +func GetSession(ctx context.Context, configuration *Configuration) (*Session, error) { + api := NewAPI(configuration) + session := new(Session) + + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + workspaceAccess, err := api.WorkspaceAccess() + if err != nil { + return fmt.Errorf("error getting workspace access: %v", err) + } + session.WorkspaceAccess = workspaceAccess + return nil + }) + + g.Go(func() error { + workspaceInfo, err := api.WorkspaceInfo() + if err != nil { + return fmt.Errorf("error getting workspace info: %v", err) + } + session.WorkspaceInfo = workspaceInfo + return nil + }) + + if err := g.Wait(); err != nil { + return nil, err + } + + return session, nil +} diff --git a/ssh.go b/ssh.go new file mode 100644 index 000000000..af132b331 --- /dev/null +++ b/ssh.go @@ -0,0 +1,70 @@ +package liveshare + +import ( + "fmt" + "net" + "net/url" + "strings" + + "golang.org/x/crypto/ssh" + "golang.org/x/net/websocket" +) + +type SSHSession struct { + Session *Session + VersionExchangeError chan error +} + +func NewSSHSession(session *Session) *SSHSession { + return &SSHSession{ + Session: session, + } +} + +func (s *SSHSession) Connect() error { + socketStream, err := s.socketStream() + if err != nil { + return fmt.Errorf("error creating socket stream: %v", err) + } + + clientConfig := ssh.ClientConfig{ + User: "", + Auth: []ssh.AuthMethod{ + ssh.Password(s.Session.WorkspaceAccess.SessionToken), + }, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + // TODO(josebalius): implement + return nil + }, + } + + sshClientConn, chans, reqs, err := ssh.NewClientConn(socketStream, "", &clientConfig) + if err != nil { + return fmt.Errorf("error creating ssh client connection: %v", err) + } + + fmt.Println(sshClientConn, chans, reqs) + + return nil +} + +// Reference: +// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 +func (s *SSHSession) relayURI(action string) string { + relaySas := url.QueryEscape(s.Session.WorkspaceAccess.RelaySas) + relayURI := s.Session.WorkspaceAccess.RelayLink + relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) + relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) + relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas + return relayURI +} + +func (s *SSHSession) socketStream() (*websocket.Conn, error) { + uri := s.relayURI("connect") + ws, err := websocket.Dial(uri, "", uri) + if err != nil { + return nil, fmt.Errorf("error dialing relay connection: %v", err) + } + + return ws, nil +} From a8b1b87f7b33ab3d37fc3aed276d356d61631367 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 24 Jun 2021 20:44:16 -0400 Subject: [PATCH 002/290] Start of RPC implementation, need to figure out format --- example/main.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/example/main.go b/example/main.go index 79ab53377..bbf540d7c 100644 --- a/example/main.go +++ b/example/main.go @@ -10,8 +10,8 @@ import ( func main() { liveShare, err := liveshare.New( - liveshare.WithWorkspaceID("..."), - liveshare.WithToken("..."), + liveshare.WithWorkspaceID(""), + liveshare.WithToken(""), ) if err != nil { log.Fatal(fmt.Errorf("error creating liveshare: %v", err)) @@ -20,4 +20,17 @@ func main() { if err := liveShare.Connect(context.Background()); err != nil { log.Fatal(fmt.Errorf("error connecting to liveshare: %v", err)) } + + terminal := liveShare.NewTerminal() + + cmd := terminal.NewCommand( + "/home/codespace/workspace", + "docker ps -aq --filter label=Type=codespaces --filter status=running", + ) + output, err := cmd.Run(context.Background()) + if err != nil { + log.Fatal(fmt.Errorf("error starting ssh server with liveshare: %v", err)) + } + + fmt.Println(string(output)) } From 897ab1598b3d2cd9dd08c3310f0938f5c2ce4264 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 24 Jun 2021 20:45:03 -0400 Subject: [PATCH 003/290] RPC functionality started take two --- adapter.go | 100 +++++++++++++++++++++++++++++++++++++++++++++++++++ api.go | 40 +++++++++++---------- client.go | 9 ++--- liveshare.go | 63 ++++++++++++++++++++++++++++++-- ssh.go | 91 ++++++++++++++++++++++++++++++---------------- 5 files changed, 248 insertions(+), 55 deletions(-) create mode 100644 adapter.go diff --git a/adapter.go b/adapter.go new file mode 100644 index 000000000..fb3424734 --- /dev/null +++ b/adapter.go @@ -0,0 +1,100 @@ +package liveshare + +import ( + "errors" + "io" + "net" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type Adapter struct { + conn *websocket.Conn + readMutex sync.Mutex + writeMutex sync.Mutex + reader io.Reader +} + +func NewAdapter(conn *websocket.Conn) *Adapter { + return &Adapter{ + conn: conn, + } +} + +func (a *Adapter) Read(b []byte) (int, error) { + // Read() can be called concurrently, and we mutate some internal state here + a.readMutex.Lock() + defer a.readMutex.Unlock() + + if a.reader == nil { + messageType, reader, err := a.conn.NextReader() + if err != nil { + return 0, err + } + + if messageType != websocket.BinaryMessage { + return 0, errors.New("unexpected websocket message type") + } + + a.reader = reader + } + + bytesRead, err := a.reader.Read(b) + if err != nil { + a.reader = nil + + // EOF for the current Websocket frame, more will probably come so.. + if err == io.EOF { + // .. we must hide this from the caller since our semantics are a + // stream of bytes across many frames + err = nil + } + } + + return bytesRead, err +} + +func (a *Adapter) Write(b []byte) (int, error) { + a.writeMutex.Lock() + defer a.writeMutex.Unlock() + + nextWriter, err := a.conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, err + } + + bytesWritten, err := nextWriter.Write(b) + nextWriter.Close() + + return bytesWritten, err +} + +func (a *Adapter) Close() error { + return a.conn.Close() +} + +func (a *Adapter) LocalAddr() net.Addr { + return a.conn.LocalAddr() +} + +func (a *Adapter) RemoteAddr() net.Addr { + return a.conn.RemoteAddr() +} + +func (a *Adapter) SetDeadline(t time.Time) error { + if err := a.SetReadDeadline(t); err != nil { + return err + } + + return a.SetWriteDeadline(t) +} + +func (a *Adapter) SetReadDeadline(t time.Time) error { + return a.conn.SetReadDeadline(t) +} + +func (a *Adapter) SetWriteDeadline(t time.Time) error { + return a.conn.SetWriteDeadline(t) +} diff --git a/api.go b/api.go index c7efb9830..d8a4bebbd 100644 --- a/api.go +++ b/api.go @@ -5,6 +5,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/http/httputil" "strings" ) @@ -31,27 +32,28 @@ func NewAPI(configuration *Configuration) *API { } type WorkspaceAccessResponse struct { - SessionToken string `json:"sessionToken"` - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - Name string `json:"name"` - OwnerID string `json:"ownerId"` - JoinLink string `json:"joinLink"` - ConnectLinks []string `json:"connectLinks"` - RelayLink string `json:"relayLink"` - RelaySas string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` - ConversationID string `json:"conversationId"` - AssociatedUserIDs []string `json:"associatedUserIds"` - AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` - IsHostConnected bool `json:"isHostConnected"` - ExpiresAt string `json:"expiresAt"` - InvitationLinks []string `json:"invitationLinks"` - ID string `json:"id"` + SessionToken string `json:"sessionToken"` + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + Name string `json:"name"` + OwnerID string `json:"ownerId"` + JoinLink string `json:"joinLink"` + ConnectLinks []string `json:"connectLinks"` + RelayLink string `json:"relayLink"` + RelaySas string `json:"relaySas"` + HostPublicKeys []string `json:"hostPublicKeys"` + ConversationID string `json:"conversationId"` + AssociatedUserIDs map[string]string `json:"associatedUserIds"` + AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` + IsHostConnected bool `json:"isHostConnected"` + ExpiresAt string `json:"expiresAt"` + InvitationLinks []string `json:"invitationLinks"` + ID string `json:"id"` } func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { url := fmt.Sprintf("%s/workspace/%s/user", a.ServiceURI, a.WorkspaceID) + fmt.Println(url) req, err := http.NewRequest(http.MethodPut, url, nil) if err != nil { @@ -69,6 +71,8 @@ func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { return nil, fmt.Errorf("error reading response body: %v", err) } + d, _ := httputil.DumpResponse(resp, true) + fmt.Println(string(d)) var workspaceAccessResponse WorkspaceAccessResponse if err := json.Unmarshal(b, &workspaceAccessResponse); err != nil { return nil, fmt.Errorf("error unmarshaling response into json: %v", err) @@ -94,7 +98,7 @@ type WorkspaceInfoResponse struct { RelaySas string `json:"relaySas"` HostPublicKeys []string `json:"hostPublicKeys"` ConversationID string `json:"conversationId"` - AssociatedUserIDs []string `json:"associatedUserIds"` + AssociatedUserIDs map[string]string AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` IsHostConnected bool `json:"isHostConnected"` ExpiresAt string `json:"expiresAt"` diff --git a/client.go b/client.go index a58a34c9b..0a89a125c 100644 --- a/client.go +++ b/client.go @@ -7,10 +7,11 @@ import ( type Client struct { Configuration *Configuration + SSHSession *SSHSession } func NewClient(configuration *Configuration) *Client { - return &Client{configuration} + return &Client{Configuration: configuration} } func (c *Client) Join(ctx context.Context) error { @@ -19,9 +20,9 @@ func (c *Client) Join(ctx context.Context) error { return fmt.Errorf("error getting session: %v", err) } - sshSession := NewSSHSession(session) - if err := sshSession.Connect(); err != nil { - return fmt.Errorf("error authenticating ssh session: %v", err) + c.SSHSession, err = NewSSH(session).NewSession() + if err != nil { + return fmt.Errorf("error connecting to ssh session: %v", err) } return nil diff --git a/liveshare.go b/liveshare.go index a8c8b69d6..174eac20f 100644 --- a/liveshare.go +++ b/liveshare.go @@ -3,10 +3,14 @@ package liveshare import ( "context" "fmt" + "net/rpc" ) type LiveShare struct { Configuration *Configuration + + workspaceClient *Client + terminal *Terminal } func New(opts ...Option) (*LiveShare, error) { @@ -22,14 +26,67 @@ func New(opts ...Option) (*LiveShare, error) { return nil, fmt.Errorf("error validating configuration: %v", err) } - return &LiveShare{configuration}, nil + return &LiveShare{Configuration: configuration}, nil } func (l *LiveShare) Connect(ctx context.Context) error { - workspaceClient := NewClient(l.Configuration) - if err := workspaceClient.Join(ctx); err != nil { + l.workspaceClient = NewClient(l.Configuration) + if err := l.workspaceClient.Join(ctx); err != nil { return fmt.Errorf("error joining with workspace client: %v", err) } return nil } + +type Terminal struct { + WorkspaceClient *Client + RPCClient *rpc.Client +} + +func (l *LiveShare) NewTerminal() *Terminal { + return &Terminal{ + WorkspaceClient: l.workspaceClient, + RPCClient: rpc.NewClient(l.workspaceClient.SSHSession), + } +} + +type TerminalCommand struct { + Terminal *Terminal + Cwd string + Cmd string +} + +func (t *Terminal) NewCommand(cwd, cmd string) TerminalCommand { + return TerminalCommand{t, cwd, cmd} +} + +type RunArgs struct { + Name string + Rows, Cols int + App string + Cwd string + CommandLine []string + ReadOnlyForGuests bool +} + +func (t TerminalCommand) Run(ctx context.Context) ([]byte, error) { + args := RunArgs{ + Name: "RunCommand", + Rows: 10, + Cols: 80, + App: "/bin/bash", + Cwd: t.Cwd, + CommandLine: []string{"-c", t.Cmd}, + ReadOnlyForGuests: false, + } + + var output []byte + runCall := t.Terminal.RPCClient.Go("terminal.startAsync", &args, &output, nil) + + runReply := <-runCall.Done + if runReply.Error != nil { + return nil, fmt.Errorf("error startAsync operation: %v", runReply.Error) + } + fmt.Printf("%+v\n\n", runReply) + return output, nil +} diff --git a/ssh.go b/ssh.go index af132b331..51906290e 100644 --- a/ssh.go +++ b/ssh.go @@ -2,29 +2,66 @@ package liveshare import ( "fmt" + "io" "net" "net/url" "strings" + "time" + "github.com/gorilla/websocket" "golang.org/x/crypto/ssh" - "golang.org/x/net/websocket" ) -type SSHSession struct { - Session *Session - VersionExchangeError chan error +type SSH struct { + Session *Session } -func NewSSHSession(session *Session) *SSHSession { - return &SSHSession{ +func NewSSH(session *Session) *SSH { + return &SSH{ Session: session, } } -func (s *SSHSession) Connect() error { +// Reference: +// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 +func (s *SSH) relayURI(action string) string { + relaySas := url.QueryEscape(s.Session.WorkspaceAccess.RelaySas) + relayURI := s.Session.WorkspaceAccess.RelayLink + relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) + relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) + relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas + return relayURI +} + +func (s *SSH) socketStream() (net.Conn, error) { + uri := s.relayURI("connect") + + ws, _, err := websocket.DefaultDialer.Dial(uri, nil) + if err != nil { + return nil, fmt.Errorf("error dialing websocket connection: %v", err) + } + + return NewAdapter(ws), nil +} + +type SSHSession struct { + *ssh.Session + reader io.Reader + writer io.Writer +} + +func (s SSHSession) Read(p []byte) (n int, err error) { + return s.reader.Read(p) +} + +func (s SSHSession) Write(p []byte) (n int, err error) { + return s.writer.Write(p) +} + +func (s *SSH) NewSession() (*SSHSession, error) { socketStream, err := s.socketStream() if err != nil { - return fmt.Errorf("error creating socket stream: %v", err) + return nil, fmt.Errorf("error creating socket stream: %v", err) } clientConfig := ssh.ClientConfig{ @@ -36,35 +73,29 @@ func (s *SSHSession) Connect() error { // TODO(josebalius): implement return nil }, + Timeout: 10 * time.Second, } sshClientConn, chans, reqs, err := ssh.NewClientConn(socketStream, "", &clientConfig) if err != nil { - return fmt.Errorf("error creating ssh client connection: %v", err) + return nil, fmt.Errorf("error creating ssh client connection: %v", err) } - fmt.Println(sshClientConn, chans, reqs) - - return nil -} - -// Reference: -// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 -func (s *SSHSession) relayURI(action string) string { - relaySas := url.QueryEscape(s.Session.WorkspaceAccess.RelaySas) - relayURI := s.Session.WorkspaceAccess.RelayLink - relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) - relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) - relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas - return relayURI -} - -func (s *SSHSession) socketStream() (*websocket.Conn, error) { - uri := s.relayURI("connect") - ws, err := websocket.Dial(uri, "", uri) + sshClient := ssh.NewClient(sshClientConn, chans, reqs) + sshSession, err := sshClient.NewSession() if err != nil { - return nil, fmt.Errorf("error dialing relay connection: %v", err) + return nil, fmt.Errorf("error creating ssh client session: %v", err) } - return ws, nil + reader, err := sshSession.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("error creating ssh session reader: %v", err) + } + + writer, err := sshSession.StdinPipe() + if err != nil { + return nil, fmt.Errorf("error creating ssh session writer: %v", err) + } + + return &SSHSession{Session: sshSession, reader: reader, writer: writer}, nil } From 6cd0aa7a90c9737a8c9c0f113a301ebaf1264e5e Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 6 Jul 2021 09:12:43 -0400 Subject: [PATCH 004/290] Working albeit not imperfect implementation --- adapter.go | 100 --------------------------------------- api.go | 55 +++++++++++----------- client.go | 115 +++++++++++++++++++++++++++++++++++++++++---- example/main.go | 109 +++++++++++++++++++++++++++++++++++++------ liveshare.go | 67 -------------------------- port_forwarder.go | 63 +++++++++++++++++++++++++ rpc.go | 97 ++++++++++++++++++++++++++++++++++++++ rpc_test.go | 27 +++++++++++ server.go | 52 +++++++++++++++++++++ session.go | 40 +++++++++++----- ssh.go | 97 +++++++++++++------------------------- terminal.go | 116 ++++++++++++++++++++++++++++++++++++++++++++++ websocket.go | 105 +++++++++++++++++++++++++++++++++++++++++ 13 files changed, 746 insertions(+), 297 deletions(-) delete mode 100644 adapter.go create mode 100644 port_forwarder.go create mode 100644 rpc.go create mode 100644 rpc_test.go create mode 100644 server.go create mode 100644 terminal.go create mode 100644 websocket.go diff --git a/adapter.go b/adapter.go deleted file mode 100644 index fb3424734..000000000 --- a/adapter.go +++ /dev/null @@ -1,100 +0,0 @@ -package liveshare - -import ( - "errors" - "io" - "net" - "sync" - "time" - - "github.com/gorilla/websocket" -) - -type Adapter struct { - conn *websocket.Conn - readMutex sync.Mutex - writeMutex sync.Mutex - reader io.Reader -} - -func NewAdapter(conn *websocket.Conn) *Adapter { - return &Adapter{ - conn: conn, - } -} - -func (a *Adapter) Read(b []byte) (int, error) { - // Read() can be called concurrently, and we mutate some internal state here - a.readMutex.Lock() - defer a.readMutex.Unlock() - - if a.reader == nil { - messageType, reader, err := a.conn.NextReader() - if err != nil { - return 0, err - } - - if messageType != websocket.BinaryMessage { - return 0, errors.New("unexpected websocket message type") - } - - a.reader = reader - } - - bytesRead, err := a.reader.Read(b) - if err != nil { - a.reader = nil - - // EOF for the current Websocket frame, more will probably come so.. - if err == io.EOF { - // .. we must hide this from the caller since our semantics are a - // stream of bytes across many frames - err = nil - } - } - - return bytesRead, err -} - -func (a *Adapter) Write(b []byte) (int, error) { - a.writeMutex.Lock() - defer a.writeMutex.Unlock() - - nextWriter, err := a.conn.NextWriter(websocket.BinaryMessage) - if err != nil { - return 0, err - } - - bytesWritten, err := nextWriter.Write(b) - nextWriter.Close() - - return bytesWritten, err -} - -func (a *Adapter) Close() error { - return a.conn.Close() -} - -func (a *Adapter) LocalAddr() net.Addr { - return a.conn.LocalAddr() -} - -func (a *Adapter) RemoteAddr() net.Addr { - return a.conn.RemoteAddr() -} - -func (a *Adapter) SetDeadline(t time.Time) error { - if err := a.SetReadDeadline(t); err != nil { - return err - } - - return a.SetWriteDeadline(t) -} - -func (a *Adapter) SetReadDeadline(t time.Time) error { - return a.conn.SetReadDeadline(t) -} - -func (a *Adapter) SetWriteDeadline(t time.Time) error { - return a.conn.SetWriteDeadline(t) -} diff --git a/api.go b/api.go index d8a4bebbd..a101823df 100644 --- a/api.go +++ b/api.go @@ -5,21 +5,20 @@ import ( "fmt" "io/ioutil" "net/http" - "net/http/httputil" "strings" ) -type API struct { - Configuration *Configuration - HttpClient *http.Client - ServiceURI string - WorkspaceID string +type api struct { + client *Client + httpClient *http.Client + serviceURI string + workspaceID string } -func NewAPI(configuration *Configuration) *API { - serviceURI := configuration.LiveShareEndpoint - if !strings.HasSuffix(configuration.LiveShareEndpoint, "/") { - serviceURI = configuration.LiveShareEndpoint + "/" +func newAPI(client *Client) *api { + serviceURI := client.liveShare.Configuration.LiveShareEndpoint + if !strings.HasSuffix(client.liveShare.Configuration.LiveShareEndpoint, "/") { + serviceURI = client.liveShare.Configuration.LiveShareEndpoint + "/" } if !strings.Contains(serviceURI, "api/v1.2") { @@ -28,10 +27,10 @@ func NewAPI(configuration *Configuration) *API { serviceURI = strings.TrimSuffix(serviceURI, "/") - return &API{configuration, &http.Client{}, serviceURI, strings.ToUpper(configuration.WorkspaceID)} + return &api{client, &http.Client{}, serviceURI, strings.ToUpper(client.liveShare.Configuration.WorkspaceID)} } -type WorkspaceAccessResponse struct { +type workspaceAccessResponse struct { SessionToken string `json:"sessionToken"` CreatedAt string `json:"createdAt"` UpdatedAt string `json:"updatedAt"` @@ -51,8 +50,8 @@ type WorkspaceAccessResponse struct { ID string `json:"id"` } -func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { - url := fmt.Sprintf("%s/workspace/%s/user", a.ServiceURI, a.WorkspaceID) +func (a *api) workspaceAccess() (*workspaceAccessResponse, error) { + url := fmt.Sprintf("%s/workspace/%s/user", a.serviceURI, a.workspaceID) fmt.Println(url) req, err := http.NewRequest(http.MethodPut, url, nil) @@ -61,7 +60,7 @@ func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { } a.setDefaultHeaders(req) - resp, err := a.HttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -71,23 +70,21 @@ func (a *API) WorkspaceAccess() (*WorkspaceAccessResponse, error) { return nil, fmt.Errorf("error reading response body: %v", err) } - d, _ := httputil.DumpResponse(resp, true) - fmt.Println(string(d)) - var workspaceAccessResponse WorkspaceAccessResponse - if err := json.Unmarshal(b, &workspaceAccessResponse); err != nil { + var response workspaceAccessResponse + if err := json.Unmarshal(b, &response); err != nil { return nil, fmt.Errorf("error unmarshaling response into json: %v", err) } - return &workspaceAccessResponse, nil + return &response, nil } -func (a *API) setDefaultHeaders(req *http.Request) { - req.Header.Set("Authorization", "Bearer "+a.Configuration.Token) +func (a *api) setDefaultHeaders(req *http.Request) { + req.Header.Set("Authorization", "Bearer "+a.client.liveShare.Configuration.Token) req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Content-Type", "application/json") } -type WorkspaceInfoResponse struct { +type workspaceInfoResponse struct { CreatedAt string `json:"createdAt"` UpdatedAt string `json:"updatedAt"` Name string `json:"name"` @@ -106,8 +103,8 @@ type WorkspaceInfoResponse struct { ID string `json:"id"` } -func (a *API) WorkspaceInfo() (*WorkspaceInfoResponse, error) { - url := fmt.Sprintf("%s/workspace/%s", a.ServiceURI, a.WorkspaceID) +func (a *api) workspaceInfo() (*workspaceInfoResponse, error) { + url := fmt.Sprintf("%s/workspace/%s", a.serviceURI, a.workspaceID) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { @@ -115,7 +112,7 @@ func (a *API) WorkspaceInfo() (*WorkspaceInfoResponse, error) { } a.setDefaultHeaders(req) - resp, err := a.HttpClient.Do(req) + resp, err := a.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -125,10 +122,10 @@ func (a *API) WorkspaceInfo() (*WorkspaceInfoResponse, error) { return nil, fmt.Errorf("error reading response body: %v", err) } - var workspaceInfoResponse WorkspaceInfoResponse - if err := json.Unmarshal(b, &workspaceInfoResponse); err != nil { + var response workspaceInfoResponse + if err := json.Unmarshal(b, &response); err != nil { return nil, fmt.Errorf("error unmarshaling response into json: %v", err) } - return &workspaceInfoResponse, nil + return &response, nil } diff --git a/client.go b/client.go index 0a89a125c..88c91ba01 100644 --- a/client.go +++ b/client.go @@ -3,27 +3,122 @@ package liveshare import ( "context" "fmt" + "log" + + "golang.org/x/crypto/ssh" ) type Client struct { - Configuration *Configuration - SSHSession *SSHSession + liveShare *LiveShare + session *session + sshSession *sshSession + rpc *rpc } -func NewClient(configuration *Configuration) *Client { - return &Client{Configuration: configuration} +// NewClient is a function ... +func (l *LiveShare) NewClient() *Client { + return &Client{liveShare: l} } -func (c *Client) Join(ctx context.Context) error { - session, err := GetSession(ctx, c.Configuration) - if err != nil { - return fmt.Errorf("error getting session: %v", err) +func (c *Client) Join(ctx context.Context) (err error) { + api := newAPI(c) + + c.session = newSession(api) + if err := c.session.init(ctx); err != nil { + return fmt.Errorf("error creating session: %v", err) } - c.SSHSession, err = NewSSH(session).NewSession() - if err != nil { + websocket := newWebsocket(c.session) + if err := websocket.connect(ctx); err != nil { + return fmt.Errorf("error connecting websocket: %v", err) + } + + c.sshSession = newSSH(c.session, websocket) + if err := c.sshSession.connect(ctx); err != nil { return fmt.Errorf("error connecting to ssh session: %v", err) } + c.rpc = newRPC(c.sshSession) + c.rpc.connect(ctx) + + _, err = c.joinWorkspace(ctx) + if err != nil { + return fmt.Errorf("error joining liveshare workspace: %v", err) + } + return nil } + +func (c *Client) hasJoined() bool { + return c.sshSession != nil && c.rpc != nil +} + +type clientCapabilities struct { + IsNonInteractive bool `json:"isNonInteractive"` +} + +type joinWorkspaceArgs struct { + ID string `json:"id"` + ConnectionMode string `json:"connectionMode"` + JoiningUserSessionToken string `json:"joiningUserSessionToken"` + ClientCapabilities clientCapabilities `json:"clientCapabilities"` +} + +type joinWorkspaceResult struct { + SessionNumber int `json:"sessionNumber"` +} + +func (c *Client) joinWorkspace(ctx context.Context) (*joinWorkspaceResult, error) { + args := joinWorkspaceArgs{ + ID: c.session.workspaceInfo.ID, + ConnectionMode: "local", + JoiningUserSessionToken: c.session.workspaceAccess.SessionToken, + ClientCapabilities: clientCapabilities{ + IsNonInteractive: false, + }, + } + + var result joinWorkspaceResult + if err := c.rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { + return nil, fmt.Errorf("error making workspace.joinWorkspace call: %v", err) + } + + return &result, nil +} + +func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { + args := getStreamArgs{streamName, condition} + var streamID string + if err := c.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { + return nil, fmt.Errorf("error getting stream id: %v", err) + } + + channel, reqs, err := c.sshSession.conn.OpenChannel("session", nil) + if err != nil { + return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) + } + go c.processChannelRequests(ctx, reqs) + + requestType := fmt.Sprintf("stream-transport-%s", streamID) + acked, err := channel.SendRequest(requestType, true, nil) + if err != nil { + return nil, fmt.Errorf("error sending channel request: %v", err) + } + fmt.Println("ACKED: ", acked) + + return channel, nil +} + +func (c *Client) processChannelRequests(ctx context.Context, reqs <-chan *ssh.Request) { + for { + select { + case req := <-reqs: + if req != nil { + fmt.Printf("REQ: %+v\n\n", req) + log.Println("streaming channel requests are not supported") + } + case <-ctx.Done(): + break + } + } +} diff --git a/example/main.go b/example/main.go index bbf540d7c..b1fda8d32 100644 --- a/example/main.go +++ b/example/main.go @@ -1,36 +1,117 @@ package main import ( + "bufio" "context" + "flag" "fmt" "log" + "os" - "github.com/josebalius/go-liveshare" + "github.com/github/go-liveshare" ) +var workspaceIdFlag = flag.String("w", "", "workspace session id") + +func init() { + flag.Parse() +} + func main() { liveShare, err := liveshare.New( - liveshare.WithWorkspaceID(""), - liveshare.WithToken(""), + liveshare.WithWorkspaceID(*workspaceIdFlag), + liveshare.WithToken(os.Getenv("CODESPACE_TOKEN")), ) if err != nil { log.Fatal(fmt.Errorf("error creating liveshare: %v", err)) } - if err := liveShare.Connect(context.Background()); err != nil { - log.Fatal(fmt.Errorf("error connecting to liveshare: %v", err)) + ctx := context.Background() + liveShareClient := liveShare.NewClient() + if err := liveShareClient.Join(ctx); err != nil { + log.Fatal(fmt.Errorf("error joining liveshare with client: %v", err)) } - terminal := liveShare.NewTerminal() - - cmd := terminal.NewCommand( - "/home/codespace/workspace", - "docker ps -aq --filter label=Type=codespaces --filter status=running", - ) - output, err := cmd.Run(context.Background()) + terminal, err := liveShareClient.NewTerminal() if err != nil { - log.Fatal(fmt.Errorf("error starting ssh server with liveshare: %v", err)) + log.Fatal(fmt.Errorf("error creating liveshare terminal")) } - fmt.Println(string(output)) + containerID, err := getContainerID(ctx, terminal) + if err != nil { + log.Fatal(fmt.Errorf("error getting container id: %v", err)) + } + + if err := setupSSH(ctx, terminal, containerID); err != nil { + log.Fatal(fmt.Errorf("error setting up ssh: %v", err)) + } + + fmt.Println("Starting server...") + + server, err := liveShareClient.NewServer() + if err != nil { + log.Fatal(fmt.Errorf("error creating server: %v", err)) + } + + fmt.Println("Starting sharing...") + if err := server.StartSharing(ctx, "sshd", 2222); err != nil { + log.Fatal(fmt.Errorf("error server sharing: %v", err)) + } + + portForwarder := liveshare.NewLocalPortForwarder(liveShareClient, server, 2222) + + fmt.Println("Listening on port 2222") + if err := portForwarder.Start(ctx); err != nil { + log.Fatal(fmt.Errorf("error forwarding port: %v", err)) + } +} + +func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID string) error { + cmd := terminal.NewCommand( + "/", + fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c \"echo -e \\\"testpwd1\\ntestpwd1\\n\\\" | sudo passwd codespace;/usr/local/share/ssh-init.sh\"", containerID), + ) + stream, err := cmd.Run(ctx) + if err != nil { + return fmt.Errorf("error running command: %v", err) + } + + scanner := bufio.NewScanner(stream) + scanner.Scan() + + fmt.Println("> Debug:", scanner.Text()) + if err := scanner.Err(); err != nil { + return fmt.Errorf("error scanning stream: %v", err) + } + + if err := stream.Close(); err != nil { + return fmt.Errorf("error closing stream: %v", err) + } + + return nil +} + +func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) { + cmd := terminal.NewCommand( + "/", + "/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running", + ) + stream, err := cmd.Run(ctx) + if err != nil { + return "", fmt.Errorf("error running command: %v", err) + } + + scanner := bufio.NewScanner(stream) + scanner.Scan() + + containerID := scanner.Text() + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("error scanning stream: %v", err) + } + + if err := stream.Close(); err != nil { + return "", fmt.Errorf("error closing stream: %v", err) + } + + return containerID, nil } diff --git a/liveshare.go b/liveshare.go index 174eac20f..38222957a 100644 --- a/liveshare.go +++ b/liveshare.go @@ -1,16 +1,11 @@ package liveshare import ( - "context" "fmt" - "net/rpc" ) type LiveShare struct { Configuration *Configuration - - workspaceClient *Client - terminal *Terminal } func New(opts ...Option) (*LiveShare, error) { @@ -28,65 +23,3 @@ func New(opts ...Option) (*LiveShare, error) { return &LiveShare{Configuration: configuration}, nil } - -func (l *LiveShare) Connect(ctx context.Context) error { - l.workspaceClient = NewClient(l.Configuration) - if err := l.workspaceClient.Join(ctx); err != nil { - return fmt.Errorf("error joining with workspace client: %v", err) - } - - return nil -} - -type Terminal struct { - WorkspaceClient *Client - RPCClient *rpc.Client -} - -func (l *LiveShare) NewTerminal() *Terminal { - return &Terminal{ - WorkspaceClient: l.workspaceClient, - RPCClient: rpc.NewClient(l.workspaceClient.SSHSession), - } -} - -type TerminalCommand struct { - Terminal *Terminal - Cwd string - Cmd string -} - -func (t *Terminal) NewCommand(cwd, cmd string) TerminalCommand { - return TerminalCommand{t, cwd, cmd} -} - -type RunArgs struct { - Name string - Rows, Cols int - App string - Cwd string - CommandLine []string - ReadOnlyForGuests bool -} - -func (t TerminalCommand) Run(ctx context.Context) ([]byte, error) { - args := RunArgs{ - Name: "RunCommand", - Rows: 10, - Cols: 80, - App: "/bin/bash", - Cwd: t.Cwd, - CommandLine: []string{"-c", t.Cmd}, - ReadOnlyForGuests: false, - } - - var output []byte - runCall := t.Terminal.RPCClient.Go("terminal.startAsync", &args, &output, nil) - - runReply := <-runCall.Done - if runReply.Error != nil { - return nil, fmt.Errorf("error startAsync operation: %v", runReply.Error) - } - fmt.Printf("%+v\n\n", runReply) - return output, nil -} diff --git a/port_forwarder.go b/port_forwarder.go new file mode 100644 index 000000000..0ae5e1916 --- /dev/null +++ b/port_forwarder.go @@ -0,0 +1,63 @@ +package liveshare + +import ( + "context" + "fmt" + "io" + "log" + "net" + "strconv" + + "golang.org/x/crypto/ssh" +) + +type LocalPortForwarder struct { + client *Client + server *Server + port int + channels []ssh.Channel +} + +func NewLocalPortForwarder(client *Client, server *Server, port int) *LocalPortForwarder { + return &LocalPortForwarder{client, server, port, []ssh.Channel{}} +} + +func (l *LocalPortForwarder) Start(ctx context.Context) error { + ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) + if err != nil { + return fmt.Errorf("error listening on tcp port: %v", err) + } + + for { + conn, err := ln.Accept() + if err != nil { + return fmt.Errorf("error accepting incoming connection: %v", err) + } + + go l.handleConnection(ctx, conn) + } + + // clean up after ourselves + + return nil +} + +func (l *LocalPortForwarder) handleConnection(ctx context.Context, conn net.Conn) { + channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) + if err != nil { + log.Println("errrr handle Connect") + log.Println(err) // TODO(josebalius) handle this somehow + } + l.channels = append(l.channels, channel) + + copyConn := func(writer io.Writer, reader io.Reader) { + _, err := io.Copy(writer, reader) + if err != nil { + log.Println("errrrr copyConn") + log.Println(err) //TODO(josebalius): handle this somehow + } + } + + go copyConn(conn, channel) + go copyConn(channel, conn) +} diff --git a/rpc.go b/rpc.go new file mode 100644 index 000000000..e90f71ba6 --- /dev/null +++ b/rpc.go @@ -0,0 +1,97 @@ +package liveshare + +import ( + "context" + "encoding/json" + "fmt" + "io" + "sync" + + "github.com/sourcegraph/jsonrpc2" +) + +type rpc struct { + *jsonrpc2.Conn + conn io.ReadWriteCloser + handler *rpcHandler +} + +func newRPC(conn io.ReadWriteCloser) *rpc { + return &rpc{conn: conn, handler: newRPCHandler()} +} + +func (r *rpc) connect(ctx context.Context) { + stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) + r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) +} + +func (r *rpc) do(ctx context.Context, method string, args interface{}, result interface{}) error { + b, _ := json.Marshal(args) + fmt.Println("rpc sent: ", method, string(b)) + waiter, err := r.Conn.DispatchCall(ctx, method, args) + if err != nil { + return fmt.Errorf("error on dispatch call: %v", err) + } + + // caller doesn't care about result, so lets ignore it + if result == nil { + return nil + } + + return waiter.Wait(ctx, result) +} + +type rpcHandler struct { + mutex sync.RWMutex + eventHandlers map[string][]chan *jsonrpc2.Request +} + +func newRPCHandler() *rpcHandler { + return &rpcHandler{ + eventHandlers: make(map[string][]chan *jsonrpc2.Request), + } +} + +func (r *rpcHandler) registerEventHandler(eventMethod string) <-chan *jsonrpc2.Request { + r.mutex.Lock() + defer r.mutex.Unlock() + + ch := make(chan *jsonrpc2.Request) + if _, ok := r.eventHandlers[eventMethod]; !ok { + r.eventHandlers[eventMethod] = []chan *jsonrpc2.Request{ch} + } else { + r.eventHandlers[eventMethod] = append(r.eventHandlers[eventMethod], ch) + } + return ch +} + +func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + r.mutex.Lock() + defer r.mutex.Unlock() + + fmt.Println("REQUEST") + fmt.Println("Method:", req.Method) + b, _ := req.MarshalJSON() + fmt.Println(string(b)) + fmt.Println("----") + fmt.Printf("%+v\n\n", r.eventHandlers) + if handlers, ok := r.eventHandlers[req.Method]; ok { + go func() { + for _, handler := range handlers { + select { + case handler <- req: + case <-ctx.Done(): + break + } + } + + r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{} + }() + } else { + fmt.Println("UNHANDLED REQUEST") + fmt.Println("Method:", req.Method) + b, _ := req.MarshalJSON() + fmt.Println(string(b)) + fmt.Println("----") + } +} diff --git a/rpc_test.go b/rpc_test.go new file mode 100644 index 000000000..d16b32a4f --- /dev/null +++ b/rpc_test.go @@ -0,0 +1,27 @@ +package liveshare + +import ( + "context" + "testing" + "time" + + "github.com/sourcegraph/jsonrpc2" +) + +func TestRPCHandlerEvents(t *testing.T) { + rpcHandler := newRPCHandler() + eventCh := rpcHandler.registerEventHandler("somethingHappened") + go func() { + time.Sleep(1 * time.Second) + rpcHandler.Handle(context.Background(), nil, &jsonrpc2.Request{Method: "somethingHappened"}) + }() + ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + select { + case event := <-eventCh: + if event.Method != "somethingHappened" { + t.Error("event.Method is not the expect value") + } + case <-ctx.Done(): + t.Error("Test time out") + } +} diff --git a/server.go b/server.go new file mode 100644 index 000000000..65e03d584 --- /dev/null +++ b/server.go @@ -0,0 +1,52 @@ +package liveshare + +import ( + "context" + "errors" + "fmt" + "strconv" +) + +type Server struct { + client *Client + port int + streamName, streamCondition string +} + +func (c *Client) NewServer() (*Server, error) { + if !c.hasJoined() { + return nil, errors.New("LiveShareClient must join before creating server") + } + + return &Server{client: c}, nil +} + +type serverSharingResponse struct { + SourcePort int `json:"sourcePort"` + DestinationPort int `json:"destinationPort"` + SessionName string `json:"sessionName"` + StreamName string `json:"streamName"` + StreamCondition string `json:"streamCondition"` + BrowseURL string `json:"browseUrl"` + IsPublic bool `json:"isPublic"` + IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"` + HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"` +} + +func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { + s.port = port + + sharingStarted := s.client.rpc.handler.registerEventHandler("serverSharing.sharingStarted") + var response serverSharingResponse + if err := s.client.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ + port, protocol, fmt.Sprintf("http://localhost:%s", strconv.Itoa(port)), + }, &response); err != nil { + return err + } + <-sharingStarted + + s.streamName = response.StreamName + s.streamCondition = response.StreamCondition + + return nil +} diff --git a/session.go b/session.go index 24a284ef2..d0492a10e 100644 --- a/session.go +++ b/session.go @@ -3,42 +3,58 @@ package liveshare import ( "context" "fmt" + "net/url" + "strings" "golang.org/x/sync/errgroup" ) -type Session struct { - WorkspaceAccess *WorkspaceAccessResponse - WorkspaceInfo *WorkspaceInfoResponse +type session struct { + api *api + + workspaceAccess *workspaceAccessResponse + workspaceInfo *workspaceInfoResponse } -func GetSession(ctx context.Context, configuration *Configuration) (*Session, error) { - api := NewAPI(configuration) - session := new(Session) +func newSession(api *api) *session { + return &session{api: api} +} +func (s *session) init(ctx context.Context) error { g, ctx := errgroup.WithContext(ctx) g.Go(func() error { - workspaceAccess, err := api.WorkspaceAccess() + workspaceAccess, err := s.api.workspaceAccess() if err != nil { return fmt.Errorf("error getting workspace access: %v", err) } - session.WorkspaceAccess = workspaceAccess + s.workspaceAccess = workspaceAccess return nil }) g.Go(func() error { - workspaceInfo, err := api.WorkspaceInfo() + workspaceInfo, err := s.api.workspaceInfo() if err != nil { return fmt.Errorf("error getting workspace info: %v", err) } - session.WorkspaceInfo = workspaceInfo + s.workspaceInfo = workspaceInfo return nil }) if err := g.Wait(); err != nil { - return nil, err + return err } - return session, nil + return nil +} + +// Reference: +// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 +func (s *session) relayURI(action string) string { + relaySas := url.QueryEscape(s.workspaceAccess.RelaySas) + relayURI := s.workspaceAccess.RelayLink + relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) + relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) + relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas + return relayURI } diff --git a/ssh.go b/ssh.go index 51906290e..9ae32ed7c 100644 --- a/ssh.go +++ b/ssh.go @@ -1,101 +1,68 @@ package liveshare import ( + "context" "fmt" "io" "net" - "net/url" - "strings" "time" - "github.com/gorilla/websocket" "golang.org/x/crypto/ssh" ) -type SSH struct { - Session *Session -} - -func NewSSH(session *Session) *SSH { - return &SSH{ - Session: session, - } -} - -// Reference: -// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 -func (s *SSH) relayURI(action string) string { - relaySas := url.QueryEscape(s.Session.WorkspaceAccess.RelaySas) - relayURI := s.Session.WorkspaceAccess.RelayLink - relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) - relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) - relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas - return relayURI -} - -func (s *SSH) socketStream() (net.Conn, error) { - uri := s.relayURI("connect") - - ws, _, err := websocket.DefaultDialer.Dial(uri, nil) - if err != nil { - return nil, fmt.Errorf("error dialing websocket connection: %v", err) - } - - return NewAdapter(ws), nil -} - -type SSHSession struct { +type sshSession struct { *ssh.Session - reader io.Reader - writer io.Writer + session *session + socket net.Conn + conn ssh.Conn + reader io.Reader + writer io.Writer } -func (s SSHSession) Read(p []byte) (n int, err error) { - return s.reader.Read(p) +func newSSH(session *session, socket net.Conn) *sshSession { + return &sshSession{session: session, socket: socket} } -func (s SSHSession) Write(p []byte) (n int, err error) { - return s.writer.Write(p) -} - -func (s *SSH) NewSession() (*SSHSession, error) { - socketStream, err := s.socketStream() - if err != nil { - return nil, fmt.Errorf("error creating socket stream: %v", err) - } - +func (s *sshSession) connect(ctx context.Context) error { clientConfig := ssh.ClientConfig{ User: "", Auth: []ssh.AuthMethod{ - ssh.Password(s.Session.WorkspaceAccess.SessionToken), + ssh.Password(s.session.workspaceAccess.SessionToken), }, - HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { - // TODO(josebalius): implement - return nil - }, - Timeout: 10 * time.Second, + HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 10 * time.Second, } - sshClientConn, chans, reqs, err := ssh.NewClientConn(socketStream, "", &clientConfig) + sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig) if err != nil { - return nil, fmt.Errorf("error creating ssh client connection: %v", err) + return fmt.Errorf("error creating ssh client connection: %v", err) } + s.conn = sshClientConn sshClient := ssh.NewClient(sshClientConn, chans, reqs) - sshSession, err := sshClient.NewSession() + s.Session, err = sshClient.NewSession() if err != nil { - return nil, fmt.Errorf("error creating ssh client session: %v", err) + return fmt.Errorf("error creating ssh client session: %v", err) } - reader, err := sshSession.StdoutPipe() + s.reader, err = s.Session.StdoutPipe() if err != nil { - return nil, fmt.Errorf("error creating ssh session reader: %v", err) + return fmt.Errorf("error creating ssh session reader: %v", err) } - writer, err := sshSession.StdinPipe() + s.writer, err = s.Session.StdinPipe() if err != nil { - return nil, fmt.Errorf("error creating ssh session writer: %v", err) + return fmt.Errorf("error creating ssh session writer: %v", err) } - return &SSHSession{Session: sshSession, reader: reader, writer: writer}, nil + return nil +} + +func (s *sshSession) Read(p []byte) (n int, err error) { + return s.reader.Read(p) +} + +func (s *sshSession) Write(p []byte) (n int, err error) { + return s.writer.Write(p) } diff --git a/terminal.go b/terminal.go new file mode 100644 index 000000000..631e75912 --- /dev/null +++ b/terminal.go @@ -0,0 +1,116 @@ +package liveshare + +import ( + "context" + "errors" + "fmt" + "io" + + "golang.org/x/crypto/ssh" +) + +type Terminal struct { + client *Client +} + +func (c *Client) NewTerminal() (*Terminal, error) { + if !c.hasJoined() { + return nil, errors.New("LiveShareClient must join before creating terminal") + } + + return &Terminal{ + client: c, + }, nil +} + +type TerminalCommand struct { + terminal *Terminal + cwd string + cmd string +} + +func (t *Terminal) NewCommand(cwd, cmd string) TerminalCommand { + return TerminalCommand{t, cwd, cmd} +} + +type runArgs struct { + Name string `json:"name"` + Rows int `json:"rows"` + Cols int `json:"cols"` + App string `json:"app"` + Cwd string `json:"cwd"` + CommandLine []string `json:"commandLine"` + ReadOnlyForGuests bool `json:"readOnlyForGuests"` +} + +type startTerminalResult struct { + ID int `json:"id"` + StreamName string `json:"streamName"` + StreamCondition string `json:"streamCondition"` + LocalPipeName string `json:"localPipeName"` + AppProcessID int `json:"appProcessId"` +} + +type getStreamArgs struct { + StreamName string `json:"streamName"` + Condition string `json:"condition"` +} + +type stopTerminalArgs struct { + ID int `json:"id"` +} + +func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { + args := runArgs{ + Name: "RunCommand", + Rows: 10, + Cols: 80, + App: "/bin/bash", + Cwd: t.cwd, + CommandLine: []string{"-c", t.cmd}, + ReadOnlyForGuests: false, + } + + terminalStarted := t.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStarted") + var result startTerminalResult + if err := t.terminal.client.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil { + return nil, fmt.Errorf("error making terminal.startTerminal call: %v", err) + } + <-terminalStarted + + channel, err := t.terminal.client.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) + if err != nil { + return nil, fmt.Errorf("error opening streaming channel: %v", err) + } + + return t.newTerminalReadCloser(result.ID, channel), nil +} + +type terminalReadCloser struct { + terminalCommand TerminalCommand + terminalID int + channel ssh.Channel +} + +func (t TerminalCommand) newTerminalReadCloser(terminalID int, channel ssh.Channel) io.ReadCloser { + return terminalReadCloser{t, terminalID, channel} +} + +func (t terminalReadCloser) Read(b []byte) (int, error) { + return t.channel.Read(b) +} + +func (t terminalReadCloser) Close() error { + terminalStopped := t.terminalCommand.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStopped") + if err := t.terminalCommand.terminal.client.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil { + return fmt.Errorf("error making terminal.stopTerminal call: %v", err) + } + + if err := t.channel.Close(); err != nil { + return fmt.Errorf("error closing channel: %v", err) + } + + <-terminalStopped + + return nil +} diff --git a/websocket.go b/websocket.go new file mode 100644 index 000000000..ae163e6e2 --- /dev/null +++ b/websocket.go @@ -0,0 +1,105 @@ +package liveshare + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + gorillawebsocket "github.com/gorilla/websocket" +) + +type websocket struct { + session *session + conn *gorillawebsocket.Conn + readMutex sync.Mutex + writeMutex sync.Mutex + reader io.Reader +} + +func newWebsocket(session *session) *websocket { + return &websocket{session: session} +} + +func (w *websocket) connect(ctx context.Context) error { + ws, _, err := gorillawebsocket.DefaultDialer.Dial(w.session.relayURI("connect"), nil) + if err != nil { + return err + } + w.conn = ws + return nil +} + +func (w *websocket) Read(b []byte) (int, error) { + w.readMutex.Lock() + defer w.readMutex.Unlock() + + if w.reader == nil { + messageType, reader, err := w.conn.NextReader() + if err != nil { + return 0, err + } + + if messageType != gorillawebsocket.BinaryMessage { + return 0, errors.New("unexpected websocket message type") + } + + w.reader = reader + } + + bytesRead, err := w.reader.Read(b) + if err != nil { + w.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (w *websocket) Write(b []byte) (int, error) { + w.writeMutex.Lock() + defer w.writeMutex.Unlock() + + nextWriter, err := w.conn.NextWriter(gorillawebsocket.BinaryMessage) + if err != nil { + return 0, err + } + + bytesWritten, err := nextWriter.Write(b) + nextWriter.Close() + + return bytesWritten, err +} + +func (w *websocket) Close() error { + return w.conn.Close() +} + +func (w *websocket) LocalAddr() net.Addr { + return w.conn.LocalAddr() +} + +func (w *websocket) RemoteAddr() net.Addr { + return w.conn.RemoteAddr() +} + +func (w *websocket) SetDeadline(t time.Time) error { + if err := w.SetReadDeadline(t); err != nil { + return err + } + + return w.SetWriteDeadline(t) +} + +func (w *websocket) SetReadDeadline(t time.Time) error { + return w.conn.SetReadDeadline(t) +} + +func (w *websocket) SetWriteDeadline(t time.Time) error { + return w.conn.SetWriteDeadline(t) +} From 04a6383ccb778d0683bf63a113ac79b74d6d2b2b Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 7 Jul 2021 08:00:01 -0400 Subject: [PATCH 005/290] Tidy up go.mod --- example/main.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/example/main.go b/example/main.go index b1fda8d32..e9347bd14 100644 --- a/example/main.go +++ b/example/main.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "os" + "time" "github.com/github/go-liveshare" ) @@ -88,6 +89,8 @@ func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID str return fmt.Errorf("error closing stream: %v", err) } + time.Sleep(2 * time.Second) + return nil } From 4a0eaa3da503e5117045ab4ea39ebaafae522c0b Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 14 Jul 2021 16:12:30 -0400 Subject: [PATCH 006/290] Latest and greatest --- api/api.go | 403 +++++++++++++++++++++++++++++++++++++++++++++ cmd/ghcs/create.go | 147 +++++++++++++++++ cmd/ghcs/delete.go | 55 +++++++ cmd/ghcs/list.go | 60 +++++++ cmd/ghcs/main.go | 29 ++++ cmd/ghcs/ssh.go | 262 +++++++++++++++++++++++++++++ 6 files changed, 956 insertions(+) create mode 100644 api/api.go create mode 100644 cmd/ghcs/create.go create mode 100644 cmd/ghcs/delete.go create mode 100644 cmd/ghcs/list.go create mode 100644 cmd/ghcs/main.go create mode 100644 cmd/ghcs/ssh.go diff --git a/api/api.go b/api/api.go new file mode 100644 index 000000000..b3f7577ed --- /dev/null +++ b/api/api.go @@ -0,0 +1,403 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "sort" + "strconv" + "strings" +) + +const githubAPI = "https://api.github.com" + +type API struct { + token string + client *http.Client +} + +func New(token string) *API { + return &API{token, &http.Client{}} +} + +type User struct { + Login string `json:"login"` +} + +type errResponse struct { + Message string `json:"message"` +} + +func (a *API) GetUser(ctx context.Context) (*User, error) { + req, err := http.NewRequest(http.MethodGet, githubAPI+"/user", nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + a.setHeaders(req) + resp, err := a.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, a.errorResponse(b) + } + + var response User + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %v", err) + } + + return &response, nil +} + +func (a *API) errorResponse(b []byte) error { + var response errResponse + if err := json.Unmarshal(b, &response); err != nil { + return fmt.Errorf("error unmarshaling error response: %v", err) + } + + return errors.New(response.Message) +} + +type Repository struct { + ID int `json:"id"` +} + +func (a *API) GetRepository(ctx context.Context, nwo string) (*Repository, error) { + req, err := http.NewRequest(http.MethodGet, githubAPI+"/repos/"+strings.ToLower(nwo), nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + a.setHeaders(req) + resp, err := a.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, a.errorResponse(b) + } + + var response Repository + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %v", err) + } + + return &response, nil +} + +type Codespaces []*Codespace + +func (c Codespaces) SortByRecent() { + sort.Slice(c, func(i, j int) bool { + return c[i].CreatedAt > c[j].CreatedAt + }) +} + +type Codespace struct { + Name string `json:"name"` + GUID string `json:"guid"` + CreatedAt string `json:"created_at"` + Branch string `json:"branch"` + RepositoryName string `json:"repository_name"` + RepositoryNWO string `json:"repository_nwo"` + OwnerLogin string `json:"owner_login"` + Environment CodespaceEnvironment `json:"environment"` +} + +type CodespaceEnvironment struct { + State string `json:"state"` + Connection CodespaceEnvironmentConnection `json:"connection"` +} + +const ( + CodespaceEnvironmentStateAvailable = "Available" +) + +type CodespaceEnvironmentConnection struct { + SessionID string `json:"sessionId"` + SessionToken string `json:"sessionToken"` +} + +func (a *API) ListCodespaces(ctx context.Context, user *User) (Codespaces, error) { + req, err := http.NewRequest( + http.MethodGet, githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces", nil, + ) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + a.setHeaders(req) + resp, err := a.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, a.errorResponse(b) + } + + response := struct { + Codespaces Codespaces `json:"codespaces"` + }{} + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %v", err) + } + return response.Codespaces, nil +} + +type getCodespaceTokenRequest struct { + MintRepositoryToken bool `json:"mint_repository_token"` +} + +type getCodespaceTokenResponse struct { + RepositoryToken string `json:"repository_token"` +} + +func (a *API) GetCodespaceToken(ctx context.Context, codespace *Codespace) (string, error) { + reqBody, err := json.Marshal(getCodespaceTokenRequest{true}) + if err != nil { + return "", fmt.Errorf("error preparing request body: %v", err) + } + + req, err := http.NewRequest( + http.MethodPost, + githubAPI+"/vscs_internal/user/"+codespace.OwnerLogin+"/codespaces/"+codespace.Name+"/token", + bytes.NewBuffer(reqBody), + ) + if err != nil { + return "", fmt.Errorf("error creating request: %v", err) + } + + a.setHeaders(req) + resp, err := a.client.Do(req) + if err != nil { + return "", fmt.Errorf("error making request: %v", err) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("error reading response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return "", a.errorResponse(b) + } + + var response getCodespaceTokenResponse + if err := json.Unmarshal(b, &response); err != nil { + return "", fmt.Errorf("error unmarshaling response: %v", err) + } + + return response.RepositoryToken, nil +} + +func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) (*Codespace, error) { + req, err := http.NewRequest( + http.MethodGet, + githubAPI+"/vscs_internal/user/"+owner+"/codespaces/"+codespace, + nil, + ) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + req.Header.Set("Authorization", "Bearer "+token) + resp, err := a.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, a.errorResponse(b) + } + + var response Codespace + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %v", err) + } + + return &response, nil +} + +func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codespace) error { + req, err := http.NewRequest( + http.MethodPost, + githubAPI+"/vscs_internal/proxy/environments/"+codespace.GUID+"/start", + nil, + ) + if err != nil { + return fmt.Errorf("error creating request: %v", err) + } + + req.Header.Set("Authorization", "Bearer "+token) + _, err = a.client.Do(req) + if err != nil { + return fmt.Errorf("error making request: %v", err) + } + + return nil +} + +type getCodespaceRegionLocationResponse struct { + Current string `json:"current"` +} + +func (a *API) GetCodespaceRegionLocation(ctx context.Context) (string, error) { + req, err := http.NewRequest(http.MethodGet, "https://online.visualstudio.com/api/v1/locations", nil) + if err != nil { + return "", fmt.Errorf("error creating request: %v", err) + } + + resp, err := a.client.Do(req) + if err != nil { + return "", fmt.Errorf("error making request: %v", err) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("error reading response body: %v", err) + } + + var response getCodespaceRegionLocationResponse + if err := json.Unmarshal(b, &response); err != nil { + return "", fmt.Errorf("error unmarshaling response: %v", err) + } + + return response.Current, nil +} + +type Skus []*Sku + +type Sku struct { + Name string `json:"name"` + DisplayName string `json:"display_name"` +} + +func (a *API) GetCodespacesSkus(ctx context.Context, user *User, repository *Repository, location string) (Skus, error) { + req, err := http.NewRequest(http.MethodGet, githubAPI+"/vscs_internal/user/"+user.Login+"/skus", nil) + if err != nil { + return nil, fmt.Errorf("err creating request: %v", err) + } + + q := req.URL.Query() + q.Add("location", location) + q.Add("repository_id", strconv.Itoa(repository.ID)) + req.URL.RawQuery = q.Encode() + + a.setHeaders(req) + resp, err := a.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + response := struct { + Skus Skus `json:"skus"` + }{} + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %v", err) + } + + return response.Skus, nil +} + +type createCodespaceRequest struct { + RepositoryID int `json:"repository_id"` + Ref string `json:"ref"` + Location string `json:"location"` + SkuName string `json:"sku_name"` +} + +func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repository, sku *Sku, branch, location string) (*Codespace, error) { + requestBody, err := json.Marshal(createCodespaceRequest{repository.ID, branch, location, sku.Name}) + if err != nil { + return nil, fmt.Errorf("error marshaling request: %v", err) + } + + req, err := http.NewRequest(http.MethodPost, githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces", bytes.NewBuffer(requestBody)) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + a.setHeaders(req) + resp, err := a.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + if resp.StatusCode > http.StatusAccepted { + return nil, a.errorResponse(b) + } + + var response Codespace + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %v", err) + } + + return &response, nil +} + +func (a *API) DeleteCodespace(ctx context.Context, user *User, token, codespaceName string) error { + req, err := http.NewRequest(http.MethodDelete, githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces/"+codespaceName, nil) + if err != nil { + return fmt.Errorf("error creating request: %v", err) + } + + req.Header.Set("Authorization", "Bearer "+token) + resp, err := a.client.Do(req) + if err != nil { + return fmt.Errorf("error making request: %v", err) + } + + if resp.StatusCode > http.StatusAccepted { + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error reading response body: %v", err) + } + return a.errorResponse(b) + } + + return nil +} + +func (a *API) setHeaders(req *http.Request) { + req.Header.Set("Authorization", "Bearer "+a.token) + req.Header.Set("Accept", "application/vnd.github.v3+json") +} diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go new file mode 100644 index 000000000..44bedb5f2 --- /dev/null +++ b/cmd/ghcs/create.go @@ -0,0 +1,147 @@ +package main + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/AlecAivazis/survey/v2" + "github.com/fatih/camelcase" + "github.com/github/ghcs/api" + "github.com/spf13/cobra" +) + +var createCmd = &cobra.Command{ + Use: "create", + Short: "Create", + Long: "Create", + RunE: func(cmd *cobra.Command, args []string) error { + return Create() + }, +} + +func init() { + rootCmd.AddCommand(createCmd) +} + +var createSurvey = []*survey.Question{ + { + Name: "repository", + Prompt: &survey.Input{Message: "Repository"}, + Validate: survey.Required, + }, + { + Name: "branch", + Prompt: &survey.Input{Message: "Branch"}, + Validate: survey.Required, + }, +} + +func Create() error { + ctx := context.Background() + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + locationCh := getLocation(ctx, apiClient) + userCh := getUser(ctx, apiClient) + + answers := struct { + Repository string + Branch string + }{} + + if err := survey.Ask(createSurvey, &answers); err != nil { + return fmt.Errorf("error getting answers: %v", err) + } + + repository, err := apiClient.GetRepository(ctx, answers.Repository) + if err != nil { + return fmt.Errorf("error getting repository: %v", err) + } + + locationResult := <-locationCh + if locationResult.Err != nil { + return fmt.Errorf("error getting codespace region location: %v", locationResult.Err) + } + + userResult := <-userCh + if userResult.Err != nil { + return fmt.Errorf("error getting codespace user: %v", userResult.Err) + } + + skus, err := apiClient.GetCodespacesSkus(ctx, userResult.User, repository, locationResult.Location) + if err != nil { + return fmt.Errorf("error getting codespace skus: %v", err) + } + + if len(skus) == 0 { + fmt.Println("There are no available machine types for this repository") + return nil + } + + skuNames := make([]string, 0, len(skus)) + skuByName := make(map[string]*api.Sku) + for _, sku := range skus { + nameParts := camelcase.Split(sku.Name) + machineName := strings.Title(strings.ToLower(nameParts[0])) + skuName := fmt.Sprintf("%s - %s", machineName, sku.DisplayName) + skuNames = append(skuNames, skuName) + skuByName[skuName] = sku + } + + skuSurvey := []*survey.Question{ + { + Name: "sku", + Prompt: &survey.Select{ + Message: "Choose Machine Type:", + Options: skuNames, + Default: skuNames[0], + }, + Validate: survey.Required, + }, + } + + skuAnswers := struct{ SKU string }{} + if err := survey.Ask(skuSurvey, &skuAnswers); err != nil { + return fmt.Errorf("error getting SKU: %v", err) + } + + sku := skuByName[skuAnswers.SKU] + fmt.Println("Creating your codespace...") + + codespace, err := apiClient.CreateCodespace(ctx, userResult.User, repository, sku, answers.Branch, locationResult.Location) + if err != nil { + return fmt.Errorf("error creating codespace: %v", err) + } + + fmt.Println("Codespace created: " + codespace.Name) + + return nil +} + +type getUserResult struct { + User *api.User + Err error +} + +func getUser(ctx context.Context, apiClient *api.API) <-chan getUserResult { + ch := make(chan getUserResult) + go func() { + user, err := apiClient.GetUser(ctx) + ch <- getUserResult{user, err} + }() + return ch +} + +type locationResult struct { + Location string + Err error +} + +func getLocation(ctx context.Context, apiClient *api.API) <-chan locationResult { + ch := make(chan locationResult) + go func() { + location, err := apiClient.GetCodespaceRegionLocation(ctx) + ch <- locationResult{location, err} + }() + return ch +} diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go new file mode 100644 index 000000000..e5cd34a94 --- /dev/null +++ b/cmd/ghcs/delete.go @@ -0,0 +1,55 @@ +package main + +import ( + "context" + "errors" + "fmt" + "os" + + "github.com/github/ghcs/api" + "github.com/spf13/cobra" +) + +func NewDeleteCmd() *cobra.Command { + deleteCmd := &cobra.Command{ + Use: "delete CODESPACE_NAME", + Short: "delete", + Long: "delete", + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) == 0 { + return errors.New("A Codespace name is required.") + } + return Delete(args[0]) + }, + } + + return deleteCmd +} + +func init() { + rootCmd.AddCommand(NewDeleteCmd()) +} + +func Delete(codespaceName string) error { + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + ctx := context.Background() + + user, err := apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %v", err) + } + + codespace := api.Codespace{OwnerLogin: user.Login, Name: codespaceName} + token, err := apiClient.GetCodespaceToken(ctx, &codespace) + if err != nil { + return fmt.Errorf("error getting codespace token: %v", err) + } + + if err := apiClient.DeleteCodespace(ctx, user, token, codespaceName); err != nil { + return fmt.Errorf("error deleting codespace: %v", err) + } + + fmt.Println("Codespace deleted.") + + return List() +} diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go new file mode 100644 index 000000000..e02e6a1d2 --- /dev/null +++ b/cmd/ghcs/list.go @@ -0,0 +1,60 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/olekukonko/tablewriter" + + "github.com/github/ghcs/api" + "github.com/spf13/cobra" +) + +func NewListCmd() *cobra.Command { + listCmd := &cobra.Command{ + Use: "list", + Short: "list", + Long: "list", + RunE: func(cmd *cobra.Command, args []string) error { + return List() + }, + } + + return listCmd +} + +func init() { + rootCmd.AddCommand(NewListCmd()) +} + +func List() error { + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + ctx := context.Background() + + user, err := apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %v", err) + } + + codespaces, err := apiClient.ListCodespaces(ctx, user) + if err != nil { + return fmt.Errorf("error getting codespaces: %v", err) + } + + if len(codespaces) == 0 { + fmt.Println("You have no codespaces.") + return nil + } + + table := tablewriter.NewWriter(os.Stdout) + table.SetHeader([]string{"Name", "Repository", "Branch", "State", "Created At"}) + for _, codespace := range codespaces { + table.Append([]string{ + codespace.Name, codespace.RepositoryNWO, codespace.Branch, codespace.Environment.State, codespace.CreatedAt, + }) + } + + table.Render() + return nil +} diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go new file mode 100644 index 000000000..400f5324c --- /dev/null +++ b/cmd/ghcs/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" +) + +// ghcs create +// ghcs connect +// ghcs delete +// ghcs list +func main() { + Execute() +} + +var rootCmd = &cobra.Command{ + Use: "ghcs", + Short: "Codespaces", + Long: "Codespaces", +} + +func Execute() { + if err := rootCmd.Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go new file mode 100644 index 000000000..56f22224f --- /dev/null +++ b/cmd/ghcs/ssh.go @@ -0,0 +1,262 @@ +package main + +import ( + "bufio" + "context" + "errors" + "fmt" + "log" + "math/rand" + "os" + "os/exec" + "strconv" + "strings" + "time" + + "github.com/AlecAivazis/survey/v2" + "github.com/github/ghcs/api" + "github.com/github/go-liveshare" + "github.com/spf13/cobra" +) + +func NewSSHCmd() *cobra.Command { + var sshProfile string + + sshCmd := &cobra.Command{ + Use: "ssh", + Short: "ssh", + Long: "ssh", + RunE: func(cmd *cobra.Command, args []string) error { + return SSH(sshProfile) + }, + } + + sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "SSH Profile") + + return sshCmd +} + +func init() { + rootCmd.AddCommand(NewSSHCmd()) +} + +func SSH(sshProfile string) error { + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + ctx := context.Background() + + user, err := apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %v", err) + } + + codespaces, err := apiClient.ListCodespaces(ctx, user) + if err != nil { + return fmt.Errorf("error getting codespaces: %v", err) + } + + if len(codespaces) == 0 { + fmt.Println("You have no codespaces.") + return nil + } + + codespaces.SortByRecent() + + codespacesByName := make(map[string]*api.Codespace) + codespacesNames := make([]string, 0, len(codespaces)) + for _, codespace := range codespaces { + codespacesByName[codespace.Name] = codespace + codespacesNames = append(codespacesNames, codespace.Name) + } + + sshSurvey := []*survey.Question{ + { + Name: "codespace", + Prompt: &survey.Select{ + Message: "Choose Codespace:", + Options: codespacesNames, + Default: codespacesNames[0], + }, + Validate: survey.Required, + }, + } + + answers := struct { + Codespace string + }{} + if err := survey.Ask(sshSurvey, &answers); err != nil { + return fmt.Errorf("error getting answers: %v", err) + } + + codespace := codespacesByName[answers.Codespace] + + token, err := apiClient.GetCodespaceToken(ctx, codespace) + if err != nil { + return fmt.Errorf("error getting codespace token: %v", err) + } + + if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { + fmt.Println("Starting your codespace...") + if err := apiClient.StartCodespace(ctx, token, codespace); err != nil { + return fmt.Errorf("error starting codespace: %v", err) + } + } + + retries := 0 + for codespace.Environment.Connection.SessionID == "" || codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { + if retries > 1 { + if retries%2 == 0 { + fmt.Print(".") + } + + time.Sleep(1 * time.Second) + } + + if retries == 10 { + return errors.New("Failed to start codespace") + } + + codespace, err = apiClient.GetCodespace(ctx, token, codespace.OwnerLogin, codespace.Name) + if err != nil { + return fmt.Errorf("error getting codespace: %v", err) + } + + retries += 1 + } + + if retries >= 2 { + fmt.Print("\n") + } + + fmt.Println("Connecting to your codespace...") + + liveShare, err := liveshare.New( + liveshare.WithWorkspaceID(codespace.Environment.Connection.SessionID), + liveshare.WithToken(codespace.Environment.Connection.SessionToken), + ) + if err != nil { + return fmt.Errorf("error creating live share: %v", err) + } + + liveShareClient := liveShare.NewClient() + if err := liveShareClient.Join(ctx); err != nil { + return fmt.Errorf("error joining liveshare client: %v", err) + } + + terminal, err := liveShareClient.NewTerminal() + if err != nil { + return fmt.Errorf("error creating liveshare terminal: %v", err) + } + + if sshProfile == "" { + containerID, err := getContainerID(ctx, terminal) + if err != nil { + return fmt.Errorf("error getting container id: %v", err) + } + + if err := setupSSH(ctx, terminal, containerID, codespace.RepositoryName); err != nil { + return fmt.Errorf("error creating ssh server: %v", err) + } + } + + server, err := liveShareClient.NewServer() + if err != nil { + return fmt.Errorf("error creating server: %v", err) + } + + rand.Seed(time.Now().Unix()) + port := rand.Intn(9999-2000) + 2000 // improve this obviously + if err := server.StartSharing(ctx, "sshd", 2222); err != nil { + return fmt.Errorf("error sharing sshd port: %v", err) + } + + portForwarder := liveshare.NewLocalPortForwarder(liveShareClient, server, port) + go func() { + if err := portForwarder.Start(ctx); err != nil { + panic(fmt.Errorf("error forwarding port: %v", err)) + } + }() + + if err := connect(ctx, port, sshProfile); err != nil { + return fmt.Errorf("error connecting via SSH: %v", err) + } + + return nil +} + +func connect(ctx context.Context, port int, sshProfile string) error { + var cmd *exec.Cmd + if sshProfile != "" { + cmd = exec.CommandContext(ctx, "ssh", sshProfile, "-p", strconv.Itoa(port), "-C") + } else { + cmd = exec.CommandContext(ctx, "ssh", "codespace@localhost", "-C", "-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes") + } + + cmd.Stdout = os.Stdout + cmd.Stdin = os.Stdin + cmd.Stderr = os.Stderr + + if err := cmd.Start(); err != nil { + return fmt.Errorf("error running ssh: %v", err) + } + + go func() { + if err := cmd.Wait(); err != nil { + log.Println(fmt.Errorf("error waiting for ssh to finish: %v", err)) + } + }() + + done := make(chan bool) + <-done + + return nil +} + +func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) { + cmd := terminal.NewCommand( + "/", + "/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running", + ) + stream, err := cmd.Run(ctx) + if err != nil { + return "", fmt.Errorf("error running command: %v", err) + } + + scanner := bufio.NewScanner(stream) + scanner.Scan() + + containerID := scanner.Text() + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("error scanning stream: %v", err) + } + + if err := stream.Close(); err != nil { + return "", fmt.Errorf("error closing stream: %v", err) + } + + return containerID, nil +} + +func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID, repositoryName string) error { + getUsernameCmd := "GITHUB_USERNAME=\"$(jq .CODESPACE_NAME /workspaces/.codespaces/shared/environment-variables.json -r | cut -f1 -d -)\"" + makeSSHDirCmd := "mkdir /home/codespace/.ssh" + getUserKeysCmd := "curl --silent --fail \"https://github.com/$(echo $GITHUB_USERNAME).keys\" > /home/codespace/.ssh/authorized_keys" + setupLoginDirCmd := fmt.Sprintf("echo \"cd /workspaces/%v\" > /home/codespace/.bash_profile", repositoryName) + + compositeCommand := []string{getUsernameCmd, makeSSHDirCmd, getUserKeysCmd, setupLoginDirCmd} + cmd := terminal.NewCommand( + "/", + fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c '"+strings.Join(compositeCommand, "; ")+"'", containerID), + ) + stream, err := cmd.Run(ctx) + if err != nil { + return fmt.Errorf("error running command: %v", err) + } + + if err := stream.Close(); err != nil { + return fmt.Errorf("error closing stream: %v", err) + } + + time.Sleep(1 * time.Second) + + return nil +} From 53fd96d22ec4443ecd98b8b645f7ea0eee6e6d31 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 14 Jul 2021 20:47:06 -0400 Subject: [PATCH 007/290] Some polish and module replacement --- api.go | 1 - client.go | 7 ++----- config.go | 56 ---------------------------------------------------- liveshare.go | 52 ++++++++++++++++++++++++++++++++++++++++++++++++ rpc.go | 15 +------------- server.go | 2 -- 6 files changed, 55 insertions(+), 78 deletions(-) delete mode 100644 config.go diff --git a/api.go b/api.go index a101823df..55b6e6e93 100644 --- a/api.go +++ b/api.go @@ -52,7 +52,6 @@ type workspaceAccessResponse struct { func (a *api) workspaceAccess() (*workspaceAccessResponse, error) { url := fmt.Sprintf("%s/workspace/%s/user", a.serviceURI, a.workspaceID) - fmt.Println(url) req, err := http.NewRequest(http.MethodPut, url, nil) if err != nil { diff --git a/client.go b/client.go index 88c91ba01..0af904b92 100644 --- a/client.go +++ b/client.go @@ -3,7 +3,6 @@ package liveshare import ( "context" "fmt" - "log" "golang.org/x/crypto/ssh" ) @@ -100,11 +99,10 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition go c.processChannelRequests(ctx, reqs) requestType := fmt.Sprintf("stream-transport-%s", streamID) - acked, err := channel.SendRequest(requestType, true, nil) + _, err = channel.SendRequest(requestType, true, nil) if err != nil { return nil, fmt.Errorf("error sending channel request: %v", err) } - fmt.Println("ACKED: ", acked) return channel, nil } @@ -114,8 +112,7 @@ func (c *Client) processChannelRequests(ctx context.Context, reqs <-chan *ssh.Re select { case req := <-reqs: if req != nil { - fmt.Printf("REQ: %+v\n\n", req) - log.Println("streaming channel requests are not supported") + // TODO(josebalius): Handle } case <-ctx.Done(): break diff --git a/config.go b/config.go deleted file mode 100644 index 74eb5b178..000000000 --- a/config.go +++ /dev/null @@ -1,56 +0,0 @@ -package liveshare - -import ( - "errors" - "strings" -) - -type Option func(configuration *Configuration) error - -func WithWorkspaceID(id string) Option { - return func(configuration *Configuration) error { - configuration.WorkspaceID = id - return nil - } -} - -func WithLiveShareEndpoint(liveShareEndpoint string) Option { - return func(configuration *Configuration) error { - configuration.LiveShareEndpoint = liveShareEndpoint - return nil - } -} - -func WithToken(token string) Option { - return func(configuration *Configuration) error { - configuration.Token = token - return nil - } -} - -type Configuration struct { - WorkspaceID, LiveShareEndpoint, Token string -} - -func NewConfiguration() *Configuration { - return &Configuration{ - LiveShareEndpoint: "https://prod.liveshare.vsengsaas.visualstudio.com", - } -} - -func (c *Configuration) Validate() error { - errs := []string{} - if c.WorkspaceID == "" { - errs = append(errs, "WorkspaceID is required") - } - - if c.Token == "" { - errs = append(errs, "Token is required") - } - - if len(errs) > 0 { - return errors.New(strings.Join(errs, ", ")) - } - - return nil -} diff --git a/liveshare.go b/liveshare.go index 38222957a..3c4be5c05 100644 --- a/liveshare.go +++ b/liveshare.go @@ -1,7 +1,9 @@ package liveshare import ( + "errors" "fmt" + "strings" ) type LiveShare struct { @@ -23,3 +25,53 @@ func New(opts ...Option) (*LiveShare, error) { return &LiveShare{Configuration: configuration}, nil } + +type Option func(configuration *Configuration) error + +func WithWorkspaceID(id string) Option { + return func(configuration *Configuration) error { + configuration.WorkspaceID = id + return nil + } +} + +func WithLiveShareEndpoint(liveShareEndpoint string) Option { + return func(configuration *Configuration) error { + configuration.LiveShareEndpoint = liveShareEndpoint + return nil + } +} + +func WithToken(token string) Option { + return func(configuration *Configuration) error { + configuration.Token = token + return nil + } +} + +type Configuration struct { + WorkspaceID, LiveShareEndpoint, Token string +} + +func NewConfiguration() *Configuration { + return &Configuration{ + LiveShareEndpoint: "https://prod.liveshare.vsengsaas.visualstudio.com", + } +} + +func (c *Configuration) Validate() error { + errs := []string{} + if c.WorkspaceID == "" { + errs = append(errs, "WorkspaceID is required") + } + + if c.Token == "" { + errs = append(errs, "Token is required") + } + + if len(errs) > 0 { + return errors.New(strings.Join(errs, ", ")) + } + + return nil +} diff --git a/rpc.go b/rpc.go index e90f71ba6..de427cda9 100644 --- a/rpc.go +++ b/rpc.go @@ -2,7 +2,6 @@ package liveshare import ( "context" - "encoding/json" "fmt" "io" "sync" @@ -26,8 +25,6 @@ func (r *rpc) connect(ctx context.Context) { } func (r *rpc) do(ctx context.Context, method string, args interface{}, result interface{}) error { - b, _ := json.Marshal(args) - fmt.Println("rpc sent: ", method, string(b)) waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { return fmt.Errorf("error on dispatch call: %v", err) @@ -69,12 +66,6 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr r.mutex.Lock() defer r.mutex.Unlock() - fmt.Println("REQUEST") - fmt.Println("Method:", req.Method) - b, _ := req.MarshalJSON() - fmt.Println(string(b)) - fmt.Println("----") - fmt.Printf("%+v\n\n", r.eventHandlers) if handlers, ok := r.eventHandlers[req.Method]; ok { go func() { for _, handler := range handlers { @@ -88,10 +79,6 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{} }() } else { - fmt.Println("UNHANDLED REQUEST") - fmt.Println("Method:", req.Method) - b, _ := req.MarshalJSON() - fmt.Println(string(b)) - fmt.Println("----") + // TODO(josebalius): Handle } } diff --git a/server.go b/server.go index 65e03d584..b0f3996c9 100644 --- a/server.go +++ b/server.go @@ -36,14 +36,12 @@ type serverSharingResponse struct { func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { s.port = port - sharingStarted := s.client.rpc.handler.registerEventHandler("serverSharing.sharingStarted") var response serverSharingResponse if err := s.client.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ port, protocol, fmt.Sprintf("http://localhost:%s", strconv.Itoa(port)), }, &response); err != nil { return err } - <-sharingStarted s.streamName = response.StreamName s.streamCondition = response.StreamCondition From a5f558bf2a577bc276d2eff0b139098ee909511b Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 15 Jul 2021 08:49:18 -0400 Subject: [PATCH 008/290] Makes secrets work --- cmd/ghcs/ssh.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 56f22224f..366076516 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -147,6 +147,7 @@ func SSH(sshProfile string) error { return fmt.Errorf("error creating liveshare terminal: %v", err) } + fmt.Println("Preparing SSH...") if sshProfile == "" { containerID, err := getContainerID(ctx, terminal) if err != nil { @@ -176,6 +177,7 @@ func SSH(sshProfile string) error { } }() + fmt.Println("Ready...") if err := connect(ctx, port, sshProfile); err != nil { return fmt.Errorf("error connecting via SSH: %v", err) } @@ -240,9 +242,10 @@ func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID, re getUsernameCmd := "GITHUB_USERNAME=\"$(jq .CODESPACE_NAME /workspaces/.codespaces/shared/environment-variables.json -r | cut -f1 -d -)\"" makeSSHDirCmd := "mkdir /home/codespace/.ssh" getUserKeysCmd := "curl --silent --fail \"https://github.com/$(echo $GITHUB_USERNAME).keys\" > /home/codespace/.ssh/authorized_keys" - setupLoginDirCmd := fmt.Sprintf("echo \"cd /workspaces/%v\" > /home/codespace/.bash_profile", repositoryName) + setupSecretsCmd := `cat /workspaces/.codespaces/shared/.user-secrets.json | jq -r ".[] | select (.type==\"EnvironmentVariable\") | .name+\"=\"+.value" > /home/codespace/.zshenv` + setupLoginDirCmd := fmt.Sprintf("echo \"cd /workspaces/%v; exec /bin/zsh;\" > /home/codespace/.bash_profile", repositoryName) - compositeCommand := []string{getUsernameCmd, makeSSHDirCmd, getUserKeysCmd, setupLoginDirCmd} + compositeCommand := []string{getUsernameCmd, makeSSHDirCmd, getUserKeysCmd, setupSecretsCmd, setupLoginDirCmd} cmd := terminal.NewCommand( "/", fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c '"+strings.Join(compositeCommand, "; ")+"'", containerID), From d46420e812aa34e21d667508a32378c9f3e18c90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 15 Jul 2021 16:07:23 +0200 Subject: [PATCH 009/290] Improve ssh command - Ensure parent process exits when `ssh` sub-process is done - Enable connections to `github/github` when `--profile` flag wasn't given --- cmd/ghcs/ssh.go | 41 ++++++++++++++++------------------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 366076516..bdaea644e 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -5,7 +5,6 @@ import ( "context" "errors" "fmt" - "log" "math/rand" "os" "os/exec" @@ -177,40 +176,25 @@ func SSH(sshProfile string) error { } }() + connectDestination := sshProfile + if connectDestination == "" { + connectDestination = fmt.Sprintf("%s@localhost", getSSHUser(codespace)) + } + fmt.Println("Ready...") - if err := connect(ctx, port, sshProfile); err != nil { + if err := connect(ctx, port, connectDestination); err != nil { return fmt.Errorf("error connecting via SSH: %v", err) } return nil } -func connect(ctx context.Context, port int, sshProfile string) error { - var cmd *exec.Cmd - if sshProfile != "" { - cmd = exec.CommandContext(ctx, "ssh", sshProfile, "-p", strconv.Itoa(port), "-C") - } else { - cmd = exec.CommandContext(ctx, "ssh", "codespace@localhost", "-C", "-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes") - } - +func connect(ctx context.Context, port int, destination string) error { + cmd := exec.CommandContext(ctx, "ssh", destination, "-C", "-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes") cmd.Stdout = os.Stdout cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr - - if err := cmd.Start(); err != nil { - return fmt.Errorf("error running ssh: %v", err) - } - - go func() { - if err := cmd.Wait(); err != nil { - log.Println(fmt.Errorf("error waiting for ssh to finish: %v", err)) - } - }() - - done := make(chan bool) - <-done - - return nil + return cmd.Run() } func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) { @@ -263,3 +247,10 @@ func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID, re return nil } + +func getSSHUser(codespace *api.Codespace) string { + if codespace.RepositoryNWO == "github/github" { + return "root" + } + return "codespace" +} From d506a97419e4f1e2e3f35746bb0426d42fd598de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 15 Jul 2021 16:10:03 +0200 Subject: [PATCH 010/290] Increase ssh command timeout and improve error message - My `github/github` codespace failed to start within 10s - Output more precise error message --- cmd/ghcs/ssh.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 366076516..90a7f7bfc 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -111,8 +111,8 @@ func SSH(sshProfile string) error { time.Sleep(1 * time.Second) } - if retries == 10 { - return errors.New("Failed to start codespace") + if retries == 30 { + return errors.New("timed out while waiting for the codespace to start") } codespace, err = apiClient.GetCodespace(ctx, token, codespace.OwnerLogin, codespace.Name) From ecea5b821aceec24133ca88c8b2bf1e68a124841 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 15 Jul 2021 14:35:26 +0000 Subject: [PATCH 011/290] Give more time to start --- cmd/ghcs/ssh.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index bdaea644e..ee33260ae 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -110,8 +110,8 @@ func SSH(sshProfile string) error { time.Sleep(1 * time.Second) } - if retries == 10 { - return errors.New("Failed to start codespace") + if retries == 20 { + return errors.New("Timed out waiting for Codespace to start. Try again.") } codespace, err = apiClient.GetCodespace(ctx, token, codespace.OwnerLogin, codespace.Name) From 98bcdd16cfccafd7ef601067287012d8010a150f Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 16 Jul 2021 22:34:51 +0000 Subject: [PATCH 012/290] Support for GetSharedServers --- server.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/server.go b/server.go index b0f3996c9..68053d9f7 100644 --- a/server.go +++ b/server.go @@ -21,7 +21,7 @@ func (c *Client) NewServer() (*Server, error) { return &Server{client: c}, nil } -type serverSharingResponse struct { +type Port struct { SourcePort int `json:"sourcePort"` DestinationPort int `json:"destinationPort"` SessionName string `json:"sessionName"` @@ -36,7 +36,7 @@ type serverSharingResponse struct { func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { s.port = port - var response serverSharingResponse + var response Port if err := s.client.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ port, protocol, fmt.Sprintf("http://localhost:%s", strconv.Itoa(port)), }, &response); err != nil { @@ -48,3 +48,14 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er return nil } + +type Ports []*Port + +func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { + var response Ports + if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { + return nil, err + } + + return response, nil +} From 3c42ab8f7a3eb5068bd0bb5edbdb76cbc10664b3 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 16 Jul 2021 18:45:38 -0400 Subject: [PATCH 013/290] ghcs ports v1 --- api/api.go | 49 +++++++++- cmd/ghcs/ports.go | 233 ++++++++++++++++++++++++++++++++++++++++++++++ cmd/ghcs/ssh.go | 2 +- 3 files changed, 282 insertions(+), 2 deletions(-) create mode 100644 cmd/ghcs/ports.go diff --git a/api/api.go b/api/api.go index b3f7577ed..00ff6b056 100644 --- a/api/api.go +++ b/api/api.go @@ -3,6 +3,7 @@ package api import ( "bytes" "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -105,7 +106,7 @@ func (a *API) GetRepository(ctx context.Context, nwo string) (*Repository, error type Codespaces []*Codespace -func (c Codespaces) SortByRecent() { +func (c Codespaces) SortByCreatedAt() { sort.Slice(c, func(i, j int) bool { return c[i].CreatedAt > c[j].CreatedAt }) @@ -397,6 +398,52 @@ func (a *API) DeleteCodespace(ctx context.Context, user *User, token, codespaceN return nil } +type getCodespaceRepositoryContentsResponse struct { + Content string `json:"content"` +} + +func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Codespace, path string) ([]byte, error) { + req, err := http.NewRequest(http.MethodGet, githubAPI+"/repos/"+codespace.RepositoryNWO+"/contents/"+path, nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + q := req.URL.Query() + q.Add("ref", codespace.Branch) + req.URL.RawQuery = q.Encode() + + a.setHeaders(req) + resp, err := a.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + if resp.StatusCode == http.StatusNotFound { + return nil, nil + } + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, a.errorResponse(b) + } + + var response getCodespaceRepositoryContentsResponse + if err := json.Unmarshal(b, &response); err != nil { + return nil, fmt.Errorf("error unmarshaling response: %v", err) + } + + decoded, err := base64.StdEncoding.DecodeString(response.Content) + if err != nil { + return nil, fmt.Errorf("error decoding content: %v", err) + } + + return decoded, nil +} + func (a *API) setHeaders(req *http.Request) { req.Header.Set("Authorization", "Bearer "+a.token) req.Header.Set("Accept", "application/vnd.github.v3+json") diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go new file mode 100644 index 000000000..9d58fd491 --- /dev/null +++ b/cmd/ghcs/ports.go @@ -0,0 +1,233 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "strconv" + "strings" + "time" + + "github.com/AlecAivazis/survey/v2" + "github.com/github/ghcs/api" + "github.com/github/go-liveshare" + "github.com/muhammadmuzzammil1998/jsonc" + "github.com/olekukonko/tablewriter" + "github.com/spf13/cobra" +) + +func NewPortsCmd() *cobra.Command { + portsCmd := &cobra.Command{ + Use: "ports", + Short: "ports", + Long: "ports", + RunE: func(cmd *cobra.Command, args []string) error { + return Ports() + }, + } + + return portsCmd +} + +func init() { + rootCmd.AddCommand(NewPortsCmd()) +} + +func Ports() error { + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + ctx := context.Background() + + user, err := apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %v", err) + } + + codespaces, err := apiClient.ListCodespaces(ctx, user) + if err != nil { + return fmt.Errorf("error getting codespaces: %v", err) + } + + if len(codespaces) == 0 { + fmt.Println("You have no codespaces.") + return nil + } + + codespaces.SortByCreatedAt() + + codespacesByName := make(map[string]*api.Codespace) + codespacesNames := make([]string, 0, len(codespaces)) + for _, codespace := range codespaces { + codespacesByName[codespace.Name] = codespace + codespacesNames = append(codespacesNames, codespace.Name) + } + + portsSurvey := []*survey.Question{ + { + Name: "codespace", + Prompt: &survey.Select{ + Message: "Choose Codespace:", + Options: codespacesNames, + Default: codespacesNames[0], + }, + Validate: survey.Required, + }, + } + + answers := struct { + Codespace string + }{} + if err := survey.Ask(portsSurvey, &answers); err != nil { + return fmt.Errorf("error getting answers: %v", err) + } + + codespace := codespacesByName[answers.Codespace] + devContainerCh := getDevContainer(ctx, apiClient, codespace) + + token, err := apiClient.GetCodespaceToken(ctx, codespace) + if err != nil { + return fmt.Errorf("error getting codespace token: %v", err) + } + + if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { + fmt.Println("Starting your codespace...") + if err := apiClient.StartCodespace(ctx, token, codespace); err != nil { + return fmt.Errorf("error starting codespace: %v", err) + } + } + + retries := 0 + for codespace.Environment.Connection.SessionID == "" || codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { + if retries > 1 { + if retries%2 == 0 { + fmt.Print(".") + } + + time.Sleep(1 * time.Second) + } + + if retries == 30 { + return errors.New("timed out while waiting for the codespace to start") + } + + codespace, err = apiClient.GetCodespace(ctx, token, codespace.OwnerLogin, codespace.Name) + if err != nil { + return fmt.Errorf("error getting codespace: %v", err) + } + + retries += 1 + } + + if retries >= 2 { + fmt.Print("\n") + } + + fmt.Println("Connecting to your codespace...") + + liveShare, err := liveshare.New( + liveshare.WithWorkspaceID(codespace.Environment.Connection.SessionID), + liveshare.WithToken(codespace.Environment.Connection.SessionToken), + ) + if err != nil { + return fmt.Errorf("error creating live share: %v", err) + } + + liveShareClient := liveShare.NewClient() + if err := liveShareClient.Join(ctx); err != nil { + return fmt.Errorf("error joining liveshare client: %v", err) + } + + fmt.Println("Loading ports...") + ports, err := getPorts(ctx, liveShareClient) + if err != nil { + return fmt.Errorf("error getting ports: %v", err) + } + + devContainerResult := <-devContainerCh + if devContainerResult.Err != nil { + fmt.Println("Failed to get port names: %v", devContainerResult.Err.Error()) + } + + table := tablewriter.NewWriter(os.Stdout) + table.SetHeader([]string{"Label", "Source Port", "Destination Port", "Public", "Browse URL"}) + for _, port := range ports { + sourcePort := strconv.Itoa(port.SourcePort) + var portName string + if devContainerResult.DevContainer != nil { + if attributes, ok := devContainerResult.DevContainer.PortAttributes[sourcePort]; ok { + portName = attributes.Label + } + } + + table.Append([]string{ + portName, + sourcePort, + strconv.Itoa(port.DestinationPort), + strings.ToUpper(strconv.FormatBool(port.IsPublic)), + fmt.Sprintf("https://%s-%s.githubpreview.dev/", codespace.Name, sourcePort), + }) + } + table.Render() + + return nil + +} + +func getPorts(ctx context.Context, liveShareClient *liveshare.Client) (liveshare.Ports, error) { + server, err := liveShareClient.NewServer() + if err != nil { + return nil, fmt.Errorf("error creating server: %v", err) + } + + ports, err := server.GetSharedServers(ctx) + if err != nil { + return nil, fmt.Errorf("error getting shared servers: %v", err) + } + + return ports, nil +} + +type devContainerResult struct { + DevContainer *devContainer + Err error +} + +type devContainer struct { + PortAttributes map[string]portAttribute `json:"portsAttributes"` +} + +type portAttribute struct { + Label string `json:"label"` +} + +func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Codespace) <-chan devContainerResult { + ch := make(chan devContainerResult) + go func() { + contents, err := apiClient.GetCodespaceRepositoryContents(ctx, codespace, ".devcontainer/devcontainer.json") + if err != nil { + ch <- devContainerResult{nil, fmt.Errorf("error getting content: %v", err)} + return + } + + if contents == nil { + ch <- devContainerResult{nil, nil} + return + } + + convertedJSON := jsonc.ToJSON(contents) + if !jsonc.Valid(convertedJSON) { + ch <- devContainerResult{nil, errors.New("failed to convert json to standard json")} + return + } + + var container devContainer + if err := json.Unmarshal(convertedJSON, &container); err != nil { + ch <- devContainerResult{nil, fmt.Errorf("error unmarshaling: %v", err)} + return + } + + ch <- devContainerResult{&container, nil} + }() + return ch +} diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index e8e1cb671..39019054f 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -58,7 +58,7 @@ func SSH(sshProfile string) error { return nil } - codespaces.SortByRecent() + codespaces.SortByCreatedAt() codespacesByName := make(map[string]*api.Codespace) codespacesNames := make([]string, 0, len(codespaces)) From e373c91f8b2121a25eacf2f70f040dab14bff730 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Sun, 18 Jul 2021 00:05:13 +0000 Subject: [PATCH 014/290] UpdateSharedServerVisibility API for Server --- server.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/server.go b/server.go index 68053d9f7..71ec9d4dd 100644 --- a/server.go +++ b/server.go @@ -59,3 +59,11 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { return response, nil } + +func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error { + if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil { + return err + } + + return nil +} From 798413848b0c8b211caf1fd96fa3bc5b74baef29 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Sat, 17 Jul 2021 20:32:47 -0400 Subject: [PATCH 015/290] Portfowarding private/public/forward now supported --- api/api.go | 4 +- cmd/ghcs/delete.go | 3 +- cmd/ghcs/ports.go | 246 ++++++++++++++++++++---------- cmd/ghcs/ssh.go | 96 ++---------- internal/codespaces/codespaces.go | 110 +++++++++++++ 5 files changed, 286 insertions(+), 173 deletions(-) create mode 100644 internal/codespaces/codespaces.go diff --git a/api/api.go b/api/api.go index 00ff6b056..8bf5e155d 100644 --- a/api/api.go +++ b/api/api.go @@ -177,7 +177,7 @@ type getCodespaceTokenResponse struct { RepositoryToken string `json:"repository_token"` } -func (a *API) GetCodespaceToken(ctx context.Context, codespace *Codespace) (string, error) { +func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName string) (string, error) { reqBody, err := json.Marshal(getCodespaceTokenRequest{true}) if err != nil { return "", fmt.Errorf("error preparing request body: %v", err) @@ -185,7 +185,7 @@ func (a *API) GetCodespaceToken(ctx context.Context, codespace *Codespace) (stri req, err := http.NewRequest( http.MethodPost, - githubAPI+"/vscs_internal/user/"+codespace.OwnerLogin+"/codespaces/"+codespace.Name+"/token", + githubAPI+"/vscs_internal/user/"+ownerLogin+"/codespaces/"+codespaceName+"/token", bytes.NewBuffer(reqBody), ) if err != nil { diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index e5cd34a94..a24374a4c 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -39,8 +39,7 @@ func Delete(codespaceName string) error { return fmt.Errorf("error getting user: %v", err) } - codespace := api.Codespace{OwnerLogin: user.Login, Name: codespaceName} - token, err := apiClient.GetCodespaceToken(ctx, &codespace) + token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) if err != nil { return fmt.Errorf("error getting codespace token: %v", err) } diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 9d58fd491..7766f230d 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -8,10 +8,9 @@ import ( "os" "strconv" "strings" - "time" - "github.com/AlecAivazis/survey/v2" "github.com/github/ghcs/api" + "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/muhammadmuzzammil1998/jsonc" "github.com/olekukonko/tablewriter" @@ -28,6 +27,9 @@ func NewPortsCmd() *cobra.Command { }, } + portsCmd.AddCommand(NewPortsPublicCmd()) + portsCmd.AddCommand(NewPortsPrivateCmd()) + portsCmd.AddCommand(NewPortsForwardCmd()) return portsCmd } @@ -44,98 +46,25 @@ func Ports() error { return fmt.Errorf("error getting user: %v", err) } - codespaces, err := apiClient.ListCodespaces(ctx, user) + codespace, err := codespaces.ChooseCodespace(ctx, apiClient, user) if err != nil { - return fmt.Errorf("error getting codespaces: %v", err) + if err == codespaces.ErrNoCodespaces { + fmt.Println(err.Error()) + return nil + } + return fmt.Errorf("error choosing codespace: %v", err) } - if len(codespaces) == 0 { - fmt.Println("You have no codespaces.") - return nil - } - - codespaces.SortByCreatedAt() - - codespacesByName := make(map[string]*api.Codespace) - codespacesNames := make([]string, 0, len(codespaces)) - for _, codespace := range codespaces { - codespacesByName[codespace.Name] = codespace - codespacesNames = append(codespacesNames, codespace.Name) - } - - portsSurvey := []*survey.Question{ - { - Name: "codespace", - Prompt: &survey.Select{ - Message: "Choose Codespace:", - Options: codespacesNames, - Default: codespacesNames[0], - }, - Validate: survey.Required, - }, - } - - answers := struct { - Codespace string - }{} - if err := survey.Ask(portsSurvey, &answers); err != nil { - return fmt.Errorf("error getting answers: %v", err) - } - - codespace := codespacesByName[answers.Codespace] devContainerCh := getDevContainer(ctx, apiClient, codespace) - token, err := apiClient.GetCodespaceToken(ctx, codespace) + token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { return fmt.Errorf("error getting codespace token: %v", err) } - if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { - fmt.Println("Starting your codespace...") - if err := apiClient.StartCodespace(ctx, token, codespace); err != nil { - return fmt.Errorf("error starting codespace: %v", err) - } - } - - retries := 0 - for codespace.Environment.Connection.SessionID == "" || codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { - if retries > 1 { - if retries%2 == 0 { - fmt.Print(".") - } - - time.Sleep(1 * time.Second) - } - - if retries == 30 { - return errors.New("timed out while waiting for the codespace to start") - } - - codespace, err = apiClient.GetCodespace(ctx, token, codespace.OwnerLogin, codespace.Name) - if err != nil { - return fmt.Errorf("error getting codespace: %v", err) - } - - retries += 1 - } - - if retries >= 2 { - fmt.Print("\n") - } - - fmt.Println("Connecting to your codespace...") - - liveShare, err := liveshare.New( - liveshare.WithWorkspaceID(codespace.Environment.Connection.SessionID), - liveshare.WithToken(codespace.Environment.Connection.SessionToken), - ) + liveShareClient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) if err != nil { - return fmt.Errorf("error creating live share: %v", err) - } - - liveShareClient := liveShare.NewClient() - if err := liveShareClient.Join(ctx); err != nil { - return fmt.Errorf("error joining liveshare client: %v", err) + return fmt.Errorf("error connecting to liveshare: %v", err) } fmt.Println("Loading ports...") @@ -144,6 +73,11 @@ func Ports() error { return fmt.Errorf("error getting ports: %v", err) } + if len(ports) == 0 { + fmt.Println("This codespace has no open ports") + return nil + } + devContainerResult := <-devContainerCh if devContainerResult.Err != nil { fmt.Println("Failed to get port names: %v", devContainerResult.Err.Error()) @@ -231,3 +165,147 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod }() return ch } + +func NewPortsPublicCmd() *cobra.Command { + return &cobra.Command{ + Use: "public", + Short: "public", + Long: "public", + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) < 2 { + return errors.New("[codespace_name] [source] port number are required.") + } + + return updatePortVisibility(args[0], args[1], true) + }, + } +} + +func NewPortsPrivateCmd() *cobra.Command { + return &cobra.Command{ + Use: "private", + Short: "private", + Long: "private", + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) < 2 { + return errors.New("[codespace_name] [source] port number are required.") + } + + return updatePortVisibility(args[0], args[1], false) + }, + } +} + +func updatePortVisibility(codespaceName, sourcePort string, public bool) error { + ctx := context.Background() + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + + user, err := apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %v", err) + } + + token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) + if err != nil { + return fmt.Errorf("error getting codespace token: %v", err) + } + + codespace, err := apiClient.GetCodespace(ctx, token, user.Login, codespaceName) + if err != nil { + return fmt.Errorf("error getting codespace: %v", err) + } + + liveShareClient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + if err != nil { + return fmt.Errorf("error connecting to liveshare: %v", err) + } + + server, err := liveShareClient.NewServer() + if err != nil { + return fmt.Errorf("error creating server: %v", err) + } + + port, err := strconv.Atoi(sourcePort) + if err != nil { + return fmt.Errorf("error reading port number: %v", err) + } + + if err := server.UpdateSharedVisibility(ctx, port, public); err != nil { + return fmt.Errorf("error update port to public: %v", err) + } + + state := "PUBLIC" + if public == false { + state = "PRIVATE" + } + + fmt.Println(fmt.Sprintf("Port %s is now %s.", sourcePort, state)) + + return nil +} + +func NewPortsForwardCmd() *cobra.Command { + return &cobra.Command{ + Use: "forward", + Short: "forward", + Long: "forward", + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) < 3 { + return errors.New("[codespace_name] [source] [dst] port number are required.") + } + return forwardPort(args[0], args[1], args[2]) + }, + } +} + +func forwardPort(codespaceName, sourcePort, destPort string) error { + ctx := context.Background() + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + + user, err := apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %v", err) + } + + token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) + if err != nil { + return fmt.Errorf("error getting codespace token: %v", err) + } + + codespace, err := apiClient.GetCodespace(ctx, token, user.Login, codespaceName) + if err != nil { + return fmt.Errorf("error getting codespace: %v", err) + } + + liveShareClient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + if err != nil { + return fmt.Errorf("error connecting to liveshare: %v", err) + } + + server, err := liveShareClient.NewServer() + if err != nil { + return fmt.Errorf("error creating server: %v", err) + } + + sourcePortInt, err := strconv.Atoi(sourcePort) + if err != nil { + return fmt.Errorf("error reading source port: %v", err) + } + + dstPortInt, err := strconv.Atoi(destPort) + if err != nil { + return fmt.Errorf("error reading destination port: %v", err) + } + + if err := server.StartSharing(ctx, "share-"+sourcePort, sourcePortInt); err != nil { + return fmt.Errorf("error sharing source port: %v", err) + } + + fmt.Println("Forwarding port: " + sourcePort + " -> " + destPort) + portForwarder := liveshare.NewLocalPortForwarder(liveShareClient, server, dstPortInt) + if err := portForwarder.Start(ctx); err != nil { + return fmt.Errorf("error forwarding port: %v", err) + } + + return nil +} diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 39019054f..50196dd07 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -3,7 +3,6 @@ package main import ( "bufio" "context" - "errors" "fmt" "math/rand" "os" @@ -12,8 +11,8 @@ import ( "strings" "time" - "github.com/AlecAivazis/survey/v2" "github.com/github/ghcs/api" + "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" ) @@ -48,97 +47,24 @@ func SSH(sshProfile string) error { return fmt.Errorf("error getting user: %v", err) } - codespaces, err := apiClient.ListCodespaces(ctx, user) + codespace, err := codespaces.ChooseCodespace(ctx, apiClient, user) if err != nil { - return fmt.Errorf("error getting codespaces: %v", err) + if err == codespaces.ErrNoCodespaces { + fmt.Println(err.Error()) + return nil + } + + return fmt.Errorf("error choosing codespace: %v", err) } - if len(codespaces) == 0 { - fmt.Println("You have no codespaces.") - return nil - } - - codespaces.SortByCreatedAt() - - codespacesByName := make(map[string]*api.Codespace) - codespacesNames := make([]string, 0, len(codespaces)) - for _, codespace := range codespaces { - codespacesByName[codespace.Name] = codespace - codespacesNames = append(codespacesNames, codespace.Name) - } - - sshSurvey := []*survey.Question{ - { - Name: "codespace", - Prompt: &survey.Select{ - Message: "Choose Codespace:", - Options: codespacesNames, - Default: codespacesNames[0], - }, - Validate: survey.Required, - }, - } - - answers := struct { - Codespace string - }{} - if err := survey.Ask(sshSurvey, &answers); err != nil { - return fmt.Errorf("error getting answers: %v", err) - } - - codespace := codespacesByName[answers.Codespace] - - token, err := apiClient.GetCodespaceToken(ctx, codespace) + token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { return fmt.Errorf("error getting codespace token: %v", err) } - if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { - fmt.Println("Starting your codespace...") - if err := apiClient.StartCodespace(ctx, token, codespace); err != nil { - return fmt.Errorf("error starting codespace: %v", err) - } - } - - retries := 0 - for codespace.Environment.Connection.SessionID == "" || codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { - if retries > 1 { - if retries%2 == 0 { - fmt.Print(".") - } - - time.Sleep(1 * time.Second) - } - - if retries == 30 { - return errors.New("timed out while waiting for the codespace to start") - } - - codespace, err = apiClient.GetCodespace(ctx, token, codespace.OwnerLogin, codespace.Name) - if err != nil { - return fmt.Errorf("error getting codespace: %v", err) - } - - retries += 1 - } - - if retries >= 2 { - fmt.Print("\n") - } - - fmt.Println("Connecting to your codespace...") - - liveShare, err := liveshare.New( - liveshare.WithWorkspaceID(codespace.Environment.Connection.SessionID), - liveshare.WithToken(codespace.Environment.Connection.SessionToken), - ) + liveShareClient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) if err != nil { - return fmt.Errorf("error creating live share: %v", err) - } - - liveShareClient := liveShare.NewClient() - if err := liveShareClient.Join(ctx); err != nil { - return fmt.Errorf("error joining liveshare client: %v", err) + return fmt.Errorf("error connecting to liveshare: %v", err) } terminal, err := liveShareClient.NewTerminal() diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go new file mode 100644 index 000000000..be290fab1 --- /dev/null +++ b/internal/codespaces/codespaces.go @@ -0,0 +1,110 @@ +package codespaces + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/AlecAivazis/survey/v2" + "github.com/github/ghcs/api" + "github.com/github/go-liveshare" +) + +var ( + ErrNoCodespaces = errors.New("You have no codespaces.") +) + +func ChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (*api.Codespace, error) { + codespaces, err := apiClient.ListCodespaces(ctx, user) + if err != nil { + return nil, fmt.Errorf("error getting codespaces: %v", err) + } + + if len(codespaces) == 0 { + return nil, ErrNoCodespaces + } + + codespaces.SortByCreatedAt() + + codespacesByName := make(map[string]*api.Codespace) + codespacesNames := make([]string, 0, len(codespaces)) + for _, codespace := range codespaces { + codespacesByName[codespace.Name] = codespace + codespacesNames = append(codespacesNames, codespace.Name) + } + + sshSurvey := []*survey.Question{ + { + Name: "codespace", + Prompt: &survey.Select{ + Message: "Choose Codespace:", + Options: codespacesNames, + Default: codespacesNames[0], + }, + Validate: survey.Required, + }, + } + + answers := struct { + Codespace string + }{} + if err := survey.Ask(sshSurvey, &answers); err != nil { + return nil, fmt.Errorf("error getting answers: %v", err) + } + + codespace := codespacesByName[answers.Codespace] + return codespace, nil +} + +func ConnectToLiveshare(ctx context.Context, apiClient *api.API, token string, codespace *api.Codespace) (client *liveshare.Client, err error) { + if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { + fmt.Println("Starting your codespace...") // TODO(josebalius): better way of notifying of events + if err := apiClient.StartCodespace(ctx, token, codespace); err != nil { + return nil, fmt.Errorf("error starting codespace: %v", err) + } + } + + retries := 0 + for codespace.Environment.Connection.SessionID == "" || codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { + if retries > 1 { + if retries%2 == 0 { + fmt.Print(".") + } + + time.Sleep(1 * time.Second) + } + + if retries == 30 { + return nil, errors.New("timed out while waiting for the codespace to start") + } + + codespace, err = apiClient.GetCodespace(ctx, token, codespace.OwnerLogin, codespace.Name) + if err != nil { + return nil, fmt.Errorf("error getting codespace: %v", err) + } + + retries += 1 + } + + if retries >= 2 { + fmt.Print("\n") + } + + fmt.Println("Connecting to your codespace...") + + liveShare, err := liveshare.New( + liveshare.WithWorkspaceID(codespace.Environment.Connection.SessionID), + liveshare.WithToken(codespace.Environment.Connection.SessionToken), + ) + if err != nil { + return nil, fmt.Errorf("error creating live share: %v", err) + } + + liveShareClient := liveShare.NewClient() + if err := liveShareClient.Join(ctx); err != nil { + return nil, fmt.Errorf("error joining liveshare client: %v", err) + } + + return liveShareClient, nil +} From 570a407bace2ff3ace0d8b5fb57c9e56c2a7fb03 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 19 Jul 2021 08:00:51 -0400 Subject: [PATCH 016/290] Fix directive --- cmd/ghcs/ports.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 7766f230d..d5b863c82 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -80,7 +80,7 @@ func Ports() error { devContainerResult := <-devContainerCh if devContainerResult.Err != nil { - fmt.Println("Failed to get port names: %v", devContainerResult.Err.Error()) + fmt.Printf("Failed to get port names: %v\n", devContainerResult.Err.Error()) } table := tablewriter.NewWriter(os.Stdout) From cb29b11ab207d1bb2ef6479d77c9991501e5a675 Mon Sep 17 00:00:00 2001 From: Issy Long Date: Mon, 19 Jul 2021 18:10:15 +0100 Subject: [PATCH 017/290] cmd/ghcs/main: Fail gracefully if `GITHUB_TOKEN` entirely unset MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - I have my GitHub API token in my environment as `HOMEBREW_GITHUB_API_TOKEN`, so with things that need `GITHUB_TOKEN` I have to remember to `export GITHUB_TOKEN=$HOMEBREW_GITHUB_API_TOKEN`. - I didn't for this tool, and got this unfriendly error message: ``` ❯ ghcs list Error: error getting user: Bad credentials Usage: ghcs list [flags] Flags: -h, --help help for list error getting user: Bad credentials ``` - This moves the "do you have a `GITHUB_TOKEN`" question to the very beginning (no guarantees about org SSO access, just a string that exists), erroring out with a nice message if users don't have that envvar set: ``` issyl0 in cetus in ~/repos/github/ghcs/cmd/ghcs on gracefully-fail-if-token-envvar-unset ❯ ./ghcs list The GITHUB_TOKEN environment variable is required. Create a Personal Access Token with org SSO access at https://github.com/settings/tokens/new. issyl0 in cetus in ~/repos/github/ghcs/cmd/ghcs on gracefully-fail-if-token-envvar-unset ❯ export GITHUB_TOKEN=$HOMEBREW_GITHUB_API_TOKEN ❯ ./ghcs list +--------------------------------+--------------------+------------------------------------+----------+---------------------------+ | NAME | REPOSITORY | BRANCH | STATE | CREATED AT | +--------------------------------+--------------------+------------------------------------+----------+---------------------------+ | issyl0-github-cat-ggrpj5fvwvr | github/cat | dependabot/bundler/graphql-1.12.13 | Shutdown | 2021-07-13T12:36:53+01:00 | +--------------------------------+--------------------+------------------------------------+----------+---------------------------+ ``` --- cmd/ghcs/main.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 400f5324c..ce8ec91e6 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -22,6 +22,11 @@ var rootCmd = &cobra.Command{ } func Execute() { + if os.Getenv("GITHUB_TOKEN") == "" { + fmt.Println("The GITHUB_TOKEN environment variable is required. Create a Personal Access Token with org SSO access at https://github.com/settings/tokens/new.") + os.Exit(1) + } + if err := rootCmd.Execute(); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) From 4582fed1ccef6bdaa40f17c3a2985f92bedb90b3 Mon Sep 17 00:00:00 2001 From: Issy Long Date: Mon, 19 Jul 2021 18:44:02 +0100 Subject: [PATCH 018/290] cmd/ghcs/main: Add `--version` flag - This is built into Cobra the argument parser. Now `ghcs --version` exists. - When we prepare to bump the version, we need to remember to update this value else the Homebrew formula, GitHub releases and the `ghcs --version` output will be mismatched. - Fixes https://github.com/github/ghcs/issues/16. --- cmd/ghcs/main.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 400f5324c..6270b6fac 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -16,9 +16,10 @@ func main() { } var rootCmd = &cobra.Command{ - Use: "ghcs", - Short: "Codespaces", - Long: "Codespaces", + Use: "ghcs", + Short: "Codespaces", + Long: "Codespaces", + Version: "0.5.0", } func Execute() { From 6d5726d78a665643f89514fc678d9ab1ccb1a138 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 20 Jul 2021 11:59:14 +0000 Subject: [PATCH 019/290] Better way to discard requests & close channel/conn on disconnects --- client.go | 15 +-------------- port_forwarder.go | 4 ++-- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/client.go b/client.go index 0af904b92..456a2c321 100644 --- a/client.go +++ b/client.go @@ -96,7 +96,7 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition if err != nil { return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) } - go c.processChannelRequests(ctx, reqs) + go ssh.DiscardRequests(reqs) requestType := fmt.Sprintf("stream-transport-%s", streamID) _, err = channel.SendRequest(requestType, true, nil) @@ -106,16 +106,3 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition return channel, nil } - -func (c *Client) processChannelRequests(ctx context.Context, reqs <-chan *ssh.Request) { - for { - select { - case req := <-reqs: - if req != nil { - // TODO(josebalius): Handle - } - case <-ctx.Done(): - break - } - } -} diff --git a/port_forwarder.go b/port_forwarder.go index 0ae5e1916..20382c208 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -53,8 +53,8 @@ func (l *LocalPortForwarder) handleConnection(ctx context.Context, conn net.Conn copyConn := func(writer io.Writer, reader io.Reader) { _, err := io.Copy(writer, reader) if err != nil { - log.Println("errrrr copyConn") - log.Println(err) //TODO(josebalius): handle this somehow + channel.Close() + conn.Close() } } From 6642fb520a9fd43a928cf9fcfae0de8e306a00ec Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 20 Jul 2021 08:04:34 -0400 Subject: [PATCH 020/290] Better connection handling and simpler ssh setup --- cmd/ghcs/ssh.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 50196dd07..59ffdb3c1 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -82,6 +82,8 @@ func SSH(sshProfile string) error { if err := setupSSH(ctx, terminal, containerID, codespace.RepositoryName); err != nil { return fmt.Errorf("error creating ssh server: %v", err) } + + fmt.Printf("\n") } server, err := liveShareClient.NewServer() @@ -124,6 +126,7 @@ func connect(ctx context.Context, port int, destination string) error { } func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) { + fmt.Print(".") cmd := terminal.NewCommand( "/", "/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running", @@ -133,14 +136,17 @@ func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, return "", fmt.Errorf("error running command: %v", err) } + fmt.Print(".") scanner := bufio.NewScanner(stream) scanner.Scan() + fmt.Print(".") containerID := scanner.Text() if err := scanner.Err(); err != nil { return "", fmt.Errorf("error scanning stream: %v", err) } + fmt.Print(".") if err := stream.Close(); err != nil { return "", fmt.Errorf("error closing stream: %v", err) } @@ -149,13 +155,11 @@ func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, } func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID, repositoryName string) error { - getUsernameCmd := "GITHUB_USERNAME=\"$(jq .CODESPACE_NAME /workspaces/.codespaces/shared/environment-variables.json -r | cut -f1 -d -)\"" - makeSSHDirCmd := "mkdir /home/codespace/.ssh" - getUserKeysCmd := "curl --silent --fail \"https://github.com/$(echo $GITHUB_USERNAME).keys\" > /home/codespace/.ssh/authorized_keys" - setupSecretsCmd := `cat /workspaces/.codespaces/shared/.user-secrets.json | jq -r ".[] | select (.type==\"EnvironmentVariable\") | .name+\"=\"+.value" > /home/codespace/.zshenv` + setupSecretsCmd := `cp /workspaces/.codespaces/shared/.env /home/codespace/.zshenv` setupLoginDirCmd := fmt.Sprintf("echo \"cd /workspaces/%v; exec /bin/zsh;\" > /home/codespace/.bash_profile", repositoryName) - compositeCommand := []string{getUsernameCmd, makeSSHDirCmd, getUserKeysCmd, setupSecretsCmd, setupLoginDirCmd} + fmt.Print(".") + compositeCommand := []string{setupSecretsCmd, setupLoginDirCmd} cmd := terminal.NewCommand( "/", fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c '"+strings.Join(compositeCommand, "; ")+"'", containerID), @@ -165,6 +169,7 @@ func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID, re return fmt.Errorf("error running command: %v", err) } + fmt.Print(".") if err := stream.Close(); err != nil { return fmt.Errorf("error closing stream: %v", err) } From 8faee1e5a951c41a790253b5f60ba09f4b7ab8ab Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 20 Jul 2021 08:09:48 -0400 Subject: [PATCH 021/290] Update main.go --- cmd/ghcs/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index f7685ed25..6297eac03 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -19,7 +19,7 @@ var rootCmd = &cobra.Command{ Use: "ghcs", Short: "Codespaces", Long: "Codespaces", - Version: "0.5.0", + Version: "0.5.1", } func Execute() { From e81bee6886ba0998d2faa5b2f47fda6c1eee9f28 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 20 Jul 2021 18:43:43 -0400 Subject: [PATCH 022/290] Doesn't overwrite .zshenv and supports server-port --- cmd/ghcs/ssh.go | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 59ffdb3c1..9977928f4 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -19,17 +19,19 @@ import ( func NewSSHCmd() *cobra.Command { var sshProfile string + var sshServerPort int sshCmd := &cobra.Command{ Use: "ssh", Short: "ssh", Long: "ssh", RunE: func(cmd *cobra.Command, args []string) error { - return SSH(sshProfile) + return SSH(sshProfile, sshServerPort) }, } sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "SSH Profile") + sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH Server Port") return sshCmd } @@ -38,7 +40,7 @@ func init() { rootCmd.AddCommand(NewSSHCmd()) } -func SSH(sshProfile string) error { +func SSH(sshProfile string, sshServerPort int) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() @@ -93,6 +95,10 @@ func SSH(sshProfile string) error { rand.Seed(time.Now().Unix()) port := rand.Intn(9999-2000) + 2000 // improve this obviously + if sshServerPort != 0 { + port = sshServerPort + } + if err := server.StartSharing(ctx, "sshd", 2222); err != nil { return fmt.Errorf("error sharing sshd port: %v", err) } @@ -110,15 +116,21 @@ func SSH(sshProfile string) error { } fmt.Println("Ready...") - if err := connect(ctx, port, connectDestination); err != nil { + if err := connect(ctx, port, connectDestination, port == sshServerPort); err != nil { return fmt.Errorf("error connecting via SSH: %v", err) } return nil } -func connect(ctx context.Context, port int, destination string) error { - cmd := exec.CommandContext(ctx, "ssh", destination, "-C", "-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes") +func connect(ctx context.Context, port int, destination string, setServerPort bool) error { + cmdArgs := []string{destination, "-C", "-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} + + if setServerPort { + fmt.Println("Connection Details: ssh " + strings.Join(cmdArgs, " ")) + } + + cmd := exec.CommandContext(ctx, "ssh", cmdArgs...) cmd.Stdout = os.Stdout cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr @@ -155,11 +167,10 @@ func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, } func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID, repositoryName string) error { - setupSecretsCmd := `cp /workspaces/.codespaces/shared/.env /home/codespace/.zshenv` - setupLoginDirCmd := fmt.Sprintf("echo \"cd /workspaces/%v; exec /bin/zsh;\" > /home/codespace/.bash_profile", repositoryName) + setupBashProfileCmd := fmt.Sprintf(`echo "cd /workspaces/%v; export $(cat /workspaces/.codespaces/shared/.env | xargs); exec /bin/zsh;" > /home/codespace/.bash_profile`, repositoryName) fmt.Print(".") - compositeCommand := []string{setupSecretsCmd, setupLoginDirCmd} + compositeCommand := []string{setupBashProfileCmd} cmd := terminal.NewCommand( "/", fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c '"+strings.Join(compositeCommand, "; ")+"'", containerID), From 7a3e47ff3ef548a887997d21f7867af23de10575 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Oddsson?= Date: Wed, 21 Jul 2021 12:46:44 +0100 Subject: [PATCH 023/290] Update error message link and wording. --- cmd/ghcs/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 6297eac03..cb7c8090c 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -24,7 +24,7 @@ var rootCmd = &cobra.Command{ func Execute() { if os.Getenv("GITHUB_TOKEN") == "" { - fmt.Println("The GITHUB_TOKEN environment variable is required. Create a Personal Access Token with org SSO access at https://github.com/settings/tokens/new.") + fmt.Println("The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo and make sure to enable SSO for the GitHub organization after creating the token.") os.Exit(1) } From 3e50fff2c9d60d898104c5cd791a06c42199c110 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 21 Jul 2021 10:22:33 -0400 Subject: [PATCH 024/290] X11 support --- cmd/ghcs/ssh.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 9977928f4..4d6cd3d5b 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -124,13 +124,15 @@ func SSH(sshProfile string, sshServerPort int) error { } func connect(ctx context.Context, port int, destination string, setServerPort bool) error { - cmdArgs := []string{destination, "-C", "-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} + connectionDetailArgs := []string{destination, "-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} if setServerPort { - fmt.Println("Connection Details: ssh " + strings.Join(cmdArgs, " ")) + fmt.Println("Connection Details: ssh " + strings.Join(connectionDetailArgs, " ")) } - cmd := exec.CommandContext(ctx, "ssh", cmdArgs...) + cmdArgs := []string{"-X", "-Y", "-C"} // X11, X11Trust, Compression + + cmd := exec.CommandContext(ctx, "ssh", append(cmdArgs, connectionDetailArgs...)...) cmd.Stdout = os.Stdout cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr From c2b136a84f4054c79923c9c06ae9a832577d9432 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 21 Jul 2021 13:28:47 -0400 Subject: [PATCH 025/290] ghcs code command support --- cmd/ghcs/code.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 cmd/ghcs/code.go diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go new file mode 100644 index 000000000..1080fd896 --- /dev/null +++ b/cmd/ghcs/code.go @@ -0,0 +1,64 @@ +package main + +import ( + "context" + "fmt" + "net/url" + "os" + + "github.com/github/ghcs/api" + "github.com/github/ghcs/internal/codespaces" + "github.com/skratchdot/open-golang/open" + "github.com/spf13/cobra" +) + +func NewCodeCmd() *cobra.Command { + return &cobra.Command{ + Use: "code", + Short: "code", + Long: "code", + RunE: func(cmd *cobra.Command, args []string) error { + var codespaceName string + if len(args) > 0 { + codespaceName = args[0] + } + return Code(codespaceName) + }, + } +} + +func init() { + rootCmd.AddCommand(NewCodeCmd()) +} + +func Code(codespaceName string) error { + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + ctx := context.Background() + + user, err := apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %v", err) + } + + if codespaceName == "" { + codespace, err := codespaces.ChooseCodespace(ctx, apiClient, user) + if err != nil { + if err == codespaces.ErrNoCodespaces { + fmt.Println(err.Error()) + return nil + } + return fmt.Errorf("error choosing codespace: %v", err) + } + codespaceName = codespace.Name + } + + if err := open.Run(vscodeProtocolURL(codespaceName)); err != nil { + return fmt.Errorf("error opening vscode URL") + } + + return nil +} + +func vscodeProtocolURL(codespaceName string) string { + return fmt.Sprintf("vscode://github.codespaces/connect?name=%s", url.QueryEscape(codespaceName)) +} From 345e3e1b8af2c51c74fc8129a971495b797f025f Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 21 Jul 2021 13:50:19 -0400 Subject: [PATCH 026/290] Update main.go --- cmd/ghcs/main.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index cb7c8090c..696975ab5 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -7,10 +7,6 @@ import ( "github.com/spf13/cobra" ) -// ghcs create -// ghcs connect -// ghcs delete -// ghcs list func main() { Execute() } @@ -19,7 +15,7 @@ var rootCmd = &cobra.Command{ Use: "ghcs", Short: "Codespaces", Long: "Codespaces", - Version: "0.5.1", + Version: "0.6.0", } func Execute() { From 532ee681657c0c662c2743c69bc5659774434c5f Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 21 Jul 2021 14:04:42 -0400 Subject: [PATCH 027/290] Fix ssh command order --- cmd/ghcs/ssh.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 4d6cd3d5b..6c4385aaf 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -124,15 +124,14 @@ func SSH(sshProfile string, sshServerPort int) error { } func connect(ctx context.Context, port int, destination string, setServerPort bool) error { - connectionDetailArgs := []string{destination, "-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} + connectionDetailArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} if setServerPort { - fmt.Println("Connection Details: ssh " + strings.Join(connectionDetailArgs, " ")) + fmt.Println("Connection Details: ssh " + destination + " " + strings.Join(connectionDetailArgs, " ")) } - cmdArgs := []string{"-X", "-Y", "-C"} // X11, X11Trust, Compression - - cmd := exec.CommandContext(ctx, "ssh", append(cmdArgs, connectionDetailArgs...)...) + args := []string{destination, "-X", "-Y", "-C"} // X11, X11Trust, Compression + cmd := exec.CommandContext(ctx, "ssh", append(args, connectionDetailArgs...)...) cmd.Stdout = os.Stdout cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr From 0d6926e14bd4248186656b981bc280a20b9ce7bd Mon Sep 17 00:00:00 2001 From: Camilo Garcia La Rotta Date: Wed, 21 Jul 2021 17:41:50 -0400 Subject: [PATCH 028/290] doc: root cmd description --- cmd/ghcs/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 696975ab5..c0659484d 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -13,8 +13,8 @@ func main() { var rootCmd = &cobra.Command{ Use: "ghcs", - Short: "Codespaces", - Long: "Codespaces", + Short: "Unofficial GitHub Codespaces CLI", + Long: "Unofficial CLI tool to manage and interact with GitHub Codespaces", Version: "0.6.0", } From 7a0a8fa39c517647de4a4a499a757b70d16a5321 Mon Sep 17 00:00:00 2001 From: Camilo Garcia La Rotta Date: Wed, 21 Jul 2021 18:02:50 -0400 Subject: [PATCH 029/290] feat: ghcs delete all --- cmd/ghcs/delete.go | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index a24374a4c..976fa4fd4 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -13,7 +13,7 @@ import ( func NewDeleteCmd() *cobra.Command { deleteCmd := &cobra.Command{ Use: "delete CODESPACE_NAME", - Short: "delete", + Short: "delete codespaces", Long: "delete", RunE: func(cmd *cobra.Command, args []string) error { if len(args) == 0 { @@ -23,6 +23,17 @@ func NewDeleteCmd() *cobra.Command { }, } + deleteAllCmd := &cobra.Command{ + Use: "all", + Short: "delete all codespaces", + Long: "delete all codespaces for the user with the current token", + RunE: func(cmd *cobra.Command, args []string) error { + return DeleteAll() + }, + } + + deleteCmd.AddCommand(deleteAllCmd) + return deleteCmd } @@ -52,3 +63,33 @@ func Delete(codespaceName string) error { return List() } + +func DeleteAll() error { + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + ctx := context.Background() + + user, err := apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %v", err) + } + + codespaces, err := apiClient.ListCodespaces(ctx, user) + if err != nil { + return fmt.Errorf("error getting codespaces: %v", err) + } + + for _, c := range codespaces { + token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) + if err != nil { + return fmt.Errorf("error getting codespace token: %v", err) + } + + if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { + return fmt.Errorf("error deleting codespace: %v", err) + } + + fmt.Printf("Codespace deleted: %s\n", c.Name) + } + + return List() +} From 5ca2fa556270763ee08e2c054b0bd80a8204ea1e Mon Sep 17 00:00:00 2001 From: Camilo Garcia La Rotta Date: Wed, 21 Jul 2021 18:13:36 -0400 Subject: [PATCH 030/290] feat: ghcs delete repo REPO_NAME --- cmd/ghcs/delete.go | 55 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 976fa4fd4..d748ee3e1 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -32,7 +32,20 @@ func NewDeleteCmd() *cobra.Command { }, } - deleteCmd.AddCommand(deleteAllCmd) + deleteByRepoCmd := &cobra.Command{ + Use: "repo REPO_NAME", + Short: "delete all codespaces for the repo", + Long: `delete all the codespaces that the user with the current token has in this repo. +This includes all codespaces in all states.`, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) == 0 { + return errors.New("A Repository name is required.") + } + return DeleteByRepo(args[0]) + }, + } + + deleteCmd.AddCommand(deleteAllCmd, deleteByRepoCmd) return deleteCmd } @@ -93,3 +106,43 @@ func DeleteAll() error { return List() } + +func DeleteByRepo(repo string) error { + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + ctx := context.Background() + + user, err := apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %v", err) + } + + codespaces, err := apiClient.ListCodespaces(ctx, user) + if err != nil { + return fmt.Errorf("error getting codespaces: %v", err) + } + + var deleted bool + for _, c := range codespaces { + if c.RepositoryNWO != repo { + continue + } + deleted = true + + token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) + if err != nil { + return fmt.Errorf("error getting codespace token: %v", err) + } + + if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { + return fmt.Errorf("error deleting codespace: %v", err) + } + + fmt.Printf("Codespace deleted: %s\n", c.Name) + } + + if !deleted { + fmt.Printf("No codespace was found for repository: %s\n", repo) + } + + return List() +} From c751e88120baa4f25ec978a092e308d915f95cc3 Mon Sep 17 00:00:00 2001 From: Camilo Garcia La Rotta Date: Wed, 21 Jul 2021 19:56:08 -0400 Subject: [PATCH 031/290] feat: introduce repo, branch and machine flags for ghcs create --- cmd/ghcs/create.go | 44 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 44bedb5f2..35cd49a2d 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -12,17 +12,45 @@ import ( "github.com/spf13/cobra" ) -var createCmd = &cobra.Command{ - Use: "create", - Short: "Create", - Long: "Create", - RunE: func(cmd *cobra.Command, args []string) error { - return Create() - }, +var repo, branch, machine string + +type machineType string + +const ( + basicMachine machineType = "basic" + standardMachine machineType = "standard" + premiumMachine machineType = "premium" + ExtremeMachine machineType = "extreme" +) + +func newCreateCmd() *cobra.Command { + createCmd := &cobra.Command{ + Use: "create", + Short: "Create a codespace", + Long: `Create a codespace for a given repository and branch. +You must also choose the type of machine to use.`, + RunE: func(cmd *cobra.Command, args []string) error { + if machine != "" { + switch machineType(machine) { + case basicMachine, standardMachine, premiumMachine, ExtremeMachine: + break + default: + return fmt.Errorf("invalid machine type: %s", machine) + } + } + return Create() + }, + } + + createCmd.Flags().StringVarP(&repo, "repo", "r", "", "repository name with owner: user/repo") + createCmd.Flags().StringVarP(&branch, "branch", "b", "", "repository branch") + createCmd.Flags().StringVarP(&machine, "machine", "m", "", "hardware specifications for the VM. Can be: basic, standard, premium, extreme") + + return createCmd } func init() { - rootCmd.AddCommand(createCmd) + rootCmd.AddCommand(newCreateCmd()) } var createSurvey = []*survey.Question{ From aab98ccc18fcaa74c7c59bdb52b6b4a11fa3b7e2 Mon Sep 17 00:00:00 2001 From: Camilo Garcia La Rotta Date: Wed, 21 Jul 2021 20:06:05 -0400 Subject: [PATCH 032/290] feat: break out repo and branch surveys --- cmd/ghcs/create.go | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 35cd49a2d..cc7436e6f 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -53,12 +53,14 @@ func init() { rootCmd.AddCommand(newCreateCmd()) } -var createSurvey = []*survey.Question{ +var repoSurvey = []*survey.Question{ { Name: "repository", Prompt: &survey.Input{Message: "Repository"}, Validate: survey.Required, }, +} +var branchSurvey = []*survey.Question{ { Name: "branch", Prompt: &survey.Input{Message: "Branch"}, @@ -72,16 +74,19 @@ func Create() error { locationCh := getLocation(ctx, apiClient) userCh := getUser(ctx, apiClient) - answers := struct { - Repository string - Branch string - }{} - - if err := survey.Ask(createSurvey, &answers); err != nil { - return fmt.Errorf("error getting answers: %v", err) + if repo == "" { + if err := survey.Ask(repoSurvey, &repo); err != nil { + return fmt.Errorf("error getting repository name: %v", err) + } } - repository, err := apiClient.GetRepository(ctx, answers.Repository) + if branch == "" { + if err := survey.Ask(branchSurvey, &branch); err != nil { + return fmt.Errorf("error getting branch name: %v", err) + } + } + + repository, err := apiClient.GetRepository(ctx, repo) if err != nil { return fmt.Errorf("error getting repository: %v", err) } @@ -136,7 +141,7 @@ func Create() error { sku := skuByName[skuAnswers.SKU] fmt.Println("Creating your codespace...") - codespace, err := apiClient.CreateCodespace(ctx, userResult.User, repository, sku, answers.Branch, locationResult.Location) + codespace, err := apiClient.CreateCodespace(ctx, userResult.User, repository, sku, branch, locationResult.Location) if err != nil { return fmt.Errorf("error creating codespace: %v", err) } From 3db217fef063ce2bc877d6d11f184ca5b4bc696e Mon Sep 17 00:00:00 2001 From: Camilo Garcia La Rotta Date: Wed, 21 Jul 2021 20:27:22 -0400 Subject: [PATCH 033/290] feat: make sku survey optional --- api/api.go | 4 +-- cmd/ghcs/create.go | 73 +++++++++++++++++++--------------------------- 2 files changed, 32 insertions(+), 45 deletions(-) diff --git a/api/api.go b/api/api.go index 8bf5e155d..bf16260ae 100644 --- a/api/api.go +++ b/api/api.go @@ -341,8 +341,8 @@ type createCodespaceRequest struct { SkuName string `json:"sku_name"` } -func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repository, sku *Sku, branch, location string) (*Codespace, error) { - requestBody, err := json.Marshal(createCodespaceRequest{repository.ID, branch, location, sku.Name}) +func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repository, sku, branch, location string) (*Codespace, error) { + requestBody, err := json.Marshal(createCodespaceRequest{repository.ID, branch, location, sku}) if err != nil { return nil, fmt.Errorf("error marshaling request: %v", err) } diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index cc7436e6f..0b0d9ebd3 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -14,15 +14,6 @@ import ( var repo, branch, machine string -type machineType string - -const ( - basicMachine machineType = "basic" - standardMachine machineType = "standard" - premiumMachine machineType = "premium" - ExtremeMachine machineType = "extreme" -) - func newCreateCmd() *cobra.Command { createCmd := &cobra.Command{ Use: "create", @@ -30,21 +21,13 @@ func newCreateCmd() *cobra.Command { Long: `Create a codespace for a given repository and branch. You must also choose the type of machine to use.`, RunE: func(cmd *cobra.Command, args []string) error { - if machine != "" { - switch machineType(machine) { - case basicMachine, standardMachine, premiumMachine, ExtremeMachine: - break - default: - return fmt.Errorf("invalid machine type: %s", machine) - } - } return Create() }, } createCmd.Flags().StringVarP(&repo, "repo", "r", "", "repository name with owner: user/repo") createCmd.Flags().StringVarP(&branch, "branch", "b", "", "repository branch") - createCmd.Flags().StringVarP(&machine, "machine", "m", "", "hardware specifications for the VM. Can be: basic, standard, premium, extreme") + createCmd.Flags().StringVarP(&machine, "machine", "m", "", "hardware specifications for the VM") return createCmd } @@ -111,37 +94,41 @@ func Create() error { return nil } - skuNames := make([]string, 0, len(skus)) - skuByName := make(map[string]*api.Sku) - for _, sku := range skus { - nameParts := camelcase.Split(sku.Name) - machineName := strings.Title(strings.ToLower(nameParts[0])) - skuName := fmt.Sprintf("%s - %s", machineName, sku.DisplayName) - skuNames = append(skuNames, skuName) - skuByName[skuName] = sku - } + if machine == "" { + skuNames := make([]string, 0, len(skus)) + skuByName := make(map[string]*api.Sku) + for _, sku := range skus { + nameParts := camelcase.Split(sku.Name) + machineName := strings.Title(strings.ToLower(nameParts[0])) + skuName := fmt.Sprintf("%s - %s", machineName, sku.DisplayName) + skuNames = append(skuNames, skuName) + skuByName[skuName] = sku + } - skuSurvey := []*survey.Question{ - { - Name: "sku", - Prompt: &survey.Select{ - Message: "Choose Machine Type:", - Options: skuNames, - Default: skuNames[0], + skuSurvey := []*survey.Question{ + { + Name: "sku", + Prompt: &survey.Select{ + Message: "Choose Machine Type:", + Options: skuNames, + Default: skuNames[0], + }, + Validate: survey.Required, }, - Validate: survey.Required, - }, + } + + skuAnswers := struct{ SKU string }{} + if err := survey.Ask(skuSurvey, &skuAnswers); err != nil { + return fmt.Errorf("error getting SKU: %v", err) + } + + sku := skuByName[skuAnswers.SKU] + machine = sku.Name } - skuAnswers := struct{ SKU string }{} - if err := survey.Ask(skuSurvey, &skuAnswers); err != nil { - return fmt.Errorf("error getting SKU: %v", err) - } - - sku := skuByName[skuAnswers.SKU] fmt.Println("Creating your codespace...") - codespace, err := apiClient.CreateCodespace(ctx, userResult.User, repository, sku, branch, locationResult.Location) + codespace, err := apiClient.CreateCodespace(ctx, userResult.User, repository, machine, branch, locationResult.Location) if err != nil { return fmt.Errorf("error creating codespace: %v", err) } From 7332aa428c4db7b87c4280063b89a4d10763cb3c Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 22 Jul 2021 00:45:45 +0000 Subject: [PATCH 034/290] Large refactor and solidifying of APIs before tests --- api.go | 130 ---------------------------------------------- client.go | 50 +++++++++++------- client_test.go | 24 +++++++++ connection.go | 43 +++++++++++++++ liveshare.go | 77 --------------------------- port_forwarder.go | 10 ++-- rpc.go | 2 - server.go | 8 +-- session.go | 60 --------------------- socket.go | 105 +++++++++++++++++++++++++++++++++++++ ssh.go | 16 +++--- terminal.go | 8 +-- websocket.go | 105 ------------------------------------- 13 files changed, 224 insertions(+), 414 deletions(-) delete mode 100644 api.go create mode 100644 client_test.go create mode 100644 connection.go delete mode 100644 liveshare.go delete mode 100644 session.go create mode 100644 socket.go delete mode 100644 websocket.go diff --git a/api.go b/api.go deleted file mode 100644 index 55b6e6e93..000000000 --- a/api.go +++ /dev/null @@ -1,130 +0,0 @@ -package liveshare - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "strings" -) - -type api struct { - client *Client - httpClient *http.Client - serviceURI string - workspaceID string -} - -func newAPI(client *Client) *api { - serviceURI := client.liveShare.Configuration.LiveShareEndpoint - if !strings.HasSuffix(client.liveShare.Configuration.LiveShareEndpoint, "/") { - serviceURI = client.liveShare.Configuration.LiveShareEndpoint + "/" - } - - if !strings.Contains(serviceURI, "api/v1.2") { - serviceURI = serviceURI + "api/v1.2" - } - - serviceURI = strings.TrimSuffix(serviceURI, "/") - - return &api{client, &http.Client{}, serviceURI, strings.ToUpper(client.liveShare.Configuration.WorkspaceID)} -} - -type workspaceAccessResponse struct { - SessionToken string `json:"sessionToken"` - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - Name string `json:"name"` - OwnerID string `json:"ownerId"` - JoinLink string `json:"joinLink"` - ConnectLinks []string `json:"connectLinks"` - RelayLink string `json:"relayLink"` - RelaySas string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` - ConversationID string `json:"conversationId"` - AssociatedUserIDs map[string]string `json:"associatedUserIds"` - AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` - IsHostConnected bool `json:"isHostConnected"` - ExpiresAt string `json:"expiresAt"` - InvitationLinks []string `json:"invitationLinks"` - ID string `json:"id"` -} - -func (a *api) workspaceAccess() (*workspaceAccessResponse, error) { - url := fmt.Sprintf("%s/workspace/%s/user", a.serviceURI, a.workspaceID) - - req, err := http.NewRequest(http.MethodPut, url, nil) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - a.setDefaultHeaders(req) - resp, err := a.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error making request: %v", err) - } - - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) - } - - var response workspaceAccessResponse - if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response into json: %v", err) - } - - return &response, nil -} - -func (a *api) setDefaultHeaders(req *http.Request) { - req.Header.Set("Authorization", "Bearer "+a.client.liveShare.Configuration.Token) - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Content-Type", "application/json") -} - -type workspaceInfoResponse struct { - CreatedAt string `json:"createdAt"` - UpdatedAt string `json:"updatedAt"` - Name string `json:"name"` - OwnerID string `json:"ownerId"` - JoinLink string `json:"joinLink"` - ConnectLinks []string `json:"connectLinks"` - RelayLink string `json:"relayLink"` - RelaySas string `json:"relaySas"` - HostPublicKeys []string `json:"hostPublicKeys"` - ConversationID string `json:"conversationId"` - AssociatedUserIDs map[string]string - AreAnonymousGuestsAllowed bool `json:"areAnonymousGuestsAllowed"` - IsHostConnected bool `json:"isHostConnected"` - ExpiresAt string `json:"expiresAt"` - InvitationLinks []string `json:"invitationLinks"` - ID string `json:"id"` -} - -func (a *api) workspaceInfo() (*workspaceInfoResponse, error) { - url := fmt.Sprintf("%s/workspace/%s", a.serviceURI, a.workspaceID) - - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) - } - - a.setDefaultHeaders(req) - resp, err := a.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("error making request: %v", err) - } - - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) - } - - var response workspaceInfoResponse - if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response into json: %v", err) - } - - return &response, nil -} diff --git a/client.go b/client.go index 456a2c321..1ad90b336 100644 --- a/client.go +++ b/client.go @@ -8,31 +8,44 @@ import ( ) type Client struct { - liveShare *LiveShare - session *session + connection Connection + sshSession *sshSession rpc *rpc } -// NewClient is a function ... -func (l *LiveShare) NewClient() *Client { - return &Client{liveShare: l} +type ClientOption func(*Client) error + +func NewClient(opts ...ClientOption) (*Client, error) { + client := new(Client) + + for _, o := range opts { + if err := o(client); err != nil { + return nil, err + } + } + + return client, nil +} + +func WithConnection(connection Connection) ClientOption { + return func(c *Client) error { + if err := connection.validate(); err != nil { + return err + } + + c.connection = connection + return nil + } } func (c *Client) Join(ctx context.Context) (err error) { - api := newAPI(c) - - c.session = newSession(api) - if err := c.session.init(ctx); err != nil { - return fmt.Errorf("error creating session: %v", err) - } - - websocket := newWebsocket(c.session) - if err := websocket.connect(ctx); err != nil { + clientSocket := newSocket(c.connection) + if err := clientSocket.connect(ctx); err != nil { return fmt.Errorf("error connecting websocket: %v", err) } - c.sshSession = newSSH(c.session, websocket) + c.sshSession = newSSH(c.connection.SessionToken, clientSocket) if err := c.sshSession.connect(ctx); err != nil { return fmt.Errorf("error connecting to ssh session: %v", err) } @@ -69,9 +82,9 @@ type joinWorkspaceResult struct { func (c *Client) joinWorkspace(ctx context.Context) (*joinWorkspaceResult, error) { args := joinWorkspaceArgs{ - ID: c.session.workspaceInfo.ID, + ID: c.connection.SessionID, ConnectionMode: "local", - JoiningUserSessionToken: c.session.workspaceAccess.SessionToken, + JoiningUserSessionToken: c.connection.SessionToken, ClientCapabilities: clientCapabilities{ IsNonInteractive: false, }, @@ -99,8 +112,7 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition go ssh.DiscardRequests(reqs) requestType := fmt.Sprintf("stream-transport-%s", streamID) - _, err = channel.SendRequest(requestType, true, nil) - if err != nil { + if _, err = channel.SendRequest(requestType, true, nil); err != nil { return nil, fmt.Errorf("error sending channel request: %v", err) } diff --git a/client_test.go b/client_test.go new file mode 100644 index 000000000..8d118974d --- /dev/null +++ b/client_test.go @@ -0,0 +1,24 @@ +package liveshare + +import ( + "testing" +) + +func TestClientJoin(t *testing.T) { + // connection := Connection{ + // SessionID: "session-id", + // SessionToken: "session-token", + // RelayEndpoint: "relay-endpoint", + // RelaySAS: "relay-sas", + // } + + // client, err := NewClient(WithConnection(connection)) + // if err != nil { + // t.Errorf("error creating client: %v", err) + // } + + // ctx := context.Background() + // if err := client.Join(ctx); err != nil { + // t.Errorf("error joining client: %v", err) + // } +} diff --git a/connection.go b/connection.go new file mode 100644 index 000000000..a97935d3b --- /dev/null +++ b/connection.go @@ -0,0 +1,43 @@ +package liveshare + +import ( + "errors" + "net/url" + "strings" +) + +type Connection struct { + SessionID string `json:"sessionId"` + SessionToken string `json:"sessionToken"` + RelaySAS string `json:"relaySas"` + RelayEndpoint string `json:"relayEndpoint"` +} + +func (r Connection) validate() error { + if r.SessionID == "" { + return errors.New("connection sessionID is required") + } + + if r.SessionToken == "" { + return errors.New("connection sessionToken is required") + } + + if r.RelaySAS == "" { + return errors.New("connection relaySas is required") + } + + if r.RelayEndpoint == "" { + return errors.New("connection relayEndpoint is required") + } + + return nil +} + +func (r Connection) uri(action string) string { + sas := url.QueryEscape(r.RelaySAS) + uri := r.RelayEndpoint + uri = strings.Replace(uri, "sb:", "wss:", -1) + uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) + uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas + return uri +} diff --git a/liveshare.go b/liveshare.go deleted file mode 100644 index 3c4be5c05..000000000 --- a/liveshare.go +++ /dev/null @@ -1,77 +0,0 @@ -package liveshare - -import ( - "errors" - "fmt" - "strings" -) - -type LiveShare struct { - Configuration *Configuration -} - -func New(opts ...Option) (*LiveShare, error) { - configuration := NewConfiguration() - - for _, o := range opts { - if err := o(configuration); err != nil { - return nil, fmt.Errorf("error configuring liveshare: %v", err) - } - } - - if err := configuration.Validate(); err != nil { - return nil, fmt.Errorf("error validating configuration: %v", err) - } - - return &LiveShare{Configuration: configuration}, nil -} - -type Option func(configuration *Configuration) error - -func WithWorkspaceID(id string) Option { - return func(configuration *Configuration) error { - configuration.WorkspaceID = id - return nil - } -} - -func WithLiveShareEndpoint(liveShareEndpoint string) Option { - return func(configuration *Configuration) error { - configuration.LiveShareEndpoint = liveShareEndpoint - return nil - } -} - -func WithToken(token string) Option { - return func(configuration *Configuration) error { - configuration.Token = token - return nil - } -} - -type Configuration struct { - WorkspaceID, LiveShareEndpoint, Token string -} - -func NewConfiguration() *Configuration { - return &Configuration{ - LiveShareEndpoint: "https://prod.liveshare.vsengsaas.visualstudio.com", - } -} - -func (c *Configuration) Validate() error { - errs := []string{} - if c.WorkspaceID == "" { - errs = append(errs, "WorkspaceID is required") - } - - if c.Token == "" { - errs = append(errs, "Token is required") - } - - if len(errs) > 0 { - return errors.New(strings.Join(errs, ", ")) - } - - return nil -} diff --git a/port_forwarder.go b/port_forwarder.go index 20382c208..1227493b2 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -11,18 +11,18 @@ import ( "golang.org/x/crypto/ssh" ) -type LocalPortForwarder struct { +type PortForwarder struct { client *Client server *Server port int channels []ssh.Channel } -func NewLocalPortForwarder(client *Client, server *Server, port int) *LocalPortForwarder { - return &LocalPortForwarder{client, server, port, []ssh.Channel{}} +func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { + return &PortForwarder{client, server, port, []ssh.Channel{}} } -func (l *LocalPortForwarder) Start(ctx context.Context) error { +func (l *PortForwarder) Start(ctx context.Context) error { ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) if err != nil { return fmt.Errorf("error listening on tcp port: %v", err) @@ -42,7 +42,7 @@ func (l *LocalPortForwarder) Start(ctx context.Context) error { return nil } -func (l *LocalPortForwarder) handleConnection(ctx context.Context, conn net.Conn) { +func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) if err != nil { log.Println("errrr handle Connect") diff --git a/rpc.go b/rpc.go index de427cda9..d40046471 100644 --- a/rpc.go +++ b/rpc.go @@ -78,7 +78,5 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{} }() - } else { - // TODO(josebalius): Handle } } diff --git a/server.go b/server.go index 71ec9d4dd..6f17d5ac5 100644 --- a/server.go +++ b/server.go @@ -13,12 +13,12 @@ type Server struct { streamName, streamCondition string } -func (c *Client) NewServer() (*Server, error) { - if !c.hasJoined() { - return nil, errors.New("LiveShareClient must join before creating server") +func NewServer(client *Client) (*Server, error) { + if !client.hasJoined() { + return nil, errors.New("client must join before creating server") } - return &Server{client: c}, nil + return &Server{client: client}, nil } type Port struct { diff --git a/session.go b/session.go deleted file mode 100644 index d0492a10e..000000000 --- a/session.go +++ /dev/null @@ -1,60 +0,0 @@ -package liveshare - -import ( - "context" - "fmt" - "net/url" - "strings" - - "golang.org/x/sync/errgroup" -) - -type session struct { - api *api - - workspaceAccess *workspaceAccessResponse - workspaceInfo *workspaceInfoResponse -} - -func newSession(api *api) *session { - return &session{api: api} -} - -func (s *session) init(ctx context.Context) error { - g, ctx := errgroup.WithContext(ctx) - - g.Go(func() error { - workspaceAccess, err := s.api.workspaceAccess() - if err != nil { - return fmt.Errorf("error getting workspace access: %v", err) - } - s.workspaceAccess = workspaceAccess - return nil - }) - - g.Go(func() error { - workspaceInfo, err := s.api.workspaceInfo() - if err != nil { - return fmt.Errorf("error getting workspace info: %v", err) - } - s.workspaceInfo = workspaceInfo - return nil - }) - - if err := g.Wait(); err != nil { - return err - } - - return nil -} - -// Reference: -// https://github.com/Azure/azure-relay-node/blob/7b57225365df3010163bf4b9e640868a02737eb6/hyco-ws/index.js#L107-L137 -func (s *session) relayURI(action string) string { - relaySas := url.QueryEscape(s.workspaceAccess.RelaySas) - relayURI := s.workspaceAccess.RelayLink - relayURI = strings.Replace(relayURI, "sb:", "wss:", -1) - relayURI = strings.Replace(relayURI, ".net/", ".net:443/$hc/", 1) - relayURI = relayURI + "?sb-hc-action=" + action + "&sb-hc-token=" + relaySas - return relayURI -} diff --git a/socket.go b/socket.go new file mode 100644 index 000000000..c3f75b9db --- /dev/null +++ b/socket.go @@ -0,0 +1,105 @@ +package liveshare + +import ( + "context" + "errors" + "io" + "net" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type socket struct { + addr string + conn *websocket.Conn + readMutex sync.Mutex + writeMutex sync.Mutex + reader io.Reader +} + +func newSocket(clientConn Connection) *socket { + return &socket{addr: clientConn.uri("connect")} +} + +func (s *socket) connect(ctx context.Context) error { + ws, _, err := websocket.DefaultDialer.Dial(s.addr, nil) + if err != nil { + return err + } + s.conn = ws + return nil +} + +func (s *socket) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + messageType, reader, err := s.conn.NextReader() + if err != nil { + return 0, err + } + + if messageType != websocket.BinaryMessage { + return 0, errors.New("unexpected websocket message type") + } + + s.reader = reader + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socket) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, err + } + + bytesWritten, err := nextWriter.Write(b) + nextWriter.Close() + + return bytesWritten, err +} + +func (s *socket) Close() error { + return s.conn.Close() +} + +func (s *socket) LocalAddr() net.Addr { + return s.conn.LocalAddr() +} + +func (s *socket) RemoteAddr() net.Addr { + return s.conn.RemoteAddr() +} + +func (s *socket) SetDeadline(t time.Time) error { + if err := s.SetReadDeadline(t); err != nil { + return err + } + + return s.SetWriteDeadline(t) +} + +func (s *socket) SetReadDeadline(t time.Time) error { + return s.conn.SetReadDeadline(t) +} + +func (s *socket) SetWriteDeadline(t time.Time) error { + return s.conn.SetWriteDeadline(t) +} diff --git a/ssh.go b/ssh.go index 9ae32ed7c..3ea2d2777 100644 --- a/ssh.go +++ b/ssh.go @@ -12,22 +12,22 @@ import ( type sshSession struct { *ssh.Session - session *session - socket net.Conn - conn ssh.Conn - reader io.Reader - writer io.Writer + token string + socket net.Conn + conn ssh.Conn + reader io.Reader + writer io.Writer } -func newSSH(session *session, socket net.Conn) *sshSession { - return &sshSession{session: session, socket: socket} +func newSSH(token string, socket net.Conn) *sshSession { + return &sshSession{token: token, socket: socket} } func (s *sshSession) connect(ctx context.Context) error { clientConfig := ssh.ClientConfig{ User: "", Auth: []ssh.AuthMethod{ - ssh.Password(s.session.workspaceAccess.SessionToken), + ssh.Password(s.token), }, HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), diff --git a/terminal.go b/terminal.go index 631e75912..1621559a1 100644 --- a/terminal.go +++ b/terminal.go @@ -13,13 +13,13 @@ type Terminal struct { client *Client } -func (c *Client) NewTerminal() (*Terminal, error) { - if !c.hasJoined() { - return nil, errors.New("LiveShareClient must join before creating terminal") +func NewTerminal(client *Client) (*Terminal, error) { + if !client.hasJoined() { + return nil, errors.New("client must join before creating terminal") } return &Terminal{ - client: c, + client: client, }, nil } diff --git a/websocket.go b/websocket.go deleted file mode 100644 index ae163e6e2..000000000 --- a/websocket.go +++ /dev/null @@ -1,105 +0,0 @@ -package liveshare - -import ( - "context" - "errors" - "io" - "net" - "sync" - "time" - - gorillawebsocket "github.com/gorilla/websocket" -) - -type websocket struct { - session *session - conn *gorillawebsocket.Conn - readMutex sync.Mutex - writeMutex sync.Mutex - reader io.Reader -} - -func newWebsocket(session *session) *websocket { - return &websocket{session: session} -} - -func (w *websocket) connect(ctx context.Context) error { - ws, _, err := gorillawebsocket.DefaultDialer.Dial(w.session.relayURI("connect"), nil) - if err != nil { - return err - } - w.conn = ws - return nil -} - -func (w *websocket) Read(b []byte) (int, error) { - w.readMutex.Lock() - defer w.readMutex.Unlock() - - if w.reader == nil { - messageType, reader, err := w.conn.NextReader() - if err != nil { - return 0, err - } - - if messageType != gorillawebsocket.BinaryMessage { - return 0, errors.New("unexpected websocket message type") - } - - w.reader = reader - } - - bytesRead, err := w.reader.Read(b) - if err != nil { - w.reader = nil - - if err == io.EOF { - err = nil - } - } - - return bytesRead, err -} - -func (w *websocket) Write(b []byte) (int, error) { - w.writeMutex.Lock() - defer w.writeMutex.Unlock() - - nextWriter, err := w.conn.NextWriter(gorillawebsocket.BinaryMessage) - if err != nil { - return 0, err - } - - bytesWritten, err := nextWriter.Write(b) - nextWriter.Close() - - return bytesWritten, err -} - -func (w *websocket) Close() error { - return w.conn.Close() -} - -func (w *websocket) LocalAddr() net.Addr { - return w.conn.LocalAddr() -} - -func (w *websocket) RemoteAddr() net.Addr { - return w.conn.RemoteAddr() -} - -func (w *websocket) SetDeadline(t time.Time) error { - if err := w.SetReadDeadline(t); err != nil { - return err - } - - return w.SetWriteDeadline(t) -} - -func (w *websocket) SetReadDeadline(t time.Time) error { - return w.conn.SetReadDeadline(t) -} - -func (w *websocket) SetWriteDeadline(t time.Time) error { - return w.conn.SetWriteDeadline(t) -} From a68cda14698887808309dda7eaa0f11ed7c51829 Mon Sep 17 00:00:00 2001 From: Camilo Garcia La Rotta Date: Wed, 21 Jul 2021 20:54:18 -0400 Subject: [PATCH 035/290] refactor: break down Create() into smaller funcs --- cmd/ghcs/create.go | 153 +++++++++++++++++++++++++++------------------ 1 file changed, 93 insertions(+), 60 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 0b0d9ebd3..128a5ea44 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -36,37 +36,19 @@ func init() { rootCmd.AddCommand(newCreateCmd()) } -var repoSurvey = []*survey.Question{ - { - Name: "repository", - Prompt: &survey.Input{Message: "Repository"}, - Validate: survey.Required, - }, -} -var branchSurvey = []*survey.Question{ - { - Name: "branch", - Prompt: &survey.Input{Message: "Branch"}, - Validate: survey.Required, - }, -} - func Create() error { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) locationCh := getLocation(ctx, apiClient) userCh := getUser(ctx, apiClient) - if repo == "" { - if err := survey.Ask(repoSurvey, &repo); err != nil { - return fmt.Errorf("error getting repository name: %v", err) - } + repo, err := getRepoName() + if err != nil { + return fmt.Errorf("error getting repository name: %v", err) } - - if branch == "" { - if err := survey.Ask(branchSurvey, &branch); err != nil { - return fmt.Errorf("error getting branch name: %v", err) - } + branch, err := getBranchName() + if err != nil { + return fmt.Errorf("error getting branch name: %v", err) } repository, err := apiClient.GetRepository(ctx, repo) @@ -84,48 +66,15 @@ func Create() error { return fmt.Errorf("error getting codespace user: %v", userResult.Err) } - skus, err := apiClient.GetCodespacesSkus(ctx, userResult.User, repository, locationResult.Location) + machine, err := getMachineName(ctx, userResult.User, repository, locationResult.Location, apiClient) if err != nil { - return fmt.Errorf("error getting codespace skus: %v", err) + return fmt.Errorf("error getting machine type: %v", err) } - - if len(skus) == 0 { + if machine == "" { fmt.Println("There are no available machine types for this repository") return nil } - if machine == "" { - skuNames := make([]string, 0, len(skus)) - skuByName := make(map[string]*api.Sku) - for _, sku := range skus { - nameParts := camelcase.Split(sku.Name) - machineName := strings.Title(strings.ToLower(nameParts[0])) - skuName := fmt.Sprintf("%s - %s", machineName, sku.DisplayName) - skuNames = append(skuNames, skuName) - skuByName[skuName] = sku - } - - skuSurvey := []*survey.Question{ - { - Name: "sku", - Prompt: &survey.Select{ - Message: "Choose Machine Type:", - Options: skuNames, - Default: skuNames[0], - }, - Validate: survey.Required, - }, - } - - skuAnswers := struct{ SKU string }{} - if err := survey.Ask(skuSurvey, &skuAnswers); err != nil { - return fmt.Errorf("error getting SKU: %v", err) - } - - sku := skuByName[skuAnswers.SKU] - machine = sku.Name - } - fmt.Println("Creating your codespace...") codespace, err := apiClient.CreateCodespace(ctx, userResult.User, repository, machine, branch, locationResult.Location) @@ -165,3 +114,87 @@ func getLocation(ctx context.Context, apiClient *api.API) <-chan locationResult }() return ch } + +func getRepoName() (string, error) { + if repo != "" { + return repo, nil + } + + repoSurvey := []*survey.Question{ + { + Name: "repository", + Prompt: &survey.Input{Message: "Repository"}, + Validate: survey.Required, + }, + } + err := survey.Ask(repoSurvey, &repo) + return repo, err +} + +func getBranchName() (string, error) { + if branch != "" { + return branch, nil + } + + branchSurvey := []*survey.Question{ + { + Name: "branch", + Prompt: &survey.Input{Message: "Branch"}, + Validate: survey.Required, + }, + } + err := survey.Ask(branchSurvey, &branch) + return branch, err +} + +func getMachineName(ctx context.Context, user *api.User, repo *api.Repository, location string, apiClient *api.API) (string, error) { + skus, err := apiClient.GetCodespacesSkus(ctx, user, repo, location) + if err != nil { + return "", fmt.Errorf("error getting codespace skus: %v", err) + } + + // if user supplied a machine type, it must be valid + // if no machine type was supplied, we don't error if there are no machine types for the current repo + if machine != "" { + for _, sku := range skus { + if machine == sku.Name { + return machine, nil + } + } + return "", fmt.Errorf("there are is no such machine for the repository: %s", machine) + } else if len(skus) == 0 { + return "", nil + } + + skuNames := make([]string, 0, len(skus)) + skuByName := make(map[string]*api.Sku) + for _, sku := range skus { + nameParts := camelcase.Split(sku.Name) + machineName := strings.Title(strings.ToLower(nameParts[0])) + skuName := fmt.Sprintf("%s - %s", machineName, sku.DisplayName) + skuNames = append(skuNames, skuName) + skuByName[skuName] = sku + } + + skuSurvey := []*survey.Question{ + { + Name: "sku", + Prompt: &survey.Select{ + Message: "Choose Machine Type:", + Options: skuNames, + Default: skuNames[0], + }, + Validate: survey.Required, + }, + } + + skuAnswers := struct{ SKU string }{} + if err := survey.Ask(skuSurvey, &skuAnswers); err != nil { + return "", fmt.Errorf("error getting SKU: %v", err) + } + + sku := skuByName[skuAnswers.SKU] + machine = sku.Name + + return machine, nil +} From fddcd876b0b6e50959b0530662c50dda1b0079c1 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 22 Jul 2021 01:02:03 +0000 Subject: [PATCH 036/290] Some more cleanup to the port forwarder and connection --- connection.go | 16 ++++++++-------- port_forwarder.go | 28 +++++++++++++--------------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/connection.go b/connection.go index a97935d3b..eda050f63 100644 --- a/connection.go +++ b/connection.go @@ -7,27 +7,27 @@ import ( ) type Connection struct { - SessionID string `json:"sessionId"` - SessionToken string `json:"sessionToken"` - RelaySAS string `json:"relaySas"` - RelayEndpoint string `json:"relayEndpoint"` + SessionID string + SessionToken string + RelaySAS string + RelayEndpoint string } func (r Connection) validate() error { if r.SessionID == "" { - return errors.New("connection sessionID is required") + return errors.New("connection SessionID is required") } if r.SessionToken == "" { - return errors.New("connection sessionToken is required") + return errors.New("connection SessionToken is required") } if r.RelaySAS == "" { - return errors.New("connection relaySas is required") + return errors.New("connection RelaySAS is required") } if r.RelayEndpoint == "" { - return errors.New("connection relayEndpoint is required") + return errors.New("connection RelayEndpoint is required") } return nil diff --git a/port_forwarder.go b/port_forwarder.go index 1227493b2..8d42f3c05 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -4,22 +4,24 @@ import ( "context" "fmt" "io" - "log" "net" "strconv" - - "golang.org/x/crypto/ssh" ) type PortForwarder struct { - client *Client - server *Server - port int - channels []ssh.Channel + client *Client + server *Server + port int + errCh chan error } func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { - return &PortForwarder{client, server, port, []ssh.Channel{}} + return &PortForwarder{ + client: client, + server: server, + port: port, + errCh: make(chan error), + } } func (l *PortForwarder) Start(ctx context.Context) error { @@ -37,22 +39,18 @@ func (l *PortForwarder) Start(ctx context.Context) error { go l.handleConnection(ctx, conn) } - // clean up after ourselves - return nil } func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) if err != nil { - log.Println("errrr handle Connect") - log.Println(err) // TODO(josebalius) handle this somehow + l.errCh <- fmt.Errorf("error opening streaming channel for new connection: %v", err) + return } - l.channels = append(l.channels, channel) copyConn := func(writer io.Writer, reader io.Reader) { - _, err := io.Copy(writer, reader) - if err != nil { + if _, err := io.Copy(writer, reader); err != nil { channel.Close() conn.Close() } From a99d0f5495a575c0694307dbbe606210b16830d7 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 22 Jul 2021 01:07:06 +0000 Subject: [PATCH 037/290] Better naming for rpc client and ssh session --- client.go | 14 +++++++------- rpc.go | 10 +++++----- ssh.go | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 1ad90b336..d0e6e84a8 100644 --- a/client.go +++ b/client.go @@ -10,8 +10,8 @@ import ( type Client struct { connection Connection - sshSession *sshSession - rpc *rpc + ssh *sshSession + rpc *rpcClient } type ClientOption func(*Client) error @@ -45,12 +45,12 @@ func (c *Client) Join(ctx context.Context) (err error) { return fmt.Errorf("error connecting websocket: %v", err) } - c.sshSession = newSSH(c.connection.SessionToken, clientSocket) - if err := c.sshSession.connect(ctx); err != nil { + c.ssh = newSshSession(c.connection.SessionToken, clientSocket) + if err := c.ssh.connect(ctx); err != nil { return fmt.Errorf("error connecting to ssh session: %v", err) } - c.rpc = newRPC(c.sshSession) + c.rpc = newRpcClient(c.ssh) c.rpc.connect(ctx) _, err = c.joinWorkspace(ctx) @@ -62,7 +62,7 @@ func (c *Client) Join(ctx context.Context) (err error) { } func (c *Client) hasJoined() bool { - return c.sshSession != nil && c.rpc != nil + return c.ssh != nil && c.rpc != nil } type clientCapabilities struct { @@ -105,7 +105,7 @@ func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition return nil, fmt.Errorf("error getting stream id: %v", err) } - channel, reqs, err := c.sshSession.conn.OpenChannel("session", nil) + channel, reqs, err := c.ssh.conn.OpenChannel("session", nil) if err != nil { return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) } diff --git a/rpc.go b/rpc.go index d40046471..d624bbd74 100644 --- a/rpc.go +++ b/rpc.go @@ -9,22 +9,22 @@ import ( "github.com/sourcegraph/jsonrpc2" ) -type rpc struct { +type rpcClient struct { *jsonrpc2.Conn conn io.ReadWriteCloser handler *rpcHandler } -func newRPC(conn io.ReadWriteCloser) *rpc { - return &rpc{conn: conn, handler: newRPCHandler()} +func newRpcClient(conn io.ReadWriteCloser) *rpcClient { + return &rpcClient{conn: conn, handler: newRPCHandler()} } -func (r *rpc) connect(ctx context.Context) { +func (r *rpcClient) connect(ctx context.Context) { stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) } -func (r *rpc) do(ctx context.Context, method string, args interface{}, result interface{}) error { +func (r *rpcClient) do(ctx context.Context, method string, args interface{}, result interface{}) error { waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { return fmt.Errorf("error on dispatch call: %v", err) diff --git a/ssh.go b/ssh.go index 3ea2d2777..e22cd69d1 100644 --- a/ssh.go +++ b/ssh.go @@ -19,7 +19,7 @@ type sshSession struct { writer io.Writer } -func newSSH(token string, socket net.Conn) *sshSession { +func newSshSession(token string, socket net.Conn) *sshSession { return &sshSession{token: token, socket: socket} } From b66d65379faba6635690cd9313fd43b60a69a59b Mon Sep 17 00:00:00 2001 From: Issy Long Date: Thu, 22 Jul 2021 11:07:23 +0100 Subject: [PATCH 038/290] cmd/ghcs/*.go: Better short descriptions of what commands do MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - I ran `--help` on `ghcs code` and saw `ghcs code` and that was it, which was surprising. I expected a description. - Here's a fix for all of the commands thus far to give them longer descriptions. - I've only done "short" descriptions in Cobra terms, and removed the "long" descriptions as they seemed like they needed to be unnecessarily verbose. Before: ``` ❯ ghcs --help Codespaces Usage: ghcs [command] Available Commands: code code create Create delete delete help Help about any command list list ports ports ssh ssh Flags: -h, --help help for ghcs -v, --version version for ghcs Use "ghcs [command] --help" for more information about a command. ❯ ghcs ssh --help ssh Usage: ghcs ssh [flags] Flags: -h, --help help for ssh --profile string SSH Profile --server-port int SSH Server Port ``` After: ``` ❯ ./ghcs --help Codespaces Usage: ghcs [command] Available Commands: code Open a GitHub Codespace in VSCode. create Create a GitHub Codespace. delete Delete a GitHub Codespace. help Help about any command list List GitHub Codespaces you have on your account. ports Forward ports from a GitHub Codespace. ssh SSH into a GitHub Codespace, for use with running tests/editing in vim, etc. Flags: -h, --help help for ghcs -v, --version version for ghcs Use "ghcs [command] --help" for more information about a command. ❯ ./ghcs ssh --help SSH into a GitHub Codespace, for use with running tests/editing in vim, etc. Usage: ghcs ssh [flags] Flags: -h, --help help for ssh --profile string SSH Profile --server-port int SSH Server Port ``` --- cmd/ghcs/code.go | 3 +-- cmd/ghcs/create.go | 3 +-- cmd/ghcs/delete.go | 3 +-- cmd/ghcs/list.go | 3 +-- cmd/ghcs/ports.go | 3 +-- cmd/ghcs/ssh.go | 3 +-- 6 files changed, 6 insertions(+), 12 deletions(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 1080fd896..ccb4788ee 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -15,8 +15,7 @@ import ( func NewCodeCmd() *cobra.Command { return &cobra.Command{ Use: "code", - Short: "code", - Long: "code", + Short: "Open a GitHub Codespace in VSCode.", RunE: func(cmd *cobra.Command, args []string) error { var codespaceName string if len(args) > 0 { diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 44bedb5f2..b179b5dba 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -14,8 +14,7 @@ import ( var createCmd = &cobra.Command{ Use: "create", - Short: "Create", - Long: "Create", + Short: "Create a GitHub Codespace.", RunE: func(cmd *cobra.Command, args []string) error { return Create() }, diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index a24374a4c..0f274e987 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -13,8 +13,7 @@ import ( func NewDeleteCmd() *cobra.Command { deleteCmd := &cobra.Command{ Use: "delete CODESPACE_NAME", - Short: "delete", - Long: "delete", + Short: "Delete a GitHub Codespace.", RunE: func(cmd *cobra.Command, args []string) error { if len(args) == 0 { return errors.New("A Codespace name is required.") diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index e02e6a1d2..6db79af97 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -14,8 +14,7 @@ import ( func NewListCmd() *cobra.Command { listCmd := &cobra.Command{ Use: "list", - Short: "list", - Long: "list", + Short: "List GitHub Codespaces you have on your account.", RunE: func(cmd *cobra.Command, args []string) error { return List() }, diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index d5b863c82..6d2086088 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -20,8 +20,7 @@ import ( func NewPortsCmd() *cobra.Command { portsCmd := &cobra.Command{ Use: "ports", - Short: "ports", - Long: "ports", + Short: "Forward ports from a GitHub Codespace.", RunE: func(cmd *cobra.Command, args []string) error { return Ports() }, diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 6c4385aaf..89aa77c1b 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -23,8 +23,7 @@ func NewSSHCmd() *cobra.Command { sshCmd := &cobra.Command{ Use: "ssh", - Short: "ssh", - Long: "ssh", + Short: "SSH into a GitHub Codespace, for use with running tests/editing in vim, etc.", RunE: func(cmd *cobra.Command, args []string) error { return SSH(sshProfile, sshServerPort) }, From 69865fa7623bc0a857dc8a1f4140d3c4567785a7 Mon Sep 17 00:00:00 2001 From: Issy Long Date: Thu, 22 Jul 2021 14:08:20 +0100 Subject: [PATCH 039/290] cmd/ghcs/main: Better description of `ghcs` as a whole Co-authored-by: Camilo Garcia La Rotta --- cmd/ghcs/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 696975ab5..a89a544b4 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -13,8 +13,8 @@ func main() { var rootCmd = &cobra.Command{ Use: "ghcs", - Short: "Codespaces", - Long: "Codespaces", + Short: "Unofficial GitHub Codespaces CLI.", + Long: "Unofficial CLI tool to manage and interact with GitHub Codespaces.", Version: "0.6.0", } From 3ef0226e20e64088c4755b7b91a0d02a5fccf697 Mon Sep 17 00:00:00 2001 From: Camilo Garcia La Rotta Date: Thu, 22 Jul 2021 10:07:09 -0400 Subject: [PATCH 040/290] fix: output available machine names on --machine error --- cmd/ghcs/create.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 128a5ea44..b23e5631b 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -161,7 +161,13 @@ func getMachineName(ctx context.Context, user *api.User, repo *api.Repository, l return machine, nil } } - return "", fmt.Errorf("there are is no such machine for the repository: %s", machine) + + availableSkus := make([]string, len(skus)) + for i := 0; i < len(skus); i++ { + availableSkus[i] = skus[i].Name + } + + return "", fmt.Errorf("there are is no such machine for the repository: %s\nAvailable machines: %v", machine, availableSkus) } else if len(skus) == 0 { return "", nil } From 14468baba6d84fca7546e44e8d7633f0968c8aa5 Mon Sep 17 00:00:00 2001 From: Camilo Garcia La Rotta Date: Thu, 22 Jul 2021 10:13:20 -0400 Subject: [PATCH 041/290] config: bump to v0.7.0 --- cmd/ghcs/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index a89a544b4..0b5c001b2 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -15,7 +15,7 @@ var rootCmd = &cobra.Command{ Use: "ghcs", Short: "Unofficial GitHub Codespaces CLI.", Long: "Unofficial CLI tool to manage and interact with GitHub Codespaces.", - Version: "0.6.0", + Version: "0.7.0", } func Execute() { From 7e49db3be3129761aaddeb9acdb038797335b303 Mon Sep 17 00:00:00 2001 From: CamiloGarciaLaRotta Date: Thu, 22 Jul 2021 10:33:00 -0400 Subject: [PATCH 042/290] config: bump to 0.7.1 Hoping to prove that Goreleaser & Homebrew run automatically --- cmd/ghcs/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 0b5c001b2..aee1b2aec 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -15,7 +15,7 @@ var rootCmd = &cobra.Command{ Use: "ghcs", Short: "Unofficial GitHub Codespaces CLI.", Long: "Unofficial CLI tool to manage and interact with GitHub Codespaces.", - Version: "0.7.0", + Version: "0.7.1", } func Execute() { From b9cd9af7fa83ad2fd7cca4727d5adc1be51fa384 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 23 Jul 2021 01:17:32 +0000 Subject: [PATCH 043/290] Start of tests and comments --- client.go | 5 ++++ client_test.go | 73 +++++++++++++++++++++++++++++++++++++++++++++-- connection.go | 1 + port_forwarder.go | 3 ++ 4 files changed, 79 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index d0e6e84a8..435dd6775 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "golang.org/x/crypto/ssh" ) +// A Client capable of joining a liveshare connection type Client struct { connection Connection @@ -14,8 +15,10 @@ type Client struct { rpc *rpcClient } +// A ClientOption is a function that modifies a client type ClientOption func(*Client) error +// NewClient accepts a range of options, applies them and returns a client func NewClient(opts ...ClientOption) (*Client, error) { client := new(Client) @@ -28,6 +31,7 @@ func NewClient(opts ...ClientOption) (*Client, error) { return client, nil } +// WithConnection is a ClientOption that accepts a Connection func WithConnection(connection Connection) ClientOption { return func(c *Client) error { if err := connection.validate(); err != nil { @@ -39,6 +43,7 @@ func WithConnection(connection Connection) ClientOption { } } +// Join is a method that joins the client to the liveshare session func (c *Client) Join(ctx context.Context) (err error) { clientSocket := newSocket(c.connection) if err := clientSocket.connect(ctx); err != nil { diff --git a/client_test.go b/client_test.go index 8d118974d..86f732637 100644 --- a/client_test.go +++ b/client_test.go @@ -1,22 +1,89 @@ package liveshare import ( + "fmt" + "net/http" + "net/http/httptest" "testing" + + "github.com/gorilla/websocket" ) +func TestNewClient(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Errorf("error creating new client: %v", err) + } + if client == nil { + t.Error("client is nil") + } +} + +func TestNewClientValidConnection(t *testing.T) { + connection := Connection{"1", "2", "3", "4"} + + client, err := NewClient(WithConnection(connection)) + if err != nil { + t.Errorf("error creating new client: %v", err) + } + if client == nil { + t.Error("client is nil") + } +} + +func TestNewClientWithInvalidConnection(t *testing.T) { + connection := Connection{} + + if _, err := NewClient(WithConnection(connection)); err == nil { + t.Error("err is nil") + } +} + +var upgrader = websocket.Upgrader{} + +func newMockLiveShareServer() *httptest.Server { + endpoint := func(w http.ResponseWriter, req *http.Request) { + c, err := upgrader.Upgrade(w, req, nil) + if err != nil { + fmt.Println(err) + return + } + defer c.Close() + + for { + mt, message, err := c.ReadMessage() + if err != nil { + fmt.Println(err) + break + } + + err = c.WriteMessage(mt, message) + if err != nil { + fmt.Println(err) + break + } + + } + } + + return httptest.NewTLSServer(http.HandlerFunc(endpoint)) +} + func TestClientJoin(t *testing.T) { + // server := newMockLiveShareServer() + // defer server.Close() + // connection := Connection{ // SessionID: "session-id", // SessionToken: "session-token", - // RelayEndpoint: "relay-endpoint", // RelaySAS: "relay-sas", + // RelayEndpoint: "sb" + strings.TrimPrefix(server.URL, "https"), // } // client, err := NewClient(WithConnection(connection)) // if err != nil { - // t.Errorf("error creating client: %v", err) + // t.Errorf("error creating new client: %v", err) // } - // ctx := context.Background() // if err := client.Join(ctx); err != nil { // t.Errorf("error joining client: %v", err) diff --git a/connection.go b/connection.go index eda050f63..c1a4632c8 100644 --- a/connection.go +++ b/connection.go @@ -6,6 +6,7 @@ import ( "strings" ) +// A Connection represents a set of values necessary to join a liveshare connection type Connection struct { SessionID string SessionToken string diff --git a/port_forwarder.go b/port_forwarder.go index 8d42f3c05..6d459b4d6 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -8,6 +8,7 @@ import ( "strconv" ) +// A PortForwader can forward ports from a remote liveshare host to localhost type PortForwarder struct { client *Client server *Server @@ -15,6 +16,7 @@ type PortForwarder struct { errCh chan error } +// NewPortForwarder creates a new PortForwader with a given client, server and port func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { return &PortForwarder{ client: client, @@ -24,6 +26,7 @@ func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { } } +// Start is a method to start forwarding the server to a localhost port func (l *PortForwarder) Start(ctx context.Context) error { ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) if err != nil { From 9132a28e9cf2a09359a80e104a558c77a0c0abea Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 23 Jul 2021 19:15:54 +0000 Subject: [PATCH 044/290] Checking point after continuing to flesh out mock server --- client.go | 11 ++- client_test.go | 100 +++++++++++----------- socket.go | 17 +++- test/server.go | 226 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 299 insertions(+), 55 deletions(-) create mode 100644 test/server.go diff --git a/client.go b/client.go index 435dd6775..a8a1e3864 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package liveshare import ( "context" + "crypto/tls" "fmt" "golang.org/x/crypto/ssh" @@ -10,6 +11,7 @@ import ( // A Client capable of joining a liveshare connection type Client struct { connection Connection + tlsConfig *tls.Config ssh *sshSession rpc *rpcClient @@ -43,9 +45,16 @@ func WithConnection(connection Connection) ClientOption { } } +func WithTLSConfig(tlsConfig *tls.Config) ClientOption { + return func(c *Client) error { + c.tlsConfig = tlsConfig + return nil + } +} + // Join is a method that joins the client to the liveshare session func (c *Client) Join(ctx context.Context) (err error) { - clientSocket := newSocket(c.connection) + clientSocket := newSocket(c.connection, c.tlsConfig) if err := clientSocket.connect(ctx); err != nil { return fmt.Errorf("error connecting websocket: %v", err) } diff --git a/client_test.go b/client_test.go index 86f732637..bf77e3dce 100644 --- a/client_test.go +++ b/client_test.go @@ -1,12 +1,14 @@ package liveshare import ( + "context" + "crypto/tls" "fmt" - "net/http" - "net/http/httptest" + "strings" "testing" - "github.com/gorilla/websocket" + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" ) func TestNewClient(t *testing.T) { @@ -39,53 +41,51 @@ func TestNewClientWithInvalidConnection(t *testing.T) { } } -var upgrader = websocket.Upgrader{} - -func newMockLiveShareServer() *httptest.Server { - endpoint := func(w http.ResponseWriter, req *http.Request) { - c, err := upgrader.Upgrade(w, req, nil) - if err != nil { - fmt.Println(err) - return - } - defer c.Close() - - for { - mt, message, err := c.ReadMessage() - if err != nil { - fmt.Println(err) - break - } - - err = c.WriteMessage(mt, message) - if err != nil { - fmt.Println(err) - break - } - - } +func TestClientJoin(t *testing.T) { + sessionToken := "session-token" + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + return 1, nil } - return httptest.NewTLSServer(http.HandlerFunc(endpoint)) -} - -func TestClientJoin(t *testing.T) { - // server := newMockLiveShareServer() - // defer server.Close() - - // connection := Connection{ - // SessionID: "session-id", - // SessionToken: "session-token", - // RelaySAS: "relay-sas", - // RelayEndpoint: "sb" + strings.TrimPrefix(server.URL, "https"), - // } - - // client, err := NewClient(WithConnection(connection)) - // if err != nil { - // t.Errorf("error creating new client: %v", err) - // } - // ctx := context.Background() - // if err := client.Join(ctx); err != nil { - // t.Errorf("error joining client: %v", err) - // } + server, err := livesharetest.NewServer( + livesharetest.WithPassword(sessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + ) + if err != nil { + t.Errorf("error creating liveshare server: %v", err) + } + defer server.Close() + + ctx := context.Background() + connection := Connection{ + SessionID: "session-id", + SessionToken: sessionToken, + RelaySAS: "relay-sas", + RelayEndpoint: "sb" + strings.TrimPrefix(server.URL(), "https"), + } + + tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + client, err := NewClient(WithConnection(connection), tlsConfig) + if err != nil { + t.Errorf("error creating new client: %v", err) + } + + clientErr := make(chan error) + go func() { + if err := client.Join(ctx); err != nil { + clientErr <- fmt.Errorf("error joining client: %v", err) + return + } + + ctx.Done() + }() + + select { + case err := <-server.Err(): + t.Errorf("error from server: %v", err) + case err := <-clientErr: + t.Errorf("error from client: %v", err) + case <-ctx.Done(): + return + } } diff --git a/socket.go b/socket.go index c3f75b9db..e4f80a0cf 100644 --- a/socket.go +++ b/socket.go @@ -2,9 +2,11 @@ package liveshare import ( "context" + "crypto/tls" "errors" "io" "net" + "net/http" "sync" "time" @@ -12,19 +14,26 @@ import ( ) type socket struct { - addr string + addr string + tlsConfig *tls.Config + conn *websocket.Conn readMutex sync.Mutex writeMutex sync.Mutex reader io.Reader } -func newSocket(clientConn Connection) *socket { - return &socket{addr: clientConn.uri("connect")} +func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket { + return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig} } func (s *socket) connect(ctx context.Context) error { - ws, _, err := websocket.DefaultDialer.Dial(s.addr, nil) + dialer := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + TLSClientConfig: s.tlsConfig, + } + ws, _, err := dialer.Dial(s.addr, nil) if err != nil { return err } diff --git a/test/server.go b/test/server.go new file mode 100644 index 000000000..ed8666cce --- /dev/null +++ b/test/server.go @@ -0,0 +1,226 @@ +package livesharetest + +import ( + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "path/filepath" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/sourcegraph/jsonrpc2" + "golang.org/x/crypto/ssh" +) + +type Server struct { + password string + services map[string]RpcHandleFunc + + sshConfig *ssh.ServerConfig + httptestServer *httptest.Server + errCh chan error +} + +func NewServer(opts ...ServerOption) (*Server, error) { + server := new(Server) + + for _, o := range opts { + if err := o(server); err != nil { + return nil, err + } + } + + server.sshConfig = &ssh.ServerConfig{ + PasswordCallback: sshPasswordCallback(server.password), + } + b, err := ioutil.ReadFile(filepath.Join("test", "private.key")) + if err != nil { + return nil, fmt.Errorf("error reading private.key: %v", err) + } + privateKey, err := ssh.ParsePrivateKey(b) + if err != nil { + return nil, fmt.Errorf("error parsing key: %v", err) + } + server.sshConfig.AddHostKey(privateKey) + + server.errCh = make(chan error) + server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server))) + return server, nil +} + +type ServerOption func(*Server) error + +func WithPassword(password string) ServerOption { + return func(s *Server) error { + s.password = password + return nil + } +} + +func WithService(serviceName string, handler RpcHandleFunc) ServerOption { + return func(s *Server) error { + if s.services == nil { + s.services = make(map[string]RpcHandleFunc) + } + + s.services[serviceName] = handler + return nil + } +} + +func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { + return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + if string(password) == serverPassword { + return nil, nil + } + return nil, errors.New("password rejected") + } +} + +func (s *Server) Close() { + s.httptestServer.Close() +} + +func (s *Server) URL() string { + return s.httptestServer.URL +} + +func (s *Server) Err() <-chan error { + return s.errCh +} + +var upgrader = websocket.Upgrader{} + +func newConnection(server *Server) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + c, err := upgrader.Upgrade(w, req, nil) + if err != nil { + server.errCh <- fmt.Errorf("error upgrading connection: %v", err) + return + } + defer c.Close() + + socketConn := newSocketConn(c) + _, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig) + if err != nil { + server.errCh <- fmt.Errorf("error creating new ssh conn: %v", err) + return + } + go ssh.DiscardRequests(reqs) + + for newChannel := range chans { + ch, reqs, err := newChannel.Accept() + if err != nil { + server.errCh <- fmt.Errorf("error accepting new channel: %v", err) + return + } + go ssh.DiscardRequests(reqs) + go handleNewChannel(server, ch) + } + } +} + +func handleNewChannel(server *Server, channel ssh.Channel) { + stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) + jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server)) +} + +type RpcHandleFunc func(req *jsonrpc2.Request) (interface{}, error) + +type rpcHandler struct { + server *Server +} + +func newRpcHandler(server *Server) *rpcHandler { + return &rpcHandler{server} +} + +func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + handler, found := r.server.services[req.Method] + if !found { + r.server.errCh <- fmt.Errorf("RPC Method: '%v' not serviced", req.Method) + return + } + + result, err := handler(req) + if err != nil { + r.server.errCh <- fmt.Errorf("error handling: '%v': %v", req.Method, err) + return + } + + if err := conn.Reply(ctx, req.ID, result); err != nil { + r.server.errCh <- fmt.Errorf("error replying: %v", err) + } +} + +type socketConn struct { + *websocket.Conn + + reader io.Reader + writeMutex sync.Mutex + readMutex sync.Mutex +} + +func newSocketConn(conn *websocket.Conn) *socketConn { + return &socketConn{Conn: conn} +} + +func (s *socketConn) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + msgType, r, err := s.Conn.NextReader() + if err != nil { + return 0, fmt.Errorf("error getting next reader: %v", err) + } + if msgType != websocket.BinaryMessage { + return 0, fmt.Errorf("invalid message type") + } + s.reader = r + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socketConn) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + w, err := s.Conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, fmt.Errorf("error getting next writer: %v", err) + } + + n, err := w.Write(b) + if err != nil { + return 0, fmt.Errorf("error writing: %v", err) + } + + if err := w.Close(); err != nil { + return 0, fmt.Errorf("error closing writer: %v", err) + } + + return n, nil +} + +func (s *socketConn) SetDeadline(deadline time.Time) error { + if err := s.Conn.SetReadDeadline(deadline); err != nil { + return err + } + return s.Conn.SetWriteDeadline(deadline) +} From fcfb10cb56e6aaaae9c745bc1ee00f48897220f9 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 23 Jul 2021 20:24:50 +0000 Subject: [PATCH 045/290] Working test for Client.Join --- client_test.go | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/client_test.go b/client_test.go index bf77e3dce..fdf566fc0 100644 --- a/client_test.go +++ b/client_test.go @@ -42,27 +42,26 @@ func TestNewClientWithInvalidConnection(t *testing.T) { } func TestClientJoin(t *testing.T) { - sessionToken := "session-token" + connection := Connection{ + SessionID: "session-id", + SessionToken: "session-token", + RelaySAS: "relay-sas", + } joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { - return 1, nil + return joinWorkspaceResult{1}, nil } server, err := livesharetest.NewServer( - livesharetest.WithPassword(sessionToken), + livesharetest.WithPassword(connection.SessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), ) if err != nil { t.Errorf("error creating liveshare server: %v", err) } defer server.Close() + connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") ctx := context.Background() - connection := Connection{ - SessionID: "session-id", - SessionToken: sessionToken, - RelaySAS: "relay-sas", - RelayEndpoint: "sb" + strings.TrimPrefix(server.URL(), "https"), - } tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) client, err := NewClient(WithConnection(connection), tlsConfig) @@ -70,22 +69,22 @@ func TestClientJoin(t *testing.T) { t.Errorf("error creating new client: %v", err) } - clientErr := make(chan error) + done := make(chan error) go func() { if err := client.Join(ctx); err != nil { - clientErr <- fmt.Errorf("error joining client: %v", err) + done <- fmt.Errorf("error joining client: %v", err) return } - ctx.Done() + done <- nil }() select { case err := <-server.Err(): t.Errorf("error from server: %v", err) - case err := <-clientErr: - t.Errorf("error from client: %v", err) - case <-ctx.Done(): - return + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } } } From 91114d35c3d04245a58f78ebf2feb6bb5edde4e2 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Sat, 24 Jul 2021 03:44:20 +0000 Subject: [PATCH 046/290] More tests --- client_test.go | 19 +++++ port_forwarder_test.go | 1 + server_test.go | 186 +++++++++++++++++++++++++++++++++++++++++ test/server.go | 16 ++++ 4 files changed, 222 insertions(+) create mode 100644 port_forwarder_test.go create mode 100644 server_test.go diff --git a/client_test.go b/client_test.go index fdf566fc0..110c7e3b9 100644 --- a/client_test.go +++ b/client_test.go @@ -3,6 +3,8 @@ package liveshare import ( "context" "crypto/tls" + "encoding/json" + "errors" "fmt" "strings" "testing" @@ -48,12 +50,29 @@ func TestClientJoin(t *testing.T) { RelaySAS: "relay-sas", } joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + var joinWorkspaceReq joinWorkspaceArgs + if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { + return nil, fmt.Errorf("error unmarshaling req: %v", err) + } + if joinWorkspaceReq.ID != connection.SessionID { + return nil, errors.New("connection session id does not match") + } + if joinWorkspaceReq.ConnectionMode != "local" { + return nil, errors.New("connection mode is not local") + } + if joinWorkspaceReq.JoiningUserSessionToken != connection.SessionToken { + return nil, errors.New("connection user token does not match") + } + if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false { + return nil, errors.New("non interactive is not false") + } return joinWorkspaceResult{1}, nil } server, err := livesharetest.NewServer( livesharetest.WithPassword(connection.SessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + livesharetest.WithRelaySAS(connection.RelaySAS), ) if err != nil { t.Errorf("error creating liveshare server: %v", err) diff --git a/port_forwarder_test.go b/port_forwarder_test.go new file mode 100644 index 000000000..e3e219705 --- /dev/null +++ b/port_forwarder_test.go @@ -0,0 +1 @@ +package liveshare diff --git a/server_test.go b/server_test.go new file mode 100644 index 000000000..cc2b9adbd --- /dev/null +++ b/server_test.go @@ -0,0 +1,186 @@ +package liveshare + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestNewServerWithNotJoinedClient(t *testing.T) { + client, err := NewClient() + if err != nil { + t.Errorf("error creating new client: %v", err) + } + if _, err := NewServer(client); err == nil { + t.Error("expected error") + } +} + +func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { + connection := Connection{ + SessionID: "session-id", + SessionToken: "session-token", + RelaySAS: "relay-sas", + } + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + return joinWorkspaceResult{1}, nil + } + opts = append( + opts, + livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + ) + testServer, err := livesharetest.NewServer( + opts..., + ) + connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https") + tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + client, err := NewClient(WithConnection(connection), tlsConfig) + if err != nil { + return nil, nil, fmt.Errorf("error creating new client: %v", err) + } + ctx := context.Background() + if err := client.Join(ctx); err != nil { + return nil, nil, fmt.Errorf("error joining client: %v", err) + } + return testServer, client, nil +} + +func TestNewServer(t *testing.T) { + testServer, client, err := newMockJoinedClient() + defer testServer.Close() + if err != nil { + t.Errorf("error creating mock joined client: %v", err) + } + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + if server == nil { + t.Error("server is nil") + } +} + +func TestServerStartSharing(t *testing.T) { + serverPort, serverProtocol := 2222, "sshd" + startSharing := func(req *jsonrpc2.Request) (interface{}, error) { + var args []interface{} + if err := json.Unmarshal(*req.Params, &args); err != nil { + return nil, fmt.Errorf("error unmarshaling request: %v", err) + } + if len(args) < 3 { + return nil, errors.New("not enough arguments to start sharing") + } + if port, ok := args[0].(float64); !ok { + return nil, errors.New("port argument is not an int") + } else if port != float64(serverPort) { + return nil, errors.New("port does not match serverPort") + } + if protocol, ok := args[1].(string); !ok { + return nil, errors.New("protocol argument is not a string") + } else if protocol != serverProtocol { + return nil, errors.New("protocol does not match serverProtocol") + } + if browseURL, ok := args[2].(string); !ok { + return nil, errors.New("browse url is not a string") + } else if browseURL != fmt.Sprintf("http://localhost:%v", serverPort) { + return nil, errors.New("browseURL does not match expected") + } + return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil + } + testServer, client, err := newMockJoinedClient( + livesharetest.WithService("serverSharing.startSharing", startSharing), + ) + defer testServer.Close() + if err != nil { + t.Errorf("error creating mock joined client: %v", err) + } + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + ctx := context.Background() + + done := make(chan error) + go func() { + if err := server.StartSharing(ctx, serverProtocol, serverPort); err != nil { + done <- fmt.Errorf("error sharing server: %v", err) + } + if server.streamName == "" || server.streamCondition == "" { + done <- errors.New("stream name or condition is blank") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} + +func TestServerGetSharedServers(t *testing.T) { + sharedServer := Port{ + SourcePort: 2222, + StreamName: "stream-name", + StreamCondition: "stream-condition", + } + getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) { + return Ports{&sharedServer}, nil + } + testServer, client, err := newMockJoinedClient( + livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), + ) + if err != nil { + t.Errorf("error creating new mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("error creating new server: %v", err) + } + ctx := context.Background() + done := make(chan error) + go func() { + ports, err := server.GetSharedServers(ctx) + if err != nil { + done <- fmt.Errorf("error getting shared servers: %v", err) + } + if len(ports) < 1 { + done <- errors.New("not enough ports returned") + } + if ports[0].SourcePort != sharedServer.SourcePort { + done <- errors.New("source port does not match") + } + if ports[0].StreamName != sharedServer.StreamName { + done <- errors.New("stream name does not match") + } + if ports[0].StreamCondition != sharedServer.StreamCondition { + done <- errors.New("stream condiion does not match") + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} + +func TestServerUpdateSharedVisibility(t *testing.T) { + +} diff --git a/test/server.go b/test/server.go index ed8666cce..abb7ac96a 100644 --- a/test/server.go +++ b/test/server.go @@ -20,6 +20,7 @@ import ( type Server struct { password string services map[string]RpcHandleFunc + relaySAS string sshConfig *ssh.ServerConfig httptestServer *httptest.Server @@ -73,6 +74,13 @@ func WithService(serviceName string, handler RpcHandleFunc) ServerOption { } } +func WithRelaySAS(sas string) ServerOption { + return func(s *Server) error { + s.relaySAS = sas + return nil + } +} + func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { if string(password) == serverPassword { @@ -98,6 +106,14 @@ var upgrader = websocket.Upgrader{} func newConnection(server *Server) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { + if server.relaySAS != "" { + // validate the sas key + sasParam := req.URL.Query().Get("sb-hc-token") + if sasParam != server.relaySAS { + server.errCh <- errors.New("error validating sas") + return + } + } c, err := upgrader.Upgrade(w, req, nil) if err != nil { server.errCh <- fmt.Errorf("error upgrading connection: %v", err) From c092a293501cd757b99e2a89c8b8548310fa440a Mon Sep 17 00:00:00 2001 From: Issy Long Date: Mon, 26 Jul 2021 13:34:11 +0100 Subject: [PATCH 047/290] cmd/ghcs/ssh: Add `-c` parameter for specifying a Codespace to SSH to - This adds a `-c`, `--codespace` parameter to `ghcs ssh` to allow for non-interactively specifying a Codespace to SSH into, for instance if a user has recently done `ghcs list` and already knows which Codespace they want to access. Without a value for the `-c` parameter, the interactive prompt appears as usual. --- cmd/ghcs/ssh.go | 47 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 89aa77c1b..c3c1105e6 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -18,19 +18,20 @@ import ( ) func NewSSHCmd() *cobra.Command { - var sshProfile string + var sshProfile, codespaceName string var sshServerPort int sshCmd := &cobra.Command{ Use: "ssh", Short: "SSH into a GitHub Codespace, for use with running tests/editing in vim, etc.", RunE: func(cmd *cobra.Command, args []string) error { - return SSH(sshProfile, sshServerPort) + return SSH(sshProfile, codespaceName, sshServerPort) }, } sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "SSH Profile") sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH Server Port") + sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "Codespace Name") return sshCmd } @@ -39,7 +40,7 @@ func init() { rootCmd.AddCommand(NewSSHCmd()) } -func SSH(sshProfile string, sshServerPort int) error { +func SSH(sshProfile, codespaceName string, sshServerPort int) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() @@ -48,19 +49,37 @@ func SSH(sshProfile string, sshServerPort int) error { return fmt.Errorf("error getting user: %v", err) } - codespace, err := codespaces.ChooseCodespace(ctx, apiClient, user) - if err != nil { - if err == codespaces.ErrNoCodespaces { - fmt.Println(err.Error()) - return nil + var ( + codespace *api.Codespace + token string + ) + + if codespaceName == "" { + codespace, err = codespaces.ChooseCodespace(ctx, apiClient, user) + if err != nil { + if err == codespaces.ErrNoCodespaces { + fmt.Println(err.Error()) + return nil + } + + return fmt.Errorf("error choosing codespace: %v", err) + } + codespaceName = codespace.Name + + token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) + if err != nil { + return fmt.Errorf("error getting codespace token: %v", err) + } + } else { + token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) + if err != nil { + return fmt.Errorf("error getting codespace token: %v", err) } - return fmt.Errorf("error choosing codespace: %v", err) - } - - token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) - if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) + codespace, err = apiClient.GetCodespace(ctx, token, user.Login, codespaceName) + if err != nil { + return fmt.Errorf("error getting full codespace details: %v", err) + } } liveShareClient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) From 98282ba4b51085e03965672b1b124cc708bc6e82 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 26 Jul 2021 14:31:00 +0000 Subject: [PATCH 048/290] Update shared visibility tests --- rpc.go | 5 ----- server_test.go | 32 +++++++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/rpc.go b/rpc.go index d624bbd74..8abd0e98f 100644 --- a/rpc.go +++ b/rpc.go @@ -30,11 +30,6 @@ func (r *rpcClient) do(ctx context.Context, method string, args interface{}, res return fmt.Errorf("error on dispatch call: %v", err) } - // caller doesn't care about result, so lets ignore it - if result == nil { - return nil - } - return waiter.Wait(ctx, result) } diff --git a/server_test.go b/server_test.go index cc2b9adbd..8a736b6f5 100644 --- a/server_test.go +++ b/server_test.go @@ -182,5 +182,35 @@ func TestServerGetSharedServers(t *testing.T) { } func TestServerUpdateSharedVisibility(t *testing.T) { - + updateSharedVisibility := func(req *jsonrpc2.Request) error { + return nil + } + testServer, client, err := newMockJoinedClient( + livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), + ) + if err != nil { + t.Errorf("creating new mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("creating server: %v", err) + } + ctx := context.Background() + done := make(chan error) + go func() { + if err := server.UpdateSharedVisibility(ctx, 80, true); err != nil { + done <- err + return + } + done <- nil + }() + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } } From 892f73221c69d21f77d53e77eddd16f40c20c4ba Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 26 Jul 2021 14:39:52 +0000 Subject: [PATCH 049/290] Update shared visibility finalized tests --- server_test.go | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/server_test.go b/server_test.go index 8a736b6f5..7c9ee4288 100644 --- a/server_test.go +++ b/server_test.go @@ -182,8 +182,29 @@ func TestServerGetSharedServers(t *testing.T) { } func TestServerUpdateSharedVisibility(t *testing.T) { - updateSharedVisibility := func(req *jsonrpc2.Request) error { - return nil + updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) { + var req []interface{} + if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { + return nil, fmt.Errorf("unmarshal req: %v", err) + } + if len(req) < 2 { + return nil, errors.New("request arguments is less than 2") + } + if port, ok := req[0].(float64); ok { + if port != 80.0 { + return nil, errors.New("port param is not expected value") + } + } else { + return nil, errors.New("port param is not a float64") + } + if public, ok := req[1].(bool); ok { + if public != true { + return nil, errors.New("pulic param is not expected value") + } + } else { + return nil, errors.New("public param is not a bool") + } + return nil, nil } testServer, client, err := newMockJoinedClient( livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), From 3931c16bd765a301e71f7918bd14427a519a0e2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Mon, 26 Jul 2021 17:07:42 +0200 Subject: [PATCH 050/290] Provide version number at build time --- cmd/ghcs/main.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index aee1b2aec..00e8be894 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -11,11 +11,13 @@ func main() { Execute() } +var Version = "DEV" + var rootCmd = &cobra.Command{ Use: "ghcs", Short: "Unofficial GitHub Codespaces CLI.", Long: "Unofficial CLI tool to manage and interact with GitHub Codespaces.", - Version: "0.7.1", + Version: Version, } func Execute() { From 0ab67badfad20a67a73bb170647fa115538b2995 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 27 Jul 2021 23:19:55 +0000 Subject: [PATCH 051/290] Final changes to finish this refactor --- connection_test.go | 33 +++++++++++++ port_forwarder.go | 4 ++ port_forwarder_test.go | 102 +++++++++++++++++++++++++++++++++++++++++ server.go | 8 ++++ server_test.go | 10 ++-- socket.go | 20 ++------ test/server.go | 54 ++++++++++++++++++++-- 7 files changed, 206 insertions(+), 25 deletions(-) create mode 100644 connection_test.go diff --git a/connection_test.go b/connection_test.go new file mode 100644 index 000000000..e952290be --- /dev/null +++ b/connection_test.go @@ -0,0 +1,33 @@ +package liveshare + +import "testing" + +func TestConnectionValid(t *testing.T) { + conn := Connection{"sess-id", "sess-token", "sas", "endpoint"} + if err := conn.validate(); err != nil { + t.Error(err) + } +} + +func TestConnectionInvalid(t *testing.T) { + conn := Connection{"", "sess-token", "sas", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "", "sas", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "sess-token", "", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "sess-token", "sas", ""} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"", "", "", ""} + if err := conn.validate(); err == nil { + t.Error(err) + } +} diff --git a/port_forwarder.go b/port_forwarder.go index 6d459b4d6..0a049d586 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -54,8 +54,12 @@ func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { copyConn := func(writer io.Writer, reader io.Reader) { if _, err := io.Copy(writer, reader); err != nil { + fmt.Println(err) channel.Close() conn.Close() + if err != io.EOF { + l.errCh <- fmt.Errorf("tunnel connection: %v", err) + } } } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index e3e219705..33a33b39b 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -1 +1,103 @@ package liveshare + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "testing" + "time" + + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestNewPortForwarder(t *testing.T) { + testServer, client, err := makeMockJoinedClient() + if err != nil { + t.Errorf("create mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("create new server: %v", err) + } + pf := NewPortForwarder(client, server, 80) + if pf == nil { + t.Error("port forwarder is nil") + } +} + +func TestPortForwarderStart(t *testing.T) { + streamName, streamCondition := "stream-name", "stream-condition" + serverSharing := func(req *jsonrpc2.Request) (interface{}, error) { + return Port{StreamName: streamName, StreamCondition: streamCondition}, nil + } + getStream := func(req *jsonrpc2.Request) (interface{}, error) { + return "stream-id", nil + } + + stream := bytes.NewBufferString("stream-data") + testServer, client, err := makeMockJoinedClient( + livesharetest.WithService("serverSharing.startSharing", serverSharing), + livesharetest.WithService("streamManager.getStream", getStream), + livesharetest.WithStream("stream-id", stream), + ) + if err != nil { + t.Errorf("create mock client: %v", err) + } + defer testServer.Close() + + server, err := NewServer(client) + if err != nil { + t.Errorf("create new server: %v", err) + } + + ctx, _ := context.WithCancel(context.Background()) + pf := NewPortForwarder(client, server, 8000) + done := make(chan error) + + go func() { + if err := server.StartSharing(ctx, "http", 8000); err != nil { + done <- fmt.Errorf("start sharing: %v", err) + } + if err := pf.Start(ctx); err != nil { + done <- err + } + done <- nil + }() + + go func() { + var conn net.Conn + retries := 0 + for conn == nil && retries < 2 { + conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second) + time.Sleep(1 * time.Second) + } + if conn == nil { + done <- errors.New("failed to connect to forwarded port") + } + b := make([]byte, len("stream-data")) + if _, err := conn.Read(b); err != nil && err != io.EOF { + done <- fmt.Errorf("reading stream: %v", err) + } + if string(b) != "stream-data" { + done <- fmt.Errorf("stream data is not expected value, got: %v", string(b)) + } + if _, err := conn.Write([]byte("new-data")); err != nil { + done <- fmt.Errorf("writing to stream: %v", err) + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} diff --git a/server.go b/server.go index 6f17d5ac5..7e8c8b1cb 100644 --- a/server.go +++ b/server.go @@ -7,12 +7,14 @@ import ( "strconv" ) +// A Server represents the liveshare host and container server type Server struct { client *Client port int streamName, streamCondition string } +// NewServer creates a new Server with a given Client func NewServer(client *Client) (*Server, error) { if !client.hasJoined() { return nil, errors.New("client must join before creating server") @@ -21,6 +23,7 @@ func NewServer(client *Client) (*Server, error) { return &Server{client: client}, nil } +// Port represents an open port on the container type Port struct { SourcePort int `json:"sourcePort"` DestinationPort int `json:"destinationPort"` @@ -33,6 +36,7 @@ type Port struct { HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"` } +// StartSharing tells the liveshare host to start sharing the port from the container func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { s.port = port @@ -49,8 +53,10 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er return nil } +// Ports is a slice of Port pointers type Ports []*Port +// GetSharedServers returns a list of available/open ports from the container func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { var response Ports if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { @@ -60,6 +66,8 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { return response, nil } +// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly +// via the Browse URL func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error { if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil { return err diff --git a/server_test.go b/server_test.go index 7c9ee4288..b91fbfddc 100644 --- a/server_test.go +++ b/server_test.go @@ -23,7 +23,7 @@ func TestNewServerWithNotJoinedClient(t *testing.T) { } } -func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { +func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { connection := Connection{ SessionID: "session-id", SessionToken: "session-token", @@ -54,7 +54,7 @@ func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Ser } func TestNewServer(t *testing.T) { - testServer, client, err := newMockJoinedClient() + testServer, client, err := makeMockJoinedClient() defer testServer.Close() if err != nil { t.Errorf("error creating mock joined client: %v", err) @@ -95,7 +95,7 @@ func TestServerStartSharing(t *testing.T) { } return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil } - testServer, client, err := newMockJoinedClient( + testServer, client, err := makeMockJoinedClient( livesharetest.WithService("serverSharing.startSharing", startSharing), ) defer testServer.Close() @@ -138,7 +138,7 @@ func TestServerGetSharedServers(t *testing.T) { getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) { return Ports{&sharedServer}, nil } - testServer, client, err := newMockJoinedClient( + testServer, client, err := makeMockJoinedClient( livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), ) if err != nil { @@ -206,7 +206,7 @@ func TestServerUpdateSharedVisibility(t *testing.T) { } return nil, nil } - testServer, client, err := newMockJoinedClient( + testServer, client, err := makeMockJoinedClient( livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), ) if err != nil { diff --git a/socket.go b/socket.go index e4f80a0cf..8744eeb96 100644 --- a/socket.go +++ b/socket.go @@ -3,11 +3,9 @@ package liveshare import ( "context" "crypto/tls" - "errors" "io" "net" "net/http" - "sync" "time" "github.com/gorilla/websocket" @@ -17,10 +15,8 @@ type socket struct { addr string tlsConfig *tls.Config - conn *websocket.Conn - readMutex sync.Mutex - writeMutex sync.Mutex - reader io.Reader + conn *websocket.Conn + reader io.Reader } func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket { @@ -42,19 +38,12 @@ func (s *socket) connect(ctx context.Context) error { } func (s *socket) Read(b []byte) (int, error) { - s.readMutex.Lock() - defer s.readMutex.Unlock() - if s.reader == nil { - messageType, reader, err := s.conn.NextReader() + _, reader, err := s.conn.NextReader() if err != nil { return 0, err } - if messageType != websocket.BinaryMessage { - return 0, errors.New("unexpected websocket message type") - } - s.reader = reader } @@ -71,9 +60,6 @@ func (s *socket) Read(b []byte) (int, error) { } func (s *socket) Write(b []byte) (int, error) { - s.writeMutex.Lock() - defer s.writeMutex.Unlock() - nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage) if err != nil { return 0, err diff --git a/test/server.go b/test/server.go index abb7ac96a..a52d31ab9 100644 --- a/test/server.go +++ b/test/server.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "path/filepath" + "strings" "sync" "time" @@ -21,6 +22,7 @@ type Server struct { password string services map[string]RpcHandleFunc relaySAS string + streams map[string]io.ReadWriter sshConfig *ssh.ServerConfig httptestServer *httptest.Server @@ -50,7 +52,7 @@ func NewServer(opts ...ServerOption) (*Server, error) { server.sshConfig.AddHostKey(privateKey) server.errCh = make(chan error) - server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server))) + server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server))) return server, nil } @@ -81,6 +83,16 @@ func WithRelaySAS(sas string) ServerOption { } } +func WithStream(name string, stream io.ReadWriter) ServerOption { + return func(s *Server) error { + if s.streams == nil { + s.streams = make(map[string]io.ReadWriter) + } + s.streams[name] = stream + return nil + } +} + func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { if string(password) == serverPassword { @@ -104,7 +116,7 @@ func (s *Server) Err() <-chan error { var upgrader = websocket.Upgrader{} -func newConnection(server *Server) http.HandlerFunc { +func makeConnection(server *Server) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { if server.relaySAS != "" { // validate the sas key @@ -135,12 +147,48 @@ func newConnection(server *Server) http.HandlerFunc { server.errCh <- fmt.Errorf("error accepting new channel: %v", err) return } - go ssh.DiscardRequests(reqs) + go handleNewRequests(server, ch, reqs) go handleNewChannel(server, ch) } } } +func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { + for req := range reqs { + if req.WantReply { + if err := req.Reply(true, nil); err != nil { + server.errCh <- fmt.Errorf("error replying to channel request: %v", err) + } + } + if strings.HasPrefix(req.Type, "stream-transport") { + forwardStream(server, req.Type, channel) + } + } +} + +func forwardStream(server *Server, streamName string, channel ssh.Channel) { + simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") + stream, found := server.streams[simpleStreamName] + if !found { + server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName) + return + } + + copy := func(dst io.Writer, src io.Reader) { + if _, err := io.Copy(dst, src); err != nil { + fmt.Println(err) + server.errCh <- fmt.Errorf("io copy: %v", err) + return + } + } + + go copy(stream, channel) + go copy(channel, stream) + + for { + } +} + func handleNewChannel(server *Server, channel ssh.Channel) { stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server)) From cba40ad72aa8ba432a2aeb9115de9b1c70170324 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 28 Jul 2021 08:33:06 -0400 Subject: [PATCH 052/290] liveshare client upgrade --- cmd/ghcs/ports.go | 14 +++++++------- cmd/ghcs/ssh.go | 8 ++++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 6d2086088..77d1b00f7 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -107,8 +107,8 @@ func Ports() error { } -func getPorts(ctx context.Context, liveShareClient *liveshare.Client) (liveshare.Ports, error) { - server, err := liveShareClient.NewServer() +func getPorts(ctx context.Context, lsclient *liveshare.Client) (liveshare.Ports, error) { + server, err := liveshare.NewServer(lsclient) if err != nil { return nil, fmt.Errorf("error creating server: %v", err) } @@ -214,12 +214,12 @@ func updatePortVisibility(codespaceName, sourcePort string, public bool) error { return fmt.Errorf("error getting codespace: %v", err) } - liveShareClient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) if err != nil { return fmt.Errorf("error connecting to liveshare: %v", err) } - server, err := liveShareClient.NewServer() + server, err := liveshare.NewServer(lsclient) if err != nil { return fmt.Errorf("error creating server: %v", err) } @@ -276,12 +276,12 @@ func forwardPort(codespaceName, sourcePort, destPort string) error { return fmt.Errorf("error getting codespace: %v", err) } - liveShareClient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) if err != nil { return fmt.Errorf("error connecting to liveshare: %v", err) } - server, err := liveShareClient.NewServer() + server, err := liveshare.NewServer(lsclient) if err != nil { return fmt.Errorf("error creating server: %v", err) } @@ -301,7 +301,7 @@ func forwardPort(codespaceName, sourcePort, destPort string) error { } fmt.Println("Forwarding port: " + sourcePort + " -> " + destPort) - portForwarder := liveshare.NewLocalPortForwarder(liveShareClient, server, dstPortInt) + portForwarder := liveshare.NewPortForwarder(lsclient, server, dstPortInt) if err := portForwarder.Start(ctx); err != nil { return fmt.Errorf("error forwarding port: %v", err) } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index c3c1105e6..ef03ba946 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -82,12 +82,12 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { } } - liveShareClient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) if err != nil { return fmt.Errorf("error connecting to liveshare: %v", err) } - terminal, err := liveShareClient.NewTerminal() + terminal, err := liveshare.NewTerminal(lsclient) if err != nil { return fmt.Errorf("error creating liveshare terminal: %v", err) } @@ -106,7 +106,7 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { fmt.Printf("\n") } - server, err := liveShareClient.NewServer() + server, err := liveshare.NewServer(lsclient) if err != nil { return fmt.Errorf("error creating server: %v", err) } @@ -121,7 +121,7 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { return fmt.Errorf("error sharing sshd port: %v", err) } - portForwarder := liveshare.NewLocalPortForwarder(liveShareClient, server, port) + portForwarder := liveshare.NewPortForwarder(lsclient, server, port) go func() { if err := portForwarder.Start(ctx); err != nil { panic(fmt.Errorf("error forwarding port: %v", err)) From 9544f8acc9a3204756e56b4382eedc34ff0a132d Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 28 Jul 2021 09:05:58 -0400 Subject: [PATCH 053/290] Commit vendors --- api/api.go | 6 ++++-- internal/codespaces/codespaces.go | 15 +++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/api/api.go b/api/api.go index bf16260ae..83510d8c5 100644 --- a/api/api.go +++ b/api/api.go @@ -133,8 +133,10 @@ const ( ) type CodespaceEnvironmentConnection struct { - SessionID string `json:"sessionId"` - SessionToken string `json:"sessionToken"` + SessionID string `json:"sessionId"` + SessionToken string `json:"sessionToken"` + RelayEndpoint string `json:"relayEndpoint"` + RelaySAS string `json:"relaySas"` } func (a *API) ListCodespaces(ctx context.Context, user *User) (Codespaces, error) { diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index be290fab1..6c3517f39 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -93,18 +93,21 @@ func ConnectToLiveshare(ctx context.Context, apiClient *api.API, token string, c fmt.Println("Connecting to your codespace...") - liveShare, err := liveshare.New( - liveshare.WithWorkspaceID(codespace.Environment.Connection.SessionID), - liveshare.WithToken(codespace.Environment.Connection.SessionToken), + lsclient, err := liveshare.NewClient( + liveshare.WithConnection(liveshare.Connection{ + SessionID: codespace.Environment.Connection.SessionID, + SessionToken: codespace.Environment.Connection.SessionToken, + RelaySAS: codespace.Environment.Connection.RelaySAS, + RelayEndpoint: codespace.Environment.Connection.RelayEndpoint, + }), ) if err != nil { return nil, fmt.Errorf("error creating live share: %v", err) } - liveShareClient := liveShare.NewClient() - if err := liveShareClient.Join(ctx); err != nil { + if err := lsclient.Join(ctx); err != nil { return nil, fmt.Errorf("error joining liveshare client: %v", err) } - return liveShareClient, nil + return lsclient, nil } From 3a2ade23a4a154eb327c097214784aa95bd138c9 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 28 Jul 2021 13:52:30 +0000 Subject: [PATCH 054/290] Connection test --- connection_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/connection_test.go b/connection_test.go index e952290be..f42ec4189 100644 --- a/connection_test.go +++ b/connection_test.go @@ -31,3 +31,11 @@ func TestConnectionInvalid(t *testing.T) { t.Error(err) } } + +func TestConnectionURI(t *testing.T) { + conn := Connection{"sess-id", "sess-token", "sas", "sb://endpoint/.net/liveshare"} + uri := conn.uri("connect") + if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" { + t.Errorf("uri is not correct, got: '%v'", uri) + } +} From ae29c3c1ea7358504650e0c6fd07dc4a57cbb2a0 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 28 Jul 2021 13:55:33 +0000 Subject: [PATCH 055/290] Ignore EOF on terminal close --- terminal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/terminal.go b/terminal.go index 1621559a1..c26d9fd9f 100644 --- a/terminal.go +++ b/terminal.go @@ -106,7 +106,7 @@ func (t terminalReadCloser) Close() error { return fmt.Errorf("error making terminal.stopTerminal call: %v", err) } - if err := t.channel.Close(); err != nil { + if err := t.channel.Close(); err != nil && err != io.EOF { return fmt.Errorf("error closing channel: %v", err) } From 58a055609dea29874e5a4e1ba00a56897a1599a2 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 29 Jul 2021 10:57:51 -0400 Subject: [PATCH 056/290] logs cmd spike and refactor of ssh tunnel methods --- cmd/ghcs/logs.go | 95 ++++++++++++++++++++++++ cmd/ghcs/ssh.go | 87 +++++----------------- internal/codespaces/codespaces.go | 32 +++++++++ internal/codespaces/ssh.go | 116 ++++++++++++++++++++++++++++++ 4 files changed, 260 insertions(+), 70 deletions(-) create mode 100644 cmd/ghcs/logs.go create mode 100644 internal/codespaces/ssh.go diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go new file mode 100644 index 000000000..3696999d5 --- /dev/null +++ b/cmd/ghcs/logs.go @@ -0,0 +1,95 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "os" + + "github.com/github/ghcs/api" + "github.com/github/ghcs/internal/codespaces" + "github.com/spf13/cobra" +) + +func NewLogsCmd() *cobra.Command { + return &cobra.Command{ + Use: "logs", + Short: "Access Codespace logs", + RunE: func(cmd *cobra.Command, args []string) error { + var codespaceName string + if len(args) > 0 { + codespaceName = args[0] + } + return Logs(codespaceName) + }, + } +} + +func init() { + rootCmd.AddCommand(NewLogsCmd()) +} + +func Logs(codespaceName string) error { + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + ctx := context.Background() + + user, err := apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("getting user: %v", err) + } + + codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) + if err != nil { + return fmt.Errorf("get or choose codespace: %v", err) + } + + lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + if err != nil { + return fmt.Errorf("connecting to liveshare: %v", err) + } + + tunnelPort, connClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, 0) + if err != nil { + return fmt.Errorf("make ssh tunnel: %v", err) + } + + dst := fmt.Sprintf("%s@localhost", getSSHUser(codespace)) + stdout, err := codespaces.RunCommand( + ctx, tunnelPort, dst, "cat /workspaces/.codespaces/.persistedshare/creation.log", + ) + if err != nil { + return fmt.Errorf("run command: %v", err) + } + + done := make(chan error) + go func() { + scanner := bufio.NewScanner(stdout) + for scanner.Scan() { + fmt.Println(scanner.Text()) + } + + if err := scanner.Err(); err != nil { + done <- fmt.Errorf("error scanning: %v", err) + return + } + + if err := stdout.Close(); err != nil { + done <- fmt.Errorf("close stdout: %v", err) + return + } + done <- nil + }() + + select { + case err := <-connClosed: + if err != nil { + return fmt.Errorf("connection closed: %v", err) + } + case err := <-done: + if err != nil { + return err + } + } + + return nil +} diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index ef03ba946..23a4c2ca0 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -4,10 +4,7 @@ import ( "bufio" "context" "fmt" - "math/rand" "os" - "os/exec" - "strconv" "strings" "time" @@ -49,37 +46,9 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { return fmt.Errorf("error getting user: %v", err) } - var ( - codespace *api.Codespace - token string - ) - - if codespaceName == "" { - codespace, err = codespaces.ChooseCodespace(ctx, apiClient, user) - if err != nil { - if err == codespaces.ErrNoCodespaces { - fmt.Println(err.Error()) - return nil - } - - return fmt.Errorf("error choosing codespace: %v", err) - } - codespaceName = codespace.Name - - token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) - if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) - } - } else { - token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) - if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) - } - - codespace, err = apiClient.GetCodespace(ctx, token, user.Login, codespaceName) - if err != nil { - return fmt.Errorf("error getting full codespace details: %v", err) - } + codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) + if err != nil { + return fmt.Errorf("get or choose codespace: %v") } lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) @@ -106,56 +75,34 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { fmt.Printf("\n") } - server, err := liveshare.NewServer(lsclient) + tunnelPort, tunnelClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, sshServerPort) if err != nil { - return fmt.Errorf("error creating server: %v", err) + return fmt.Errorf("make ssh tunnel: %v", err) } - rand.Seed(time.Now().Unix()) - port := rand.Intn(9999-2000) + 2000 // improve this obviously - if sshServerPort != 0 { - port = sshServerPort - } - - if err := server.StartSharing(ctx, "sshd", 2222); err != nil { - return fmt.Errorf("error sharing sshd port: %v", err) - } - - portForwarder := liveshare.NewPortForwarder(lsclient, server, port) - go func() { - if err := portForwarder.Start(ctx); err != nil { - panic(fmt.Errorf("error forwarding port: %v", err)) - } - }() - connectDestination := sshProfile if connectDestination == "" { connectDestination = fmt.Sprintf("%s@localhost", getSSHUser(codespace)) } + usingCustomPort := tunnelPort == sshServerPort + connClosed := codespaces.ConnectToTunnel(ctx, tunnelPort, connectDestination, usingCustomPort) + fmt.Println("Ready...") - if err := connect(ctx, port, connectDestination, port == sshServerPort); err != nil { - return fmt.Errorf("error connecting via SSH: %v", err) + select { + case err := <-tunnelClosed: + if err != nil { + return fmt.Errorf("tunnel closed: %v", err) + } + case err := <-connClosed: + if err != nil { + return fmt.Errorf("connection closed: %v", err) + } } return nil } -func connect(ctx context.Context, port int, destination string, setServerPort bool) error { - connectionDetailArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} - - if setServerPort { - fmt.Println("Connection Details: ssh " + destination + " " + strings.Join(connectionDetailArgs, " ")) - } - - args := []string{destination, "-X", "-Y", "-C"} // X11, X11Trust, Compression - cmd := exec.CommandContext(ctx, "ssh", append(args, connectionDetailArgs...)...) - cmd.Stdout = os.Stdout - cmd.Stdin = os.Stdin - cmd.Stderr = os.Stderr - return cmd.Run() -} - func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) { fmt.Print(".") cmd := terminal.NewCommand( diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 6c3517f39..4c62d9aff 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -111,3 +111,35 @@ func ConnectToLiveshare(ctx context.Context, apiClient *api.API, token string, c return lsclient, nil } + +func GetOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { + if codespaceName == "" { + codespace, err = ChooseCodespace(ctx, apiClient, user) + if err != nil { + if err == ErrNoCodespaces { + fmt.Println(err.Error()) + return nil, "", nil + } + + return nil, "", fmt.Errorf("choosing codespace: %v", err) + } + codespaceName = codespace.Name + + token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) + if err != nil { + return nil, "", fmt.Errorf("getting codespace token: %v", err) + } + } else { + token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) + if err != nil { + return nil, "", fmt.Errorf("getting codespace token for given codespace: %v", err) + } + + codespace, err = apiClient.GetCodespace(ctx, token, user.Login, codespaceName) + if err != nil { + return nil, "", fmt.Errorf("getting full codespace details: %v", err) + } + } + + return codespace, token, nil +} diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go new file mode 100644 index 000000000..2bb661086 --- /dev/null +++ b/internal/codespaces/ssh.go @@ -0,0 +1,116 @@ +package codespaces + +import ( + "context" + "fmt" + "io" + "math/rand" + "os" + "os/exec" + "strconv" + "strings" + "time" + + "github.com/github/go-liveshare" +) + +func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, serverPort int) (int, <-chan error, error) { + tunnelClosed := make(chan error) + + server, err := liveshare.NewServer(lsclient) + if err != nil { + return 0, nil, fmt.Errorf("new liveshare server: %v", err) + } + + rand.Seed(time.Now().Unix()) + port := rand.Intn(9999-2000) + 2000 // improve this obviously + if serverPort != 0 { + port = serverPort + } + + // TODO(josebalius): This port won't always be 2222 + if err := server.StartSharing(ctx, "sshd", 2222); err != nil { + return 0, nil, fmt.Errorf("sharing sshd port: %v", err) + } + + go func() { + portForwarder := liveshare.NewPortForwarder(lsclient, server, port) + if err := portForwarder.Start(ctx); err != nil { + tunnelClosed <- fmt.Errorf("forwarding port: %v", err) + return + } + tunnelClosed <- nil + }() + + return port, tunnelClosed, nil +} + +func makeSSHArgs(port int, dst, cmd string) ([]string, []string) { + connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} + cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression + + if cmd != "" { + cmdArgs = append(cmdArgs, cmd) + } + + return cmdArgs, connArgs +} + +func ConnectToTunnel(ctx context.Context, port int, destination string, usingCustomPort bool) <-chan error { + connClosed := make(chan error) + args, connArgs := makeSSHArgs(port, destination, "") + + if usingCustomPort { + fmt.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " ")) + } + + cmd := exec.CommandContext(ctx, "ssh", args...) + cmd.Stdout = os.Stdout + cmd.Stdin = os.Stdin + cmd.Stderr = os.Stderr + + go func() { + connClosed <- cmd.Run() + }() + + return connClosed +} + +type command struct { + Cmd *exec.Cmd + StdoutPipe io.ReadCloser +} + +func newCommand(cmd *exec.Cmd) (*command, error) { + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("create stdout pipe: %v", err) + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("cmd start: %v", err) + } + + return &command{ + Cmd: cmd, + StdoutPipe: stdoutPipe, + }, nil +} + +func (c *command) Read(p []byte) (int, error) { + return c.StdoutPipe.Read(p) +} + +func (c *command) Close() error { + if err := c.StdoutPipe.Close(); err != nil { + return fmt.Errorf("close stdout: %v", err) + } + + return c.Cmd.Wait() +} + +func RunCommand(ctx context.Context, tunnelPort int, destination, cmdString string) (io.ReadCloser, error) { + args, _ := makeSSHArgs(tunnelPort, destination, cmdString) + cmd := exec.CommandContext(ctx, "ssh", args...) + return newCommand(cmd) +} From be794f1579e8839a75c8cd3149bd28bede0ce63e Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 29 Jul 2021 17:09:50 +0000 Subject: [PATCH 057/290] creation log support for cat and tail --- cmd/ghcs/logs.go | 19 +++++++++++++++---- cmd/ghcs/ssh.go | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 3696999d5..03a7c963a 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -12,7 +12,9 @@ import ( ) func NewLogsCmd() *cobra.Command { - return &cobra.Command{ + var tail bool + + logsCmd := &cobra.Command{ Use: "logs", Short: "Access Codespace logs", RunE: func(cmd *cobra.Command, args []string) error { @@ -20,16 +22,20 @@ func NewLogsCmd() *cobra.Command { if len(args) > 0 { codespaceName = args[0] } - return Logs(codespaceName) + return Logs(tail, codespaceName) }, } + + logsCmd.Flags().BoolVarP(&tail, "tail", "t", false, "Tail the logs") + + return logsCmd } func init() { rootCmd.AddCommand(NewLogsCmd()) } -func Logs(codespaceName string) error { +func Logs(tail bool, codespaceName string) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() @@ -53,9 +59,14 @@ func Logs(codespaceName string) error { return fmt.Errorf("make ssh tunnel: %v", err) } + cmdType := "cat" + if tail { + cmdType = "tail -f" + } + dst := fmt.Sprintf("%s@localhost", getSSHUser(codespace)) stdout, err := codespaces.RunCommand( - ctx, tunnelPort, dst, "cat /workspaces/.codespaces/.persistedshare/creation.log", + ctx, tunnelPort, dst, fmt.Sprintf("%v /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) if err != nil { return fmt.Errorf("run command: %v", err) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 23a4c2ca0..60fdee498 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -48,7 +48,7 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { - return fmt.Errorf("get or choose codespace: %v") + return fmt.Errorf("get or choose codespace: %v", err) } lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) From e57b390d4a75e4b97e304b65a122f5874b1e14c7 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 3 Aug 2021 13:42:34 +0000 Subject: [PATCH 058/290] dotfiles status spike --- cmd/ghcs/create.go | 33 +++++++++- internal/codespaces/dotfiles.go | 109 ++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 1 deletion(-) create mode 100644 internal/codespaces/dotfiles.go diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 385e5d957..790364c19 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -9,6 +9,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/fatih/camelcase" "github.com/github/ghcs/api" + "github.com/github/ghcs/internal/codespaces" "github.com/spf13/cobra" ) @@ -17,7 +18,7 @@ var repo, branch, machine string func newCreateCmd() *cobra.Command { createCmd := &cobra.Command{ Use: "create", - Short: "Create a GitHub Codespace.", + Short: "Create a GitHub Codespace.", RunE: func(cmd *cobra.Command, args []string) error { return Create() }, @@ -80,6 +81,36 @@ func Create() error { return fmt.Errorf("error creating codespace: %v", err) } + states, err := codespaces.PollPostCreateStates(ctx, apiClient, userResult.User, codespace) + if err != nil { + return fmt.Errorf("poll post create states: %v", err) + } + + for { + select { + case stateUpdate := <-states: + if stateUpdate.Err != nil { + return fmt.Errorf("receive state update: %v", err) + } + + var inProgress bool + for _, state := range stateUpdate.PostCreateStates { + fmt.Print(state.Name) + switch state.Status { + case codespaces.PostCreateStateRunning: + inProgress = true + case codespaces.PostCreateStateFailed: + fmt.Print("...Failed") + } + fmt.Print("\n") + } + + if !inProgress { + break + } + } + } + fmt.Println("Codespace created: " + codespace.Name) return nil diff --git a/internal/codespaces/dotfiles.go b/internal/codespaces/dotfiles.go new file mode 100644 index 000000000..75411c35f --- /dev/null +++ b/internal/codespaces/dotfiles.go @@ -0,0 +1,109 @@ +package codespaces + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "time" + + "github.com/github/ghcs/api" +) + +type PostCreateStateStatus string + +const ( + PostCreateStateRunning PostCreateStateStatus = "running" + PostCreateStateSuccess PostCreateStateStatus = "succeeded" + PostCreateStateFailed PostCreateStateStatus = "failed" +) + +type PostCreateStatesResult struct { + PostCreateStates PostCreateStates + Err error +} + +type PostCreateStates []*PostCreateState + +type PostCreateState struct { + Name string `json:"name"` + Status PostCreateStateStatus `json:"status"` +} + +func PollPostCreateStates(ctx context.Context, apiClient *api.API, user *api.User, codespace *api.Codespace) (<-chan PostCreateStatesResult, error) { + pollch := make(chan PostCreateStatesResult) + + token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) + if err != nil { + return nil, fmt.Errorf("getting codespace token: %v", err) + } + + lsclient, err := ConnectToLiveshare(ctx, apiClient, token, codespace) + if err != nil { + return nil, fmt.Errorf("connect to liveshare: %v", err) + } + + tunnelPort, connClosed, err := MakeSSHTunnel(ctx, lsclient, 0) + if err != nil { + return nil, fmt.Errorf("make ssh tunnel: %v", err) + } + + go func() { + t := time.NewTicker(1 * time.Second) + for { + select { + case <-ctx.Done(): + return + case err := <-connClosed: + if err != nil { + pollch <- PostCreateStatesResult{Err: fmt.Errorf("connection closed: %v", err)} + return + } + case <-t.C: + states, err := getPostCreateOutput(ctx, tunnelPort, codespace) + if err != nil { + pollch <- PostCreateStatesResult{Err: fmt.Errorf("get post create output: %v", err)} + return + } + + pollch <- PostCreateStatesResult{ + PostCreateStates: states, + } + } + } + }() + + return pollch, nil +} + +func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace) (PostCreateStates, error) { + stdout, err := RunCommand( + ctx, tunnelPort, sshDestination(codespace), + "cat /workspaces/.codespaces/shared/postCreateOutput.json", + ) + if err != nil { + return nil, fmt.Errorf("run command: %v", err) + } + + b, err := ioutil.ReadAll(stdout) + if err != nil { + return nil, fmt.Errorf("read output: %v", err) + } + + output := struct { + Steps PostCreateStates `json:"steps"` + }{} + if err := json.Unmarshal(b, &output); err != nil { + return nil, fmt.Errorf("unmarshal output: %v", err) + } + + return output.Steps, nil +} + +func sshDestination(codespace *api.Codespace) string { + user := "codespace" + if codespace.RepositoryNWO == "github/github" { + user = "root" + } + return user + "@localhost" +} From d5003334e36b94d5b81f27ac890ce6ebe4f06bf7 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 3 Aug 2021 13:43:09 +0000 Subject: [PATCH 059/290] Remove secrets export --- cmd/ghcs/ssh.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 60fdee498..a47d3d309 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -133,7 +133,7 @@ func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, } func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID, repositoryName string) error { - setupBashProfileCmd := fmt.Sprintf(`echo "cd /workspaces/%v; export $(cat /workspaces/.codespaces/shared/.env | xargs); exec /bin/zsh;" > /home/codespace/.bash_profile`, repositoryName) + setupBashProfileCmd := fmt.Sprintf(`echo "cd /workspaces/%v; exec /bin/zsh;" > /home/codespace/.bash_profile`, repositoryName) fmt.Print(".") compositeCommand := []string{setupBashProfileCmd} From 70f4a7b4b5dabb2133a172123c687bd2708a4ed5 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 4 Aug 2021 13:19:00 +0000 Subject: [PATCH 060/290] Re-introduce secrets export --- cmd/ghcs/ssh.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index a47d3d309..60fdee498 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -133,7 +133,7 @@ func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, } func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID, repositoryName string) error { - setupBashProfileCmd := fmt.Sprintf(`echo "cd /workspaces/%v; exec /bin/zsh;" > /home/codespace/.bash_profile`, repositoryName) + setupBashProfileCmd := fmt.Sprintf(`echo "cd /workspaces/%v; export $(cat /workspaces/.codespaces/shared/.env | xargs); exec /bin/zsh;" > /home/codespace/.bash_profile`, repositoryName) fmt.Print(".") compositeCommand := []string{setupBashProfileCmd} From 140a54a009b27c41fe7d602b5fbada129fcacd4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 4 Aug 2021 15:56:10 +0200 Subject: [PATCH 061/290] Add machine-readable output formats - Default table output (when stdout is attached to a terminal) stays the same; - When stdout is redirected, output tab-separated values and no header line; - With `--json` flag, output structured JSON data. Example: $ ghcs list --json [ { "Branch": "main", "Created At": "2021-06-10T15:04:46+02:00", "Name": "mislav-playground-jvqj", "Repository": "mislav/playground", "State": "Shutdown" }, { "Branch": "master", "Created At": "2021-07-15T15:51:08+02:00", "Name": "mislav-github-github-pwgg365xv", "Repository": "github/github", "State": "Shutdown" } ] --- cmd/ghcs/delete.go | 6 +++--- cmd/ghcs/list.go | 17 ++++++++++++----- cmd/ghcs/output/format_json.go | 33 +++++++++++++++++++++++++++++++++ cmd/ghcs/output/format_table.go | 28 ++++++++++++++++++++++++++++ cmd/ghcs/output/format_tsv.go | 25 +++++++++++++++++++++++++ 5 files changed, 101 insertions(+), 8 deletions(-) create mode 100644 cmd/ghcs/output/format_json.go create mode 100644 cmd/ghcs/output/format_table.go create mode 100644 cmd/ghcs/output/format_tsv.go diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 2694e9cdf..5cec0f80b 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -73,7 +73,7 @@ func Delete(codespaceName string) error { fmt.Println("Codespace deleted.") - return List() + return List(&ListOptions{}) } func DeleteAll() error { @@ -103,7 +103,7 @@ func DeleteAll() error { fmt.Printf("Codespace deleted: %s\n", c.Name) } - return List() + return List(&ListOptions{}) } func DeleteByRepo(repo string) error { @@ -143,5 +143,5 @@ func DeleteByRepo(repo string) error { fmt.Printf("No codespace was found for repository: %s\n", repo) } - return List() + return List(&ListOptions{}) } diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 6db79af97..bbee2aae8 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -5,21 +5,28 @@ import ( "fmt" "os" - "github.com/olekukonko/tablewriter" - "github.com/github/ghcs/api" + "github.com/github/ghcs/cmd/ghcs/output" "github.com/spf13/cobra" ) +type ListOptions struct { + AsJSON bool +} + func NewListCmd() *cobra.Command { + opts := &ListOptions{} + listCmd := &cobra.Command{ Use: "list", Short: "List GitHub Codespaces you have on your account.", RunE: func(cmd *cobra.Command, args []string) error { - return List() + return List(opts) }, } + listCmd.Flags().BoolVar(&opts.AsJSON, "json", false, "Output as JSON") + return listCmd } @@ -27,7 +34,7 @@ func init() { rootCmd.AddCommand(NewListCmd()) } -func List() error { +func List(opts *ListOptions) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() @@ -46,7 +53,7 @@ func List() error { return nil } - table := tablewriter.NewWriter(os.Stdout) + table := output.NewTable(os.Stdout, opts.AsJSON) table.SetHeader([]string{"Name", "Repository", "Branch", "State", "Created At"}) for _, codespace := range codespaces { table.Append([]string{ diff --git a/cmd/ghcs/output/format_json.go b/cmd/ghcs/output/format_json.go new file mode 100644 index 000000000..37208629c --- /dev/null +++ b/cmd/ghcs/output/format_json.go @@ -0,0 +1,33 @@ +package output + +import ( + "encoding/json" + "io" +) + +type jsonwriter struct { + w io.Writer + pretty bool + cols []string + data []interface{} +} + +func (j *jsonwriter) SetHeader(cols []string) { + j.cols = cols +} + +func (j *jsonwriter) Append(values []string) { + row := make(map[string]string) + for i, v := range values { + row[j.cols[i]] = v + } + j.data = append(j.data, row) +} + +func (j *jsonwriter) Render() { + enc := json.NewEncoder(j.w) + if j.pretty { + enc.SetIndent("", " ") + } + _ = enc.Encode(j.data) +} diff --git a/cmd/ghcs/output/format_table.go b/cmd/ghcs/output/format_table.go new file mode 100644 index 000000000..97e7cab58 --- /dev/null +++ b/cmd/ghcs/output/format_table.go @@ -0,0 +1,28 @@ +package output + +import ( + "io" + "os" + + "github.com/olekukonko/tablewriter" + "golang.org/x/term" +) + +type Table interface { + SetHeader([]string) + Append([]string) + Render() +} + +func NewTable(w io.Writer, asJSON bool) Table { + f, ok := w.(*os.File) + isTTY := ok && term.IsTerminal(int(f.Fd())) + + if asJSON { + return &jsonwriter{w: w, pretty: isTTY} + } + if isTTY { + return tablewriter.NewWriter(w) + } + return &tabwriter{w: w} +} diff --git a/cmd/ghcs/output/format_tsv.go b/cmd/ghcs/output/format_tsv.go new file mode 100644 index 000000000..3f1d226ca --- /dev/null +++ b/cmd/ghcs/output/format_tsv.go @@ -0,0 +1,25 @@ +package output + +import ( + "fmt" + "io" +) + +type tabwriter struct { + w io.Writer +} + +func (j *tabwriter) SetHeader([]string) {} + +func (j *tabwriter) Append(values []string) { + var sep string + for i, v := range values { + if i == 1 { + sep = "\t" + } + fmt.Fprintf(j.w, "%s%s", sep, v) + } + fmt.Fprint(j.w, "\n") +} + +func (j *tabwriter) Render() {} From 76aca39f5bc56e01fc07628903e940d75ae5064d Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 4 Aug 2021 17:35:11 +0000 Subject: [PATCH 062/290] Create status support --- cmd/ghcs/create.go | 31 ++++++++++++++++--- cmd/ghcs/logs.go | 2 +- cmd/ghcs/ports.go | 6 ++-- cmd/ghcs/ssh.go | 7 ++--- internal/codespaces/codespaces.go | 11 ++++--- .../codespaces/{dotfiles.go => states.go} | 4 +-- 6 files changed, 41 insertions(+), 20 deletions(-) rename internal/codespaces/{dotfiles.go => states.go} (95%) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 790364c19..236ed0e01 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -86,6 +86,10 @@ func Create() error { return fmt.Errorf("poll post create states: %v", err) } + var lastState codespaces.PostCreateState + var breakNextState bool + +PollStates: for { select { case stateUpdate := <-states: @@ -95,18 +99,35 @@ func Create() error { var inProgress bool for _, state := range stateUpdate.PostCreateStates { - fmt.Print(state.Name) switch state.Status { case codespaces.PostCreateStateRunning: + if lastState != state { + lastState = state + fmt.Print(state.Name) + } else { + fmt.Print(".") + } + inProgress = true + break case codespaces.PostCreateStateFailed: - fmt.Print("...Failed") + if lastState.Name == state.Name && lastState.Status != state.Status { + lastState = state + fmt.Print(".Failed\n") + } + case codespaces.PostCreateStateSuccess: + if lastState.Name == state.Name && lastState.Status != state.Status { + lastState = state + fmt.Print(".Success\n") + } } - fmt.Print("\n") } - if !inProgress { - break + switch { + case !inProgress && !breakNextState: + breakNextState = true + case !inProgress && breakNextState: + break PollStates } } } diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 03a7c963a..c8c95182e 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -49,7 +49,7 @@ func Logs(tail bool, codespaceName string) error { return fmt.Errorf("get or choose codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("connecting to liveshare: %v", err) } diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 77d1b00f7..e573df483 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -61,7 +61,7 @@ func Ports() error { return fmt.Errorf("error getting codespace token: %v", err) } - liveShareClient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + liveShareClient, err := codespaces.ConnectToLiveshare(ctx, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to liveshare: %v", err) } @@ -214,7 +214,7 @@ func updatePortVisibility(codespaceName, sourcePort string, public bool) error { return fmt.Errorf("error getting codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to liveshare: %v", err) } @@ -276,7 +276,7 @@ func forwardPort(codespaceName, sourcePort, destPort string) error { return fmt.Errorf("error getting codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to liveshare: %v", err) } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 60fdee498..caf85c576 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -51,7 +51,7 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { return fmt.Errorf("get or choose codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to liveshare: %v", err) } @@ -61,7 +61,7 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { return fmt.Errorf("error creating liveshare terminal: %v", err) } - fmt.Println("Preparing SSH...") + fmt.Print("Preparing SSH...") if sshProfile == "" { containerID, err := getContainerID(ctx, terminal) if err != nil { @@ -71,9 +71,8 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { if err := setupSSH(ctx, terminal, containerID, codespace.RepositoryName); err != nil { return fmt.Errorf("error creating ssh server: %v", err) } - - fmt.Printf("\n") } + fmt.Print("\n") tunnelPort, tunnelClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, sshServerPort) if err != nil { diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 4c62d9aff..5c40c9931 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -57,9 +57,11 @@ func ChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (* return codespace, nil } -func ConnectToLiveshare(ctx context.Context, apiClient *api.API, token string, codespace *api.Codespace) (client *liveshare.Client, err error) { +func ConnectToLiveshare(ctx context.Context, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (client *liveshare.Client, err error) { + var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { - fmt.Println("Starting your codespace...") // TODO(josebalius): better way of notifying of events + startedCodespace = true + fmt.Print("Starting your codespace...") // TODO(josebalius): better way of notifying of events if err := apiClient.StartCodespace(ctx, token, codespace); err != nil { return nil, fmt.Errorf("error starting codespace: %v", err) } @@ -79,7 +81,7 @@ func ConnectToLiveshare(ctx context.Context, apiClient *api.API, token string, c return nil, errors.New("timed out while waiting for the codespace to start") } - codespace, err = apiClient.GetCodespace(ctx, token, codespace.OwnerLogin, codespace.Name) + codespace, err = apiClient.GetCodespace(ctx, token, userLogin, codespace.Name) if err != nil { return nil, fmt.Errorf("error getting codespace: %v", err) } @@ -87,10 +89,9 @@ func ConnectToLiveshare(ctx context.Context, apiClient *api.API, token string, c retries += 1 } - if retries >= 2 { + if startedCodespace { fmt.Print("\n") } - fmt.Println("Connecting to your codespace...") lsclient, err := liveshare.NewClient( diff --git a/internal/codespaces/dotfiles.go b/internal/codespaces/states.go similarity index 95% rename from internal/codespaces/dotfiles.go rename to internal/codespaces/states.go index 75411c35f..e16e4fc7d 100644 --- a/internal/codespaces/dotfiles.go +++ b/internal/codespaces/states.go @@ -23,7 +23,7 @@ type PostCreateStatesResult struct { Err error } -type PostCreateStates []*PostCreateState +type PostCreateStates []PostCreateState type PostCreateState struct { Name string `json:"name"` @@ -38,7 +38,7 @@ func PollPostCreateStates(ctx context.Context, apiClient *api.API, user *api.Use return nil, fmt.Errorf("getting codespace token: %v", err) } - lsclient, err := ConnectToLiveshare(ctx, apiClient, token, codespace) + lsclient, err := ConnectToLiveshare(ctx, apiClient, user.Login, token, codespace) if err != nil { return nil, fmt.Errorf("connect to liveshare: %v", err) } From 4362b0b241201ed90678b44ac8da6ff075784642 Mon Sep 17 00:00:00 2001 From: Issy Long Date: Wed, 4 Aug 2021 17:39:51 +0100 Subject: [PATCH 063/290] cmd/ghcs/delete: Display the interactive menu when there are no args - Currently the flow to delete a single Codespace is `gh cs list`, copy and paste the Codespace name onto the end of `gh cs delete`. - This improves consistency with other commands by letting the user choose which Codespace they want to delete, interactively. A Codespace name on the command-line still works too. --- cmd/ghcs/delete.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 2694e9cdf..f374ef7e6 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -7,18 +7,20 @@ import ( "os" "github.com/github/ghcs/api" + "github.com/github/ghcs/internal/codespaces" "github.com/spf13/cobra" ) func NewDeleteCmd() *cobra.Command { deleteCmd := &cobra.Command{ - Use: "delete CODESPACE_NAME", + Use: "delete", Short: "Delete a GitHub Codespace.", RunE: func(cmd *cobra.Command, args []string) error { - if len(args) == 0 { - return errors.New("A Codespace name is required.") + var codespaceName string + if len(args) > 0 { + codespaceName = args[0] } - return Delete(args[0]) + return Delete(codespaceName) }, } @@ -62,12 +64,12 @@ func Delete(codespaceName string) error { return fmt.Errorf("error getting user: %v", err) } - token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) + codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) + return fmt.Errorf("get or choose codespace: %v", err) } - if err := apiClient.DeleteCodespace(ctx, user, token, codespaceName); err != nil { + if err := apiClient.DeleteCodespace(ctx, user, token, codespace.Name); err != nil { return fmt.Errorf("error deleting codespace: %v", err) } From 619862a46bb99cebf4d7698dca741ffb2c596d4b Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 5 Aug 2021 15:21:26 +0000 Subject: [PATCH 064/290] initial spike for multiple port support --- cmd/ghcs/ports.go | 79 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 20 deletions(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 77d1b00f7..b51490ddc 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -15,6 +15,7 @@ import ( "github.com/muhammadmuzzammil1998/jsonc" "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) func NewPortsCmd() *cobra.Command { @@ -249,18 +250,23 @@ func NewPortsForwardCmd() *cobra.Command { Short: "forward", Long: "forward", RunE: func(cmd *cobra.Command, args []string) error { - if len(args) < 3 { - return errors.New("[codespace_name] [source] [dst] port number are required.") + if len(args) < 2 { + return errors.New("[codespace_name] [source]:[dst] port number are required.") } - return forwardPort(args[0], args[1], args[2]) + return forwardPort(args[0], args[1:]) }, } } -func forwardPort(codespaceName, sourcePort, destPort string) error { +func forwardPort(codespaceName string, ports []string) error { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + portPairs, err := getPortPairs(ports) + if err != nil { + return fmt.Errorf("get port pairs: %v", err) + } + user, err := apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %v", err) @@ -286,25 +292,58 @@ func forwardPort(codespaceName, sourcePort, destPort string) error { return fmt.Errorf("error creating server: %v", err) } - sourcePortInt, err := strconv.Atoi(sourcePort) - if err != nil { - return fmt.Errorf("error reading source port: %v", err) + g, gctx := errgroup.WithContext(ctx) + for _, portPair := range portPairs { + portPair := portPair + + srcstr := strconv.Itoa(portPair.Src) + if err := server.StartSharing(gctx, "share-"+srcstr, portPair.Src); err != nil { + return fmt.Errorf("start sharing port: %v", err) + } + + g.Go(func() error { + fmt.Println("Forwarding port: " + srcstr + " ==> " + strconv.Itoa(portPair.Dst)) + portForwarder := liveshare.NewPortForwarder(lsclient, server, portPair.Dst) + if err := portForwarder.Start(gctx); err != nil { + return fmt.Errorf("error forwarding port: %v", err) + } + + return nil + }) } - dstPortInt, err := strconv.Atoi(destPort) - if err != nil { - return fmt.Errorf("error reading destination port: %v", err) - } - - if err := server.StartSharing(ctx, "share-"+sourcePort, sourcePortInt); err != nil { - return fmt.Errorf("error sharing source port: %v", err) - } - - fmt.Println("Forwarding port: " + sourcePort + " -> " + destPort) - portForwarder := liveshare.NewPortForwarder(lsclient, server, dstPortInt) - if err := portForwarder.Start(ctx); err != nil { - return fmt.Errorf("error forwarding port: %v", err) + if err := g.Wait(); err != nil { + return err } return nil } + +type portPair struct { + Src, Dst int +} + +func getPortPairs(ports []string) ([]portPair, error) { + pp := make([]portPair, 0, len(ports)) + + for _, portString := range ports { + parts := strings.Split(portString, ":") + if len(parts) < 2 { + return pp, fmt.Errorf("port pair: '%v' is not valid", portString) + } + + srcp, err := strconv.Atoi(parts[0]) + if err != nil { + return pp, fmt.Errorf("convert source port to int: %v", err) + } + + dstp, err := strconv.Atoi(parts[1]) + if err != nil { + return pp, fmt.Errorf("convert dest port to int: %v", err) + } + + pp = append(pp, portPair{srcp, dstp}) + } + + return pp, nil +} From fbf0d286729dd355889a8caad0b80be71e4ae601 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 6 Aug 2021 01:03:03 +0000 Subject: [PATCH 065/290] port forwarding err handling and test refactors --- example/main.go | 120 ---------------------------------------------- port_forwarder.go | 21 +++++--- test/server.go | 105 +++++++++++----------------------------- test/socket.go | 77 +++++++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 202 deletions(-) delete mode 100644 example/main.go create mode 100644 test/socket.go diff --git a/example/main.go b/example/main.go deleted file mode 100644 index e9347bd14..000000000 --- a/example/main.go +++ /dev/null @@ -1,120 +0,0 @@ -package main - -import ( - "bufio" - "context" - "flag" - "fmt" - "log" - "os" - "time" - - "github.com/github/go-liveshare" -) - -var workspaceIdFlag = flag.String("w", "", "workspace session id") - -func init() { - flag.Parse() -} - -func main() { - liveShare, err := liveshare.New( - liveshare.WithWorkspaceID(*workspaceIdFlag), - liveshare.WithToken(os.Getenv("CODESPACE_TOKEN")), - ) - if err != nil { - log.Fatal(fmt.Errorf("error creating liveshare: %v", err)) - } - - ctx := context.Background() - liveShareClient := liveShare.NewClient() - if err := liveShareClient.Join(ctx); err != nil { - log.Fatal(fmt.Errorf("error joining liveshare with client: %v", err)) - } - - terminal, err := liveShareClient.NewTerminal() - if err != nil { - log.Fatal(fmt.Errorf("error creating liveshare terminal")) - } - - containerID, err := getContainerID(ctx, terminal) - if err != nil { - log.Fatal(fmt.Errorf("error getting container id: %v", err)) - } - - if err := setupSSH(ctx, terminal, containerID); err != nil { - log.Fatal(fmt.Errorf("error setting up ssh: %v", err)) - } - - fmt.Println("Starting server...") - - server, err := liveShareClient.NewServer() - if err != nil { - log.Fatal(fmt.Errorf("error creating server: %v", err)) - } - - fmt.Println("Starting sharing...") - if err := server.StartSharing(ctx, "sshd", 2222); err != nil { - log.Fatal(fmt.Errorf("error server sharing: %v", err)) - } - - portForwarder := liveshare.NewLocalPortForwarder(liveShareClient, server, 2222) - - fmt.Println("Listening on port 2222") - if err := portForwarder.Start(ctx); err != nil { - log.Fatal(fmt.Errorf("error forwarding port: %v", err)) - } -} - -func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID string) error { - cmd := terminal.NewCommand( - "/", - fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c \"echo -e \\\"testpwd1\\ntestpwd1\\n\\\" | sudo passwd codespace;/usr/local/share/ssh-init.sh\"", containerID), - ) - stream, err := cmd.Run(ctx) - if err != nil { - return fmt.Errorf("error running command: %v", err) - } - - scanner := bufio.NewScanner(stream) - scanner.Scan() - - fmt.Println("> Debug:", scanner.Text()) - if err := scanner.Err(); err != nil { - return fmt.Errorf("error scanning stream: %v", err) - } - - if err := stream.Close(); err != nil { - return fmt.Errorf("error closing stream: %v", err) - } - - time.Sleep(2 * time.Second) - - return nil -} - -func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) { - cmd := terminal.NewCommand( - "/", - "/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running", - ) - stream, err := cmd.Run(ctx) - if err != nil { - return "", fmt.Errorf("error running command: %v", err) - } - - scanner := bufio.NewScanner(stream) - scanner.Scan() - - containerID := scanner.Text() - if err := scanner.Err(); err != nil { - return "", fmt.Errorf("error scanning stream: %v", err) - } - - if err := stream.Close(); err != nil { - return "", fmt.Errorf("error closing stream: %v", err) - } - - return containerID, nil -} diff --git a/port_forwarder.go b/port_forwarder.go index 0a049d586..3a73e3fce 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -33,13 +33,22 @@ func (l *PortForwarder) Start(ctx context.Context) error { return fmt.Errorf("error listening on tcp port: %v", err) } - for { - conn, err := ln.Accept() - if err != nil { - return fmt.Errorf("error accepting incoming connection: %v", err) - } + go func() { + for { + conn, err := ln.Accept() + if err != nil { + l.errCh <- fmt.Errorf("error accepting incoming connection: %v", err) + } - go l.handleConnection(ctx, conn) + go l.handleConnection(ctx, conn) + } + }() + + select { + case err := <-l.errCh: + return err + case <-ctx.Done(): + return ln.Close() } return nil diff --git a/test/server.go b/test/server.go index a52d31ab9..159a2a982 100644 --- a/test/server.go +++ b/test/server.go @@ -5,19 +5,43 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" - "path/filepath" "strings" - "sync" - "time" "github.com/gorilla/websocket" "github.com/sourcegraph/jsonrpc2" "golang.org/x/crypto/ssh" ) +const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAp/Jmzy/HaPNx5Bug09FX5Q/KGY4G9c4DfplhWrn31OQCqNiT +ZSLd46rdXC75liHzE7e5Ic0RJN61cYN9SNArjvEXx2vvs7szhwO7LonwPOvpYpUf +daayrgbr6S46plpx+hEZ1kO/6BqMgFuvnkIVThrEyx5b48ll8zgDABsYrKF8/p1V +SjGfb+bLwjn1NtnZF2prBG5P4ZtMR06HaPglLqBJhmc0ZMG5IZGUE7ew/VrPDqdC +f1v4XvvGiU4BLoKYy4QOhyrCGh9Uk/9u0Ea56M2bh4RqwhbpR8m7TYJZ0DVMLbGW +8C+4lCWp+xRyBNxAQh8qeQVCxYl02hPE4bXLGQIDAQABAoIBAEoVPk6UZ+UexhV2 +LnphNOFhFqgxI1bYWmhE5lHsCKuLLLUoW9RYDgL4gw6/1e7o6N3AxFRpre9Soj0B +YIl28k/qf6/DKAhjQnaDKdV8mVF2Swvmdesi7lyfxv6kGtD4wqApXPlMB2IuG94f +E5e+1MEQQ9DJgoU3eNZR1dj9GuRC3PyzPcNNJ2R/MMGFw3sOOVcLOgAukotoicuL +0SiL51rHPQu8a5/darH9EltN1GFeceJSDDhgqMP5T8Tp7g/c3//H6szon4H9W+uN +Z3UrImJ+teJjFOaVDqN93+J2eQSUk0lCPGQCd4U9I4AGDGyU6ucdcLQ58Aha9gmU +uQwkfKUCgYEA0UkuPOSDE9dbXe+yhsbOwMb1kKzJYgFDKjRTSP7D9BOMZu4YyASo +J95R4DWjePlDopafG2tNJoWX+CwUl7Uld1R3Ex6xHBa2B7hwZj860GZtr7D4mdWc +DTVjczAjp4P0K1MIFYQui1mVJterkjKuePiI6q/27L1c2jIa/39BWBcCgYEAzW8R +MFZamVw3eA2JYSpBuqhQgE5gX5IWrmVJZSUhpAQTNG/A4nxf7WGtjy9p99tm0RMb +ld05+sOmNLrzw8Pq8SBpFOd+MAca7lPLS1A2CoaAHbOqRqrzVcZ4EZ2jB3WjoLoq +yctwslGb9KmrhBCdcwT48aPAYUIJCZdqEen2xE8CgYBoMowvywGrvjwCH9X9njvP +5P7cAfrdrY04FQcmP5lmCtmLYZ267/6couaWv33dPBU9fMpIh3rI5BiOebvi8FBw +AgCq50v8lR4Z5+0mKvLoUSbpIy4SwTRJqzwRXHVT8LF/ZH6Q39egj4Bf716/kjYl +im/4kJVatsjk5a9lZ4EsDwKBgERkJ3rKJNtNggHrr8KzSLKVekdc0GTAw+BHRAny +NKLf4Gzij3pXIbBrhlZW2JZ1amNMUzCvN7AuFlUTsDeKL9saiSE2eCIRG3wgVVu7 +VmJmqJw6xgNEwkHaEvr6Wd4P4euOTtRjcB9NX/gxzDHpPiGelCoN8+vtCgkxaVSR +aV+tAoGAO4HtLOfBAVDNbVXa27aJAjQSUq8qfkwUNJNz+rwgpVQahfiVkyqAPCQM +IfRJxKWb0Wbt9ojw3AowK/k0d3LZA7FS41JSiiGKIllSGb+i7JKqKW7RHLA3VJ/E +Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U= +-----END RSA PRIVATE KEY-----` + type Server struct { password string services map[string]RpcHandleFunc @@ -41,11 +65,7 @@ func NewServer(opts ...ServerOption) (*Server, error) { server.sshConfig = &ssh.ServerConfig{ PasswordCallback: sshPasswordCallback(server.password), } - b, err := ioutil.ReadFile(filepath.Join("test", "private.key")) - if err != nil { - return nil, fmt.Errorf("error reading private.key: %v", err) - } - privateKey, err := ssh.ParsePrivateKey(b) + privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey)) if err != nil { return nil, fmt.Errorf("error parsing key: %v", err) } @@ -221,70 +241,3 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr r.server.errCh <- fmt.Errorf("error replying: %v", err) } } - -type socketConn struct { - *websocket.Conn - - reader io.Reader - writeMutex sync.Mutex - readMutex sync.Mutex -} - -func newSocketConn(conn *websocket.Conn) *socketConn { - return &socketConn{Conn: conn} -} - -func (s *socketConn) Read(b []byte) (int, error) { - s.readMutex.Lock() - defer s.readMutex.Unlock() - - if s.reader == nil { - msgType, r, err := s.Conn.NextReader() - if err != nil { - return 0, fmt.Errorf("error getting next reader: %v", err) - } - if msgType != websocket.BinaryMessage { - return 0, fmt.Errorf("invalid message type") - } - s.reader = r - } - - bytesRead, err := s.reader.Read(b) - if err != nil { - s.reader = nil - - if err == io.EOF { - err = nil - } - } - - return bytesRead, err -} - -func (s *socketConn) Write(b []byte) (int, error) { - s.writeMutex.Lock() - defer s.writeMutex.Unlock() - - w, err := s.Conn.NextWriter(websocket.BinaryMessage) - if err != nil { - return 0, fmt.Errorf("error getting next writer: %v", err) - } - - n, err := w.Write(b) - if err != nil { - return 0, fmt.Errorf("error writing: %v", err) - } - - if err := w.Close(); err != nil { - return 0, fmt.Errorf("error closing writer: %v", err) - } - - return n, nil -} - -func (s *socketConn) SetDeadline(deadline time.Time) error { - if err := s.Conn.SetReadDeadline(deadline); err != nil { - return err - } - return s.Conn.SetWriteDeadline(deadline) -} diff --git a/test/socket.go b/test/socket.go new file mode 100644 index 000000000..9a2d92491 --- /dev/null +++ b/test/socket.go @@ -0,0 +1,77 @@ +package livesharetest + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type socketConn struct { + *websocket.Conn + + reader io.Reader + writeMutex sync.Mutex + readMutex sync.Mutex +} + +func newSocketConn(conn *websocket.Conn) *socketConn { + return &socketConn{Conn: conn} +} + +func (s *socketConn) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + msgType, r, err := s.Conn.NextReader() + if err != nil { + return 0, fmt.Errorf("error getting next reader: %v", err) + } + if msgType != websocket.BinaryMessage { + return 0, fmt.Errorf("invalid message type") + } + s.reader = r + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socketConn) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + w, err := s.Conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, fmt.Errorf("error getting next writer: %v", err) + } + + n, err := w.Write(b) + if err != nil { + return 0, fmt.Errorf("error writing: %v", err) + } + + if err := w.Close(); err != nil { + return 0, fmt.Errorf("error closing writer: %v", err) + } + + return n, nil +} + +func (s *socketConn) SetDeadline(deadline time.Time) error { + if err := s.Conn.SetReadDeadline(deadline); err != nil { + return err + } + return s.Conn.SetWriteDeadline(deadline) +} From eb2a17645056c6b34638cf32c9b0d39e068e1e69 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Sat, 7 Aug 2021 17:54:43 +0000 Subject: [PATCH 066/290] remove err print --- port_forwarder.go | 1 - 1 file changed, 1 deletion(-) diff --git a/port_forwarder.go b/port_forwarder.go index 3a73e3fce..06c164e8d 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -63,7 +63,6 @@ func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { copyConn := func(writer io.Writer, reader io.Reader) { if _, err := io.Copy(writer, reader); err != nil { - fmt.Println(err) channel.Close() conn.Close() if err != io.EOF { From db95f2f71f5c390dac86811f120a193074dffbb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 12 Aug 2021 14:35:49 +0200 Subject: [PATCH 067/290] Add machine-readable output functionality to `ports` command --- cmd/ghcs/output/format_json.go | 24 +++++++++++++++++- cmd/ghcs/output/format_table.go | 9 ++++--- cmd/ghcs/output/logger.go | 45 +++++++++++++++++++++++++++++++++ cmd/ghcs/ports.go | 37 ++++++++++++++------------- 4 files changed, 93 insertions(+), 22 deletions(-) create mode 100644 cmd/ghcs/output/logger.go diff --git a/cmd/ghcs/output/format_json.go b/cmd/ghcs/output/format_json.go index 37208629c..8488e8dfa 100644 --- a/cmd/ghcs/output/format_json.go +++ b/cmd/ghcs/output/format_json.go @@ -3,6 +3,8 @@ package output import ( "encoding/json" "io" + "strings" + "unicode" ) type jsonwriter struct { @@ -19,7 +21,7 @@ func (j *jsonwriter) SetHeader(cols []string) { func (j *jsonwriter) Append(values []string) { row := make(map[string]string) for i, v := range values { - row[j.cols[i]] = v + row[camelize(j.cols[i])] = v } j.data = append(j.data, row) } @@ -31,3 +33,23 @@ func (j *jsonwriter) Render() { } _ = enc.Encode(j.data) } + +func camelize(s string) string { + var b strings.Builder + capitalizeNext := false + for i, r := range s { + if r == ' ' { + capitalizeNext = true + continue + } + if capitalizeNext { + b.WriteRune(unicode.ToUpper(r)) + capitalizeNext = false + } else if i == 0 { + b.WriteRune(unicode.ToLower(r)) + } else { + b.WriteRune(r) + } + } + return b.String() +} diff --git a/cmd/ghcs/output/format_table.go b/cmd/ghcs/output/format_table.go index 97e7cab58..e0345672d 100644 --- a/cmd/ghcs/output/format_table.go +++ b/cmd/ghcs/output/format_table.go @@ -15,9 +15,7 @@ type Table interface { } func NewTable(w io.Writer, asJSON bool) Table { - f, ok := w.(*os.File) - isTTY := ok && term.IsTerminal(int(f.Fd())) - + isTTY := isTTY(w) if asJSON { return &jsonwriter{w: w, pretty: isTTY} } @@ -26,3 +24,8 @@ func NewTable(w io.Writer, asJSON bool) Table { } return &tabwriter{w: w} } + +func isTTY(w io.Writer) bool { + f, ok := w.(*os.File) + return ok && term.IsTerminal(int(f.Fd())) +} diff --git a/cmd/ghcs/output/logger.go b/cmd/ghcs/output/logger.go new file mode 100644 index 000000000..32d05acc8 --- /dev/null +++ b/cmd/ghcs/output/logger.go @@ -0,0 +1,45 @@ +package output + +import ( + "fmt" + "io" +) + +func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger { + return &Logger{ + out: stdout, + errout: stderr, + enabled: !disabled && isTTY(stdout), + } +} + +type Logger struct { + out io.Writer + errout io.Writer + enabled bool +} + +func (l *Logger) Print(v ...interface{}) (int, error) { + if !l.enabled { + return 0, nil + } + return fmt.Fprint(l.out, v...) +} + +func (l *Logger) Println(v ...interface{}) (int, error) { + if !l.enabled { + return 0, nil + } + return fmt.Fprintln(l.out, v...) +} + +func (l *Logger) Printf(f string, v ...interface{}) (int, error) { + if !l.enabled { + return 0, nil + } + return fmt.Fprintf(l.out, f, v...) +} + +func (l *Logger) Errorf(f string, v ...interface{}) (int, error) { + return fmt.Fprintf(l.errout, f, v...) +} diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 77d1b00f7..fbbffcf1d 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -10,25 +10,36 @@ import ( "strings" "github.com/github/ghcs/api" + "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/muhammadmuzzammil1998/jsonc" - "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" ) +type PortsOptions struct { + CodespaceName string + AsJSON bool +} + func NewPortsCmd() *cobra.Command { + opts := &PortsOptions{} + portsCmd := &cobra.Command{ Use: "ports", Short: "Forward ports from a GitHub Codespace.", RunE: func(cmd *cobra.Command, args []string) error { - return Ports() + return Ports(opts) }, } + portsCmd.Flags().StringVarP(&opts.CodespaceName, "name", "n", "", "Name of Codespace to use") + portsCmd.Flags().BoolVar(&opts.AsJSON, "json", false, "Output as JSON") + portsCmd.AddCommand(NewPortsPublicCmd()) portsCmd.AddCommand(NewPortsPrivateCmd()) portsCmd.AddCommand(NewPortsForwardCmd()) + return portsCmd } @@ -36,16 +47,17 @@ func init() { rootCmd.AddCommand(NewPortsCmd()) } -func Ports() error { +func Ports(opts *PortsOptions) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() + log := output.NewLogger(os.Stdout, os.Stderr, opts.AsJSON) user, err := apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %v", err) } - codespace, err := codespaces.ChooseCodespace(ctx, apiClient, user) + codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, opts.CodespaceName) if err != nil { if err == codespaces.ErrNoCodespaces { fmt.Println(err.Error()) @@ -56,33 +68,23 @@ func Ports() error { devContainerCh := getDevContainer(ctx, apiClient, codespace) - token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) - if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) - } - liveShareClient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) if err != nil { return fmt.Errorf("error connecting to liveshare: %v", err) } - fmt.Println("Loading ports...") + log.Println("Loading ports...") ports, err := getPorts(ctx, liveShareClient) if err != nil { return fmt.Errorf("error getting ports: %v", err) } - if len(ports) == 0 { - fmt.Println("This codespace has no open ports") - return nil - } - devContainerResult := <-devContainerCh if devContainerResult.Err != nil { - fmt.Printf("Failed to get port names: %v\n", devContainerResult.Err.Error()) + _, _ = log.Errorf("Failed to get port names: %v\n", devContainerResult.Err.Error()) } - table := tablewriter.NewWriter(os.Stdout) + table := output.NewTable(os.Stdout, opts.AsJSON) table.SetHeader([]string{"Label", "Source Port", "Destination Port", "Public", "Browse URL"}) for _, port := range ports { sourcePort := strconv.Itoa(port.SourcePort) @@ -104,7 +106,6 @@ func Ports() error { table.Render() return nil - } func getPorts(ctx context.Context, lsclient *liveshare.Client) (liveshare.Ports, error) { From 41e223869e0d8e0c96f79eb6801875b3b39f1109 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 12 Aug 2021 14:37:06 +0200 Subject: [PATCH 068/290] Fix mapping port numbers to labels --- cmd/ghcs/ports.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index fbbffcf1d..27b4e0614 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "context" "encoding/json" "errors" @@ -149,7 +150,7 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod return } - convertedJSON := jsonc.ToJSON(contents) + convertedJSON := normalizeJSON(jsonc.ToJSON(contents)) if !jsonc.Valid(convertedJSON) { ch <- devContainerResult{nil, errors.New("failed to convert json to standard json")} return @@ -309,3 +310,8 @@ func forwardPort(codespaceName, sourcePort, destPort string) error { return nil } + +func normalizeJSON(j []byte) []byte { + // remove trailing commas + return bytes.ReplaceAll(j, []byte("},}"), []byte("}}")) +} From 20d75f0ff9198f952311a8aa21bf648df6147767 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 12 Aug 2021 14:37:23 +0200 Subject: [PATCH 069/290] Normalize logging, output, and error reporting - Return errors as errors, not print to stdout and return nil - Ensure errors and warnings are always written to stderr, not stout - Do not print progress to stdout unless stdout is a terminal --- cmd/ghcs/code.go | 3 +-- cmd/ghcs/create.go | 12 +++++++----- cmd/ghcs/delete.go | 12 ++++++++---- cmd/ghcs/list.go | 5 ----- cmd/ghcs/logs.go | 4 +++- cmd/ghcs/main.go | 2 +- cmd/ghcs/ports.go | 30 ++++++++++++++++-------------- cmd/ghcs/ssh.go | 16 +++++----------- internal/codespaces/codespaces.go | 19 +++++++++++-------- 9 files changed, 52 insertions(+), 51 deletions(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index ccb4788ee..4880436ed 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -43,8 +43,7 @@ func Code(codespaceName string) error { codespace, err := codespaces.ChooseCodespace(ctx, apiClient, user) if err != nil { if err == codespaces.ErrNoCodespaces { - fmt.Println(err.Error()) - return nil + return err } return fmt.Errorf("error choosing codespace: %v", err) } diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 385e5d957..b4a4fcb1f 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "fmt" "os" "strings" @@ -9,6 +10,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/fatih/camelcase" "github.com/github/ghcs/api" + "github.com/github/ghcs/cmd/ghcs/output" "github.com/spf13/cobra" ) @@ -17,7 +19,7 @@ var repo, branch, machine string func newCreateCmd() *cobra.Command { createCmd := &cobra.Command{ Use: "create", - Short: "Create a GitHub Codespace.", + Short: "Create a GitHub Codespace.", RunE: func(cmd *cobra.Command, args []string) error { return Create() }, @@ -39,6 +41,7 @@ func Create() error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) locationCh := getLocation(ctx, apiClient) userCh := getUser(ctx, apiClient) + log := output.NewLogger(os.Stdout, os.Stderr, false) repo, err := getRepoName() if err != nil { @@ -69,18 +72,17 @@ func Create() error { return fmt.Errorf("error getting machine type: %v", err) } if machine == "" { - fmt.Println("There are no available machine types for this repository") - return nil + return errors.New("There are no available machine types for this repository") } - fmt.Println("Creating your codespace...") + log.Println("Creating your codespace...") codespace, err := apiClient.CreateCodespace(ctx, userResult.User, repository, machine, branch, locationResult.Location) if err != nil { return fmt.Errorf("error creating codespace: %v", err) } - fmt.Println("Codespace created: " + codespace.Name) + log.Printf("Codespace created: %s\n", codespace.Name) return nil } diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index df7a3f1be..d625a1ca3 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -7,6 +7,7 @@ import ( "os" "github.com/github/ghcs/api" + "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/codespaces" "github.com/spf13/cobra" ) @@ -58,6 +59,7 @@ func init() { func Delete(codespaceName string) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() + log := output.NewLogger(os.Stdout, os.Stderr, false) user, err := apiClient.GetUser(ctx) if err != nil { @@ -73,7 +75,7 @@ func Delete(codespaceName string) error { return fmt.Errorf("error deleting codespace: %v", err) } - fmt.Println("Codespace deleted.") + log.Println("Codespace deleted.") return List(&ListOptions{}) } @@ -81,6 +83,7 @@ func Delete(codespaceName string) error { func DeleteAll() error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() + log := output.NewLogger(os.Stdout, os.Stderr, false) user, err := apiClient.GetUser(ctx) if err != nil { @@ -102,7 +105,7 @@ func DeleteAll() error { return fmt.Errorf("error deleting codespace: %v", err) } - fmt.Printf("Codespace deleted: %s\n", c.Name) + log.Printf("Codespace deleted: %s\n", c.Name) } return List(&ListOptions{}) @@ -111,6 +114,7 @@ func DeleteAll() error { func DeleteByRepo(repo string) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() + log := output.NewLogger(os.Stdout, os.Stderr, false) user, err := apiClient.GetUser(ctx) if err != nil { @@ -138,11 +142,11 @@ func DeleteByRepo(repo string) error { return fmt.Errorf("error deleting codespace: %v", err) } - fmt.Printf("Codespace deleted: %s\n", c.Name) + log.Printf("Codespace deleted: %s\n", c.Name) } if !deleted { - fmt.Printf("No codespace was found for repository: %s\n", repo) + return fmt.Errorf("No codespace was found for repository: %s", repo) } return List(&ListOptions{}) diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index bbee2aae8..acc01b7ad 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -48,11 +48,6 @@ func List(opts *ListOptions) error { return fmt.Errorf("error getting codespaces: %v", err) } - if len(codespaces) == 0 { - fmt.Println("You have no codespaces.") - return nil - } - table := output.NewTable(os.Stdout, opts.AsJSON) table.SetHeader([]string{"Name", "Repository", "Branch", "State", "Created At"}) for _, codespace := range codespaces { diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 03a7c963a..45fd8bca8 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -7,6 +7,7 @@ import ( "os" "github.com/github/ghcs/api" + "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/codespaces" "github.com/spf13/cobra" ) @@ -38,6 +39,7 @@ func init() { func Logs(tail bool, codespaceName string) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() + log := output.NewLogger(os.Stdout, os.Stderr, false) user, err := apiClient.GetUser(ctx) if err != nil { @@ -49,7 +51,7 @@ func Logs(tail bool, codespaceName string) error { return fmt.Errorf("get or choose codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, token, codespace) if err != nil { return fmt.Errorf("connecting to liveshare: %v", err) } diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 00e8be894..a2617788a 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -22,7 +22,7 @@ var rootCmd = &cobra.Command{ func Execute() { if os.Getenv("GITHUB_TOKEN") == "" { - fmt.Println("The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo and make sure to enable SSO for the GitHub organization after creating the token.") + fmt.Fprintln(os.Stderr, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo and make sure to enable SSO for the GitHub organization after creating the token.") os.Exit(1) } diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 27b4e0614..613a1834b 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -61,15 +61,14 @@ func Ports(opts *PortsOptions) error { codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, opts.CodespaceName) if err != nil { if err == codespaces.ErrNoCodespaces { - fmt.Println(err.Error()) - return nil + return err } return fmt.Errorf("error choosing codespace: %v", err) } devContainerCh := getDevContainer(ctx, apiClient, codespace) - liveShareClient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + liveShareClient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, token, codespace) if err != nil { return fmt.Errorf("error connecting to liveshare: %v", err) } @@ -177,7 +176,8 @@ func NewPortsPublicCmd() *cobra.Command { return errors.New("[codespace_name] [source] port number are required.") } - return updatePortVisibility(args[0], args[1], true) + log := output.NewLogger(os.Stdout, os.Stderr, false) + return updatePortVisibility(log, args[0], args[1], true) }, } } @@ -192,12 +192,13 @@ func NewPortsPrivateCmd() *cobra.Command { return errors.New("[codespace_name] [source] port number are required.") } - return updatePortVisibility(args[0], args[1], false) + log := output.NewLogger(os.Stdout, os.Stderr, false) + return updatePortVisibility(log, args[0], args[1], false) }, } } -func updatePortVisibility(codespaceName, sourcePort string, public bool) error { +func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) error { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) @@ -216,7 +217,7 @@ func updatePortVisibility(codespaceName, sourcePort string, public bool) error { return fmt.Errorf("error getting codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, token, codespace) if err != nil { return fmt.Errorf("error connecting to liveshare: %v", err) } @@ -236,11 +237,10 @@ func updatePortVisibility(codespaceName, sourcePort string, public bool) error { } state := "PUBLIC" - if public == false { + if !public { state = "PRIVATE" } - - fmt.Println(fmt.Sprintf("Port %s is now %s.", sourcePort, state)) + log.Printf("Port %s is now %s.\n", sourcePort, state) return nil } @@ -254,12 +254,14 @@ func NewPortsForwardCmd() *cobra.Command { if len(args) < 3 { return errors.New("[codespace_name] [source] [dst] port number are required.") } - return forwardPort(args[0], args[1], args[2]) + + log := output.NewLogger(os.Stdout, os.Stderr, false) + return forwardPort(log, args[0], args[1], args[2]) }, } } -func forwardPort(codespaceName, sourcePort, destPort string) error { +func forwardPort(log *output.Logger, codespaceName, sourcePort, destPort string) error { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) @@ -278,7 +280,7 @@ func forwardPort(codespaceName, sourcePort, destPort string) error { return fmt.Errorf("error getting codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, token, codespace) if err != nil { return fmt.Errorf("error connecting to liveshare: %v", err) } @@ -302,7 +304,7 @@ func forwardPort(codespaceName, sourcePort, destPort string) error { return fmt.Errorf("error sharing source port: %v", err) } - fmt.Println("Forwarding port: " + sourcePort + " -> " + destPort) + log.Println("Forwarding port: " + sourcePort + " -> " + destPort) portForwarder := liveshare.NewPortForwarder(lsclient, server, dstPortInt) if err := portForwarder.Start(ctx); err != nil { return fmt.Errorf("error forwarding port: %v", err) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 60fdee498..372061c55 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -9,6 +9,7 @@ import ( "time" "github.com/github/ghcs/api" + "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" @@ -40,6 +41,7 @@ func init() { func SSH(sshProfile, codespaceName string, sshServerPort int) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() + log := output.NewLogger(os.Stdout, os.Stderr, false) user, err := apiClient.GetUser(ctx) if err != nil { @@ -51,7 +53,7 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { return fmt.Errorf("get or choose codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, apiClient, token, codespace) + lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, token, codespace) if err != nil { return fmt.Errorf("error connecting to liveshare: %v", err) } @@ -61,7 +63,7 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { return fmt.Errorf("error creating liveshare terminal: %v", err) } - fmt.Println("Preparing SSH...") + log.Println("Preparing SSH...") if sshProfile == "" { containerID, err := getContainerID(ctx, terminal) if err != nil { @@ -71,8 +73,6 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { if err := setupSSH(ctx, terminal, containerID, codespace.RepositoryName); err != nil { return fmt.Errorf("error creating ssh server: %v", err) } - - fmt.Printf("\n") } tunnelPort, tunnelClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, sshServerPort) @@ -88,7 +88,7 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { usingCustomPort := tunnelPort == sshServerPort connClosed := codespaces.ConnectToTunnel(ctx, tunnelPort, connectDestination, usingCustomPort) - fmt.Println("Ready...") + log.Println("Ready...") select { case err := <-tunnelClosed: if err != nil { @@ -104,7 +104,6 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { } func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) { - fmt.Print(".") cmd := terminal.NewCommand( "/", "/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running", @@ -114,17 +113,14 @@ func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, return "", fmt.Errorf("error running command: %v", err) } - fmt.Print(".") scanner := bufio.NewScanner(stream) scanner.Scan() - fmt.Print(".") containerID := scanner.Text() if err := scanner.Err(); err != nil { return "", fmt.Errorf("error scanning stream: %v", err) } - fmt.Print(".") if err := stream.Close(); err != nil { return "", fmt.Errorf("error closing stream: %v", err) } @@ -135,7 +131,6 @@ func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID, repositoryName string) error { setupBashProfileCmd := fmt.Sprintf(`echo "cd /workspaces/%v; export $(cat /workspaces/.codespaces/shared/.env | xargs); exec /bin/zsh;" > /home/codespace/.bash_profile`, repositoryName) - fmt.Print(".") compositeCommand := []string{setupBashProfileCmd} cmd := terminal.NewCommand( "/", @@ -146,7 +141,6 @@ func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID, re return fmt.Errorf("error running command: %v", err) } - fmt.Print(".") if err := stream.Close(); err != nil { return fmt.Errorf("error closing stream: %v", err) } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 4c62d9aff..48369cfa0 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -57,9 +57,14 @@ func ChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (* return codespace, nil } -func ConnectToLiveshare(ctx context.Context, apiClient *api.API, token string, codespace *api.Codespace) (client *liveshare.Client, err error) { +type logger interface { + Print(v ...interface{}) (int, error) + Println(v ...interface{}) (int, error) +} + +func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, token string, codespace *api.Codespace) (client *liveshare.Client, err error) { if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { - fmt.Println("Starting your codespace...") // TODO(josebalius): better way of notifying of events + log.Println("Starting your codespace...") if err := apiClient.StartCodespace(ctx, token, codespace); err != nil { return nil, fmt.Errorf("error starting codespace: %v", err) } @@ -69,7 +74,7 @@ func ConnectToLiveshare(ctx context.Context, apiClient *api.API, token string, c for codespace.Environment.Connection.SessionID == "" || codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { if retries > 1 { if retries%2 == 0 { - fmt.Print(".") + log.Print(".") } time.Sleep(1 * time.Second) @@ -88,10 +93,10 @@ func ConnectToLiveshare(ctx context.Context, apiClient *api.API, token string, c } if retries >= 2 { - fmt.Print("\n") + log.Print("\n") } - fmt.Println("Connecting to your codespace...") + log.Println("Connecting to your codespace...") lsclient, err := liveshare.NewClient( liveshare.WithConnection(liveshare.Connection{ @@ -117,10 +122,8 @@ func GetOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.Use codespace, err = ChooseCodespace(ctx, apiClient, user) if err != nil { if err == ErrNoCodespaces { - fmt.Println(err.Error()) - return nil, "", nil + return nil, "", err } - return nil, "", fmt.Errorf("choosing codespace: %v", err) } codespaceName = codespace.Name From c9c1ff8dacdee9fae5d386d8f084890b6c6147a8 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 16 Aug 2021 20:16:50 +0000 Subject: [PATCH 070/290] add back . indicators & update ConnectToTunnel --- cmd/ghcs/logs.go | 2 +- cmd/ghcs/ssh.go | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 45fd8bca8..d9422d49d 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -78,7 +78,7 @@ func Logs(tail bool, codespaceName string) error { go func() { scanner := bufio.NewScanner(stdout) for scanner.Scan() { - fmt.Println(scanner.Text()) + log.Println(scanner.Text()) } if err := scanner.Err(); err != nil { diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 372061c55..40f8fbc18 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -65,14 +65,16 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { log.Println("Preparing SSH...") if sshProfile == "" { - containerID, err := getContainerID(ctx, terminal) + containerID, err := getContainerID(ctx, log, terminal) if err != nil { return fmt.Errorf("error getting container id: %v", err) } - if err := setupSSH(ctx, terminal, containerID, codespace.RepositoryName); err != nil { + if err := setupSSH(ctx, log, terminal, containerID, codespace.RepositoryName); err != nil { return fmt.Errorf("error creating ssh server: %v", err) } + + log.Print("\n") } tunnelPort, tunnelClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, sshServerPort) @@ -86,7 +88,7 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { } usingCustomPort := tunnelPort == sshServerPort - connClosed := codespaces.ConnectToTunnel(ctx, tunnelPort, connectDestination, usingCustomPort) + connClosed := codespaces.ConnectToTunnel(ctx, log, tunnelPort, connectDestination, usingCustomPort) log.Println("Ready...") select { @@ -103,24 +105,30 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { return nil } -func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) { +func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) { + logger.Print(".") + cmd := terminal.NewCommand( "/", "/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running", ) + stream, err := cmd.Run(ctx) if err != nil { return "", fmt.Errorf("error running command: %v", err) } + logger.Print(".") scanner := bufio.NewScanner(stream) scanner.Scan() + logger.Print(".") containerID := scanner.Text() if err := scanner.Err(); err != nil { return "", fmt.Errorf("error scanning stream: %v", err) } + logger.Print(".") if err := stream.Close(); err != nil { return "", fmt.Errorf("error closing stream: %v", err) } @@ -128,9 +136,10 @@ func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, return containerID, nil } -func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID, repositoryName string) error { +func setupSSH(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal, containerID, repositoryName string) error { setupBashProfileCmd := fmt.Sprintf(`echo "cd /workspaces/%v; export $(cat /workspaces/.codespaces/shared/.env | xargs); exec /bin/zsh;" > /home/codespace/.bash_profile`, repositoryName) + logger.Print(".") compositeCommand := []string{setupBashProfileCmd} cmd := terminal.NewCommand( "/", @@ -141,6 +150,7 @@ func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID, re return fmt.Errorf("error running command: %v", err) } + logger.Print(".") if err := stream.Close(); err != nil { return fmt.Errorf("error closing stream: %v", err) } From 22be26431e918fa10142ddb471ffaf151d9877a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Mon, 16 Aug 2021 22:28:39 +0200 Subject: [PATCH 071/290] Have `--codespace ` flag be consistent across commands --- cmd/ghcs/ports.go | 2 +- cmd/ghcs/ssh.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 613a1834b..1d5d5ff9f 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -34,7 +34,7 @@ func NewPortsCmd() *cobra.Command { }, } - portsCmd.Flags().StringVarP(&opts.CodespaceName, "name", "n", "", "Name of Codespace to use") + portsCmd.Flags().StringVarP(&opts.CodespaceName, "codespace", "c", "", "The `name` of the Codespace to use") portsCmd.Flags().BoolVar(&opts.AsJSON, "json", false, "Output as JSON") portsCmd.AddCommand(NewPortsPublicCmd()) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 372061c55..b6bbe254d 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -29,7 +29,7 @@ func NewSSHCmd() *cobra.Command { sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "SSH Profile") sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH Server Port") - sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "Codespace Name") + sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "The `name` of the Codespace to use") return sshCmd } From 97d8285b5870d478a22bbbc949ff663e77043141 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Mon, 16 Aug 2021 23:19:20 +0200 Subject: [PATCH 072/290] Do not require GITHUB_TOKEN for merely viewing command help --- cmd/ghcs/main.go | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 00e8be894..58037437a 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -1,33 +1,46 @@ package main import ( + "errors" "fmt" + "io" "os" "github.com/spf13/cobra" ) func main() { - Execute() + if err := rootCmd.Execute(); err != nil { + explainError(os.Stderr, err) + os.Exit(1) + } } var Version = "DEV" var rootCmd = &cobra.Command{ - Use: "ghcs", - Short: "Unofficial GitHub Codespaces CLI.", - Long: "Unofficial CLI tool to manage and interact with GitHub Codespaces.", + Use: "ghcs", + Long: `Unofficial CLI tool to manage GitHub Codespaces. + +Running commands requires the GITHUB_TOKEN environment variable to be set to a +token to access the GitHub API with.`, Version: Version, + + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + if os.Getenv("GITHUB_TOKEN") == "" { + return tokenError + } + return nil + }, } -func Execute() { - if os.Getenv("GITHUB_TOKEN") == "" { - fmt.Println("The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo and make sure to enable SSO for the GitHub organization after creating the token.") - os.Exit(1) - } +var tokenError = errors.New("GITHUB_TOKEN is missing") - if err := rootCmd.Execute(); err != nil { - fmt.Fprintln(os.Stderr, err) - os.Exit(1) +func explainError(w io.Writer, err error) { + if errors.Is(err, tokenError) { + fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo") + fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") + return } + // fmt.Fprintf(w, "%v\n", err) } From 5e472bc0e5996f478d69388d2fc6cb24afff9c11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Mon, 16 Aug 2021 23:24:11 +0200 Subject: [PATCH 073/290] Improve command descriptions and argument assertions --- cmd/ghcs/code.go | 5 +++-- cmd/ghcs/create.go | 3 ++- cmd/ghcs/delete.go | 20 ++++++++------------ cmd/ghcs/list.go | 3 ++- cmd/ghcs/logs.go | 3 ++- cmd/ghcs/ports.go | 32 +++++++++++--------------------- cmd/ghcs/ssh.go | 7 ++++--- 7 files changed, 32 insertions(+), 41 deletions(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index ccb4788ee..23dc0f767 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -14,8 +14,9 @@ import ( func NewCodeCmd() *cobra.Command { return &cobra.Command{ - Use: "code", - Short: "Open a GitHub Codespace in VSCode.", + Use: "code []", + Short: "Open a Codespace in VS Code", + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { var codespaceName string if len(args) > 0 { diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 385e5d957..2228bc105 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -17,7 +17,8 @@ var repo, branch, machine string func newCreateCmd() *cobra.Command { createCmd := &cobra.Command{ Use: "create", - Short: "Create a GitHub Codespace.", + Short: "Create a Codespace", + Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { return Create() }, diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index f374ef7e6..c3a842e4f 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -2,7 +2,6 @@ package main import ( "context" - "errors" "fmt" "os" @@ -13,8 +12,9 @@ import ( func NewDeleteCmd() *cobra.Command { deleteCmd := &cobra.Command{ - Use: "delete", - Short: "Delete a GitHub Codespace.", + Use: "delete []", + Short: "Delete a Codespace", + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { var codespaceName string if len(args) > 0 { @@ -26,22 +26,18 @@ func NewDeleteCmd() *cobra.Command { deleteAllCmd := &cobra.Command{ Use: "all", - Short: "delete all codespaces", - Long: "delete all codespaces for the user with the current token", + Short: "Delete all Codespaces for the current user", + Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { return DeleteAll() }, } deleteByRepoCmd := &cobra.Command{ - Use: "repo REPO_NAME", - Short: "delete all codespaces for the repo", - Long: `delete all the codespaces that the user with the current token has in this repo. -This includes all codespaces in all states.`, + Use: "repo ", + Short: "Delete all Codespaces for a repository", + Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - if len(args) == 0 { - return errors.New("A Repository name is required.") - } return DeleteByRepo(args[0]) }, } diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 6db79af97..f19095ff8 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -14,7 +14,8 @@ import ( func NewListCmd() *cobra.Command { listCmd := &cobra.Command{ Use: "list", - Short: "List GitHub Codespaces you have on your account.", + Short: "List your Codespaces", + Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { return List() }, diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 03a7c963a..3c192b081 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -15,8 +15,9 @@ func NewLogsCmd() *cobra.Command { var tail bool logsCmd := &cobra.Command{ - Use: "logs", + Use: "logs []", Short: "Access Codespace logs", + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { var codespaceName string if len(args) > 0 { diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 77d1b00f7..2e0e44908 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -20,7 +20,8 @@ import ( func NewPortsCmd() *cobra.Command { portsCmd := &cobra.Command{ Use: "ports", - Short: "Forward ports from a GitHub Codespace.", + Short: "List ports in a Codespace", + Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { return Ports() }, @@ -167,14 +168,10 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod func NewPortsPublicCmd() *cobra.Command { return &cobra.Command{ - Use: "public", - Short: "public", - Long: "public", + Use: "public ", + Short: "Mark port as public", + Args: cobra.ExactArgs(2), RunE: func(cmd *cobra.Command, args []string) error { - if len(args) < 2 { - return errors.New("[codespace_name] [source] port number are required.") - } - return updatePortVisibility(args[0], args[1], true) }, } @@ -182,14 +179,10 @@ func NewPortsPublicCmd() *cobra.Command { func NewPortsPrivateCmd() *cobra.Command { return &cobra.Command{ - Use: "private", - Short: "private", - Long: "private", + Use: "private ", + Short: "Mark port as private", + Args: cobra.ExactArgs(2), RunE: func(cmd *cobra.Command, args []string) error { - if len(args) < 2 { - return errors.New("[codespace_name] [source] port number are required.") - } - return updatePortVisibility(args[0], args[1], false) }, } @@ -245,13 +238,10 @@ func updatePortVisibility(codespaceName, sourcePort string, public bool) error { func NewPortsForwardCmd() *cobra.Command { return &cobra.Command{ - Use: "forward", - Short: "forward", - Long: "forward", + Use: "forward ", + Short: "Forward port", + Args: cobra.ExactArgs(3), RunE: func(cmd *cobra.Command, args []string) error { - if len(args) < 3 { - return errors.New("[codespace_name] [source] [dst] port number are required.") - } return forwardPort(args[0], args[1], args[2]) }, } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 60fdee498..428cd74bf 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -20,14 +20,15 @@ func NewSSHCmd() *cobra.Command { sshCmd := &cobra.Command{ Use: "ssh", - Short: "SSH into a GitHub Codespace, for use with running tests/editing in vim, etc.", + Short: "SSH into a Codespace", + Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { return SSH(sshProfile, codespaceName, sshServerPort) }, } - sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "SSH Profile") - sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH Server Port") + sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "The `name` of the SSH profile to use") + sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH server port number") sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "Codespace Name") return sshCmd From b47686163a2d06ce4e1b46d480f34133df611f2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Tue, 17 Aug 2021 13:04:55 +0200 Subject: [PATCH 074/290] Fixes for log/output streams --- cmd/ghcs/logs.go | 2 +- internal/codespaces/ssh.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index d9422d49d..45fd8bca8 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -78,7 +78,7 @@ func Logs(tail bool, codespaceName string) error { go func() { scanner := bufio.NewScanner(stdout) for scanner.Scan() { - log.Println(scanner.Text()) + fmt.Println(scanner.Text()) } if err := scanner.Err(); err != nil { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 2bb661086..672ba3b7b 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -56,12 +56,12 @@ func makeSSHArgs(port int, dst, cmd string) ([]string, []string) { return cmdArgs, connArgs } -func ConnectToTunnel(ctx context.Context, port int, destination string, usingCustomPort bool) <-chan error { +func ConnectToTunnel(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) <-chan error { connClosed := make(chan error) args, connArgs := makeSSHArgs(port, destination, "") if usingCustomPort { - fmt.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " ")) + log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " ")) } cmd := exec.CommandContext(ctx, "ssh", args...) From b5670252decdcc142d3efd928403c67c771d1c22 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 17 Aug 2021 12:58:46 +0000 Subject: [PATCH 075/290] small update to description --- cmd/ghcs/ports.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 9985c708a..5c3f5a166 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -242,7 +242,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, func NewPortsForwardCmd() *cobra.Command { return &cobra.Command{ Use: "forward ", - Short: "Forward port", + Short: "Forward ports", Args: cobra.MinimumNArgs(2), RunE: func(cmd *cobra.Command, args []string) error { log := output.NewLogger(os.Stdout, os.Stderr, false) From 8533d084614a373412d20013e3c7e0a7d77dd833 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 17 Aug 2021 13:07:40 +0000 Subject: [PATCH 076/290] rename var --- cmd/ghcs/ports.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 5c3f5a166..0f2460e0a 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -287,16 +287,16 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro g, gctx := errgroup.WithContext(ctx) for _, portPair := range portPairs { - portPair := portPair + pp := portPair srcstr := strconv.Itoa(portPair.Src) - if err := server.StartSharing(gctx, "share-"+srcstr, portPair.Src); err != nil { + if err := server.StartSharing(gctx, "share-"+srcstr, pp.Src); err != nil { return fmt.Errorf("start sharing port: %v", err) } g.Go(func() error { - log.Println("Forwarding port: " + srcstr + " ==> " + strconv.Itoa(portPair.Dst)) - portForwarder := liveshare.NewPortForwarder(lsclient, server, portPair.Dst) + log.Println("Forwarding port: " + srcstr + " ==> " + strconv.Itoa(pp.Dst)) + portForwarder := liveshare.NewPortForwarder(lsclient, server, pp.Dst) if err := portForwarder.Start(gctx); err != nil { return fmt.Errorf("error forwarding port: %v", err) } From 269196c94f3ca7b2d8fb9efe001295bc161d6948 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 18 Aug 2021 15:12:47 +0000 Subject: [PATCH 077/290] support existing connections for port forwarding --- port_forwarder.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/port_forwarder.go b/port_forwarder.go index 06c164e8d..e6eedf16c 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -54,7 +54,12 @@ func (l *PortForwarder) Start(ctx context.Context) error { return nil } -func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { +func (l *PortForwarder) StartWithConn(ctx context.Context, conn io.ReadWriteCloser) error { + go l.handleConnection(ctx, conn) + return <-l.errCh +} + +func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) { channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) if err != nil { l.errCh <- fmt.Errorf("error opening streaming channel for new connection: %v", err) From 5af1cccb73310136e0de4d1055ee55e9db29cb7a Mon Sep 17 00:00:00 2001 From: Issy Long Date: Wed, 18 Aug 2021 18:05:59 +0100 Subject: [PATCH 078/290] cmd/ghcs/delete: When matching repos to delete, standardize casing - It was possible to delete Codespaces for repo `SomePerson/foo` but not `someperson/foo`, despite the fact that the GitHub APIs don't actually care about casing - `SomePerson` and `someperson` is the same account. - This fixes that by lowercasing both the user-provided repo name, and the repository that is attached to the Codespace for a match. - Fixes #76. --- cmd/ghcs/delete.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index d79bcc448..e789ba1ba 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "strings" "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" @@ -124,7 +125,7 @@ func DeleteByRepo(repo string) error { var deleted bool for _, c := range codespaces { - if c.RepositoryNWO != repo { + if strings.ToLower(c.RepositoryNWO) != strings.ToLower(repo) { continue } deleted = true From 28a3644a079169b78aa0a8149e9eed15ef98445d Mon Sep 17 00:00:00 2001 From: Issy Long Date: Wed, 18 Aug 2021 18:15:15 +0100 Subject: [PATCH 079/290] cmd/ghcs/delete: I learnt about `strings.EqualFold` - thanks, linter! --- cmd/ghcs/delete.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index e789ba1ba..c357171d1 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -125,7 +125,7 @@ func DeleteByRepo(repo string) error { var deleted bool for _, c := range codespaces { - if strings.ToLower(c.RepositoryNWO) != strings.ToLower(repo) { + if !strings.EqualFold(c.RepositoryNWO, repo) { continue } deleted = true From a53eb53ad4c6850ce13234b123b122c16796e1c7 Mon Sep 17 00:00:00 2001 From: Issy Long Date: Thu, 19 Aug 2021 10:10:30 +0100 Subject: [PATCH 080/290] cmd/ghcs/ports: Fix usage docs for the new `source:forward` syntax Co-authored-by: George Brocklehurst --- cmd/ghcs/ports.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 0f2460e0a..09397af54 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -241,7 +241,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, func NewPortsForwardCmd() *cobra.Command { return &cobra.Command{ - Use: "forward ", + Use: "forward :", Short: "Forward ports", Args: cobra.MinimumNArgs(2), RunE: func(cmd *cobra.Command, args []string) error { From 530c0244f9461874b206a4a1324759d8032ec745 Mon Sep 17 00:00:00 2001 From: Josh Gross Date: Thu, 19 Aug 2021 17:37:57 -0400 Subject: [PATCH 081/290] Add support to `code` for VS Code Insiders --- cmd/ghcs/code.go | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index ac5bffe8b..9a5abddfb 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -12,8 +12,14 @@ import ( "github.com/spf13/cobra" ) +type CodeOptions struct { + UseInsiders bool +} + func NewCodeCmd() *cobra.Command { - return &cobra.Command{ + opts := &CodeOptions{} + + codeCmd := &cobra.Command{ Use: "code []", Short: "Open a Codespace in VS Code", Args: cobra.MaximumNArgs(1), @@ -22,16 +28,20 @@ func NewCodeCmd() *cobra.Command { if len(args) > 0 { codespaceName = args[0] } - return Code(codespaceName) + return Code(codespaceName, opts) }, } + + codeCmd.Flags().BoolVar(&opts.UseInsiders, "insiders", false, "Use the insiders version of VS Code") + + return codeCmd } func init() { rootCmd.AddCommand(NewCodeCmd()) } -func Code(codespaceName string) error { +func Code(codespaceName string, opts *CodeOptions) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() @@ -51,13 +61,17 @@ func Code(codespaceName string) error { codespaceName = codespace.Name } - if err := open.Run(vscodeProtocolURL(codespaceName)); err != nil { + if err := open.Run(vscodeProtocolURL(codespaceName, opts.UseInsiders)); err != nil { return fmt.Errorf("error opening vscode URL") } return nil } -func vscodeProtocolURL(codespaceName string) string { - return fmt.Sprintf("vscode://github.codespaces/connect?name=%s", url.QueryEscape(codespaceName)) +func vscodeProtocolURL(codespaceName string, useInsiders bool) string { + application := "vscode" + if useInsiders { + application = "vscode-insiders" + } + return fmt.Sprintf("%s://github.codespaces/connect?name=%s", application, url.QueryEscape(codespaceName)) } From ae88091fd8276d57437aa7f44053f1462f3412e2 Mon Sep 17 00:00:00 2001 From: Josh Gross Date: Mon, 23 Aug 2021 12:01:13 -0400 Subject: [PATCH 082/290] Replace options struct with variable --- cmd/ghcs/code.go | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 9a5abddfb..81dbdbb2c 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -12,12 +12,8 @@ import ( "github.com/spf13/cobra" ) -type CodeOptions struct { - UseInsiders bool -} - func NewCodeCmd() *cobra.Command { - opts := &CodeOptions{} + useInsiders := false codeCmd := &cobra.Command{ Use: "code []", @@ -28,11 +24,11 @@ func NewCodeCmd() *cobra.Command { if len(args) > 0 { codespaceName = args[0] } - return Code(codespaceName, opts) + return Code(codespaceName, useInsiders) }, } - codeCmd.Flags().BoolVar(&opts.UseInsiders, "insiders", false, "Use the insiders version of VS Code") + codeCmd.Flags().BoolVar(&useInsiders, "insiders", false, "Use the insiders version of VS Code") return codeCmd } @@ -41,7 +37,7 @@ func init() { rootCmd.AddCommand(NewCodeCmd()) } -func Code(codespaceName string, opts *CodeOptions) error { +func Code(codespaceName string, useInsiders bool) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() @@ -61,7 +57,7 @@ func Code(codespaceName string, opts *CodeOptions) error { codespaceName = codespace.Name } - if err := open.Run(vscodeProtocolURL(codespaceName, opts.UseInsiders)); err != nil { + if err := open.Run(vscodeProtocolURL(codespaceName, useInsiders)); err != nil { return fmt.Errorf("error opening vscode URL") } From 30be4c98f95c0827b2060054271b8b8e736fce86 Mon Sep 17 00:00:00 2001 From: Gabriel Ramirez Date: Tue, 24 Aug 2021 13:12:18 -0500 Subject: [PATCH 083/290] Send codespace name to Stdout to enable scripting --- cmd/ghcs/create.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index bb2e18eff..c3e8a24a1 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -83,7 +83,9 @@ func Create() error { return fmt.Errorf("error creating codespace: %v", err) } - log.Printf("Codespace created: %s\n", codespace.Name) + log.Printf("Codespace created: ") + + fmt.Fprintln(os.Stdout, codespace.Name) return nil } From 46ee45bcdd99c54849f76c4484950c7825a85616 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 24 Aug 2021 17:46:24 -0400 Subject: [PATCH 084/290] simplify the state iteration --- cmd/ghcs/create.go | 64 +++++++++++++++---------------- cmd/ghcs/ssh.go | 2 +- internal/codespaces/codespaces.go | 3 +- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 9c509b851..890e6b424 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -92,46 +92,42 @@ func Create() error { var lastState codespaces.PostCreateState var breakNextState bool -PollStates: for { - select { - case stateUpdate := <-states: - if stateUpdate.Err != nil { - return fmt.Errorf("receive state update: %v", err) - } + stateUpdate := <-states + if stateUpdate.Err != nil { + return fmt.Errorf("receive state update: %v", err) + } - var inProgress bool - for _, state := range stateUpdate.PostCreateStates { - switch state.Status { - case codespaces.PostCreateStateRunning: - if lastState != state { - lastState = state - fmt.Print(state.Name) - } else { - fmt.Print(".") - } + var inProgress bool + for _, state := range stateUpdate.PostCreateStates { + switch state.Status { + case codespaces.PostCreateStateRunning: + if lastState != state { + lastState = state + log.Print(state.Name) + } else { + log.Print(".") + } - inProgress = true - break - case codespaces.PostCreateStateFailed: - if lastState.Name == state.Name && lastState.Status != state.Status { - lastState = state - fmt.Print(".Failed\n") - } - case codespaces.PostCreateStateSuccess: - if lastState.Name == state.Name && lastState.Status != state.Status { - lastState = state - fmt.Print(".Success\n") - } + inProgress = true + break + case codespaces.PostCreateStateFailed: + if lastState.Name == state.Name && lastState.Status != state.Status { + lastState = state + log.Print(".Failed\n") + } + case codespaces.PostCreateStateSuccess: + if lastState.Name == state.Name && lastState.Status != state.Status { + lastState = state + log.Print(".Success\n") } } + } - switch { - case !inProgress && !breakNextState: - breakNextState = true - case !inProgress && breakNextState: - break PollStates - } + if !inProgress && !breakNextState { + breakNextState = true + } else if !inProgress && breakNextState { + break } } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index bb1edfeee..a895b6b4d 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -64,7 +64,7 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { return fmt.Errorf("error creating liveshare terminal: %v", err) } - log.Println("Preparing SSH...") + log.Print("Preparing SSH...") if sshProfile == "" { containerID, err := getContainerID(ctx, log, terminal) if err != nil { diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index bbf63b709..005ea0fda 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -66,7 +66,7 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { startedCodespace = true - log.Println("Starting your codespace...") + log.Print("Starting your codespace...") if err := apiClient.StartCodespace(ctx, token, codespace); err != nil { return nil, fmt.Errorf("error starting codespace: %v", err) } @@ -97,6 +97,7 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use if startedCodespace { fmt.Print("\n") } + log.Println("Connecting to your codespace...") lsclient, err := liveshare.NewClient( From 2ef6e95982342bcb0069c810889a853d8857f4a7 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 24 Aug 2021 20:15:21 -0400 Subject: [PATCH 085/290] show status under a flag --- cmd/ghcs/create.go | 48 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 890e6b424..9b25b00b5 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -17,19 +17,29 @@ import ( var repo, branch, machine string +type CreateOptions struct { + Repo string + Branch string + Machine string + ShowStatus bool +} + func newCreateCmd() *cobra.Command { + opts := &CreateOptions{} + createCmd := &cobra.Command{ Use: "create", Short: "Create a Codespace", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return Create() + return Create(opts) }, } - createCmd.Flags().StringVarP(&repo, "repo", "r", "", "repository name with owner: user/repo") - createCmd.Flags().StringVarP(&branch, "branch", "b", "", "repository branch") - createCmd.Flags().StringVarP(&machine, "machine", "m", "", "hardware specifications for the VM") + createCmd.Flags().StringVarP(&opts.Repo, "repo", "r", "", "repository name with owner: user/repo") + createCmd.Flags().StringVarP(&opts.Branch, "branch", "b", "", "repository branch") + createCmd.Flags().StringVarP(&opts.Machine, "machine", "m", "", "hardware specifications for the VM") + createCmd.Flags().BoolVarP(&opts.ShowStatus, "status", "s", false, "show status of post-create command and dotfiles") return createCmd } @@ -38,18 +48,18 @@ func init() { rootCmd.AddCommand(newCreateCmd()) } -func Create() error { +func Create(opts *CreateOptions) error { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) locationCh := getLocation(ctx, apiClient) userCh := getUser(ctx, apiClient) log := output.NewLogger(os.Stdout, os.Stderr, false) - repo, err := getRepoName() + repo, err := getRepoName(opts.Repo) if err != nil { return fmt.Errorf("error getting repository name: %v", err) } - branch, err := getBranchName() + branch, err := getBranchName(opts.Branch) if err != nil { return fmt.Errorf("error getting branch name: %v", err) } @@ -69,7 +79,7 @@ func Create() error { return fmt.Errorf("error getting codespace user: %v", userResult.Err) } - machine, err := getMachineName(ctx, userResult.User, repository, locationResult.Location, apiClient) + machine, err := getMachineName(ctx, opts.Machine, userResult.User, repository, locationResult.Location, apiClient) if err != nil { return fmt.Errorf("error getting machine type: %v", err) } @@ -84,7 +94,19 @@ func Create() error { return fmt.Errorf("error creating codespace: %v", err) } - states, err := codespaces.PollPostCreateStates(ctx, log, apiClient, userResult.User, codespace) + if opts.ShowStatus { + if err := showStatus(ctx, log, apiClient, userResult.User, codespace); err != nil { + return fmt.Errorf("show status: %w", err) + } + } + + log.Printf("Codespace created: %s\n", codespace.Name) + + return nil +} + +func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, user *api.User, codespace *api.Codespace) error { + states, err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace) if err != nil { return fmt.Errorf("poll post create states: %v", err) } @@ -131,8 +153,6 @@ func Create() error { } } - log.Printf("Codespace created: %s\n", codespace.Name) - return nil } @@ -164,7 +184,7 @@ func getLocation(ctx context.Context, apiClient *api.API) <-chan locationResult return ch } -func getRepoName() (string, error) { +func getRepoName(repo string) (string, error) { if repo != "" { return repo, nil } @@ -180,7 +200,7 @@ func getRepoName() (string, error) { return repo, err } -func getBranchName() (string, error) { +func getBranchName(branch string) (string, error) { if branch != "" { return branch, nil } @@ -196,7 +216,7 @@ func getBranchName() (string, error) { return branch, err } -func getMachineName(ctx context.Context, user *api.User, repo *api.Repository, location string, apiClient *api.API) (string, error) { +func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, location string, apiClient *api.API) (string, error) { skus, err := apiClient.GetCodespacesSkus(ctx, user, repo, location) if err != nil { return "", fmt.Errorf("error getting codespace skus: %v", err) From 151eb2b656e8ee2b4864653f3f785133f506c47e Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 26 Aug 2021 08:35:30 -0400 Subject: [PATCH 086/290] fix linter --- cmd/ghcs/create.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 9b25b00b5..d08c4913d 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -15,8 +15,6 @@ import ( "github.com/spf13/cobra" ) -var repo, branch, machine string - type CreateOptions struct { Repo string Branch string @@ -132,7 +130,6 @@ func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, use } inProgress = true - break case codespaces.PostCreateStateFailed: if lastState.Name == state.Name && lastState.Status != state.Status { lastState = state From b6094e0006b8fd73c390429b2b46166048706b84 Mon Sep 17 00:00:00 2001 From: Edmundo Gonzalez <51725820+edgonmsft@users.noreply.github.com> Date: Thu, 26 Aug 2021 21:50:20 +0000 Subject: [PATCH 087/290] Changes to point to RPC service. --- cmd/ghcs/logs.go | 13 ++++- cmd/ghcs/ssh.go | 88 +++---------------------------- internal/codespaces/codespaces.go | 24 +++++++++ internal/codespaces/ssh.go | 4 +- 4 files changed, 43 insertions(+), 86 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index dd8664597..6f93ee3b9 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -57,7 +57,16 @@ func Logs(tail bool, codespaceName string) error { return fmt.Errorf("connecting to liveshare: %v", err) } - tunnelPort, connClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, 0) + result, remoteSSHServerPort, sshUser, _, err := codespaces.StartSSHServer(ctx, lsclient) + if err != nil { + return fmt.Errorf("error getting ssh server details: %v", err) + } + + if !result { + return fmt.Errorf("error starting ssh: %v", err) + } + + tunnelPort, connClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, 0, remoteSSHServerPort) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) } @@ -67,7 +76,7 @@ func Logs(tail bool, codespaceName string) error { cmdType = "tail -f" } - dst := fmt.Sprintf("%s@localhost", getSSHUser(codespace)) + dst := fmt.Sprintf("%s@localhost", sshUser) stdout, err := codespaces.RunCommand( ctx, tunnelPort, dst, fmt.Sprintf("%v /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 1754f968a..f3e621824 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -1,17 +1,13 @@ package main import ( - "bufio" "context" "fmt" "os" - "strings" - "time" "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/codespaces" - "github.com/github/go-liveshare" "github.com/spf13/cobra" ) @@ -59,33 +55,23 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { return fmt.Errorf("error connecting to liveshare: %v", err) } - terminal, err := liveshare.NewTerminal(lsclient) + result, remoteSSHServerPort, sshUser, _, err := codespaces.StartSSHServer(ctx, lsclient) if err != nil { - return fmt.Errorf("error creating liveshare terminal: %v", err) + return fmt.Errorf("error getting ssh server details: %v", err) } - log.Println("Preparing SSH...") - if sshProfile == "" { - containerID, err := getContainerID(ctx, log, terminal) - if err != nil { - return fmt.Errorf("error getting container id: %v", err) - } - - if err := setupSSH(ctx, log, terminal, containerID, codespace.RepositoryName); err != nil { - return fmt.Errorf("error creating ssh server: %v", err) - } - - log.Print("\n") + if !result { + return fmt.Errorf("error starting ssh: %v", err) } - tunnelPort, tunnelClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, sshServerPort) + tunnelPort, tunnelClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, sshServerPort, remoteSSHServerPort) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) } connectDestination := sshProfile if connectDestination == "" { - connectDestination = fmt.Sprintf("%s@localhost", getSSHUser(codespace)) + connectDestination = fmt.Sprintf("%s@localhost", sshUser) } usingCustomPort := tunnelPort == sshServerPort @@ -105,65 +91,3 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { return nil } - -func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) { - logger.Print(".") - - cmd := terminal.NewCommand( - "/", - "/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running", - ) - - stream, err := cmd.Run(ctx) - if err != nil { - return "", fmt.Errorf("error running command: %v", err) - } - - logger.Print(".") - scanner := bufio.NewScanner(stream) - scanner.Scan() - - logger.Print(".") - containerID := scanner.Text() - if err := scanner.Err(); err != nil { - return "", fmt.Errorf("error scanning stream: %v", err) - } - - logger.Print(".") - if err := stream.Close(); err != nil { - return "", fmt.Errorf("error closing stream: %v", err) - } - - return containerID, nil -} - -func setupSSH(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal, containerID, repositoryName string) error { - setupBashProfileCmd := fmt.Sprintf(`echo "cd /workspaces/%v; export $(cat /workspaces/.codespaces/shared/.env | xargs); exec /bin/zsh;" > /home/codespace/.bash_profile`, repositoryName) - - logger.Print(".") - compositeCommand := []string{setupBashProfileCmd} - cmd := terminal.NewCommand( - "/", - fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c '"+strings.Join(compositeCommand, "; ")+"'", containerID), - ) - stream, err := cmd.Run(ctx) - if err != nil { - return fmt.Errorf("error running command: %v", err) - } - - logger.Print(".") - if err := stream.Close(); err != nil { - return fmt.Errorf("error closing stream: %v", err) - } - - time.Sleep(1 * time.Second) - - return nil -} - -func getSSHUser(codespace *api.Codespace) string { - if codespace.RepositoryNWO == "github/github" { - return "root" - } - return "codespace" -} diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 48369cfa0..30173f018 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strconv" "time" "github.com/AlecAivazis/survey/v2" @@ -117,6 +118,29 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, tok return lsclient, nil } +func StartSSHServer(ctx context.Context, client *liveshare.Client) (result bool, serverPort int, user string, message string, err error) { + sshRpc, err := liveshare.NewSSHRpc(client) + if err != nil { + return false, 0, "", "", fmt.Errorf("error creating live share: %v", err) + } + + sshRpcResult, err := sshRpc.StartRemoteServer(ctx) + if err != nil { + return false, 0, "", "", fmt.Errorf("error creating live share: %v", err) + } + + if !sshRpcResult.Result { + return false, 0, "", sshRpcResult.Message, nil + } + + portInt, err := strconv.Atoi(sshRpcResult.ServerPort) + if err != nil { + return false, 0, "", "", fmt.Errorf("error parsing port: %v", err) + } + + return sshRpcResult.Result, portInt, sshRpcResult.User, sshRpcResult.Message, err +} + func GetOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { if codespaceName == "" { codespace, err = ChooseCodespace(ctx, apiClient, user) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 672ba3b7b..cf6118704 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -14,7 +14,7 @@ import ( "github.com/github/go-liveshare" ) -func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, serverPort int) (int, <-chan error, error) { +func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, serverPort int, remoteSSHPort int) (int, <-chan error, error) { tunnelClosed := make(chan error) server, err := liveshare.NewServer(lsclient) @@ -29,7 +29,7 @@ func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, serverPort i } // TODO(josebalius): This port won't always be 2222 - if err := server.StartSharing(ctx, "sshd", 2222); err != nil { + if err := server.StartSharing(ctx, "sshd", remoteSSHPort); err != nil { return 0, nil, fmt.Errorf("sharing sshd port: %v", err) } From a89c17a564b9a9a9bbee261d3b4a158c4efff36d Mon Sep 17 00:00:00 2001 From: Edmundo Gonzalez <51725820+edgonmsft@users.noreply.github.com> Date: Thu, 26 Aug 2021 22:34:56 +0000 Subject: [PATCH 088/290] Adding sshRPC interface --- sshRpc.go | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 sshRpc.go diff --git a/sshRpc.go b/sshRpc.go new file mode 100644 index 000000000..78ac90f82 --- /dev/null +++ b/sshRpc.go @@ -0,0 +1,34 @@ +package liveshare + +import ( + "context" + "errors" +) + +type SshRpc struct { + client *Client +} + +func NewSSHRpc(client *Client) (*SshRpc, error) { + if !client.hasJoined() { + return nil, errors.New("client must join before creating server") + } + return &SshRpc{client: client}, nil +} + +type SshServerStartResult struct { + Result bool `json:"result"` + ServerPort string `json:"serverPort"` + User string `json:"user"` + Message string `json:"message"` +} + +func (s *SshRpc) StartRemoteServer(ctx context.Context) (SshServerStartResult, error) { + var response SshServerStartResult + + if err := s.client.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { + return response, err + } + + return response, nil +} From 18ab421b0846e7a152d431646b12702f14b26ba9 Mon Sep 17 00:00:00 2001 From: Edmundo Gonzalez <51725820+edgonmsft@users.noreply.github.com> Date: Thu, 26 Aug 2021 23:04:05 +0000 Subject: [PATCH 089/290] Rename to SSHServer --- sshRpc.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sshRpc.go b/sshRpc.go index 78ac90f82..ec7d8dfd1 100644 --- a/sshRpc.go +++ b/sshRpc.go @@ -5,26 +5,26 @@ import ( "errors" ) -type SshRpc struct { +type SSHServer struct { client *Client } -func NewSSHRpc(client *Client) (*SshRpc, error) { +func NewSSHServer(client *Client) (*SSHServer, error) { if !client.hasJoined() { return nil, errors.New("client must join before creating server") } - return &SshRpc{client: client}, nil + return &SSHServer{client: client}, nil } -type SshServerStartResult struct { +type SSHServerStartResult struct { Result bool `json:"result"` ServerPort string `json:"serverPort"` User string `json:"user"` Message string `json:"message"` } -func (s *SshRpc) StartRemoteServer(ctx context.Context) (SshServerStartResult, error) { - var response SshServerStartResult +func (s *SSHServer) StartRemoteServer(ctx context.Context) (SSHServerStartResult, error) { + var response SSHServerStartResult if err := s.client.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { return response, err From 0eb769d608552e54f6db6bdb53e70996893a6acb Mon Sep 17 00:00:00 2001 From: Edmundo Gonzalez <51725820+edgonmsft@users.noreply.github.com> Date: Thu, 26 Aug 2021 23:04:35 +0000 Subject: [PATCH 090/290] Rename File --- sshRpc.go => sshServer.go | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sshRpc.go => sshServer.go (100%) diff --git a/sshRpc.go b/sshServer.go similarity index 100% rename from sshRpc.go rename to sshServer.go From d5a26e1536048ab6294b064960f22c8e5ced71cc Mon Sep 17 00:00:00 2001 From: Edmundo Gonzalez <51725820+edgonmsft@users.noreply.github.com> Date: Thu, 26 Aug 2021 23:14:13 +0000 Subject: [PATCH 091/290] Apply renames on the go-liveshare side. --- internal/codespaces/codespaces.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 30173f018..3db66b427 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -119,26 +119,26 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, tok } func StartSSHServer(ctx context.Context, client *liveshare.Client) (result bool, serverPort int, user string, message string, err error) { - sshRpc, err := liveshare.NewSSHRpc(client) + sshServer, err := liveshare.NewSSHServer(client) if err != nil { return false, 0, "", "", fmt.Errorf("error creating live share: %v", err) } - sshRpcResult, err := sshRpc.StartRemoteServer(ctx) + sshServerStartResult, err := sshServer.StartRemoteServer(ctx) if err != nil { return false, 0, "", "", fmt.Errorf("error creating live share: %v", err) } - if !sshRpcResult.Result { - return false, 0, "", sshRpcResult.Message, nil + if !sshServerStartResult.Result { + return false, 0, "", sshServerStartResult.Message, nil } - portInt, err := strconv.Atoi(sshRpcResult.ServerPort) + portInt, err := strconv.Atoi(sshServerStartResult.ServerPort) if err != nil { return false, 0, "", "", fmt.Errorf("error parsing port: %v", err) } - return sshRpcResult.Result, portInt, sshRpcResult.User, sshRpcResult.Message, err + return sshServerStartResult.Result, portInt, sshServerStartResult.User, sshServerStartResult.Message, err } func GetOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { From 273782bcbcb06bb143d28f322fcc1e935e378737 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 11:49:21 +0000 Subject: [PATCH 092/290] rename file --- sshServer.go => ssh_server.go | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sshServer.go => ssh_server.go (100%) diff --git a/sshServer.go b/ssh_server.go similarity index 100% rename from sshServer.go rename to ssh_server.go From 3dcee5cca72f0a70aed7cdc270cd4a0d6f0d584b Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 12:41:36 +0000 Subject: [PATCH 093/290] remove dst port column and add docs --- cmd/ghcs/ports.go | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 09397af54..9e2713a5e 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -19,11 +19,16 @@ import ( "golang.org/x/sync/errgroup" ) +// PortOptions represents the options accepted by the ports command. type PortsOptions struct { + // CodespaceName is the name of the codespace, optional CodespaceName string - AsJSON bool + + // AsJSON dictates whether the command returns a json output or not, optional + AsJSON bool } +// NewPortsCmd returns a new cobra command representing the ports command and sub commands func NewPortsCmd() *cobra.Command { opts := &PortsOptions{} @@ -50,6 +55,7 @@ func init() { rootCmd.AddCommand(NewPortsCmd()) } +// Ports accepts a PortOptions pointer and logs a list the list of available open ports found in a codespace func Ports(opts *PortsOptions) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() @@ -87,7 +93,7 @@ func Ports(opts *PortsOptions) error { } table := output.NewTable(os.Stdout, opts.AsJSON) - table.SetHeader([]string{"Label", "Source Port", "Destination Port", "Public", "Browse URL"}) + table.SetHeader([]string{"Label", "Port", "Public", "Browse URL"}) for _, port := range ports { sourcePort := strconv.Itoa(port.SourcePort) var portName string @@ -100,7 +106,6 @@ func Ports(opts *PortsOptions) error { table.Append([]string{ portName, sourcePort, - strconv.Itoa(port.DestinationPort), strings.ToUpper(strconv.FormatBool(port.IsPublic)), fmt.Sprintf("https://%s-%s.githubpreview.dev/", codespace.Name, sourcePort), }) @@ -168,6 +173,8 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod return ch } +// NewPortsPublicCmd returns a cobra command representing the ports subcommand used +// to make a given port public func NewPortsPublicCmd() *cobra.Command { return &cobra.Command{ Use: "public ", @@ -180,6 +187,8 @@ func NewPortsPublicCmd() *cobra.Command { } } +// NewPortsPrivateCmd rturns a cobra command representing the ports subcommand used +// to make a given port private func NewPortsPrivateCmd() *cobra.Command { return &cobra.Command{ Use: "private ", @@ -239,6 +248,8 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, return nil } +// NewPortsForwardCmd returns a cobra command representing the ports subcommand used to forward +// ports from the codespace to localhost, it supports multiple ports to be forwarded at once func NewPortsForwardCmd() *cobra.Command { return &cobra.Command{ Use: "forward :", From 0392c5017408cac6e63b71dd13b7e318350fd225 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 27 Aug 2021 11:25:24 -0400 Subject: [PATCH 094/290] api: close HTTP response body on all paths --- api/api.go | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/api/api.go b/api/api.go index 83510d8c5..25faf5724 100644 --- a/api/api.go +++ b/api/api.go @@ -44,6 +44,7 @@ func (a *API) GetUser(ctx context.Context) (*User, error) { if err != nil { return nil, fmt.Errorf("error making request: %v", err) } + defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -86,6 +87,7 @@ func (a *API) GetRepository(ctx context.Context, nwo string) (*Repository, error if err != nil { return nil, fmt.Errorf("error making request: %v", err) } + defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -152,6 +154,7 @@ func (a *API) ListCodespaces(ctx context.Context, user *User) (Codespaces, error if err != nil { return nil, fmt.Errorf("error making request: %v", err) } + defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -199,6 +202,7 @@ func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName s if err != nil { return "", fmt.Errorf("error making request: %v", err) } + defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -232,6 +236,7 @@ func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) if err != nil { return nil, fmt.Errorf("error making request: %v", err) } + defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -261,10 +266,13 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes } req.Header.Set("Authorization", "Bearer "+token) - _, err = a.client.Do(req) + resp, err := a.client.Do(req) if err != nil { return fmt.Errorf("error making request: %v", err) } + defer resp.Body.Close() + + // TODO: check status code? return nil } @@ -283,12 +291,15 @@ func (a *API) GetCodespaceRegionLocation(ctx context.Context) (string, error) { if err != nil { return "", fmt.Errorf("error making request: %v", err) } + defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("error reading response body: %v", err) } + // TODO: check status code? + var response getCodespaceRegionLocationResponse if err := json.Unmarshal(b, &response); err != nil { return "", fmt.Errorf("error unmarshaling response: %v", err) @@ -320,12 +331,15 @@ func (a *API) GetCodespacesSkus(ctx context.Context, user *User, repository *Rep if err != nil { return nil, fmt.Errorf("error making request: %v", err) } + defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("error reading response body: %v", err) } + // TODO: check status code? + response := struct { Skus Skus `json:"skus"` }{} @@ -359,6 +373,7 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos if err != nil { return nil, fmt.Errorf("error making request: %v", err) } + defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -388,6 +403,7 @@ func (a *API) DeleteCodespace(ctx context.Context, user *User, token, codespaceN if err != nil { return fmt.Errorf("error making request: %v", err) } + defer resp.Body.Close() if resp.StatusCode > http.StatusAccepted { b, err := ioutil.ReadAll(resp.Body) @@ -419,6 +435,7 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod if err != nil { return nil, fmt.Errorf("error making request: %v", err) } + defer resp.Body.Close() if resp.StatusCode == http.StatusNotFound { return nil, nil From 5dc923777be8c5ca232128c2b7924420b49b6bfd Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 15:32:18 +0000 Subject: [PATCH 095/290] update docs, make ports private to be more consistent --- cmd/ghcs/ports.go | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 9e2713a5e..2318097aa 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -21,14 +21,15 @@ import ( // PortOptions represents the options accepted by the ports command. type PortsOptions struct { - // CodespaceName is the name of the codespace, optional + // CodespaceName is the name of the codespace, optional. CodespaceName string - // AsJSON dictates whether the command returns a json output or not, optional + // AsJSON dictates whether the command returns a json output or not, optional. AsJSON bool } -// NewPortsCmd returns a new cobra command representing the ports command and sub commands +// NewPortsCmd returns a Cobra "ports" command that displays a table of available ports, +// according to the specified flags. func NewPortsCmd() *cobra.Command { opts := &PortsOptions{} @@ -37,7 +38,7 @@ func NewPortsCmd() *cobra.Command { Short: "List ports in a Codespace", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return Ports(opts) + return ports(opts) }, } @@ -55,8 +56,7 @@ func init() { rootCmd.AddCommand(NewPortsCmd()) } -// Ports accepts a PortOptions pointer and logs a list the list of available open ports found in a codespace -func Ports(opts *PortsOptions) error { +func ports(opts *PortsOptions) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() log := output.NewLogger(os.Stdout, os.Stderr, opts.AsJSON) @@ -173,7 +173,7 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod return ch } -// NewPortsPublicCmd returns a cobra command representing the ports subcommand used +// NewPortsPublicCmd returns a Cobra "ports public" subcommand, which makes a given port public. // to make a given port public func NewPortsPublicCmd() *cobra.Command { return &cobra.Command{ @@ -187,8 +187,7 @@ func NewPortsPublicCmd() *cobra.Command { } } -// NewPortsPrivateCmd rturns a cobra command representing the ports subcommand used -// to make a given port private +// NewPortsPrivateCmd returns a Cobra "ports private" subcommand, which makes a given port private. func NewPortsPrivateCmd() *cobra.Command { return &cobra.Command{ Use: "private ", @@ -248,8 +247,8 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, return nil } -// NewPortsForwardCmd returns a cobra command representing the ports subcommand used to forward -// ports from the codespace to localhost, it supports multiple ports to be forwarded at once +// NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of +// port pairs from the codespace to localhost. func NewPortsForwardCmd() *cobra.Command { return &cobra.Command{ Use: "forward :", From 8e95493872f31e953a33a86381e12ab93f7999f5 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 15:46:40 +0000 Subject: [PATCH 096/290] period --- cmd/ghcs/ports.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 2318097aa..ea501b73e 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -174,7 +174,7 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod } // NewPortsPublicCmd returns a Cobra "ports public" subcommand, which makes a given port public. -// to make a given port public +// to make a given port public. func NewPortsPublicCmd() *cobra.Command { return &cobra.Command{ Use: "public ", From 38ff786a7d2b209a74b3ab9bfd1ee3f2106323da Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 27 Aug 2021 11:25:24 -0400 Subject: [PATCH 097/290] cmd/ghcs: style tweaks --- api/api.go | 40 +++++++------------ cmd/ghcs/code.go | 8 ++-- cmd/ghcs/create.go | 6 +-- cmd/ghcs/delete.go | 22 +++++------ cmd/ghcs/list.go | 18 ++++----- cmd/ghcs/logs.go | 8 ++-- cmd/ghcs/main.go | 9 +++-- cmd/ghcs/ports.go | 64 ++++++++++++++++--------------- cmd/ghcs/ssh.go | 8 ++-- internal/codespaces/codespaces.go | 5 ++- 10 files changed, 92 insertions(+), 96 deletions(-) diff --git a/api/api.go b/api/api.go index 83510d8c5..9d9eae79f 100644 --- a/api/api.go +++ b/api/api.go @@ -1,3 +1,4 @@ +// TODO(adonovan): rename to package codespaces, and codespaces.Client? package api import ( @@ -9,7 +10,6 @@ import ( "fmt" "io/ioutil" "net/http" - "sort" "strconv" "strings" ) @@ -29,10 +29,6 @@ type User struct { Login string `json:"login"` } -type errResponse struct { - Message string `json:"message"` -} - func (a *API) GetUser(ctx context.Context) (*User, error) { req, err := http.NewRequest(http.MethodGet, githubAPI+"/user", nil) if err != nil { @@ -63,7 +59,9 @@ func (a *API) GetUser(ctx context.Context) (*User, error) { } func (a *API) errorResponse(b []byte) error { - var response errResponse + var response struct { + Message string `json:"message"` + } if err := json.Unmarshal(b, &response); err != nil { return fmt.Errorf("error unmarshaling error response: %v", err) } @@ -104,14 +102,6 @@ func (a *API) GetRepository(ctx context.Context, nwo string) (*Repository, error return &response, nil } -type Codespaces []*Codespace - -func (c Codespaces) SortByCreatedAt() { - sort.Slice(c, func(i, j int) bool { - return c[i].CreatedAt > c[j].CreatedAt - }) -} - type Codespace struct { Name string `json:"name"` GUID string `json:"guid"` @@ -139,7 +129,7 @@ type CodespaceEnvironmentConnection struct { RelaySAS string `json:"relaySas"` } -func (a *API) ListCodespaces(ctx context.Context, user *User) (Codespaces, error) { +func (a *API) ListCodespaces(ctx context.Context, user *User) ([]*Codespace, error) { req, err := http.NewRequest( http.MethodGet, githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces", nil, ) @@ -162,9 +152,9 @@ func (a *API) ListCodespaces(ctx context.Context, user *User) (Codespaces, error return nil, a.errorResponse(b) } - response := struct { - Codespaces Codespaces `json:"codespaces"` - }{} + var response struct { + Codespaces []*Codespace `json:"codespaces"` + } if err := json.Unmarshal(b, &response); err != nil { return nil, fmt.Errorf("error unmarshaling response: %v", err) } @@ -297,14 +287,12 @@ func (a *API) GetCodespaceRegionLocation(ctx context.Context) (string, error) { return response.Current, nil } -type Skus []*Sku - -type Sku struct { +type SKU struct { Name string `json:"name"` DisplayName string `json:"display_name"` } -func (a *API) GetCodespacesSkus(ctx context.Context, user *User, repository *Repository, location string) (Skus, error) { +func (a *API) GetCodespacesSkus(ctx context.Context, user *User, repository *Repository, location string) ([]*SKU, error) { req, err := http.NewRequest(http.MethodGet, githubAPI+"/vscs_internal/user/"+user.Login+"/skus", nil) if err != nil { return nil, fmt.Errorf("err creating request: %v", err) @@ -326,14 +314,14 @@ func (a *API) GetCodespacesSkus(ctx context.Context, user *User, repository *Rep return nil, fmt.Errorf("error reading response body: %v", err) } - response := struct { - Skus Skus `json:"skus"` - }{} + var response struct { + SKUs []*SKU `json:"skus"` + } if err := json.Unmarshal(b, &response); err != nil { return nil, fmt.Errorf("error unmarshaling response: %v", err) } - return response.Skus, nil + return response.SKUs, nil } type createCodespaceRequest struct { diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 81dbdbb2c..9bd4db634 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -12,7 +12,7 @@ import ( "github.com/spf13/cobra" ) -func NewCodeCmd() *cobra.Command { +func newCodeCmd() *cobra.Command { useInsiders := false codeCmd := &cobra.Command{ @@ -24,7 +24,7 @@ func NewCodeCmd() *cobra.Command { if len(args) > 0 { codespaceName = args[0] } - return Code(codespaceName, useInsiders) + return code(codespaceName, useInsiders) }, } @@ -34,10 +34,10 @@ func NewCodeCmd() *cobra.Command { } func init() { - rootCmd.AddCommand(NewCodeCmd()) + rootCmd.AddCommand(newCodeCmd()) } -func Code(codespaceName string, useInsiders bool) error { +func code(codespaceName string, useInsiders bool) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index c3e8a24a1..8b4e1a743 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -22,7 +22,7 @@ func newCreateCmd() *cobra.Command { Short: "Create a Codespace", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return Create() + return create() }, } @@ -37,7 +37,7 @@ func init() { rootCmd.AddCommand(newCreateCmd()) } -func Create() error { +func create() error { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) locationCh := getLocation(ctx, apiClient) @@ -176,7 +176,7 @@ func getMachineName(ctx context.Context, user *api.User, repo *api.Repository, l } skuNames := make([]string, 0, len(skus)) - skuByName := make(map[string]*api.Sku) + skuByName := make(map[string]*api.SKU) for _, sku := range skus { nameParts := camelcase.Split(sku.Name) machineName := strings.Title(strings.ToLower(nameParts[0])) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index c357171d1..d37029753 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -12,7 +12,7 @@ import ( "github.com/spf13/cobra" ) -func NewDeleteCmd() *cobra.Command { +func newDeleteCmd() *cobra.Command { deleteCmd := &cobra.Command{ Use: "delete []", Short: "Delete a Codespace", @@ -22,7 +22,7 @@ func NewDeleteCmd() *cobra.Command { if len(args) > 0 { codespaceName = args[0] } - return Delete(codespaceName) + return delete_(codespaceName) }, } @@ -31,7 +31,7 @@ func NewDeleteCmd() *cobra.Command { Short: "Delete all Codespaces for the current user", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return DeleteAll() + return deleteAll() }, } @@ -40,7 +40,7 @@ func NewDeleteCmd() *cobra.Command { Short: "Delete all Codespaces for a repository", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return DeleteByRepo(args[0]) + return deleteByRepo(args[0]) }, } @@ -50,10 +50,10 @@ func NewDeleteCmd() *cobra.Command { } func init() { - rootCmd.AddCommand(NewDeleteCmd()) + rootCmd.AddCommand(newDeleteCmd()) } -func Delete(codespaceName string) error { +func delete_(codespaceName string) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() log := output.NewLogger(os.Stdout, os.Stderr, false) @@ -74,10 +74,10 @@ func Delete(codespaceName string) error { log.Println("Codespace deleted.") - return List(&ListOptions{}) + return list(&listOptions{}) } -func DeleteAll() error { +func deleteAll() error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() log := output.NewLogger(os.Stdout, os.Stderr, false) @@ -105,10 +105,10 @@ func DeleteAll() error { log.Printf("Codespace deleted: %s\n", c.Name) } - return List(&ListOptions{}) + return list(&listOptions{}) } -func DeleteByRepo(repo string) error { +func deleteByRepo(repo string) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() log := output.NewLogger(os.Stdout, os.Stderr, false) @@ -146,5 +146,5 @@ func DeleteByRepo(repo string) error { return fmt.Errorf("No codespace was found for repository: %s", repo) } - return List(&ListOptions{}) + return list(&listOptions{}) } diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 27b11d4fd..a19439296 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -10,32 +10,32 @@ import ( "github.com/spf13/cobra" ) -type ListOptions struct { - AsJSON bool +type listOptions struct { + asJSON bool } -func NewListCmd() *cobra.Command { - opts := &ListOptions{} +func newListCmd() *cobra.Command { + opts := &listOptions{} listCmd := &cobra.Command{ Use: "list", Short: "List your Codespaces", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return List(opts) + return list(opts) }, } - listCmd.Flags().BoolVar(&opts.AsJSON, "json", false, "Output as JSON") + listCmd.Flags().BoolVar(&opts.asJSON, "json", false, "Output as JSON") return listCmd } func init() { - rootCmd.AddCommand(NewListCmd()) + rootCmd.AddCommand(newListCmd()) } -func List(opts *ListOptions) error { +func list(opts *listOptions) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() @@ -49,7 +49,7 @@ func List(opts *ListOptions) error { return fmt.Errorf("error getting codespaces: %v", err) } - table := output.NewTable(os.Stdout, opts.AsJSON) + table := output.NewTable(os.Stdout, opts.asJSON) table.SetHeader([]string{"Name", "Repository", "Branch", "State", "Created At"}) for _, codespace := range codespaces { table.Append([]string{ diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index dd8664597..006f9c477 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -12,7 +12,7 @@ import ( "github.com/spf13/cobra" ) -func NewLogsCmd() *cobra.Command { +func newLogsCmd() *cobra.Command { var tail bool logsCmd := &cobra.Command{ @@ -24,7 +24,7 @@ func NewLogsCmd() *cobra.Command { if len(args) > 0 { codespaceName = args[0] } - return Logs(tail, codespaceName) + return logs(tail, codespaceName) }, } @@ -34,10 +34,10 @@ func NewLogsCmd() *cobra.Command { } func init() { - rootCmd.AddCommand(NewLogsCmd()) + rootCmd.AddCommand(newLogsCmd()) } -func Logs(tail bool, codespaceName string) error { +func logs(tail bool, codespaceName string) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() log := output.NewLogger(os.Stdout, os.Stderr, false) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 58037437a..dbf1dc714 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -1,5 +1,8 @@ package main +// TODO(adonovan): write 'help' commands, in manner of the 'go' tool. +// Document GITHUB_TOKEN. + import ( "errors" "fmt" @@ -16,7 +19,7 @@ func main() { } } -var Version = "DEV" +var version = "DEV" var rootCmd = &cobra.Command{ Use: "ghcs", @@ -24,7 +27,7 @@ var rootCmd = &cobra.Command{ Running commands requires the GITHUB_TOKEN environment variable to be set to a token to access the GitHub API with.`, - Version: Version, + Version: version, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { if os.Getenv("GITHUB_TOKEN") == "" { @@ -42,5 +45,5 @@ func explainError(w io.Writer, err error) { fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") return } - // fmt.Fprintf(w, "%v\n", err) + fmt.Fprintf(w, "%v\n", err) } diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 09397af54..c5d127892 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -19,48 +19,48 @@ import ( "golang.org/x/sync/errgroup" ) -type PortsOptions struct { - CodespaceName string - AsJSON bool +type portsOptions struct { + codespaceName string + asJSON bool } -func NewPortsCmd() *cobra.Command { - opts := &PortsOptions{} +func newPortsCmd() *cobra.Command { + opts := &portsOptions{} portsCmd := &cobra.Command{ Use: "ports", Short: "List ports in a Codespace", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return Ports(opts) + return ports(opts) }, } - portsCmd.Flags().StringVarP(&opts.CodespaceName, "codespace", "c", "", "The `name` of the Codespace to use") - portsCmd.Flags().BoolVar(&opts.AsJSON, "json", false, "Output as JSON") + portsCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "The `name` of the Codespace to use") + portsCmd.Flags().BoolVar(&opts.asJSON, "json", false, "Output as JSON") - portsCmd.AddCommand(NewPortsPublicCmd()) - portsCmd.AddCommand(NewPortsPrivateCmd()) - portsCmd.AddCommand(NewPortsForwardCmd()) + portsCmd.AddCommand(newPortsPublicCmd()) + portsCmd.AddCommand(newPortsPrivateCmd()) + portsCmd.AddCommand(newPortsForwardCmd()) return portsCmd } func init() { - rootCmd.AddCommand(NewPortsCmd()) + rootCmd.AddCommand(newPortsCmd()) } -func Ports(opts *PortsOptions) error { +func ports(opts *portsOptions) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() - log := output.NewLogger(os.Stdout, os.Stderr, opts.AsJSON) + log := output.NewLogger(os.Stdout, os.Stderr, opts.asJSON) user, err := apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %v", err) } - codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, opts.CodespaceName) + codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, opts.codespaceName) if err != nil { if err == codespaces.ErrNoCodespaces { return err @@ -82,17 +82,18 @@ func Ports(opts *PortsOptions) error { } devContainerResult := <-devContainerCh - if devContainerResult.Err != nil { - _, _ = log.Errorf("Failed to get port names: %v\n", devContainerResult.Err.Error()) + if devContainerResult.err != nil { + _, _ = log.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error()) + // TODO(adonovan): should this cause non-zero exit? } - table := output.NewTable(os.Stdout, opts.AsJSON) + table := output.NewTable(os.Stdout, opts.asJSON) table.SetHeader([]string{"Label", "Source Port", "Destination Port", "Public", "Browse URL"}) for _, port := range ports { sourcePort := strconv.Itoa(port.SourcePort) var portName string - if devContainerResult.DevContainer != nil { - if attributes, ok := devContainerResult.DevContainer.PortAttributes[sourcePort]; ok { + if devContainerResult.devContainer != nil { + if attributes, ok := devContainerResult.devContainer.PortAttributes[sourcePort]; ok { portName = attributes.Label } } @@ -125,8 +126,8 @@ func getPorts(ctx context.Context, lsclient *liveshare.Client) (liveshare.Ports, } type devContainerResult struct { - DevContainer *devContainer - Err error + devContainer *devContainer + err error } type devContainer struct { @@ -168,7 +169,7 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod return ch } -func NewPortsPublicCmd() *cobra.Command { +func newPortsPublicCmd() *cobra.Command { return &cobra.Command{ Use: "public ", Short: "Mark port as public", @@ -180,7 +181,7 @@ func NewPortsPublicCmd() *cobra.Command { } } -func NewPortsPrivateCmd() *cobra.Command { +func newPortsPrivateCmd() *cobra.Command { return &cobra.Command{ Use: "private ", Short: "Mark port as private", @@ -239,7 +240,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, return nil } -func NewPortsForwardCmd() *cobra.Command { +func newPortsForwardCmd() *cobra.Command { return &cobra.Command{ Use: "forward :", Short: "Forward ports", @@ -289,14 +290,14 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro for _, portPair := range portPairs { pp := portPair - srcstr := strconv.Itoa(portPair.Src) - if err := server.StartSharing(gctx, "share-"+srcstr, pp.Src); err != nil { + srcstr := strconv.Itoa(portPair.src) + if err := server.StartSharing(gctx, "share-"+srcstr, pp.src); err != nil { return fmt.Errorf("start sharing port: %v", err) } g.Go(func() error { - log.Println("Forwarding port: " + srcstr + " ==> " + strconv.Itoa(pp.Dst)) - portForwarder := liveshare.NewPortForwarder(lsclient, server, pp.Dst) + log.Println("Forwarding port: " + srcstr + " ==> " + strconv.Itoa(pp.dst)) + portForwarder := liveshare.NewPortForwarder(lsclient, server, pp.dst) if err := portForwarder.Start(gctx); err != nil { return fmt.Errorf("error forwarding port: %v", err) } @@ -313,16 +314,17 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro } type portPair struct { - Src, Dst int + src, dst int } +// getPortPairs parses a list of strings of form "%d:%d" into pairs of numbers. func getPortPairs(ports []string) ([]portPair, error) { pp := make([]portPair, 0, len(ports)) for _, portString := range ports { parts := strings.Split(portString, ":") if len(parts) < 2 { - return pp, fmt.Errorf("port pair: '%v' is not valid", portString) + return nil, fmt.Errorf("port pair: '%v' is not valid", portString) } srcp, err := strconv.Atoi(parts[0]) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 1754f968a..a02dd557e 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -15,7 +15,7 @@ import ( "github.com/spf13/cobra" ) -func NewSSHCmd() *cobra.Command { +func newSSHCmd() *cobra.Command { var sshProfile, codespaceName string var sshServerPort int @@ -24,7 +24,7 @@ func NewSSHCmd() *cobra.Command { Short: "SSH into a Codespace", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return SSH(sshProfile, codespaceName, sshServerPort) + return ssh(sshProfile, codespaceName, sshServerPort) }, } @@ -36,10 +36,10 @@ func NewSSHCmd() *cobra.Command { } func init() { - rootCmd.AddCommand(NewSSHCmd()) + rootCmd.AddCommand(newSSHCmd()) } -func SSH(sshProfile, codespaceName string, sshServerPort int) error { +func ssh(sshProfile, codespaceName string, sshServerPort int) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() log := output.NewLogger(os.Stdout, os.Stderr, false) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 48369cfa0..a346c06d7 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sort" "time" "github.com/AlecAivazis/survey/v2" @@ -25,7 +26,9 @@ func ChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (* return nil, ErrNoCodespaces } - codespaces.SortByCreatedAt() + sort.Slice(codespaces, func(i, j int) bool { + return codespaces[i].CreatedAt > codespaces[j].CreatedAt + }) codespacesByName := make(map[string]*api.Codespace) codespacesNames := make([]string, 0, len(codespaces)) From cb6552f4cae0d5255471319455f24254a5bcfae0 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 12:50:32 -0400 Subject: [PATCH 098/290] more efficient impl for processing states --- cmd/ghcs/create.go | 39 ++++++++++++++++++------------- internal/codespaces/codespaces.go | 11 ++++++++- internal/codespaces/states.go | 5 ++++ 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index fb9834351..9ffcc58d1 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -112,6 +112,7 @@ func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, use } var lastState codespaces.PostCreateState + finishedStates := make(map[string]bool) var breakNextState bool for { @@ -122,25 +123,31 @@ func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, use var inProgress bool for _, state := range stateUpdate.PostCreateStates { - switch state.Status { - case codespaces.PostCreateStateRunning: - if lastState != state { - lastState = state - log.Print(state.Name) - } else { - log.Print(".") - } + if _, found := finishedStates[state.Name]; found { + continue // skip this state as we've processed it already + } - inProgress = true - case codespaces.PostCreateStateFailed: - if lastState.Name == state.Name && lastState.Status != state.Status { + if state.Name != lastState.Name { + log.Print(state.Name) + + if state.Status == codespaces.PostCreateStateRunning { + inProgress = true lastState = state - log.Print(".Failed\n") + log.Print("...") + break + } else { + finishedStates[state.Name] = true + log.Println("..." + state.Status) } - case codespaces.PostCreateStateSuccess: - if lastState.Name == state.Name && lastState.Status != state.Status { - lastState = state - log.Print(".Success\n") + } else { + if state.Status == codespaces.PostCreateStateRunning { + inProgress = true + log.Print(".") + break + } else { + finishedStates[state.Name] = true + log.Println(state.Status) + lastState = codespaces.PostCreateState{} // reset the value } } } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 005ea0fda..11e0a8902 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -62,6 +62,15 @@ type logger interface { Println(v ...interface{}) (int, error) } +func connectionReady(codespace *api.Codespace) bool { + ready := codespace.Environment.Connection.SessionID != "" + ready = ready && codespace.Environment.Connection.SessionToken != "" + ready = ready && codespace.Environment.Connection.RelayEndpoint != "" + ready = ready && codespace.Environment.Connection.RelaySAS != "" + ready = ready && codespace.Environment.State == api.CodespaceEnvironmentStateAvailable + return ready +} + func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (client *liveshare.Client, err error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { @@ -73,7 +82,7 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use } retries := 0 - for codespace.Environment.Connection.SessionID == "" || codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { + for !connectionReady(codespace) { if retries > 1 { if retries%2 == 0 { log.Print(".") diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 078dc546c..3ad6c9a4c 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "strings" "time" "github.com/github/ghcs/api" @@ -12,6 +13,10 @@ import ( type PostCreateStateStatus string +func (p PostCreateStateStatus) String() string { + return strings.Title(string(p)) +} + const ( PostCreateStateRunning PostCreateStateStatus = "running" PostCreateStateSuccess PostCreateStateStatus = "succeeded" From da8655209b59fe4a5791e00d26fbbe63f9dc1db4 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 12:52:30 -0400 Subject: [PATCH 099/290] make things private --- cmd/ghcs/create.go | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 9ffcc58d1..bb416d53f 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -15,29 +15,29 @@ import ( "github.com/spf13/cobra" ) -type CreateOptions struct { - Repo string - Branch string - Machine string - ShowStatus bool +type createOptions struct { + repo string + branch string + machine string + showStatus bool } func newCreateCmd() *cobra.Command { - opts := &CreateOptions{} + opts := &createOptions{} createCmd := &cobra.Command{ Use: "create", Short: "Create a Codespace", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return Create(opts) + return create(opts) }, } - createCmd.Flags().StringVarP(&opts.Repo, "repo", "r", "", "repository name with owner: user/repo") - createCmd.Flags().StringVarP(&opts.Branch, "branch", "b", "", "repository branch") - createCmd.Flags().StringVarP(&opts.Machine, "machine", "m", "", "hardware specifications for the VM") - createCmd.Flags().BoolVarP(&opts.ShowStatus, "status", "s", false, "show status of post-create command and dotfiles") + createCmd.Flags().StringVarP(&opts.repo, "repo", "r", "", "repository name with owner: user/repo") + createCmd.Flags().StringVarP(&opts.branch, "branch", "b", "", "repository branch") + createCmd.Flags().StringVarP(&opts.machine, "machine", "m", "", "hardware specifications for the VM") + createCmd.Flags().BoolVarP(&opts.showStatus, "status", "s", false, "show status of post-create command and dotfiles") return createCmd } @@ -46,18 +46,18 @@ func init() { rootCmd.AddCommand(newCreateCmd()) } -func Create(opts *CreateOptions) error { +func create(opts *createOptions) error { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) locationCh := getLocation(ctx, apiClient) userCh := getUser(ctx, apiClient) log := output.NewLogger(os.Stdout, os.Stderr, false) - repo, err := getRepoName(opts.Repo) + repo, err := getRepoName(opts.repo) if err != nil { return fmt.Errorf("error getting repository name: %v", err) } - branch, err := getBranchName(opts.Branch) + branch, err := getBranchName(opts.branch) if err != nil { return fmt.Errorf("error getting branch name: %v", err) } @@ -77,7 +77,7 @@ func Create(opts *CreateOptions) error { return fmt.Errorf("error getting codespace user: %v", userResult.Err) } - machine, err := getMachineName(ctx, opts.Machine, userResult.User, repository, locationResult.Location, apiClient) + machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, locationResult.Location, apiClient) if err != nil { return fmt.Errorf("error getting machine type: %v", err) } @@ -92,7 +92,7 @@ func Create(opts *CreateOptions) error { return fmt.Errorf("error creating codespace: %v", err) } - if opts.ShowStatus { + if opts.showStatus { if err := showStatus(ctx, log, apiClient, userResult.User, codespace); err != nil { return fmt.Errorf("show status: %w", err) } From 2cc91c224edcbc8a266c8b82c9b898364f399cc3 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 12:53:27 -0400 Subject: [PATCH 100/290] add comment for improvements --- internal/codespaces/codespaces.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 11e0a8902..8897c0839 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -62,6 +62,7 @@ type logger interface { Println(v ...interface{}) (int, error) } +// TODO(josebalius): we should move some of this to the liveshare.Connection struct func connectionReady(codespace *api.Codespace) bool { ready := codespace.Environment.Connection.SessionID != "" ready = ready && codespace.Environment.Connection.SessionToken != "" From 90f3ac6f56a5b44fc65da1dd3003e97b2f4b378b Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 27 Aug 2021 14:23:33 -0400 Subject: [PATCH 101/290] check status codes --- api/api.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/api/api.go b/api/api.go index 25faf5724..fb386a57b 100644 --- a/api/api.go +++ b/api/api.go @@ -272,7 +272,17 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes } defer resp.Body.Close() - // TODO: check status code? + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error reading response body: %v", err) + } + + // TODO(adonovan): the status code proxied from VSCS may distinguish + // "already running" from "fresh start". Find out what code it uses + // and allow it too. + if resp.StatusCode != http.StatusOK { + return a.errorResponse(b) + } return nil } @@ -298,7 +308,9 @@ func (a *API) GetCodespaceRegionLocation(ctx context.Context) (string, error) { return "", fmt.Errorf("error reading response body: %v", err) } - // TODO: check status code? + if resp.StatusCode != http.StatusOK { + return "", a.errorResponse(b) + } var response getCodespaceRegionLocationResponse if err := json.Unmarshal(b, &response); err != nil { @@ -338,7 +350,9 @@ func (a *API) GetCodespacesSkus(ctx context.Context, user *User, repository *Rep return nil, fmt.Errorf("error reading response body: %v", err) } - // TODO: check status code? + if resp.StatusCode != http.StatusOK { + return nil, a.errorResponse(b) + } response := struct { Skus Skus `json:"skus"` From da34d12abb099b9ea3e1093c6c8af6c0fecae4ad Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 27 Aug 2021 15:26:34 -0400 Subject: [PATCH 102/290] respond to review --- api/api.go | 2 +- cmd/ghcs/main.go | 3 --- cmd/ghcs/ports.go | 4 ++-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/api/api.go b/api/api.go index 9d9eae79f..f534c4587 100644 --- a/api/api.go +++ b/api/api.go @@ -1,4 +1,4 @@ -// TODO(adonovan): rename to package codespaces, and codespaces.Client? +// TODO(adonovan): rename to package codespaces, and codespaces.Client. package api import ( diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index dbf1dc714..bc9bc2c6b 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -1,8 +1,5 @@ package main -// TODO(adonovan): write 'help' commands, in manner of the 'go' tool. -// Document GITHUB_TOKEN. - import ( "errors" "fmt" diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index c5d127892..c5088bfe6 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -83,8 +83,8 @@ func ports(opts *portsOptions) error { devContainerResult := <-devContainerCh if devContainerResult.err != nil { - _, _ = log.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error()) - // TODO(adonovan): should this cause non-zero exit? + // Warn about failure to read the devcontainer file. Not a ghcs command error. + log.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error()) } table := output.NewTable(os.Stdout, opts.asJSON) From d8f1baa519e1daa0441664956882f433a6f91df1 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 27 Aug 2021 15:36:45 -0400 Subject: [PATCH 103/290] more SKU renames. --- api/api.go | 2 +- cmd/ghcs/create.go | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/api/api.go b/api/api.go index f534c4587..3e405dd5c 100644 --- a/api/api.go +++ b/api/api.go @@ -292,7 +292,7 @@ type SKU struct { DisplayName string `json:"display_name"` } -func (a *API) GetCodespacesSkus(ctx context.Context, user *User, repository *Repository, location string) ([]*SKU, error) { +func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Repository, location string) ([]*SKU, error) { req, err := http.NewRequest(http.MethodGet, githubAPI+"/vscs_internal/user/"+user.Login+"/skus", nil) if err != nil { return nil, fmt.Errorf("err creating request: %v", err) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 8b4e1a743..ef2209856 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -151,9 +151,9 @@ func getBranchName() (string, error) { } func getMachineName(ctx context.Context, user *api.User, repo *api.Repository, location string, apiClient *api.API) (string, error) { - skus, err := apiClient.GetCodespacesSkus(ctx, user, repo, location) + skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, location) if err != nil { - return "", fmt.Errorf("error getting codespace skus: %v", err) + return "", fmt.Errorf("error getting codespace SKUs: %v", err) } // if user supplied a machine type, it must be valid @@ -165,12 +165,12 @@ func getMachineName(ctx context.Context, user *api.User, repo *api.Repository, l } } - availableSkus := make([]string, len(skus)) + availableSKUs := make([]string, len(skus)) for i := 0; i < len(skus); i++ { - availableSkus[i] = skus[i].Name + availableSKUs[i] = skus[i].Name } - return "", fmt.Errorf("there are is no such machine for the repository: %s\nAvailable machines: %v", machine, availableSkus) + return "", fmt.Errorf("there is no such machine for the repository: %s\nAvailable machines: %v", machine, availableSKUs) } else if len(skus) == 0 { return "", nil } From a5ae72cb26fe05bb5d1ac031af14631183fe23d9 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 27 Aug 2021 15:38:41 -0400 Subject: [PATCH 104/290] revert removal of _ = f() to pacify linter --- cmd/ghcs/ports.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index c5088bfe6..492107bfc 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -84,7 +84,7 @@ func ports(opts *portsOptions) error { devContainerResult := <-devContainerCh if devContainerResult.err != nil { // Warn about failure to read the devcontainer file. Not a ghcs command error. - log.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error()) + _, _ = log.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error()) } table := output.NewTable(os.Stdout, opts.asJSON) From 8b395b5ab5b86366ca6e2c6b5a46133cac703028 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 27 Aug 2021 15:53:55 -0400 Subject: [PATCH 105/290] ghcs code: improve vscode error --- cmd/ghcs/code.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 81dbdbb2c..fd71da74f 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -57,8 +57,9 @@ func Code(codespaceName string, useInsiders bool) error { codespaceName = codespace.Name } - if err := open.Run(vscodeProtocolURL(codespaceName, useInsiders)); err != nil { - return fmt.Errorf("error opening vscode URL") + url := vscodeProtocolURL(codespaceName, useInsiders) + if err := open.Run(url); err != nil { + return fmt.Errorf("error opening vscode URL %s: %s. (Is VSCode installed?)", url, err) } return nil From e423cb0ef953f4d6e02707547f1c02615f9a7cff Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 27 Aug 2021 16:09:02 -0400 Subject: [PATCH 106/290] display colon and cursor in survey prompts --- cmd/ghcs/create.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index c3e8a24a1..7b33322be 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -126,11 +126,11 @@ func getRepoName() (string, error) { repoSurvey := []*survey.Question{ { Name: "repository", - Prompt: &survey.Input{Message: "Repository"}, + Prompt: &survey.Input{Message: "Repository:"}, Validate: survey.Required, }, } - err := survey.Ask(repoSurvey, &repo) + err := ask(repoSurvey, &repo) return repo, err } @@ -142,11 +142,11 @@ func getBranchName() (string, error) { branchSurvey := []*survey.Question{ { Name: "branch", - Prompt: &survey.Input{Message: "Branch"}, + Prompt: &survey.Input{Message: "Branch:"}, Validate: survey.Required, }, } - err := survey.Ask(branchSurvey, &branch) + err := ask(branchSurvey, &branch) return branch, err } @@ -198,7 +198,7 @@ func getMachineName(ctx context.Context, user *api.User, repo *api.Repository, l } skuAnswers := struct{ SKU string }{} - if err := survey.Ask(skuSurvey, &skuAnswers); err != nil { + if err := ask(skuSurvey, &skuAnswers); err != nil { return "", fmt.Errorf("error getting SKU: %v", err) } @@ -207,3 +207,8 @@ func getMachineName(ctx context.Context, user *api.User, repo *api.Repository, l return machine, nil } + +// ask asks survery questions using standard options. +func ask(qs []*survey.Question, response interface{}) error { + return survey.Ask(qs, response, survey.WithShowCursor(true)) +} From 1e8a8370fee4988be9ab089473d8d7ad07438761 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 16:29:02 -0400 Subject: [PATCH 107/290] initial round of PR feedback --- cmd/ghcs/create.go | 23 ++++++++++++----------- internal/codespaces/codespaces.go | 17 ++++++----------- internal/codespaces/states.go | 12 +++++------- 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index bb416d53f..c3cfbce22 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -108,7 +108,7 @@ func create(opts *createOptions) error { func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, user *api.User, codespace *api.Codespace) error { states, err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace) if err != nil { - return fmt.Errorf("poll post create states: %v", err) + return fmt.Errorf("failed to subscribe to state changes from codespace: %v", err) } var lastState codespaces.PostCreateState @@ -135,27 +135,28 @@ func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, use lastState = state log.Print("...") break - } else { - finishedStates[state.Name] = true - log.Println("..." + state.Status) } + + finishedStates[state.Name] = true + log.Println("..." + state.Status) } else { if state.Status == codespaces.PostCreateStateRunning { inProgress = true log.Print(".") break - } else { - finishedStates[state.Name] = true - log.Println(state.Status) - lastState = codespaces.PostCreateState{} // reset the value } + + finishedStates[state.Name] = true + log.Println(state.Status) + lastState = codespaces.PostCreateState{} // reset the value } } - if !inProgress && !breakNextState { + if !inProgress { + if breakNextState { + break + } breakNextState = true - } else if !inProgress && breakNextState { - break } } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 8897c0839..05254c9a7 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -62,14 +62,12 @@ type logger interface { Println(v ...interface{}) (int, error) } -// TODO(josebalius): we should move some of this to the liveshare.Connection struct func connectionReady(codespace *api.Codespace) bool { - ready := codespace.Environment.Connection.SessionID != "" - ready = ready && codespace.Environment.Connection.SessionToken != "" - ready = ready && codespace.Environment.Connection.RelayEndpoint != "" - ready = ready && codespace.Environment.Connection.RelaySAS != "" - ready = ready && codespace.Environment.State == api.CodespaceEnvironmentStateAvailable - return ready + return codespace.Environment.Connection.SessionID != "" && + codespace.Environment.Connection.SessionToken != "" && + codespace.Environment.Connection.RelayEndpoint != "" && + codespace.Environment.Connection.RelaySAS != "" && + codespace.Environment.State == api.CodespaceEnvironmentStateAvailable } func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (client *liveshare.Client, err error) { @@ -82,8 +80,7 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use } } - retries := 0 - for !connectionReady(codespace) { + for retries := 0; !connectionReady(codespace); retries++ { if retries > 1 { if retries%2 == 0 { log.Print(".") @@ -100,8 +97,6 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use if err != nil { return nil, fmt.Errorf("error getting codespace: %v", err) } - - retries += 1 } if startedCodespace { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 3ad6c9a4c..9ace150d6 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -24,12 +24,10 @@ const ( ) type PostCreateStatesResult struct { - PostCreateStates PostCreateStates + PostCreateStates []PostCreateState Err error } -type PostCreateStates []PostCreateState - type PostCreateState struct { Name string `json:"name"` Status PostCreateStateStatus `json:"status"` @@ -81,7 +79,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return pollch, nil } -func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace) (PostCreateStates, error) { +func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace) ([]PostCreateState, error) { stdout, err := RunCommand( ctx, tunnelPort, sshDestination(codespace), "cat /workspaces/.codespaces/shared/postCreateOutput.json", @@ -95,9 +93,9 @@ func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Cod return nil, fmt.Errorf("read output: %v", err) } - output := struct { - Steps PostCreateStates `json:"steps"` - }{} + var output struct { + Steps []PostCreateState `json:"steps"` + } if err := json.Unmarshal(b, &output); err != nil { return nil, fmt.Errorf("unmarshal output: %v", err) } From e5f45d4bfab826871cebabda66da6f46ad0be504 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 16:41:22 -0400 Subject: [PATCH 108/290] docs and improvement to the showStatus implementation --- cmd/ghcs/create.go | 79 ++++++++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index c3cfbce22..42e8f11be 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -115,48 +115,54 @@ func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, use finishedStates := make(map[string]bool) var breakNextState bool +PollStates: for { - stateUpdate := <-states - if stateUpdate.Err != nil { - return fmt.Errorf("receive state update: %v", err) - } + select { + case <-ctx.Done(): + return nil - var inProgress bool - for _, state := range stateUpdate.PostCreateStates { - if _, found := finishedStates[state.Name]; found { - continue // skip this state as we've processed it already + case stateUpdate := <-states: + if stateUpdate.Err != nil { + return fmt.Errorf("receive state update: %v", err) } - if state.Name != lastState.Name { - log.Print(state.Name) - - if state.Status == codespaces.PostCreateStateRunning { - inProgress = true - lastState = state - log.Print("...") - break + var inProgress bool + for _, state := range stateUpdate.PostCreateStates { + if _, found := finishedStates[state.Name]; found { + continue // skip this state as we've processed it already } - finishedStates[state.Name] = true - log.Println("..." + state.Status) - } else { - if state.Status == codespaces.PostCreateStateRunning { - inProgress = true - log.Print(".") - break + if state.Name != lastState.Name { + log.Print(state.Name) + + if state.Status == codespaces.PostCreateStateRunning { + inProgress = true + lastState = state + log.Print("...") + break + } + + finishedStates[state.Name] = true + log.Println("..." + state.Status) + } else { + if state.Status == codespaces.PostCreateStateRunning { + inProgress = true + log.Print(".") + break + } + + finishedStates[state.Name] = true + log.Println(state.Status) + lastState = codespaces.PostCreateState{} // reset the value } - - finishedStates[state.Name] = true - log.Println(state.Status) - lastState = codespaces.PostCreateState{} // reset the value } - } - if !inProgress { - if breakNextState { - break + if !inProgress { + if breakNextState { + break PollStates + } + breakNextState = true } - breakNextState = true } } @@ -168,6 +174,7 @@ type getUserResult struct { Err error } +// getUser fetches the user record associated with the GITHUB_TOKEN func getUser(ctx context.Context, apiClient *api.API) <-chan getUserResult { ch := make(chan getUserResult) go func() { @@ -182,6 +189,7 @@ type locationResult struct { Err error } +// getLocation fetches the closest Codespace datacenter region/location to the user. func getLocation(ctx context.Context, apiClient *api.API) <-chan locationResult { ch := make(chan locationResult) go func() { @@ -191,6 +199,7 @@ func getLocation(ctx context.Context, apiClient *api.API) <-chan locationResult return ch } +// getRepoName prompts the user for the name of the repository, or returns the repository if non-empty. func getRepoName(repo string) (string, error) { if repo != "" { return repo, nil @@ -207,6 +216,7 @@ func getRepoName(repo string) (string, error) { return repo, err } +// getBranchName prompts the user for the name of the branch, or returns the branch if non-empty. func getBranchName(branch string) (string, error) { if branch != "" { return branch, nil @@ -223,6 +233,7 @@ func getBranchName(branch string) (string, error) { return branch, err } +// getMachineName prompts the user to select the machine type, or validates the machine if non-empty. func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, location string, apiClient *api.API) (string, error) { skus, err := apiClient.GetCodespacesSkus(ctx, user, repo, location) if err != nil { @@ -243,7 +254,7 @@ func getMachineName(ctx context.Context, machine string, user *api.User, repo *a availableSkus[i] = skus[i].Name } - return "", fmt.Errorf("there are is no such machine for the repository: %s\nAvailable machines: %v", machine, availableSkus) + return "", fmt.Errorf("there is no such machine for the repository: %s\nAvailable machines: %v", machine, availableSkus) } else if len(skus) == 0 { return "", nil } @@ -270,7 +281,7 @@ func getMachineName(ctx context.Context, machine string, user *api.User, repo *a }, } - skuAnswers := struct{ SKU string }{} + var skuAnswers struct{ SKU string } if err := survey.Ask(skuSurvey, &skuAnswers); err != nil { return "", fmt.Errorf("error getting SKU: %v", err) } From 368e8c61105f7beafcdedd3d6c3376293db345ca Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 17:34:06 -0400 Subject: [PATCH 109/290] simplify contract for state polling --- cmd/ghcs/create.go | 83 ++++++++++++++++------------------- internal/codespaces/states.go | 51 +++++++++------------ 2 files changed, 58 insertions(+), 76 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 42e8f11be..6b6d10511 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -106,64 +106,55 @@ func create(opts *createOptions) error { } func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, user *api.User, codespace *api.Codespace) error { - states, err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace) - if err != nil { - return fmt.Errorf("failed to subscribe to state changes from codespace: %v", err) - } - var lastState codespaces.PostCreateState - finishedStates := make(map[string]bool) var breakNextState bool -PollStates: - for { - select { - case <-ctx.Done(): - return nil + finishedStates := make(map[string]bool) + ctx, stopPolling := context.WithCancel(ctx) - case stateUpdate := <-states: - if stateUpdate.Err != nil { - return fmt.Errorf("receive state update: %v", err) + poller := func(states []codespaces.PostCreateState) { + var inProgress bool + for _, state := range states { + if _, found := finishedStates[state.Name]; found { + continue // skip this state as we've processed it already } - var inProgress bool - for _, state := range stateUpdate.PostCreateStates { - if _, found := finishedStates[state.Name]; found { - continue // skip this state as we've processed it already + if state.Name != lastState.Name { + log.Print(state.Name) + + if state.Status == codespaces.PostCreateStateRunning { + inProgress = true + lastState = state + log.Print("...") + break } - if state.Name != lastState.Name { - log.Print(state.Name) - - if state.Status == codespaces.PostCreateStateRunning { - inProgress = true - lastState = state - log.Print("...") - break - } - - finishedStates[state.Name] = true - log.Println("..." + state.Status) - } else { - if state.Status == codespaces.PostCreateStateRunning { - inProgress = true - log.Print(".") - break - } - - finishedStates[state.Name] = true - log.Println(state.Status) - lastState = codespaces.PostCreateState{} // reset the value + finishedStates[state.Name] = true + log.Println("..." + state.Status) + } else { + if state.Status == codespaces.PostCreateStateRunning { + inProgress = true + log.Print(".") + break } - } - if !inProgress { - if breakNextState { - break PollStates - } - breakNextState = true + finishedStates[state.Name] = true + log.Println(state.Status) + lastState = codespaces.PostCreateState{} // reset the value } } + + if !inProgress { + if breakNextState { + stopPolling() + return + } + breakNextState = true + } + } + + if err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace, poller); err != nil { + return fmt.Errorf("failed to poll state changes from codespace: %v", err) } return nil diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 9ace150d6..427726a46 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -33,50 +33,40 @@ type PostCreateState struct { Status PostCreateStateStatus `json:"status"` } -func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace) (<-chan PostCreateStatesResult, error) { - pollch := make(chan PostCreateStatesResult) - +func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { - return nil, fmt.Errorf("getting codespace token: %v", err) + return fmt.Errorf("getting codespace token: %v", err) } lsclient, err := ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return nil, fmt.Errorf("connect to liveshare: %v", err) + return fmt.Errorf("connect to liveshare: %v", err) } tunnelPort, connClosed, err := MakeSSHTunnel(ctx, lsclient, 0) if err != nil { - return nil, fmt.Errorf("make ssh tunnel: %v", err) + return fmt.Errorf("make ssh tunnel: %v", err) } - go func() { - t := time.NewTicker(1 * time.Second) - for { - select { - case <-ctx.Done(): - return - case err := <-connClosed: - if err != nil { - pollch <- PostCreateStatesResult{Err: fmt.Errorf("connection closed: %v", err)} - return - } - case <-t.C: - states, err := getPostCreateOutput(ctx, tunnelPort, codespace) - if err != nil { - pollch <- PostCreateStatesResult{Err: fmt.Errorf("get post create output: %v", err)} - return - } - - pollch <- PostCreateStatesResult{ - PostCreateStates: states, - } + t := time.NewTicker(1 * time.Second) + for { + select { + case <-ctx.Done(): + return nil + case err := <-connClosed: + return fmt.Errorf("connection closed: %v", err) + case <-t.C: + states, err := getPostCreateOutput(ctx, tunnelPort, codespace) + if err != nil { + return fmt.Errorf("get post create output: %v", err) } - } - }() - return pollch, nil + poller(states) + } + } + + return nil } func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace) ([]PostCreateState, error) { @@ -87,6 +77,7 @@ func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Cod if err != nil { return nil, fmt.Errorf("run command: %v", err) } + defer stdout.Close() b, err := ioutil.ReadAll(stdout) if err != nil { From a5a18026cc34382da6b7c7bf3a6d950f7a8f0678 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 17:39:10 -0400 Subject: [PATCH 110/290] fix linter --- internal/codespaces/states.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 427726a46..36cf5e5e3 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -65,8 +65,6 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u poller(states) } } - - return nil } func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace) ([]PostCreateState, error) { From dcf4f041e9817e7aed0f2e584f7c00884c85f258 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 27 Aug 2021 18:01:52 -0400 Subject: [PATCH 111/290] deal with Start errors, non-JSON --- api/api.go | 45 +++++++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/api/api.go b/api/api.go index fb386a57b..c86de79f7 100644 --- a/api/api.go +++ b/api/api.go @@ -1,5 +1,11 @@ package api +// For descriptions of service interfaces, see: +// - https://online.visualstudio.com/api/swagger (for visualstudio.com) +// - https://docs.github.com/en/rest/reference/repos (for api.github.com) +// - https://github.com/github/github/blob/master/app/api/codespaces.rb (for vscs_internal) +// TODO(adonovan): replace the last link with a public doc URL when available. + import ( "bytes" "context" @@ -29,10 +35,6 @@ type User struct { Login string `json:"login"` } -type errResponse struct { - Message string `json:"message"` -} - func (a *API) GetUser(ctx context.Context) (*User, error) { req, err := http.NewRequest(http.MethodGet, githubAPI+"/user", nil) if err != nil { @@ -52,7 +54,7 @@ func (a *API) GetUser(ctx context.Context) (*User, error) { } if resp.StatusCode != http.StatusOK { - return nil, a.errorResponse(b) + return nil, jsonErrorResponse(b) } var response User @@ -63,8 +65,10 @@ func (a *API) GetUser(ctx context.Context) (*User, error) { return &response, nil } -func (a *API) errorResponse(b []byte) error { - var response errResponse +func jsonErrorResponse(b []byte) error { + var response struct { + Message string `json:"message"` + } if err := json.Unmarshal(b, &response); err != nil { return fmt.Errorf("error unmarshaling error response: %v", err) } @@ -95,7 +99,7 @@ func (a *API) GetRepository(ctx context.Context, nwo string) (*Repository, error } if resp.StatusCode != http.StatusOK { - return nil, a.errorResponse(b) + return nil, jsonErrorResponse(b) } var response Repository @@ -162,7 +166,7 @@ func (a *API) ListCodespaces(ctx context.Context, user *User) (Codespaces, error } if resp.StatusCode != http.StatusOK { - return nil, a.errorResponse(b) + return nil, jsonErrorResponse(b) } response := struct { @@ -210,7 +214,7 @@ func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName s } if resp.StatusCode != http.StatusOK { - return "", a.errorResponse(b) + return "", jsonErrorResponse(b) } var response getCodespaceTokenResponse @@ -244,7 +248,7 @@ func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) } if resp.StatusCode != http.StatusOK { - return nil, a.errorResponse(b) + return nil, jsonErrorResponse(b) } var response Codespace @@ -277,11 +281,12 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes return fmt.Errorf("error reading response body: %v", err) } - // TODO(adonovan): the status code proxied from VSCS may distinguish - // "already running" from "fresh start". Find out what code it uses - // and allow it too. if resp.StatusCode != http.StatusOK { - return a.errorResponse(b) + // Error response is numeric code and/or string message, not JSON. + if len(b) > 100 { + b = append(b[:97], "..."...) + } + return fmt.Errorf("failed to start codespace: %s", b) } return nil @@ -309,7 +314,7 @@ func (a *API) GetCodespaceRegionLocation(ctx context.Context) (string, error) { } if resp.StatusCode != http.StatusOK { - return "", a.errorResponse(b) + return "", jsonErrorResponse(b) } var response getCodespaceRegionLocationResponse @@ -351,7 +356,7 @@ func (a *API) GetCodespacesSkus(ctx context.Context, user *User, repository *Rep } if resp.StatusCode != http.StatusOK { - return nil, a.errorResponse(b) + return nil, jsonErrorResponse(b) } response := struct { @@ -395,7 +400,7 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos } if resp.StatusCode > http.StatusAccepted { - return nil, a.errorResponse(b) + return nil, jsonErrorResponse(b) } var response Codespace @@ -424,7 +429,7 @@ func (a *API) DeleteCodespace(ctx context.Context, user *User, token, codespaceN if err != nil { return fmt.Errorf("error reading response body: %v", err) } - return a.errorResponse(b) + return jsonErrorResponse(b) } return nil @@ -461,7 +466,7 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod } if resp.StatusCode != http.StatusOK { - return nil, a.errorResponse(b) + return nil, jsonErrorResponse(b) } var response getCodespaceRepositoryContentsResponse From 272af2fadf4694bf2ed59845aa469c8bd2a8c0ea Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 18:12:06 -0400 Subject: [PATCH 112/290] add docs --- internal/codespaces/states.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 36cf5e5e3..ed41fab9b 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -11,6 +11,7 @@ import ( "github.com/github/ghcs/api" ) +// PostCreateStateStatus is a string value representing the different statuses a state can have. type PostCreateStateStatus string func (p PostCreateStateStatus) String() string { @@ -23,16 +24,15 @@ const ( PostCreateStateFailed PostCreateStateStatus = "failed" ) -type PostCreateStatesResult struct { - PostCreateStates []PostCreateState - Err error -} - +// PostCreateState is a combination of a state and status value that is captured +// during codespace creation. type PostCreateState struct { Name string `json:"name"` Status PostCreateStateStatus `json:"status"` } +// PollPostCreateStates polls the state file in a codespace after creation and calls the poller +// with a slice of states to be processed. func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { @@ -92,6 +92,7 @@ func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Cod return output.Steps, nil } +// TODO(josebalius): this won't be needed soon func sshDestination(codespace *api.Codespace) string { user := "codespace" if codespace.RepositoryNWO == "github/github" { From 0cf2640c863898fb3a78f19be2d87f29f2ba677f Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 27 Aug 2021 18:14:10 -0400 Subject: [PATCH 113/290] better docs and stop ticker --- internal/codespaces/states.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index ed41fab9b..5c3dcef45 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -31,8 +31,9 @@ type PostCreateState struct { Status PostCreateStateStatus `json:"status"` } -// PollPostCreateStates polls the state file in a codespace after creation and calls the poller -// with a slice of states to be processed. +// PollPostCreateStates watches for state changes in a codespace, +// and calls the supplied poller for each batch of state changes. +// It runs until the context is cancelled or SSH tunnel is closed. func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { @@ -50,6 +51,8 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u } t := time.NewTicker(1 * time.Second) + defer t.Stop() + for { select { case <-ctx.Done(): From 5db9e2d83e04754f85f18c6a6f9e9834826e86cc Mon Sep 17 00:00:00 2001 From: Edmundo Gonzalez <51725820+edgonmsft@users.noreply.github.com> Date: Mon, 30 Aug 2021 04:48:56 +0000 Subject: [PATCH 114/290] PR changes. --- cmd/ghcs/logs.go | 6 +-- cmd/ghcs/ssh.go | 79 +++++++++++++++++++++++++++++-- internal/codespaces/codespaces.go | 15 +++--- internal/codespaces/ssh.go | 10 ++-- 4 files changed, 92 insertions(+), 18 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 6f93ee3b9..31af10dda 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -57,15 +57,11 @@ func Logs(tail bool, codespaceName string) error { return fmt.Errorf("connecting to liveshare: %v", err) } - result, remoteSSHServerPort, sshUser, _, err := codespaces.StartSSHServer(ctx, lsclient) + remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, lsclient) if err != nil { return fmt.Errorf("error getting ssh server details: %v", err) } - if !result { - return fmt.Errorf("error starting ssh: %v", err) - } - tunnelPort, connClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, 0, remoteSSHServerPort) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index f3e621824..e3e51e08e 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -1,13 +1,17 @@ package main import ( + "bufio" "context" "fmt" "os" + "strings" + "time" "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/codespaces" + "github.com/github/go-liveshare" "github.com/spf13/cobra" ) @@ -55,15 +59,29 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { return fmt.Errorf("error connecting to liveshare: %v", err) } - result, remoteSSHServerPort, sshUser, _, err := codespaces.StartSSHServer(ctx, lsclient) + remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, lsclient) if err != nil { return fmt.Errorf("error getting ssh server details: %v", err) } - if !result { - return fmt.Errorf("error starting ssh: %v", err) + terminal, err := liveshare.NewTerminal(lsclient) + if err != nil { + return fmt.Errorf("error creating liveshare terminal: %v", err) } + log.Print("Preparing SSH...") + if sshProfile == "" { + containerID, err := getContainerID(ctx, log, terminal) + if err != nil { + return fmt.Errorf("error getting container id: %v", err) + } + + if err := setupEnv(ctx, log, terminal, containerID, codespace.RepositoryName, sshUser); err != nil { + return fmt.Errorf("error creating ssh server: %v", err) + } + } + log.Print("\n") + tunnelPort, tunnelClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, sshServerPort, remoteSSHServerPort) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) @@ -91,3 +109,58 @@ func SSH(sshProfile, codespaceName string, sshServerPort int) error { return nil } + +func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) { + logger.Print(".") + + cmd := terminal.NewCommand( + "/", + "/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running", + ) + + stream, err := cmd.Run(ctx) + if err != nil { + return "", fmt.Errorf("error running command: %v", err) + } + + logger.Print(".") + scanner := bufio.NewScanner(stream) + scanner.Scan() + + logger.Print(".") + containerID := scanner.Text() + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("error scanning stream: %v", err) + } + + logger.Print(".") + if err := stream.Close(); err != nil { + return "", fmt.Errorf("error closing stream: %v", err) + } + + return containerID, nil +} + +func setupEnv(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal, containerID, repositoryName, containerUser string) error { + setupBashProfileCmd := fmt.Sprintf(`echo "cd /workspaces/%v; export $(cat /workspaces/.codespaces/shared/.env | xargs); exec /bin/zsh;" > /home/%v/.bash_profile`, repositoryName, containerUser) + + logger.Print(".") + compositeCommand := []string{setupBashProfileCmd} + cmd := terminal.NewCommand( + "/", + fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c '"+strings.Join(compositeCommand, "; ")+"'", containerID), + ) + stream, err := cmd.Run(ctx) + if err != nil { + return fmt.Errorf("error running command: %v", err) + } + + logger.Print(".") + if err := stream.Close(); err != nil { + return fmt.Errorf("error closing stream: %v", err) + } + + time.Sleep(1 * time.Second) + + return nil +} diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 3db66b427..8b18f7d7d 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -118,27 +118,30 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, tok return lsclient, nil } -func StartSSHServer(ctx context.Context, client *liveshare.Client) (result bool, serverPort int, user string, message string, err error) { +// StartSSHServer starts and installs the SSH server in the codespace +// returns the remote port where it is running, the user to use to login +// or an error if something failed. +func StartSSHServer(ctx context.Context, client *liveshare.Client) (serverPort int, user string, err error) { sshServer, err := liveshare.NewSSHServer(client) if err != nil { - return false, 0, "", "", fmt.Errorf("error creating live share: %v", err) + return 0, "", fmt.Errorf("error creating live share: %v", err) } sshServerStartResult, err := sshServer.StartRemoteServer(ctx) if err != nil { - return false, 0, "", "", fmt.Errorf("error creating live share: %v", err) + return 0, "", fmt.Errorf("error starting live share: %v", err) } if !sshServerStartResult.Result { - return false, 0, "", sshServerStartResult.Message, nil + return 0, "", errors.New(sshServerStartResult.Message) } portInt, err := strconv.Atoi(sshServerStartResult.ServerPort) if err != nil { - return false, 0, "", "", fmt.Errorf("error parsing port: %v", err) + return 0, "", fmt.Errorf("error parsing port: %v", err) } - return sshServerStartResult.Result, portInt, sshServerStartResult.User, sshServerStartResult.Message, err + return portInt, sshServerStartResult.User, nil } func GetOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index cf6118704..cd944e114 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -14,7 +14,10 @@ import ( "github.com/github/go-liveshare" ) -func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, serverPort int, remoteSSHPort int) (int, <-chan error, error) { +// MakeSSHTunnel This function initializes the liveshare tunnel +// Creates the tunnel from a local port to a remote port. +// Returns the local port that was used, the channel and the error if any. +func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, localSSHPort int, remoteSSHPort int) (int, <-chan error, error) { tunnelClosed := make(chan error) server, err := liveshare.NewServer(lsclient) @@ -24,11 +27,10 @@ func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, serverPort i rand.Seed(time.Now().Unix()) port := rand.Intn(9999-2000) + 2000 // improve this obviously - if serverPort != 0 { - port = serverPort + if localSSHPort != 0 { + port = localSSHPort } - // TODO(josebalius): This port won't always be 2222 if err := server.StartSharing(ctx, "sshd", remoteSSHPort); err != nil { return 0, nil, fmt.Errorf("sharing sshd port: %v", err) } From 13917a289df253bd819511ee75d7851551b5f86e Mon Sep 17 00:00:00 2001 From: Edmundo Gonzalez <51725820+edgonmsft@users.noreply.github.com> Date: Mon, 30 Aug 2021 04:52:27 +0000 Subject: [PATCH 115/290] Moved function to ssh.go file. --- internal/codespaces/codespaces.go | 27 --------------------------- internal/codespaces/ssh.go | 27 +++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 8b18f7d7d..48369cfa0 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "strconv" "time" "github.com/AlecAivazis/survey/v2" @@ -118,32 +117,6 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, tok return lsclient, nil } -// StartSSHServer starts and installs the SSH server in the codespace -// returns the remote port where it is running, the user to use to login -// or an error if something failed. -func StartSSHServer(ctx context.Context, client *liveshare.Client) (serverPort int, user string, err error) { - sshServer, err := liveshare.NewSSHServer(client) - if err != nil { - return 0, "", fmt.Errorf("error creating live share: %v", err) - } - - sshServerStartResult, err := sshServer.StartRemoteServer(ctx) - if err != nil { - return 0, "", fmt.Errorf("error starting live share: %v", err) - } - - if !sshServerStartResult.Result { - return 0, "", errors.New(sshServerStartResult.Message) - } - - portInt, err := strconv.Atoi(sshServerStartResult.ServerPort) - if err != nil { - return 0, "", fmt.Errorf("error parsing port: %v", err) - } - - return portInt, sshServerStartResult.User, nil -} - func GetOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { if codespaceName == "" { codespace, err = ChooseCodespace(ctx, apiClient, user) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index cd944e114..a1abcf381 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -2,6 +2,7 @@ package codespaces import ( "context" + "errors" "fmt" "io" "math/rand" @@ -47,6 +48,32 @@ func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, localSSHPort return port, tunnelClosed, nil } +// StartSSHServer starts and installs the SSH server in the codespace +// returns the remote port where it is running, the user to use to login +// or an error if something failed. +func StartSSHServer(ctx context.Context, client *liveshare.Client) (serverPort int, user string, err error) { + sshServer, err := liveshare.NewSSHServer(client) + if err != nil { + return 0, "", fmt.Errorf("error creating live share: %v", err) + } + + sshServerStartResult, err := sshServer.StartRemoteServer(ctx) + if err != nil { + return 0, "", fmt.Errorf("error starting live share: %v", err) + } + + if !sshServerStartResult.Result { + return 0, "", errors.New(sshServerStartResult.Message) + } + + portInt, err := strconv.Atoi(sshServerStartResult.ServerPort) + if err != nil { + return 0, "", fmt.Errorf("error parsing port: %v", err) + } + + return portInt, sshServerStartResult.User, nil +} + func makeSSHArgs(port int, dst, cmd string) ([]string, []string) { connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression From 0c066cbd099622d16516b346445050d16f9d68df Mon Sep 17 00:00:00 2001 From: Edmundo Gonzalez <51725820+edgonmsft@users.noreply.github.com> Date: Mon, 30 Aug 2021 05:05:43 +0000 Subject: [PATCH 116/290] Fix compilation error. --- internal/codespaces/states.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 5c3dcef45..7b98b12dd 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -45,7 +45,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connect to liveshare: %v", err) } - tunnelPort, connClosed, err := MakeSSHTunnel(ctx, lsclient, 0) + tunnelPort, connClosed, err := MakeSSHTunnel(ctx, lsclient, 0, 2222) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) } From 954d46dce5846c94c2f3cfd07206acacc5208d19 Mon Sep 17 00:00:00 2001 From: Edmundo Gonzalez <51725820+edgonmsft@users.noreply.github.com> Date: Mon, 30 Aug 2021 17:30:28 +0000 Subject: [PATCH 117/290] Changes from comments on pr. --- internal/codespaces/ssh.go | 8 ++------ internal/codespaces/states.go | 7 ++++++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index a1abcf381..1c3807742 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -15,9 +15,6 @@ import ( "github.com/github/go-liveshare" ) -// MakeSSHTunnel This function initializes the liveshare tunnel -// Creates the tunnel from a local port to a remote port. -// Returns the local port that was used, the channel and the error if any. func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, localSSHPort int, remoteSSHPort int) (int, <-chan error, error) { tunnelClosed := make(chan error) @@ -48,9 +45,8 @@ func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, localSSHPort return port, tunnelClosed, nil } -// StartSSHServer starts and installs the SSH server in the codespace -// returns the remote port where it is running, the user to use to login -// or an error if something failed. +// StartSSHServer installs (if necessary) and starts the SSH in the codespace. +// It returns the remote port where it is running, the user to log in with, or an error if something failed. func StartSSHServer(ctx context.Context, client *liveshare.Client) (serverPort int, user string, err error) { sshServer, err := liveshare.NewSSHServer(client) if err != nil { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 7b98b12dd..d2aa389ef 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -45,7 +45,12 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connect to liveshare: %v", err) } - tunnelPort, connClosed, err := MakeSSHTunnel(ctx, lsclient, 0, 2222) + remoteSSHServerPort, _, err := StartSSHServer(ctx, lsclient) + if err != nil { + return fmt.Errorf("error getting ssh server details: %v", err) + } + + tunnelPort, connClosed, err := MakeSSHTunnel(ctx, lsclient, 0, remoteSSHServerPort) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) } From 903b7be7dea2d4a0f5a2c9cc4ef3053d90029ca2 Mon Sep 17 00:00:00 2001 From: Edmundo Gonzalez <51725820+edgonmsft@users.noreply.github.com> Date: Mon, 30 Aug 2021 21:01:13 +0000 Subject: [PATCH 118/290] Comments from pr. --- cmd/ghcs/logs.go | 2 +- cmd/ghcs/ssh.go | 2 +- internal/codespaces/ssh.go | 4 +++- internal/codespaces/states.go | 14 +++++--------- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index ea5531bac..ec9e63a56 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -57,7 +57,7 @@ func logs(tail bool, codespaceName string) error { return fmt.Errorf("connecting to liveshare: %v", err) } - remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, lsclient) + remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, lsclient, log) if err != nil { return fmt.Errorf("error getting ssh server details: %v", err) } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index fd98397fe..ad8743360 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -59,7 +59,7 @@ func ssh(sshProfile, codespaceName string, sshServerPort int) error { return fmt.Errorf("error connecting to liveshare: %v", err) } - remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, lsclient) + remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, lsclient, log) if err != nil { return fmt.Errorf("error getting ssh server details: %v", err) } diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 1c3807742..16ffed07b 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -47,7 +47,9 @@ func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, localSSHPort // StartSSHServer installs (if necessary) and starts the SSH in the codespace. // It returns the remote port where it is running, the user to log in with, or an error if something failed. -func StartSSHServer(ctx context.Context, client *liveshare.Client) (serverPort int, user string, err error) { +func StartSSHServer(ctx context.Context, client *liveshare.Client, log logger) (serverPort int, user string, err error) { + log.Println("Fetching SSH details...") + sshServer, err := liveshare.NewSSHServer(client) if err != nil { return 0, "", fmt.Errorf("error creating live share: %v", err) diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index d2aa389ef..f3e9cbefe 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -45,7 +45,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connect to liveshare: %v", err) } - remoteSSHServerPort, _, err := StartSSHServer(ctx, lsclient) + remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, lsclient, log) if err != nil { return fmt.Errorf("error getting ssh server details: %v", err) } @@ -65,7 +65,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u case err := <-connClosed: return fmt.Errorf("connection closed: %v", err) case <-t.C: - states, err := getPostCreateOutput(ctx, tunnelPort, codespace) + states, err := getPostCreateOutput(ctx, tunnelPort, codespace, sshUser) if err != nil { return fmt.Errorf("get post create output: %v", err) } @@ -75,9 +75,9 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u } } -func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace) ([]PostCreateState, error) { +func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace, user string) ([]PostCreateState, error) { stdout, err := RunCommand( - ctx, tunnelPort, sshDestination(codespace), + ctx, tunnelPort, sshDestination(codespace, user), "cat /workspaces/.codespaces/shared/postCreateOutput.json", ) if err != nil { @@ -101,10 +101,6 @@ func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Cod } // TODO(josebalius): this won't be needed soon -func sshDestination(codespace *api.Codespace) string { - user := "codespace" - if codespace.RepositoryNWO == "github/github" { - user = "root" - } +func sshDestination(codespace *api.Codespace, user string) string { return user + "@localhost" } From 4af240d87da018b38c4765f31ece31b7aa2c8478 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 30 Aug 2021 17:36:28 -0400 Subject: [PATCH 119/290] handle errors in port forwarding --- port_forwarder.go | 113 +++++++++++++++++++++++++++-------------- port_forwarder_test.go | 5 +- 2 files changed, 77 insertions(+), 41 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index e6eedf16c..cc7b6ea1d 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -5,77 +5,116 @@ import ( "fmt" "io" "net" - "strconv" ) -// A PortForwader can forward ports from a remote liveshare host to localhost +// A PortForwarder forwards TCP traffic between a port on a remote +// LiveShare host and a local port. type PortForwarder struct { client *Client server *Server port int - errCh chan error } -// NewPortForwarder creates a new PortForwader with a given client, server and port +// NewPortForwarder creates a new PortForwarder that connects a given client, server and port. func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { return &PortForwarder{ client: client, server: server, port: port, - errCh: make(chan error), } } -// Start is a method to start forwarding the server to a localhost port -func (l *PortForwarder) Start(ctx context.Context) error { - ln, err := net.Listen("tcp", ":"+strconv.Itoa(l.port)) +// Forward enables port forwarding. It accepts and handles TCP +// connections until it encounters the first error, which may include +// context cancellation. Its result is non-nil. +func (l *PortForwarder) Forward(ctx context.Context) (err error) { + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", l.port)) if err != nil { - return fmt.Errorf("error listening on tcp port: %v", err) + return fmt.Errorf("error listening on TCP port: %v", err) } + defer safeClose(listen, &err) + errc := make(chan error, 1) + sendError := func(err error) { + // Use non-blocking send, to avoid goroutines getting + // stuck in case of concurrent or sequential errors. + select { + case errc <- err: + default: + } + } go func() { for { - conn, err := ln.Accept() + conn, err := listen.Accept() if err != nil { - l.errCh <- fmt.Errorf("error accepting incoming connection: %v", err) + sendError(err) + return } - go l.handleConnection(ctx, conn) + go func() { + if err := l.handleConnection(ctx, conn); err != nil { + sendError(err) + } + }() } }() + return awaitError(ctx, errc) +} + +// ForwardWithConn handles port forwarding for a single connection. +func (l *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { + // Create buffered channel so that send doesn't get stuck after context cancellation. + errc := make(chan error, 1) + go func() { + if err := l.handleConnection(ctx, conn); err != nil { + errc <- err + } + }() + return awaitError(ctx, errc) +} + +func awaitError(ctx context.Context, errc <-chan error) error { select { - case err := <-l.errCh: + case err := <-errc: return err case <-ctx.Done(): - return ln.Close() + return ctx.Err() // canceled } +} +// handleConnection handles forwarding for a single accepted connection, then closes it. +func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) { + defer safeClose(conn, &err) + + channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) + if err != nil { + return fmt.Errorf("error opening streaming channel for new connection: %v", err) + } + defer safeClose(channel, &err) + + errs := make(chan error, 2) + copyConn := func(w io.Writer, r io.Reader) { + _, err := io.Copy(w, r) + errs <- err + } + go copyConn(conn, channel) + go copyConn(channel, conn) + + // await result + for i := 0; i < 2; i++ { + if err := <-errs; err != nil && err != io.EOF { + return fmt.Errorf("tunnel connection: %v", err) + } + } return nil } -func (l *PortForwarder) StartWithConn(ctx context.Context, conn io.ReadWriteCloser) error { - go l.handleConnection(ctx, conn) - return <-l.errCh -} - -func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) { - channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) - if err != nil { - l.errCh <- fmt.Errorf("error opening streaming channel for new connection: %v", err) - return +// safeClose reports the error (to *err) from closing the stream only +// if no other error was previously reported. +func safeClose(closer io.Closer, err *error) { + closeErr := closer.Close() + if *err == nil { + *err = closeErr } - - copyConn := func(writer io.Writer, reader io.Reader) { - if _, err := io.Copy(writer, reader); err != nil { - channel.Close() - conn.Close() - if err != io.EOF { - l.errCh <- fmt.Errorf("tunnel connection: %v", err) - } - } - } - - go copyConn(conn, channel) - go copyConn(channel, conn) } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 33a33b39b..3ae846937 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -63,10 +63,7 @@ func TestPortForwarderStart(t *testing.T) { if err := server.StartSharing(ctx, "http", 8000); err != nil { done <- fmt.Errorf("start sharing: %v", err) } - if err := pf.Start(ctx); err != nil { - done <- err - } - done <- nil + done <- pf.Forward(ctx) }() go func() { From 40317e91f8ae0a5460234ae398155af1650c9086 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Sat, 28 Aug 2021 20:02:08 -0400 Subject: [PATCH 120/290] cleanup to ssh api --- cmd/ghcs/logs.go | 42 ++++------- cmd/ghcs/ssh.go | 12 ++-- internal/codespaces/ssh.go | 132 ++++++++++++++++------------------ internal/codespaces/states.go | 19 ++--- 4 files changed, 90 insertions(+), 115 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 662e20b79..b8bc462fc 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -1,7 +1,6 @@ package main import ( - "bufio" "context" "fmt" "os" @@ -24,7 +23,7 @@ func newLogsCmd() *cobra.Command { if len(args) > 0 { codespaceName = args[0] } - return logs(tail, codespaceName) + return logs(context.Background(), tail, codespaceName) }, } @@ -37,9 +36,12 @@ func init() { rootCmd.AddCommand(newLogsCmd()) } -func logs(tail bool, codespaceName string) error { +func logs(ctx context.Context, tail bool, codespaceName string) error { + // Ensure all child tasks (port forwarding, remote exec) terminate before return. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) - ctx := context.Background() log := output.NewLogger(os.Stdout, os.Stderr, false) user, err := apiClient.GetUser(ctx) @@ -57,7 +59,7 @@ func logs(tail bool, codespaceName string) error { return fmt.Errorf("connecting to liveshare: %v", err) } - tunnelPort, connClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, 0) + tunnelPort, connClosed, err := codespaces.StartPortForwarding(ctx, lsclient, "sshd", 0) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) } @@ -68,31 +70,13 @@ func logs(tail bool, codespaceName string) error { } dst := fmt.Sprintf("%s@localhost", getSSHUser(codespace)) - stdout, err := codespaces.RunCommand( - ctx, tunnelPort, dst, fmt.Sprintf("%v /workspaces/.codespaces/.persistedshare/creation.log", cmdType), + cmd := codespaces.NewRemoteCommand( + ctx, tunnelPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) - if err != nil { - return fmt.Errorf("run command: %v", err) - } - done := make(chan error) - go func() { - scanner := bufio.NewScanner(stdout) - for scanner.Scan() { - fmt.Println(scanner.Text()) - } - - if err := scanner.Err(); err != nil { - done <- fmt.Errorf("error scanning: %v", err) - return - } - - if err := stdout.Close(); err != nil { - done <- fmt.Errorf("close stdout: %v", err) - return - } - done <- nil - }() + // Channel is buffered to avoid a goroutine leak when connClosed occurs before done. + done := make(chan error, 1) + go func() { done <- cmd.Run() }() select { case err := <-connClosed: @@ -101,7 +85,7 @@ func logs(tail bool, codespaceName string) error { } case err := <-done: if err != nil { - return err + return fmt.Errorf("error retrieving logs: %v", err) } } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 84964829a..579093c3d 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -77,7 +77,7 @@ func ssh(sshProfile, codespaceName string, sshServerPort int) error { } log.Print("\n") - tunnelPort, tunnelClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, sshServerPort) + tunnelPort, tunnelClosed, err := codespaces.StartPortForwarding(ctx, lsclient, "sshd", sshServerPort) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) } @@ -88,7 +88,11 @@ func ssh(sshProfile, codespaceName string, sshServerPort int) error { } usingCustomPort := tunnelPort == sshServerPort - connClosed := codespaces.ConnectToTunnel(ctx, log, tunnelPort, connectDestination, usingCustomPort) + + shellClosed := make(chan error) + go func() { + shellClosed <- codespaces.Shell(ctx, log, tunnelPort, connectDestination, usingCustomPort) + }() log.Println("Ready...") select { @@ -96,9 +100,9 @@ func ssh(sshProfile, codespaceName string, sshServerPort int) error { if err != nil { return fmt.Errorf("tunnel closed: %v", err) } - case err := <-connClosed: + case err := <-shellClosed: if err != nil { - return fmt.Errorf("connection closed: %v", err) + return fmt.Errorf("shell closed: %v", err) } } diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 672ba3b7b..67ffdbdd2 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -3,7 +3,6 @@ package codespaces import ( "context" "fmt" - "io" "math/rand" "os" "os/exec" @@ -14,27 +13,53 @@ import ( "github.com/github/go-liveshare" ) -func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, serverPort int) (int, <-chan error, error) { - tunnelClosed := make(chan error) - +// StartPortForwarding starts LiveShare port forwarding of traffic of +// the specified protocol (e.g. "sshd") between the LiveShare client +// and the specified local port, or, if zero, a port chosen at random; +// the effective port number is returned. Forwarding continues in the +// background until an error is encountered (including cancellation of +// the context). Therefore clients must cancel the context +// +// REVIEWERS: where is the set of legal values of protocol defined? +// It appears to be: "whatever is supported by the LiveShare service's +// serverSharing.startSharing method". Where is that defined? +// +// TODO(adonovan): simplify API concurrency from API. Either: +// 1) return a stop function so that clients don't forget to stop forwarding. +// 2) avoid creating a goroutine and returning a channel. Use approach of +// http.ListenAndServe, which runs until it encounters an error +// (incl. cancellation). But this means we can't return the port. +// Can we make the client responsible for supplying it? +// 3) return a PortForwarding object that encapsulates the port, +// and has NewRemoteCommand as a method. It will need a Stop method, +// and an Error method for querying whether the session has failed +// asynchronously. +func StartPortForwarding(ctx context.Context, lsclient *liveshare.Client, protocol string, localPort int) (int, <-chan error, error) { server, err := liveshare.NewServer(lsclient) if err != nil { return 0, nil, fmt.Errorf("new liveshare server: %v", err) } - rand.Seed(time.Now().Unix()) - port := rand.Intn(9999-2000) + 2000 // improve this obviously - if serverPort != 0 { - port = serverPort + if localPort == 0 { + // improve this obviously + // REVIEWERS: any reason not to use the global PRNG? + rng := rand.New(rand.NewSource(time.Now().Unix())) + localPort = rng.Intn(9999-2000) + 2000 + // TODO(adonovan): loop if port is taken? } // TODO(josebalius): This port won't always be 2222 - if err := server.StartSharing(ctx, "sshd", 2222); err != nil { + if err := server.StartSharing(ctx, protocol, 2222); err != nil { return 0, nil, fmt.Errorf("sharing sshd port: %v", err) } + tunnelClosed := make(chan error) go func() { - portForwarder := liveshare.NewPortForwarder(lsclient, server, port) + // TODO(adonovan): simplify liveshare API to combine NewPortForwarder and Start + // methods into a single ForwardPort call, like http.ListenAndServe. + // (Start is a misnomer: it runs the complete session.) + // Also document that it never returns a nil error. + portForwarder := liveshare.NewPortForwarder(lsclient, server, localPort) if err := portForwarder.Start(ctx); err != nil { tunnelClosed <- fmt.Errorf("forwarding port: %v", err) return @@ -42,75 +67,42 @@ func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, serverPort i tunnelClosed <- nil }() - return port, tunnelClosed, nil + return localPort, tunnelClosed, nil } -func makeSSHArgs(port int, dst, cmd string) ([]string, []string) { - connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} - cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression - - if cmd != "" { - cmdArgs = append(cmdArgs, cmd) - } - - return cmdArgs, connArgs -} - -func ConnectToTunnel(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) <-chan error { - connClosed := make(chan error) - args, connArgs := makeSSHArgs(port, destination, "") +// Shell runs an interactive secure shell over an existing +// port-forwarding session. It runs until the shell is terminated +// (including by cancellation of the context). +func Shell(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) error { + cmd, connArgs := newSSHCommand(ctx, port, destination, "") if usingCustomPort { log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " ")) } - cmd := exec.CommandContext(ctx, "ssh", args...) + return cmd.Run() +} + +// NewRemoteCommand returns a partially populated exec.Cmd that will +// securely run a shell command on the remote machine. +func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd { + cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command) + return cmd +} + +func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cmd, []string) { + connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} + cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression + + // An empty command enables port forwarding but not execution. + if command != "" { + cmdArgs = append(cmdArgs, command) + } + + cmd := exec.CommandContext(ctx, "ssh", cmdArgs...) cmd.Stdout = os.Stdout cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr - go func() { - connClosed <- cmd.Run() - }() - - return connClosed -} - -type command struct { - Cmd *exec.Cmd - StdoutPipe io.ReadCloser -} - -func newCommand(cmd *exec.Cmd) (*command, error) { - stdoutPipe, err := cmd.StdoutPipe() - if err != nil { - return nil, fmt.Errorf("create stdout pipe: %v", err) - } - - if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("cmd start: %v", err) - } - - return &command{ - Cmd: cmd, - StdoutPipe: stdoutPipe, - }, nil -} - -func (c *command) Read(p []byte) (int, error) { - return c.StdoutPipe.Read(p) -} - -func (c *command) Close() error { - if err := c.StdoutPipe.Close(); err != nil { - return fmt.Errorf("close stdout: %v", err) - } - - return c.Cmd.Wait() -} - -func RunCommand(ctx context.Context, tunnelPort int, destination, cmdString string) (io.ReadCloser, error) { - args, _ := makeSSHArgs(tunnelPort, destination, cmdString) - cmd := exec.CommandContext(ctx, "ssh", args...) - return newCommand(cmd) + return cmd, connArgs } diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 5c3dcef45..fe34f5486 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -1,10 +1,10 @@ package codespaces import ( + "bytes" "context" "encoding/json" "fmt" - "io/ioutil" "strings" "time" @@ -45,7 +45,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connect to liveshare: %v", err) } - tunnelPort, connClosed, err := MakeSSHTunnel(ctx, lsclient, 0) + tunnelPort, connClosed, err := StartPortForwarding(ctx, lsclient, "sshd", 0) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) } @@ -71,24 +71,19 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u } func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace) ([]PostCreateState, error) { - stdout, err := RunCommand( + cmd := NewRemoteCommand( ctx, tunnelPort, sshDestination(codespace), "cat /workspaces/.codespaces/shared/postCreateOutput.json", ) - if err != nil { + stdout := new(bytes.Buffer) + cmd.Stdout = stdout + if err := cmd.Run(); err != nil { return nil, fmt.Errorf("run command: %v", err) } - defer stdout.Close() - - b, err := ioutil.ReadAll(stdout) - if err != nil { - return nil, fmt.Errorf("read output: %v", err) - } - var output struct { Steps []PostCreateState `json:"steps"` } - if err := json.Unmarshal(b, &output); err != nil { + if err := json.Unmarshal(stdout.Bytes(), &output); err != nil { return nil, fmt.Errorf("unmarshal output: %v", err) } From ea97e2e73dfe8b89ffdaafb1f9c83a8cd4f2919c Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 30 Aug 2021 18:15:37 -0400 Subject: [PATCH 121/290] remove sleep 1s --- cmd/ghcs/ssh.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 84964829a..23f49c1d1 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "strings" - "time" "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" @@ -155,8 +154,6 @@ func setupSSH(ctx context.Context, logger *output.Logger, terminal *liveshare.Te return fmt.Errorf("error closing stream: %v", err) } - time.Sleep(1 * time.Second) - return nil } From 15dab395a519bee18c9b2a2f7bd7cebafc578b1b Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 30 Aug 2021 18:23:55 -0400 Subject: [PATCH 122/290] in Start, ignore HTTP 503 with reason 7 EnvironmentNotShutdown --- api/api.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/api/api.go b/api/api.go index a8f6ea724..bb3eba489 100644 --- a/api/api.go +++ b/api/api.go @@ -274,11 +274,17 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes } if resp.StatusCode != http.StatusOK { - // Error response is numeric code and/or string message, not JSON. + // Error response is typically a numeric code (not an error message, nor JSON). if len(b) > 100 { b = append(b[:97], "..."...) } - return fmt.Errorf("failed to start codespace: %s", b) + + if resp.StatusCode == http.StatusServiceUnavailable && strings.TrimSpace(string(b)) == "7" { + // HTTP 503 with error code 7 (EnvironmentNotShutdown) is benign. + // Ignore it. + } else { + return fmt.Errorf("failed to start codespace: %s", b) + } } return nil From b63972b62f2564dad26a922d346108c4f7683953 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 31 Aug 2021 11:07:26 -0400 Subject: [PATCH 123/290] spell Live Share product name correctly in UI --- client.go | 2 +- client_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index a8a1e3864..fe890f0bf 100644 --- a/client.go +++ b/client.go @@ -69,7 +69,7 @@ func (c *Client) Join(ctx context.Context) (err error) { _, err = c.joinWorkspace(ctx) if err != nil { - return fmt.Errorf("error joining liveshare workspace: %v", err) + return fmt.Errorf("error joining Live Share workspace: %v", err) } return nil diff --git a/client_test.go b/client_test.go index 110c7e3b9..f1591ed51 100644 --- a/client_test.go +++ b/client_test.go @@ -75,7 +75,7 @@ func TestClientJoin(t *testing.T) { livesharetest.WithRelaySAS(connection.RelaySAS), ) if err != nil { - t.Errorf("error creating liveshare server: %v", err) + t.Errorf("error creating Live Share server: %v", err) } defer server.Close() connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") From bbcf2dd321527e08df5c483b1646b8b7fab53d78 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 31 Aug 2021 11:15:26 -0400 Subject: [PATCH 124/290] spell product names (Codespaces, Live Share) correctly --- api/api.go | 2 +- cmd/ghcs/code.go | 2 +- cmd/ghcs/create.go | 12 ++++++------ cmd/ghcs/delete.go | 18 +++++++++--------- cmd/ghcs/list.go | 2 +- cmd/ghcs/logs.go | 4 ++-- cmd/ghcs/ports.go | 16 ++++++++-------- cmd/ghcs/ssh.go | 6 +++--- internal/codespaces/codespaces.go | 26 +++++++++++++------------- internal/codespaces/ssh.go | 2 +- internal/codespaces/states.go | 4 ++-- 11 files changed, 47 insertions(+), 47 deletions(-) diff --git a/api/api.go b/api/api.go index a8f6ea724..b9d4213a0 100644 --- a/api/api.go +++ b/api/api.go @@ -278,7 +278,7 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes if len(b) > 100 { b = append(b[:97], "..."...) } - return fmt.Errorf("failed to start codespace: %s", b) + return fmt.Errorf("failed to start Codespace: %s", b) } return nil diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 19d76fadb..b3b43f050 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -52,7 +52,7 @@ func code(codespaceName string, useInsiders bool) error { if err == codespaces.ErrNoCodespaces { return err } - return fmt.Errorf("error choosing codespace: %v", err) + return fmt.Errorf("error choosing Codespace: %v", err) } codespaceName = codespace.Name } diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 7d5d59923..71e31f298 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -69,12 +69,12 @@ func create(opts *createOptions) error { locationResult := <-locationCh if locationResult.Err != nil { - return fmt.Errorf("error getting codespace region location: %v", locationResult.Err) + return fmt.Errorf("error getting Codespace region location: %v", locationResult.Err) } userResult := <-userCh if userResult.Err != nil { - return fmt.Errorf("error getting codespace user: %v", userResult.Err) + return fmt.Errorf("error getting Codespace user: %v", userResult.Err) } machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, locationResult.Location, apiClient) @@ -85,11 +85,11 @@ func create(opts *createOptions) error { return errors.New("There are no available machine types for this repository") } - log.Println("Creating your codespace...") + log.Println("Creating your Codespace...") codespace, err := apiClient.CreateCodespace(ctx, userResult.User, repository, machine, branch, locationResult.Location) if err != nil { - return fmt.Errorf("error creating codespace: %v", err) + return fmt.Errorf("error creating Codespace: %v", err) } if opts.showStatus { @@ -154,7 +154,7 @@ func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, use } if err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace, poller); err != nil { - return fmt.Errorf("failed to poll state changes from codespace: %v", err) + return fmt.Errorf("failed to poll state changes from Codespace: %v", err) } return nil @@ -228,7 +228,7 @@ func getBranchName(branch string) (string, error) { func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, location string, apiClient *api.API) (string, error) { skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, location) if err != nil { - return "", fmt.Errorf("error getting codespace SKUs: %v", err) + return "", fmt.Errorf("error getting Codespace SKUs: %v", err) } // if user supplied a machine type, it must be valid diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index d37029753..c42f57f6f 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -65,11 +65,11 @@ func delete_(codespaceName string) error { codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { - return fmt.Errorf("get or choose codespace: %v", err) + return fmt.Errorf("get or choose Codespace: %v", err) } if err := apiClient.DeleteCodespace(ctx, user, token, codespace.Name); err != nil { - return fmt.Errorf("error deleting codespace: %v", err) + return fmt.Errorf("error deleting Codespace: %v", err) } log.Println("Codespace deleted.") @@ -89,17 +89,17 @@ func deleteAll() error { codespaces, err := apiClient.ListCodespaces(ctx, user) if err != nil { - return fmt.Errorf("error getting codespaces: %v", err) + return fmt.Errorf("error getting Codespaces: %v", err) } for _, c := range codespaces { token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) + return fmt.Errorf("error getting Codespace token: %v", err) } if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { - return fmt.Errorf("error deleting codespace: %v", err) + return fmt.Errorf("error deleting Codespace: %v", err) } log.Printf("Codespace deleted: %s\n", c.Name) @@ -120,7 +120,7 @@ func deleteByRepo(repo string) error { codespaces, err := apiClient.ListCodespaces(ctx, user) if err != nil { - return fmt.Errorf("error getting codespaces: %v", err) + return fmt.Errorf("error getting Codespaces: %v", err) } var deleted bool @@ -132,18 +132,18 @@ func deleteByRepo(repo string) error { token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) + return fmt.Errorf("error getting Codespace token: %v", err) } if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { - return fmt.Errorf("error deleting codespace: %v", err) + return fmt.Errorf("error deleting Codespace: %v", err) } log.Printf("Codespace deleted: %s\n", c.Name) } if !deleted { - return fmt.Errorf("No codespace was found for repository: %s", repo) + return fmt.Errorf("No Codespace was found for repository: %s", repo) } return list(&listOptions{}) diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index a19439296..0c055e98e 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -46,7 +46,7 @@ func list(opts *listOptions) error { codespaces, err := apiClient.ListCodespaces(ctx, user) if err != nil { - return fmt.Errorf("error getting codespaces: %v", err) + return fmt.Errorf("error getting Codespaces: %v", err) } table := output.NewTable(os.Stdout, opts.asJSON) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 662e20b79..4c840e77c 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -49,12 +49,12 @@ func logs(tail bool, codespaceName string) error { codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { - return fmt.Errorf("get or choose codespace: %v", err) + return fmt.Errorf("get or choose Codespace: %v", err) } lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return fmt.Errorf("connecting to liveshare: %v", err) + return fmt.Errorf("connecting to Live Share: %v", err) } tunnelPort, connClosed, err := codespaces.MakeSSHTunnel(ctx, lsclient, 0) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 3e89eb984..f83757ff8 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -71,14 +71,14 @@ func ports(opts *portsOptions) error { if err == codespaces.ErrNoCodespaces { return err } - return fmt.Errorf("error choosing codespace: %v", err) + return fmt.Errorf("error choosing Codespace: %v", err) } devContainerCh := getDevContainer(ctx, apiClient, codespace) liveShareClient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return fmt.Errorf("error connecting to liveshare: %v", err) + return fmt.Errorf("error connecting to Live Share: %v", err) } log.Println("Loading ports...") @@ -211,17 +211,17 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) + return fmt.Errorf("error getting Codespace token: %v", err) } codespace, err := apiClient.GetCodespace(ctx, token, user.Login, codespaceName) if err != nil { - return fmt.Errorf("error getting codespace: %v", err) + return fmt.Errorf("error getting Codespace: %v", err) } lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return fmt.Errorf("error connecting to liveshare: %v", err) + return fmt.Errorf("error connecting to Live Share: %v", err) } server, err := liveshare.NewServer(lsclient) @@ -277,17 +277,17 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) + return fmt.Errorf("error getting Codespace token: %v", err) } codespace, err := apiClient.GetCodespace(ctx, token, user.Login, codespaceName) if err != nil { - return fmt.Errorf("error getting codespace: %v", err) + return fmt.Errorf("error getting Codespace: %v", err) } lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return fmt.Errorf("error connecting to liveshare: %v", err) + return fmt.Errorf("error connecting to Live Share: %v", err) } server, err := liveshare.NewServer(lsclient) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 84964829a..7fecf0d1f 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -51,17 +51,17 @@ func ssh(sshProfile, codespaceName string, sshServerPort int) error { codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { - return fmt.Errorf("get or choose codespace: %v", err) + return fmt.Errorf("get or choose Codespace: %v", err) } lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return fmt.Errorf("error connecting to liveshare: %v", err) + return fmt.Errorf("error connecting to Live Share: %v", err) } terminal, err := liveshare.NewTerminal(lsclient) if err != nil { - return fmt.Errorf("error creating liveshare terminal: %v", err) + return fmt.Errorf("error creating Live Share terminal: %v", err) } log.Print("Preparing SSH...") diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 3214a6dea..90f676d28 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -13,13 +13,13 @@ import ( ) var ( - ErrNoCodespaces = errors.New("You have no codespaces.") + ErrNoCodespaces = errors.New("You have no Codespaces.") ) func ChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (*api.Codespace, error) { codespaces, err := apiClient.ListCodespaces(ctx, user) if err != nil { - return nil, fmt.Errorf("error getting codespaces: %v", err) + return nil, fmt.Errorf("error getting Codespaces: %v", err) } if len(codespaces) == 0 { @@ -77,9 +77,9 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { startedCodespace = true - log.Print("Starting your codespace...") + log.Print("Starting your Codespace...") if err := apiClient.StartCodespace(ctx, token, codespace); err != nil { - return nil, fmt.Errorf("error starting codespace: %v", err) + return nil, fmt.Errorf("error starting Codespace: %v", err) } } @@ -93,12 +93,12 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use } if retries == 30 { - return nil, errors.New("timed out while waiting for the codespace to start") + return nil, errors.New("timed out while waiting for the Codespace to start") } codespace, err = apiClient.GetCodespace(ctx, token, userLogin, codespace.Name) if err != nil { - return nil, fmt.Errorf("error getting codespace: %v", err) + return nil, fmt.Errorf("error getting Codespace: %v", err) } } @@ -106,7 +106,7 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use fmt.Print("\n") } - log.Println("Connecting to your codespace...") + log.Println("Connecting to your Codespace...") lsclient, err := liveshare.NewClient( liveshare.WithConnection(liveshare.Connection{ @@ -117,11 +117,11 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use }), ) if err != nil { - return nil, fmt.Errorf("error creating live share: %v", err) + return nil, fmt.Errorf("error creating Live Share: %v", err) } if err := lsclient.Join(ctx); err != nil { - return nil, fmt.Errorf("error joining liveshare client: %v", err) + return nil, fmt.Errorf("error joining Live Share client: %v", err) } return lsclient, nil @@ -134,23 +134,23 @@ func GetOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.Use if err == ErrNoCodespaces { return nil, "", err } - return nil, "", fmt.Errorf("choosing codespace: %v", err) + return nil, "", fmt.Errorf("choosing Codespace: %v", err) } codespaceName = codespace.Name token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) if err != nil { - return nil, "", fmt.Errorf("getting codespace token: %v", err) + return nil, "", fmt.Errorf("getting Codespace token: %v", err) } } else { token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) if err != nil { - return nil, "", fmt.Errorf("getting codespace token for given codespace: %v", err) + return nil, "", fmt.Errorf("getting Codespace token for given codespace: %v", err) } codespace, err = apiClient.GetCodespace(ctx, token, user.Login, codespaceName) if err != nil { - return nil, "", fmt.Errorf("getting full codespace details: %v", err) + return nil, "", fmt.Errorf("getting full Codespace details: %v", err) } } diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 672ba3b7b..ba55efae5 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -19,7 +19,7 @@ func MakeSSHTunnel(ctx context.Context, lsclient *liveshare.Client, serverPort i server, err := liveshare.NewServer(lsclient) if err != nil { - return 0, nil, fmt.Errorf("new liveshare server: %v", err) + return 0, nil, fmt.Errorf("new Live Share server: %v", err) } rand.Seed(time.Now().Unix()) diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 5c3dcef45..b6d6937a8 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -37,12 +37,12 @@ type PostCreateState struct { func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { - return fmt.Errorf("getting codespace token: %v", err) + return fmt.Errorf("getting Codespace token: %v", err) } lsclient, err := ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return fmt.Errorf("connect to liveshare: %v", err) + return fmt.Errorf("connect to Live Share: %v", err) } tunnelPort, connClosed, err := MakeSSHTunnel(ctx, lsclient, 0) From 509e037a5e916a2ebc7744859b4cc2a83998edc0 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 31 Aug 2021 12:01:59 -0400 Subject: [PATCH 125/290] address review comments --- cmd/ghcs/ssh.go | 9 ++++++--- internal/codespaces/ssh.go | 32 +++++++++++++++----------------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 579093c3d..3de123dfa 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -24,7 +24,7 @@ func newSSHCmd() *cobra.Command { Short: "SSH into a Codespace", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return ssh(sshProfile, codespaceName, sshServerPort) + return ssh(context.Background(), sshProfile, codespaceName, sshServerPort) }, } @@ -39,9 +39,12 @@ func init() { rootCmd.AddCommand(newSSHCmd()) } -func ssh(sshProfile, codespaceName string, sshServerPort int) error { +func ssh(ctx context.Context, sshProfile, codespaceName string, sshServerPort int) error { + // Ensure all child tasks (e.g. port forwarding) terminate before return. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) - ctx := context.Background() log := output.NewLogger(os.Stdout, os.Stderr, false) user, err := apiClient.GetUser(ctx) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 67ffdbdd2..ac68f13ca 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -8,21 +8,19 @@ import ( "os/exec" "strconv" "strings" - "time" "github.com/github/go-liveshare" ) -// StartPortForwarding starts LiveShare port forwarding of traffic of -// the specified protocol (e.g. "sshd") between the LiveShare client -// and the specified local port, or, if zero, a port chosen at random; -// the effective port number is returned. Forwarding continues in the -// background until an error is encountered (including cancellation of -// the context). Therefore clients must cancel the context +// StartPortForwarding starts LiveShare port forwarding of traffic +// between the LiveShare client and the specified local port, or, if +// zero, a port chosen at random; the effective port number is +// returned. Forwarding continues in the background until an error is +// encountered (including cancellation of the context). Therefore +// clients must cancel the context. // -// REVIEWERS: where is the set of legal values of protocol defined? -// It appears to be: "whatever is supported by the LiveShare service's -// serverSharing.startSharing method". Where is that defined? +// The session name is used (along with the port) to generate +// names for streams, and may appear in error messages. // // TODO(adonovan): simplify API concurrency from API. Either: // 1) return a stop function so that clients don't forget to stop forwarding. @@ -34,22 +32,19 @@ import ( // and has NewRemoteCommand as a method. It will need a Stop method, // and an Error method for querying whether the session has failed // asynchronously. -func StartPortForwarding(ctx context.Context, lsclient *liveshare.Client, protocol string, localPort int) (int, <-chan error, error) { +func StartPortForwarding(ctx context.Context, lsclient *liveshare.Client, sessionName string, localPort int) (int, <-chan error, error) { server, err := liveshare.NewServer(lsclient) if err != nil { return 0, nil, fmt.Errorf("new liveshare server: %v", err) } if localPort == 0 { - // improve this obviously - // REVIEWERS: any reason not to use the global PRNG? - rng := rand.New(rand.NewSource(time.Now().Unix())) - localPort = rng.Intn(9999-2000) + 2000 - // TODO(adonovan): loop if port is taken? + localPort = rand.Intn(9999-2000) + 2000 + // TODO(adonovan): retry if port is taken? } // TODO(josebalius): This port won't always be 2222 - if err := server.StartSharing(ctx, protocol, 2222); err != nil { + if err := server.StartSharing(ctx, sessionName, 2222); err != nil { return 0, nil, fmt.Errorf("sharing sshd port: %v", err) } @@ -90,8 +85,11 @@ func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command return cmd } +// newSSHCommand populates an exec.Cmd to run a command (or if blank, +// an interactive shell) over ssh. func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cmd, []string) { connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} + // TODO(adonovan): eliminate X11 and X11Trust flags where unneeded. cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression // An empty command enables port forwarding but not execution. From c0aae52289a8c3274c85d64639050dec66063138 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 31 Aug 2021 13:52:37 -0400 Subject: [PATCH 126/290] move port choice, and PortForwarder.Start call, into clients --- cmd/ghcs/logs.go | 35 ++++++++++------ cmd/ghcs/ssh.go | 27 ++++++++---- internal/codespaces/ssh.go | 77 +++++++++++++++-------------------- internal/codespaces/states.go | 24 ++++++++--- 4 files changed, 93 insertions(+), 70 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index b8bc462fc..a5e9235bc 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -59,7 +59,12 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { return fmt.Errorf("connecting to liveshare: %v", err) } - tunnelPort, connClosed, err := codespaces.StartPortForwarding(ctx, lsclient, "sshd", 0) + port, err := codespaces.UnusedPort() + if err != nil { + return err + } + + tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", port) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) } @@ -71,23 +76,29 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { dst := fmt.Sprintf("%s@localhost", getSSHUser(codespace)) cmd := codespaces.NewRemoteCommand( - ctx, tunnelPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), + ctx, port, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) - // Channel is buffered to avoid a goroutine leak when connClosed occurs before done. - done := make(chan error, 1) - go func() { done <- cmd.Run() }() + // Error channels are buffered so that neither sending goroutine gets stuck. + + tunnelClosed := make(chan error, 1) + go func() { + tunnelClosed <- tunnel.Start(ctx) // error is non-nil + }() + + cmdDone := make(chan error, 1) + go func() { + cmdDone <- cmd.Run() + }() select { - case err := <-connClosed: - if err != nil { - return fmt.Errorf("connection closed: %v", err) - } - case err := <-done: + case err := <-tunnelClosed: + return fmt.Errorf("connection closed: %v", err) + + case err := <-cmdDone: if err != nil { return fmt.Errorf("error retrieving logs: %v", err) } + return nil // success } - - return nil } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 3de123dfa..2b51c4cc1 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -80,7 +80,17 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, sshServerPort in } log.Print("\n") - tunnelPort, tunnelClosed, err := codespaces.StartPortForwarding(ctx, lsclient, "sshd", sshServerPort) + usingCustomPort := true + if sshServerPort == 0 { + usingCustomPort = false // suppress log of command line in Shell + port, err := codespaces.UnusedPort() + if err != nil { + return err + } + sshServerPort = port + } + + tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", sshServerPort) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) } @@ -90,26 +100,27 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, sshServerPort in connectDestination = fmt.Sprintf("%s@localhost", getSSHUser(codespace)) } - usingCustomPort := tunnelPort == sshServerPort + tunnelClosed := make(chan error) + go func() { + tunnelClosed <- tunnel.Start(ctx) // error is always non-nil + }() shellClosed := make(chan error) go func() { - shellClosed <- codespaces.Shell(ctx, log, tunnelPort, connectDestination, usingCustomPort) + shellClosed <- codespaces.Shell(ctx, log, sshServerPort, connectDestination, usingCustomPort) }() log.Println("Ready...") select { case err := <-tunnelClosed: - if err != nil { - return fmt.Errorf("tunnel closed: %v", err) - } + return fmt.Errorf("tunnel closed: %v", err) + case err := <-shellClosed: if err != nil { return fmt.Errorf("shell closed: %v", err) } + return nil // success } - - return nil } func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index ac68f13ca..5118ad91c 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -3,7 +3,7 @@ package codespaces import ( "context" "fmt" - "math/rand" + "net" "os" "os/exec" "strconv" @@ -12,57 +12,47 @@ import ( "github.com/github/go-liveshare" ) -// StartPortForwarding starts LiveShare port forwarding of traffic -// between the LiveShare client and the specified local port, or, if -// zero, a port chosen at random; the effective port number is -// returned. Forwarding continues in the background until an error is -// encountered (including cancellation of the context). Therefore -// clients must cancel the context. +// UnusedPort returns the number of a local TCP port that is currently +// unbound, or an error if none was available. +// +// Use of this function carries an inherent risk of a time-of-check to +// time-of-use race against other processes. +func UnusedPort() (int, error) { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + return 0, fmt.Errorf("internal error while choosing port: %v", err) + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return 0, fmt.Errorf("choosing available port: %v", err) + } + defer l.Close() + return l.Addr().(*net.TCPAddr).Port, nil +} + +// NewPortForwarder returns a new port forwarder for traffic between +// the Live Share client and the specified local port (which must be +// available). // // The session name is used (along with the port) to generate // names for streams, and may appear in error messages. -// -// TODO(adonovan): simplify API concurrency from API. Either: -// 1) return a stop function so that clients don't forget to stop forwarding. -// 2) avoid creating a goroutine and returning a channel. Use approach of -// http.ListenAndServe, which runs until it encounters an error -// (incl. cancellation). But this means we can't return the port. -// Can we make the client responsible for supplying it? -// 3) return a PortForwarding object that encapsulates the port, -// and has NewRemoteCommand as a method. It will need a Stop method, -// and an Error method for querying whether the session has failed -// asynchronously. -func StartPortForwarding(ctx context.Context, lsclient *liveshare.Client, sessionName string, localPort int) (int, <-chan error, error) { - server, err := liveshare.NewServer(lsclient) - if err != nil { - return 0, nil, fmt.Errorf("new liveshare server: %v", err) +func NewPortForwarder(ctx context.Context, client *liveshare.Client, sessionName string, localPort int) (*liveshare.PortForwarder, error) { + if localPort == 0 { + return nil, fmt.Errorf("a local port must be provided") } - if localPort == 0 { - localPort = rand.Intn(9999-2000) + 2000 - // TODO(adonovan): retry if port is taken? + server, err := liveshare.NewServer(client) + if err != nil { + return nil, fmt.Errorf("new liveshare server: %v", err) } // TODO(josebalius): This port won't always be 2222 if err := server.StartSharing(ctx, sessionName, 2222); err != nil { - return 0, nil, fmt.Errorf("sharing sshd port: %v", err) + return nil, fmt.Errorf("sharing sshd port: %v", err) } - tunnelClosed := make(chan error) - go func() { - // TODO(adonovan): simplify liveshare API to combine NewPortForwarder and Start - // methods into a single ForwardPort call, like http.ListenAndServe. - // (Start is a misnomer: it runs the complete session.) - // Also document that it never returns a nil error. - portForwarder := liveshare.NewPortForwarder(lsclient, server, localPort) - if err := portForwarder.Start(ctx); err != nil { - tunnelClosed <- fmt.Errorf("forwarding port: %v", err) - return - } - tunnelClosed <- nil - }() - - return localPort, tunnelClosed, nil + return liveshare.NewPortForwarder(client, server, localPort), nil } // Shell runs an interactive secure shell over an existing @@ -78,8 +68,8 @@ func Shell(ctx context.Context, log logger, port int, destination string, usingC return cmd.Run() } -// NewRemoteCommand returns a partially populated exec.Cmd that will -// securely run a shell command on the remote machine. +// NewRemoteCommand returns an exec.Cmd that will securely run a shell +// command on the remote machine. func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd { cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command) return cmd @@ -92,7 +82,6 @@ func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cm // TODO(adonovan): eliminate X11 and X11Trust flags where unneeded. cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression - // An empty command enables port forwarding but not execution. if command != "" { cmdArgs = append(cmdArgs, command) } diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index fe34f5486..870840304 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -45,22 +45,34 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connect to liveshare: %v", err) } - tunnelPort, connClosed, err := StartPortForwarding(ctx, lsclient, "sshd", 0) + port, err := UnusedPort() if err != nil { - return fmt.Errorf("make ssh tunnel: %v", err) + return err } + fwd, err := NewPortForwarder(ctx, lsclient, "sshd", port) + if err != nil { + return fmt.Errorf("creating port forwarder: %v", err) + } + + tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness + go func() { + tunnelClosed <- fwd.Start(ctx) // error is non-nil + }() + t := time.NewTicker(1 * time.Second) defer t.Stop() for { select { case <-ctx.Done(): - return nil - case err := <-connClosed: - return fmt.Errorf("connection closed: %v", err) + return nil // canceled + + case err := <-tunnelClosed: + return fmt.Errorf("connection failed: %v", err) + case <-t.C: - states, err := getPostCreateOutput(ctx, tunnelPort, codespace) + states, err := getPostCreateOutput(ctx, port, codespace) if err != nil { return fmt.Errorf("get post create output: %v", err) } From 535d832f8abfaaf997931b162411dd7aeedfa4ba Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 31 Aug 2021 15:50:04 -0400 Subject: [PATCH 127/290] small tweak --- internal/codespaces/states.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index f3e9cbefe..274d27951 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -77,7 +77,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace, user string) ([]PostCreateState, error) { stdout, err := RunCommand( - ctx, tunnelPort, sshDestination(codespace, user), + ctx, tunnelPort, fmt.Sprintf("%s@localhost", user), "cat /workspaces/.codespaces/shared/postCreateOutput.json", ) if err != nil { @@ -98,9 +98,4 @@ func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Cod } return output.Steps, nil -} - -// TODO(josebalius): this won't be needed soon -func sshDestination(codespace *api.Codespace, user string) string { - return user + "@localhost" -} +} \ No newline at end of file From ebb04d1753f27fa8747916907e9012c577708e42 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 31 Aug 2021 19:52:32 +0000 Subject: [PATCH 128/290] format code --- internal/codespaces/states.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 274d27951..d09c399e4 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -98,4 +98,4 @@ func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Cod } return output.Steps, nil -} \ No newline at end of file +} From 6a527941bf8c6098cf7763ae695b690024fccca2 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 30 Aug 2021 18:09:52 -0400 Subject: [PATCH 129/290] suppress display of usage message after errors --- cmd/ghcs/main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index bc9bc2c6b..2f1515ac0 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -19,7 +19,8 @@ func main() { var version = "DEV" var rootCmd = &cobra.Command{ - Use: "ghcs", + Use: "ghcs", + SilenceUsage: true, // don't print usage message after each error (see #80) Long: `Unofficial CLI tool to manage GitHub Codespaces. Running commands requires the GITHUB_TOKEN environment variable to be set to a From 3aad0bbeb4025b588d5192f153859d90e9f4289a Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 31 Aug 2021 17:27:53 -0400 Subject: [PATCH 130/290] check context error in PollPostCreateStates --- internal/codespaces/states.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index ee255f3e4..a58e2b235 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -33,7 +33,7 @@ type PostCreateState struct { // PollPostCreateStates watches for state changes in a codespace, // and calls the supplied poller for each batch of state changes. -// It runs until the context is cancelled or SSH tunnel is closed. +// It runs until it encounters an error, including cancellation of the context. func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { @@ -71,7 +71,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u for { select { case <-ctx.Done(): - return nil // canceled + return ctx.Err() case err := <-tunnelClosed: return fmt.Errorf("connection failed: %v", err) From 55fa17d8bc3055ddd143ac0b4e70f8513c01ef70 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 31 Aug 2021 17:30:40 -0400 Subject: [PATCH 131/290] wip --- client.go | 2 +- port_forwarder.go | 3 +-- port_forwarder_test.go | 3 ++- rpc.go | 26 +++++++++++++------------- rpc_test.go | 3 ++- 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/client.go b/client.go index a8a1e3864..628d557b5 100644 --- a/client.go +++ b/client.go @@ -64,7 +64,7 @@ func (c *Client) Join(ctx context.Context) (err error) { return fmt.Errorf("error connecting to ssh session: %v", err) } - c.rpc = newRpcClient(c.ssh) + c.rpc = newRPCClient(c.ssh) c.rpc.connect(ctx) _, err = c.joinWorkspace(ctx) diff --git a/port_forwarder.go b/port_forwarder.go index e6eedf16c..774fec863 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -48,10 +48,9 @@ func (l *PortForwarder) Start(ctx context.Context) error { case err := <-l.errCh: return err case <-ctx.Done(): + // TODO ctx.Error? return ln.Close() } - - return nil } func (l *PortForwarder) StartWithConn(ctx context.Context, conn io.ReadWriteCloser) error { diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 33a33b39b..a3621c075 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -55,7 +55,8 @@ func TestPortForwarderStart(t *testing.T) { t.Errorf("create new server: %v", err) } - ctx, _ := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() pf := NewPortForwarder(client, server, 8000) done := make(chan error) diff --git a/rpc.go b/rpc.go index 8abd0e98f..3fea63b10 100644 --- a/rpc.go +++ b/rpc.go @@ -15,7 +15,7 @@ type rpcClient struct { handler *rpcHandler } -func newRpcClient(conn io.ReadWriteCloser) *rpcClient { +func newRPCClient(conn io.ReadWriteCloser) *rpcClient { return &rpcClient{conn: conn, handler: newRPCHandler()} } @@ -24,17 +24,17 @@ func (r *rpcClient) connect(ctx context.Context) { r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) } -func (r *rpcClient) do(ctx context.Context, method string, args interface{}, result interface{}) error { +func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error { waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { - return fmt.Errorf("error on dispatch call: %v", err) + return fmt.Errorf("error dispatching %q call: %v", method, err) } return waiter.Wait(ctx, result) } type rpcHandler struct { - mutex sync.RWMutex + mutex sync.Mutex eventHandlers map[string][]chan *jsonrpc2.Request } @@ -44,34 +44,34 @@ func newRPCHandler() *rpcHandler { } } +// TODO: document obligations around chan. It appears to be used for at most one request. func (r *rpcHandler) registerEventHandler(eventMethod string) <-chan *jsonrpc2.Request { r.mutex.Lock() defer r.mutex.Unlock() ch := make(chan *jsonrpc2.Request) - if _, ok := r.eventHandlers[eventMethod]; !ok { - r.eventHandlers[eventMethod] = []chan *jsonrpc2.Request{ch} - } else { - r.eventHandlers[eventMethod] = append(r.eventHandlers[eventMethod], ch) - } + r.eventHandlers[eventMethod] = append(r.eventHandlers[eventMethod], ch) return ch } func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { r.mutex.Lock() - defer r.mutex.Unlock() + handlers := r.eventHandlers[req.Method] + r.eventHandlers[req.Method] = nil + r.mutex.Unlock() - if handlers, ok := r.eventHandlers[req.Method]; ok { + if len(handlers) > 0 { go func() { + // Broadcast the request to each handler in sequence. + // TODO rethink this. needs function call. for _, handler := range handlers { select { case handler <- req: case <-ctx.Done(): + // TODO: ctx.Err break } } - - r.eventHandlers[req.Method] = []chan *jsonrpc2.Request{} }() } } diff --git a/rpc_test.go b/rpc_test.go index d16b32a4f..7543152d1 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -15,7 +15,8 @@ func TestRPCHandlerEvents(t *testing.T) { time.Sleep(1 * time.Second) rpcHandler.Handle(context.Background(), nil, &jsonrpc2.Request{Method: "somethingHappened"}) }() - ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + defer cancel() select { case event := <-eventCh: if event.Method != "somethingHappened" { From 2163aba3d5ae5f9643e295038a0710bc57b78cc7 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 1 Sep 2021 13:54:45 -0400 Subject: [PATCH 132/290] pass branch for sku selection, pre-select if only one is returned --- api/api.go | 3 ++- cmd/ghcs/create.go | 10 +++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/api/api.go b/api/api.go index faa71d253..7686d8f81 100644 --- a/api/api.go +++ b/api/api.go @@ -327,7 +327,7 @@ type SKU struct { DisplayName string `json:"display_name"` } -func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Repository, location string) ([]*SKU, error) { +func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Repository, branch, location string) ([]*SKU, error) { req, err := http.NewRequest(http.MethodGet, githubAPI+"/vscs_internal/user/"+user.Login+"/skus", nil) if err != nil { return nil, fmt.Errorf("err creating request: %v", err) @@ -335,6 +335,7 @@ func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Rep q := req.URL.Query() q.Add("location", location) + q.Add("ref", branch) q.Add("repository_id", strconv.Itoa(repository.ID)) req.URL.RawQuery = q.Encode() diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 71e31f298..f8276bc5e 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -77,7 +77,7 @@ func create(opts *createOptions) error { return fmt.Errorf("error getting Codespace user: %v", userResult.Err) } - machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, locationResult.Location, apiClient) + machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, apiClient) if err != nil { return fmt.Errorf("error getting machine type: %v", err) } @@ -225,8 +225,8 @@ func getBranchName(branch string) (string, error) { } // getMachineName prompts the user to select the machine type, or validates the machine if non-empty. -func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, location string, apiClient *api.API) (string, error) { - skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, location) +func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient *api.API) (string, error) { + skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, branch, location) if err != nil { return "", fmt.Errorf("error getting Codespace SKUs: %v", err) } @@ -250,6 +250,10 @@ func getMachineName(ctx context.Context, machine string, user *api.User, repo *a return "", nil } + if len(skus) == 1 { + return skus[0].Name, nil // VS Code does not prompt for SKU if there is only one, this makes us consistent with that behavior + } + skuNames := make([]string, 0, len(skus)) skuByName := make(map[string]*api.SKU) for _, sku := range skus { From bfeb4e77c9063519547e1b9cf142fbc3f002e8a3 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 1 Sep 2021 14:38:37 -0400 Subject: [PATCH 133/290] remove dir command --- cmd/ghcs/ssh.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index aa85c33a3..202083afc 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -158,7 +158,7 @@ func getContainerID(ctx context.Context, logger *output.Logger, terminal *livesh } func setupEnv(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal, containerID, repositoryName, containerUser string) error { - setupBashProfileCmd := fmt.Sprintf(`echo "cd /workspaces/%v; export $(cat /workspaces/.codespaces/shared/.env | xargs); exec /bin/zsh;" > /home/%v/.bash_profile`, repositoryName, containerUser) + setupBashProfileCmd := fmt.Sprintf(`echo "export $(cat /workspaces/.codespaces/shared/.env | xargs); exec /bin/zsh;" > /home/%v/.bash_profile`, repositoryName, containerUser) logger.Print(".") compositeCommand := []string{setupBashProfileCmd} From 49ccdd3d21a277094306a4462a6b19103e8575d1 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 1 Sep 2021 17:26:26 -0400 Subject: [PATCH 134/290] use correct correct spelling of codespace --- api/api.go | 2 +- cmd/ghcs/code.go | 6 +++--- cmd/ghcs/create.go | 14 +++++++------- cmd/ghcs/delete.go | 24 ++++++++++++------------ cmd/ghcs/list.go | 4 ++-- cmd/ghcs/logs.go | 4 ++-- cmd/ghcs/ports.go | 14 +++++++------- cmd/ghcs/ssh.go | 6 +++--- internal/codespaces/codespaces.go | 24 ++++++++++++------------ internal/codespaces/states.go | 2 +- 10 files changed, 50 insertions(+), 50 deletions(-) diff --git a/api/api.go b/api/api.go index 7686d8f81..1cd69073e 100644 --- a/api/api.go +++ b/api/api.go @@ -282,7 +282,7 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes // HTTP 503 with error code 7 (EnvironmentNotShutdown) is benign. // Ignore it. } else { - return fmt.Errorf("failed to start Codespace: %s", b) + return fmt.Errorf("failed to start codespace: %s", b) } } diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index b3b43f050..5bad53648 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -17,7 +17,7 @@ func newCodeCmd() *cobra.Command { codeCmd := &cobra.Command{ Use: "code []", - Short: "Open a Codespace in VS Code", + Short: "Open a codespace in VS Code", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { var codespaceName string @@ -52,14 +52,14 @@ func code(codespaceName string, useInsiders bool) error { if err == codespaces.ErrNoCodespaces { return err } - return fmt.Errorf("error choosing Codespace: %v", err) + return fmt.Errorf("error choosing codespace: %v", err) } codespaceName = codespace.Name } url := vscodeProtocolURL(codespaceName, useInsiders) if err := open.Run(url); err != nil { - return fmt.Errorf("error opening vscode URL %s: %s. (Is VSCode installed?)", url, err) + return fmt.Errorf("error opening vscode URL %s: %s. (Is VS Code installed?)", url, err) } return nil diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index f8276bc5e..55b74d6e7 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -27,7 +27,7 @@ func newCreateCmd() *cobra.Command { createCmd := &cobra.Command{ Use: "create", - Short: "Create a Codespace", + Short: "Create a codespace", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { return create(opts) @@ -69,12 +69,12 @@ func create(opts *createOptions) error { locationResult := <-locationCh if locationResult.Err != nil { - return fmt.Errorf("error getting Codespace region location: %v", locationResult.Err) + return fmt.Errorf("error getting codespace region location: %v", locationResult.Err) } userResult := <-userCh if userResult.Err != nil { - return fmt.Errorf("error getting Codespace user: %v", userResult.Err) + return fmt.Errorf("error getting codespace user: %v", userResult.Err) } machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, apiClient) @@ -85,11 +85,11 @@ func create(opts *createOptions) error { return errors.New("There are no available machine types for this repository") } - log.Println("Creating your Codespace...") + log.Println("Creating your codespace...") codespace, err := apiClient.CreateCodespace(ctx, userResult.User, repository, machine, branch, locationResult.Location) if err != nil { - return fmt.Errorf("error creating Codespace: %v", err) + return fmt.Errorf("error creating codespace: %v", err) } if opts.showStatus { @@ -154,7 +154,7 @@ func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, use } if err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace, poller); err != nil { - return fmt.Errorf("failed to poll state changes from Codespace: %v", err) + return fmt.Errorf("failed to poll state changes from codespace: %v", err) } return nil @@ -228,7 +228,7 @@ func getBranchName(branch string) (string, error) { func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient *api.API) (string, error) { skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, branch, location) if err != nil { - return "", fmt.Errorf("error getting Codespace SKUs: %v", err) + return "", fmt.Errorf("error getting codespace SKUs: %v", err) } // if user supplied a machine type, it must be valid diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index c42f57f6f..92c405766 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -15,7 +15,7 @@ import ( func newDeleteCmd() *cobra.Command { deleteCmd := &cobra.Command{ Use: "delete []", - Short: "Delete a Codespace", + Short: "Delete a codespace", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { var codespaceName string @@ -28,7 +28,7 @@ func newDeleteCmd() *cobra.Command { deleteAllCmd := &cobra.Command{ Use: "all", - Short: "Delete all Codespaces for the current user", + Short: "Delete all codespaces for the current user", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { return deleteAll() @@ -37,7 +37,7 @@ func newDeleteCmd() *cobra.Command { deleteByRepoCmd := &cobra.Command{ Use: "repo ", - Short: "Delete all Codespaces for a repository", + Short: "Delete all codespaces for a repository", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { return deleteByRepo(args[0]) @@ -65,11 +65,11 @@ func delete_(codespaceName string) error { codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { - return fmt.Errorf("get or choose Codespace: %v", err) + return fmt.Errorf("get or choose codespace: %v", err) } if err := apiClient.DeleteCodespace(ctx, user, token, codespace.Name); err != nil { - return fmt.Errorf("error deleting Codespace: %v", err) + return fmt.Errorf("error deleting codespace: %v", err) } log.Println("Codespace deleted.") @@ -89,17 +89,17 @@ func deleteAll() error { codespaces, err := apiClient.ListCodespaces(ctx, user) if err != nil { - return fmt.Errorf("error getting Codespaces: %v", err) + return fmt.Errorf("error getting codespaces: %v", err) } for _, c := range codespaces { token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) if err != nil { - return fmt.Errorf("error getting Codespace token: %v", err) + return fmt.Errorf("error getting codespace token: %v", err) } if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { - return fmt.Errorf("error deleting Codespace: %v", err) + return fmt.Errorf("error deleting codespace: %v", err) } log.Printf("Codespace deleted: %s\n", c.Name) @@ -120,7 +120,7 @@ func deleteByRepo(repo string) error { codespaces, err := apiClient.ListCodespaces(ctx, user) if err != nil { - return fmt.Errorf("error getting Codespaces: %v", err) + return fmt.Errorf("error getting codespaces: %v", err) } var deleted bool @@ -132,18 +132,18 @@ func deleteByRepo(repo string) error { token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) if err != nil { - return fmt.Errorf("error getting Codespace token: %v", err) + return fmt.Errorf("error getting codespace token: %v", err) } if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { - return fmt.Errorf("error deleting Codespace: %v", err) + return fmt.Errorf("error deleting codespace: %v", err) } log.Printf("Codespace deleted: %s\n", c.Name) } if !deleted { - return fmt.Errorf("No Codespace was found for repository: %s", repo) + return fmt.Errorf("No codespace was found for repository: %s", repo) } return list(&listOptions{}) diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 0c055e98e..c6075a988 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -19,7 +19,7 @@ func newListCmd() *cobra.Command { listCmd := &cobra.Command{ Use: "list", - Short: "List your Codespaces", + Short: "List your codespaces", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { return list(opts) @@ -46,7 +46,7 @@ func list(opts *listOptions) error { codespaces, err := apiClient.ListCodespaces(ctx, user) if err != nil { - return fmt.Errorf("error getting Codespaces: %v", err) + return fmt.Errorf("error getting codespaces: %v", err) } table := output.NewTable(os.Stdout, opts.asJSON) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 34685e1e8..65e9dbcbe 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -16,7 +16,7 @@ func newLogsCmd() *cobra.Command { logsCmd := &cobra.Command{ Use: "logs []", - Short: "Access Codespace logs", + Short: "Access codespace logs", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { var codespaceName string @@ -51,7 +51,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { - return fmt.Errorf("get or choose Codespace: %v", err) + return fmt.Errorf("get or choose codespace: %v", err) } lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index f83757ff8..792ff97e7 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -35,14 +35,14 @@ func newPortsCmd() *cobra.Command { portsCmd := &cobra.Command{ Use: "ports", - Short: "List ports in a Codespace", + Short: "List ports in a codespace", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { return ports(opts) }, } - portsCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "The `name` of the Codespace to use") + portsCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "The `name` of the codespace to use") portsCmd.Flags().BoolVar(&opts.asJSON, "json", false, "Output as JSON") portsCmd.AddCommand(newPortsPublicCmd()) @@ -71,7 +71,7 @@ func ports(opts *portsOptions) error { if err == codespaces.ErrNoCodespaces { return err } - return fmt.Errorf("error choosing Codespace: %v", err) + return fmt.Errorf("error choosing codespace: %v", err) } devContainerCh := getDevContainer(ctx, apiClient, codespace) @@ -211,12 +211,12 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) if err != nil { - return fmt.Errorf("error getting Codespace token: %v", err) + return fmt.Errorf("error getting codespace token: %v", err) } codespace, err := apiClient.GetCodespace(ctx, token, user.Login, codespaceName) if err != nil { - return fmt.Errorf("error getting Codespace: %v", err) + return fmt.Errorf("error getting codespace: %v", err) } lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) @@ -277,12 +277,12 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) if err != nil { - return fmt.Errorf("error getting Codespace token: %v", err) + return fmt.Errorf("error getting codespace token: %v", err) } codespace, err := apiClient.GetCodespace(ctx, token, user.Login, codespaceName) if err != nil { - return fmt.Errorf("error getting Codespace: %v", err) + return fmt.Errorf("error getting codespace: %v", err) } lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index aa85c33a3..2bc08bead 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -20,7 +20,7 @@ func newSSHCmd() *cobra.Command { sshCmd := &cobra.Command{ Use: "ssh", - Short: "SSH into a Codespace", + Short: "SSH into a codespace", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { return ssh(context.Background(), sshProfile, codespaceName, sshServerPort) @@ -29,7 +29,7 @@ func newSSHCmd() *cobra.Command { sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "The `name` of the SSH profile to use") sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH server port number") - sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "The `name` of the Codespace to use") + sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "The `name` of the codespace to use") return sshCmd } @@ -53,7 +53,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { - return fmt.Errorf("get or choose Codespace: %v", err) + return fmt.Errorf("get or choose codespace: %v", err) } lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 90f676d28..f37c42ed3 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -13,13 +13,13 @@ import ( ) var ( - ErrNoCodespaces = errors.New("You have no Codespaces.") + ErrNoCodespaces = errors.New("You have no codespaces.") ) func ChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (*api.Codespace, error) { codespaces, err := apiClient.ListCodespaces(ctx, user) if err != nil { - return nil, fmt.Errorf("error getting Codespaces: %v", err) + return nil, fmt.Errorf("error getting codespaces: %v", err) } if len(codespaces) == 0 { @@ -41,7 +41,7 @@ func ChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (* { Name: "codespace", Prompt: &survey.Select{ - Message: "Choose Codespace:", + Message: "Choose codespace:", Options: codespacesNames, Default: codespacesNames[0], }, @@ -77,9 +77,9 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { startedCodespace = true - log.Print("Starting your Codespace...") + log.Print("Starting your codespace...") if err := apiClient.StartCodespace(ctx, token, codespace); err != nil { - return nil, fmt.Errorf("error starting Codespace: %v", err) + return nil, fmt.Errorf("error starting codespace: %v", err) } } @@ -93,12 +93,12 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use } if retries == 30 { - return nil, errors.New("timed out while waiting for the Codespace to start") + return nil, errors.New("timed out while waiting for the codespace to start") } codespace, err = apiClient.GetCodespace(ctx, token, userLogin, codespace.Name) if err != nil { - return nil, fmt.Errorf("error getting Codespace: %v", err) + return nil, fmt.Errorf("error getting codespace: %v", err) } } @@ -106,7 +106,7 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use fmt.Print("\n") } - log.Println("Connecting to your Codespace...") + log.Println("Connecting to your codespace...") lsclient, err := liveshare.NewClient( liveshare.WithConnection(liveshare.Connection{ @@ -134,23 +134,23 @@ func GetOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.Use if err == ErrNoCodespaces { return nil, "", err } - return nil, "", fmt.Errorf("choosing Codespace: %v", err) + return nil, "", fmt.Errorf("choosing codespace: %v", err) } codespaceName = codespace.Name token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) if err != nil { - return nil, "", fmt.Errorf("getting Codespace token: %v", err) + return nil, "", fmt.Errorf("getting codespace token: %v", err) } } else { token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) if err != nil { - return nil, "", fmt.Errorf("getting Codespace token for given codespace: %v", err) + return nil, "", fmt.Errorf("getting codespace token for given codespace: %v", err) } codespace, err = apiClient.GetCodespace(ctx, token, user.Login, codespaceName) if err != nil { - return nil, "", fmt.Errorf("getting full Codespace details: %v", err) + return nil, "", fmt.Errorf("getting full codespace details: %v", err) } } diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index a58e2b235..ce242a69b 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -37,7 +37,7 @@ type PostCreateState struct { func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { - return fmt.Errorf("getting Codespace token: %v", err) + return fmt.Errorf("getting codespace token: %v", err) } lsclient, err := ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) From 72a2099a50bea5862ad3597833fc247ff94e0679 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 1 Sep 2021 17:50:24 -0400 Subject: [PATCH 135/290] fix breakage from API changes --- cmd/ghcs/logs.go | 2 +- cmd/ghcs/ports.go | 2 +- cmd/ghcs/ssh.go | 2 +- internal/codespaces/states.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 34685e1e8..829ba9a31 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -88,7 +88,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { tunnelClosed := make(chan error, 1) go func() { - tunnelClosed <- tunnel.Start(ctx) // error is non-nil + tunnelClosed <- tunnel.Forward(ctx) // error is non-nil }() cmdDone := make(chan error, 1) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index f83757ff8..695c4e491 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -307,7 +307,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro g.Go(func() error { log.Println("Forwarding port: " + srcstr + " ==> " + strconv.Itoa(pp.dst)) portForwarder := liveshare.NewPortForwarder(lsclient, server, pp.dst) - if err := portForwarder.Start(gctx); err != nil { + if err := portForwarder.Forward(gctx); err != nil { return fmt.Errorf("error forwarding port: %v", err) } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index aa85c33a3..c6d150360 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -105,7 +105,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo tunnelClosed := make(chan error) go func() { - tunnelClosed <- tunnel.Start(ctx) // error is always non-nil + tunnelClosed <- tunnel.Forward(ctx) // error is always non-nil }() shellClosed := make(chan error) diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index a58e2b235..5a5d72d6c 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -62,7 +62,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness go func() { - tunnelClosed <- fwd.Start(ctx) // error is non-nil + tunnelClosed <- fwd.Forward(ctx) // error is non-nil }() t := time.NewTicker(1 * time.Second) From af38292f1e0a80e0ef6d996f0d67aee0452c7232 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 1 Sep 2021 18:12:23 -0400 Subject: [PATCH 136/290] fix data races --- rpc.go | 47 +++++++++++++++++++---------------------------- rpc_test.go | 5 ++++- terminal.go | 15 +++++++++++---- 3 files changed, 34 insertions(+), 33 deletions(-) diff --git a/rpc.go b/rpc.go index 3fea63b10..d1e020d17 100644 --- a/rpc.go +++ b/rpc.go @@ -33,45 +33,36 @@ func (r *rpcClient) do(ctx context.Context, method string, args, result interfac return waiter.Wait(ctx, result) } +type rpcHandlerFunc = func(*jsonrpc2.Request) + type rpcHandler struct { - mutex sync.Mutex - eventHandlers map[string][]chan *jsonrpc2.Request + handlersMu sync.Mutex + handlers map[string][]rpcHandlerFunc } func newRPCHandler() *rpcHandler { return &rpcHandler{ - eventHandlers: make(map[string][]chan *jsonrpc2.Request), + handlers: make(map[string][]rpcHandlerFunc), } } -// TODO: document obligations around chan. It appears to be used for at most one request. -func (r *rpcHandler) registerEventHandler(eventMethod string) <-chan *jsonrpc2.Request { - r.mutex.Lock() - defer r.mutex.Unlock() - - ch := make(chan *jsonrpc2.Request) - r.eventHandlers[eventMethod] = append(r.eventHandlers[eventMethod], ch) - return ch +// registerEventHandler registers a handler for the specified event. +// After the next occurrence of the event, the handler will be called, +// once, in its own goroutine. +func (r *rpcHandler) registerEventHandler(eventMethod string, h rpcHandlerFunc) { + r.handlersMu.Lock() + r.handlers[eventMethod] = append(r.handlers[eventMethod], h) + r.handlersMu.Unlock() } +// Handle calls all registered handlers for the request, concurrently, each in its own goroutine. func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { - r.mutex.Lock() - handlers := r.eventHandlers[req.Method] - r.eventHandlers[req.Method] = nil - r.mutex.Unlock() + r.handlersMu.Lock() + handlers := r.handlers[req.Method] + r.handlers[req.Method] = nil + r.handlersMu.Unlock() - if len(handlers) > 0 { - go func() { - // Broadcast the request to each handler in sequence. - // TODO rethink this. needs function call. - for _, handler := range handlers { - select { - case handler <- req: - case <-ctx.Done(): - // TODO: ctx.Err - break - } - } - }() + for _, h := range handlers { + go h(req) } } diff --git a/rpc_test.go b/rpc_test.go index 7543152d1..cf9c4cf81 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -10,7 +10,10 @@ import ( func TestRPCHandlerEvents(t *testing.T) { rpcHandler := newRPCHandler() - eventCh := rpcHandler.registerEventHandler("somethingHappened") + eventCh := make(chan *jsonrpc2.Request) + rpcHandler.registerEventHandler("somethingHappened", func(req *jsonrpc2.Request) { + eventCh <- req + }) go func() { time.Sleep(1 * time.Second) rpcHandler.Handle(context.Background(), nil, &jsonrpc2.Request{Method: "somethingHappened"}) diff --git a/terminal.go b/terminal.go index c26d9fd9f..32dd54248 100644 --- a/terminal.go +++ b/terminal.go @@ -6,6 +6,7 @@ import ( "fmt" "io" + "github.com/sourcegraph/jsonrpc2" "golang.org/x/crypto/ssh" ) @@ -71,12 +72,15 @@ func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { ReadOnlyForGuests: false, } - terminalStarted := t.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStarted") + started := make(chan struct{}) + t.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStarted", func(*jsonrpc2.Request) { + close(started) + }) var result startTerminalResult if err := t.terminal.client.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil { return nil, fmt.Errorf("error making terminal.startTerminal call: %v", err) } - <-terminalStarted + <-started channel, err := t.terminal.client.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) if err != nil { @@ -101,7 +105,10 @@ func (t terminalReadCloser) Read(b []byte) (int, error) { } func (t terminalReadCloser) Close() error { - terminalStopped := t.terminalCommand.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStopped") + stopped := make(chan struct{}) + t.terminalCommand.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStopped", func(*jsonrpc2.Request) { + close(stopped) + }) if err := t.terminalCommand.terminal.client.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil { return fmt.Errorf("error making terminal.stopTerminal call: %v", err) } @@ -110,7 +117,7 @@ func (t terminalReadCloser) Close() error { return fmt.Errorf("error closing channel: %v", err) } - <-terminalStopped + <-stopped return nil } From c31fc05746b02ace70283197b33fd3f7b6d0866a Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 09:09:05 -0400 Subject: [PATCH 137/290] more typo fixes --- api/api.go | 2 +- cmd/ghcs/create.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/api.go b/api/api.go index 1cd69073e..08ca2c370 100644 --- a/api/api.go +++ b/api/api.go @@ -330,7 +330,7 @@ type SKU struct { func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Repository, branch, location string) ([]*SKU, error) { req, err := http.NewRequest(http.MethodGet, githubAPI+"/vscs_internal/user/"+user.Login+"/skus", nil) if err != nil { - return nil, fmt.Errorf("err creating request: %v", err) + return nil, fmt.Errorf("error creating request: %v", err) } q := req.URL.Query() diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 55b74d6e7..bd1d89e4e 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -228,7 +228,7 @@ func getBranchName(branch string) (string, error) { func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient *api.API) (string, error) { skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, branch, location) if err != nil { - return "", fmt.Errorf("error getting codespace SKUs: %v", err) + return "", fmt.Errorf("error requesting machine instance types: %v", err) } // if user supplied a machine type, it must be valid From 4cceda1af02e3a097418baadd14048d331780c50 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 11:06:49 -0400 Subject: [PATCH 138/290] rename Server to Session and simplify API --- client.go | 48 +++++++++------------ client_test.go | 8 ++-- port_forwarder.go | 19 ++++----- port_forwarder_test.go | 21 +++------- rpc.go | 1 + server.go => session.go | 34 +++++---------- server_test.go => session_test.go | 70 +++++++------------------------ ssh.go | 2 +- ssh_server.go | 18 ++++---- terminal.go | 23 ++++------ 10 files changed, 82 insertions(+), 162 deletions(-) rename server.go => session.go (58%) rename server_test.go => session_test.go (74%) diff --git a/client.go b/client.go index fe890f0bf..140db3b02 100644 --- a/client.go +++ b/client.go @@ -8,13 +8,10 @@ import ( "golang.org/x/crypto/ssh" ) -// A Client capable of joining a liveshare connection +// A Client capable of joining a Live Share workspace. type Client struct { connection Connection tlsConfig *tls.Config - - ssh *sshSession - rpc *rpcClient } // A ClientOption is a function that modifies a client @@ -52,31 +49,26 @@ func WithTLSConfig(tlsConfig *tls.Config) ClientOption { } } -// Join is a method that joins the client to the liveshare session -func (c *Client) Join(ctx context.Context) (err error) { +// JoinWorkspace connects the client to the server's Live Share +// workspace and returns a session representing their connection. +func (c *Client) JoinWorkspace(ctx context.Context) (*Session, error) { clientSocket := newSocket(c.connection, c.tlsConfig) if err := clientSocket.connect(ctx); err != nil { - return fmt.Errorf("error connecting websocket: %v", err) + return nil, fmt.Errorf("error connecting websocket: %v", err) } - c.ssh = newSshSession(c.connection.SessionToken, clientSocket) - if err := c.ssh.connect(ctx); err != nil { - return fmt.Errorf("error connecting to ssh session: %v", err) + ssh := newSSHSession(c.connection.SessionToken, clientSocket) + if err := ssh.connect(ctx); err != nil { + return nil, fmt.Errorf("error connecting to ssh session: %v", err) } - c.rpc = newRpcClient(c.ssh) - c.rpc.connect(ctx) - - _, err = c.joinWorkspace(ctx) - if err != nil { - return fmt.Errorf("error joining Live Share workspace: %v", err) + rpc := newRpcClient(ssh) + rpc.connect(ctx) + if _, err := c.joinWorkspace(ctx, rpc); err != nil { + return nil, fmt.Errorf("error joining Live Share workspace: %v", err) } - return nil -} - -func (c *Client) hasJoined() bool { - return c.ssh != nil && c.rpc != nil + return &Session{ssh: ssh, rpc: rpc}, nil } type clientCapabilities struct { @@ -94,32 +86,32 @@ type joinWorkspaceResult struct { SessionNumber int `json:"sessionNumber"` } -func (c *Client) joinWorkspace(ctx context.Context) (*joinWorkspaceResult, error) { +func (client *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) { args := joinWorkspaceArgs{ - ID: c.connection.SessionID, + ID: client.connection.SessionID, ConnectionMode: "local", - JoiningUserSessionToken: c.connection.SessionToken, + JoiningUserSessionToken: client.connection.SessionToken, ClientCapabilities: clientCapabilities{ IsNonInteractive: false, }, } var result joinWorkspaceResult - if err := c.rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { + if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { return nil, fmt.Errorf("error making workspace.joinWorkspace call: %v", err) } return &result, nil } -func (c *Client) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { +func (session *Session) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { args := getStreamArgs{streamName, condition} var streamID string - if err := c.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { + if err := session.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { return nil, fmt.Errorf("error getting stream id: %v", err) } - channel, reqs, err := c.ssh.conn.OpenChannel("session", nil) + channel, reqs, err := session.ssh.conn.OpenChannel("session", nil) if err != nil { return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) } diff --git a/client_test.go b/client_test.go index f1591ed51..c1e61f6e8 100644 --- a/client_test.go +++ b/client_test.go @@ -43,7 +43,7 @@ func TestNewClientWithInvalidConnection(t *testing.T) { } } -func TestClientJoin(t *testing.T) { +func TestJoinSession(t *testing.T) { connection := Connection{ SessionID: "session-id", SessionToken: "session-token", @@ -90,10 +90,12 @@ func TestClientJoin(t *testing.T) { done := make(chan error) go func() { - if err := client.Join(ctx); err != nil { - done <- fmt.Errorf("error joining client: %v", err) + session, err := client.JoinWorkspace(ctx) + if err != nil { + done <- fmt.Errorf("error joining workspace: %v", err) return } + _ = session done <- nil }() diff --git a/port_forwarder.go b/port_forwarder.go index cc7b6ea1d..29dee58f9 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -7,20 +7,17 @@ import ( "net" ) -// A PortForwarder forwards TCP traffic between a port on a remote -// LiveShare host and a local port. +// A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session. type PortForwarder struct { - client *Client - server *Server - port int + session *Session + port int } -// NewPortForwarder creates a new PortForwarder that connects a given client, server and port. -func NewPortForwarder(client *Client, server *Server, port int) *PortForwarder { +// NewPortForwarder creates a new PortForwarder for a given Live Share session and local TCP port. +func NewPortForwarder(session *Session, port int) *PortForwarder { return &PortForwarder{ - client: client, - server: server, - port: port, + session: session, + port: port, } } @@ -87,7 +84,7 @@ func awaitError(ctx context.Context, errc <-chan error) error { func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) { defer safeClose(conn, &err) - channel, err := l.client.openStreamingChannel(ctx, l.server.streamName, l.server.streamCondition) + channel, err := l.session.openStreamingChannel(ctx, l.session.streamName, l.session.streamCondition) if err != nil { return fmt.Errorf("error opening streaming channel for new connection: %v", err) } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 3ae846937..6af5c7e70 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -15,16 +15,12 @@ import ( ) func TestNewPortForwarder(t *testing.T) { - testServer, client, err := makeMockJoinedClient() + testServer, session, err := makeMockSession() if err != nil { t.Errorf("create mock client: %v", err) } defer testServer.Close() - server, err := NewServer(client) - if err != nil { - t.Errorf("create new server: %v", err) - } - pf := NewPortForwarder(client, server, 80) + pf := NewPortForwarder(session, 80) if pf == nil { t.Error("port forwarder is nil") } @@ -40,27 +36,22 @@ func TestPortForwarderStart(t *testing.T) { } stream := bytes.NewBufferString("stream-data") - testServer, client, err := makeMockJoinedClient( + testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", serverSharing), livesharetest.WithService("streamManager.getStream", getStream), livesharetest.WithStream("stream-id", stream), ) if err != nil { - t.Errorf("create mock client: %v", err) + t.Errorf("create mock session: %v", err) } defer testServer.Close() - server, err := NewServer(client) - if err != nil { - t.Errorf("create new server: %v", err) - } - ctx, _ := context.WithCancel(context.Background()) - pf := NewPortForwarder(client, server, 8000) + pf := NewPortForwarder(session, 8000) done := make(chan error) go func() { - if err := server.StartSharing(ctx, "http", 8000); err != nil { + if err := session.StartSharing(ctx, "http", 8000); err != nil { done <- fmt.Errorf("start sharing: %v", err) } done <- pf.Forward(ctx) diff --git a/rpc.go b/rpc.go index 8abd0e98f..c58ab419d 100644 --- a/rpc.go +++ b/rpc.go @@ -21,6 +21,7 @@ func newRpcClient(conn io.ReadWriteCloser) *rpcClient { func (r *rpcClient) connect(ctx context.Context) { stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) + // TODO(adonovan): fix: ensure r.Conn is eventually Closed! r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) } diff --git a/server.go b/session.go similarity index 58% rename from server.go rename to session.go index 7e8c8b1cb..b1a175df3 100644 --- a/server.go +++ b/session.go @@ -2,27 +2,18 @@ package liveshare import ( "context" - "errors" "fmt" "strconv" ) -// A Server represents the liveshare host and container server -type Server struct { - client *Client +// A Session represents the session between a connected Live Share client and server. +type Session struct { + ssh *sshSession + rpc *rpcClient port int streamName, streamCondition string } -// NewServer creates a new Server with a given Client -func NewServer(client *Client) (*Server, error) { - if !client.hasJoined() { - return nil, errors.New("client must join before creating server") - } - - return &Server{client: client}, nil -} - // Port represents an open port on the container type Port struct { SourcePort int `json:"sourcePort"` @@ -37,11 +28,11 @@ type Port struct { } // StartSharing tells the liveshare host to start sharing the port from the container -func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { +func (s *Session) StartSharing(ctx context.Context, protocol string, port int) error { s.port = port var response Port - if err := s.client.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ + if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ port, protocol, fmt.Sprintf("http://localhost:%s", strconv.Itoa(port)), }, &response); err != nil { return err @@ -53,13 +44,10 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er return nil } -// Ports is a slice of Port pointers -type Ports []*Port - // GetSharedServers returns a list of available/open ports from the container -func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { - var response Ports - if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { +func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) { + var response []*Port + if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { return nil, err } @@ -68,8 +56,8 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { // UpdateSharedVisibility controls port permissions and whether it can be accessed publicly // via the Browse URL -func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error { - if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil { +func (s *Session) UpdateSharedVisibility(ctx context.Context, port int, public bool) error { + if err := s.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil { return err } diff --git a/server_test.go b/session_test.go similarity index 74% rename from server_test.go rename to session_test.go index b91fbfddc..005eacfbd 100644 --- a/server_test.go +++ b/session_test.go @@ -13,17 +13,7 @@ import ( "github.com/sourcegraph/jsonrpc2" ) -func TestNewServerWithNotJoinedClient(t *testing.T) { - client, err := NewClient() - if err != nil { - t.Errorf("error creating new client: %v", err) - } - if _, err := NewServer(client); err == nil { - t.Error("expected error") - } -} - -func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { +func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { connection := Connection{ SessionID: "session-id", SessionToken: "session-token", @@ -47,25 +37,11 @@ func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Se return nil, nil, fmt.Errorf("error creating new client: %v", err) } ctx := context.Background() - if err := client.Join(ctx); err != nil { - return nil, nil, fmt.Errorf("error joining client: %v", err) - } - return testServer, client, nil -} - -func TestNewServer(t *testing.T) { - testServer, client, err := makeMockJoinedClient() - defer testServer.Close() + session, err := client.JoinWorkspace(ctx) if err != nil { - t.Errorf("error creating mock joined client: %v", err) - } - server, err := NewServer(client) - if err != nil { - t.Errorf("error creating new server: %v", err) - } - if server == nil { - t.Error("server is nil") + return nil, nil, fmt.Errorf("error joining workspace: %v", err) } + return testServer, session, nil } func TestServerStartSharing(t *testing.T) { @@ -95,25 +71,21 @@ func TestServerStartSharing(t *testing.T) { } return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil } - testServer, client, err := makeMockJoinedClient( + testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", startSharing), ) defer testServer.Close() if err != nil { - t.Errorf("error creating mock joined client: %v", err) - } - server, err := NewServer(client) - if err != nil { - t.Errorf("error creating new server: %v", err) + t.Errorf("error creating mock session: %v", err) } ctx := context.Background() done := make(chan error) go func() { - if err := server.StartSharing(ctx, serverProtocol, serverPort); err != nil { + if err := session.StartSharing(ctx, serverProtocol, serverPort); err != nil { done <- fmt.Errorf("error sharing server: %v", err) } - if server.streamName == "" || server.streamCondition == "" { + if session.streamName == "" || session.streamCondition == "" { done <- errors.New("stream name or condition is blank") } done <- nil @@ -136,23 +108,19 @@ func TestServerGetSharedServers(t *testing.T) { StreamCondition: "stream-condition", } getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) { - return Ports{&sharedServer}, nil + return []*Port{&sharedServer}, nil } - testServer, client, err := makeMockJoinedClient( + testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), ) if err != nil { - t.Errorf("error creating new mock client: %v", err) + t.Errorf("error creating mock session: %v", err) } defer testServer.Close() - server, err := NewServer(client) - if err != nil { - t.Errorf("error creating new server: %v", err) - } ctx := context.Background() done := make(chan error) go func() { - ports, err := server.GetSharedServers(ctx) + ports, err := session.GetSharedServers(ctx) if err != nil { done <- fmt.Errorf("error getting shared servers: %v", err) } @@ -206,25 +174,17 @@ func TestServerUpdateSharedVisibility(t *testing.T) { } return nil, nil } - testServer, client, err := makeMockJoinedClient( + testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), ) if err != nil { - t.Errorf("creating new mock client: %v", err) + t.Errorf("creating mock session: %v", err) } defer testServer.Close() - server, err := NewServer(client) - if err != nil { - t.Errorf("creating server: %v", err) - } ctx := context.Background() done := make(chan error) go func() { - if err := server.UpdateSharedVisibility(ctx, 80, true); err != nil { - done <- err - return - } - done <- nil + done <- session.UpdateSharedVisibility(ctx, 80, true) }() select { case err := <-testServer.Err(): diff --git a/ssh.go b/ssh.go index e22cd69d1..b68d400a1 100644 --- a/ssh.go +++ b/ssh.go @@ -19,7 +19,7 @@ type sshSession struct { writer io.Writer } -func newSshSession(token string, socket net.Conn) *sshSession { +func newSSHSession(token string, socket net.Conn) *sshSession { return &sshSession{token: token, socket: socket} } diff --git a/ssh_server.go b/ssh_server.go index ec7d8dfd1..03b45f25f 100644 --- a/ssh_server.go +++ b/ssh_server.go @@ -2,18 +2,14 @@ package liveshare import ( "context" - "errors" ) type SSHServer struct { - client *Client + session *Session } -func NewSSHServer(client *Client) (*SSHServer, error) { - if !client.hasJoined() { - return nil, errors.New("client must join before creating server") - } - return &SSHServer{client: client}, nil +func (session *Session) SSHServer() *SSHServer { + return &SSHServer{session: session} } type SSHServerStartResult struct { @@ -23,12 +19,12 @@ type SSHServerStartResult struct { Message string `json:"message"` } -func (s *SSHServer) StartRemoteServer(ctx context.Context) (SSHServerStartResult, error) { +func (s *SSHServer) StartRemoteServer(ctx context.Context) (*SSHServerStartResult, error) { var response SSHServerStartResult - if err := s.client.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { - return response, err + if err := s.session.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { + return nil, err } - return response, nil + return &response, nil } diff --git a/terminal.go b/terminal.go index c26d9fd9f..07532f426 100644 --- a/terminal.go +++ b/terminal.go @@ -2,7 +2,6 @@ package liveshare import ( "context" - "errors" "fmt" "io" @@ -10,17 +9,11 @@ import ( ) type Terminal struct { - client *Client + session *Session } -func NewTerminal(client *Client) (*Terminal, error) { - if !client.hasJoined() { - return nil, errors.New("client must join before creating terminal") - } - - return &Terminal{ - client: client, - }, nil +func NewTerminal(session *Session) *Terminal { + return &Terminal{session: session} } type TerminalCommand struct { @@ -71,14 +64,14 @@ func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { ReadOnlyForGuests: false, } - terminalStarted := t.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStarted") + terminalStarted := t.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStarted") var result startTerminalResult - if err := t.terminal.client.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil { + if err := t.terminal.session.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil { return nil, fmt.Errorf("error making terminal.startTerminal call: %v", err) } <-terminalStarted - channel, err := t.terminal.client.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) + channel, err := t.terminal.session.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) if err != nil { return nil, fmt.Errorf("error opening streaming channel: %v", err) } @@ -101,8 +94,8 @@ func (t terminalReadCloser) Read(b []byte) (int, error) { } func (t terminalReadCloser) Close() error { - terminalStopped := t.terminalCommand.terminal.client.rpc.handler.registerEventHandler("terminal.terminalStopped") - if err := t.terminalCommand.terminal.client.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil { + terminalStopped := t.terminalCommand.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStopped") + if err := t.terminalCommand.terminal.session.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil { return fmt.Errorf("error making terminal.stopTerminal call: %v", err) } From 05a3d90a99b4f884a492c797f352458519d1252a Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 11:39:29 -0400 Subject: [PATCH 139/290] more tweaks --- client.go | 12 ++++++------ port_forwarder_test.go | 3 ++- rpc_test.go | 3 ++- session.go | 5 +++-- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index 140db3b02..19b0aff50 100644 --- a/client.go +++ b/client.go @@ -86,11 +86,11 @@ type joinWorkspaceResult struct { SessionNumber int `json:"sessionNumber"` } -func (client *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) { +func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) { args := joinWorkspaceArgs{ - ID: client.connection.SessionID, + ID: c.connection.SessionID, ConnectionMode: "local", - JoiningUserSessionToken: client.connection.SessionToken, + JoiningUserSessionToken: c.connection.SessionToken, ClientCapabilities: clientCapabilities{ IsNonInteractive: false, }, @@ -104,14 +104,14 @@ func (client *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinW return &result, nil } -func (session *Session) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { +func (s *Session) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { args := getStreamArgs{streamName, condition} var streamID string - if err := session.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { + if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { return nil, fmt.Errorf("error getting stream id: %v", err) } - channel, reqs, err := session.ssh.conn.OpenChannel("session", nil) + channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) if err != nil { return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 6af5c7e70..44ef59fe0 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -46,7 +46,8 @@ func TestPortForwarderStart(t *testing.T) { } defer testServer.Close() - ctx, _ := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() pf := NewPortForwarder(session, 8000) done := make(chan error) diff --git a/rpc_test.go b/rpc_test.go index d16b32a4f..7543152d1 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -15,7 +15,8 @@ func TestRPCHandlerEvents(t *testing.T) { time.Sleep(1 * time.Second) rpcHandler.Handle(context.Background(), nil, &jsonrpc2.Request{Method: "somethingHappened"}) }() - ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) + defer cancel() select { case event := <-eventCh: if event.Method != "somethingHappened" { diff --git a/session.go b/session.go index b1a175df3..d13bba9f1 100644 --- a/session.go +++ b/session.go @@ -3,7 +3,6 @@ package liveshare import ( "context" "fmt" - "strconv" ) // A Session represents the session between a connected Live Share client and server. @@ -25,6 +24,8 @@ type Port struct { IsPublic bool `json:"isPublic"` IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"` HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"` + // ^^^ + // TODO(adonovan): fix possible typo in field name, and audit others. } // StartSharing tells the liveshare host to start sharing the port from the container @@ -33,7 +34,7 @@ func (s *Session) StartSharing(ctx context.Context, protocol string, port int) e var response Port if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ - port, protocol, fmt.Sprintf("http://localhost:%s", strconv.Itoa(port)), + port, protocol, fmt.Sprintf("http://localhost:%d", port), }, &response); err != nil { return err } From 6f45c7fa7dfd4553483d28242df77ca059a174d1 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 12:14:04 -0400 Subject: [PATCH 140/290] point out data races to be fixed --- session.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/session.go b/session.go index ed87c6c2c..d57906f26 100644 --- a/session.go +++ b/session.go @@ -7,13 +7,16 @@ import ( // A Session represents the session between a connected Live Share client and server. type Session struct { - ssh *sshSession - rpc *rpcClient - port int + ssh *sshSession + rpc *rpcClient + + // TODO(adonovan): fix: avoid data race of state accessed by + // multiple calls to StartSharing and concurrent calls to + // PortForwarder. Perhaps combine the two operations in the API? streamName, streamCondition string } -// Port represents an open port on the container +// Port describes a port exposed by the container. type Port struct { SourcePort int `json:"sourcePort"` DestinationPort int `json:"destinationPort"` @@ -31,8 +34,6 @@ type Port struct { // StartSharing tells the Live Share host to start sharing the specified port from the container. // The sessionName describes the purpose of the port or service. func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) error { - s.port = port - var response Port if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ port, sessionName, fmt.Sprintf("http://localhost:%d", port), @@ -46,7 +47,8 @@ func (s *Session) StartSharing(ctx context.Context, sessionName string, port int return nil } -// GetSharedServers returns a list of available/open ports from the container +// GetSharedServers returns a description of each container port +// shared by a prior call to StartSharing by some client. func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) { var response []*Port if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { From 8570f4111d954d9f15b78ed763ddb34ab7932740 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 11:14:36 -0400 Subject: [PATCH 141/290] sketch after API changes in go-liveshare#11 --- cmd/ghcs/logs.go | 2 +- cmd/ghcs/ports.go | 47 +++++++++---------------------- cmd/ghcs/ssh.go | 13 ++++----- internal/codespaces/codespaces.go | 11 +++----- internal/codespaces/ssh.go | 24 +++++++--------- internal/codespaces/states.go | 8 +++--- 6 files changed, 38 insertions(+), 67 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 34685e1e8..829ba9a31 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -88,7 +88,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { tunnelClosed := make(chan error, 1) go func() { - tunnelClosed <- tunnel.Start(ctx) // error is non-nil + tunnelClosed <- tunnel.Forward(ctx) // error is non-nil }() cmdDone := make(chan error, 1) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index f83757ff8..522ef61cb 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -76,15 +76,15 @@ func ports(opts *portsOptions) error { devContainerCh := getDevContainer(ctx, apiClient, codespace) - liveShareClient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %v", err) } log.Println("Loading ports...") - ports, err := getPorts(ctx, liveShareClient) + ports, err := session.GetSharedServers(ctx) if err != nil { - return fmt.Errorf("error getting ports: %v", err) + return fmt.Errorf("error getting ports of shared servers: %v", err) } devContainerResult := <-devContainerCh @@ -116,20 +116,6 @@ func ports(opts *portsOptions) error { return nil } -func getPorts(ctx context.Context, lsclient *liveshare.Client) (liveshare.Ports, error) { - server, err := liveshare.NewServer(lsclient) - if err != nil { - return nil, fmt.Errorf("error creating server: %v", err) - } - - ports, err := server.GetSharedServers(ctx) - if err != nil { - return nil, fmt.Errorf("error getting shared servers: %v", err) - } - - return ports, nil -} - type devContainerResult struct { devContainer *devContainer err error @@ -219,22 +205,17 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, return fmt.Errorf("error getting Codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %v", err) } - server, err := liveshare.NewServer(lsclient) - if err != nil { - return fmt.Errorf("error creating server: %v", err) - } - port, err := strconv.Atoi(sourcePort) if err != nil { return fmt.Errorf("error reading port number: %v", err) } - if err := server.UpdateSharedVisibility(ctx, port, public); err != nil { + if err := session.UpdateSharedVisibility(ctx, port, public); err != nil { return fmt.Errorf("error update port to public: %v", err) } @@ -285,29 +266,26 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro return fmt.Errorf("error getting Codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %v", err) } - server, err := liveshare.NewServer(lsclient) - if err != nil { - return fmt.Errorf("error creating server: %v", err) - } - g, gctx := errgroup.WithContext(ctx) for _, portPair := range portPairs { pp := portPair + // TODO(adonovan): fix data race on Session between + // StartSharing and NewPortForwarder. srcstr := strconv.Itoa(portPair.src) - if err := server.StartSharing(gctx, "share-"+srcstr, pp.src); err != nil { + if err := session.StartSharing(gctx, "share-"+srcstr, pp.src); err != nil { return fmt.Errorf("start sharing port: %v", err) } g.Go(func() error { log.Println("Forwarding port: " + srcstr + " ==> " + strconv.Itoa(pp.dst)) - portForwarder := liveshare.NewPortForwarder(lsclient, server, pp.dst) - if err := portForwarder.Start(gctx); err != nil { + portForwarder := liveshare.NewPortForwarder(session, pp.dst) + if err := portForwarder.Forward(gctx); err != nil { return fmt.Errorf("error forwarding port: %v", err) } @@ -315,6 +293,9 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro }) } + // TODO(adonovan): fix: the waits for _all_ goroutines to terminate. + // If there are multiple ports, one long-lived successful connection + // will hide errors from any that fail. if err := g.Wait(); err != nil { return err } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index aa85c33a3..91329b28c 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -56,20 +56,17 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo return fmt.Errorf("get or choose Codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %v", err) } - remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, lsclient, log) + remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log) if err != nil { return fmt.Errorf("error getting ssh server details: %v", err) } - terminal, err := liveshare.NewTerminal(lsclient) - if err != nil { - return fmt.Errorf("error creating Live Share terminal: %v", err) - } + terminal := liveshare.NewTerminal(session) log.Print("Preparing SSH...") if sshProfile == "" { @@ -93,7 +90,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo } } - tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", localSSHServerPort, remoteSSHServerPort) + tunnel, err := codespaces.NewPortForwarder(ctx, session, "sshd", localSSHServerPort, remoteSSHServerPort) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) } @@ -105,7 +102,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo tunnelClosed := make(chan error) go func() { - tunnelClosed <- tunnel.Start(ctx) // error is always non-nil + tunnelClosed <- tunnel.Forward(ctx) // error is always non-nil }() shellClosed := make(chan error) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 90f676d28..86b703d92 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -73,7 +73,7 @@ func connectionReady(codespace *api.Codespace) bool { codespace.Environment.State == api.CodespaceEnvironmentStateAvailable } -func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (client *liveshare.Client, err error) { +func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { startedCodespace = true @@ -96,6 +96,7 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use return nil, errors.New("timed out while waiting for the Codespace to start") } + var err error codespace, err = apiClient.GetCodespace(ctx, token, userLogin, codespace.Name) if err != nil { return nil, fmt.Errorf("error getting Codespace: %v", err) @@ -117,14 +118,10 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use }), ) if err != nil { - return nil, fmt.Errorf("error creating Live Share: %v", err) + return nil, fmt.Errorf("error creating Live Share client: %v", err) } - if err := lsclient.Join(ctx); err != nil { - return nil, fmt.Errorf("error joining Live Share client: %v", err) - } - - return lsclient, nil + return lsclient.JoinWorkspace(ctx) } func GetOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 7a82e6af7..a8f1834d4 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -32,37 +32,33 @@ func UnusedPort() (int, error) { return l.Addr().(*net.TCPAddr).Port, nil } -// NewPortForwarder returns a new port forwarder for traffic between -// the Live Share client and the specified local and remote ports. +// NewPortForwarder returns a new port forwarder that forwards traffic between +// the specified local and remote ports over the provided Live Share session. // // The session name is used (along with the port) to generate // names for streams, and may appear in error messages. -func NewPortForwarder(ctx context.Context, client *liveshare.Client, sessionName string, localSSHPort, remoteSSHPort int) (*liveshare.PortForwarder, error) { +func NewPortForwarder(ctx context.Context, session *liveshare.Session, sessionName string, localSSHPort, remoteSSHPort int) (*liveshare.PortForwarder, error) { if localSSHPort == 0 { return nil, fmt.Errorf("a local port must be provided") } - server, err := liveshare.NewServer(client) - if err != nil { - return nil, fmt.Errorf("new liveshare server: %v", err) - } + // TODO(adonovan): fix data race on Session between + // StartSharing and NewPortForwarder. Perhaps combine the + // operations in go-liveshare? - if err := server.StartSharing(ctx, "sshd", remoteSSHPort); err != nil { + if err := session.StartSharing(ctx, "sshd", remoteSSHPort); err != nil { return nil, fmt.Errorf("sharing sshd port: %v", err) } - return liveshare.NewPortForwarder(client, server, localSSHPort), nil + return liveshare.NewPortForwarder(session, localSSHPort), nil } // StartSSHServer installs (if necessary) and starts the SSH in the codespace. // It returns the remote port where it is running, the user to log in with, or an error if something failed. -func StartSSHServer(ctx context.Context, client *liveshare.Client, log logger) (serverPort int, user string, err error) { +func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) (serverPort int, user string, err error) { log.Println("Fetching SSH details...") - sshServer, err := liveshare.NewSSHServer(client) - if err != nil { - return 0, "", fmt.Errorf("error creating live share: %v", err) - } + sshServer := session.SSHServer() sshServerStartResult, err := sshServer.StartRemoteServer(ctx) if err != nil { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index a58e2b235..f0052e72c 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -40,7 +40,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("getting Codespace token: %v", err) } - lsclient, err := ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("connect to Live Share: %v", err) } @@ -50,19 +50,19 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return err } - remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, lsclient, log) + remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, session, log) if err != nil { return fmt.Errorf("error getting ssh server details: %v", err) } - fwd, err := NewPortForwarder(ctx, lsclient, "sshd", localSSHPort, remoteSSHServerPort) + fwd, err := NewPortForwarder(ctx, session, "sshd", localSSHPort, remoteSSHServerPort) if err != nil { return fmt.Errorf("creating port forwarder: %v", err) } tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness go func() { - tunnelClosed <- fwd.Start(ctx) // error is non-nil + tunnelClosed <- fwd.Forward(ctx) // error is non-nil }() t := time.NewTicker(1 * time.Second) From c15d810d68ab0e9ffb0e2e094125845557d2ad7c Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 2 Sep 2021 13:27:12 -0400 Subject: [PATCH 142/290] remove extra verb arg --- cmd/ghcs/ssh.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 202083afc..a12e44209 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -158,7 +158,7 @@ func getContainerID(ctx context.Context, logger *output.Logger, terminal *livesh } func setupEnv(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal, containerID, repositoryName, containerUser string) error { - setupBashProfileCmd := fmt.Sprintf(`echo "export $(cat /workspaces/.codespaces/shared/.env | xargs); exec /bin/zsh;" > /home/%v/.bash_profile`, repositoryName, containerUser) + setupBashProfileCmd := fmt.Sprintf(`echo "export $(cat /workspaces/.codespaces/shared/.env | xargs); exec /bin/zsh;" > /home/%v/.bash_profile`, containerUser) logger.Print(".") compositeCommand := []string{setupBashProfileCmd} From 5c65cfd2498785d82617357d0ce49ffc8a78c7c2 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 2 Sep 2021 13:42:52 -0400 Subject: [PATCH 143/290] ignore any 7 err code in start --- api/api.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/api.go b/api/api.go index 08ca2c370..470c99695 100644 --- a/api/api.go +++ b/api/api.go @@ -278,8 +278,8 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes if len(b) > 100 { b = append(b[:97], "..."...) } - if resp.StatusCode == http.StatusServiceUnavailable && strings.TrimSpace(string(b)) == "7" { - // HTTP 503 with error code 7 (EnvironmentNotShutdown) is benign. + if strings.TrimSpace(string(b)) == "7" { + // NON HTTP 200 with error code 7 (EnvironmentNotShutdown) is benign. // Ignore it. } else { return fmt.Errorf("failed to start codespace: %s", b) From 87b15aa264e583688aa9b448ea57663b87a2b4cf Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 14:03:48 -0400 Subject: [PATCH 144/290] Fix data race in StartSharing --- client.go | 13 +++++++++-- port_forwarder.go | 50 +++++++++++++++++++++++++++++++----------- port_forwarder_test.go | 11 ++++------ session.go | 24 +++++++------------- session_test.go | 5 +++-- terminal.go | 2 +- 6 files changed, 64 insertions(+), 41 deletions(-) diff --git a/client.go b/client.go index 377ec2512..0088662f7 100644 --- a/client.go +++ b/client.go @@ -86,6 +86,12 @@ type joinWorkspaceResult struct { SessionNumber int `json:"sessionNumber"` } +// A channelID is an identifier for an exposed port on a remote +// container that may be used to open an SSH channel to it. +type channelID struct { + name, condition string +} + func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) { args := joinWorkspaceArgs{ ID: c.connection.SessionID, @@ -104,8 +110,11 @@ func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorksp return &result, nil } -func (s *Session) openStreamingChannel(ctx context.Context, streamName, condition string) (ssh.Channel, error) { - args := getStreamArgs{streamName, condition} +func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) { + args := getStreamArgs{ + StreamName: id.name, + Condition: id.condition, + } var streamID string if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { return nil, fmt.Errorf("error getting stream id: %v", err) diff --git a/port_forwarder.go b/port_forwarder.go index 29dee58f9..4391ef55c 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -9,23 +9,34 @@ import ( // A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session. type PortForwarder struct { - session *Session - port int + session *Session + name string + localPort, remotePort int } -// NewPortForwarder creates a new PortForwarder for a given Live Share session and local TCP port. -func NewPortForwarder(session *Session, port int) *PortForwarder { +// NewPortForwarder creates a new PortForwarder that forwards traffic +// between the local port and the container's remote port over the +// specified Live Share session. The name describes the purpose of the +// remote port or service. +func NewPortForwarder(session *Session, name string, localPort, remotePort int) *PortForwarder { return &PortForwarder{ - session: session, - port: port, + session: session, + name: name, + localPort: localPort, + remotePort: remotePort, } } // Forward enables port forwarding. It accepts and handles TCP // connections until it encounters the first error, which may include // context cancellation. Its result is non-nil. -func (l *PortForwarder) Forward(ctx context.Context) (err error) { - listen, err := net.Listen("tcp", fmt.Sprintf(":%d", l.port)) +func (fwd *PortForwarder) Forward(ctx context.Context) (err error) { + id, err := fwd.shareRemotePort(ctx) + if err != nil { + return err + } + + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", fwd.localPort)) if err != nil { return fmt.Errorf("error listening on TCP port: %v", err) } @@ -49,7 +60,7 @@ func (l *PortForwarder) Forward(ctx context.Context) (err error) { } go func() { - if err := l.handleConnection(ctx, conn); err != nil { + if err := fwd.handleConnection(ctx, id, conn); err != nil { sendError(err) } }() @@ -60,17 +71,30 @@ func (l *PortForwarder) Forward(ctx context.Context) (err error) { } // ForwardWithConn handles port forwarding for a single connection. -func (l *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { +func (fwd *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { + id, err := fwd.shareRemotePort(ctx) + if err != nil { + return err + } + // Create buffered channel so that send doesn't get stuck after context cancellation. errc := make(chan error, 1) go func() { - if err := l.handleConnection(ctx, conn); err != nil { + if err := fwd.handleConnection(ctx, id, conn); err != nil { errc <- err } }() return awaitError(ctx, errc) } +func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) { + id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort) + if err != nil { + err = fmt.Errorf("failed to share remote port %d: %v", fwd.remotePort, err) + } + return id, nil +} + func awaitError(ctx context.Context, errc <-chan error) error { select { case err := <-errc: @@ -81,10 +105,10 @@ func awaitError(ctx context.Context, errc <-chan error) error { } // handleConnection handles forwarding for a single accepted connection, then closes it. -func (l *PortForwarder) handleConnection(ctx context.Context, conn io.ReadWriteCloser) (err error) { +func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) { defer safeClose(conn, &err) - channel, err := l.session.openStreamingChannel(ctx, l.session.streamName, l.session.streamCondition) + channel, err := fwd.session.openStreamingChannel(ctx, id) if err != nil { return fmt.Errorf("error opening streaming channel for new connection: %v", err) } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 44ef59fe0..d47730995 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) { t.Errorf("create mock client: %v", err) } defer testServer.Close() - pf := NewPortForwarder(session, 80) + pf := NewPortForwarder(session, "ssh", 81, 80) if pf == nil { t.Error("port forwarder is nil") } @@ -48,14 +48,11 @@ func TestPortForwarderStart(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pf := NewPortForwarder(session, 8000) - done := make(chan error) + done := make(chan error) go func() { - if err := session.StartSharing(ctx, "http", 8000); err != nil { - done <- fmt.Errorf("start sharing: %v", err) - } - done <- pf.Forward(ctx) + const name, local, remote = "ssh", 8000, 8000 + done <- NewPortForwarder(session, name, local, remote).Forward(ctx) }() go func() { diff --git a/session.go b/session.go index d57906f26..0e3120cd7 100644 --- a/session.go +++ b/session.go @@ -9,11 +9,6 @@ import ( type Session struct { ssh *sshSession rpc *rpcClient - - // TODO(adonovan): fix: avoid data race of state accessed by - // multiple calls to StartSharing and concurrent calls to - // PortForwarder. Perhaps combine the two operations in the API? - streamName, streamCondition string } // Port describes a port exposed by the container. @@ -31,20 +26,17 @@ type Port struct { // TODO(adonovan): fix possible typo in field name, and audit others. } -// StartSharing tells the Live Share host to start sharing the specified port from the container. -// The sessionName describes the purpose of the port or service. -func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) error { +// startSharing tells the Live Share host to start sharing the specified port from the container. +// The sessionName describes the purpose of the remote port or service. +// It returns an identifier that can be used to open an SSH channel to the remote port. +func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) { + args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)} var response Port - if err := s.rpc.do(ctx, "serverSharing.startSharing", []interface{}{ - port, sessionName, fmt.Sprintf("http://localhost:%d", port), - }, &response); err != nil { - return err + if err := s.rpc.do(ctx, "serverSharing.startSharing", args, &response); err != nil { + return channelID{}, err } - s.streamName = response.StreamName - s.streamCondition = response.StreamCondition - - return nil + return channelID{response.StreamName, response.StreamCondition}, nil } // GetSharedServers returns a description of each container port diff --git a/session_test.go b/session_test.go index 005eacfbd..54aab16c8 100644 --- a/session_test.go +++ b/session_test.go @@ -82,10 +82,11 @@ func TestServerStartSharing(t *testing.T) { done := make(chan error) go func() { - if err := session.StartSharing(ctx, serverProtocol, serverPort); err != nil { + streamID, err := session.startSharing(ctx, serverProtocol, serverPort) + if err != nil { done <- fmt.Errorf("error sharing server: %v", err) } - if session.streamName == "" || session.streamCondition == "" { + if streamID.name == "" || streamID.condition == "" { done <- errors.New("stream name or condition is blank") } done <- nil diff --git a/terminal.go b/terminal.go index 96938ed89..24a0f5121 100644 --- a/terminal.go +++ b/terminal.go @@ -75,7 +75,7 @@ func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { } <-started - channel, err := t.terminal.session.openStreamingChannel(ctx, result.StreamName, result.StreamCondition) + channel, err := t.terminal.session.openStreamingChannel(ctx, channelID{result.StreamName, result.StreamCondition}) if err != nil { return nil, fmt.Errorf("error opening streaming channel: %v", err) } From 090af2290b186306f27e75c112afeed7df25b8d5 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 2 Sep 2021 14:13:40 -0400 Subject: [PATCH 145/290] pr feedback --- api/api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/api.go b/api/api.go index 470c99695..7bb43811a 100644 --- a/api/api.go +++ b/api/api.go @@ -279,7 +279,7 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes b = append(b[:97], "..."...) } if strings.TrimSpace(string(b)) == "7" { - // NON HTTP 200 with error code 7 (EnvironmentNotShutdown) is benign. + // Non-HTTP 200 with error code 7 (EnvironmentNotShutdown) is benign. // Ignore it. } else { return fmt.Errorf("failed to start codespace: %s", b) From 94b91661cc68b200e30e809d8c26b41a7f37c1af Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 14:30:19 -0400 Subject: [PATCH 146/290] don't forget to close conn in case of sharing error --- port_forwarder.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/port_forwarder.go b/port_forwarder.go index 4391ef55c..2d1217c24 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -18,6 +18,12 @@ type PortForwarder struct { // between the local port and the container's remote port over the // specified Live Share session. The name describes the purpose of the // remote port or service. +// +// TODO(adonovan): the localPort param is redundant wrt ForwardWithConn. +// Simpler: do away with the NewPortForwarder type altogether: +// +// - ForwardToLocalPort(ctx, session, name, remote, local) +// - ForwardToConnection(ctx, session, name, remote, conn) func NewPortForwarder(session *Session, name string, localPort, remotePort int) *PortForwarder { return &PortForwarder{ session: session, @@ -74,6 +80,7 @@ func (fwd *PortForwarder) Forward(ctx context.Context) (err error) { func (fwd *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { id, err := fwd.shareRemotePort(ctx) if err != nil { + conn.Close() return err } From 94319d4cfeaa6b6a0389e75c0401e265e2078e09 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 15:34:57 -0400 Subject: [PATCH 147/290] move localPort parameter to ForwardToLocalPort --- port_forwarder.go | 39 +++++++++++++++++---------------------- port_forwarder_test.go | 4 ++-- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index 2d1217c24..4a46cd4e6 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -7,42 +7,36 @@ import ( "net" ) -// A PortForwarder forwards TCP traffic between a local TCP port and a LiveShare session. +// A PortForwarder forwards TCP traffic over a LiveShare session from a port on a remote +// container to a local destination such as a network port or Go reader/writer. type PortForwarder struct { - session *Session - name string - localPort, remotePort int + session *Session + name string + remotePort int } -// NewPortForwarder creates a new PortForwarder that forwards traffic -// between the local port and the container's remote port over the -// specified Live Share session. The name describes the purpose of the -// remote port or service. -// -// TODO(adonovan): the localPort param is redundant wrt ForwardWithConn. -// Simpler: do away with the NewPortForwarder type altogether: -// -// - ForwardToLocalPort(ctx, session, name, remote, local) -// - ForwardToConnection(ctx, session, name, remote, conn) -func NewPortForwarder(session *Session, name string, localPort, remotePort int) *PortForwarder { +// NewPortForwarder returns a new PortForwarder for the specified +// remote port and Live Share session. The name describes the purpose +// of the remote port or service. +func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder { return &PortForwarder{ session: session, name: name, - localPort: localPort, remotePort: remotePort, } } -// Forward enables port forwarding. It accepts and handles TCP -// connections until it encounters the first error, which may include +// ForwardToLocalPort forwards traffic between the container's remote +// port and a local TCP port. It accepts and handles TCP connections +// on the local until it encounters the first error, which may include // context cancellation. Its result is non-nil. -func (fwd *PortForwarder) Forward(ctx context.Context) (err error) { +func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, localPort int) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { return err } - listen, err := net.Listen("tcp", fmt.Sprintf(":%d", fwd.localPort)) + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort)) if err != nil { return fmt.Errorf("error listening on TCP port: %v", err) } @@ -76,8 +70,9 @@ func (fwd *PortForwarder) Forward(ctx context.Context) (err error) { return awaitError(ctx, errc) } -// ForwardWithConn handles port forwarding for a single connection. -func (fwd *PortForwarder) ForwardWithConn(ctx context.Context, conn io.ReadWriteCloser) error { +// Forward forwards traffic between the container's remote port and +// the specified read/write stream. On return, the stream is closed. +func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error { id, err := fwd.shareRemotePort(ctx) if err != nil { conn.Close() diff --git a/port_forwarder_test.go b/port_forwarder_test.go index d47730995..6ccb3d05e 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -20,7 +20,7 @@ func TestNewPortForwarder(t *testing.T) { t.Errorf("create mock client: %v", err) } defer testServer.Close() - pf := NewPortForwarder(session, "ssh", 81, 80) + pf := NewPortForwarder(session, "ssh", 80) if pf == nil { t.Error("port forwarder is nil") } @@ -52,7 +52,7 @@ func TestPortForwarderStart(t *testing.T) { done := make(chan error) go func() { const name, local, remote = "ssh", 8000, 8000 - done <- NewPortForwarder(session, name, local, remote).Forward(ctx) + done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, local) }() go func() { From 4438b85e294e510edf97510ede486db175e8f084 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 15:41:36 -0400 Subject: [PATCH 148/290] comment tweaks --- port_forwarder.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index 4a46cd4e6..f4895bb60 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -27,9 +27,9 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar } // ForwardToLocalPort forwards traffic between the container's remote -// port and a local TCP port. It accepts and handles TCP connections -// on the local until it encounters the first error, which may include -// context cancellation. Its result is non-nil. +// port and a local TCP port. It accepts and handles connections on +// the local port until it encounters the first error, which may +// include context cancellation. Its error result is always non-nil. func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, localPort int) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { From 3485bacc97751521be724326189a566f02e30fb7 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 14:10:29 -0400 Subject: [PATCH 149/290] fix StartSharing data race --- cmd/ghcs/logs.go | 10 +++--- cmd/ghcs/ports.go | 60 +++++++++++++---------------------- cmd/ghcs/ssh.go | 7 ++-- internal/codespaces/ssh.go | 21 ------------ internal/codespaces/states.go | 9 ++---- 5 files changed, 31 insertions(+), 76 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 829ba9a31..07e247cd0 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -8,6 +8,7 @@ import ( "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/codespaces" + "github.com/github/go-liveshare" "github.com/spf13/cobra" ) @@ -54,7 +55,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { return fmt.Errorf("get or choose Codespace: %v", err) } - lsclient, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("connecting to Live Share: %v", err) } @@ -64,15 +65,12 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { return err } - remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, lsclient, log) + remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log) if err != nil { return fmt.Errorf("error getting ssh server details: %v", err) } - tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", localSSHPort, remoteSSHServerPort) - if err != nil { - return fmt.Errorf("make ssh tunnel: %v", err) - } + tunnel := liveshare.NewPortForwarder(session, "sshd", localSSHPort, remoteSSHServerPort) cmdType := "cat" if tail { diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 522ef61cb..539c03320 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -16,7 +16,6 @@ import ( "github.com/github/go-liveshare" "github.com/muhammadmuzzammil1998/jsonc" "github.com/spf13/cobra" - "golang.org/x/sync/errgroup" ) // portOptions represents the options accepted by the ports command. @@ -232,7 +231,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, // port pairs from the codespace to localhost. func newPortsForwardCmd() *cobra.Command { return &cobra.Command{ - Use: "forward :", + Use: "forward :", Short: "Forward ports", Args: cobra.MinimumNArgs(2), RunE: func(cmd *cobra.Command, args []string) error { @@ -271,63 +270,48 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro return fmt.Errorf("error connecting to Live Share: %v", err) } - g, gctx := errgroup.WithContext(ctx) - for _, portPair := range portPairs { - pp := portPair - - // TODO(adonovan): fix data race on Session between - // StartSharing and NewPortForwarder. - srcstr := strconv.Itoa(portPair.src) - if err := session.StartSharing(gctx, "share-"+srcstr, pp.src); err != nil { - return fmt.Errorf("start sharing port: %v", err) - } - - g.Go(func() error { - log.Println("Forwarding port: " + srcstr + " ==> " + strconv.Itoa(pp.dst)) - portForwarder := liveshare.NewPortForwarder(session, pp.dst) - if err := portForwarder.Forward(gctx); err != nil { - return fmt.Errorf("error forwarding port: %v", err) - } - - return nil - }) + // Run forwarding of all ports concurrently, aborting all of + // them at the first failure, including cancellation of the context. + errc := make(chan error, len(portPairs)) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + for _, pair := range portPairs { + log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) + name := fmt.Sprintf("share-%d", pair.remote) + fwd := liveshare.NewPortForwarder(session, name, pair.remote, pair.local) + go func() { + errc <- fwd.Forward(ctx) // error always non-nil + }() } - // TODO(adonovan): fix: the waits for _all_ goroutines to terminate. - // If there are multiple ports, one long-lived successful connection - // will hide errors from any that fail. - if err := g.Wait(); err != nil { - return err - } - - return nil + return <-errc // first error } type portPair struct { - src, dst int + remote, local int } -// getPortPairs parses a list of strings of form "%d:%d" into pairs of numbers. +// getPortPairs parses a list of strings of form "%d:%d" into pairs of (remote, local) numbers. func getPortPairs(ports []string) ([]portPair, error) { pp := make([]portPair, 0, len(ports)) for _, portString := range ports { parts := strings.Split(portString, ":") if len(parts) < 2 { - return nil, fmt.Errorf("port pair: '%v' is not valid", portString) + return nil, fmt.Errorf("port pair: %q is not valid", portString) } - srcp, err := strconv.Atoi(parts[0]) + remote, err := strconv.Atoi(parts[0]) if err != nil { - return pp, fmt.Errorf("convert source port to int: %v", err) + return pp, fmt.Errorf("convert remote port to int: %v", err) } - dstp, err := strconv.Atoi(parts[1]) + local, err := strconv.Atoi(parts[1]) if err != nil { - return pp, fmt.Errorf("convert dest port to int: %v", err) + return pp, fmt.Errorf("convert local port to int: %v", err) } - pp = append(pp, portPair{srcp, dstp}) + pp = append(pp, portPair{local, remote}) } return pp, nil diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 91329b28c..4d0134082 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -28,7 +28,7 @@ func newSSHCmd() *cobra.Command { } sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "The `name` of the SSH profile to use") - sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH server port number") + sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH server port number (0 => pick unused)") sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "The `name` of the Codespace to use") return sshCmd @@ -90,10 +90,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo } } - tunnel, err := codespaces.NewPortForwarder(ctx, session, "sshd", localSSHServerPort, remoteSSHServerPort) - if err != nil { - return fmt.Errorf("make ssh tunnel: %v", err) - } + tunnel := liveshare.NewPortForwarder(session, "sshd", localSSHServerPort, remoteSSHServerPort) connectDestination := sshProfile if connectDestination == "" { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index a8f1834d4..1ef2b819f 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -32,27 +32,6 @@ func UnusedPort() (int, error) { return l.Addr().(*net.TCPAddr).Port, nil } -// NewPortForwarder returns a new port forwarder that forwards traffic between -// the specified local and remote ports over the provided Live Share session. -// -// The session name is used (along with the port) to generate -// names for streams, and may appear in error messages. -func NewPortForwarder(ctx context.Context, session *liveshare.Session, sessionName string, localSSHPort, remoteSSHPort int) (*liveshare.PortForwarder, error) { - if localSSHPort == 0 { - return nil, fmt.Errorf("a local port must be provided") - } - - // TODO(adonovan): fix data race on Session between - // StartSharing and NewPortForwarder. Perhaps combine the - // operations in go-liveshare? - - if err := session.StartSharing(ctx, "sshd", remoteSSHPort); err != nil { - return nil, fmt.Errorf("sharing sshd port: %v", err) - } - - return liveshare.NewPortForwarder(session, localSSHPort), nil -} - // StartSSHServer installs (if necessary) and starts the SSH in the codespace. // It returns the remote port where it is running, the user to log in with, or an error if something failed. func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) (serverPort int, user string, err error) { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index f0052e72c..a745d34e5 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -9,6 +9,7 @@ import ( "time" "github.com/github/ghcs/api" + "github.com/github/go-liveshare" ) // PostCreateStateStatus is a string value representing the different statuses a state can have. @@ -55,14 +56,10 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("error getting ssh server details: %v", err) } - fwd, err := NewPortForwarder(ctx, session, "sshd", localSSHPort, remoteSSHServerPort) - if err != nil { - return fmt.Errorf("creating port forwarder: %v", err) - } - tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness go func() { - tunnelClosed <- fwd.Forward(ctx) // error is non-nil + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil }() t := time.NewTicker(1 * time.Second) From cee761238ba166fad40414058c6ae6837f65324f Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 15:51:56 -0400 Subject: [PATCH 150/290] update go-liveshare@v0.11.0 --- cmd/ghcs/logs.go | 5 ++--- cmd/ghcs/ports.go | 4 ++-- cmd/ghcs/ssh.go | 5 ++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 07e247cd0..616674942 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -70,8 +70,6 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { return fmt.Errorf("error getting ssh server details: %v", err) } - tunnel := liveshare.NewPortForwarder(session, "sshd", localSSHPort, remoteSSHServerPort) - cmdType := "cat" if tail { cmdType = "tail -f" @@ -86,7 +84,8 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { tunnelClosed := make(chan error, 1) go func() { - tunnelClosed <- tunnel.Forward(ctx) // error is non-nil + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil }() cmdDone := make(chan error, 1) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 539c03320..1e7ca96ac 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -278,9 +278,9 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro for _, pair := range portPairs { log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) name := fmt.Sprintf("share-%d", pair.remote) - fwd := liveshare.NewPortForwarder(session, name, pair.remote, pair.local) go func() { - errc <- fwd.Forward(ctx) // error always non-nil + fwd := liveshare.NewPortForwarder(session, name, pair.remote) + errc <- fwd.ForwardToLocalPort(ctx, pair.local) // error always non-nil }() } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 4d0134082..b9a8b7db4 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -90,8 +90,6 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo } } - tunnel := liveshare.NewPortForwarder(session, "sshd", localSSHServerPort, remoteSSHServerPort) - connectDestination := sshProfile if connectDestination == "" { connectDestination = fmt.Sprintf("%s@localhost", sshUser) @@ -99,7 +97,8 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo tunnelClosed := make(chan error) go func() { - tunnelClosed <- tunnel.Forward(ctx) // error is always non-nil + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHServerPort) // error is always non-nil }() shellClosed := make(chan error) From 1162c8adff7d236009ec99b8704e90f12b344e4c Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 16:02:09 -0400 Subject: [PATCH 151/290] fix go vet loopclosure finding --- cmd/ghcs/ports.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 0d875980c..800803269 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -276,6 +276,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro ctx, cancel := context.WithCancel(ctx) defer cancel() for _, pair := range portPairs { + pair := pair log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) name := fmt.Sprintf("share-%d", pair.remote) go func() { From 981b2545bc91e6f190c8ac8b8152a7c7cb0695a5 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 17:04:07 -0400 Subject: [PATCH 152/290] sketch of changes for https://github.com/github/go-liveshare/pull/13 --- cmd/ghcs/logs.go | 10 +++++++--- cmd/ghcs/ports.go | 13 ++++++++++--- cmd/ghcs/ssh.go | 18 ++++++++++-------- internal/codespaces/ssh.go | 20 -------------------- internal/codespaces/states.go | 9 ++++++--- 5 files changed, 33 insertions(+), 37 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 49acb3449..590596603 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "net" "os" "github.com/github/ghcs/api" @@ -60,10 +61,13 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { return fmt.Errorf("connecting to Live Share: %v", err) } - localSSHPort, err := codespaces.UnusedPort() + // Ensure local port is listening before client (getPostCreateOutput) connects. + listen, err := liveshare.Listen(0) // zero => arbitrary if err != nil { return err } + defer listen.Close() + localPort := listen.Addr().(*net.TCPAddr).Port remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log) if err != nil { @@ -77,7 +81,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { dst := fmt.Sprintf("%s@localhost", sshUser) cmd := codespaces.NewRemoteCommand( - ctx, localSSHPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), + ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) // Error channels are buffered so that neither sending goroutine gets stuck. @@ -85,7 +89,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { tunnelClosed := make(chan error, 1) go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil + tunnelClosed <- fwd.ForwardToLocalPort(ctx, listen) // error is non-nil }() cmdDone := make(chan error, 1) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 800803269..3e403294a 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -277,11 +277,18 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro defer cancel() for _, pair := range portPairs { pair := pair - log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) - name := fmt.Sprintf("share-%d", pair.remote) + go func() { + listen, err := liveshare.Listen(pair.local) + if err != nil { + errc <- err + return + } + defer listen.Close() + log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) + name := fmt.Sprintf("share-%d", pair.remote) fwd := liveshare.NewPortForwarder(session, name, pair.remote) - errc <- fwd.ForwardToLocalPort(ctx, pair.local) // error always non-nil + errc <- fwd.ForwardToLocalPort(ctx, listen) // error always non-nil }() } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 183019504..2637dab99 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "net" "os" "strings" @@ -81,14 +82,15 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo } log.Print("\n") - usingCustomPort := true - if localSSHServerPort == 0 { - usingCustomPort = false // suppress log of command line in Shell - localSSHServerPort, err = codespaces.UnusedPort() - if err != nil { - return err - } + usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell + + // Ensure local port is listening before client (Shell) connects. + listen, err := liveshare.Listen(localSSHServerPort) + if err != nil { + return err } + defer listen.Close() + localSSHServerPort = listen.Addr().(*net.TCPAddr).Port connectDestination := sshProfile if connectDestination == "" { @@ -98,7 +100,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo tunnelClosed := make(chan error) go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHServerPort) // error is always non-nil + tunnelClosed <- fwd.ForwardToLocalPort(ctx, listen) // error is always non-nil }() shellClosed := make(chan error) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 1ef2b819f..14dbfbb88 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "os" "os/exec" "strconv" @@ -13,25 +12,6 @@ import ( "github.com/github/go-liveshare" ) -// UnusedPort returns the number of a local TCP port that is currently -// unbound, or an error if none was available. -// -// Use of this function carries an inherent risk of a time-of-check to -// time-of-use race against other processes. -func UnusedPort() (int, error) { - addr, err := net.ResolveTCPAddr("tcp", "localhost:0") - if err != nil { - return 0, fmt.Errorf("internal error while choosing port: %v", err) - } - - l, err := net.ListenTCP("tcp", addr) - if err != nil { - return 0, fmt.Errorf("choosing available port: %v", err) - } - defer l.Close() - return l.Addr().(*net.TCPAddr).Port, nil -} - // StartSSHServer installs (if necessary) and starts the SSH in the codespace. // It returns the remote port where it is running, the user to log in with, or an error if something failed. func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) (serverPort int, user string, err error) { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 271674e5f..99a713ba8 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "net" "strings" "time" @@ -46,10 +47,12 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connect to Live Share: %v", err) } - localSSHPort, err := UnusedPort() + // Ensure local port is listening before client (getPostCreateOutput) connects. + listen, err := liveshare.Listen(0) if err != nil { return err } + localPort := listen.Addr().(*net.TCPAddr).Port remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, session, log) if err != nil { @@ -59,7 +62,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToLocalPort(ctx, localSSHPort) // error is non-nil + tunnelClosed <- fwd.ForwardToLocalPort(ctx, listen) // error is non-nil }() t := time.NewTicker(1 * time.Second) @@ -74,7 +77,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connection failed: %v", err) case <-t.C: - states, err := getPostCreateOutput(ctx, localSSHPort, codespace, sshUser) + states, err := getPostCreateOutput(ctx, localPort, codespace, sshUser) if err != nil { return fmt.Errorf("get post create output: %v", err) } From 786a6319959b1fc38adf07b9e867d66cdd1ce352 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 17:21:24 -0400 Subject: [PATCH 153/290] fix local/remote confusion in getPorts (!) --- cmd/ghcs/ports.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 800803269..dd8f70131 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -312,7 +312,7 @@ func getPortPairs(ports []string) ([]portPair, error) { return pp, fmt.Errorf("convert local port to int: %v", err) } - pp = append(pp, portPair{local, remote}) + pp = append(pp, portPair{remote, local}) } return pp, nil From 5bd0519ef32827e59d94003b995a62a8915f48d4 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 2 Sep 2021 16:45:23 -0400 Subject: [PATCH 154/290] move Listen call into clients to avoid race --- port_forwarder.go | 26 ++++++++++++++++---------- port_forwarder_test.go | 10 ++++++++-- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index f4895bb60..fe0d7d80e 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -26,22 +26,28 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar } } +// ListenTCP calls listen on the chosen local TCP port. Zero picks an arbitrary port. +// It is provided for the convenience of callers of ForwardToLocalPort. +func Listen(port int) (net.Listener, error) { + return net.Listen("tcp", fmt.Sprintf(":%d", port)) +} + // ForwardToLocalPort forwards traffic between the container's remote -// port and a local TCP port. It accepts and handles connections on -// the local port until it encounters the first error, which may -// include context cancellation. Its error result is always non-nil. -func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, localPort int) (err error) { +// port and a local TCP port, which must already be listening for +// connections. (Accepting a listener rather than a port number avoids +// races against other processes opening ports, and against a client +// connecting to the socket prematurely.) +// +// ForwardToLocalPort accepts and handles connections on the local +// port until it encounters the first error, which may include context +// cancellation. Its error result is always non-nil. The caller is +// responsible for closing the listening port. +func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, listen net.Listener) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { return err } - listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localPort)) - if err != nil { - return fmt.Errorf("error listening on TCP port: %v", err) - } - defer safeClose(listen, &err) - errc := make(chan error, 1) sendError := func(err error) { // Use non-blocking send, to avoid goroutines getting diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 6ccb3d05e..68b658b6b 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -46,13 +46,19 @@ func TestPortForwarderStart(t *testing.T) { } defer testServer.Close() + listen, err := Listen(8000) // local port + if err != nil { + t.Fatal(err) + } + defer listen.Close() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() done := make(chan error) go func() { - const name, local, remote = "ssh", 8000, 8000 - done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, local) + const name, remote = "ssh", 8000 + done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, listen) }() go func() { From e2552fbd2a049a8314d83e1b1357b6b5494267dd Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 3 Sep 2021 09:43:31 -0400 Subject: [PATCH 155/290] rename to ForwardToListener --- port_forwarder.go | 17 +++++++++-------- port_forwarder_test.go | 4 ++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index fe0d7d80e..593b70fe7 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -26,23 +26,24 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar } } -// ListenTCP calls listen on the chosen local TCP port. Zero picks an arbitrary port. -// It is provided for the convenience of callers of ForwardToLocalPort. -func Listen(port int) (net.Listener, error) { +// ListenTCP calls listen on the chosen local TCP port. Zero picks an +// arbitrary port. It is provided for the convenience of callers of +// ForwardToListener. +func ListenTCP(port int) (net.Listener, error) { return net.Listen("tcp", fmt.Sprintf(":%d", port)) } -// ForwardToLocalPort forwards traffic between the container's remote -// port and a local TCP port, which must already be listening for +// ForwardToListener forwards traffic between the container's remote +// port and a local port, which must already be listening for // connections. (Accepting a listener rather than a port number avoids // races against other processes opening ports, and against a client // connecting to the socket prematurely.) // -// ForwardToLocalPort accepts and handles connections on the local -// port until it encounters the first error, which may include context +// ForwardToListener accepts and handles connections on the local port +// until it encounters the first error, which may include context // cancellation. Its error result is always non-nil. The caller is // responsible for closing the listening port. -func (fwd *PortForwarder) ForwardToLocalPort(ctx context.Context, listen net.Listener) (err error) { +func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { return err diff --git a/port_forwarder_test.go b/port_forwarder_test.go index 68b658b6b..d6a4e7708 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -46,7 +46,7 @@ func TestPortForwarderStart(t *testing.T) { } defer testServer.Close() - listen, err := Listen(8000) // local port + listen, err := ListenTCP(8000) // local port if err != nil { t.Fatal(err) } @@ -58,7 +58,7 @@ func TestPortForwarderStart(t *testing.T) { done := make(chan error) go func() { const name, remote = "ssh", 8000 - done <- NewPortForwarder(session, name, remote).ForwardToLocalPort(ctx, listen) + done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen) }() go func() { From b1d83fe294e3082e4def167e7b891b9c5c81a80c Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 3 Sep 2021 11:33:33 -0400 Subject: [PATCH 156/290] codespace flag, deprecate argument --- cmd/ghcs/logs.go | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 49acb3449..fd92ff739 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -13,21 +13,27 @@ import ( ) func newLogsCmd() *cobra.Command { - var tail bool + var ( + codespace string + tail bool + ) + + log := output.NewLogger(os.Stdout, os.Stderr, false) logsCmd := &cobra.Command{ - Use: "logs []", + Use: "logs", Short: "Access codespace logs", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - var codespaceName string if len(args) > 0 { - codespaceName = args[0] + log.Println(" argument is deprecated. Use --codespace instead.") + codespace = args[0] } - return logs(context.Background(), tail, codespaceName) + return logs(context.Background(), log, codespace, tail) }, } + logsCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") logsCmd.Flags().BoolVarP(&tail, "tail", "t", false, "Tail the logs") return logsCmd @@ -37,13 +43,12 @@ func init() { rootCmd.AddCommand(newLogsCmd()) } -func logs(ctx context.Context, tail bool, codespaceName string) error { +func logs(ctx context.Context, log *output.Logger, codespaceName string, tail bool) error { // Ensure all child tasks (port forwarding, remote exec) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) - log := output.NewLogger(os.Stdout, os.Stderr, false) user, err := apiClient.GetUser(ctx) if err != nil { From 3216cbc07f9eb698ae88d3214c0647650a542bb7 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 3 Sep 2021 11:43:10 -0400 Subject: [PATCH 157/290] codespace flag, deprecate argument --- cmd/ghcs/code.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 5bad53648..ba4f0234e 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -7,27 +7,34 @@ import ( "os" "github.com/github/ghcs/api" + "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/codespaces" "github.com/skratchdot/open-golang/open" "github.com/spf13/cobra" ) func newCodeCmd() *cobra.Command { - useInsiders := false + var ( + codespace string + useInsiders bool + ) + + log := output.NewLogger(os.Stdout, os.Stderr, false) codeCmd := &cobra.Command{ - Use: "code []", + Use: "code", Short: "Open a codespace in VS Code", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - var codespaceName string if len(args) > 0 { - codespaceName = args[0] + log.Println(" argument is deprecated. Use --codespace instead.") + codespace = args[0] } - return code(codespaceName, useInsiders) + return code(codespace, useInsiders) }, } + codeCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") codeCmd.Flags().BoolVar(&useInsiders, "insiders", false, "Use the insiders version of VS Code") return codeCmd From 9dbf267e54fd679563474e442d0956c444f5f328 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 3 Sep 2021 12:33:47 -0400 Subject: [PATCH 158/290] codespace flag, deprecate argument --- cmd/ghcs/ports.go | 93 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 64 insertions(+), 29 deletions(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index dd8f70131..4ef460b51 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -18,31 +18,25 @@ import ( "github.com/spf13/cobra" ) -// portOptions represents the options accepted by the ports command. -type portsOptions struct { - // CodespaceName is the name of the codespace, optional. - codespaceName string - - // AsJSON dictates whether the command returns a json output or not, optional. - asJSON bool -} - // newPortsCmd returns a Cobra "ports" command that displays a table of available ports, // according to the specified flags. func newPortsCmd() *cobra.Command { - opts := &portsOptions{} + var ( + codespace string + asJSON bool + ) portsCmd := &cobra.Command{ Use: "ports", Short: "List ports in a codespace", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return ports(opts) + return ports(codespace, asJSON) }, } - portsCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "The `name` of the codespace to use") - portsCmd.Flags().BoolVar(&opts.asJSON, "json", false, "Output as JSON") + portsCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + portsCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON") portsCmd.AddCommand(newPortsPublicCmd()) portsCmd.AddCommand(newPortsPrivateCmd()) @@ -55,17 +49,17 @@ func init() { rootCmd.AddCommand(newPortsCmd()) } -func ports(opts *portsOptions) error { +func ports(codespaceName string, asJSON bool) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() - log := output.NewLogger(os.Stdout, os.Stderr, opts.asJSON) + log := output.NewLogger(os.Stdout, os.Stderr, asJSON) user, err := apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %v", err) } - codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, opts.codespaceName) + codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { if err == codespaces.ErrNoCodespaces { return err @@ -92,7 +86,7 @@ func ports(opts *portsOptions) error { _, _ = log.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error()) } - table := output.NewTable(os.Stdout, opts.asJSON) + table := output.NewTable(os.Stdout, asJSON) table.SetHeader([]string{"Label", "Port", "Public", "Browse URL"}) for _, port := range ports { sourcePort := strconv.Itoa(port.SourcePort) @@ -161,28 +155,54 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod // newPortsPublicCmd returns a Cobra "ports public" subcommand, which makes a given port public. func newPortsPublicCmd() *cobra.Command { - return &cobra.Command{ - Use: "public ", + var codespace string + + newPortsPublicCmd := &cobra.Command{ + Use: "public ", Short: "Mark port as public", - Args: cobra.ExactArgs(2), + Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { log := output.NewLogger(os.Stdout, os.Stderr, false) - return updatePortVisibility(log, args[0], args[1], true) + + port := args[0] + if len(args) > 1 { + log.Println(" argument is deprecated. Use --codespace instead.") + codespace, port = args[0], args[1] + } + + return updatePortVisibility(log, codespace, port, true) }, } + + newPortsPublicCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + + return newPortsPublicCmd } // newPortsPrivateCmd returns a Cobra "ports private" subcommand, which makes a given port private. func newPortsPrivateCmd() *cobra.Command { - return &cobra.Command{ - Use: "private ", + var codespace string + + newPortsPrivateCmd := &cobra.Command{ + Use: "private ", Short: "Mark port as private", - Args: cobra.ExactArgs(2), + Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { log := output.NewLogger(os.Stdout, os.Stderr, false) - return updatePortVisibility(log, args[0], args[1], false) + + port := args[0] + if len(args) > 1 { + log.Println(" argument is deprecated. Use --codespace instead.") + codespace, port = args[0], args[1] + } + + return updatePortVisibility(log, codespace, port, false) }, } + + newPortsPrivateCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + + return newPortsPrivateCmd } func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) error { @@ -230,15 +250,30 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, // NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of // port pairs from the codespace to localhost. func newPortsForwardCmd() *cobra.Command { - return &cobra.Command{ - Use: "forward :", + var codespace string + + newPortsForwardCmd := &cobra.Command{ + Use: "forward :...", Short: "Forward ports", - Args: cobra.MinimumNArgs(2), + Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { log := output.NewLogger(os.Stdout, os.Stderr, false) - return forwardPorts(log, args[0], args[1:]) + + ports := args[0:] + if len(args) > 1 && !strings.Contains(args[0], ":") { + // assume this is a codespace name + log.Println(" argument is deprecated. Use --codespace instead.") + codespace = args[0] + ports = args[1:] + } + + return forwardPorts(log, codespace, ports) }, } + + newPortsForwardCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + + return newPortsForwardCmd } func forwardPorts(log *output.Logger, codespaceName string, ports []string) error { From 9193b03b696eb0f91eb0f7d1273b3aed38f514f5 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 3 Sep 2021 12:40:01 -0400 Subject: [PATCH 159/290] introduce follow, deprecate tail --- cmd/ghcs/logs.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index fd92ff739..69ec12c42 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -16,6 +16,7 @@ func newLogsCmd() *cobra.Command { var ( codespace string tail bool + follow bool ) log := output.NewLogger(os.Stdout, os.Stderr, false) @@ -29,12 +30,17 @@ func newLogsCmd() *cobra.Command { log.Println(" argument is deprecated. Use --codespace instead.") codespace = args[0] } - return logs(context.Background(), log, codespace, tail) + if tail { + log.Println("--tail flag is deprecated. Use --follow instead.") + follow = true + } + return logs(context.Background(), log, codespace, follow) }, } logsCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") - logsCmd.Flags().BoolVarP(&tail, "tail", "t", false, "Tail the logs") + logsCmd.Flags().BoolVarP(&tail, "tail", "t", false, "Tail the logs (deprecated, use --follow)") + logsCmd.Flags().BoolVarP(&follow, "follow", "f", false, "Tail and follow the logs") return logsCmd } From 43198b24aa6c5342dba92cbe514ab30f9dea05ac Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 3 Sep 2021 12:50:11 -0400 Subject: [PATCH 160/290] use errgroup --- cmd/ghcs/logs.go | 28 +++++++--------------------- cmd/ghcs/ports.go | 18 +++++++----------- cmd/ghcs/ssh.go | 27 ++++++++++----------------- 3 files changed, 24 insertions(+), 49 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index d3b2c063f..5e7e8c0a5 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -11,6 +11,7 @@ import ( "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) func newLogsCmd() *cobra.Command { @@ -84,27 +85,12 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) - // Error channels are buffered so that neither sending goroutine gets stuck. - - tunnelClosed := make(chan error, 1) - go func() { + group, ctx := errgroup.WithContext(ctx) + group.Go(func() error { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil - }() - - cmdDone := make(chan error, 1) - go func() { - cmdDone <- cmd.Run() - }() - - select { - case err := <-tunnelClosed: + err := fwd.ForwardToListener(ctx, listen) // error is non-nil return fmt.Errorf("connection closed: %v", err) - - case err := <-cmdDone: - if err != nil { - return fmt.Errorf("error retrieving logs: %v", err) - } - return nil // success - } + }) + group.Go(cmd.Run) + return group.Wait() } diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 6c582c504..958b25996 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -16,6 +16,7 @@ import ( "github.com/github/go-liveshare" "github.com/muhammadmuzzammil1998/jsonc" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) // portOptions represents the options accepted by the ports command. @@ -272,27 +273,22 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro // Run forwarding of all ports concurrently, aborting all of // them at the first failure, including cancellation of the context. - errc := make(chan error, len(portPairs)) - ctx, cancel := context.WithCancel(ctx) - defer cancel() + group, ctx := errgroup.WithContext(ctx) for _, pair := range portPairs { pair := pair - - go func() { + group.Go(func() error { listen, err := liveshare.ListenTCP(pair.local) if err != nil { - errc <- err - return + return nil } defer listen.Close() log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) name := fmt.Sprintf("share-%d", pair.remote) fwd := liveshare.NewPortForwarder(session, name, pair.remote) - errc <- fwd.ForwardToListener(ctx, listen) // error always non-nil - }() + return fwd.ForwardToListener(ctx, listen) // error always non-nil + }) } - - return <-errc // first error + return group.Wait() // first error } type portPair struct { diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index d7c0847e7..55a406c94 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -13,6 +13,7 @@ import ( "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) func newSSHCmd() *cobra.Command { @@ -97,28 +98,20 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo connectDestination = fmt.Sprintf("%s@localhost", sshUser) } - tunnelClosed := make(chan error) - go func() { - fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is always non-nil - }() - - shellClosed := make(chan error) - go func() { - shellClosed <- codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort) - }() - log.Println("Ready...") - select { - case err := <-tunnelClosed: + group, ctx := errgroup.WithContext(ctx) + group.Go(func() error { + fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) + err := fwd.ForwardToListener(ctx, listen) // always non-nil return fmt.Errorf("tunnel closed: %v", err) - - case err := <-shellClosed: - if err != nil { + }) + group.Go(func() error { + if err := codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort); err != nil { return fmt.Errorf("shell closed: %v", err) } return nil // success - } + }) + return group.Wait() } func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) { From 2c660fa2e5a47c499f74aeb7dc522349a5753d3a Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 3 Sep 2021 12:55:40 -0400 Subject: [PATCH 161/290] avoid ListenTCP helper --- cmd/ghcs/logs.go | 2 +- cmd/ghcs/ports.go | 3 ++- cmd/ghcs/ssh.go | 2 +- internal/codespaces/states.go | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 5e7e8c0a5..f069a58ab 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -63,7 +63,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { } // Ensure local port is listening before client (getPostCreateOutput) connects. - listen, err := liveshare.ListenTCP(0) // zero => arbitrary + listen, err := net.Listen("tcp", ":0") // arbitrary port if err != nil { return err } diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 958b25996..fb76022d7 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "os" "strconv" "strings" @@ -277,7 +278,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro for _, pair := range portPairs { pair := pair group.Go(func() error { - listen, err := liveshare.ListenTCP(pair.local) + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local)) if err != nil { return nil } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 55a406c94..6e2724e73 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -86,7 +86,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell // Ensure local port is listening before client (Shell) connects. - listen, err := liveshare.ListenTCP(localSSHServerPort) + listen, err := net.Listen("tcp", fmt.Sprintf(":%d", localSSHServerPort)) if err != nil { return err } diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 46d4f5ed5..492ce3964 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -48,7 +48,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u } // Ensure local port is listening before client (getPostCreateOutput) connects. - listen, err := liveshare.ListenTCP(0) + listen, err := net.Listen("tcp", ":0") // arbitrary port if err != nil { return err } From 9e81dc7fdef457f09a18bd81a231b34a3e8f7d03 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 3 Sep 2021 12:56:47 -0400 Subject: [PATCH 162/290] fix missing error return --- cmd/ghcs/ports.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index fb76022d7..4258991b6 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -280,7 +280,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro group.Go(func() error { listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local)) if err != nil { - return nil + return err } defer listen.Close() log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) From 50523c4f1087ea361b1216d894a26c7ca6fb7d46 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 3 Sep 2021 14:39:47 -0400 Subject: [PATCH 163/290] remove ListenTCP and add workaround for ssh.channel.Close EOF --- port_forwarder.go | 17 +++++++++-------- port_forwarder_test.go | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index 593b70fe7..dc91222ed 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -26,13 +26,6 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar } } -// ListenTCP calls listen on the chosen local TCP port. Zero picks an -// arbitrary port. It is provided for the convenience of callers of -// ForwardToListener. -func ListenTCP(port int) (net.Listener, error) { - return net.Listen("tcp", fmt.Sprintf(":%d", port)) -} - // ForwardToListener forwards traffic between the container's remote // port and a local port, which must already be listening for // connections. (Accepting a listener rather than a port number avoids @@ -121,7 +114,15 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co if err != nil { return fmt.Errorf("error opening streaming channel for new connection: %v", err) } - defer safeClose(channel, &err) + // Ideally we would call safeClose again, but (*ssh.channel).Close + // appears to have a bug that causes it return io.EOF spuriously + // if its peer closed first; see github.com/golang/go/issues/38115. + defer func() { + closeErr := channel.Close() + if err == nil && closeErr != io.EOF { + err = closeErr + } + }() errs := make(chan error, 2) copyConn := func(w io.Writer, r io.Reader) { diff --git a/port_forwarder_test.go b/port_forwarder_test.go index d6a4e7708..c4245f513 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -46,7 +46,7 @@ func TestPortForwarderStart(t *testing.T) { } defer testServer.Close() - listen, err := ListenTCP(8000) // local port + listen, err := net.Listen("tcp", ":8000") if err != nil { t.Fatal(err) } From b79ea871fd38c1ba4d6b3f8a995a1754a72eb651 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 3 Sep 2021 16:04:00 -0400 Subject: [PATCH 164/290] rename arg --- cmd/ghcs/logs.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 69ec12c42..8a83b252a 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -49,7 +49,7 @@ func init() { rootCmd.AddCommand(newLogsCmd()) } -func logs(ctx context.Context, log *output.Logger, codespaceName string, tail bool) error { +func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) error { // Ensure all child tasks (port forwarding, remote exec) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -82,7 +82,7 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, tail bo } cmdType := "cat" - if tail { + if follow { cmdType = "tail -f" } From d395dae3a875fb248b6fcc1abe705064d422d057 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 6 Sep 2021 15:17:24 -0400 Subject: [PATCH 165/290] don't double-print errors --- cmd/ghcs/main.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 2f1515ac0..3e861dacb 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -19,8 +19,9 @@ func main() { var version = "DEV" var rootCmd = &cobra.Command{ - Use: "ghcs", - SilenceUsage: true, // don't print usage message after each error (see #80) + Use: "ghcs", + SilenceUsage: true, // don't print usage message after each error (see #80) + SilenceErrors: false, // print errors automatically so that main need not Long: `Unofficial CLI tool to manage GitHub Codespaces. Running commands requires the GITHUB_TOKEN environment variable to be set to a @@ -43,5 +44,4 @@ func explainError(w io.Writer, err error) { fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") return } - fmt.Fprintf(w, "%v\n", err) } From fda40a96826f9915a73545894d855f03d47670c7 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 8 Sep 2021 10:29:30 -0400 Subject: [PATCH 166/290] new Errorln method and add comments --- cmd/ghcs/output/logger.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/cmd/ghcs/output/logger.go b/cmd/ghcs/output/logger.go index 32d05acc8..a2aa68ba1 100644 --- a/cmd/ghcs/output/logger.go +++ b/cmd/ghcs/output/logger.go @@ -5,6 +5,8 @@ import ( "io" ) +// NewLogger returns a Logger that will write to the given stdout/stderr writers. +// Disable the Logger to prevent it from writing to stdout in a TTY environment. func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger { return &Logger{ out: stdout, @@ -13,12 +15,16 @@ func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger { } } +// Logger writes to the given stdout/stderr writers. +// If not enabled, Print functions will noop but Error functions will continue +// to write to the stderr writer. type Logger struct { out io.Writer errout io.Writer enabled bool } +// Print writes the arguments to the stdout writer. func (l *Logger) Print(v ...interface{}) (int, error) { if !l.enabled { return 0, nil @@ -26,6 +32,7 @@ func (l *Logger) Print(v ...interface{}) (int, error) { return fmt.Fprint(l.out, v...) } +// Println writes the arguments to the stdout writer with a newline at the end. func (l *Logger) Println(v ...interface{}) (int, error) { if !l.enabled { return 0, nil @@ -33,6 +40,7 @@ func (l *Logger) Println(v ...interface{}) (int, error) { return fmt.Fprintln(l.out, v...) } +// Printf writes the formatted arguments to the stdout writer. func (l *Logger) Printf(f string, v ...interface{}) (int, error) { if !l.enabled { return 0, nil @@ -40,6 +48,12 @@ func (l *Logger) Printf(f string, v ...interface{}) (int, error) { return fmt.Fprintf(l.out, f, v...) } +// Errorf writes the formatted arguments to the stderr writer. func (l *Logger) Errorf(f string, v ...interface{}) (int, error) { return fmt.Fprintf(l.errout, f, v...) } + +// Errorln writes the arguments to the stderr writer with a newline at the end. +func (l *Logger) Errorln(v ...interface{}) (int, error) { + return fmt.Fprintln(l.errout, v...) +} From c86cd34f5ef9680de672d388b8d9a1dc2c4e35b3 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 8 Sep 2021 13:38:27 -0400 Subject: [PATCH 167/290] switch to Errorln --- cmd/ghcs/logs.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 49369d1c0..15929c9bb 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -29,11 +29,11 @@ func newLogsCmd() *cobra.Command { Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { if len(args) > 0 { - log.Println(" argument is deprecated. Use --codespace instead.") + log.Errorln(" argument is deprecated. Use --codespace instead.") codespace = args[0] } if tail { - log.Println("--tail flag is deprecated. Use --follow instead.") + log.Errorln("--tail flag is deprecated. Use --follow instead.") follow = true } return logs(context.Background(), log, codespace, follow) From d8138c08b8b59a39918c572e06d2d38317beadf0 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 8 Sep 2021 14:56:16 -0400 Subject: [PATCH 168/290] revert errgroup usage for ssh and logs --- cmd/ghcs/logs.go | 26 +++++++++++++++++++------- cmd/ghcs/ssh.go | 24 +++++++++++++++--------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index f069a58ab..e35ecb728 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -11,7 +11,6 @@ import ( "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" - "golang.org/x/sync/errgroup" ) func newLogsCmd() *cobra.Command { @@ -85,12 +84,25 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) - group, ctx := errgroup.WithContext(ctx) - group.Go(func() error { + tunnelClosed := make(chan error, 1) + go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - err := fwd.ForwardToListener(ctx, listen) // error is non-nil + tunnelClosed <- fwd.ForwardToListener(ctx, listen) // error is non-nil + }() + + cmdDone := make(chan error, 1) + go func() { + cmdDone <- cmd.Run() + }() + + select { + case err := <-tunnelClosed: return fmt.Errorf("connection closed: %v", err) - }) - group.Go(cmd.Run) - return group.Wait() + case err := <-cmdDone: + if err != nil { + return fmt.Errorf("error retrieving logs: %v", err) + } + + return nil // success + } } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 6e2724e73..e2003b347 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -13,7 +13,6 @@ import ( "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" - "golang.org/x/sync/errgroup" ) func newSSHCmd() *cobra.Command { @@ -99,19 +98,26 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo } log.Println("Ready...") - group, ctx := errgroup.WithContext(ctx) - group.Go(func() error { + tunnelClosed := make(chan error, 1) + go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) - err := fwd.ForwardToListener(ctx, listen) // always non-nil + tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil + }() + + shellClosed := make(chan error, 1) + go func() { + shellClosed <- codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort) + }() + + select { + case err := <-tunnelClosed: return fmt.Errorf("tunnel closed: %v", err) - }) - group.Go(func() error { - if err := codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort); err != nil { + case err := <-shellClosed: + if err != nil { return fmt.Errorf("shell closed: %v", err) } return nil // success - }) - return group.Wait() + } } func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) { From 3a46f2ac56e9aa23037f76e56a0d2858ea41f152 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 8 Sep 2021 17:15:42 -0400 Subject: [PATCH 169/290] add --lightstep flag for tracing --- api/api.go | 31 +++++++++++------ cmd/ghcs/main.go | 90 +++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 98 insertions(+), 23 deletions(-) diff --git a/api/api.go b/api/api.go index 7bb43811a..28c7072ea 100644 --- a/api/api.go +++ b/api/api.go @@ -18,6 +18,8 @@ import ( "net/http" "strconv" "strings" + + "github.com/opentracing/opentracing-go" ) const githubAPI = "https://api.github.com" @@ -42,7 +44,7 @@ func (a *API) GetUser(ctx context.Context) (*User, error) { } a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/user") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -87,7 +89,7 @@ func (a *API) GetRepository(ctx context.Context, nwo string) (*Repository, error } a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/repos/*") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -146,7 +148,7 @@ func (a *API) ListCodespaces(ctx context.Context, user *User) ([]*Codespace, err } a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -194,7 +196,7 @@ func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName s } a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*/token") if err != nil { return "", fmt.Errorf("error making request: %v", err) } @@ -228,7 +230,7 @@ func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) } req.Header.Set("Authorization", "Bearer "+token) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -262,7 +264,7 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes } req.Header.Set("Authorization", "Bearer "+token) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/proxy/environments/*/start") if err != nil { return fmt.Errorf("error making request: %v", err) } @@ -299,7 +301,7 @@ func (a *API) GetCodespaceRegionLocation(ctx context.Context) (string, error) { return "", fmt.Errorf("error creating request: %v", err) } - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, req.URL.String()) if err != nil { return "", fmt.Errorf("error making request: %v", err) } @@ -340,7 +342,7 @@ func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Rep req.URL.RawQuery = q.Encode() a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/skus") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -384,7 +386,7 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos } a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -414,7 +416,7 @@ func (a *API) DeleteCodespace(ctx context.Context, user *User, token, codespaceN } req.Header.Set("Authorization", "Bearer "+token) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") if err != nil { return fmt.Errorf("error making request: %v", err) } @@ -446,7 +448,7 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod req.URL.RawQuery = q.Encode() a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/repos/*/contents/*") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -478,6 +480,13 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod return decoded, nil } +func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http.Response, error) { + // TODO(adonovan): use NewRequestWithContext(ctx) and drop ctx parameter. + span, ctx := opentracing.StartSpanFromContext(ctx, spanName) + defer span.Finish() + return a.client.Do(req) +} + func (a *API) setHeaders(req *http.Request) { req.Header.Set("Authorization", "Bearer "+a.token) req.Header.Set("Accept", "application/vnd.github.v3+json") diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 3e861dacb..5d7d7cf23 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -4,8 +4,13 @@ import ( "errors" "fmt" "io" + "log" "os" + "strconv" + "strings" + "github.com/lightstep/lightstep-tracer-go" + "github.com/opentracing/opentracing-go" "github.com/spf13/cobra" ) @@ -18,22 +23,33 @@ func main() { var version = "DEV" -var rootCmd = &cobra.Command{ - Use: "ghcs", - SilenceUsage: true, // don't print usage message after each error (see #80) - SilenceErrors: false, // print errors automatically so that main need not - Long: `Unofficial CLI tool to manage GitHub Codespaces. +var rootCmd = newRootCmd() + +func newRootCmd() *cobra.Command { + var lightstep string + + root := &cobra.Command{ + Use: "ghcs", + SilenceUsage: true, // don't print usage message after each error (see #80) + SilenceErrors: false, // print errors automatically so that main need not + Long: `Unofficial CLI tool to manage GitHub Codespaces. Running commands requires the GITHUB_TOKEN environment variable to be set to a token to access the GitHub API with.`, - Version: version, + Version: version, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if os.Getenv("GITHUB_TOKEN") == "" { - return tokenError - } - return nil - }, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + if os.Getenv("GITHUB_TOKEN") == "" { + return tokenError + } + initLightstep(lightstep) + return nil + }, + } + + root.PersistentFlags().StringVar(&lightstep, "lightstep", "", "Lightstep tracing endpoint (service:token@host:port)") + + return root } var tokenError = errors.New("GITHUB_TOKEN is missing") @@ -45,3 +61,53 @@ func explainError(w io.Writer, err error) { return } } + +// initLightstep parses the --lightstep=service:token@host:port flag and +// enables tracing if non-empty. +func initLightstep(config string) { + if config == "" { + return + } + + cut := func(s, sep string) (pre, post string) { + if i := strings.Index(s, sep); i >= 0 { + return s[:i], s[i+len(sep):] + } + return s, "" + } + + // Parse service:password@host:port. + serviceToken, hostPort := cut(config, "@") + service, token := cut(serviceToken, ":") + host, port := cut(hostPort, ":") + portI, err := strconv.Atoi(port) + if err != nil { + log.Fatalf("invalid lightstep configuration: %s", config) + } + + // View at https://app.lightstep.com/github-prod/service-directory/ghcs/deployments + // --lightstep=ghcs:dhhPgaoavzIHz3tJMnj3Oz88Md2VC4HpwcZ8mpoWwwuOwcfU3x+K70lLhJJAXsk63T3bWfPXGgrAwTMQxLY=@lightstep-collector.service.iad.github.net:443 + // From https://app.lightstep.com/github-prod/project ghcs + + opentracing.SetGlobalTracer(lightstep.NewTracer(lightstep.Options{ + AccessToken: token, + Collector: lightstep.Endpoint{ + Host: host, + Port: portI, + Plaintext: false, + }, + Tags: opentracing.Tags{ + lightstep.ComponentNameKey: service, + }, + })) + + // Report failure to record traces. + lightstep.SetGlobalEventHandler(func(ev lightstep.Event) { + switch ev := ev.(type) { + case lightstep.EventStatusReport, lightstep.MetricEventStatusReport: + // ignore + default: + log.Printf("[trace] %s", ev) + } + }) +} From 72659a360334186804e5cfcf781d504104ca50a8 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 8 Sep 2021 17:21:54 -0400 Subject: [PATCH 170/290] add lightstep instrumentation --- client.go | 7 +++++++ port_forwarder.go | 5 +++++ rpc.go | 4 ++++ 3 files changed, 16 insertions(+) diff --git a/client.go b/client.go index 0088662f7..566db6cd3 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "fmt" + "github.com/opentracing/opentracing-go" "golang.org/x/crypto/ssh" ) @@ -52,6 +53,9 @@ func WithTLSConfig(tlsConfig *tls.Config) ClientOption { // JoinWorkspace connects the client to the server's Live Share // workspace and returns a session representing their connection. func (c *Client) JoinWorkspace(ctx context.Context) (*Session, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "Client.JoinWorkspace") + defer span.Finish() + clientSocket := newSocket(c.connection, c.tlsConfig) if err := clientSocket.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting websocket: %v", err) @@ -120,6 +124,9 @@ func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.C return nil, fmt.Errorf("error getting stream id: %v", err) } + span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") + defer span.Finish() + channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) if err != nil { return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) diff --git a/port_forwarder.go b/port_forwarder.go index dc91222ed..7d3363ba2 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -5,6 +5,8 @@ import ( "fmt" "io" "net" + + "github.com/opentracing/opentracing-go" ) // A PortForwarder forwards TCP traffic over a LiveShare session from a port on a remote @@ -108,6 +110,9 @@ func awaitError(ctx context.Context, errc <-chan error) error { // handleConnection handles forwarding for a single accepted connection, then closes it. func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection") + defer span.Finish() + defer safeClose(conn, &err) channel, err := fwd.session.openStreamingChannel(ctx, id) diff --git a/rpc.go b/rpc.go index 237606fe0..10aa2c7eb 100644 --- a/rpc.go +++ b/rpc.go @@ -6,6 +6,7 @@ import ( "io" "sync" + "github.com/opentracing/opentracing-go" "github.com/sourcegraph/jsonrpc2" ) @@ -26,6 +27,9 @@ func (r *rpcClient) connect(ctx context.Context) { } func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error { + span, ctx := opentracing.StartSpanFromContext(ctx, method) + defer span.Finish() + waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { return fmt.Errorf("error dispatching %q call: %v", method, err) From c6a991586104cc933e09fd1ba008ec98f6e32073 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 8 Sep 2021 17:15:42 -0400 Subject: [PATCH 171/290] add --lightstep flag for tracing --- api/api.go | 31 ++++++++++------- cmd/ghcs/main.go | 86 +++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 94 insertions(+), 23 deletions(-) diff --git a/api/api.go b/api/api.go index 7bb43811a..28c7072ea 100644 --- a/api/api.go +++ b/api/api.go @@ -18,6 +18,8 @@ import ( "net/http" "strconv" "strings" + + "github.com/opentracing/opentracing-go" ) const githubAPI = "https://api.github.com" @@ -42,7 +44,7 @@ func (a *API) GetUser(ctx context.Context) (*User, error) { } a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/user") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -87,7 +89,7 @@ func (a *API) GetRepository(ctx context.Context, nwo string) (*Repository, error } a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/repos/*") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -146,7 +148,7 @@ func (a *API) ListCodespaces(ctx context.Context, user *User) ([]*Codespace, err } a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -194,7 +196,7 @@ func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName s } a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*/token") if err != nil { return "", fmt.Errorf("error making request: %v", err) } @@ -228,7 +230,7 @@ func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) } req.Header.Set("Authorization", "Bearer "+token) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -262,7 +264,7 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes } req.Header.Set("Authorization", "Bearer "+token) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/proxy/environments/*/start") if err != nil { return fmt.Errorf("error making request: %v", err) } @@ -299,7 +301,7 @@ func (a *API) GetCodespaceRegionLocation(ctx context.Context) (string, error) { return "", fmt.Errorf("error creating request: %v", err) } - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, req.URL.String()) if err != nil { return "", fmt.Errorf("error making request: %v", err) } @@ -340,7 +342,7 @@ func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Rep req.URL.RawQuery = q.Encode() a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/skus") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -384,7 +386,7 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos } a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -414,7 +416,7 @@ func (a *API) DeleteCodespace(ctx context.Context, user *User, token, codespaceN } req.Header.Set("Authorization", "Bearer "+token) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") if err != nil { return fmt.Errorf("error making request: %v", err) } @@ -446,7 +448,7 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod req.URL.RawQuery = q.Encode() a.setHeaders(req) - resp, err := a.client.Do(req) + resp, err := a.do(ctx, req, "/repos/*/contents/*") if err != nil { return nil, fmt.Errorf("error making request: %v", err) } @@ -478,6 +480,13 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod return decoded, nil } +func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http.Response, error) { + // TODO(adonovan): use NewRequestWithContext(ctx) and drop ctx parameter. + span, ctx := opentracing.StartSpanFromContext(ctx, spanName) + defer span.Finish() + return a.client.Do(req) +} + func (a *API) setHeaders(req *http.Request) { req.Header.Set("Authorization", "Bearer "+a.token) req.Header.Set("Accept", "application/vnd.github.v3+json") diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 3e861dacb..b1108547e 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -4,8 +4,13 @@ import ( "errors" "fmt" "io" + "log" "os" + "strconv" + "strings" + "github.com/lightstep/lightstep-tracer-go" + "github.com/opentracing/opentracing-go" "github.com/spf13/cobra" ) @@ -18,22 +23,33 @@ func main() { var version = "DEV" -var rootCmd = &cobra.Command{ - Use: "ghcs", - SilenceUsage: true, // don't print usage message after each error (see #80) - SilenceErrors: false, // print errors automatically so that main need not - Long: `Unofficial CLI tool to manage GitHub Codespaces. +var rootCmd = newRootCmd() + +func newRootCmd() *cobra.Command { + var lightstep string + + root := &cobra.Command{ + Use: "ghcs", + SilenceUsage: true, // don't print usage message after each error (see #80) + SilenceErrors: false, // print errors automatically so that main need not + Long: `Unofficial CLI tool to manage GitHub Codespaces. Running commands requires the GITHUB_TOKEN environment variable to be set to a token to access the GitHub API with.`, - Version: version, + Version: version, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if os.Getenv("GITHUB_TOKEN") == "" { - return tokenError - } - return nil - }, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + if os.Getenv("GITHUB_TOKEN") == "" { + return tokenError + } + initLightstep(lightstep) + return nil + }, + } + + root.PersistentFlags().StringVar(&lightstep, "lightstep", "", "Lightstep tracing endpoint (service:token@host:port)") + + return root } var tokenError = errors.New("GITHUB_TOKEN is missing") @@ -45,3 +61,49 @@ func explainError(w io.Writer, err error) { return } } + +// initLightstep parses the --lightstep=service:token@host:port flag and +// enables tracing if non-empty. +func initLightstep(config string) { + if config == "" { + return + } + + cut := func(s, sep string) (pre, post string) { + if i := strings.Index(s, sep); i >= 0 { + return s[:i], s[i+len(sep):] + } + return s, "" + } + + // Parse service:password@host:port. + serviceToken, hostPort := cut(config, "@") + service, token := cut(serviceToken, ":") + host, port := cut(hostPort, ":") + portI, err := strconv.Atoi(port) + if err != nil { + log.Fatalf("invalid lightstep configuration: %s", config) + } + + opentracing.SetGlobalTracer(lightstep.NewTracer(lightstep.Options{ + AccessToken: token, + Collector: lightstep.Endpoint{ + Host: host, + Port: portI, + Plaintext: false, + }, + Tags: opentracing.Tags{ + lightstep.ComponentNameKey: service, + }, + })) + + // Report failure to record traces. + lightstep.SetGlobalEventHandler(func(ev lightstep.Event) { + switch ev := ev.(type) { + case lightstep.EventStatusReport, lightstep.MetricEventStatusReport: + // ignore + default: + log.Printf("[trace] %s", ev) + } + }) +} From 09a660905081d74504a376d262a9886a150526d9 Mon Sep 17 00:00:00 2001 From: Christian Gregg Date: Thu, 9 Sep 2021 12:37:05 +0100 Subject: [PATCH 172/290] Show * after branch name if codespace working directory is dirty Append a `*` to the end of a branch name in `ghcs list` if the working directory of the codespace is dirty (has uncommited or unpushed changes). Closes: #104 --- api/api.go | 10 ++++++++++ cmd/ghcs/list.go | 17 ++++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/api/api.go b/api/api.go index 7bb43811a..00844a422 100644 --- a/api/api.go +++ b/api/api.go @@ -124,6 +124,16 @@ type Codespace struct { type CodespaceEnvironment struct { State string `json:"state"` Connection CodespaceEnvironmentConnection `json:"connection"` + GitStatus CodespaceEnvironmentGitStatus `json:"gitStatus"` +} + +type CodespaceEnvironmentGitStatus struct { + Ahead int `json:"ahead"` + Behind int `json:"behind"` + Branch string `json:"branch"` + Commit string `json:"commit"` + HasUnpushedChanges bool `json:"hasUnpushedChanges"` + HasUncommitedChanges bool `json:"hasUncommitedChanges"` } const ( diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index c6075a988..ee26e3013 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -53,10 +53,25 @@ func list(opts *listOptions) error { table.SetHeader([]string{"Name", "Repository", "Branch", "State", "Created At"}) for _, codespace := range codespaces { table.Append([]string{ - codespace.Name, codespace.RepositoryNWO, codespace.Branch, codespace.Environment.State, codespace.CreatedAt, + codespace.Name, + codespace.RepositoryNWO, + branch(codespace), + codespace.Environment.State, + codespace.CreatedAt, }) } table.Render() return nil } + +func branch(codespace *api.Codespace) string { + name := codespace.Branch + gitStatus := codespace.Environment.GitStatus + + if gitStatus.HasUncommitedChanges || gitStatus.HasUnpushedChanges { + name += "*" + } + + return name +} From 3b198c1707737ed000febd0eb9b469452231f097 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 9 Sep 2021 10:09:14 -0400 Subject: [PATCH 173/290] switch to Errorln --- cmd/ghcs/code.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index ba4f0234e..d14a66926 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -27,7 +27,7 @@ func newCodeCmd() *cobra.Command { Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { if len(args) > 0 { - log.Println(" argument is deprecated. Use --codespace instead.") + log.Errorln(" argument is deprecated. Use --codespace instead.") codespace = args[0] } return code(codespace, useInsiders) From 230bf640c5f020ec866596c5fbf53011c355b54d Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 9 Sep 2021 11:06:18 -0400 Subject: [PATCH 174/290] global flag, choose codespace when empty --- cmd/ghcs/ports.go | 64 +++++++++++++++++-------------- internal/codespaces/codespaces.go | 2 + 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 2fab4254d..2bb3e6917 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -37,7 +37,7 @@ func newPortsCmd() *cobra.Command { }, } - portsCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + portsCmd.PersistentFlags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") portsCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON") portsCmd.AddCommand(newPortsPublicCmd()) @@ -157,18 +157,24 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod // newPortsPublicCmd returns a Cobra "ports public" subcommand, which makes a given port public. func newPortsPublicCmd() *cobra.Command { - var codespace string - newPortsPublicCmd := &cobra.Command{ Use: "public ", Short: "Mark port as public", Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { + codespace, err := cmd.Flags().GetString("codespace") + if err != nil { + // should only happen if flag is not defined + // or if the flag is not of string type + // since it's a persistent flag that we control it should never happen + return fmt.Errorf("get codespace flag: %v", err) + } + log := output.NewLogger(os.Stdout, os.Stderr, false) port := args[0] if len(args) > 1 { - log.Println(" argument is deprecated. Use --codespace instead.") + log.Errorln(" argument is deprecated. Use --codespace instead.") codespace, port = args[0], args[1] } @@ -176,25 +182,29 @@ func newPortsPublicCmd() *cobra.Command { }, } - newPortsPublicCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") - return newPortsPublicCmd } // newPortsPrivateCmd returns a Cobra "ports private" subcommand, which makes a given port private. func newPortsPrivateCmd() *cobra.Command { - var codespace string - newPortsPrivateCmd := &cobra.Command{ Use: "private ", Short: "Mark port as private", Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { + codespace, err := cmd.Flags().GetString("codespace") + if err != nil { + // should only happen if flag is not defined + // or if the flag is not of string type + // since it's a persistent flag that we control it should never happen + return fmt.Errorf("get codespace flag: %v", err) + } + log := output.NewLogger(os.Stdout, os.Stderr, false) port := args[0] if len(args) > 1 { - log.Println(" argument is deprecated. Use --codespace instead.") + log.Errorln(" argument is deprecated. Use --codespace instead.") codespace, port = args[0], args[1] } @@ -202,8 +212,6 @@ func newPortsPrivateCmd() *cobra.Command { }, } - newPortsPrivateCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") - return newPortsPrivateCmd } @@ -216,13 +224,11 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, return fmt.Errorf("error getting user: %v", err) } - token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) - if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) - } - - codespace, err := apiClient.GetCodespace(ctx, token, user.Login, codespaceName) + codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { + if err == codespaces.ErrNoCodespaces { + return err + } return fmt.Errorf("error getting codespace: %v", err) } @@ -252,19 +258,25 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, // NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of // port pairs from the codespace to localhost. func newPortsForwardCmd() *cobra.Command { - var codespace string - newPortsForwardCmd := &cobra.Command{ Use: "forward :...", Short: "Forward ports", Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { + codespace, err := cmd.Flags().GetString("codespace") + if err != nil { + // should only happen if flag is not defined + // or if the flag is not of string type + // since it's a persistent flag that we control it should never happen + return fmt.Errorf("get codespace flag: %v", err) + } + log := output.NewLogger(os.Stdout, os.Stderr, false) ports := args[0:] if len(args) > 1 && !strings.Contains(args[0], ":") { // assume this is a codespace name - log.Println(" argument is deprecated. Use --codespace instead.") + log.Errorln(" argument is deprecated. Use --codespace instead.") codespace = args[0] ports = args[1:] } @@ -273,8 +285,6 @@ func newPortsForwardCmd() *cobra.Command { }, } - newPortsForwardCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") - return newPortsForwardCmd } @@ -292,13 +302,11 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro return fmt.Errorf("error getting user: %v", err) } - token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) - if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) - } - - codespace, err := apiClient.GetCodespace(ctx, token, user.Login, codespaceName) + codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { + if err == codespaces.ErrNoCodespaces { + return err + } return fmt.Errorf("error getting codespace: %v", err) } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index b5fa4a583..fd04b303e 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -124,6 +124,8 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use return lsclient.JoinWorkspace(ctx) } +// GetOrChooseCodespace prompts the user to choose a codespace if the codespaceName is empty. +// It then fetches the codespace token and the codespace record. func GetOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { if codespaceName == "" { codespace, err = ChooseCodespace(ctx, apiClient, user) From cbb82535448b11b53dfd51ea8a67c37e6a9ac1f4 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 9 Sep 2021 10:28:55 -0400 Subject: [PATCH 175/290] consolidate survey functions --- cmd/ghcs/code.go | 5 +- cmd/ghcs/common.go | 101 ++++++++++++++++++++++++++++++ cmd/ghcs/create.go | 5 -- cmd/ghcs/delete.go | 3 +- cmd/ghcs/logs.go | 2 +- cmd/ghcs/ports.go | 4 +- cmd/ghcs/ssh.go | 2 +- internal/codespaces/codespaces.go | 80 ----------------------- 8 files changed, 108 insertions(+), 94 deletions(-) create mode 100644 cmd/ghcs/common.go diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 5bad53648..81c2cc9a2 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -7,7 +7,6 @@ import ( "os" "github.com/github/ghcs/api" - "github.com/github/ghcs/internal/codespaces" "github.com/skratchdot/open-golang/open" "github.com/spf13/cobra" ) @@ -47,9 +46,9 @@ func code(codespaceName string, useInsiders bool) error { } if codespaceName == "" { - codespace, err := codespaces.ChooseCodespace(ctx, apiClient, user) + codespace, err := chooseCodespace(ctx, apiClient, user) if err != nil { - if err == codespaces.ErrNoCodespaces { + if err == errNoCodespaces { return err } return fmt.Errorf("error choosing codespace: %v", err) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go new file mode 100644 index 000000000..712c04daa --- /dev/null +++ b/cmd/ghcs/common.go @@ -0,0 +1,101 @@ +package main + +// This file defines functions common to the entire ghcs command set. + +import ( + "context" + "errors" + "fmt" + "sort" + + "github.com/AlecAivazis/survey/v2" + "github.com/github/ghcs/api" + "golang.org/x/term" +) + +var errNoCodespaces = errors.New("You have no codespaces.") + +func chooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (*api.Codespace, error) { + codespaces, err := apiClient.ListCodespaces(ctx, user) + if err != nil { + return nil, fmt.Errorf("error getting codespaces: %v", err) + } + + if len(codespaces) == 0 { + return nil, errNoCodespaces + } + + sort.Slice(codespaces, func(i, j int) bool { + return codespaces[i].CreatedAt > codespaces[j].CreatedAt + }) + + codespacesByName := make(map[string]*api.Codespace) + codespacesNames := make([]string, 0, len(codespaces)) + for _, codespace := range codespaces { + codespacesByName[codespace.Name] = codespace + codespacesNames = append(codespacesNames, codespace.Name) + } + + sshSurvey := []*survey.Question{ + { + Name: "codespace", + Prompt: &survey.Select{ + Message: "Choose codespace:", + Options: codespacesNames, + Default: codespacesNames[0], + }, + Validate: survey.Required, + }, + } + + var answers struct { + Codespace string + } + if err := ask(sshSurvey, &answers); err != nil { + return nil, fmt.Errorf("error getting answers: %v", err) + } + + codespace := codespacesByName[answers.Codespace] + return codespace, nil +} + +func getOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { + if codespaceName == "" { + codespace, err = chooseCodespace(ctx, apiClient, user) + if err != nil { + if err == errNoCodespaces { + return nil, "", err + } + return nil, "", fmt.Errorf("choosing codespace: %v", err) + } + codespaceName = codespace.Name + + token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) + if err != nil { + return nil, "", fmt.Errorf("getting codespace token: %v", err) + } + } else { + token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) + if err != nil { + return nil, "", fmt.Errorf("getting codespace token for given codespace: %v", err) + } + + codespace, err = apiClient.GetCodespace(ctx, token, user.Login, codespaceName) + if err != nil { + return nil, "", fmt.Errorf("getting full codespace details: %v", err) + } + } + + return codespace, token, nil +} + +var hasTTY = term.IsTerminal(0) && term.IsTerminal(1) // is process connected to a terminal? + +// ask asks survey questions on the terminal, using standard options. +// It fails unless hasTTY, but ideally callers should avoid calling it in that case. +func ask(qs []*survey.Question, response interface{}) error { + if !hasTTY { + return fmt.Errorf("no terminal") + } + return survey.Ask(qs, response, survey.WithShowCursor(true)) +} diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index bd1d89e4e..093450e7d 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -286,8 +286,3 @@ func getMachineName(ctx context.Context, machine string, user *api.User, repo *a return machine, nil } - -// ask asks survery questions using standard options. -func ask(qs []*survey.Question, response interface{}) error { - return survey.Ask(qs, response, survey.WithShowCursor(true)) -} diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 92c405766..75b9362bb 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -8,7 +8,6 @@ import ( "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" - "github.com/github/ghcs/internal/codespaces" "github.com/spf13/cobra" ) @@ -63,7 +62,7 @@ func delete_(codespaceName string) error { return fmt.Errorf("error getting user: %v", err) } - codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %v", err) } diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index e35ecb728..9998b5f0f 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -51,7 +51,7 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { return fmt.Errorf("getting user: %v", err) } - codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %v", err) } diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 4258991b6..c175549a0 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -67,9 +67,9 @@ func ports(opts *portsOptions) error { return fmt.Errorf("error getting user: %v", err) } - codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, opts.codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, opts.codespaceName) if err != nil { - if err == codespaces.ErrNoCodespaces { + if err == errNoCodespaces { return err } return fmt.Errorf("error choosing codespace: %v", err) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index e2003b347..aefb959f3 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -52,7 +52,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo return fmt.Errorf("error getting user: %v", err) } - codespace, token, err := codespaces.GetOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %v", err) } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index b5fa4a583..9aee3564c 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -4,62 +4,12 @@ import ( "context" "errors" "fmt" - "sort" "time" - "github.com/AlecAivazis/survey/v2" "github.com/github/ghcs/api" "github.com/github/go-liveshare" ) -var ( - ErrNoCodespaces = errors.New("You have no codespaces.") -) - -func ChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (*api.Codespace, error) { - codespaces, err := apiClient.ListCodespaces(ctx, user) - if err != nil { - return nil, fmt.Errorf("error getting codespaces: %v", err) - } - - if len(codespaces) == 0 { - return nil, ErrNoCodespaces - } - - sort.Slice(codespaces, func(i, j int) bool { - return codespaces[i].CreatedAt > codespaces[j].CreatedAt - }) - - codespacesByName := make(map[string]*api.Codespace) - codespacesNames := make([]string, 0, len(codespaces)) - for _, codespace := range codespaces { - codespacesByName[codespace.Name] = codespace - codespacesNames = append(codespacesNames, codespace.Name) - } - - sshSurvey := []*survey.Question{ - { - Name: "codespace", - Prompt: &survey.Select{ - Message: "Choose codespace:", - Options: codespacesNames, - Default: codespacesNames[0], - }, - Validate: survey.Required, - }, - } - - answers := struct { - Codespace string - }{} - if err := survey.Ask(sshSurvey, &answers); err != nil { - return nil, fmt.Errorf("error getting answers: %v", err) - } - - codespace := codespacesByName[answers.Codespace] - return codespace, nil -} - type logger interface { Print(v ...interface{}) (int, error) Println(v ...interface{}) (int, error) @@ -123,33 +73,3 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use return lsclient.JoinWorkspace(ctx) } - -func GetOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { - if codespaceName == "" { - codespace, err = ChooseCodespace(ctx, apiClient, user) - if err != nil { - if err == ErrNoCodespaces { - return nil, "", err - } - return nil, "", fmt.Errorf("choosing codespace: %v", err) - } - codespaceName = codespace.Name - - token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) - if err != nil { - return nil, "", fmt.Errorf("getting codespace token: %v", err) - } - } else { - token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) - if err != nil { - return nil, "", fmt.Errorf("getting codespace token for given codespace: %v", err) - } - - codespace, err = apiClient.GetCodespace(ctx, token, user.Login, codespaceName) - if err != nil { - return nil, "", fmt.Errorf("getting full codespace details: %v", err) - } - } - - return codespace, token, nil -} From 81f08e7baf9ec7f3610d852b08834b338a77b18d Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 9 Sep 2021 12:08:07 -0400 Subject: [PATCH 176/290] start converting to flags --- cmd/ghcs/ssh.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index e2003b347..cd552099f 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -22,15 +22,14 @@ func newSSHCmd() *cobra.Command { sshCmd := &cobra.Command{ Use: "ssh", Short: "SSH into a codespace", - Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { return ssh(context.Background(), sshProfile, codespaceName, sshServerPort) }, } - sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "The `name` of the SSH profile to use") + sshCmd.Flags().StringVarP(&sshProfile, "profile", "", "", "Name of the SSH profile to use") sshCmd.Flags().IntVarP(&sshServerPort, "server-port", "", 0, "SSH server port number (0 => pick unused)") - sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "The `name` of the codespace to use") + sshCmd.Flags().StringVarP(&codespaceName, "codespace", "c", "", "Name of the codespace") return sshCmd } @@ -67,10 +66,10 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo return fmt.Errorf("error getting ssh server details: %v", err) } - terminal := liveshare.NewTerminal(session) - log.Print("Preparing SSH...") if sshProfile == "" { + terminal := liveshare.NewTerminal(session) + containerID, err := getContainerID(ctx, log, terminal) if err != nil { return fmt.Errorf("error getting container id: %v", err) From 8b0e8c990e68dc9b74ed3d40f4c73c9a24bacb7b Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 9 Sep 2021 17:31:18 +0000 Subject: [PATCH 177/290] ignore pf conn errors --- port_forwarder.go | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index dc91222ed..1351025cb 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -33,9 +33,7 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar // connecting to the socket prematurely.) // // ForwardToListener accepts and handles connections on the local port -// until it encounters the first error, which may include context -// cancellation. Its error result is always non-nil. The caller is -// responsible for closing the listening port. +// until the context is cancelled. The caller is responsible for closing the listening port. func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { @@ -124,21 +122,14 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co } }() - errs := make(chan error, 2) - copyConn := func(w io.Writer, r io.Reader) { - _, err := io.Copy(w, r) - errs <- err - } - go copyConn(conn, channel) - go copyConn(channel, conn) + // Bi-directional copy of data. + // If any individual connection has an error, we can safely ignore them + // and defer to connection clients to handle data loss as necessary. + go io.Copy(conn, channel) + go io.Copy(channel, conn) - // await result - for i := 0; i < 2; i++ { - if err := <-errs; err != nil && err != io.EOF { - return fmt.Errorf("tunnel connection: %v", err) - } - } - return nil + <-ctx.Done() + return ctx.Err() } // safeClose reports the error (to *err) from closing the stream only From ee44ecc944cbd6288dadec9c9b723367f81b9f8c Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 9 Sep 2021 13:42:44 -0400 Subject: [PATCH 178/290] include span context in HTTP request --- api/api.go | 1 + 1 file changed, 1 insertion(+) diff --git a/api/api.go b/api/api.go index 28c7072ea..cb9f9371a 100644 --- a/api/api.go +++ b/api/api.go @@ -484,6 +484,7 @@ func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http // TODO(adonovan): use NewRequestWithContext(ctx) and drop ctx parameter. span, ctx := opentracing.StartSpanFromContext(ctx, spanName) defer span.Finish() + req = req.WithContext(ctx) return a.client.Do(req) } From 1ff5c514fb82458558757fc8f1fcdf6cc838afc0 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 9 Sep 2021 18:35:05 +0000 Subject: [PATCH 179/290] fix erroneous ctx waiting and introduce back io.EOF handling --- port_forwarder.go | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index 1351025cb..f47d11565 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -123,13 +123,28 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co }() // Bi-directional copy of data. - // If any individual connection has an error, we can safely ignore them - // and defer to connection clients to handle data loss as necessary. - go io.Copy(conn, channel) - go io.Copy(channel, conn) + errs := make(chan error, 2) + copyConn := func(w io.Writer, r io.Reader) { + _, err := io.Copy(w, r) + errs <- err + } + go copyConn(conn, channel) + go copyConn(channel, conn) - <-ctx.Done() - return ctx.Err() + // wait until context is cancelled or we've received two io.EOF +Loop: + for i := 0; i < 2; i++ { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errs: + if err != nil && err != io.EOF { + break Loop // non-EOF errors stop connection handling + } + } + } + + return nil } // safeClose reports the error (to *err) from closing the stream only From 920f793c6ddf001901308df947091c9d98563fdd Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 9 Sep 2021 19:33:16 +0000 Subject: [PATCH 180/290] pr feedback --- port_forwarder.go | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index f47d11565..e8649c693 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -122,7 +122,7 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co } }() - // Bi-directional copy of data. + // bi-directional copy of data. errs := make(chan error, 2) copyConn := func(w io.Writer, r io.Reader) { _, err := io.Copy(w, r) @@ -131,20 +131,18 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co go copyConn(conn, channel) go copyConn(channel, conn) - // wait until context is cancelled or we've received two io.EOF -Loop: - for i := 0; i < 2; i++ { + // wait until context is cancelled or both copies are done + for i := 0; ; { select { case <-ctx.Done(): return ctx.Err() - case err := <-errs: - if err != nil && err != io.EOF { - break Loop // non-EOF errors stop connection handling + case <-errs: + i++ + if i == 2 { + return nil } } } - - return nil } // safeClose reports the error (to *err) from closing the stream only From efe519cb7af8015a0747ac20d91877627b21b8cc Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 9 Sep 2021 20:11:45 +0000 Subject: [PATCH 181/290] comments + fix Forward method --- port_forwarder.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/port_forwarder.go b/port_forwarder.go index e8649c693..be191c211 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -80,9 +80,7 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) // Create buffered channel so that send doesn't get stuck after context cancellation. errc := make(chan error, 1) go func() { - if err := fwd.handleConnection(ctx, id, conn); err != nil { - errc <- err - } + errc <- fwd.handleConnection(ctx, id, conn) }() return awaitError(ctx, errc) } @@ -131,7 +129,9 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co go copyConn(conn, channel) go copyConn(channel, conn) - // wait until context is cancelled or both copies are done + // Wait until context is cancelled or both copies are done. + // Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail. + // TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file? for i := 0; ; { select { case <-ctx.Done(): From 22f9824ec8da912e5a4cecb3b77b7f23130f1829 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 9 Sep 2021 16:31:15 -0400 Subject: [PATCH 182/290] deliver SIGINT to self after Ctrl-C in survey --- cmd/ghcs/common.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index 712c04daa..b61d8bb81 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -6,9 +6,11 @@ import ( "context" "errors" "fmt" + "os" "sort" "github.com/AlecAivazis/survey/v2" + "github.com/AlecAivazis/survey/v2/terminal" "github.com/github/ghcs/api" "golang.org/x/term" ) @@ -97,5 +99,16 @@ func ask(qs []*survey.Question, response interface{}) error { if !hasTTY { return fmt.Errorf("no terminal") } - return survey.Ask(qs, response, survey.WithShowCursor(true)) + err := survey.Ask(qs, response, survey.WithShowCursor(true)) + // The survey package temporarily clears the terminal's ISIG mode bit + // (see tcsetattr(3)) so the QUIT button (Ctrl-C) is reported as + // ASCII \x03 (ETX) instead of delivering SIGINT to the application. + // So we have to serve ourselves the SIGINT. + // + // https://github.com/AlecAivazis/survey/#why-isnt-sending-a-sigint-aka-ctrl-c-signal-working + if err == terminal.InterruptErr { + self, _ := os.FindProcess(os.Getpid()) + _ = self.Signal(os.Interrupt) // assumes POSIX + } + return err } From 2cbe1207742a7e5add6d0da54edda4257a98083c Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 9 Sep 2021 16:37:26 -0400 Subject: [PATCH 183/290] return err, don"t fatal --- cmd/ghcs/main.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 44db33113..3a692d443 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -42,8 +42,7 @@ token to access the GitHub API with.`, if os.Getenv("GITHUB_TOKEN") == "" { return tokenError } - initLightstep(lightstep) - return nil + return initLightstep(lightstep) }, } @@ -64,9 +63,9 @@ func explainError(w io.Writer, err error) { // initLightstep parses the --lightstep=service:token@host:port flag and // enables tracing if non-empty. -func initLightstep(config string) { +func initLightstep(config string) error { if config == "" { - return + return nil } cut := func(s, sep string) (pre, post string) { @@ -82,7 +81,7 @@ func initLightstep(config string) { host, port := cut(hostPort, ":") portI, err := strconv.Atoi(port) if err != nil { - log.Fatalf("invalid Lightstep configuration: %s", config) + return fmt.Errorf("invalid Lightstep configuration: %s", config) } opentracing.SetGlobalTracer(lightstep.NewTracer(lightstep.Options{ @@ -106,4 +105,6 @@ func initLightstep(config string) { log.Printf("[trace] %s", ev) } }) + + return nil } From 272ea57b541c33846add2b9b493c0ca011985c93 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 9 Sep 2021 21:00:09 +0000 Subject: [PATCH 184/290] revert comment update --- port_forwarder.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/port_forwarder.go b/port_forwarder.go index be191c211..8011d19fc 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -33,7 +33,9 @@ func NewPortForwarder(session *Session, name string, remotePort int) *PortForwar // connecting to the socket prematurely.) // // ForwardToListener accepts and handles connections on the local port -// until the context is cancelled. The caller is responsible for closing the listening port. +// until it encounters the first error, which may include context +// cancellation. Its error result is always non-nil. The caller is +// responsible for closing the listening port. func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) { id, err := fwd.shareRemotePort(ctx) if err != nil { From 4f6cab195a89535ae8a96f9ad5bef8afba145b23 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 10 Sep 2021 10:08:54 -0400 Subject: [PATCH 185/290] wait for sigint delivery --- cmd/ghcs/common.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index b61d8bb81..e8927464a 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -109,6 +109,10 @@ func ask(qs []*survey.Question, response interface{}) error { if err == terminal.InterruptErr { self, _ := os.FindProcess(os.Getpid()) _ = self.Signal(os.Interrupt) // assumes POSIX + + // Suspend the goroutine, to avoid a race between + // return from main and async delivery of INT signal. + select {} } return err } From 810c127608a200318cfbd4d8247d0a6948f9443c Mon Sep 17 00:00:00 2001 From: Issy Long Date: Fri, 10 Sep 2021 15:24:46 +0100 Subject: [PATCH 186/290] goreleaser: Fix version string replacement - The `Version` variable casing changed in https://github.com/github/ghcs/commit/6a4950cf7ae02afc36cee06c26c232bc3fb71347#diff-d897a31624bae4fe935e8dc2243f41626c68639be6643535297c06935277ffb4, so we need to update our version setting code. - Otherwise, for `ghcs 0.11.0`, `ghcs --version` would print "DEV". --- cmd/ghcs/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 3a692d443..651d98c1d 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -21,7 +21,7 @@ func main() { } } -var version = "DEV" +var version = "DEV" // Replaced in the release build process (by GoReleaser or Homebrew) by the git tag version number. var rootCmd = newRootCmd() From efb2569d2bc22c09835b9fd761a052fa692cf35d Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 10 Sep 2021 12:29:25 -0400 Subject: [PATCH 187/290] move vendored go-ghcs-crypto to internal module --- internal/crypto/AUTHORS | 3 + internal/crypto/CONTRIBUTORS | 3 + internal/crypto/LICENSE | 27 + internal/crypto/PATENTS | 22 + internal/crypto/blowfish/block.go | 159 ++ internal/crypto/blowfish/cipher.go | 99 + internal/crypto/blowfish/const.go | 199 ++ internal/crypto/chacha20/chacha_arm64.go | 16 + internal/crypto/chacha20/chacha_arm64.s | 307 +++ internal/crypto/chacha20/chacha_generic.go | 398 ++++ internal/crypto/chacha20/chacha_noasm.go | 13 + internal/crypto/chacha20/chacha_ppc64le.go | 16 + internal/crypto/chacha20/chacha_ppc64le.s | 449 +++++ internal/crypto/chacha20/chacha_s390x.go | 26 + internal/crypto/chacha20/chacha_s390x.s | 224 ++ internal/crypto/chacha20/xor.go | 42 + internal/crypto/curve25519/curve25519.go | 95 + .../crypto/curve25519/curve25519_amd64.go | 240 +++ internal/crypto/curve25519/curve25519_amd64.s | 1793 +++++++++++++++++ .../crypto/curve25519/curve25519_generic.go | 828 ++++++++ .../crypto/curve25519/curve25519_noasm.go | 11 + internal/crypto/ed25519/ed25519.go | 222 ++ internal/crypto/ed25519/ed25519_go113.go | 73 + .../ed25519/internal/edwards25519/const.go | 1422 +++++++++++++ .../internal/edwards25519/edwards25519.go | 1793 +++++++++++++++++ internal/crypto/go.mod | 9 + internal/crypto/internal/subtle/aliasing.go | 32 + .../internal/subtle/aliasing_appengine.go | 35 + internal/crypto/poly1305/bits_compat.go | 39 + internal/crypto/poly1305/bits_go1.13.go | 21 + internal/crypto/poly1305/mac_noasm.go | 9 + internal/crypto/poly1305/poly1305.go | 99 + internal/crypto/poly1305/sum_amd64.go | 47 + internal/crypto/poly1305/sum_amd64.s | 108 + internal/crypto/poly1305/sum_generic.go | 310 +++ internal/crypto/poly1305/sum_ppc64le.go | 47 + internal/crypto/poly1305/sum_ppc64le.s | 181 ++ internal/crypto/poly1305/sum_s390x.go | 75 + internal/crypto/poly1305/sum_s390x.s | 503 +++++ internal/crypto/ssh/buffer.go | 97 + internal/crypto/ssh/certs.go | 556 +++++ internal/crypto/ssh/channel.go | 633 ++++++ internal/crypto/ssh/cipher.go | 781 +++++++ internal/crypto/ssh/client.go | 287 +++ internal/crypto/ssh/client_auth.go | 641 ++++++ internal/crypto/ssh/common.go | 408 ++++ internal/crypto/ssh/connection.go | 143 ++ internal/crypto/ssh/doc.go | 21 + internal/crypto/ssh/handshake.go | 646 ++++++ .../ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf.go | 93 + internal/crypto/ssh/kex.go | 789 ++++++++ internal/crypto/ssh/keys.go | 1493 ++++++++++++++ internal/crypto/ssh/mac.go | 61 + internal/crypto/ssh/messages.go | 866 ++++++++ internal/crypto/ssh/mux.go | 351 ++++ internal/crypto/ssh/server.go | 743 +++++++ internal/crypto/ssh/session.go | 647 ++++++ internal/crypto/ssh/ssh_gss.go | 139 ++ internal/crypto/ssh/streamlocal.go | 116 ++ internal/crypto/ssh/tcpip.go | 474 +++++ internal/crypto/ssh/terminal/terminal.go | 987 +++++++++ internal/crypto/ssh/terminal/util.go | 114 ++ internal/crypto/ssh/terminal/util_aix.go | 12 + internal/crypto/ssh/terminal/util_bsd.go | 12 + internal/crypto/ssh/terminal/util_linux.go | 10 + internal/crypto/ssh/terminal/util_plan9.go | 58 + internal/crypto/ssh/terminal/util_solaris.go | 124 ++ internal/crypto/ssh/terminal/util_windows.go | 105 + internal/crypto/ssh/transport.go | 353 ++++ 69 files changed, 21755 insertions(+) create mode 100644 internal/crypto/AUTHORS create mode 100644 internal/crypto/CONTRIBUTORS create mode 100644 internal/crypto/LICENSE create mode 100644 internal/crypto/PATENTS create mode 100644 internal/crypto/blowfish/block.go create mode 100644 internal/crypto/blowfish/cipher.go create mode 100644 internal/crypto/blowfish/const.go create mode 100644 internal/crypto/chacha20/chacha_arm64.go create mode 100644 internal/crypto/chacha20/chacha_arm64.s create mode 100644 internal/crypto/chacha20/chacha_generic.go create mode 100644 internal/crypto/chacha20/chacha_noasm.go create mode 100644 internal/crypto/chacha20/chacha_ppc64le.go create mode 100644 internal/crypto/chacha20/chacha_ppc64le.s create mode 100644 internal/crypto/chacha20/chacha_s390x.go create mode 100644 internal/crypto/chacha20/chacha_s390x.s create mode 100644 internal/crypto/chacha20/xor.go create mode 100644 internal/crypto/curve25519/curve25519.go create mode 100644 internal/crypto/curve25519/curve25519_amd64.go create mode 100644 internal/crypto/curve25519/curve25519_amd64.s create mode 100644 internal/crypto/curve25519/curve25519_generic.go create mode 100644 internal/crypto/curve25519/curve25519_noasm.go create mode 100644 internal/crypto/ed25519/ed25519.go create mode 100644 internal/crypto/ed25519/ed25519_go113.go create mode 100644 internal/crypto/ed25519/internal/edwards25519/const.go create mode 100644 internal/crypto/ed25519/internal/edwards25519/edwards25519.go create mode 100644 internal/crypto/go.mod create mode 100644 internal/crypto/internal/subtle/aliasing.go create mode 100644 internal/crypto/internal/subtle/aliasing_appengine.go create mode 100644 internal/crypto/poly1305/bits_compat.go create mode 100644 internal/crypto/poly1305/bits_go1.13.go create mode 100644 internal/crypto/poly1305/mac_noasm.go create mode 100644 internal/crypto/poly1305/poly1305.go create mode 100644 internal/crypto/poly1305/sum_amd64.go create mode 100644 internal/crypto/poly1305/sum_amd64.s create mode 100644 internal/crypto/poly1305/sum_generic.go create mode 100644 internal/crypto/poly1305/sum_ppc64le.go create mode 100644 internal/crypto/poly1305/sum_ppc64le.s create mode 100644 internal/crypto/poly1305/sum_s390x.go create mode 100644 internal/crypto/poly1305/sum_s390x.s create mode 100644 internal/crypto/ssh/buffer.go create mode 100644 internal/crypto/ssh/certs.go create mode 100644 internal/crypto/ssh/channel.go create mode 100644 internal/crypto/ssh/cipher.go create mode 100644 internal/crypto/ssh/client.go create mode 100644 internal/crypto/ssh/client_auth.go create mode 100644 internal/crypto/ssh/common.go create mode 100644 internal/crypto/ssh/connection.go create mode 100644 internal/crypto/ssh/doc.go create mode 100644 internal/crypto/ssh/handshake.go create mode 100644 internal/crypto/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf.go create mode 100644 internal/crypto/ssh/kex.go create mode 100644 internal/crypto/ssh/keys.go create mode 100644 internal/crypto/ssh/mac.go create mode 100644 internal/crypto/ssh/messages.go create mode 100644 internal/crypto/ssh/mux.go create mode 100644 internal/crypto/ssh/server.go create mode 100644 internal/crypto/ssh/session.go create mode 100644 internal/crypto/ssh/ssh_gss.go create mode 100644 internal/crypto/ssh/streamlocal.go create mode 100644 internal/crypto/ssh/tcpip.go create mode 100644 internal/crypto/ssh/terminal/terminal.go create mode 100644 internal/crypto/ssh/terminal/util.go create mode 100644 internal/crypto/ssh/terminal/util_aix.go create mode 100644 internal/crypto/ssh/terminal/util_bsd.go create mode 100644 internal/crypto/ssh/terminal/util_linux.go create mode 100644 internal/crypto/ssh/terminal/util_plan9.go create mode 100644 internal/crypto/ssh/terminal/util_solaris.go create mode 100644 internal/crypto/ssh/terminal/util_windows.go create mode 100644 internal/crypto/ssh/transport.go diff --git a/internal/crypto/AUTHORS b/internal/crypto/AUTHORS new file mode 100644 index 000000000..2b00ddba0 --- /dev/null +++ b/internal/crypto/AUTHORS @@ -0,0 +1,3 @@ +# This source code refers to The Go Authors for copyright purposes. +# The master list of authors is in the main Go distribution, +# visible at https://tip.golang.org/AUTHORS. diff --git a/internal/crypto/CONTRIBUTORS b/internal/crypto/CONTRIBUTORS new file mode 100644 index 000000000..1fbd3e976 --- /dev/null +++ b/internal/crypto/CONTRIBUTORS @@ -0,0 +1,3 @@ +# This source code was written by the Go contributors. +# The master list of contributors is in the main Go distribution, +# visible at https://tip.golang.org/CONTRIBUTORS. diff --git a/internal/crypto/LICENSE b/internal/crypto/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/internal/crypto/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/internal/crypto/PATENTS b/internal/crypto/PATENTS new file mode 100644 index 000000000..733099041 --- /dev/null +++ b/internal/crypto/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the Go project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of Go, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of Go. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of Go or any code incorporated within this +implementation of Go constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of Go +shall terminate as of the date such litigation is filed. diff --git a/internal/crypto/blowfish/block.go b/internal/crypto/blowfish/block.go new file mode 100644 index 000000000..9d80f1952 --- /dev/null +++ b/internal/crypto/blowfish/block.go @@ -0,0 +1,159 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package blowfish + +// getNextWord returns the next big-endian uint32 value from the byte slice +// at the given position in a circular manner, updating the position. +func getNextWord(b []byte, pos *int) uint32 { + var w uint32 + j := *pos + for i := 0; i < 4; i++ { + w = w<<8 | uint32(b[j]) + j++ + if j >= len(b) { + j = 0 + } + } + *pos = j + return w +} + +// ExpandKey performs a key expansion on the given *Cipher. Specifically, it +// performs the Blowfish algorithm's key schedule which sets up the *Cipher's +// pi and substitution tables for calls to Encrypt. This is used, primarily, +// by the bcrypt package to reuse the Blowfish key schedule during its +// set up. It's unlikely that you need to use this directly. +func ExpandKey(key []byte, c *Cipher) { + j := 0 + for i := 0; i < 18; i++ { + // Using inlined getNextWord for performance. + var d uint32 + for k := 0; k < 4; k++ { + d = d<<8 | uint32(key[j]) + j++ + if j >= len(key) { + j = 0 + } + } + c.p[i] ^= d + } + + var l, r uint32 + for i := 0; i < 18; i += 2 { + l, r = encryptBlock(l, r, c) + c.p[i], c.p[i+1] = l, r + } + + for i := 0; i < 256; i += 2 { + l, r = encryptBlock(l, r, c) + c.s0[i], c.s0[i+1] = l, r + } + for i := 0; i < 256; i += 2 { + l, r = encryptBlock(l, r, c) + c.s1[i], c.s1[i+1] = l, r + } + for i := 0; i < 256; i += 2 { + l, r = encryptBlock(l, r, c) + c.s2[i], c.s2[i+1] = l, r + } + for i := 0; i < 256; i += 2 { + l, r = encryptBlock(l, r, c) + c.s3[i], c.s3[i+1] = l, r + } +} + +// This is similar to ExpandKey, but folds the salt during the key +// schedule. While ExpandKey is essentially expandKeyWithSalt with an all-zero +// salt passed in, reusing ExpandKey turns out to be a place of inefficiency +// and specializing it here is useful. +func expandKeyWithSalt(key []byte, salt []byte, c *Cipher) { + j := 0 + for i := 0; i < 18; i++ { + c.p[i] ^= getNextWord(key, &j) + } + + j = 0 + var l, r uint32 + for i := 0; i < 18; i += 2 { + l ^= getNextWord(salt, &j) + r ^= getNextWord(salt, &j) + l, r = encryptBlock(l, r, c) + c.p[i], c.p[i+1] = l, r + } + + for i := 0; i < 256; i += 2 { + l ^= getNextWord(salt, &j) + r ^= getNextWord(salt, &j) + l, r = encryptBlock(l, r, c) + c.s0[i], c.s0[i+1] = l, r + } + + for i := 0; i < 256; i += 2 { + l ^= getNextWord(salt, &j) + r ^= getNextWord(salt, &j) + l, r = encryptBlock(l, r, c) + c.s1[i], c.s1[i+1] = l, r + } + + for i := 0; i < 256; i += 2 { + l ^= getNextWord(salt, &j) + r ^= getNextWord(salt, &j) + l, r = encryptBlock(l, r, c) + c.s2[i], c.s2[i+1] = l, r + } + + for i := 0; i < 256; i += 2 { + l ^= getNextWord(salt, &j) + r ^= getNextWord(salt, &j) + l, r = encryptBlock(l, r, c) + c.s3[i], c.s3[i+1] = l, r + } +} + +func encryptBlock(l, r uint32, c *Cipher) (uint32, uint32) { + xl, xr := l, r + xl ^= c.p[0] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[1] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[2] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[3] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[4] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[5] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[6] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[7] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[8] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[9] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[10] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[11] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[12] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[13] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[14] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[15] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[16] + xr ^= c.p[17] + return xr, xl +} + +func decryptBlock(l, r uint32, c *Cipher) (uint32, uint32) { + xl, xr := l, r + xl ^= c.p[17] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[16] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[15] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[14] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[13] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[12] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[11] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[10] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[9] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[8] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[7] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[6] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[5] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[4] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[3] + xr ^= ((c.s0[byte(xl>>24)] + c.s1[byte(xl>>16)]) ^ c.s2[byte(xl>>8)]) + c.s3[byte(xl)] ^ c.p[2] + xl ^= ((c.s0[byte(xr>>24)] + c.s1[byte(xr>>16)]) ^ c.s2[byte(xr>>8)]) + c.s3[byte(xr)] ^ c.p[1] + xr ^= c.p[0] + return xr, xl +} diff --git a/internal/crypto/blowfish/cipher.go b/internal/crypto/blowfish/cipher.go new file mode 100644 index 000000000..213bf204a --- /dev/null +++ b/internal/crypto/blowfish/cipher.go @@ -0,0 +1,99 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package blowfish implements Bruce Schneier's Blowfish encryption algorithm. +// +// Blowfish is a legacy cipher and its short block size makes it vulnerable to +// birthday bound attacks (see https://sweet32.info). It should only be used +// where compatibility with legacy systems, not security, is the goal. +// +// Deprecated: any new system should use AES (from crypto/aes, if necessary in +// an AEAD mode like crypto/cipher.NewGCM) or XChaCha20-Poly1305 (from +// golang.org/x/crypto/chacha20poly1305). +package blowfish // import "golang.org/x/crypto/blowfish" + +// The code is a port of Bruce Schneier's C implementation. +// See https://www.schneier.com/blowfish.html. + +import "strconv" + +// The Blowfish block size in bytes. +const BlockSize = 8 + +// A Cipher is an instance of Blowfish encryption using a particular key. +type Cipher struct { + p [18]uint32 + s0, s1, s2, s3 [256]uint32 +} + +type KeySizeError int + +func (k KeySizeError) Error() string { + return "crypto/blowfish: invalid key size " + strconv.Itoa(int(k)) +} + +// NewCipher creates and returns a Cipher. +// The key argument should be the Blowfish key, from 1 to 56 bytes. +func NewCipher(key []byte) (*Cipher, error) { + var result Cipher + if k := len(key); k < 1 || k > 56 { + return nil, KeySizeError(k) + } + initCipher(&result) + ExpandKey(key, &result) + return &result, nil +} + +// NewSaltedCipher creates a returns a Cipher that folds a salt into its key +// schedule. For most purposes, NewCipher, instead of NewSaltedCipher, is +// sufficient and desirable. For bcrypt compatibility, the key can be over 56 +// bytes. +func NewSaltedCipher(key, salt []byte) (*Cipher, error) { + if len(salt) == 0 { + return NewCipher(key) + } + var result Cipher + if k := len(key); k < 1 { + return nil, KeySizeError(k) + } + initCipher(&result) + expandKeyWithSalt(key, salt, &result) + return &result, nil +} + +// BlockSize returns the Blowfish block size, 8 bytes. +// It is necessary to satisfy the Block interface in the +// package "crypto/cipher". +func (c *Cipher) BlockSize() int { return BlockSize } + +// Encrypt encrypts the 8-byte buffer src using the key k +// and stores the result in dst. +// Note that for amounts of data larger than a block, +// it is not safe to just call Encrypt on successive blocks; +// instead, use an encryption mode like CBC (see crypto/cipher/cbc.go). +func (c *Cipher) Encrypt(dst, src []byte) { + l := uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3]) + r := uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7]) + l, r = encryptBlock(l, r, c) + dst[0], dst[1], dst[2], dst[3] = byte(l>>24), byte(l>>16), byte(l>>8), byte(l) + dst[4], dst[5], dst[6], dst[7] = byte(r>>24), byte(r>>16), byte(r>>8), byte(r) +} + +// Decrypt decrypts the 8-byte buffer src using the key k +// and stores the result in dst. +func (c *Cipher) Decrypt(dst, src []byte) { + l := uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3]) + r := uint32(src[4])<<24 | uint32(src[5])<<16 | uint32(src[6])<<8 | uint32(src[7]) + l, r = decryptBlock(l, r, c) + dst[0], dst[1], dst[2], dst[3] = byte(l>>24), byte(l>>16), byte(l>>8), byte(l) + dst[4], dst[5], dst[6], dst[7] = byte(r>>24), byte(r>>16), byte(r>>8), byte(r) +} + +func initCipher(c *Cipher) { + copy(c.p[0:], p[0:]) + copy(c.s0[0:], s0[0:]) + copy(c.s1[0:], s1[0:]) + copy(c.s2[0:], s2[0:]) + copy(c.s3[0:], s3[0:]) +} diff --git a/internal/crypto/blowfish/const.go b/internal/crypto/blowfish/const.go new file mode 100644 index 000000000..d04077595 --- /dev/null +++ b/internal/crypto/blowfish/const.go @@ -0,0 +1,199 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The startup permutation array and substitution boxes. +// They are the hexadecimal digits of PI; see: +// https://www.schneier.com/code/constants.txt. + +package blowfish + +var s0 = [256]uint32{ + 0xd1310ba6, 0x98dfb5ac, 0x2ffd72db, 0xd01adfb7, 0xb8e1afed, 0x6a267e96, + 0xba7c9045, 0xf12c7f99, 0x24a19947, 0xb3916cf7, 0x0801f2e2, 0x858efc16, + 0x636920d8, 0x71574e69, 0xa458fea3, 0xf4933d7e, 0x0d95748f, 0x728eb658, + 0x718bcd58, 0x82154aee, 0x7b54a41d, 0xc25a59b5, 0x9c30d539, 0x2af26013, + 0xc5d1b023, 0x286085f0, 0xca417918, 0xb8db38ef, 0x8e79dcb0, 0x603a180e, + 0x6c9e0e8b, 0xb01e8a3e, 0xd71577c1, 0xbd314b27, 0x78af2fda, 0x55605c60, + 0xe65525f3, 0xaa55ab94, 0x57489862, 0x63e81440, 0x55ca396a, 0x2aab10b6, + 0xb4cc5c34, 0x1141e8ce, 0xa15486af, 0x7c72e993, 0xb3ee1411, 0x636fbc2a, + 0x2ba9c55d, 0x741831f6, 0xce5c3e16, 0x9b87931e, 0xafd6ba33, 0x6c24cf5c, + 0x7a325381, 0x28958677, 0x3b8f4898, 0x6b4bb9af, 0xc4bfe81b, 0x66282193, + 0x61d809cc, 0xfb21a991, 0x487cac60, 0x5dec8032, 0xef845d5d, 0xe98575b1, + 0xdc262302, 0xeb651b88, 0x23893e81, 0xd396acc5, 0x0f6d6ff3, 0x83f44239, + 0x2e0b4482, 0xa4842004, 0x69c8f04a, 0x9e1f9b5e, 0x21c66842, 0xf6e96c9a, + 0x670c9c61, 0xabd388f0, 0x6a51a0d2, 0xd8542f68, 0x960fa728, 0xab5133a3, + 0x6eef0b6c, 0x137a3be4, 0xba3bf050, 0x7efb2a98, 0xa1f1651d, 0x39af0176, + 0x66ca593e, 0x82430e88, 0x8cee8619, 0x456f9fb4, 0x7d84a5c3, 0x3b8b5ebe, + 0xe06f75d8, 0x85c12073, 0x401a449f, 0x56c16aa6, 0x4ed3aa62, 0x363f7706, + 0x1bfedf72, 0x429b023d, 0x37d0d724, 0xd00a1248, 0xdb0fead3, 0x49f1c09b, + 0x075372c9, 0x80991b7b, 0x25d479d8, 0xf6e8def7, 0xe3fe501a, 0xb6794c3b, + 0x976ce0bd, 0x04c006ba, 0xc1a94fb6, 0x409f60c4, 0x5e5c9ec2, 0x196a2463, + 0x68fb6faf, 0x3e6c53b5, 0x1339b2eb, 0x3b52ec6f, 0x6dfc511f, 0x9b30952c, + 0xcc814544, 0xaf5ebd09, 0xbee3d004, 0xde334afd, 0x660f2807, 0x192e4bb3, + 0xc0cba857, 0x45c8740f, 0xd20b5f39, 0xb9d3fbdb, 0x5579c0bd, 0x1a60320a, + 0xd6a100c6, 0x402c7279, 0x679f25fe, 0xfb1fa3cc, 0x8ea5e9f8, 0xdb3222f8, + 0x3c7516df, 0xfd616b15, 0x2f501ec8, 0xad0552ab, 0x323db5fa, 0xfd238760, + 0x53317b48, 0x3e00df82, 0x9e5c57bb, 0xca6f8ca0, 0x1a87562e, 0xdf1769db, + 0xd542a8f6, 0x287effc3, 0xac6732c6, 0x8c4f5573, 0x695b27b0, 0xbbca58c8, + 0xe1ffa35d, 0xb8f011a0, 0x10fa3d98, 0xfd2183b8, 0x4afcb56c, 0x2dd1d35b, + 0x9a53e479, 0xb6f84565, 0xd28e49bc, 0x4bfb9790, 0xe1ddf2da, 0xa4cb7e33, + 0x62fb1341, 0xcee4c6e8, 0xef20cada, 0x36774c01, 0xd07e9efe, 0x2bf11fb4, + 0x95dbda4d, 0xae909198, 0xeaad8e71, 0x6b93d5a0, 0xd08ed1d0, 0xafc725e0, + 0x8e3c5b2f, 0x8e7594b7, 0x8ff6e2fb, 0xf2122b64, 0x8888b812, 0x900df01c, + 0x4fad5ea0, 0x688fc31c, 0xd1cff191, 0xb3a8c1ad, 0x2f2f2218, 0xbe0e1777, + 0xea752dfe, 0x8b021fa1, 0xe5a0cc0f, 0xb56f74e8, 0x18acf3d6, 0xce89e299, + 0xb4a84fe0, 0xfd13e0b7, 0x7cc43b81, 0xd2ada8d9, 0x165fa266, 0x80957705, + 0x93cc7314, 0x211a1477, 0xe6ad2065, 0x77b5fa86, 0xc75442f5, 0xfb9d35cf, + 0xebcdaf0c, 0x7b3e89a0, 0xd6411bd3, 0xae1e7e49, 0x00250e2d, 0x2071b35e, + 0x226800bb, 0x57b8e0af, 0x2464369b, 0xf009b91e, 0x5563911d, 0x59dfa6aa, + 0x78c14389, 0xd95a537f, 0x207d5ba2, 0x02e5b9c5, 0x83260376, 0x6295cfa9, + 0x11c81968, 0x4e734a41, 0xb3472dca, 0x7b14a94a, 0x1b510052, 0x9a532915, + 0xd60f573f, 0xbc9bc6e4, 0x2b60a476, 0x81e67400, 0x08ba6fb5, 0x571be91f, + 0xf296ec6b, 0x2a0dd915, 0xb6636521, 0xe7b9f9b6, 0xff34052e, 0xc5855664, + 0x53b02d5d, 0xa99f8fa1, 0x08ba4799, 0x6e85076a, +} + +var s1 = [256]uint32{ + 0x4b7a70e9, 0xb5b32944, 0xdb75092e, 0xc4192623, 0xad6ea6b0, 0x49a7df7d, + 0x9cee60b8, 0x8fedb266, 0xecaa8c71, 0x699a17ff, 0x5664526c, 0xc2b19ee1, + 0x193602a5, 0x75094c29, 0xa0591340, 0xe4183a3e, 0x3f54989a, 0x5b429d65, + 0x6b8fe4d6, 0x99f73fd6, 0xa1d29c07, 0xefe830f5, 0x4d2d38e6, 0xf0255dc1, + 0x4cdd2086, 0x8470eb26, 0x6382e9c6, 0x021ecc5e, 0x09686b3f, 0x3ebaefc9, + 0x3c971814, 0x6b6a70a1, 0x687f3584, 0x52a0e286, 0xb79c5305, 0xaa500737, + 0x3e07841c, 0x7fdeae5c, 0x8e7d44ec, 0x5716f2b8, 0xb03ada37, 0xf0500c0d, + 0xf01c1f04, 0x0200b3ff, 0xae0cf51a, 0x3cb574b2, 0x25837a58, 0xdc0921bd, + 0xd19113f9, 0x7ca92ff6, 0x94324773, 0x22f54701, 0x3ae5e581, 0x37c2dadc, + 0xc8b57634, 0x9af3dda7, 0xa9446146, 0x0fd0030e, 0xecc8c73e, 0xa4751e41, + 0xe238cd99, 0x3bea0e2f, 0x3280bba1, 0x183eb331, 0x4e548b38, 0x4f6db908, + 0x6f420d03, 0xf60a04bf, 0x2cb81290, 0x24977c79, 0x5679b072, 0xbcaf89af, + 0xde9a771f, 0xd9930810, 0xb38bae12, 0xdccf3f2e, 0x5512721f, 0x2e6b7124, + 0x501adde6, 0x9f84cd87, 0x7a584718, 0x7408da17, 0xbc9f9abc, 0xe94b7d8c, + 0xec7aec3a, 0xdb851dfa, 0x63094366, 0xc464c3d2, 0xef1c1847, 0x3215d908, + 0xdd433b37, 0x24c2ba16, 0x12a14d43, 0x2a65c451, 0x50940002, 0x133ae4dd, + 0x71dff89e, 0x10314e55, 0x81ac77d6, 0x5f11199b, 0x043556f1, 0xd7a3c76b, + 0x3c11183b, 0x5924a509, 0xf28fe6ed, 0x97f1fbfa, 0x9ebabf2c, 0x1e153c6e, + 0x86e34570, 0xeae96fb1, 0x860e5e0a, 0x5a3e2ab3, 0x771fe71c, 0x4e3d06fa, + 0x2965dcb9, 0x99e71d0f, 0x803e89d6, 0x5266c825, 0x2e4cc978, 0x9c10b36a, + 0xc6150eba, 0x94e2ea78, 0xa5fc3c53, 0x1e0a2df4, 0xf2f74ea7, 0x361d2b3d, + 0x1939260f, 0x19c27960, 0x5223a708, 0xf71312b6, 0xebadfe6e, 0xeac31f66, + 0xe3bc4595, 0xa67bc883, 0xb17f37d1, 0x018cff28, 0xc332ddef, 0xbe6c5aa5, + 0x65582185, 0x68ab9802, 0xeecea50f, 0xdb2f953b, 0x2aef7dad, 0x5b6e2f84, + 0x1521b628, 0x29076170, 0xecdd4775, 0x619f1510, 0x13cca830, 0xeb61bd96, + 0x0334fe1e, 0xaa0363cf, 0xb5735c90, 0x4c70a239, 0xd59e9e0b, 0xcbaade14, + 0xeecc86bc, 0x60622ca7, 0x9cab5cab, 0xb2f3846e, 0x648b1eaf, 0x19bdf0ca, + 0xa02369b9, 0x655abb50, 0x40685a32, 0x3c2ab4b3, 0x319ee9d5, 0xc021b8f7, + 0x9b540b19, 0x875fa099, 0x95f7997e, 0x623d7da8, 0xf837889a, 0x97e32d77, + 0x11ed935f, 0x16681281, 0x0e358829, 0xc7e61fd6, 0x96dedfa1, 0x7858ba99, + 0x57f584a5, 0x1b227263, 0x9b83c3ff, 0x1ac24696, 0xcdb30aeb, 0x532e3054, + 0x8fd948e4, 0x6dbc3128, 0x58ebf2ef, 0x34c6ffea, 0xfe28ed61, 0xee7c3c73, + 0x5d4a14d9, 0xe864b7e3, 0x42105d14, 0x203e13e0, 0x45eee2b6, 0xa3aaabea, + 0xdb6c4f15, 0xfacb4fd0, 0xc742f442, 0xef6abbb5, 0x654f3b1d, 0x41cd2105, + 0xd81e799e, 0x86854dc7, 0xe44b476a, 0x3d816250, 0xcf62a1f2, 0x5b8d2646, + 0xfc8883a0, 0xc1c7b6a3, 0x7f1524c3, 0x69cb7492, 0x47848a0b, 0x5692b285, + 0x095bbf00, 0xad19489d, 0x1462b174, 0x23820e00, 0x58428d2a, 0x0c55f5ea, + 0x1dadf43e, 0x233f7061, 0x3372f092, 0x8d937e41, 0xd65fecf1, 0x6c223bdb, + 0x7cde3759, 0xcbee7460, 0x4085f2a7, 0xce77326e, 0xa6078084, 0x19f8509e, + 0xe8efd855, 0x61d99735, 0xa969a7aa, 0xc50c06c2, 0x5a04abfc, 0x800bcadc, + 0x9e447a2e, 0xc3453484, 0xfdd56705, 0x0e1e9ec9, 0xdb73dbd3, 0x105588cd, + 0x675fda79, 0xe3674340, 0xc5c43465, 0x713e38d8, 0x3d28f89e, 0xf16dff20, + 0x153e21e7, 0x8fb03d4a, 0xe6e39f2b, 0xdb83adf7, +} + +var s2 = [256]uint32{ + 0xe93d5a68, 0x948140f7, 0xf64c261c, 0x94692934, 0x411520f7, 0x7602d4f7, + 0xbcf46b2e, 0xd4a20068, 0xd4082471, 0x3320f46a, 0x43b7d4b7, 0x500061af, + 0x1e39f62e, 0x97244546, 0x14214f74, 0xbf8b8840, 0x4d95fc1d, 0x96b591af, + 0x70f4ddd3, 0x66a02f45, 0xbfbc09ec, 0x03bd9785, 0x7fac6dd0, 0x31cb8504, + 0x96eb27b3, 0x55fd3941, 0xda2547e6, 0xabca0a9a, 0x28507825, 0x530429f4, + 0x0a2c86da, 0xe9b66dfb, 0x68dc1462, 0xd7486900, 0x680ec0a4, 0x27a18dee, + 0x4f3ffea2, 0xe887ad8c, 0xb58ce006, 0x7af4d6b6, 0xaace1e7c, 0xd3375fec, + 0xce78a399, 0x406b2a42, 0x20fe9e35, 0xd9f385b9, 0xee39d7ab, 0x3b124e8b, + 0x1dc9faf7, 0x4b6d1856, 0x26a36631, 0xeae397b2, 0x3a6efa74, 0xdd5b4332, + 0x6841e7f7, 0xca7820fb, 0xfb0af54e, 0xd8feb397, 0x454056ac, 0xba489527, + 0x55533a3a, 0x20838d87, 0xfe6ba9b7, 0xd096954b, 0x55a867bc, 0xa1159a58, + 0xcca92963, 0x99e1db33, 0xa62a4a56, 0x3f3125f9, 0x5ef47e1c, 0x9029317c, + 0xfdf8e802, 0x04272f70, 0x80bb155c, 0x05282ce3, 0x95c11548, 0xe4c66d22, + 0x48c1133f, 0xc70f86dc, 0x07f9c9ee, 0x41041f0f, 0x404779a4, 0x5d886e17, + 0x325f51eb, 0xd59bc0d1, 0xf2bcc18f, 0x41113564, 0x257b7834, 0x602a9c60, + 0xdff8e8a3, 0x1f636c1b, 0x0e12b4c2, 0x02e1329e, 0xaf664fd1, 0xcad18115, + 0x6b2395e0, 0x333e92e1, 0x3b240b62, 0xeebeb922, 0x85b2a20e, 0xe6ba0d99, + 0xde720c8c, 0x2da2f728, 0xd0127845, 0x95b794fd, 0x647d0862, 0xe7ccf5f0, + 0x5449a36f, 0x877d48fa, 0xc39dfd27, 0xf33e8d1e, 0x0a476341, 0x992eff74, + 0x3a6f6eab, 0xf4f8fd37, 0xa812dc60, 0xa1ebddf8, 0x991be14c, 0xdb6e6b0d, + 0xc67b5510, 0x6d672c37, 0x2765d43b, 0xdcd0e804, 0xf1290dc7, 0xcc00ffa3, + 0xb5390f92, 0x690fed0b, 0x667b9ffb, 0xcedb7d9c, 0xa091cf0b, 0xd9155ea3, + 0xbb132f88, 0x515bad24, 0x7b9479bf, 0x763bd6eb, 0x37392eb3, 0xcc115979, + 0x8026e297, 0xf42e312d, 0x6842ada7, 0xc66a2b3b, 0x12754ccc, 0x782ef11c, + 0x6a124237, 0xb79251e7, 0x06a1bbe6, 0x4bfb6350, 0x1a6b1018, 0x11caedfa, + 0x3d25bdd8, 0xe2e1c3c9, 0x44421659, 0x0a121386, 0xd90cec6e, 0xd5abea2a, + 0x64af674e, 0xda86a85f, 0xbebfe988, 0x64e4c3fe, 0x9dbc8057, 0xf0f7c086, + 0x60787bf8, 0x6003604d, 0xd1fd8346, 0xf6381fb0, 0x7745ae04, 0xd736fccc, + 0x83426b33, 0xf01eab71, 0xb0804187, 0x3c005e5f, 0x77a057be, 0xbde8ae24, + 0x55464299, 0xbf582e61, 0x4e58f48f, 0xf2ddfda2, 0xf474ef38, 0x8789bdc2, + 0x5366f9c3, 0xc8b38e74, 0xb475f255, 0x46fcd9b9, 0x7aeb2661, 0x8b1ddf84, + 0x846a0e79, 0x915f95e2, 0x466e598e, 0x20b45770, 0x8cd55591, 0xc902de4c, + 0xb90bace1, 0xbb8205d0, 0x11a86248, 0x7574a99e, 0xb77f19b6, 0xe0a9dc09, + 0x662d09a1, 0xc4324633, 0xe85a1f02, 0x09f0be8c, 0x4a99a025, 0x1d6efe10, + 0x1ab93d1d, 0x0ba5a4df, 0xa186f20f, 0x2868f169, 0xdcb7da83, 0x573906fe, + 0xa1e2ce9b, 0x4fcd7f52, 0x50115e01, 0xa70683fa, 0xa002b5c4, 0x0de6d027, + 0x9af88c27, 0x773f8641, 0xc3604c06, 0x61a806b5, 0xf0177a28, 0xc0f586e0, + 0x006058aa, 0x30dc7d62, 0x11e69ed7, 0x2338ea63, 0x53c2dd94, 0xc2c21634, + 0xbbcbee56, 0x90bcb6de, 0xebfc7da1, 0xce591d76, 0x6f05e409, 0x4b7c0188, + 0x39720a3d, 0x7c927c24, 0x86e3725f, 0x724d9db9, 0x1ac15bb4, 0xd39eb8fc, + 0xed545578, 0x08fca5b5, 0xd83d7cd3, 0x4dad0fc4, 0x1e50ef5e, 0xb161e6f8, + 0xa28514d9, 0x6c51133c, 0x6fd5c7e7, 0x56e14ec4, 0x362abfce, 0xddc6c837, + 0xd79a3234, 0x92638212, 0x670efa8e, 0x406000e0, +} + +var s3 = [256]uint32{ + 0x3a39ce37, 0xd3faf5cf, 0xabc27737, 0x5ac52d1b, 0x5cb0679e, 0x4fa33742, + 0xd3822740, 0x99bc9bbe, 0xd5118e9d, 0xbf0f7315, 0xd62d1c7e, 0xc700c47b, + 0xb78c1b6b, 0x21a19045, 0xb26eb1be, 0x6a366eb4, 0x5748ab2f, 0xbc946e79, + 0xc6a376d2, 0x6549c2c8, 0x530ff8ee, 0x468dde7d, 0xd5730a1d, 0x4cd04dc6, + 0x2939bbdb, 0xa9ba4650, 0xac9526e8, 0xbe5ee304, 0xa1fad5f0, 0x6a2d519a, + 0x63ef8ce2, 0x9a86ee22, 0xc089c2b8, 0x43242ef6, 0xa51e03aa, 0x9cf2d0a4, + 0x83c061ba, 0x9be96a4d, 0x8fe51550, 0xba645bd6, 0x2826a2f9, 0xa73a3ae1, + 0x4ba99586, 0xef5562e9, 0xc72fefd3, 0xf752f7da, 0x3f046f69, 0x77fa0a59, + 0x80e4a915, 0x87b08601, 0x9b09e6ad, 0x3b3ee593, 0xe990fd5a, 0x9e34d797, + 0x2cf0b7d9, 0x022b8b51, 0x96d5ac3a, 0x017da67d, 0xd1cf3ed6, 0x7c7d2d28, + 0x1f9f25cf, 0xadf2b89b, 0x5ad6b472, 0x5a88f54c, 0xe029ac71, 0xe019a5e6, + 0x47b0acfd, 0xed93fa9b, 0xe8d3c48d, 0x283b57cc, 0xf8d56629, 0x79132e28, + 0x785f0191, 0xed756055, 0xf7960e44, 0xe3d35e8c, 0x15056dd4, 0x88f46dba, + 0x03a16125, 0x0564f0bd, 0xc3eb9e15, 0x3c9057a2, 0x97271aec, 0xa93a072a, + 0x1b3f6d9b, 0x1e6321f5, 0xf59c66fb, 0x26dcf319, 0x7533d928, 0xb155fdf5, + 0x03563482, 0x8aba3cbb, 0x28517711, 0xc20ad9f8, 0xabcc5167, 0xccad925f, + 0x4de81751, 0x3830dc8e, 0x379d5862, 0x9320f991, 0xea7a90c2, 0xfb3e7bce, + 0x5121ce64, 0x774fbe32, 0xa8b6e37e, 0xc3293d46, 0x48de5369, 0x6413e680, + 0xa2ae0810, 0xdd6db224, 0x69852dfd, 0x09072166, 0xb39a460a, 0x6445c0dd, + 0x586cdecf, 0x1c20c8ae, 0x5bbef7dd, 0x1b588d40, 0xccd2017f, 0x6bb4e3bb, + 0xdda26a7e, 0x3a59ff45, 0x3e350a44, 0xbcb4cdd5, 0x72eacea8, 0xfa6484bb, + 0x8d6612ae, 0xbf3c6f47, 0xd29be463, 0x542f5d9e, 0xaec2771b, 0xf64e6370, + 0x740e0d8d, 0xe75b1357, 0xf8721671, 0xaf537d5d, 0x4040cb08, 0x4eb4e2cc, + 0x34d2466a, 0x0115af84, 0xe1b00428, 0x95983a1d, 0x06b89fb4, 0xce6ea048, + 0x6f3f3b82, 0x3520ab82, 0x011a1d4b, 0x277227f8, 0x611560b1, 0xe7933fdc, + 0xbb3a792b, 0x344525bd, 0xa08839e1, 0x51ce794b, 0x2f32c9b7, 0xa01fbac9, + 0xe01cc87e, 0xbcc7d1f6, 0xcf0111c3, 0xa1e8aac7, 0x1a908749, 0xd44fbd9a, + 0xd0dadecb, 0xd50ada38, 0x0339c32a, 0xc6913667, 0x8df9317c, 0xe0b12b4f, + 0xf79e59b7, 0x43f5bb3a, 0xf2d519ff, 0x27d9459c, 0xbf97222c, 0x15e6fc2a, + 0x0f91fc71, 0x9b941525, 0xfae59361, 0xceb69ceb, 0xc2a86459, 0x12baa8d1, + 0xb6c1075e, 0xe3056a0c, 0x10d25065, 0xcb03a442, 0xe0ec6e0e, 0x1698db3b, + 0x4c98a0be, 0x3278e964, 0x9f1f9532, 0xe0d392df, 0xd3a0342b, 0x8971f21e, + 0x1b0a7441, 0x4ba3348c, 0xc5be7120, 0xc37632d8, 0xdf359f8d, 0x9b992f2e, + 0xe60b6f47, 0x0fe3f11d, 0xe54cda54, 0x1edad891, 0xce6279cf, 0xcd3e7e6f, + 0x1618b166, 0xfd2c1d05, 0x848fd2c5, 0xf6fb2299, 0xf523f357, 0xa6327623, + 0x93a83531, 0x56cccd02, 0xacf08162, 0x5a75ebb5, 0x6e163697, 0x88d273cc, + 0xde966292, 0x81b949d0, 0x4c50901b, 0x71c65614, 0xe6c6c7bd, 0x327a140a, + 0x45e1d006, 0xc3f27b9a, 0xc9aa53fd, 0x62a80f00, 0xbb25bfe2, 0x35bdd2f6, + 0x71126905, 0xb2040222, 0xb6cbcf7c, 0xcd769c2b, 0x53113ec0, 0x1640e3d3, + 0x38abbd60, 0x2547adf0, 0xba38209c, 0xf746ce76, 0x77afa1c5, 0x20756060, + 0x85cbfe4e, 0x8ae88dd8, 0x7aaaf9b0, 0x4cf9aa7e, 0x1948c25c, 0x02fb8a8c, + 0x01c36ae4, 0xd6ebe1f9, 0x90d4f869, 0xa65cdea0, 0x3f09252d, 0xc208e69f, + 0xb74e6132, 0xce77e25b, 0x578fdfe3, 0x3ac372e6, +} + +var p = [18]uint32{ + 0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344, 0xa4093822, 0x299f31d0, + 0x082efa98, 0xec4e6c89, 0x452821e6, 0x38d01377, 0xbe5466cf, 0x34e90c6c, + 0xc0ac29b7, 0xc97c50dd, 0x3f84d5b5, 0xb5470917, 0x9216d5d9, 0x8979fb1b, +} diff --git a/internal/crypto/chacha20/chacha_arm64.go b/internal/crypto/chacha20/chacha_arm64.go new file mode 100644 index 000000000..b799e440b --- /dev/null +++ b/internal/crypto/chacha20/chacha_arm64.go @@ -0,0 +1,16 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.11,!gccgo,!purego + +package chacha20 + +const bufSize = 256 + +//go:noescape +func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32) + +func (c *Cipher) xorKeyStreamBlocks(dst, src []byte) { + xorKeyStreamVX(dst, src, &c.key, &c.nonce, &c.counter) +} diff --git a/internal/crypto/chacha20/chacha_arm64.s b/internal/crypto/chacha20/chacha_arm64.s new file mode 100644 index 000000000..891481539 --- /dev/null +++ b/internal/crypto/chacha20/chacha_arm64.s @@ -0,0 +1,307 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.11,!gccgo,!purego + +#include "textflag.h" + +#define NUM_ROUNDS 10 + +// func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32) +TEXT ·xorKeyStreamVX(SB), NOSPLIT, $0 + MOVD dst+0(FP), R1 + MOVD src+24(FP), R2 + MOVD src_len+32(FP), R3 + MOVD key+48(FP), R4 + MOVD nonce+56(FP), R6 + MOVD counter+64(FP), R7 + + MOVD $·constants(SB), R10 + MOVD $·incRotMatrix(SB), R11 + + MOVW (R7), R20 + + AND $~255, R3, R13 + ADD R2, R13, R12 // R12 for block end + AND $255, R3, R13 +loop: + MOVD $NUM_ROUNDS, R21 + VLD1 (R11), [V30.S4, V31.S4] + + // load contants + // VLD4R (R10), [V0.S4, V1.S4, V2.S4, V3.S4] + WORD $0x4D60E940 + + // load keys + // VLD4R 16(R4), [V4.S4, V5.S4, V6.S4, V7.S4] + WORD $0x4DFFE884 + // VLD4R 16(R4), [V8.S4, V9.S4, V10.S4, V11.S4] + WORD $0x4DFFE888 + SUB $32, R4 + + // load counter + nonce + // VLD1R (R7), [V12.S4] + WORD $0x4D40C8EC + + // VLD3R (R6), [V13.S4, V14.S4, V15.S4] + WORD $0x4D40E8CD + + // update counter + VADD V30.S4, V12.S4, V12.S4 + +chacha: + // V0..V3 += V4..V7 + // V12..V15 <<<= ((V12..V15 XOR V0..V3), 16) + VADD V0.S4, V4.S4, V0.S4 + VADD V1.S4, V5.S4, V1.S4 + VADD V2.S4, V6.S4, V2.S4 + VADD V3.S4, V7.S4, V3.S4 + VEOR V12.B16, V0.B16, V12.B16 + VEOR V13.B16, V1.B16, V13.B16 + VEOR V14.B16, V2.B16, V14.B16 + VEOR V15.B16, V3.B16, V15.B16 + VREV32 V12.H8, V12.H8 + VREV32 V13.H8, V13.H8 + VREV32 V14.H8, V14.H8 + VREV32 V15.H8, V15.H8 + // V8..V11 += V12..V15 + // V4..V7 <<<= ((V4..V7 XOR V8..V11), 12) + VADD V8.S4, V12.S4, V8.S4 + VADD V9.S4, V13.S4, V9.S4 + VADD V10.S4, V14.S4, V10.S4 + VADD V11.S4, V15.S4, V11.S4 + VEOR V8.B16, V4.B16, V16.B16 + VEOR V9.B16, V5.B16, V17.B16 + VEOR V10.B16, V6.B16, V18.B16 + VEOR V11.B16, V7.B16, V19.B16 + VSHL $12, V16.S4, V4.S4 + VSHL $12, V17.S4, V5.S4 + VSHL $12, V18.S4, V6.S4 + VSHL $12, V19.S4, V7.S4 + VSRI $20, V16.S4, V4.S4 + VSRI $20, V17.S4, V5.S4 + VSRI $20, V18.S4, V6.S4 + VSRI $20, V19.S4, V7.S4 + + // V0..V3 += V4..V7 + // V12..V15 <<<= ((V12..V15 XOR V0..V3), 8) + VADD V0.S4, V4.S4, V0.S4 + VADD V1.S4, V5.S4, V1.S4 + VADD V2.S4, V6.S4, V2.S4 + VADD V3.S4, V7.S4, V3.S4 + VEOR V12.B16, V0.B16, V12.B16 + VEOR V13.B16, V1.B16, V13.B16 + VEOR V14.B16, V2.B16, V14.B16 + VEOR V15.B16, V3.B16, V15.B16 + VTBL V31.B16, [V12.B16], V12.B16 + VTBL V31.B16, [V13.B16], V13.B16 + VTBL V31.B16, [V14.B16], V14.B16 + VTBL V31.B16, [V15.B16], V15.B16 + + // V8..V11 += V12..V15 + // V4..V7 <<<= ((V4..V7 XOR V8..V11), 7) + VADD V12.S4, V8.S4, V8.S4 + VADD V13.S4, V9.S4, V9.S4 + VADD V14.S4, V10.S4, V10.S4 + VADD V15.S4, V11.S4, V11.S4 + VEOR V8.B16, V4.B16, V16.B16 + VEOR V9.B16, V5.B16, V17.B16 + VEOR V10.B16, V6.B16, V18.B16 + VEOR V11.B16, V7.B16, V19.B16 + VSHL $7, V16.S4, V4.S4 + VSHL $7, V17.S4, V5.S4 + VSHL $7, V18.S4, V6.S4 + VSHL $7, V19.S4, V7.S4 + VSRI $25, V16.S4, V4.S4 + VSRI $25, V17.S4, V5.S4 + VSRI $25, V18.S4, V6.S4 + VSRI $25, V19.S4, V7.S4 + + // V0..V3 += V5..V7, V4 + // V15,V12-V14 <<<= ((V15,V12-V14 XOR V0..V3), 16) + VADD V0.S4, V5.S4, V0.S4 + VADD V1.S4, V6.S4, V1.S4 + VADD V2.S4, V7.S4, V2.S4 + VADD V3.S4, V4.S4, V3.S4 + VEOR V15.B16, V0.B16, V15.B16 + VEOR V12.B16, V1.B16, V12.B16 + VEOR V13.B16, V2.B16, V13.B16 + VEOR V14.B16, V3.B16, V14.B16 + VREV32 V12.H8, V12.H8 + VREV32 V13.H8, V13.H8 + VREV32 V14.H8, V14.H8 + VREV32 V15.H8, V15.H8 + + // V10 += V15; V5 <<<= ((V10 XOR V5), 12) + // ... + VADD V15.S4, V10.S4, V10.S4 + VADD V12.S4, V11.S4, V11.S4 + VADD V13.S4, V8.S4, V8.S4 + VADD V14.S4, V9.S4, V9.S4 + VEOR V10.B16, V5.B16, V16.B16 + VEOR V11.B16, V6.B16, V17.B16 + VEOR V8.B16, V7.B16, V18.B16 + VEOR V9.B16, V4.B16, V19.B16 + VSHL $12, V16.S4, V5.S4 + VSHL $12, V17.S4, V6.S4 + VSHL $12, V18.S4, V7.S4 + VSHL $12, V19.S4, V4.S4 + VSRI $20, V16.S4, V5.S4 + VSRI $20, V17.S4, V6.S4 + VSRI $20, V18.S4, V7.S4 + VSRI $20, V19.S4, V4.S4 + + // V0 += V5; V15 <<<= ((V0 XOR V15), 8) + // ... + VADD V5.S4, V0.S4, V0.S4 + VADD V6.S4, V1.S4, V1.S4 + VADD V7.S4, V2.S4, V2.S4 + VADD V4.S4, V3.S4, V3.S4 + VEOR V0.B16, V15.B16, V15.B16 + VEOR V1.B16, V12.B16, V12.B16 + VEOR V2.B16, V13.B16, V13.B16 + VEOR V3.B16, V14.B16, V14.B16 + VTBL V31.B16, [V12.B16], V12.B16 + VTBL V31.B16, [V13.B16], V13.B16 + VTBL V31.B16, [V14.B16], V14.B16 + VTBL V31.B16, [V15.B16], V15.B16 + + // V10 += V15; V5 <<<= ((V10 XOR V5), 7) + // ... + VADD V15.S4, V10.S4, V10.S4 + VADD V12.S4, V11.S4, V11.S4 + VADD V13.S4, V8.S4, V8.S4 + VADD V14.S4, V9.S4, V9.S4 + VEOR V10.B16, V5.B16, V16.B16 + VEOR V11.B16, V6.B16, V17.B16 + VEOR V8.B16, V7.B16, V18.B16 + VEOR V9.B16, V4.B16, V19.B16 + VSHL $7, V16.S4, V5.S4 + VSHL $7, V17.S4, V6.S4 + VSHL $7, V18.S4, V7.S4 + VSHL $7, V19.S4, V4.S4 + VSRI $25, V16.S4, V5.S4 + VSRI $25, V17.S4, V6.S4 + VSRI $25, V18.S4, V7.S4 + VSRI $25, V19.S4, V4.S4 + + SUB $1, R21 + CBNZ R21, chacha + + // VLD4R (R10), [V16.S4, V17.S4, V18.S4, V19.S4] + WORD $0x4D60E950 + + // VLD4R 16(R4), [V20.S4, V21.S4, V22.S4, V23.S4] + WORD $0x4DFFE894 + VADD V30.S4, V12.S4, V12.S4 + VADD V16.S4, V0.S4, V0.S4 + VADD V17.S4, V1.S4, V1.S4 + VADD V18.S4, V2.S4, V2.S4 + VADD V19.S4, V3.S4, V3.S4 + // VLD4R 16(R4), [V24.S4, V25.S4, V26.S4, V27.S4] + WORD $0x4DFFE898 + // restore R4 + SUB $32, R4 + + // load counter + nonce + // VLD1R (R7), [V28.S4] + WORD $0x4D40C8FC + // VLD3R (R6), [V29.S4, V30.S4, V31.S4] + WORD $0x4D40E8DD + + VADD V20.S4, V4.S4, V4.S4 + VADD V21.S4, V5.S4, V5.S4 + VADD V22.S4, V6.S4, V6.S4 + VADD V23.S4, V7.S4, V7.S4 + VADD V24.S4, V8.S4, V8.S4 + VADD V25.S4, V9.S4, V9.S4 + VADD V26.S4, V10.S4, V10.S4 + VADD V27.S4, V11.S4, V11.S4 + VADD V28.S4, V12.S4, V12.S4 + VADD V29.S4, V13.S4, V13.S4 + VADD V30.S4, V14.S4, V14.S4 + VADD V31.S4, V15.S4, V15.S4 + + VZIP1 V1.S4, V0.S4, V16.S4 + VZIP2 V1.S4, V0.S4, V17.S4 + VZIP1 V3.S4, V2.S4, V18.S4 + VZIP2 V3.S4, V2.S4, V19.S4 + VZIP1 V5.S4, V4.S4, V20.S4 + VZIP2 V5.S4, V4.S4, V21.S4 + VZIP1 V7.S4, V6.S4, V22.S4 + VZIP2 V7.S4, V6.S4, V23.S4 + VZIP1 V9.S4, V8.S4, V24.S4 + VZIP2 V9.S4, V8.S4, V25.S4 + VZIP1 V11.S4, V10.S4, V26.S4 + VZIP2 V11.S4, V10.S4, V27.S4 + VZIP1 V13.S4, V12.S4, V28.S4 + VZIP2 V13.S4, V12.S4, V29.S4 + VZIP1 V15.S4, V14.S4, V30.S4 + VZIP2 V15.S4, V14.S4, V31.S4 + VZIP1 V18.D2, V16.D2, V0.D2 + VZIP2 V18.D2, V16.D2, V4.D2 + VZIP1 V19.D2, V17.D2, V8.D2 + VZIP2 V19.D2, V17.D2, V12.D2 + VLD1.P 64(R2), [V16.B16, V17.B16, V18.B16, V19.B16] + + VZIP1 V22.D2, V20.D2, V1.D2 + VZIP2 V22.D2, V20.D2, V5.D2 + VZIP1 V23.D2, V21.D2, V9.D2 + VZIP2 V23.D2, V21.D2, V13.D2 + VLD1.P 64(R2), [V20.B16, V21.B16, V22.B16, V23.B16] + VZIP1 V26.D2, V24.D2, V2.D2 + VZIP2 V26.D2, V24.D2, V6.D2 + VZIP1 V27.D2, V25.D2, V10.D2 + VZIP2 V27.D2, V25.D2, V14.D2 + VLD1.P 64(R2), [V24.B16, V25.B16, V26.B16, V27.B16] + VZIP1 V30.D2, V28.D2, V3.D2 + VZIP2 V30.D2, V28.D2, V7.D2 + VZIP1 V31.D2, V29.D2, V11.D2 + VZIP2 V31.D2, V29.D2, V15.D2 + VLD1.P 64(R2), [V28.B16, V29.B16, V30.B16, V31.B16] + VEOR V0.B16, V16.B16, V16.B16 + VEOR V1.B16, V17.B16, V17.B16 + VEOR V2.B16, V18.B16, V18.B16 + VEOR V3.B16, V19.B16, V19.B16 + VST1.P [V16.B16, V17.B16, V18.B16, V19.B16], 64(R1) + VEOR V4.B16, V20.B16, V20.B16 + VEOR V5.B16, V21.B16, V21.B16 + VEOR V6.B16, V22.B16, V22.B16 + VEOR V7.B16, V23.B16, V23.B16 + VST1.P [V20.B16, V21.B16, V22.B16, V23.B16], 64(R1) + VEOR V8.B16, V24.B16, V24.B16 + VEOR V9.B16, V25.B16, V25.B16 + VEOR V10.B16, V26.B16, V26.B16 + VEOR V11.B16, V27.B16, V27.B16 + VST1.P [V24.B16, V25.B16, V26.B16, V27.B16], 64(R1) + VEOR V12.B16, V28.B16, V28.B16 + VEOR V13.B16, V29.B16, V29.B16 + VEOR V14.B16, V30.B16, V30.B16 + VEOR V15.B16, V31.B16, V31.B16 + VST1.P [V28.B16, V29.B16, V30.B16, V31.B16], 64(R1) + + ADD $4, R20 + MOVW R20, (R7) // update counter + + CMP R2, R12 + BGT loop + + RET + + +DATA ·constants+0x00(SB)/4, $0x61707865 +DATA ·constants+0x04(SB)/4, $0x3320646e +DATA ·constants+0x08(SB)/4, $0x79622d32 +DATA ·constants+0x0c(SB)/4, $0x6b206574 +GLOBL ·constants(SB), NOPTR|RODATA, $32 + +DATA ·incRotMatrix+0x00(SB)/4, $0x00000000 +DATA ·incRotMatrix+0x04(SB)/4, $0x00000001 +DATA ·incRotMatrix+0x08(SB)/4, $0x00000002 +DATA ·incRotMatrix+0x0c(SB)/4, $0x00000003 +DATA ·incRotMatrix+0x10(SB)/4, $0x02010003 +DATA ·incRotMatrix+0x14(SB)/4, $0x06050407 +DATA ·incRotMatrix+0x18(SB)/4, $0x0A09080B +DATA ·incRotMatrix+0x1c(SB)/4, $0x0E0D0C0F +GLOBL ·incRotMatrix(SB), NOPTR|RODATA, $32 diff --git a/internal/crypto/chacha20/chacha_generic.go b/internal/crypto/chacha20/chacha_generic.go new file mode 100644 index 000000000..a2ecf5c32 --- /dev/null +++ b/internal/crypto/chacha20/chacha_generic.go @@ -0,0 +1,398 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package chacha20 implements the ChaCha20 and XChaCha20 encryption algorithms +// as specified in RFC 8439 and draft-irtf-cfrg-xchacha-01. +package chacha20 + +import ( + "crypto/cipher" + "encoding/binary" + "errors" + "math/bits" + + "golang.org/x/crypto/internal/subtle" +) + +const ( + // KeySize is the size of the key used by this cipher, in bytes. + KeySize = 32 + + // NonceSize is the size of the nonce used with the standard variant of this + // cipher, in bytes. + // + // Note that this is too short to be safely generated at random if the same + // key is reused more than 2³² times. + NonceSize = 12 + + // NonceSizeX is the size of the nonce used with the XChaCha20 variant of + // this cipher, in bytes. + NonceSizeX = 24 +) + +// Cipher is a stateful instance of ChaCha20 or XChaCha20 using a particular key +// and nonce. A *Cipher implements the cipher.Stream interface. +type Cipher struct { + // The ChaCha20 state is 16 words: 4 constant, 8 of key, 1 of counter + // (incremented after each block), and 3 of nonce. + key [8]uint32 + counter uint32 + nonce [3]uint32 + + // The last len bytes of buf are leftover key stream bytes from the previous + // XORKeyStream invocation. The size of buf depends on how many blocks are + // computed at a time by xorKeyStreamBlocks. + buf [bufSize]byte + len int + + // overflow is set when the counter overflowed, no more blocks can be + // generated, and the next XORKeyStream call should panic. + overflow bool + + // The counter-independent results of the first round are cached after they + // are computed the first time. + precompDone bool + p1, p5, p9, p13 uint32 + p2, p6, p10, p14 uint32 + p3, p7, p11, p15 uint32 +} + +var _ cipher.Stream = (*Cipher)(nil) + +// NewUnauthenticatedCipher creates a new ChaCha20 stream cipher with the given +// 32 bytes key and a 12 or 24 bytes nonce. If a nonce of 24 bytes is provided, +// the XChaCha20 construction will be used. It returns an error if key or nonce +// have any other length. +// +// Note that ChaCha20, like all stream ciphers, is not authenticated and allows +// attackers to silently tamper with the plaintext. For this reason, it is more +// appropriate as a building block than as a standalone encryption mechanism. +// Instead, consider using package golang.org/x/crypto/chacha20poly1305. +func NewUnauthenticatedCipher(key, nonce []byte) (*Cipher, error) { + // This function is split into a wrapper so that the Cipher allocation will + // be inlined, and depending on how the caller uses the return value, won't + // escape to the heap. + c := &Cipher{} + return newUnauthenticatedCipher(c, key, nonce) +} + +func newUnauthenticatedCipher(c *Cipher, key, nonce []byte) (*Cipher, error) { + if len(key) != KeySize { + return nil, errors.New("chacha20: wrong key size") + } + if len(nonce) == NonceSizeX { + // XChaCha20 uses the ChaCha20 core to mix 16 bytes of the nonce into a + // derived key, allowing it to operate on a nonce of 24 bytes. See + // draft-irtf-cfrg-xchacha-01, Section 2.3. + key, _ = HChaCha20(key, nonce[0:16]) + cNonce := make([]byte, NonceSize) + copy(cNonce[4:12], nonce[16:24]) + nonce = cNonce + } else if len(nonce) != NonceSize { + return nil, errors.New("chacha20: wrong nonce size") + } + + key, nonce = key[:KeySize], nonce[:NonceSize] // bounds check elimination hint + c.key = [8]uint32{ + binary.LittleEndian.Uint32(key[0:4]), + binary.LittleEndian.Uint32(key[4:8]), + binary.LittleEndian.Uint32(key[8:12]), + binary.LittleEndian.Uint32(key[12:16]), + binary.LittleEndian.Uint32(key[16:20]), + binary.LittleEndian.Uint32(key[20:24]), + binary.LittleEndian.Uint32(key[24:28]), + binary.LittleEndian.Uint32(key[28:32]), + } + c.nonce = [3]uint32{ + binary.LittleEndian.Uint32(nonce[0:4]), + binary.LittleEndian.Uint32(nonce[4:8]), + binary.LittleEndian.Uint32(nonce[8:12]), + } + return c, nil +} + +// The constant first 4 words of the ChaCha20 state. +const ( + j0 uint32 = 0x61707865 // expa + j1 uint32 = 0x3320646e // nd 3 + j2 uint32 = 0x79622d32 // 2-by + j3 uint32 = 0x6b206574 // te k +) + +const blockSize = 64 + +// quarterRound is the core of ChaCha20. It shuffles the bits of 4 state words. +// It's executed 4 times for each of the 20 ChaCha20 rounds, operating on all 16 +// words each round, in columnar or diagonal groups of 4 at a time. +func quarterRound(a, b, c, d uint32) (uint32, uint32, uint32, uint32) { + a += b + d ^= a + d = bits.RotateLeft32(d, 16) + c += d + b ^= c + b = bits.RotateLeft32(b, 12) + a += b + d ^= a + d = bits.RotateLeft32(d, 8) + c += d + b ^= c + b = bits.RotateLeft32(b, 7) + return a, b, c, d +} + +// SetCounter sets the Cipher counter. The next invocation of XORKeyStream will +// behave as if (64 * counter) bytes had been encrypted so far. +// +// To prevent accidental counter reuse, SetCounter panics if counter is less +// than the current value. +// +// Note that the execution time of XORKeyStream is not independent of the +// counter value. +func (s *Cipher) SetCounter(counter uint32) { + // Internally, s may buffer multiple blocks, which complicates this + // implementation slightly. When checking whether the counter has rolled + // back, we must use both s.counter and s.len to determine how many blocks + // we have already output. + outputCounter := s.counter - uint32(s.len)/blockSize + if s.overflow || counter < outputCounter { + panic("chacha20: SetCounter attempted to rollback counter") + } + + // In the general case, we set the new counter value and reset s.len to 0, + // causing the next call to XORKeyStream to refill the buffer. However, if + // we're advancing within the existing buffer, we can save work by simply + // setting s.len. + if counter < s.counter { + s.len = int(s.counter-counter) * blockSize + } else { + s.counter = counter + s.len = 0 + } +} + +// XORKeyStream XORs each byte in the given slice with a byte from the +// cipher's key stream. Dst and src must overlap entirely or not at all. +// +// If len(dst) < len(src), XORKeyStream will panic. It is acceptable +// to pass a dst bigger than src, and in that case, XORKeyStream will +// only update dst[:len(src)] and will not touch the rest of dst. +// +// Multiple calls to XORKeyStream behave as if the concatenation of +// the src buffers was passed in a single run. That is, Cipher +// maintains state and does not reset at each XORKeyStream call. +func (s *Cipher) XORKeyStream(dst, src []byte) { + if len(src) == 0 { + return + } + if len(dst) < len(src) { + panic("chacha20: output smaller than input") + } + dst = dst[:len(src)] + if subtle.InexactOverlap(dst, src) { + panic("chacha20: invalid buffer overlap") + } + + // First, drain any remaining key stream from a previous XORKeyStream. + if s.len != 0 { + keyStream := s.buf[bufSize-s.len:] + if len(src) < len(keyStream) { + keyStream = keyStream[:len(src)] + } + _ = src[len(keyStream)-1] // bounds check elimination hint + for i, b := range keyStream { + dst[i] = src[i] ^ b + } + s.len -= len(keyStream) + dst, src = dst[len(keyStream):], src[len(keyStream):] + } + if len(src) == 0 { + return + } + + // If we'd need to let the counter overflow and keep generating output, + // panic immediately. If instead we'd only reach the last block, remember + // not to generate any more output after the buffer is drained. + numBlocks := (uint64(len(src)) + blockSize - 1) / blockSize + if s.overflow || uint64(s.counter)+numBlocks > 1<<32 { + panic("chacha20: counter overflow") + } else if uint64(s.counter)+numBlocks == 1<<32 { + s.overflow = true + } + + // xorKeyStreamBlocks implementations expect input lengths that are a + // multiple of bufSize. Platform-specific ones process multiple blocks at a + // time, so have bufSizes that are a multiple of blockSize. + + full := len(src) - len(src)%bufSize + if full > 0 { + s.xorKeyStreamBlocks(dst[:full], src[:full]) + } + dst, src = dst[full:], src[full:] + + // If using a multi-block xorKeyStreamBlocks would overflow, use the generic + // one that does one block at a time. + const blocksPerBuf = bufSize / blockSize + if uint64(s.counter)+blocksPerBuf > 1<<32 { + s.buf = [bufSize]byte{} + numBlocks := (len(src) + blockSize - 1) / blockSize + buf := s.buf[bufSize-numBlocks*blockSize:] + copy(buf, src) + s.xorKeyStreamBlocksGeneric(buf, buf) + s.len = len(buf) - copy(dst, buf) + return + } + + // If we have a partial (multi-)block, pad it for xorKeyStreamBlocks, and + // keep the leftover keystream for the next XORKeyStream invocation. + if len(src) > 0 { + s.buf = [bufSize]byte{} + copy(s.buf[:], src) + s.xorKeyStreamBlocks(s.buf[:], s.buf[:]) + s.len = bufSize - copy(dst, s.buf[:]) + } +} + +func (s *Cipher) xorKeyStreamBlocksGeneric(dst, src []byte) { + if len(dst) != len(src) || len(dst)%blockSize != 0 { + panic("chacha20: internal error: wrong dst and/or src length") + } + + // To generate each block of key stream, the initial cipher state + // (represented below) is passed through 20 rounds of shuffling, + // alternatively applying quarterRounds by columns (like 1, 5, 9, 13) + // or by diagonals (like 1, 6, 11, 12). + // + // 0:cccccccc 1:cccccccc 2:cccccccc 3:cccccccc + // 4:kkkkkkkk 5:kkkkkkkk 6:kkkkkkkk 7:kkkkkkkk + // 8:kkkkkkkk 9:kkkkkkkk 10:kkkkkkkk 11:kkkkkkkk + // 12:bbbbbbbb 13:nnnnnnnn 14:nnnnnnnn 15:nnnnnnnn + // + // c=constant k=key b=blockcount n=nonce + var ( + c0, c1, c2, c3 = j0, j1, j2, j3 + c4, c5, c6, c7 = s.key[0], s.key[1], s.key[2], s.key[3] + c8, c9, c10, c11 = s.key[4], s.key[5], s.key[6], s.key[7] + _, c13, c14, c15 = s.counter, s.nonce[0], s.nonce[1], s.nonce[2] + ) + + // Three quarters of the first round don't depend on the counter, so we can + // calculate them here, and reuse them for multiple blocks in the loop, and + // for future XORKeyStream invocations. + if !s.precompDone { + s.p1, s.p5, s.p9, s.p13 = quarterRound(c1, c5, c9, c13) + s.p2, s.p6, s.p10, s.p14 = quarterRound(c2, c6, c10, c14) + s.p3, s.p7, s.p11, s.p15 = quarterRound(c3, c7, c11, c15) + s.precompDone = true + } + + // A condition of len(src) > 0 would be sufficient, but this also + // acts as a bounds check elimination hint. + for len(src) >= 64 && len(dst) >= 64 { + // The remainder of the first column round. + fcr0, fcr4, fcr8, fcr12 := quarterRound(c0, c4, c8, s.counter) + + // The second diagonal round. + x0, x5, x10, x15 := quarterRound(fcr0, s.p5, s.p10, s.p15) + x1, x6, x11, x12 := quarterRound(s.p1, s.p6, s.p11, fcr12) + x2, x7, x8, x13 := quarterRound(s.p2, s.p7, fcr8, s.p13) + x3, x4, x9, x14 := quarterRound(s.p3, fcr4, s.p9, s.p14) + + // The remaining 18 rounds. + for i := 0; i < 9; i++ { + // Column round. + x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12) + x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13) + x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14) + x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15) + + // Diagonal round. + x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15) + x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12) + x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13) + x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14) + } + + // Add back the initial state to generate the key stream, then + // XOR the key stream with the source and write out the result. + addXor(dst[0:4], src[0:4], x0, c0) + addXor(dst[4:8], src[4:8], x1, c1) + addXor(dst[8:12], src[8:12], x2, c2) + addXor(dst[12:16], src[12:16], x3, c3) + addXor(dst[16:20], src[16:20], x4, c4) + addXor(dst[20:24], src[20:24], x5, c5) + addXor(dst[24:28], src[24:28], x6, c6) + addXor(dst[28:32], src[28:32], x7, c7) + addXor(dst[32:36], src[32:36], x8, c8) + addXor(dst[36:40], src[36:40], x9, c9) + addXor(dst[40:44], src[40:44], x10, c10) + addXor(dst[44:48], src[44:48], x11, c11) + addXor(dst[48:52], src[48:52], x12, s.counter) + addXor(dst[52:56], src[52:56], x13, c13) + addXor(dst[56:60], src[56:60], x14, c14) + addXor(dst[60:64], src[60:64], x15, c15) + + s.counter += 1 + + src, dst = src[blockSize:], dst[blockSize:] + } +} + +// HChaCha20 uses the ChaCha20 core to generate a derived key from a 32 bytes +// key and a 16 bytes nonce. It returns an error if key or nonce have any other +// length. It is used as part of the XChaCha20 construction. +func HChaCha20(key, nonce []byte) ([]byte, error) { + // This function is split into a wrapper so that the slice allocation will + // be inlined, and depending on how the caller uses the return value, won't + // escape to the heap. + out := make([]byte, 32) + return hChaCha20(out, key, nonce) +} + +func hChaCha20(out, key, nonce []byte) ([]byte, error) { + if len(key) != KeySize { + return nil, errors.New("chacha20: wrong HChaCha20 key size") + } + if len(nonce) != 16 { + return nil, errors.New("chacha20: wrong HChaCha20 nonce size") + } + + x0, x1, x2, x3 := j0, j1, j2, j3 + x4 := binary.LittleEndian.Uint32(key[0:4]) + x5 := binary.LittleEndian.Uint32(key[4:8]) + x6 := binary.LittleEndian.Uint32(key[8:12]) + x7 := binary.LittleEndian.Uint32(key[12:16]) + x8 := binary.LittleEndian.Uint32(key[16:20]) + x9 := binary.LittleEndian.Uint32(key[20:24]) + x10 := binary.LittleEndian.Uint32(key[24:28]) + x11 := binary.LittleEndian.Uint32(key[28:32]) + x12 := binary.LittleEndian.Uint32(nonce[0:4]) + x13 := binary.LittleEndian.Uint32(nonce[4:8]) + x14 := binary.LittleEndian.Uint32(nonce[8:12]) + x15 := binary.LittleEndian.Uint32(nonce[12:16]) + + for i := 0; i < 10; i++ { + // Diagonal round. + x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12) + x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13) + x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14) + x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15) + + // Column round. + x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15) + x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12) + x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13) + x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14) + } + + _ = out[31] // bounds check elimination hint + binary.LittleEndian.PutUint32(out[0:4], x0) + binary.LittleEndian.PutUint32(out[4:8], x1) + binary.LittleEndian.PutUint32(out[8:12], x2) + binary.LittleEndian.PutUint32(out[12:16], x3) + binary.LittleEndian.PutUint32(out[16:20], x12) + binary.LittleEndian.PutUint32(out[20:24], x13) + binary.LittleEndian.PutUint32(out[24:28], x14) + binary.LittleEndian.PutUint32(out[28:32], x15) + return out, nil +} diff --git a/internal/crypto/chacha20/chacha_noasm.go b/internal/crypto/chacha20/chacha_noasm.go new file mode 100644 index 000000000..4635307b8 --- /dev/null +++ b/internal/crypto/chacha20/chacha_noasm.go @@ -0,0 +1,13 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !arm64,!s390x,!ppc64le arm64,!go1.11 gccgo purego + +package chacha20 + +const bufSize = blockSize + +func (s *Cipher) xorKeyStreamBlocks(dst, src []byte) { + s.xorKeyStreamBlocksGeneric(dst, src) +} diff --git a/internal/crypto/chacha20/chacha_ppc64le.go b/internal/crypto/chacha20/chacha_ppc64le.go new file mode 100644 index 000000000..b79933034 --- /dev/null +++ b/internal/crypto/chacha20/chacha_ppc64le.go @@ -0,0 +1,16 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !gccgo,!purego + +package chacha20 + +const bufSize = 256 + +//go:noescape +func chaCha20_ctr32_vsx(out, inp *byte, len int, key *[8]uint32, counter *uint32) + +func (c *Cipher) xorKeyStreamBlocks(dst, src []byte) { + chaCha20_ctr32_vsx(&dst[0], &src[0], len(src), &c.key, &c.counter) +} diff --git a/internal/crypto/chacha20/chacha_ppc64le.s b/internal/crypto/chacha20/chacha_ppc64le.s new file mode 100644 index 000000000..23c602164 --- /dev/null +++ b/internal/crypto/chacha20/chacha_ppc64le.s @@ -0,0 +1,449 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Based on CRYPTOGAMS code with the following comment: +// # ==================================================================== +// # Written by Andy Polyakov for the OpenSSL +// # project. The module is, however, dual licensed under OpenSSL and +// # CRYPTOGAMS licenses depending on where you obtain it. For further +// # details see http://www.openssl.org/~appro/cryptogams/. +// # ==================================================================== + +// Code for the perl script that generates the ppc64 assembler +// can be found in the cryptogams repository at the link below. It is based on +// the original from openssl. + +// https://github.com/dot-asm/cryptogams/commit/a60f5b50ed908e91 + +// The differences in this and the original implementation are +// due to the calling conventions and initialization of constants. + +// +build !gccgo,!purego + +#include "textflag.h" + +#define OUT R3 +#define INP R4 +#define LEN R5 +#define KEY R6 +#define CNT R7 +#define TMP R15 + +#define CONSTBASE R16 +#define BLOCKS R17 + +DATA consts<>+0x00(SB)/8, $0x3320646e61707865 +DATA consts<>+0x08(SB)/8, $0x6b20657479622d32 +DATA consts<>+0x10(SB)/8, $0x0000000000000001 +DATA consts<>+0x18(SB)/8, $0x0000000000000000 +DATA consts<>+0x20(SB)/8, $0x0000000000000004 +DATA consts<>+0x28(SB)/8, $0x0000000000000000 +DATA consts<>+0x30(SB)/8, $0x0a0b08090e0f0c0d +DATA consts<>+0x38(SB)/8, $0x0203000106070405 +DATA consts<>+0x40(SB)/8, $0x090a0b080d0e0f0c +DATA consts<>+0x48(SB)/8, $0x0102030005060704 +DATA consts<>+0x50(SB)/8, $0x6170786561707865 +DATA consts<>+0x58(SB)/8, $0x6170786561707865 +DATA consts<>+0x60(SB)/8, $0x3320646e3320646e +DATA consts<>+0x68(SB)/8, $0x3320646e3320646e +DATA consts<>+0x70(SB)/8, $0x79622d3279622d32 +DATA consts<>+0x78(SB)/8, $0x79622d3279622d32 +DATA consts<>+0x80(SB)/8, $0x6b2065746b206574 +DATA consts<>+0x88(SB)/8, $0x6b2065746b206574 +DATA consts<>+0x90(SB)/8, $0x0000000100000000 +DATA consts<>+0x98(SB)/8, $0x0000000300000002 +GLOBL consts<>(SB), RODATA, $0xa0 + +//func chaCha20_ctr32_vsx(out, inp *byte, len int, key *[8]uint32, counter *uint32) +TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40 + MOVD out+0(FP), OUT + MOVD inp+8(FP), INP + MOVD len+16(FP), LEN + MOVD key+24(FP), KEY + MOVD counter+32(FP), CNT + + // Addressing for constants + MOVD $consts<>+0x00(SB), CONSTBASE + MOVD $16, R8 + MOVD $32, R9 + MOVD $48, R10 + MOVD $64, R11 + SRD $6, LEN, BLOCKS + // V16 + LXVW4X (CONSTBASE)(R0), VS48 + ADD $80,CONSTBASE + + // Load key into V17,V18 + LXVW4X (KEY)(R0), VS49 + LXVW4X (KEY)(R8), VS50 + + // Load CNT, NONCE into V19 + LXVW4X (CNT)(R0), VS51 + + // Clear V27 + VXOR V27, V27, V27 + + // V28 + LXVW4X (CONSTBASE)(R11), VS60 + + // splat slot from V19 -> V26 + VSPLTW $0, V19, V26 + + VSLDOI $4, V19, V27, V19 + VSLDOI $12, V27, V19, V19 + + VADDUWM V26, V28, V26 + + MOVD $10, R14 + MOVD R14, CTR + +loop_outer_vsx: + // V0, V1, V2, V3 + LXVW4X (R0)(CONSTBASE), VS32 + LXVW4X (R8)(CONSTBASE), VS33 + LXVW4X (R9)(CONSTBASE), VS34 + LXVW4X (R10)(CONSTBASE), VS35 + + // splat values from V17, V18 into V4-V11 + VSPLTW $0, V17, V4 + VSPLTW $1, V17, V5 + VSPLTW $2, V17, V6 + VSPLTW $3, V17, V7 + VSPLTW $0, V18, V8 + VSPLTW $1, V18, V9 + VSPLTW $2, V18, V10 + VSPLTW $3, V18, V11 + + // VOR + VOR V26, V26, V12 + + // splat values from V19 -> V13, V14, V15 + VSPLTW $1, V19, V13 + VSPLTW $2, V19, V14 + VSPLTW $3, V19, V15 + + // splat const values + VSPLTISW $-16, V27 + VSPLTISW $12, V28 + VSPLTISW $8, V29 + VSPLTISW $7, V30 + +loop_vsx: + VADDUWM V0, V4, V0 + VADDUWM V1, V5, V1 + VADDUWM V2, V6, V2 + VADDUWM V3, V7, V3 + + VXOR V12, V0, V12 + VXOR V13, V1, V13 + VXOR V14, V2, V14 + VXOR V15, V3, V15 + + VRLW V12, V27, V12 + VRLW V13, V27, V13 + VRLW V14, V27, V14 + VRLW V15, V27, V15 + + VADDUWM V8, V12, V8 + VADDUWM V9, V13, V9 + VADDUWM V10, V14, V10 + VADDUWM V11, V15, V11 + + VXOR V4, V8, V4 + VXOR V5, V9, V5 + VXOR V6, V10, V6 + VXOR V7, V11, V7 + + VRLW V4, V28, V4 + VRLW V5, V28, V5 + VRLW V6, V28, V6 + VRLW V7, V28, V7 + + VADDUWM V0, V4, V0 + VADDUWM V1, V5, V1 + VADDUWM V2, V6, V2 + VADDUWM V3, V7, V3 + + VXOR V12, V0, V12 + VXOR V13, V1, V13 + VXOR V14, V2, V14 + VXOR V15, V3, V15 + + VRLW V12, V29, V12 + VRLW V13, V29, V13 + VRLW V14, V29, V14 + VRLW V15, V29, V15 + + VADDUWM V8, V12, V8 + VADDUWM V9, V13, V9 + VADDUWM V10, V14, V10 + VADDUWM V11, V15, V11 + + VXOR V4, V8, V4 + VXOR V5, V9, V5 + VXOR V6, V10, V6 + VXOR V7, V11, V7 + + VRLW V4, V30, V4 + VRLW V5, V30, V5 + VRLW V6, V30, V6 + VRLW V7, V30, V7 + + VADDUWM V0, V5, V0 + VADDUWM V1, V6, V1 + VADDUWM V2, V7, V2 + VADDUWM V3, V4, V3 + + VXOR V15, V0, V15 + VXOR V12, V1, V12 + VXOR V13, V2, V13 + VXOR V14, V3, V14 + + VRLW V15, V27, V15 + VRLW V12, V27, V12 + VRLW V13, V27, V13 + VRLW V14, V27, V14 + + VADDUWM V10, V15, V10 + VADDUWM V11, V12, V11 + VADDUWM V8, V13, V8 + VADDUWM V9, V14, V9 + + VXOR V5, V10, V5 + VXOR V6, V11, V6 + VXOR V7, V8, V7 + VXOR V4, V9, V4 + + VRLW V5, V28, V5 + VRLW V6, V28, V6 + VRLW V7, V28, V7 + VRLW V4, V28, V4 + + VADDUWM V0, V5, V0 + VADDUWM V1, V6, V1 + VADDUWM V2, V7, V2 + VADDUWM V3, V4, V3 + + VXOR V15, V0, V15 + VXOR V12, V1, V12 + VXOR V13, V2, V13 + VXOR V14, V3, V14 + + VRLW V15, V29, V15 + VRLW V12, V29, V12 + VRLW V13, V29, V13 + VRLW V14, V29, V14 + + VADDUWM V10, V15, V10 + VADDUWM V11, V12, V11 + VADDUWM V8, V13, V8 + VADDUWM V9, V14, V9 + + VXOR V5, V10, V5 + VXOR V6, V11, V6 + VXOR V7, V8, V7 + VXOR V4, V9, V4 + + VRLW V5, V30, V5 + VRLW V6, V30, V6 + VRLW V7, V30, V7 + VRLW V4, V30, V4 + BC 16, LT, loop_vsx + + VADDUWM V12, V26, V12 + + WORD $0x13600F8C // VMRGEW V0, V1, V27 + WORD $0x13821F8C // VMRGEW V2, V3, V28 + + WORD $0x10000E8C // VMRGOW V0, V1, V0 + WORD $0x10421E8C // VMRGOW V2, V3, V2 + + WORD $0x13A42F8C // VMRGEW V4, V5, V29 + WORD $0x13C63F8C // VMRGEW V6, V7, V30 + + XXPERMDI VS32, VS34, $0, VS33 + XXPERMDI VS32, VS34, $3, VS35 + XXPERMDI VS59, VS60, $0, VS32 + XXPERMDI VS59, VS60, $3, VS34 + + WORD $0x10842E8C // VMRGOW V4, V5, V4 + WORD $0x10C63E8C // VMRGOW V6, V7, V6 + + WORD $0x13684F8C // VMRGEW V8, V9, V27 + WORD $0x138A5F8C // VMRGEW V10, V11, V28 + + XXPERMDI VS36, VS38, $0, VS37 + XXPERMDI VS36, VS38, $3, VS39 + XXPERMDI VS61, VS62, $0, VS36 + XXPERMDI VS61, VS62, $3, VS38 + + WORD $0x11084E8C // VMRGOW V8, V9, V8 + WORD $0x114A5E8C // VMRGOW V10, V11, V10 + + WORD $0x13AC6F8C // VMRGEW V12, V13, V29 + WORD $0x13CE7F8C // VMRGEW V14, V15, V30 + + XXPERMDI VS40, VS42, $0, VS41 + XXPERMDI VS40, VS42, $3, VS43 + XXPERMDI VS59, VS60, $0, VS40 + XXPERMDI VS59, VS60, $3, VS42 + + WORD $0x118C6E8C // VMRGOW V12, V13, V12 + WORD $0x11CE7E8C // VMRGOW V14, V15, V14 + + VSPLTISW $4, V27 + VADDUWM V26, V27, V26 + + XXPERMDI VS44, VS46, $0, VS45 + XXPERMDI VS44, VS46, $3, VS47 + XXPERMDI VS61, VS62, $0, VS44 + XXPERMDI VS61, VS62, $3, VS46 + + VADDUWM V0, V16, V0 + VADDUWM V4, V17, V4 + VADDUWM V8, V18, V8 + VADDUWM V12, V19, V12 + + CMPU LEN, $64 + BLT tail_vsx + + // Bottom of loop + LXVW4X (INP)(R0), VS59 + LXVW4X (INP)(R8), VS60 + LXVW4X (INP)(R9), VS61 + LXVW4X (INP)(R10), VS62 + + VXOR V27, V0, V27 + VXOR V28, V4, V28 + VXOR V29, V8, V29 + VXOR V30, V12, V30 + + STXVW4X VS59, (OUT)(R0) + STXVW4X VS60, (OUT)(R8) + ADD $64, INP + STXVW4X VS61, (OUT)(R9) + ADD $-64, LEN + STXVW4X VS62, (OUT)(R10) + ADD $64, OUT + BEQ done_vsx + + VADDUWM V1, V16, V0 + VADDUWM V5, V17, V4 + VADDUWM V9, V18, V8 + VADDUWM V13, V19, V12 + + CMPU LEN, $64 + BLT tail_vsx + + LXVW4X (INP)(R0), VS59 + LXVW4X (INP)(R8), VS60 + LXVW4X (INP)(R9), VS61 + LXVW4X (INP)(R10), VS62 + VXOR V27, V0, V27 + + VXOR V28, V4, V28 + VXOR V29, V8, V29 + VXOR V30, V12, V30 + + STXVW4X VS59, (OUT)(R0) + STXVW4X VS60, (OUT)(R8) + ADD $64, INP + STXVW4X VS61, (OUT)(R9) + ADD $-64, LEN + STXVW4X VS62, (OUT)(V10) + ADD $64, OUT + BEQ done_vsx + + VADDUWM V2, V16, V0 + VADDUWM V6, V17, V4 + VADDUWM V10, V18, V8 + VADDUWM V14, V19, V12 + + CMPU LEN, $64 + BLT tail_vsx + + LXVW4X (INP)(R0), VS59 + LXVW4X (INP)(R8), VS60 + LXVW4X (INP)(R9), VS61 + LXVW4X (INP)(R10), VS62 + + VXOR V27, V0, V27 + VXOR V28, V4, V28 + VXOR V29, V8, V29 + VXOR V30, V12, V30 + + STXVW4X VS59, (OUT)(R0) + STXVW4X VS60, (OUT)(R8) + ADD $64, INP + STXVW4X VS61, (OUT)(R9) + ADD $-64, LEN + STXVW4X VS62, (OUT)(R10) + ADD $64, OUT + BEQ done_vsx + + VADDUWM V3, V16, V0 + VADDUWM V7, V17, V4 + VADDUWM V11, V18, V8 + VADDUWM V15, V19, V12 + + CMPU LEN, $64 + BLT tail_vsx + + LXVW4X (INP)(R0), VS59 + LXVW4X (INP)(R8), VS60 + LXVW4X (INP)(R9), VS61 + LXVW4X (INP)(R10), VS62 + + VXOR V27, V0, V27 + VXOR V28, V4, V28 + VXOR V29, V8, V29 + VXOR V30, V12, V30 + + STXVW4X VS59, (OUT)(R0) + STXVW4X VS60, (OUT)(R8) + ADD $64, INP + STXVW4X VS61, (OUT)(R9) + ADD $-64, LEN + STXVW4X VS62, (OUT)(R10) + ADD $64, OUT + + MOVD $10, R14 + MOVD R14, CTR + BNE loop_outer_vsx + +done_vsx: + // Increment counter by number of 64 byte blocks + MOVD (CNT), R14 + ADD BLOCKS, R14 + MOVD R14, (CNT) + RET + +tail_vsx: + ADD $32, R1, R11 + MOVD LEN, CTR + + // Save values on stack to copy from + STXVW4X VS32, (R11)(R0) + STXVW4X VS36, (R11)(R8) + STXVW4X VS40, (R11)(R9) + STXVW4X VS44, (R11)(R10) + ADD $-1, R11, R12 + ADD $-1, INP + ADD $-1, OUT + +looptail_vsx: + // Copying the result to OUT + // in bytes. + MOVBZU 1(R12), KEY + MOVBZU 1(INP), TMP + XOR KEY, TMP, KEY + MOVBU KEY, 1(OUT) + BC 16, LT, looptail_vsx + + // Clear the stack values + STXVW4X VS48, (R11)(R0) + STXVW4X VS48, (R11)(R8) + STXVW4X VS48, (R11)(R9) + STXVW4X VS48, (R11)(R10) + BR done_vsx diff --git a/internal/crypto/chacha20/chacha_s390x.go b/internal/crypto/chacha20/chacha_s390x.go new file mode 100644 index 000000000..a9244bdf4 --- /dev/null +++ b/internal/crypto/chacha20/chacha_s390x.go @@ -0,0 +1,26 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !gccgo,!purego + +package chacha20 + +import "golang.org/x/sys/cpu" + +var haveAsm = cpu.S390X.HasVX + +const bufSize = 256 + +// xorKeyStreamVX is an assembly implementation of XORKeyStream. It must only +// be called when the vector facility is available. Implementation in asm_s390x.s. +//go:noescape +func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32) + +func (c *Cipher) xorKeyStreamBlocks(dst, src []byte) { + if cpu.S390X.HasVX { + xorKeyStreamVX(dst, src, &c.key, &c.nonce, &c.counter) + } else { + c.xorKeyStreamBlocksGeneric(dst, src) + } +} diff --git a/internal/crypto/chacha20/chacha_s390x.s b/internal/crypto/chacha20/chacha_s390x.s new file mode 100644 index 000000000..89c658c41 --- /dev/null +++ b/internal/crypto/chacha20/chacha_s390x.s @@ -0,0 +1,224 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !gccgo,!purego + +#include "go_asm.h" +#include "textflag.h" + +// This is an implementation of the ChaCha20 encryption algorithm as +// specified in RFC 7539. It uses vector instructions to compute +// 4 keystream blocks in parallel (256 bytes) which are then XORed +// with the bytes in the input slice. + +GLOBL ·constants<>(SB), RODATA|NOPTR, $32 +// BSWAP: swap bytes in each 4-byte element +DATA ·constants<>+0x00(SB)/4, $0x03020100 +DATA ·constants<>+0x04(SB)/4, $0x07060504 +DATA ·constants<>+0x08(SB)/4, $0x0b0a0908 +DATA ·constants<>+0x0c(SB)/4, $0x0f0e0d0c +// J0: [j0, j1, j2, j3] +DATA ·constants<>+0x10(SB)/4, $0x61707865 +DATA ·constants<>+0x14(SB)/4, $0x3320646e +DATA ·constants<>+0x18(SB)/4, $0x79622d32 +DATA ·constants<>+0x1c(SB)/4, $0x6b206574 + +#define BSWAP V5 +#define J0 V6 +#define KEY0 V7 +#define KEY1 V8 +#define NONCE V9 +#define CTR V10 +#define M0 V11 +#define M1 V12 +#define M2 V13 +#define M3 V14 +#define INC V15 +#define X0 V16 +#define X1 V17 +#define X2 V18 +#define X3 V19 +#define X4 V20 +#define X5 V21 +#define X6 V22 +#define X7 V23 +#define X8 V24 +#define X9 V25 +#define X10 V26 +#define X11 V27 +#define X12 V28 +#define X13 V29 +#define X14 V30 +#define X15 V31 + +#define NUM_ROUNDS 20 + +#define ROUND4(a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3, d0, d1, d2, d3) \ + VAF a1, a0, a0 \ + VAF b1, b0, b0 \ + VAF c1, c0, c0 \ + VAF d1, d0, d0 \ + VX a0, a2, a2 \ + VX b0, b2, b2 \ + VX c0, c2, c2 \ + VX d0, d2, d2 \ + VERLLF $16, a2, a2 \ + VERLLF $16, b2, b2 \ + VERLLF $16, c2, c2 \ + VERLLF $16, d2, d2 \ + VAF a2, a3, a3 \ + VAF b2, b3, b3 \ + VAF c2, c3, c3 \ + VAF d2, d3, d3 \ + VX a3, a1, a1 \ + VX b3, b1, b1 \ + VX c3, c1, c1 \ + VX d3, d1, d1 \ + VERLLF $12, a1, a1 \ + VERLLF $12, b1, b1 \ + VERLLF $12, c1, c1 \ + VERLLF $12, d1, d1 \ + VAF a1, a0, a0 \ + VAF b1, b0, b0 \ + VAF c1, c0, c0 \ + VAF d1, d0, d0 \ + VX a0, a2, a2 \ + VX b0, b2, b2 \ + VX c0, c2, c2 \ + VX d0, d2, d2 \ + VERLLF $8, a2, a2 \ + VERLLF $8, b2, b2 \ + VERLLF $8, c2, c2 \ + VERLLF $8, d2, d2 \ + VAF a2, a3, a3 \ + VAF b2, b3, b3 \ + VAF c2, c3, c3 \ + VAF d2, d3, d3 \ + VX a3, a1, a1 \ + VX b3, b1, b1 \ + VX c3, c1, c1 \ + VX d3, d1, d1 \ + VERLLF $7, a1, a1 \ + VERLLF $7, b1, b1 \ + VERLLF $7, c1, c1 \ + VERLLF $7, d1, d1 + +#define PERMUTE(mask, v0, v1, v2, v3) \ + VPERM v0, v0, mask, v0 \ + VPERM v1, v1, mask, v1 \ + VPERM v2, v2, mask, v2 \ + VPERM v3, v3, mask, v3 + +#define ADDV(x, v0, v1, v2, v3) \ + VAF x, v0, v0 \ + VAF x, v1, v1 \ + VAF x, v2, v2 \ + VAF x, v3, v3 + +#define XORV(off, dst, src, v0, v1, v2, v3) \ + VLM off(src), M0, M3 \ + PERMUTE(BSWAP, v0, v1, v2, v3) \ + VX v0, M0, M0 \ + VX v1, M1, M1 \ + VX v2, M2, M2 \ + VX v3, M3, M3 \ + VSTM M0, M3, off(dst) + +#define SHUFFLE(a, b, c, d, t, u, v, w) \ + VMRHF a, c, t \ // t = {a[0], c[0], a[1], c[1]} + VMRHF b, d, u \ // u = {b[0], d[0], b[1], d[1]} + VMRLF a, c, v \ // v = {a[2], c[2], a[3], c[3]} + VMRLF b, d, w \ // w = {b[2], d[2], b[3], d[3]} + VMRHF t, u, a \ // a = {a[0], b[0], c[0], d[0]} + VMRLF t, u, b \ // b = {a[1], b[1], c[1], d[1]} + VMRHF v, w, c \ // c = {a[2], b[2], c[2], d[2]} + VMRLF v, w, d // d = {a[3], b[3], c[3], d[3]} + +// func xorKeyStreamVX(dst, src []byte, key *[8]uint32, nonce *[3]uint32, counter *uint32) +TEXT ·xorKeyStreamVX(SB), NOSPLIT, $0 + MOVD $·constants<>(SB), R1 + MOVD dst+0(FP), R2 // R2=&dst[0] + LMG src+24(FP), R3, R4 // R3=&src[0] R4=len(src) + MOVD key+48(FP), R5 // R5=key + MOVD nonce+56(FP), R6 // R6=nonce + MOVD counter+64(FP), R7 // R7=counter + + // load BSWAP and J0 + VLM (R1), BSWAP, J0 + + // setup + MOVD $95, R0 + VLM (R5), KEY0, KEY1 + VLL R0, (R6), NONCE + VZERO M0 + VLEIB $7, $32, M0 + VSRLB M0, NONCE, NONCE + + // initialize counter values + VLREPF (R7), CTR + VZERO INC + VLEIF $1, $1, INC + VLEIF $2, $2, INC + VLEIF $3, $3, INC + VAF INC, CTR, CTR + VREPIF $4, INC + +chacha: + VREPF $0, J0, X0 + VREPF $1, J0, X1 + VREPF $2, J0, X2 + VREPF $3, J0, X3 + VREPF $0, KEY0, X4 + VREPF $1, KEY0, X5 + VREPF $2, KEY0, X6 + VREPF $3, KEY0, X7 + VREPF $0, KEY1, X8 + VREPF $1, KEY1, X9 + VREPF $2, KEY1, X10 + VREPF $3, KEY1, X11 + VLR CTR, X12 + VREPF $1, NONCE, X13 + VREPF $2, NONCE, X14 + VREPF $3, NONCE, X15 + + MOVD $(NUM_ROUNDS/2), R1 + +loop: + ROUND4(X0, X4, X12, X8, X1, X5, X13, X9, X2, X6, X14, X10, X3, X7, X15, X11) + ROUND4(X0, X5, X15, X10, X1, X6, X12, X11, X2, X7, X13, X8, X3, X4, X14, X9) + + ADD $-1, R1 + BNE loop + + // decrement length + ADD $-256, R4 + + // rearrange vectors + SHUFFLE(X0, X1, X2, X3, M0, M1, M2, M3) + ADDV(J0, X0, X1, X2, X3) + SHUFFLE(X4, X5, X6, X7, M0, M1, M2, M3) + ADDV(KEY0, X4, X5, X6, X7) + SHUFFLE(X8, X9, X10, X11, M0, M1, M2, M3) + ADDV(KEY1, X8, X9, X10, X11) + VAF CTR, X12, X12 + SHUFFLE(X12, X13, X14, X15, M0, M1, M2, M3) + ADDV(NONCE, X12, X13, X14, X15) + + // increment counters + VAF INC, CTR, CTR + + // xor keystream with plaintext + XORV(0*64, R2, R3, X0, X4, X8, X12) + XORV(1*64, R2, R3, X1, X5, X9, X13) + XORV(2*64, R2, R3, X2, X6, X10, X14) + XORV(3*64, R2, R3, X3, X7, X11, X15) + + // increment pointers + MOVD $256(R2), R2 + MOVD $256(R3), R3 + + CMPBNE R4, $0, chacha + + VSTEF $0, CTR, (R7) + RET diff --git a/internal/crypto/chacha20/xor.go b/internal/crypto/chacha20/xor.go new file mode 100644 index 000000000..c2d04851e --- /dev/null +++ b/internal/crypto/chacha20/xor.go @@ -0,0 +1,42 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found src the LICENSE file. + +package chacha20 + +import "runtime" + +// Platforms that have fast unaligned 32-bit little endian accesses. +const unaligned = runtime.GOARCH == "386" || + runtime.GOARCH == "amd64" || + runtime.GOARCH == "arm64" || + runtime.GOARCH == "ppc64le" || + runtime.GOARCH == "s390x" + +// addXor reads a little endian uint32 from src, XORs it with (a + b) and +// places the result in little endian byte order in dst. +func addXor(dst, src []byte, a, b uint32) { + _, _ = src[3], dst[3] // bounds check elimination hint + if unaligned { + // The compiler should optimize this code into + // 32-bit unaligned little endian loads and stores. + // TODO: delete once the compiler does a reliably + // good job with the generic code below. + // See issue #25111 for more details. + v := uint32(src[0]) + v |= uint32(src[1]) << 8 + v |= uint32(src[2]) << 16 + v |= uint32(src[3]) << 24 + v ^= a + b + dst[0] = byte(v) + dst[1] = byte(v >> 8) + dst[2] = byte(v >> 16) + dst[3] = byte(v >> 24) + } else { + a += b + dst[0] = src[0] ^ byte(a) + dst[1] = src[1] ^ byte(a>>8) + dst[2] = src[2] ^ byte(a>>16) + dst[3] = src[3] ^ byte(a>>24) + } +} diff --git a/internal/crypto/curve25519/curve25519.go b/internal/crypto/curve25519/curve25519.go new file mode 100644 index 000000000..4b9a655d1 --- /dev/null +++ b/internal/crypto/curve25519/curve25519.go @@ -0,0 +1,95 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package curve25519 provides an implementation of the X25519 function, which +// performs scalar multiplication on the elliptic curve known as Curve25519. +// See RFC 7748. +package curve25519 // import "golang.org/x/crypto/curve25519" + +import ( + "crypto/subtle" + "fmt" +) + +// ScalarMult sets dst to the product scalar * point. +// +// Deprecated: when provided a low-order point, ScalarMult will set dst to all +// zeroes, irrespective of the scalar. Instead, use the X25519 function, which +// will return an error. +func ScalarMult(dst, scalar, point *[32]byte) { + scalarMult(dst, scalar, point) +} + +// ScalarBaseMult sets dst to the product scalar * base where base is the +// standard generator. +// +// It is recommended to use the X25519 function with Basepoint instead, as +// copying into fixed size arrays can lead to unexpected bugs. +func ScalarBaseMult(dst, scalar *[32]byte) { + ScalarMult(dst, scalar, &basePoint) +} + +const ( + // ScalarSize is the size of the scalar input to X25519. + ScalarSize = 32 + // PointSize is the size of the point input to X25519. + PointSize = 32 +) + +// Basepoint is the canonical Curve25519 generator. +var Basepoint []byte + +var basePoint = [32]byte{9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + +func init() { Basepoint = basePoint[:] } + +func checkBasepoint() { + if subtle.ConstantTimeCompare(Basepoint, []byte{ + 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }) != 1 { + panic("curve25519: global Basepoint value was modified") + } +} + +// X25519 returns the result of the scalar multiplication (scalar * point), +// according to RFC 7748, Section 5. scalar, point and the return value are +// slices of 32 bytes. +// +// scalar can be generated at random, for example with crypto/rand. point should +// be either Basepoint or the output of another X25519 call. +// +// If point is Basepoint (but not if it's a different slice with the same +// contents) a precomputed implementation might be used for performance. +func X25519(scalar, point []byte) ([]byte, error) { + // Outline the body of function, to let the allocation be inlined in the + // caller, and possibly avoid escaping to the heap. + var dst [32]byte + return x25519(&dst, scalar, point) +} + +func x25519(dst *[32]byte, scalar, point []byte) ([]byte, error) { + var in [32]byte + if l := len(scalar); l != 32 { + return nil, fmt.Errorf("bad scalar length: %d, expected %d", l, 32) + } + if l := len(point); l != 32 { + return nil, fmt.Errorf("bad point length: %d, expected %d", l, 32) + } + copy(in[:], scalar) + if &point[0] == &Basepoint[0] { + checkBasepoint() + ScalarBaseMult(dst, &in) + } else { + var base, zero [32]byte + copy(base[:], point) + ScalarMult(dst, &in, &base) + if subtle.ConstantTimeCompare(dst[:], zero[:]) == 1 { + return nil, fmt.Errorf("bad input point: low order point") + } + } + return dst[:], nil +} diff --git a/internal/crypto/curve25519/curve25519_amd64.go b/internal/crypto/curve25519/curve25519_amd64.go new file mode 100644 index 000000000..5120b779b --- /dev/null +++ b/internal/crypto/curve25519/curve25519_amd64.go @@ -0,0 +1,240 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build amd64,!gccgo,!appengine,!purego + +package curve25519 + +// These functions are implemented in the .s files. The names of the functions +// in the rest of the file are also taken from the SUPERCOP sources to help +// people following along. + +//go:noescape + +func cswap(inout *[5]uint64, v uint64) + +//go:noescape + +func ladderstep(inout *[5][5]uint64) + +//go:noescape + +func freeze(inout *[5]uint64) + +//go:noescape + +func mul(dest, a, b *[5]uint64) + +//go:noescape + +func square(out, in *[5]uint64) + +// mladder uses a Montgomery ladder to calculate (xr/zr) *= s. +func mladder(xr, zr *[5]uint64, s *[32]byte) { + var work [5][5]uint64 + + work[0] = *xr + setint(&work[1], 1) + setint(&work[2], 0) + work[3] = *xr + setint(&work[4], 1) + + j := uint(6) + var prevbit byte + + for i := 31; i >= 0; i-- { + for j < 8 { + bit := ((*s)[i] >> j) & 1 + swap := bit ^ prevbit + prevbit = bit + cswap(&work[1], uint64(swap)) + ladderstep(&work) + j-- + } + j = 7 + } + + *xr = work[1] + *zr = work[2] +} + +func scalarMult(out, in, base *[32]byte) { + var e [32]byte + copy(e[:], (*in)[:]) + e[0] &= 248 + e[31] &= 127 + e[31] |= 64 + + var t, z [5]uint64 + unpack(&t, base) + mladder(&t, &z, &e) + invert(&z, &z) + mul(&t, &t, &z) + pack(out, &t) +} + +func setint(r *[5]uint64, v uint64) { + r[0] = v + r[1] = 0 + r[2] = 0 + r[3] = 0 + r[4] = 0 +} + +// unpack sets r = x where r consists of 5, 51-bit limbs in little-endian +// order. +func unpack(r *[5]uint64, x *[32]byte) { + r[0] = uint64(x[0]) | + uint64(x[1])<<8 | + uint64(x[2])<<16 | + uint64(x[3])<<24 | + uint64(x[4])<<32 | + uint64(x[5])<<40 | + uint64(x[6]&7)<<48 + + r[1] = uint64(x[6])>>3 | + uint64(x[7])<<5 | + uint64(x[8])<<13 | + uint64(x[9])<<21 | + uint64(x[10])<<29 | + uint64(x[11])<<37 | + uint64(x[12]&63)<<45 + + r[2] = uint64(x[12])>>6 | + uint64(x[13])<<2 | + uint64(x[14])<<10 | + uint64(x[15])<<18 | + uint64(x[16])<<26 | + uint64(x[17])<<34 | + uint64(x[18])<<42 | + uint64(x[19]&1)<<50 + + r[3] = uint64(x[19])>>1 | + uint64(x[20])<<7 | + uint64(x[21])<<15 | + uint64(x[22])<<23 | + uint64(x[23])<<31 | + uint64(x[24])<<39 | + uint64(x[25]&15)<<47 + + r[4] = uint64(x[25])>>4 | + uint64(x[26])<<4 | + uint64(x[27])<<12 | + uint64(x[28])<<20 | + uint64(x[29])<<28 | + uint64(x[30])<<36 | + uint64(x[31]&127)<<44 +} + +// pack sets out = x where out is the usual, little-endian form of the 5, +// 51-bit limbs in x. +func pack(out *[32]byte, x *[5]uint64) { + t := *x + freeze(&t) + + out[0] = byte(t[0]) + out[1] = byte(t[0] >> 8) + out[2] = byte(t[0] >> 16) + out[3] = byte(t[0] >> 24) + out[4] = byte(t[0] >> 32) + out[5] = byte(t[0] >> 40) + out[6] = byte(t[0] >> 48) + + out[6] ^= byte(t[1]<<3) & 0xf8 + out[7] = byte(t[1] >> 5) + out[8] = byte(t[1] >> 13) + out[9] = byte(t[1] >> 21) + out[10] = byte(t[1] >> 29) + out[11] = byte(t[1] >> 37) + out[12] = byte(t[1] >> 45) + + out[12] ^= byte(t[2]<<6) & 0xc0 + out[13] = byte(t[2] >> 2) + out[14] = byte(t[2] >> 10) + out[15] = byte(t[2] >> 18) + out[16] = byte(t[2] >> 26) + out[17] = byte(t[2] >> 34) + out[18] = byte(t[2] >> 42) + out[19] = byte(t[2] >> 50) + + out[19] ^= byte(t[3]<<1) & 0xfe + out[20] = byte(t[3] >> 7) + out[21] = byte(t[3] >> 15) + out[22] = byte(t[3] >> 23) + out[23] = byte(t[3] >> 31) + out[24] = byte(t[3] >> 39) + out[25] = byte(t[3] >> 47) + + out[25] ^= byte(t[4]<<4) & 0xf0 + out[26] = byte(t[4] >> 4) + out[27] = byte(t[4] >> 12) + out[28] = byte(t[4] >> 20) + out[29] = byte(t[4] >> 28) + out[30] = byte(t[4] >> 36) + out[31] = byte(t[4] >> 44) +} + +// invert calculates r = x^-1 mod p using Fermat's little theorem. +func invert(r *[5]uint64, x *[5]uint64) { + var z2, z9, z11, z2_5_0, z2_10_0, z2_20_0, z2_50_0, z2_100_0, t [5]uint64 + + square(&z2, x) /* 2 */ + square(&t, &z2) /* 4 */ + square(&t, &t) /* 8 */ + mul(&z9, &t, x) /* 9 */ + mul(&z11, &z9, &z2) /* 11 */ + square(&t, &z11) /* 22 */ + mul(&z2_5_0, &t, &z9) /* 2^5 - 2^0 = 31 */ + + square(&t, &z2_5_0) /* 2^6 - 2^1 */ + for i := 1; i < 5; i++ { /* 2^20 - 2^10 */ + square(&t, &t) + } + mul(&z2_10_0, &t, &z2_5_0) /* 2^10 - 2^0 */ + + square(&t, &z2_10_0) /* 2^11 - 2^1 */ + for i := 1; i < 10; i++ { /* 2^20 - 2^10 */ + square(&t, &t) + } + mul(&z2_20_0, &t, &z2_10_0) /* 2^20 - 2^0 */ + + square(&t, &z2_20_0) /* 2^21 - 2^1 */ + for i := 1; i < 20; i++ { /* 2^40 - 2^20 */ + square(&t, &t) + } + mul(&t, &t, &z2_20_0) /* 2^40 - 2^0 */ + + square(&t, &t) /* 2^41 - 2^1 */ + for i := 1; i < 10; i++ { /* 2^50 - 2^10 */ + square(&t, &t) + } + mul(&z2_50_0, &t, &z2_10_0) /* 2^50 - 2^0 */ + + square(&t, &z2_50_0) /* 2^51 - 2^1 */ + for i := 1; i < 50; i++ { /* 2^100 - 2^50 */ + square(&t, &t) + } + mul(&z2_100_0, &t, &z2_50_0) /* 2^100 - 2^0 */ + + square(&t, &z2_100_0) /* 2^101 - 2^1 */ + for i := 1; i < 100; i++ { /* 2^200 - 2^100 */ + square(&t, &t) + } + mul(&t, &t, &z2_100_0) /* 2^200 - 2^0 */ + + square(&t, &t) /* 2^201 - 2^1 */ + for i := 1; i < 50; i++ { /* 2^250 - 2^50 */ + square(&t, &t) + } + mul(&t, &t, &z2_50_0) /* 2^250 - 2^0 */ + + square(&t, &t) /* 2^251 - 2^1 */ + square(&t, &t) /* 2^252 - 2^2 */ + square(&t, &t) /* 2^253 - 2^3 */ + + square(&t, &t) /* 2^254 - 2^4 */ + + square(&t, &t) /* 2^255 - 2^5 */ + mul(r, &t, &z11) /* 2^255 - 21 */ +} diff --git a/internal/crypto/curve25519/curve25519_amd64.s b/internal/crypto/curve25519/curve25519_amd64.s new file mode 100644 index 000000000..0250c8885 --- /dev/null +++ b/internal/crypto/curve25519/curve25519_amd64.s @@ -0,0 +1,1793 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This code was translated into a form compatible with 6a from the public +// domain sources in SUPERCOP: https://bench.cr.yp.to/supercop.html + +// +build amd64,!gccgo,!appengine,!purego + +#define REDMASK51 0x0007FFFFFFFFFFFF + +// These constants cannot be encoded in non-MOVQ immediates. +// We access them directly from memory instead. + +DATA ·_121666_213(SB)/8, $996687872 +GLOBL ·_121666_213(SB), 8, $8 + +DATA ·_2P0(SB)/8, $0xFFFFFFFFFFFDA +GLOBL ·_2P0(SB), 8, $8 + +DATA ·_2P1234(SB)/8, $0xFFFFFFFFFFFFE +GLOBL ·_2P1234(SB), 8, $8 + +// func freeze(inout *[5]uint64) +TEXT ·freeze(SB),7,$0-8 + MOVQ inout+0(FP), DI + + MOVQ 0(DI),SI + MOVQ 8(DI),DX + MOVQ 16(DI),CX + MOVQ 24(DI),R8 + MOVQ 32(DI),R9 + MOVQ $REDMASK51,AX + MOVQ AX,R10 + SUBQ $18,R10 + MOVQ $3,R11 +REDUCELOOP: + MOVQ SI,R12 + SHRQ $51,R12 + ANDQ AX,SI + ADDQ R12,DX + MOVQ DX,R12 + SHRQ $51,R12 + ANDQ AX,DX + ADDQ R12,CX + MOVQ CX,R12 + SHRQ $51,R12 + ANDQ AX,CX + ADDQ R12,R8 + MOVQ R8,R12 + SHRQ $51,R12 + ANDQ AX,R8 + ADDQ R12,R9 + MOVQ R9,R12 + SHRQ $51,R12 + ANDQ AX,R9 + IMUL3Q $19,R12,R12 + ADDQ R12,SI + SUBQ $1,R11 + JA REDUCELOOP + MOVQ $1,R12 + CMPQ R10,SI + CMOVQLT R11,R12 + CMPQ AX,DX + CMOVQNE R11,R12 + CMPQ AX,CX + CMOVQNE R11,R12 + CMPQ AX,R8 + CMOVQNE R11,R12 + CMPQ AX,R9 + CMOVQNE R11,R12 + NEGQ R12 + ANDQ R12,AX + ANDQ R12,R10 + SUBQ R10,SI + SUBQ AX,DX + SUBQ AX,CX + SUBQ AX,R8 + SUBQ AX,R9 + MOVQ SI,0(DI) + MOVQ DX,8(DI) + MOVQ CX,16(DI) + MOVQ R8,24(DI) + MOVQ R9,32(DI) + RET + +// func ladderstep(inout *[5][5]uint64) +TEXT ·ladderstep(SB),0,$296-8 + MOVQ inout+0(FP),DI + + MOVQ 40(DI),SI + MOVQ 48(DI),DX + MOVQ 56(DI),CX + MOVQ 64(DI),R8 + MOVQ 72(DI),R9 + MOVQ SI,AX + MOVQ DX,R10 + MOVQ CX,R11 + MOVQ R8,R12 + MOVQ R9,R13 + ADDQ ·_2P0(SB),AX + ADDQ ·_2P1234(SB),R10 + ADDQ ·_2P1234(SB),R11 + ADDQ ·_2P1234(SB),R12 + ADDQ ·_2P1234(SB),R13 + ADDQ 80(DI),SI + ADDQ 88(DI),DX + ADDQ 96(DI),CX + ADDQ 104(DI),R8 + ADDQ 112(DI),R9 + SUBQ 80(DI),AX + SUBQ 88(DI),R10 + SUBQ 96(DI),R11 + SUBQ 104(DI),R12 + SUBQ 112(DI),R13 + MOVQ SI,0(SP) + MOVQ DX,8(SP) + MOVQ CX,16(SP) + MOVQ R8,24(SP) + MOVQ R9,32(SP) + MOVQ AX,40(SP) + MOVQ R10,48(SP) + MOVQ R11,56(SP) + MOVQ R12,64(SP) + MOVQ R13,72(SP) + MOVQ 40(SP),AX + MULQ 40(SP) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 40(SP),AX + SHLQ $1,AX + MULQ 48(SP) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 40(SP),AX + SHLQ $1,AX + MULQ 56(SP) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 40(SP),AX + SHLQ $1,AX + MULQ 64(SP) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 40(SP),AX + SHLQ $1,AX + MULQ 72(SP) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 48(SP),AX + MULQ 48(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 48(SP),AX + SHLQ $1,AX + MULQ 56(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 48(SP),AX + SHLQ $1,AX + MULQ 64(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 48(SP),DX + IMUL3Q $38,DX,AX + MULQ 72(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 56(SP),AX + MULQ 56(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 56(SP),DX + IMUL3Q $38,DX,AX + MULQ 64(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 56(SP),DX + IMUL3Q $38,DX,AX + MULQ 72(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 64(SP),DX + IMUL3Q $19,DX,AX + MULQ 64(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 64(SP),DX + IMUL3Q $38,DX,AX + MULQ 72(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 72(SP),DX + IMUL3Q $19,DX,AX + MULQ 72(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,SI,CX + ANDQ DX,SI + SHLQ $13,R8,R9 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R10,R11 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R12,R13 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R14,R15 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + ANDQ DX,SI + MOVQ CX,R8 + SHRQ $51,CX + ADDQ R10,CX + ANDQ DX,R8 + MOVQ CX,R9 + SHRQ $51,CX + ADDQ R12,CX + ANDQ DX,R9 + MOVQ CX,AX + SHRQ $51,CX + ADDQ R14,CX + ANDQ DX,AX + MOVQ CX,R10 + SHRQ $51,CX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,80(SP) + MOVQ R8,88(SP) + MOVQ R9,96(SP) + MOVQ AX,104(SP) + MOVQ R10,112(SP) + MOVQ 0(SP),AX + MULQ 0(SP) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 0(SP),AX + SHLQ $1,AX + MULQ 8(SP) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 0(SP),AX + SHLQ $1,AX + MULQ 16(SP) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 0(SP),AX + SHLQ $1,AX + MULQ 24(SP) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 0(SP),AX + SHLQ $1,AX + MULQ 32(SP) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 8(SP),AX + MULQ 8(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 8(SP),AX + SHLQ $1,AX + MULQ 16(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 8(SP),AX + SHLQ $1,AX + MULQ 24(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 8(SP),DX + IMUL3Q $38,DX,AX + MULQ 32(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 16(SP),AX + MULQ 16(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 16(SP),DX + IMUL3Q $38,DX,AX + MULQ 24(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 16(SP),DX + IMUL3Q $38,DX,AX + MULQ 32(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 24(SP),DX + IMUL3Q $19,DX,AX + MULQ 24(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 24(SP),DX + IMUL3Q $38,DX,AX + MULQ 32(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 32(SP),DX + IMUL3Q $19,DX,AX + MULQ 32(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,SI,CX + ANDQ DX,SI + SHLQ $13,R8,R9 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R10,R11 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R12,R13 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R14,R15 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + ANDQ DX,SI + MOVQ CX,R8 + SHRQ $51,CX + ADDQ R10,CX + ANDQ DX,R8 + MOVQ CX,R9 + SHRQ $51,CX + ADDQ R12,CX + ANDQ DX,R9 + MOVQ CX,AX + SHRQ $51,CX + ADDQ R14,CX + ANDQ DX,AX + MOVQ CX,R10 + SHRQ $51,CX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,120(SP) + MOVQ R8,128(SP) + MOVQ R9,136(SP) + MOVQ AX,144(SP) + MOVQ R10,152(SP) + MOVQ SI,SI + MOVQ R8,DX + MOVQ R9,CX + MOVQ AX,R8 + MOVQ R10,R9 + ADDQ ·_2P0(SB),SI + ADDQ ·_2P1234(SB),DX + ADDQ ·_2P1234(SB),CX + ADDQ ·_2P1234(SB),R8 + ADDQ ·_2P1234(SB),R9 + SUBQ 80(SP),SI + SUBQ 88(SP),DX + SUBQ 96(SP),CX + SUBQ 104(SP),R8 + SUBQ 112(SP),R9 + MOVQ SI,160(SP) + MOVQ DX,168(SP) + MOVQ CX,176(SP) + MOVQ R8,184(SP) + MOVQ R9,192(SP) + MOVQ 120(DI),SI + MOVQ 128(DI),DX + MOVQ 136(DI),CX + MOVQ 144(DI),R8 + MOVQ 152(DI),R9 + MOVQ SI,AX + MOVQ DX,R10 + MOVQ CX,R11 + MOVQ R8,R12 + MOVQ R9,R13 + ADDQ ·_2P0(SB),AX + ADDQ ·_2P1234(SB),R10 + ADDQ ·_2P1234(SB),R11 + ADDQ ·_2P1234(SB),R12 + ADDQ ·_2P1234(SB),R13 + ADDQ 160(DI),SI + ADDQ 168(DI),DX + ADDQ 176(DI),CX + ADDQ 184(DI),R8 + ADDQ 192(DI),R9 + SUBQ 160(DI),AX + SUBQ 168(DI),R10 + SUBQ 176(DI),R11 + SUBQ 184(DI),R12 + SUBQ 192(DI),R13 + MOVQ SI,200(SP) + MOVQ DX,208(SP) + MOVQ CX,216(SP) + MOVQ R8,224(SP) + MOVQ R9,232(SP) + MOVQ AX,240(SP) + MOVQ R10,248(SP) + MOVQ R11,256(SP) + MOVQ R12,264(SP) + MOVQ R13,272(SP) + MOVQ 224(SP),SI + IMUL3Q $19,SI,AX + MOVQ AX,280(SP) + MULQ 56(SP) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 232(SP),DX + IMUL3Q $19,DX,AX + MOVQ AX,288(SP) + MULQ 48(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 200(SP),AX + MULQ 40(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 200(SP),AX + MULQ 48(SP) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 200(SP),AX + MULQ 56(SP) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 200(SP),AX + MULQ 64(SP) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 200(SP),AX + MULQ 72(SP) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 208(SP),AX + MULQ 40(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 208(SP),AX + MULQ 48(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 208(SP),AX + MULQ 56(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 208(SP),AX + MULQ 64(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 208(SP),DX + IMUL3Q $19,DX,AX + MULQ 72(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 216(SP),AX + MULQ 40(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 216(SP),AX + MULQ 48(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 216(SP),AX + MULQ 56(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 216(SP),DX + IMUL3Q $19,DX,AX + MULQ 64(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 216(SP),DX + IMUL3Q $19,DX,AX + MULQ 72(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 224(SP),AX + MULQ 40(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 224(SP),AX + MULQ 48(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 280(SP),AX + MULQ 64(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 280(SP),AX + MULQ 72(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 232(SP),AX + MULQ 40(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 288(SP),AX + MULQ 56(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 288(SP),AX + MULQ 64(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 288(SP),AX + MULQ 72(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,SI,CX + ANDQ DX,SI + SHLQ $13,R8,R9 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R10,R11 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R12,R13 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R14,R15 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + MOVQ CX,R8 + SHRQ $51,CX + ANDQ DX,SI + ADDQ R10,CX + MOVQ CX,R9 + SHRQ $51,CX + ANDQ DX,R8 + ADDQ R12,CX + MOVQ CX,AX + SHRQ $51,CX + ANDQ DX,R9 + ADDQ R14,CX + MOVQ CX,R10 + SHRQ $51,CX + ANDQ DX,AX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,40(SP) + MOVQ R8,48(SP) + MOVQ R9,56(SP) + MOVQ AX,64(SP) + MOVQ R10,72(SP) + MOVQ 264(SP),SI + IMUL3Q $19,SI,AX + MOVQ AX,200(SP) + MULQ 16(SP) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 272(SP),DX + IMUL3Q $19,DX,AX + MOVQ AX,208(SP) + MULQ 8(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 240(SP),AX + MULQ 0(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 240(SP),AX + MULQ 8(SP) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 240(SP),AX + MULQ 16(SP) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 240(SP),AX + MULQ 24(SP) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 240(SP),AX + MULQ 32(SP) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 248(SP),AX + MULQ 0(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 248(SP),AX + MULQ 8(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 248(SP),AX + MULQ 16(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 248(SP),AX + MULQ 24(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 248(SP),DX + IMUL3Q $19,DX,AX + MULQ 32(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 256(SP),AX + MULQ 0(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 256(SP),AX + MULQ 8(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 256(SP),AX + MULQ 16(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 256(SP),DX + IMUL3Q $19,DX,AX + MULQ 24(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 256(SP),DX + IMUL3Q $19,DX,AX + MULQ 32(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 264(SP),AX + MULQ 0(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 264(SP),AX + MULQ 8(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 200(SP),AX + MULQ 24(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 200(SP),AX + MULQ 32(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 272(SP),AX + MULQ 0(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 208(SP),AX + MULQ 16(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 208(SP),AX + MULQ 24(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 208(SP),AX + MULQ 32(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,SI,CX + ANDQ DX,SI + SHLQ $13,R8,R9 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R10,R11 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R12,R13 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R14,R15 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + MOVQ CX,R8 + SHRQ $51,CX + ANDQ DX,SI + ADDQ R10,CX + MOVQ CX,R9 + SHRQ $51,CX + ANDQ DX,R8 + ADDQ R12,CX + MOVQ CX,AX + SHRQ $51,CX + ANDQ DX,R9 + ADDQ R14,CX + MOVQ CX,R10 + SHRQ $51,CX + ANDQ DX,AX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,DX + MOVQ R8,CX + MOVQ R9,R11 + MOVQ AX,R12 + MOVQ R10,R13 + ADDQ ·_2P0(SB),DX + ADDQ ·_2P1234(SB),CX + ADDQ ·_2P1234(SB),R11 + ADDQ ·_2P1234(SB),R12 + ADDQ ·_2P1234(SB),R13 + ADDQ 40(SP),SI + ADDQ 48(SP),R8 + ADDQ 56(SP),R9 + ADDQ 64(SP),AX + ADDQ 72(SP),R10 + SUBQ 40(SP),DX + SUBQ 48(SP),CX + SUBQ 56(SP),R11 + SUBQ 64(SP),R12 + SUBQ 72(SP),R13 + MOVQ SI,120(DI) + MOVQ R8,128(DI) + MOVQ R9,136(DI) + MOVQ AX,144(DI) + MOVQ R10,152(DI) + MOVQ DX,160(DI) + MOVQ CX,168(DI) + MOVQ R11,176(DI) + MOVQ R12,184(DI) + MOVQ R13,192(DI) + MOVQ 120(DI),AX + MULQ 120(DI) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 120(DI),AX + SHLQ $1,AX + MULQ 128(DI) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 120(DI),AX + SHLQ $1,AX + MULQ 136(DI) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 120(DI),AX + SHLQ $1,AX + MULQ 144(DI) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 120(DI),AX + SHLQ $1,AX + MULQ 152(DI) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 128(DI),AX + MULQ 128(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 128(DI),AX + SHLQ $1,AX + MULQ 136(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 128(DI),AX + SHLQ $1,AX + MULQ 144(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 128(DI),DX + IMUL3Q $38,DX,AX + MULQ 152(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 136(DI),AX + MULQ 136(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 136(DI),DX + IMUL3Q $38,DX,AX + MULQ 144(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 136(DI),DX + IMUL3Q $38,DX,AX + MULQ 152(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 144(DI),DX + IMUL3Q $19,DX,AX + MULQ 144(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 144(DI),DX + IMUL3Q $38,DX,AX + MULQ 152(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 152(DI),DX + IMUL3Q $19,DX,AX + MULQ 152(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,SI,CX + ANDQ DX,SI + SHLQ $13,R8,R9 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R10,R11 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R12,R13 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R14,R15 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + ANDQ DX,SI + MOVQ CX,R8 + SHRQ $51,CX + ADDQ R10,CX + ANDQ DX,R8 + MOVQ CX,R9 + SHRQ $51,CX + ADDQ R12,CX + ANDQ DX,R9 + MOVQ CX,AX + SHRQ $51,CX + ADDQ R14,CX + ANDQ DX,AX + MOVQ CX,R10 + SHRQ $51,CX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,120(DI) + MOVQ R8,128(DI) + MOVQ R9,136(DI) + MOVQ AX,144(DI) + MOVQ R10,152(DI) + MOVQ 160(DI),AX + MULQ 160(DI) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 160(DI),AX + SHLQ $1,AX + MULQ 168(DI) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 160(DI),AX + SHLQ $1,AX + MULQ 176(DI) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 160(DI),AX + SHLQ $1,AX + MULQ 184(DI) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 160(DI),AX + SHLQ $1,AX + MULQ 192(DI) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 168(DI),AX + MULQ 168(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 168(DI),AX + SHLQ $1,AX + MULQ 176(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 168(DI),AX + SHLQ $1,AX + MULQ 184(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 168(DI),DX + IMUL3Q $38,DX,AX + MULQ 192(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 176(DI),AX + MULQ 176(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 176(DI),DX + IMUL3Q $38,DX,AX + MULQ 184(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 176(DI),DX + IMUL3Q $38,DX,AX + MULQ 192(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 184(DI),DX + IMUL3Q $19,DX,AX + MULQ 184(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 184(DI),DX + IMUL3Q $38,DX,AX + MULQ 192(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 192(DI),DX + IMUL3Q $19,DX,AX + MULQ 192(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,SI,CX + ANDQ DX,SI + SHLQ $13,R8,R9 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R10,R11 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R12,R13 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R14,R15 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + ANDQ DX,SI + MOVQ CX,R8 + SHRQ $51,CX + ADDQ R10,CX + ANDQ DX,R8 + MOVQ CX,R9 + SHRQ $51,CX + ADDQ R12,CX + ANDQ DX,R9 + MOVQ CX,AX + SHRQ $51,CX + ADDQ R14,CX + ANDQ DX,AX + MOVQ CX,R10 + SHRQ $51,CX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,160(DI) + MOVQ R8,168(DI) + MOVQ R9,176(DI) + MOVQ AX,184(DI) + MOVQ R10,192(DI) + MOVQ 184(DI),SI + IMUL3Q $19,SI,AX + MOVQ AX,0(SP) + MULQ 16(DI) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 192(DI),DX + IMUL3Q $19,DX,AX + MOVQ AX,8(SP) + MULQ 8(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 160(DI),AX + MULQ 0(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 160(DI),AX + MULQ 8(DI) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 160(DI),AX + MULQ 16(DI) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 160(DI),AX + MULQ 24(DI) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 160(DI),AX + MULQ 32(DI) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 168(DI),AX + MULQ 0(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 168(DI),AX + MULQ 8(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 168(DI),AX + MULQ 16(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 168(DI),AX + MULQ 24(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 168(DI),DX + IMUL3Q $19,DX,AX + MULQ 32(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 176(DI),AX + MULQ 0(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 176(DI),AX + MULQ 8(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 176(DI),AX + MULQ 16(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 176(DI),DX + IMUL3Q $19,DX,AX + MULQ 24(DI) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 176(DI),DX + IMUL3Q $19,DX,AX + MULQ 32(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 184(DI),AX + MULQ 0(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 184(DI),AX + MULQ 8(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 0(SP),AX + MULQ 24(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 0(SP),AX + MULQ 32(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 192(DI),AX + MULQ 0(DI) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 8(SP),AX + MULQ 16(DI) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 8(SP),AX + MULQ 24(DI) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 8(SP),AX + MULQ 32(DI) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,SI,CX + ANDQ DX,SI + SHLQ $13,R8,R9 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R10,R11 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R12,R13 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R14,R15 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + MOVQ CX,R8 + SHRQ $51,CX + ANDQ DX,SI + ADDQ R10,CX + MOVQ CX,R9 + SHRQ $51,CX + ANDQ DX,R8 + ADDQ R12,CX + MOVQ CX,AX + SHRQ $51,CX + ANDQ DX,R9 + ADDQ R14,CX + MOVQ CX,R10 + SHRQ $51,CX + ANDQ DX,AX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,160(DI) + MOVQ R8,168(DI) + MOVQ R9,176(DI) + MOVQ AX,184(DI) + MOVQ R10,192(DI) + MOVQ 144(SP),SI + IMUL3Q $19,SI,AX + MOVQ AX,0(SP) + MULQ 96(SP) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 152(SP),DX + IMUL3Q $19,DX,AX + MOVQ AX,8(SP) + MULQ 88(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 120(SP),AX + MULQ 80(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 120(SP),AX + MULQ 88(SP) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 120(SP),AX + MULQ 96(SP) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 120(SP),AX + MULQ 104(SP) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 120(SP),AX + MULQ 112(SP) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 128(SP),AX + MULQ 80(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 128(SP),AX + MULQ 88(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 128(SP),AX + MULQ 96(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 128(SP),AX + MULQ 104(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 128(SP),DX + IMUL3Q $19,DX,AX + MULQ 112(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 136(SP),AX + MULQ 80(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 136(SP),AX + MULQ 88(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 136(SP),AX + MULQ 96(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 136(SP),DX + IMUL3Q $19,DX,AX + MULQ 104(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 136(SP),DX + IMUL3Q $19,DX,AX + MULQ 112(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 144(SP),AX + MULQ 80(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 144(SP),AX + MULQ 88(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 0(SP),AX + MULQ 104(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 0(SP),AX + MULQ 112(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 152(SP),AX + MULQ 80(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 8(SP),AX + MULQ 96(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 8(SP),AX + MULQ 104(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 8(SP),AX + MULQ 112(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,SI,CX + ANDQ DX,SI + SHLQ $13,R8,R9 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R10,R11 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R12,R13 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R14,R15 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + MOVQ CX,R8 + SHRQ $51,CX + ANDQ DX,SI + ADDQ R10,CX + MOVQ CX,R9 + SHRQ $51,CX + ANDQ DX,R8 + ADDQ R12,CX + MOVQ CX,AX + SHRQ $51,CX + ANDQ DX,R9 + ADDQ R14,CX + MOVQ CX,R10 + SHRQ $51,CX + ANDQ DX,AX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,40(DI) + MOVQ R8,48(DI) + MOVQ R9,56(DI) + MOVQ AX,64(DI) + MOVQ R10,72(DI) + MOVQ 160(SP),AX + MULQ ·_121666_213(SB) + SHRQ $13,AX + MOVQ AX,SI + MOVQ DX,CX + MOVQ 168(SP),AX + MULQ ·_121666_213(SB) + SHRQ $13,AX + ADDQ AX,CX + MOVQ DX,R8 + MOVQ 176(SP),AX + MULQ ·_121666_213(SB) + SHRQ $13,AX + ADDQ AX,R8 + MOVQ DX,R9 + MOVQ 184(SP),AX + MULQ ·_121666_213(SB) + SHRQ $13,AX + ADDQ AX,R9 + MOVQ DX,R10 + MOVQ 192(SP),AX + MULQ ·_121666_213(SB) + SHRQ $13,AX + ADDQ AX,R10 + IMUL3Q $19,DX,DX + ADDQ DX,SI + ADDQ 80(SP),SI + ADDQ 88(SP),CX + ADDQ 96(SP),R8 + ADDQ 104(SP),R9 + ADDQ 112(SP),R10 + MOVQ SI,80(DI) + MOVQ CX,88(DI) + MOVQ R8,96(DI) + MOVQ R9,104(DI) + MOVQ R10,112(DI) + MOVQ 104(DI),SI + IMUL3Q $19,SI,AX + MOVQ AX,0(SP) + MULQ 176(SP) + MOVQ AX,SI + MOVQ DX,CX + MOVQ 112(DI),DX + IMUL3Q $19,DX,AX + MOVQ AX,8(SP) + MULQ 168(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 80(DI),AX + MULQ 160(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 80(DI),AX + MULQ 168(SP) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 80(DI),AX + MULQ 176(SP) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 80(DI),AX + MULQ 184(SP) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 80(DI),AX + MULQ 192(SP) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 88(DI),AX + MULQ 160(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 88(DI),AX + MULQ 168(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 88(DI),AX + MULQ 176(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 88(DI),AX + MULQ 184(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 88(DI),DX + IMUL3Q $19,DX,AX + MULQ 192(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 96(DI),AX + MULQ 160(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 96(DI),AX + MULQ 168(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 96(DI),AX + MULQ 176(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 96(DI),DX + IMUL3Q $19,DX,AX + MULQ 184(SP) + ADDQ AX,SI + ADCQ DX,CX + MOVQ 96(DI),DX + IMUL3Q $19,DX,AX + MULQ 192(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 104(DI),AX + MULQ 160(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 104(DI),AX + MULQ 168(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 0(SP),AX + MULQ 184(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 0(SP),AX + MULQ 192(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 112(DI),AX + MULQ 160(SP) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 8(SP),AX + MULQ 176(SP) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 8(SP),AX + MULQ 184(SP) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 8(SP),AX + MULQ 192(SP) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ $REDMASK51,DX + SHLQ $13,SI,CX + ANDQ DX,SI + SHLQ $13,R8,R9 + ANDQ DX,R8 + ADDQ CX,R8 + SHLQ $13,R10,R11 + ANDQ DX,R10 + ADDQ R9,R10 + SHLQ $13,R12,R13 + ANDQ DX,R12 + ADDQ R11,R12 + SHLQ $13,R14,R15 + ANDQ DX,R14 + ADDQ R13,R14 + IMUL3Q $19,R15,CX + ADDQ CX,SI + MOVQ SI,CX + SHRQ $51,CX + ADDQ R8,CX + MOVQ CX,R8 + SHRQ $51,CX + ANDQ DX,SI + ADDQ R10,CX + MOVQ CX,R9 + SHRQ $51,CX + ANDQ DX,R8 + ADDQ R12,CX + MOVQ CX,AX + SHRQ $51,CX + ANDQ DX,R9 + ADDQ R14,CX + MOVQ CX,R10 + SHRQ $51,CX + ANDQ DX,AX + IMUL3Q $19,CX,CX + ADDQ CX,SI + ANDQ DX,R10 + MOVQ SI,80(DI) + MOVQ R8,88(DI) + MOVQ R9,96(DI) + MOVQ AX,104(DI) + MOVQ R10,112(DI) + RET + +// func cswap(inout *[4][5]uint64, v uint64) +TEXT ·cswap(SB),7,$0 + MOVQ inout+0(FP),DI + MOVQ v+8(FP),SI + + SUBQ $1, SI + NOTQ SI + MOVQ SI, X15 + PSHUFD $0x44, X15, X15 + + MOVOU 0(DI), X0 + MOVOU 16(DI), X2 + MOVOU 32(DI), X4 + MOVOU 48(DI), X6 + MOVOU 64(DI), X8 + MOVOU 80(DI), X1 + MOVOU 96(DI), X3 + MOVOU 112(DI), X5 + MOVOU 128(DI), X7 + MOVOU 144(DI), X9 + + MOVO X1, X10 + MOVO X3, X11 + MOVO X5, X12 + MOVO X7, X13 + MOVO X9, X14 + + PXOR X0, X10 + PXOR X2, X11 + PXOR X4, X12 + PXOR X6, X13 + PXOR X8, X14 + PAND X15, X10 + PAND X15, X11 + PAND X15, X12 + PAND X15, X13 + PAND X15, X14 + PXOR X10, X0 + PXOR X10, X1 + PXOR X11, X2 + PXOR X11, X3 + PXOR X12, X4 + PXOR X12, X5 + PXOR X13, X6 + PXOR X13, X7 + PXOR X14, X8 + PXOR X14, X9 + + MOVOU X0, 0(DI) + MOVOU X2, 16(DI) + MOVOU X4, 32(DI) + MOVOU X6, 48(DI) + MOVOU X8, 64(DI) + MOVOU X1, 80(DI) + MOVOU X3, 96(DI) + MOVOU X5, 112(DI) + MOVOU X7, 128(DI) + MOVOU X9, 144(DI) + RET + +// func mul(dest, a, b *[5]uint64) +TEXT ·mul(SB),0,$16-24 + MOVQ dest+0(FP), DI + MOVQ a+8(FP), SI + MOVQ b+16(FP), DX + + MOVQ DX,CX + MOVQ 24(SI),DX + IMUL3Q $19,DX,AX + MOVQ AX,0(SP) + MULQ 16(CX) + MOVQ AX,R8 + MOVQ DX,R9 + MOVQ 32(SI),DX + IMUL3Q $19,DX,AX + MOVQ AX,8(SP) + MULQ 8(CX) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 0(SI),AX + MULQ 0(CX) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 0(SI),AX + MULQ 8(CX) + MOVQ AX,R10 + MOVQ DX,R11 + MOVQ 0(SI),AX + MULQ 16(CX) + MOVQ AX,R12 + MOVQ DX,R13 + MOVQ 0(SI),AX + MULQ 24(CX) + MOVQ AX,R14 + MOVQ DX,R15 + MOVQ 0(SI),AX + MULQ 32(CX) + MOVQ AX,BX + MOVQ DX,BP + MOVQ 8(SI),AX + MULQ 0(CX) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 8(SI),AX + MULQ 8(CX) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 8(SI),AX + MULQ 16(CX) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 8(SI),AX + MULQ 24(CX) + ADDQ AX,BX + ADCQ DX,BP + MOVQ 8(SI),DX + IMUL3Q $19,DX,AX + MULQ 32(CX) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 16(SI),AX + MULQ 0(CX) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 16(SI),AX + MULQ 8(CX) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 16(SI),AX + MULQ 16(CX) + ADDQ AX,BX + ADCQ DX,BP + MOVQ 16(SI),DX + IMUL3Q $19,DX,AX + MULQ 24(CX) + ADDQ AX,R8 + ADCQ DX,R9 + MOVQ 16(SI),DX + IMUL3Q $19,DX,AX + MULQ 32(CX) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 24(SI),AX + MULQ 0(CX) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ 24(SI),AX + MULQ 8(CX) + ADDQ AX,BX + ADCQ DX,BP + MOVQ 0(SP),AX + MULQ 24(CX) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 0(SP),AX + MULQ 32(CX) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 32(SI),AX + MULQ 0(CX) + ADDQ AX,BX + ADCQ DX,BP + MOVQ 8(SP),AX + MULQ 16(CX) + ADDQ AX,R10 + ADCQ DX,R11 + MOVQ 8(SP),AX + MULQ 24(CX) + ADDQ AX,R12 + ADCQ DX,R13 + MOVQ 8(SP),AX + MULQ 32(CX) + ADDQ AX,R14 + ADCQ DX,R15 + MOVQ $REDMASK51,SI + SHLQ $13,R8,R9 + ANDQ SI,R8 + SHLQ $13,R10,R11 + ANDQ SI,R10 + ADDQ R9,R10 + SHLQ $13,R12,R13 + ANDQ SI,R12 + ADDQ R11,R12 + SHLQ $13,R14,R15 + ANDQ SI,R14 + ADDQ R13,R14 + SHLQ $13,BX,BP + ANDQ SI,BX + ADDQ R15,BX + IMUL3Q $19,BP,DX + ADDQ DX,R8 + MOVQ R8,DX + SHRQ $51,DX + ADDQ R10,DX + MOVQ DX,CX + SHRQ $51,DX + ANDQ SI,R8 + ADDQ R12,DX + MOVQ DX,R9 + SHRQ $51,DX + ANDQ SI,CX + ADDQ R14,DX + MOVQ DX,AX + SHRQ $51,DX + ANDQ SI,R9 + ADDQ BX,DX + MOVQ DX,R10 + SHRQ $51,DX + ANDQ SI,AX + IMUL3Q $19,DX,DX + ADDQ DX,R8 + ANDQ SI,R10 + MOVQ R8,0(DI) + MOVQ CX,8(DI) + MOVQ R9,16(DI) + MOVQ AX,24(DI) + MOVQ R10,32(DI) + RET + +// func square(out, in *[5]uint64) +TEXT ·square(SB),7,$0-16 + MOVQ out+0(FP), DI + MOVQ in+8(FP), SI + + MOVQ 0(SI),AX + MULQ 0(SI) + MOVQ AX,CX + MOVQ DX,R8 + MOVQ 0(SI),AX + SHLQ $1,AX + MULQ 8(SI) + MOVQ AX,R9 + MOVQ DX,R10 + MOVQ 0(SI),AX + SHLQ $1,AX + MULQ 16(SI) + MOVQ AX,R11 + MOVQ DX,R12 + MOVQ 0(SI),AX + SHLQ $1,AX + MULQ 24(SI) + MOVQ AX,R13 + MOVQ DX,R14 + MOVQ 0(SI),AX + SHLQ $1,AX + MULQ 32(SI) + MOVQ AX,R15 + MOVQ DX,BX + MOVQ 8(SI),AX + MULQ 8(SI) + ADDQ AX,R11 + ADCQ DX,R12 + MOVQ 8(SI),AX + SHLQ $1,AX + MULQ 16(SI) + ADDQ AX,R13 + ADCQ DX,R14 + MOVQ 8(SI),AX + SHLQ $1,AX + MULQ 24(SI) + ADDQ AX,R15 + ADCQ DX,BX + MOVQ 8(SI),DX + IMUL3Q $38,DX,AX + MULQ 32(SI) + ADDQ AX,CX + ADCQ DX,R8 + MOVQ 16(SI),AX + MULQ 16(SI) + ADDQ AX,R15 + ADCQ DX,BX + MOVQ 16(SI),DX + IMUL3Q $38,DX,AX + MULQ 24(SI) + ADDQ AX,CX + ADCQ DX,R8 + MOVQ 16(SI),DX + IMUL3Q $38,DX,AX + MULQ 32(SI) + ADDQ AX,R9 + ADCQ DX,R10 + MOVQ 24(SI),DX + IMUL3Q $19,DX,AX + MULQ 24(SI) + ADDQ AX,R9 + ADCQ DX,R10 + MOVQ 24(SI),DX + IMUL3Q $38,DX,AX + MULQ 32(SI) + ADDQ AX,R11 + ADCQ DX,R12 + MOVQ 32(SI),DX + IMUL3Q $19,DX,AX + MULQ 32(SI) + ADDQ AX,R13 + ADCQ DX,R14 + MOVQ $REDMASK51,SI + SHLQ $13,CX,R8 + ANDQ SI,CX + SHLQ $13,R9,R10 + ANDQ SI,R9 + ADDQ R8,R9 + SHLQ $13,R11,R12 + ANDQ SI,R11 + ADDQ R10,R11 + SHLQ $13,R13,R14 + ANDQ SI,R13 + ADDQ R12,R13 + SHLQ $13,R15,BX + ANDQ SI,R15 + ADDQ R14,R15 + IMUL3Q $19,BX,DX + ADDQ DX,CX + MOVQ CX,DX + SHRQ $51,DX + ADDQ R9,DX + ANDQ SI,CX + MOVQ DX,R8 + SHRQ $51,DX + ADDQ R11,DX + ANDQ SI,R8 + MOVQ DX,R9 + SHRQ $51,DX + ADDQ R13,DX + ANDQ SI,R9 + MOVQ DX,AX + SHRQ $51,DX + ADDQ R15,DX + ANDQ SI,AX + MOVQ DX,R10 + SHRQ $51,DX + IMUL3Q $19,DX,DX + ADDQ DX,CX + ANDQ SI,R10 + MOVQ CX,0(DI) + MOVQ R8,8(DI) + MOVQ R9,16(DI) + MOVQ AX,24(DI) + MOVQ R10,32(DI) + RET diff --git a/internal/crypto/curve25519/curve25519_generic.go b/internal/crypto/curve25519/curve25519_generic.go new file mode 100644 index 000000000..c43b13fc8 --- /dev/null +++ b/internal/crypto/curve25519/curve25519_generic.go @@ -0,0 +1,828 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package curve25519 + +import "encoding/binary" + +// This code is a port of the public domain, "ref10" implementation of +// curve25519 from SUPERCOP 20130419 by D. J. Bernstein. + +// fieldElement represents an element of the field GF(2^255 - 19). An element +// t, entries t[0]...t[9], represents the integer t[0]+2^26 t[1]+2^51 t[2]+2^77 +// t[3]+2^102 t[4]+...+2^230 t[9]. Bounds on each t[i] vary depending on +// context. +type fieldElement [10]int32 + +func feZero(fe *fieldElement) { + for i := range fe { + fe[i] = 0 + } +} + +func feOne(fe *fieldElement) { + feZero(fe) + fe[0] = 1 +} + +func feAdd(dst, a, b *fieldElement) { + for i := range dst { + dst[i] = a[i] + b[i] + } +} + +func feSub(dst, a, b *fieldElement) { + for i := range dst { + dst[i] = a[i] - b[i] + } +} + +func feCopy(dst, src *fieldElement) { + for i := range dst { + dst[i] = src[i] + } +} + +// feCSwap replaces (f,g) with (g,f) if b == 1; replaces (f,g) with (f,g) if b == 0. +// +// Preconditions: b in {0,1}. +func feCSwap(f, g *fieldElement, b int32) { + b = -b + for i := range f { + t := b & (f[i] ^ g[i]) + f[i] ^= t + g[i] ^= t + } +} + +// load3 reads a 24-bit, little-endian value from in. +func load3(in []byte) int64 { + var r int64 + r = int64(in[0]) + r |= int64(in[1]) << 8 + r |= int64(in[2]) << 16 + return r +} + +// load4 reads a 32-bit, little-endian value from in. +func load4(in []byte) int64 { + return int64(binary.LittleEndian.Uint32(in)) +} + +func feFromBytes(dst *fieldElement, src *[32]byte) { + h0 := load4(src[:]) + h1 := load3(src[4:]) << 6 + h2 := load3(src[7:]) << 5 + h3 := load3(src[10:]) << 3 + h4 := load3(src[13:]) << 2 + h5 := load4(src[16:]) + h6 := load3(src[20:]) << 7 + h7 := load3(src[23:]) << 5 + h8 := load3(src[26:]) << 4 + h9 := (load3(src[29:]) & 0x7fffff) << 2 + + var carry [10]int64 + carry[9] = (h9 + 1<<24) >> 25 + h0 += carry[9] * 19 + h9 -= carry[9] << 25 + carry[1] = (h1 + 1<<24) >> 25 + h2 += carry[1] + h1 -= carry[1] << 25 + carry[3] = (h3 + 1<<24) >> 25 + h4 += carry[3] + h3 -= carry[3] << 25 + carry[5] = (h5 + 1<<24) >> 25 + h6 += carry[5] + h5 -= carry[5] << 25 + carry[7] = (h7 + 1<<24) >> 25 + h8 += carry[7] + h7 -= carry[7] << 25 + + carry[0] = (h0 + 1<<25) >> 26 + h1 += carry[0] + h0 -= carry[0] << 26 + carry[2] = (h2 + 1<<25) >> 26 + h3 += carry[2] + h2 -= carry[2] << 26 + carry[4] = (h4 + 1<<25) >> 26 + h5 += carry[4] + h4 -= carry[4] << 26 + carry[6] = (h6 + 1<<25) >> 26 + h7 += carry[6] + h6 -= carry[6] << 26 + carry[8] = (h8 + 1<<25) >> 26 + h9 += carry[8] + h8 -= carry[8] << 26 + + dst[0] = int32(h0) + dst[1] = int32(h1) + dst[2] = int32(h2) + dst[3] = int32(h3) + dst[4] = int32(h4) + dst[5] = int32(h5) + dst[6] = int32(h6) + dst[7] = int32(h7) + dst[8] = int32(h8) + dst[9] = int32(h9) +} + +// feToBytes marshals h to s. +// Preconditions: +// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc. +// +// Write p=2^255-19; q=floor(h/p). +// Basic claim: q = floor(2^(-255)(h + 19 2^(-25)h9 + 2^(-1))). +// +// Proof: +// Have |h|<=p so |q|<=1 so |19^2 2^(-255) q|<1/4. +// Also have |h-2^230 h9|<2^230 so |19 2^(-255)(h-2^230 h9)|<1/4. +// +// Write y=2^(-1)-19^2 2^(-255)q-19 2^(-255)(h-2^230 h9). +// Then 0> 25 + q = (h[0] + q) >> 26 + q = (h[1] + q) >> 25 + q = (h[2] + q) >> 26 + q = (h[3] + q) >> 25 + q = (h[4] + q) >> 26 + q = (h[5] + q) >> 25 + q = (h[6] + q) >> 26 + q = (h[7] + q) >> 25 + q = (h[8] + q) >> 26 + q = (h[9] + q) >> 25 + + // Goal: Output h-(2^255-19)q, which is between 0 and 2^255-20. + h[0] += 19 * q + // Goal: Output h-2^255 q, which is between 0 and 2^255-20. + + carry[0] = h[0] >> 26 + h[1] += carry[0] + h[0] -= carry[0] << 26 + carry[1] = h[1] >> 25 + h[2] += carry[1] + h[1] -= carry[1] << 25 + carry[2] = h[2] >> 26 + h[3] += carry[2] + h[2] -= carry[2] << 26 + carry[3] = h[3] >> 25 + h[4] += carry[3] + h[3] -= carry[3] << 25 + carry[4] = h[4] >> 26 + h[5] += carry[4] + h[4] -= carry[4] << 26 + carry[5] = h[5] >> 25 + h[6] += carry[5] + h[5] -= carry[5] << 25 + carry[6] = h[6] >> 26 + h[7] += carry[6] + h[6] -= carry[6] << 26 + carry[7] = h[7] >> 25 + h[8] += carry[7] + h[7] -= carry[7] << 25 + carry[8] = h[8] >> 26 + h[9] += carry[8] + h[8] -= carry[8] << 26 + carry[9] = h[9] >> 25 + h[9] -= carry[9] << 25 + // h10 = carry9 + + // Goal: Output h[0]+...+2^255 h10-2^255 q, which is between 0 and 2^255-20. + // Have h[0]+...+2^230 h[9] between 0 and 2^255-1; + // evidently 2^255 h10-2^255 q = 0. + // Goal: Output h[0]+...+2^230 h[9]. + + s[0] = byte(h[0] >> 0) + s[1] = byte(h[0] >> 8) + s[2] = byte(h[0] >> 16) + s[3] = byte((h[0] >> 24) | (h[1] << 2)) + s[4] = byte(h[1] >> 6) + s[5] = byte(h[1] >> 14) + s[6] = byte((h[1] >> 22) | (h[2] << 3)) + s[7] = byte(h[2] >> 5) + s[8] = byte(h[2] >> 13) + s[9] = byte((h[2] >> 21) | (h[3] << 5)) + s[10] = byte(h[3] >> 3) + s[11] = byte(h[3] >> 11) + s[12] = byte((h[3] >> 19) | (h[4] << 6)) + s[13] = byte(h[4] >> 2) + s[14] = byte(h[4] >> 10) + s[15] = byte(h[4] >> 18) + s[16] = byte(h[5] >> 0) + s[17] = byte(h[5] >> 8) + s[18] = byte(h[5] >> 16) + s[19] = byte((h[5] >> 24) | (h[6] << 1)) + s[20] = byte(h[6] >> 7) + s[21] = byte(h[6] >> 15) + s[22] = byte((h[6] >> 23) | (h[7] << 3)) + s[23] = byte(h[7] >> 5) + s[24] = byte(h[7] >> 13) + s[25] = byte((h[7] >> 21) | (h[8] << 4)) + s[26] = byte(h[8] >> 4) + s[27] = byte(h[8] >> 12) + s[28] = byte((h[8] >> 20) | (h[9] << 6)) + s[29] = byte(h[9] >> 2) + s[30] = byte(h[9] >> 10) + s[31] = byte(h[9] >> 18) +} + +// feMul calculates h = f * g +// Can overlap h with f or g. +// +// Preconditions: +// |f| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc. +// |g| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc. +// +// Postconditions: +// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc. +// +// Notes on implementation strategy: +// +// Using schoolbook multiplication. +// Karatsuba would save a little in some cost models. +// +// Most multiplications by 2 and 19 are 32-bit precomputations; +// cheaper than 64-bit postcomputations. +// +// There is one remaining multiplication by 19 in the carry chain; +// one *19 precomputation can be merged into this, +// but the resulting data flow is considerably less clean. +// +// There are 12 carries below. +// 10 of them are 2-way parallelizable and vectorizable. +// Can get away with 11 carries, but then data flow is much deeper. +// +// With tighter constraints on inputs can squeeze carries into int32. +func feMul(h, f, g *fieldElement) { + f0 := f[0] + f1 := f[1] + f2 := f[2] + f3 := f[3] + f4 := f[4] + f5 := f[5] + f6 := f[6] + f7 := f[7] + f8 := f[8] + f9 := f[9] + g0 := g[0] + g1 := g[1] + g2 := g[2] + g3 := g[3] + g4 := g[4] + g5 := g[5] + g6 := g[6] + g7 := g[7] + g8 := g[8] + g9 := g[9] + g1_19 := 19 * g1 // 1.4*2^29 + g2_19 := 19 * g2 // 1.4*2^30; still ok + g3_19 := 19 * g3 + g4_19 := 19 * g4 + g5_19 := 19 * g5 + g6_19 := 19 * g6 + g7_19 := 19 * g7 + g8_19 := 19 * g8 + g9_19 := 19 * g9 + f1_2 := 2 * f1 + f3_2 := 2 * f3 + f5_2 := 2 * f5 + f7_2 := 2 * f7 + f9_2 := 2 * f9 + f0g0 := int64(f0) * int64(g0) + f0g1 := int64(f0) * int64(g1) + f0g2 := int64(f0) * int64(g2) + f0g3 := int64(f0) * int64(g3) + f0g4 := int64(f0) * int64(g4) + f0g5 := int64(f0) * int64(g5) + f0g6 := int64(f0) * int64(g6) + f0g7 := int64(f0) * int64(g7) + f0g8 := int64(f0) * int64(g8) + f0g9 := int64(f0) * int64(g9) + f1g0 := int64(f1) * int64(g0) + f1g1_2 := int64(f1_2) * int64(g1) + f1g2 := int64(f1) * int64(g2) + f1g3_2 := int64(f1_2) * int64(g3) + f1g4 := int64(f1) * int64(g4) + f1g5_2 := int64(f1_2) * int64(g5) + f1g6 := int64(f1) * int64(g6) + f1g7_2 := int64(f1_2) * int64(g7) + f1g8 := int64(f1) * int64(g8) + f1g9_38 := int64(f1_2) * int64(g9_19) + f2g0 := int64(f2) * int64(g0) + f2g1 := int64(f2) * int64(g1) + f2g2 := int64(f2) * int64(g2) + f2g3 := int64(f2) * int64(g3) + f2g4 := int64(f2) * int64(g4) + f2g5 := int64(f2) * int64(g5) + f2g6 := int64(f2) * int64(g6) + f2g7 := int64(f2) * int64(g7) + f2g8_19 := int64(f2) * int64(g8_19) + f2g9_19 := int64(f2) * int64(g9_19) + f3g0 := int64(f3) * int64(g0) + f3g1_2 := int64(f3_2) * int64(g1) + f3g2 := int64(f3) * int64(g2) + f3g3_2 := int64(f3_2) * int64(g3) + f3g4 := int64(f3) * int64(g4) + f3g5_2 := int64(f3_2) * int64(g5) + f3g6 := int64(f3) * int64(g6) + f3g7_38 := int64(f3_2) * int64(g7_19) + f3g8_19 := int64(f3) * int64(g8_19) + f3g9_38 := int64(f3_2) * int64(g9_19) + f4g0 := int64(f4) * int64(g0) + f4g1 := int64(f4) * int64(g1) + f4g2 := int64(f4) * int64(g2) + f4g3 := int64(f4) * int64(g3) + f4g4 := int64(f4) * int64(g4) + f4g5 := int64(f4) * int64(g5) + f4g6_19 := int64(f4) * int64(g6_19) + f4g7_19 := int64(f4) * int64(g7_19) + f4g8_19 := int64(f4) * int64(g8_19) + f4g9_19 := int64(f4) * int64(g9_19) + f5g0 := int64(f5) * int64(g0) + f5g1_2 := int64(f5_2) * int64(g1) + f5g2 := int64(f5) * int64(g2) + f5g3_2 := int64(f5_2) * int64(g3) + f5g4 := int64(f5) * int64(g4) + f5g5_38 := int64(f5_2) * int64(g5_19) + f5g6_19 := int64(f5) * int64(g6_19) + f5g7_38 := int64(f5_2) * int64(g7_19) + f5g8_19 := int64(f5) * int64(g8_19) + f5g9_38 := int64(f5_2) * int64(g9_19) + f6g0 := int64(f6) * int64(g0) + f6g1 := int64(f6) * int64(g1) + f6g2 := int64(f6) * int64(g2) + f6g3 := int64(f6) * int64(g3) + f6g4_19 := int64(f6) * int64(g4_19) + f6g5_19 := int64(f6) * int64(g5_19) + f6g6_19 := int64(f6) * int64(g6_19) + f6g7_19 := int64(f6) * int64(g7_19) + f6g8_19 := int64(f6) * int64(g8_19) + f6g9_19 := int64(f6) * int64(g9_19) + f7g0 := int64(f7) * int64(g0) + f7g1_2 := int64(f7_2) * int64(g1) + f7g2 := int64(f7) * int64(g2) + f7g3_38 := int64(f7_2) * int64(g3_19) + f7g4_19 := int64(f7) * int64(g4_19) + f7g5_38 := int64(f7_2) * int64(g5_19) + f7g6_19 := int64(f7) * int64(g6_19) + f7g7_38 := int64(f7_2) * int64(g7_19) + f7g8_19 := int64(f7) * int64(g8_19) + f7g9_38 := int64(f7_2) * int64(g9_19) + f8g0 := int64(f8) * int64(g0) + f8g1 := int64(f8) * int64(g1) + f8g2_19 := int64(f8) * int64(g2_19) + f8g3_19 := int64(f8) * int64(g3_19) + f8g4_19 := int64(f8) * int64(g4_19) + f8g5_19 := int64(f8) * int64(g5_19) + f8g6_19 := int64(f8) * int64(g6_19) + f8g7_19 := int64(f8) * int64(g7_19) + f8g8_19 := int64(f8) * int64(g8_19) + f8g9_19 := int64(f8) * int64(g9_19) + f9g0 := int64(f9) * int64(g0) + f9g1_38 := int64(f9_2) * int64(g1_19) + f9g2_19 := int64(f9) * int64(g2_19) + f9g3_38 := int64(f9_2) * int64(g3_19) + f9g4_19 := int64(f9) * int64(g4_19) + f9g5_38 := int64(f9_2) * int64(g5_19) + f9g6_19 := int64(f9) * int64(g6_19) + f9g7_38 := int64(f9_2) * int64(g7_19) + f9g8_19 := int64(f9) * int64(g8_19) + f9g9_38 := int64(f9_2) * int64(g9_19) + h0 := f0g0 + f1g9_38 + f2g8_19 + f3g7_38 + f4g6_19 + f5g5_38 + f6g4_19 + f7g3_38 + f8g2_19 + f9g1_38 + h1 := f0g1 + f1g0 + f2g9_19 + f3g8_19 + f4g7_19 + f5g6_19 + f6g5_19 + f7g4_19 + f8g3_19 + f9g2_19 + h2 := f0g2 + f1g1_2 + f2g0 + f3g9_38 + f4g8_19 + f5g7_38 + f6g6_19 + f7g5_38 + f8g4_19 + f9g3_38 + h3 := f0g3 + f1g2 + f2g1 + f3g0 + f4g9_19 + f5g8_19 + f6g7_19 + f7g6_19 + f8g5_19 + f9g4_19 + h4 := f0g4 + f1g3_2 + f2g2 + f3g1_2 + f4g0 + f5g9_38 + f6g8_19 + f7g7_38 + f8g6_19 + f9g5_38 + h5 := f0g5 + f1g4 + f2g3 + f3g2 + f4g1 + f5g0 + f6g9_19 + f7g8_19 + f8g7_19 + f9g6_19 + h6 := f0g6 + f1g5_2 + f2g4 + f3g3_2 + f4g2 + f5g1_2 + f6g0 + f7g9_38 + f8g8_19 + f9g7_38 + h7 := f0g7 + f1g6 + f2g5 + f3g4 + f4g3 + f5g2 + f6g1 + f7g0 + f8g9_19 + f9g8_19 + h8 := f0g8 + f1g7_2 + f2g6 + f3g5_2 + f4g4 + f5g3_2 + f6g2 + f7g1_2 + f8g0 + f9g9_38 + h9 := f0g9 + f1g8 + f2g7 + f3g6 + f4g5 + f5g4 + f6g3 + f7g2 + f8g1 + f9g0 + var carry [10]int64 + + // |h0| <= (1.1*1.1*2^52*(1+19+19+19+19)+1.1*1.1*2^50*(38+38+38+38+38)) + // i.e. |h0| <= 1.2*2^59; narrower ranges for h2, h4, h6, h8 + // |h1| <= (1.1*1.1*2^51*(1+1+19+19+19+19+19+19+19+19)) + // i.e. |h1| <= 1.5*2^58; narrower ranges for h3, h5, h7, h9 + + carry[0] = (h0 + (1 << 25)) >> 26 + h1 += carry[0] + h0 -= carry[0] << 26 + carry[4] = (h4 + (1 << 25)) >> 26 + h5 += carry[4] + h4 -= carry[4] << 26 + // |h0| <= 2^25 + // |h4| <= 2^25 + // |h1| <= 1.51*2^58 + // |h5| <= 1.51*2^58 + + carry[1] = (h1 + (1 << 24)) >> 25 + h2 += carry[1] + h1 -= carry[1] << 25 + carry[5] = (h5 + (1 << 24)) >> 25 + h6 += carry[5] + h5 -= carry[5] << 25 + // |h1| <= 2^24; from now on fits into int32 + // |h5| <= 2^24; from now on fits into int32 + // |h2| <= 1.21*2^59 + // |h6| <= 1.21*2^59 + + carry[2] = (h2 + (1 << 25)) >> 26 + h3 += carry[2] + h2 -= carry[2] << 26 + carry[6] = (h6 + (1 << 25)) >> 26 + h7 += carry[6] + h6 -= carry[6] << 26 + // |h2| <= 2^25; from now on fits into int32 unchanged + // |h6| <= 2^25; from now on fits into int32 unchanged + // |h3| <= 1.51*2^58 + // |h7| <= 1.51*2^58 + + carry[3] = (h3 + (1 << 24)) >> 25 + h4 += carry[3] + h3 -= carry[3] << 25 + carry[7] = (h7 + (1 << 24)) >> 25 + h8 += carry[7] + h7 -= carry[7] << 25 + // |h3| <= 2^24; from now on fits into int32 unchanged + // |h7| <= 2^24; from now on fits into int32 unchanged + // |h4| <= 1.52*2^33 + // |h8| <= 1.52*2^33 + + carry[4] = (h4 + (1 << 25)) >> 26 + h5 += carry[4] + h4 -= carry[4] << 26 + carry[8] = (h8 + (1 << 25)) >> 26 + h9 += carry[8] + h8 -= carry[8] << 26 + // |h4| <= 2^25; from now on fits into int32 unchanged + // |h8| <= 2^25; from now on fits into int32 unchanged + // |h5| <= 1.01*2^24 + // |h9| <= 1.51*2^58 + + carry[9] = (h9 + (1 << 24)) >> 25 + h0 += carry[9] * 19 + h9 -= carry[9] << 25 + // |h9| <= 2^24; from now on fits into int32 unchanged + // |h0| <= 1.8*2^37 + + carry[0] = (h0 + (1 << 25)) >> 26 + h1 += carry[0] + h0 -= carry[0] << 26 + // |h0| <= 2^25; from now on fits into int32 unchanged + // |h1| <= 1.01*2^24 + + h[0] = int32(h0) + h[1] = int32(h1) + h[2] = int32(h2) + h[3] = int32(h3) + h[4] = int32(h4) + h[5] = int32(h5) + h[6] = int32(h6) + h[7] = int32(h7) + h[8] = int32(h8) + h[9] = int32(h9) +} + +// feSquare calculates h = f*f. Can overlap h with f. +// +// Preconditions: +// |f| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc. +// +// Postconditions: +// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc. +func feSquare(h, f *fieldElement) { + f0 := f[0] + f1 := f[1] + f2 := f[2] + f3 := f[3] + f4 := f[4] + f5 := f[5] + f6 := f[6] + f7 := f[7] + f8 := f[8] + f9 := f[9] + f0_2 := 2 * f0 + f1_2 := 2 * f1 + f2_2 := 2 * f2 + f3_2 := 2 * f3 + f4_2 := 2 * f4 + f5_2 := 2 * f5 + f6_2 := 2 * f6 + f7_2 := 2 * f7 + f5_38 := 38 * f5 // 1.31*2^30 + f6_19 := 19 * f6 // 1.31*2^30 + f7_38 := 38 * f7 // 1.31*2^30 + f8_19 := 19 * f8 // 1.31*2^30 + f9_38 := 38 * f9 // 1.31*2^30 + f0f0 := int64(f0) * int64(f0) + f0f1_2 := int64(f0_2) * int64(f1) + f0f2_2 := int64(f0_2) * int64(f2) + f0f3_2 := int64(f0_2) * int64(f3) + f0f4_2 := int64(f0_2) * int64(f4) + f0f5_2 := int64(f0_2) * int64(f5) + f0f6_2 := int64(f0_2) * int64(f6) + f0f7_2 := int64(f0_2) * int64(f7) + f0f8_2 := int64(f0_2) * int64(f8) + f0f9_2 := int64(f0_2) * int64(f9) + f1f1_2 := int64(f1_2) * int64(f1) + f1f2_2 := int64(f1_2) * int64(f2) + f1f3_4 := int64(f1_2) * int64(f3_2) + f1f4_2 := int64(f1_2) * int64(f4) + f1f5_4 := int64(f1_2) * int64(f5_2) + f1f6_2 := int64(f1_2) * int64(f6) + f1f7_4 := int64(f1_2) * int64(f7_2) + f1f8_2 := int64(f1_2) * int64(f8) + f1f9_76 := int64(f1_2) * int64(f9_38) + f2f2 := int64(f2) * int64(f2) + f2f3_2 := int64(f2_2) * int64(f3) + f2f4_2 := int64(f2_2) * int64(f4) + f2f5_2 := int64(f2_2) * int64(f5) + f2f6_2 := int64(f2_2) * int64(f6) + f2f7_2 := int64(f2_2) * int64(f7) + f2f8_38 := int64(f2_2) * int64(f8_19) + f2f9_38 := int64(f2) * int64(f9_38) + f3f3_2 := int64(f3_2) * int64(f3) + f3f4_2 := int64(f3_2) * int64(f4) + f3f5_4 := int64(f3_2) * int64(f5_2) + f3f6_2 := int64(f3_2) * int64(f6) + f3f7_76 := int64(f3_2) * int64(f7_38) + f3f8_38 := int64(f3_2) * int64(f8_19) + f3f9_76 := int64(f3_2) * int64(f9_38) + f4f4 := int64(f4) * int64(f4) + f4f5_2 := int64(f4_2) * int64(f5) + f4f6_38 := int64(f4_2) * int64(f6_19) + f4f7_38 := int64(f4) * int64(f7_38) + f4f8_38 := int64(f4_2) * int64(f8_19) + f4f9_38 := int64(f4) * int64(f9_38) + f5f5_38 := int64(f5) * int64(f5_38) + f5f6_38 := int64(f5_2) * int64(f6_19) + f5f7_76 := int64(f5_2) * int64(f7_38) + f5f8_38 := int64(f5_2) * int64(f8_19) + f5f9_76 := int64(f5_2) * int64(f9_38) + f6f6_19 := int64(f6) * int64(f6_19) + f6f7_38 := int64(f6) * int64(f7_38) + f6f8_38 := int64(f6_2) * int64(f8_19) + f6f9_38 := int64(f6) * int64(f9_38) + f7f7_38 := int64(f7) * int64(f7_38) + f7f8_38 := int64(f7_2) * int64(f8_19) + f7f9_76 := int64(f7_2) * int64(f9_38) + f8f8_19 := int64(f8) * int64(f8_19) + f8f9_38 := int64(f8) * int64(f9_38) + f9f9_38 := int64(f9) * int64(f9_38) + h0 := f0f0 + f1f9_76 + f2f8_38 + f3f7_76 + f4f6_38 + f5f5_38 + h1 := f0f1_2 + f2f9_38 + f3f8_38 + f4f7_38 + f5f6_38 + h2 := f0f2_2 + f1f1_2 + f3f9_76 + f4f8_38 + f5f7_76 + f6f6_19 + h3 := f0f3_2 + f1f2_2 + f4f9_38 + f5f8_38 + f6f7_38 + h4 := f0f4_2 + f1f3_4 + f2f2 + f5f9_76 + f6f8_38 + f7f7_38 + h5 := f0f5_2 + f1f4_2 + f2f3_2 + f6f9_38 + f7f8_38 + h6 := f0f6_2 + f1f5_4 + f2f4_2 + f3f3_2 + f7f9_76 + f8f8_19 + h7 := f0f7_2 + f1f6_2 + f2f5_2 + f3f4_2 + f8f9_38 + h8 := f0f8_2 + f1f7_4 + f2f6_2 + f3f5_4 + f4f4 + f9f9_38 + h9 := f0f9_2 + f1f8_2 + f2f7_2 + f3f6_2 + f4f5_2 + var carry [10]int64 + + carry[0] = (h0 + (1 << 25)) >> 26 + h1 += carry[0] + h0 -= carry[0] << 26 + carry[4] = (h4 + (1 << 25)) >> 26 + h5 += carry[4] + h4 -= carry[4] << 26 + + carry[1] = (h1 + (1 << 24)) >> 25 + h2 += carry[1] + h1 -= carry[1] << 25 + carry[5] = (h5 + (1 << 24)) >> 25 + h6 += carry[5] + h5 -= carry[5] << 25 + + carry[2] = (h2 + (1 << 25)) >> 26 + h3 += carry[2] + h2 -= carry[2] << 26 + carry[6] = (h6 + (1 << 25)) >> 26 + h7 += carry[6] + h6 -= carry[6] << 26 + + carry[3] = (h3 + (1 << 24)) >> 25 + h4 += carry[3] + h3 -= carry[3] << 25 + carry[7] = (h7 + (1 << 24)) >> 25 + h8 += carry[7] + h7 -= carry[7] << 25 + + carry[4] = (h4 + (1 << 25)) >> 26 + h5 += carry[4] + h4 -= carry[4] << 26 + carry[8] = (h8 + (1 << 25)) >> 26 + h9 += carry[8] + h8 -= carry[8] << 26 + + carry[9] = (h9 + (1 << 24)) >> 25 + h0 += carry[9] * 19 + h9 -= carry[9] << 25 + + carry[0] = (h0 + (1 << 25)) >> 26 + h1 += carry[0] + h0 -= carry[0] << 26 + + h[0] = int32(h0) + h[1] = int32(h1) + h[2] = int32(h2) + h[3] = int32(h3) + h[4] = int32(h4) + h[5] = int32(h5) + h[6] = int32(h6) + h[7] = int32(h7) + h[8] = int32(h8) + h[9] = int32(h9) +} + +// feMul121666 calculates h = f * 121666. Can overlap h with f. +// +// Preconditions: +// |f| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc. +// +// Postconditions: +// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc. +func feMul121666(h, f *fieldElement) { + h0 := int64(f[0]) * 121666 + h1 := int64(f[1]) * 121666 + h2 := int64(f[2]) * 121666 + h3 := int64(f[3]) * 121666 + h4 := int64(f[4]) * 121666 + h5 := int64(f[5]) * 121666 + h6 := int64(f[6]) * 121666 + h7 := int64(f[7]) * 121666 + h8 := int64(f[8]) * 121666 + h9 := int64(f[9]) * 121666 + var carry [10]int64 + + carry[9] = (h9 + (1 << 24)) >> 25 + h0 += carry[9] * 19 + h9 -= carry[9] << 25 + carry[1] = (h1 + (1 << 24)) >> 25 + h2 += carry[1] + h1 -= carry[1] << 25 + carry[3] = (h3 + (1 << 24)) >> 25 + h4 += carry[3] + h3 -= carry[3] << 25 + carry[5] = (h5 + (1 << 24)) >> 25 + h6 += carry[5] + h5 -= carry[5] << 25 + carry[7] = (h7 + (1 << 24)) >> 25 + h8 += carry[7] + h7 -= carry[7] << 25 + + carry[0] = (h0 + (1 << 25)) >> 26 + h1 += carry[0] + h0 -= carry[0] << 26 + carry[2] = (h2 + (1 << 25)) >> 26 + h3 += carry[2] + h2 -= carry[2] << 26 + carry[4] = (h4 + (1 << 25)) >> 26 + h5 += carry[4] + h4 -= carry[4] << 26 + carry[6] = (h6 + (1 << 25)) >> 26 + h7 += carry[6] + h6 -= carry[6] << 26 + carry[8] = (h8 + (1 << 25)) >> 26 + h9 += carry[8] + h8 -= carry[8] << 26 + + h[0] = int32(h0) + h[1] = int32(h1) + h[2] = int32(h2) + h[3] = int32(h3) + h[4] = int32(h4) + h[5] = int32(h5) + h[6] = int32(h6) + h[7] = int32(h7) + h[8] = int32(h8) + h[9] = int32(h9) +} + +// feInvert sets out = z^-1. +func feInvert(out, z *fieldElement) { + var t0, t1, t2, t3 fieldElement + var i int + + feSquare(&t0, z) + for i = 1; i < 1; i++ { + feSquare(&t0, &t0) + } + feSquare(&t1, &t0) + for i = 1; i < 2; i++ { + feSquare(&t1, &t1) + } + feMul(&t1, z, &t1) + feMul(&t0, &t0, &t1) + feSquare(&t2, &t0) + for i = 1; i < 1; i++ { + feSquare(&t2, &t2) + } + feMul(&t1, &t1, &t2) + feSquare(&t2, &t1) + for i = 1; i < 5; i++ { + feSquare(&t2, &t2) + } + feMul(&t1, &t2, &t1) + feSquare(&t2, &t1) + for i = 1; i < 10; i++ { + feSquare(&t2, &t2) + } + feMul(&t2, &t2, &t1) + feSquare(&t3, &t2) + for i = 1; i < 20; i++ { + feSquare(&t3, &t3) + } + feMul(&t2, &t3, &t2) + feSquare(&t2, &t2) + for i = 1; i < 10; i++ { + feSquare(&t2, &t2) + } + feMul(&t1, &t2, &t1) + feSquare(&t2, &t1) + for i = 1; i < 50; i++ { + feSquare(&t2, &t2) + } + feMul(&t2, &t2, &t1) + feSquare(&t3, &t2) + for i = 1; i < 100; i++ { + feSquare(&t3, &t3) + } + feMul(&t2, &t3, &t2) + feSquare(&t2, &t2) + for i = 1; i < 50; i++ { + feSquare(&t2, &t2) + } + feMul(&t1, &t2, &t1) + feSquare(&t1, &t1) + for i = 1; i < 5; i++ { + feSquare(&t1, &t1) + } + feMul(out, &t1, &t0) +} + +func scalarMultGeneric(out, in, base *[32]byte) { + var e [32]byte + + copy(e[:], in[:]) + e[0] &= 248 + e[31] &= 127 + e[31] |= 64 + + var x1, x2, z2, x3, z3, tmp0, tmp1 fieldElement + feFromBytes(&x1, base) + feOne(&x2) + feCopy(&x3, &x1) + feOne(&z3) + + swap := int32(0) + for pos := 254; pos >= 0; pos-- { + b := e[pos/8] >> uint(pos&7) + b &= 1 + swap ^= int32(b) + feCSwap(&x2, &x3, swap) + feCSwap(&z2, &z3, swap) + swap = int32(b) + + feSub(&tmp0, &x3, &z3) + feSub(&tmp1, &x2, &z2) + feAdd(&x2, &x2, &z2) + feAdd(&z2, &x3, &z3) + feMul(&z3, &tmp0, &x2) + feMul(&z2, &z2, &tmp1) + feSquare(&tmp0, &tmp1) + feSquare(&tmp1, &x2) + feAdd(&x3, &z3, &z2) + feSub(&z2, &z3, &z2) + feMul(&x2, &tmp1, &tmp0) + feSub(&tmp1, &tmp1, &tmp0) + feSquare(&z2, &z2) + feMul121666(&z3, &tmp1) + feSquare(&x3, &x3) + feAdd(&tmp0, &tmp0, &z3) + feMul(&z3, &x1, &z2) + feMul(&z2, &tmp1, &tmp0) + } + + feCSwap(&x2, &x3, swap) + feCSwap(&z2, &z3, swap) + + feInvert(&z2, &z2) + feMul(&x2, &x2, &z2) + feToBytes(out, &x2) +} diff --git a/internal/crypto/curve25519/curve25519_noasm.go b/internal/crypto/curve25519/curve25519_noasm.go new file mode 100644 index 000000000..047d49afc --- /dev/null +++ b/internal/crypto/curve25519/curve25519_noasm.go @@ -0,0 +1,11 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !amd64 gccgo appengine purego + +package curve25519 + +func scalarMult(out, in, base *[32]byte) { + scalarMultGeneric(out, in, base) +} diff --git a/internal/crypto/ed25519/ed25519.go b/internal/crypto/ed25519/ed25519.go new file mode 100644 index 000000000..c7f8c7e64 --- /dev/null +++ b/internal/crypto/ed25519/ed25519.go @@ -0,0 +1,222 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// In Go 1.13, the ed25519 package was promoted to the standard library as +// crypto/ed25519, and this package became a wrapper for the standard library one. +// +// +build !go1.13 + +// Package ed25519 implements the Ed25519 signature algorithm. See +// https://ed25519.cr.yp.to/. +// +// These functions are also compatible with the “Ed25519” function defined in +// RFC 8032. However, unlike RFC 8032's formulation, this package's private key +// representation includes a public key suffix to make multiple signing +// operations with the same key more efficient. This package refers to the RFC +// 8032 private key as the “seed”. +package ed25519 + +// This code is a port of the public domain, “ref10” implementation of ed25519 +// from SUPERCOP. + +import ( + "bytes" + "crypto" + cryptorand "crypto/rand" + "crypto/sha512" + "errors" + "io" + "strconv" + + "golang.org/x/crypto/ed25519/internal/edwards25519" +) + +const ( + // PublicKeySize is the size, in bytes, of public keys as used in this package. + PublicKeySize = 32 + // PrivateKeySize is the size, in bytes, of private keys as used in this package. + PrivateKeySize = 64 + // SignatureSize is the size, in bytes, of signatures generated and verified by this package. + SignatureSize = 64 + // SeedSize is the size, in bytes, of private key seeds. These are the private key representations used by RFC 8032. + SeedSize = 32 +) + +// PublicKey is the type of Ed25519 public keys. +type PublicKey []byte + +// PrivateKey is the type of Ed25519 private keys. It implements crypto.Signer. +type PrivateKey []byte + +// Public returns the PublicKey corresponding to priv. +func (priv PrivateKey) Public() crypto.PublicKey { + publicKey := make([]byte, PublicKeySize) + copy(publicKey, priv[32:]) + return PublicKey(publicKey) +} + +// Seed returns the private key seed corresponding to priv. It is provided for +// interoperability with RFC 8032. RFC 8032's private keys correspond to seeds +// in this package. +func (priv PrivateKey) Seed() []byte { + seed := make([]byte, SeedSize) + copy(seed, priv[:32]) + return seed +} + +// Sign signs the given message with priv. +// Ed25519 performs two passes over messages to be signed and therefore cannot +// handle pre-hashed messages. Thus opts.HashFunc() must return zero to +// indicate the message hasn't been hashed. This can be achieved by passing +// crypto.Hash(0) as the value for opts. +func (priv PrivateKey) Sign(rand io.Reader, message []byte, opts crypto.SignerOpts) (signature []byte, err error) { + if opts.HashFunc() != crypto.Hash(0) { + return nil, errors.New("ed25519: cannot sign hashed message") + } + + return Sign(priv, message), nil +} + +// GenerateKey generates a public/private key pair using entropy from rand. +// If rand is nil, crypto/rand.Reader will be used. +func GenerateKey(rand io.Reader) (PublicKey, PrivateKey, error) { + if rand == nil { + rand = cryptorand.Reader + } + + seed := make([]byte, SeedSize) + if _, err := io.ReadFull(rand, seed); err != nil { + return nil, nil, err + } + + privateKey := NewKeyFromSeed(seed) + publicKey := make([]byte, PublicKeySize) + copy(publicKey, privateKey[32:]) + + return publicKey, privateKey, nil +} + +// NewKeyFromSeed calculates a private key from a seed. It will panic if +// len(seed) is not SeedSize. This function is provided for interoperability +// with RFC 8032. RFC 8032's private keys correspond to seeds in this +// package. +func NewKeyFromSeed(seed []byte) PrivateKey { + if l := len(seed); l != SeedSize { + panic("ed25519: bad seed length: " + strconv.Itoa(l)) + } + + digest := sha512.Sum512(seed) + digest[0] &= 248 + digest[31] &= 127 + digest[31] |= 64 + + var A edwards25519.ExtendedGroupElement + var hBytes [32]byte + copy(hBytes[:], digest[:]) + edwards25519.GeScalarMultBase(&A, &hBytes) + var publicKeyBytes [32]byte + A.ToBytes(&publicKeyBytes) + + privateKey := make([]byte, PrivateKeySize) + copy(privateKey, seed) + copy(privateKey[32:], publicKeyBytes[:]) + + return privateKey +} + +// Sign signs the message with privateKey and returns a signature. It will +// panic if len(privateKey) is not PrivateKeySize. +func Sign(privateKey PrivateKey, message []byte) []byte { + if l := len(privateKey); l != PrivateKeySize { + panic("ed25519: bad private key length: " + strconv.Itoa(l)) + } + + h := sha512.New() + h.Write(privateKey[:32]) + + var digest1, messageDigest, hramDigest [64]byte + var expandedSecretKey [32]byte + h.Sum(digest1[:0]) + copy(expandedSecretKey[:], digest1[:]) + expandedSecretKey[0] &= 248 + expandedSecretKey[31] &= 63 + expandedSecretKey[31] |= 64 + + h.Reset() + h.Write(digest1[32:]) + h.Write(message) + h.Sum(messageDigest[:0]) + + var messageDigestReduced [32]byte + edwards25519.ScReduce(&messageDigestReduced, &messageDigest) + var R edwards25519.ExtendedGroupElement + edwards25519.GeScalarMultBase(&R, &messageDigestReduced) + + var encodedR [32]byte + R.ToBytes(&encodedR) + + h.Reset() + h.Write(encodedR[:]) + h.Write(privateKey[32:]) + h.Write(message) + h.Sum(hramDigest[:0]) + var hramDigestReduced [32]byte + edwards25519.ScReduce(&hramDigestReduced, &hramDigest) + + var s [32]byte + edwards25519.ScMulAdd(&s, &hramDigestReduced, &expandedSecretKey, &messageDigestReduced) + + signature := make([]byte, SignatureSize) + copy(signature[:], encodedR[:]) + copy(signature[32:], s[:]) + + return signature +} + +// Verify reports whether sig is a valid signature of message by publicKey. It +// will panic if len(publicKey) is not PublicKeySize. +func Verify(publicKey PublicKey, message, sig []byte) bool { + if l := len(publicKey); l != PublicKeySize { + panic("ed25519: bad public key length: " + strconv.Itoa(l)) + } + + if len(sig) != SignatureSize || sig[63]&224 != 0 { + return false + } + + var A edwards25519.ExtendedGroupElement + var publicKeyBytes [32]byte + copy(publicKeyBytes[:], publicKey) + if !A.FromBytes(&publicKeyBytes) { + return false + } + edwards25519.FeNeg(&A.X, &A.X) + edwards25519.FeNeg(&A.T, &A.T) + + h := sha512.New() + h.Write(sig[:32]) + h.Write(publicKey[:]) + h.Write(message) + var digest [64]byte + h.Sum(digest[:0]) + + var hReduced [32]byte + edwards25519.ScReduce(&hReduced, &digest) + + var R edwards25519.ProjectiveGroupElement + var s [32]byte + copy(s[:], sig[32:]) + + // https://tools.ietf.org/html/rfc8032#section-5.1.7 requires that s be in + // the range [0, order) in order to prevent signature malleability. + if !edwards25519.ScMinimal(&s) { + return false + } + + edwards25519.GeDoubleScalarMultVartime(&R, &hReduced, &A, &s) + + var checkR [32]byte + R.ToBytes(&checkR) + return bytes.Equal(sig[:32], checkR[:]) +} diff --git a/internal/crypto/ed25519/ed25519_go113.go b/internal/crypto/ed25519/ed25519_go113.go new file mode 100644 index 000000000..d1448d8d2 --- /dev/null +++ b/internal/crypto/ed25519/ed25519_go113.go @@ -0,0 +1,73 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.13 + +// Package ed25519 implements the Ed25519 signature algorithm. See +// https://ed25519.cr.yp.to/. +// +// These functions are also compatible with the “Ed25519” function defined in +// RFC 8032. However, unlike RFC 8032's formulation, this package's private key +// representation includes a public key suffix to make multiple signing +// operations with the same key more efficient. This package refers to the RFC +// 8032 private key as the “seed”. +// +// Beginning with Go 1.13, the functionality of this package was moved to the +// standard library as crypto/ed25519. This package only acts as a compatibility +// wrapper. +package ed25519 + +import ( + "crypto/ed25519" + "io" +) + +const ( + // PublicKeySize is the size, in bytes, of public keys as used in this package. + PublicKeySize = 32 + // PrivateKeySize is the size, in bytes, of private keys as used in this package. + PrivateKeySize = 64 + // SignatureSize is the size, in bytes, of signatures generated and verified by this package. + SignatureSize = 64 + // SeedSize is the size, in bytes, of private key seeds. These are the private key representations used by RFC 8032. + SeedSize = 32 +) + +// PublicKey is the type of Ed25519 public keys. +// +// This type is an alias for crypto/ed25519's PublicKey type. +// See the crypto/ed25519 package for the methods on this type. +type PublicKey = ed25519.PublicKey + +// PrivateKey is the type of Ed25519 private keys. It implements crypto.Signer. +// +// This type is an alias for crypto/ed25519's PrivateKey type. +// See the crypto/ed25519 package for the methods on this type. +type PrivateKey = ed25519.PrivateKey + +// GenerateKey generates a public/private key pair using entropy from rand. +// If rand is nil, crypto/rand.Reader will be used. +func GenerateKey(rand io.Reader) (PublicKey, PrivateKey, error) { + return ed25519.GenerateKey(rand) +} + +// NewKeyFromSeed calculates a private key from a seed. It will panic if +// len(seed) is not SeedSize. This function is provided for interoperability +// with RFC 8032. RFC 8032's private keys correspond to seeds in this +// package. +func NewKeyFromSeed(seed []byte) PrivateKey { + return ed25519.NewKeyFromSeed(seed) +} + +// Sign signs the message with privateKey and returns a signature. It will +// panic if len(privateKey) is not PrivateKeySize. +func Sign(privateKey PrivateKey, message []byte) []byte { + return ed25519.Sign(privateKey, message) +} + +// Verify reports whether sig is a valid signature of message by publicKey. It +// will panic if len(publicKey) is not PublicKeySize. +func Verify(publicKey PublicKey, message, sig []byte) bool { + return ed25519.Verify(publicKey, message, sig) +} diff --git a/internal/crypto/ed25519/internal/edwards25519/const.go b/internal/crypto/ed25519/internal/edwards25519/const.go new file mode 100644 index 000000000..e39f086c1 --- /dev/null +++ b/internal/crypto/ed25519/internal/edwards25519/const.go @@ -0,0 +1,1422 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package edwards25519 + +// These values are from the public domain, “ref10” implementation of ed25519 +// from SUPERCOP. + +// d is a constant in the Edwards curve equation. +var d = FieldElement{ + -10913610, 13857413, -15372611, 6949391, 114729, -8787816, -6275908, -3247719, -18696448, -12055116, +} + +// d2 is 2*d. +var d2 = FieldElement{ + -21827239, -5839606, -30745221, 13898782, 229458, 15978800, -12551817, -6495438, 29715968, 9444199, +} + +// SqrtM1 is the square-root of -1 in the field. +var SqrtM1 = FieldElement{ + -32595792, -7943725, 9377950, 3500415, 12389472, -272473, -25146209, -2005654, 326686, 11406482, +} + +// A is a constant in the Montgomery-form of curve25519. +var A = FieldElement{ + 486662, 0, 0, 0, 0, 0, 0, 0, 0, 0, +} + +// bi contains precomputed multiples of the base-point. See the Ed25519 paper +// for a discussion about how these values are used. +var bi = [8]PreComputedGroupElement{ + { + FieldElement{25967493, -14356035, 29566456, 3660896, -12694345, 4014787, 27544626, -11754271, -6079156, 2047605}, + FieldElement{-12545711, 934262, -2722910, 3049990, -727428, 9406986, 12720692, 5043384, 19500929, -15469378}, + FieldElement{-8738181, 4489570, 9688441, -14785194, 10184609, -12363380, 29287919, 11864899, -24514362, -4438546}, + }, + { + FieldElement{15636291, -9688557, 24204773, -7912398, 616977, -16685262, 27787600, -14772189, 28944400, -1550024}, + FieldElement{16568933, 4717097, -11556148, -1102322, 15682896, -11807043, 16354577, -11775962, 7689662, 11199574}, + FieldElement{30464156, -5976125, -11779434, -15670865, 23220365, 15915852, 7512774, 10017326, -17749093, -9920357}, + }, + { + FieldElement{10861363, 11473154, 27284546, 1981175, -30064349, 12577861, 32867885, 14515107, -15438304, 10819380}, + FieldElement{4708026, 6336745, 20377586, 9066809, -11272109, 6594696, -25653668, 12483688, -12668491, 5581306}, + FieldElement{19563160, 16186464, -29386857, 4097519, 10237984, -4348115, 28542350, 13850243, -23678021, -15815942}, + }, + { + FieldElement{5153746, 9909285, 1723747, -2777874, 30523605, 5516873, 19480852, 5230134, -23952439, -15175766}, + FieldElement{-30269007, -3463509, 7665486, 10083793, 28475525, 1649722, 20654025, 16520125, 30598449, 7715701}, + FieldElement{28881845, 14381568, 9657904, 3680757, -20181635, 7843316, -31400660, 1370708, 29794553, -1409300}, + }, + { + FieldElement{-22518993, -6692182, 14201702, -8745502, -23510406, 8844726, 18474211, -1361450, -13062696, 13821877}, + FieldElement{-6455177, -7839871, 3374702, -4740862, -27098617, -10571707, 31655028, -7212327, 18853322, -14220951}, + FieldElement{4566830, -12963868, -28974889, -12240689, -7602672, -2830569, -8514358, -10431137, 2207753, -3209784}, + }, + { + FieldElement{-25154831, -4185821, 29681144, 7868801, -6854661, -9423865, -12437364, -663000, -31111463, -16132436}, + FieldElement{25576264, -2703214, 7349804, -11814844, 16472782, 9300885, 3844789, 15725684, 171356, 6466918}, + FieldElement{23103977, 13316479, 9739013, -16149481, 817875, -15038942, 8965339, -14088058, -30714912, 16193877}, + }, + { + FieldElement{-33521811, 3180713, -2394130, 14003687, -16903474, -16270840, 17238398, 4729455, -18074513, 9256800}, + FieldElement{-25182317, -4174131, 32336398, 5036987, -21236817, 11360617, 22616405, 9761698, -19827198, 630305}, + FieldElement{-13720693, 2639453, -24237460, -7406481, 9494427, -5774029, -6554551, -15960994, -2449256, -14291300}, + }, + { + FieldElement{-3151181, -5046075, 9282714, 6866145, -31907062, -863023, -18940575, 15033784, 25105118, -7894876}, + FieldElement{-24326370, 15950226, -31801215, -14592823, -11662737, -5090925, 1573892, -2625887, 2198790, -15804619}, + FieldElement{-3099351, 10324967, -2241613, 7453183, -5446979, -2735503, -13812022, -16236442, -32461234, -12290683}, + }, +} + +// base contains precomputed multiples of the base-point. See the Ed25519 paper +// for a discussion about how these values are used. +var base = [32][8]PreComputedGroupElement{ + { + { + FieldElement{25967493, -14356035, 29566456, 3660896, -12694345, 4014787, 27544626, -11754271, -6079156, 2047605}, + FieldElement{-12545711, 934262, -2722910, 3049990, -727428, 9406986, 12720692, 5043384, 19500929, -15469378}, + FieldElement{-8738181, 4489570, 9688441, -14785194, 10184609, -12363380, 29287919, 11864899, -24514362, -4438546}, + }, + { + FieldElement{-12815894, -12976347, -21581243, 11784320, -25355658, -2750717, -11717903, -3814571, -358445, -10211303}, + FieldElement{-21703237, 6903825, 27185491, 6451973, -29577724, -9554005, -15616551, 11189268, -26829678, -5319081}, + FieldElement{26966642, 11152617, 32442495, 15396054, 14353839, -12752335, -3128826, -9541118, -15472047, -4166697}, + }, + { + FieldElement{15636291, -9688557, 24204773, -7912398, 616977, -16685262, 27787600, -14772189, 28944400, -1550024}, + FieldElement{16568933, 4717097, -11556148, -1102322, 15682896, -11807043, 16354577, -11775962, 7689662, 11199574}, + FieldElement{30464156, -5976125, -11779434, -15670865, 23220365, 15915852, 7512774, 10017326, -17749093, -9920357}, + }, + { + FieldElement{-17036878, 13921892, 10945806, -6033431, 27105052, -16084379, -28926210, 15006023, 3284568, -6276540}, + FieldElement{23599295, -8306047, -11193664, -7687416, 13236774, 10506355, 7464579, 9656445, 13059162, 10374397}, + FieldElement{7798556, 16710257, 3033922, 2874086, 28997861, 2835604, 32406664, -3839045, -641708, -101325}, + }, + { + FieldElement{10861363, 11473154, 27284546, 1981175, -30064349, 12577861, 32867885, 14515107, -15438304, 10819380}, + FieldElement{4708026, 6336745, 20377586, 9066809, -11272109, 6594696, -25653668, 12483688, -12668491, 5581306}, + FieldElement{19563160, 16186464, -29386857, 4097519, 10237984, -4348115, 28542350, 13850243, -23678021, -15815942}, + }, + { + FieldElement{-15371964, -12862754, 32573250, 4720197, -26436522, 5875511, -19188627, -15224819, -9818940, -12085777}, + FieldElement{-8549212, 109983, 15149363, 2178705, 22900618, 4543417, 3044240, -15689887, 1762328, 14866737}, + FieldElement{-18199695, -15951423, -10473290, 1707278, -17185920, 3916101, -28236412, 3959421, 27914454, 4383652}, + }, + { + FieldElement{5153746, 9909285, 1723747, -2777874, 30523605, 5516873, 19480852, 5230134, -23952439, -15175766}, + FieldElement{-30269007, -3463509, 7665486, 10083793, 28475525, 1649722, 20654025, 16520125, 30598449, 7715701}, + FieldElement{28881845, 14381568, 9657904, 3680757, -20181635, 7843316, -31400660, 1370708, 29794553, -1409300}, + }, + { + FieldElement{14499471, -2729599, -33191113, -4254652, 28494862, 14271267, 30290735, 10876454, -33154098, 2381726}, + FieldElement{-7195431, -2655363, -14730155, 462251, -27724326, 3941372, -6236617, 3696005, -32300832, 15351955}, + FieldElement{27431194, 8222322, 16448760, -3907995, -18707002, 11938355, -32961401, -2970515, 29551813, 10109425}, + }, + }, + { + { + FieldElement{-13657040, -13155431, -31283750, 11777098, 21447386, 6519384, -2378284, -1627556, 10092783, -4764171}, + FieldElement{27939166, 14210322, 4677035, 16277044, -22964462, -12398139, -32508754, 12005538, -17810127, 12803510}, + FieldElement{17228999, -15661624, -1233527, 300140, -1224870, -11714777, 30364213, -9038194, 18016357, 4397660}, + }, + { + FieldElement{-10958843, -7690207, 4776341, -14954238, 27850028, -15602212, -26619106, 14544525, -17477504, 982639}, + FieldElement{29253598, 15796703, -2863982, -9908884, 10057023, 3163536, 7332899, -4120128, -21047696, 9934963}, + FieldElement{5793303, 16271923, -24131614, -10116404, 29188560, 1206517, -14747930, 4559895, -30123922, -10897950}, + }, + { + FieldElement{-27643952, -11493006, 16282657, -11036493, 28414021, -15012264, 24191034, 4541697, -13338309, 5500568}, + FieldElement{12650548, -1497113, 9052871, 11355358, -17680037, -8400164, -17430592, 12264343, 10874051, 13524335}, + FieldElement{25556948, -3045990, 714651, 2510400, 23394682, -10415330, 33119038, 5080568, -22528059, 5376628}, + }, + { + FieldElement{-26088264, -4011052, -17013699, -3537628, -6726793, 1920897, -22321305, -9447443, 4535768, 1569007}, + FieldElement{-2255422, 14606630, -21692440, -8039818, 28430649, 8775819, -30494562, 3044290, 31848280, 12543772}, + FieldElement{-22028579, 2943893, -31857513, 6777306, 13784462, -4292203, -27377195, -2062731, 7718482, 14474653}, + }, + { + FieldElement{2385315, 2454213, -22631320, 46603, -4437935, -15680415, 656965, -7236665, 24316168, -5253567}, + FieldElement{13741529, 10911568, -33233417, -8603737, -20177830, -1033297, 33040651, -13424532, -20729456, 8321686}, + FieldElement{21060490, -2212744, 15712757, -4336099, 1639040, 10656336, 23845965, -11874838, -9984458, 608372}, + }, + { + FieldElement{-13672732, -15087586, -10889693, -7557059, -6036909, 11305547, 1123968, -6780577, 27229399, 23887}, + FieldElement{-23244140, -294205, -11744728, 14712571, -29465699, -2029617, 12797024, -6440308, -1633405, 16678954}, + FieldElement{-29500620, 4770662, -16054387, 14001338, 7830047, 9564805, -1508144, -4795045, -17169265, 4904953}, + }, + { + FieldElement{24059557, 14617003, 19037157, -15039908, 19766093, -14906429, 5169211, 16191880, 2128236, -4326833}, + FieldElement{-16981152, 4124966, -8540610, -10653797, 30336522, -14105247, -29806336, 916033, -6882542, -2986532}, + FieldElement{-22630907, 12419372, -7134229, -7473371, -16478904, 16739175, 285431, 2763829, 15736322, 4143876}, + }, + { + FieldElement{2379352, 11839345, -4110402, -5988665, 11274298, 794957, 212801, -14594663, 23527084, -16458268}, + FieldElement{33431127, -11130478, -17838966, -15626900, 8909499, 8376530, -32625340, 4087881, -15188911, -14416214}, + FieldElement{1767683, 7197987, -13205226, -2022635, -13091350, 448826, 5799055, 4357868, -4774191, -16323038}, + }, + }, + { + { + FieldElement{6721966, 13833823, -23523388, -1551314, 26354293, -11863321, 23365147, -3949732, 7390890, 2759800}, + FieldElement{4409041, 2052381, 23373853, 10530217, 7676779, -12885954, 21302353, -4264057, 1244380, -12919645}, + FieldElement{-4421239, 7169619, 4982368, -2957590, 30256825, -2777540, 14086413, 9208236, 15886429, 16489664}, + }, + { + FieldElement{1996075, 10375649, 14346367, 13311202, -6874135, -16438411, -13693198, 398369, -30606455, -712933}, + FieldElement{-25307465, 9795880, -2777414, 14878809, -33531835, 14780363, 13348553, 12076947, -30836462, 5113182}, + FieldElement{-17770784, 11797796, 31950843, 13929123, -25888302, 12288344, -30341101, -7336386, 13847711, 5387222}, + }, + { + FieldElement{-18582163, -3416217, 17824843, -2340966, 22744343, -10442611, 8763061, 3617786, -19600662, 10370991}, + FieldElement{20246567, -14369378, 22358229, -543712, 18507283, -10413996, 14554437, -8746092, 32232924, 16763880}, + FieldElement{9648505, 10094563, 26416693, 14745928, -30374318, -6472621, 11094161, 15689506, 3140038, -16510092}, + }, + { + FieldElement{-16160072, 5472695, 31895588, 4744994, 8823515, 10365685, -27224800, 9448613, -28774454, 366295}, + FieldElement{19153450, 11523972, -11096490, -6503142, -24647631, 5420647, 28344573, 8041113, 719605, 11671788}, + FieldElement{8678025, 2694440, -6808014, 2517372, 4964326, 11152271, -15432916, -15266516, 27000813, -10195553}, + }, + { + FieldElement{-15157904, 7134312, 8639287, -2814877, -7235688, 10421742, 564065, 5336097, 6750977, -14521026}, + FieldElement{11836410, -3979488, 26297894, 16080799, 23455045, 15735944, 1695823, -8819122, 8169720, 16220347}, + FieldElement{-18115838, 8653647, 17578566, -6092619, -8025777, -16012763, -11144307, -2627664, -5990708, -14166033}, + }, + { + FieldElement{-23308498, -10968312, 15213228, -10081214, -30853605, -11050004, 27884329, 2847284, 2655861, 1738395}, + FieldElement{-27537433, -14253021, -25336301, -8002780, -9370762, 8129821, 21651608, -3239336, -19087449, -11005278}, + FieldElement{1533110, 3437855, 23735889, 459276, 29970501, 11335377, 26030092, 5821408, 10478196, 8544890}, + }, + { + FieldElement{32173121, -16129311, 24896207, 3921497, 22579056, -3410854, 19270449, 12217473, 17789017, -3395995}, + FieldElement{-30552961, -2228401, -15578829, -10147201, 13243889, 517024, 15479401, -3853233, 30460520, 1052596}, + FieldElement{-11614875, 13323618, 32618793, 8175907, -15230173, 12596687, 27491595, -4612359, 3179268, -9478891}, + }, + { + FieldElement{31947069, -14366651, -4640583, -15339921, -15125977, -6039709, -14756777, -16411740, 19072640, -9511060}, + FieldElement{11685058, 11822410, 3158003, -13952594, 33402194, -4165066, 5977896, -5215017, 473099, 5040608}, + FieldElement{-20290863, 8198642, -27410132, 11602123, 1290375, -2799760, 28326862, 1721092, -19558642, -3131606}, + }, + }, + { + { + FieldElement{7881532, 10687937, 7578723, 7738378, -18951012, -2553952, 21820786, 8076149, -27868496, 11538389}, + FieldElement{-19935666, 3899861, 18283497, -6801568, -15728660, -11249211, 8754525, 7446702, -5676054, 5797016}, + FieldElement{-11295600, -3793569, -15782110, -7964573, 12708869, -8456199, 2014099, -9050574, -2369172, -5877341}, + }, + { + FieldElement{-22472376, -11568741, -27682020, 1146375, 18956691, 16640559, 1192730, -3714199, 15123619, 10811505}, + FieldElement{14352098, -3419715, -18942044, 10822655, 32750596, 4699007, -70363, 15776356, -28886779, -11974553}, + FieldElement{-28241164, -8072475, -4978962, -5315317, 29416931, 1847569, -20654173, -16484855, 4714547, -9600655}, + }, + { + FieldElement{15200332, 8368572, 19679101, 15970074, -31872674, 1959451, 24611599, -4543832, -11745876, 12340220}, + FieldElement{12876937, -10480056, 33134381, 6590940, -6307776, 14872440, 9613953, 8241152, 15370987, 9608631}, + FieldElement{-4143277, -12014408, 8446281, -391603, 4407738, 13629032, -7724868, 15866074, -28210621, -8814099}, + }, + { + FieldElement{26660628, -15677655, 8393734, 358047, -7401291, 992988, -23904233, 858697, 20571223, 8420556}, + FieldElement{14620715, 13067227, -15447274, 8264467, 14106269, 15080814, 33531827, 12516406, -21574435, -12476749}, + FieldElement{236881, 10476226, 57258, -14677024, 6472998, 2466984, 17258519, 7256740, 8791136, 15069930}, + }, + { + FieldElement{1276410, -9371918, 22949635, -16322807, -23493039, -5702186, 14711875, 4874229, -30663140, -2331391}, + FieldElement{5855666, 4990204, -13711848, 7294284, -7804282, 1924647, -1423175, -7912378, -33069337, 9234253}, + FieldElement{20590503, -9018988, 31529744, -7352666, -2706834, 10650548, 31559055, -11609587, 18979186, 13396066}, + }, + { + FieldElement{24474287, 4968103, 22267082, 4407354, 24063882, -8325180, -18816887, 13594782, 33514650, 7021958}, + FieldElement{-11566906, -6565505, -21365085, 15928892, -26158305, 4315421, -25948728, -3916677, -21480480, 12868082}, + FieldElement{-28635013, 13504661, 19988037, -2132761, 21078225, 6443208, -21446107, 2244500, -12455797, -8089383}, + }, + { + FieldElement{-30595528, 13793479, -5852820, 319136, -25723172, -6263899, 33086546, 8957937, -15233648, 5540521}, + FieldElement{-11630176, -11503902, -8119500, -7643073, 2620056, 1022908, -23710744, -1568984, -16128528, -14962807}, + FieldElement{23152971, 775386, 27395463, 14006635, -9701118, 4649512, 1689819, 892185, -11513277, -15205948}, + }, + { + FieldElement{9770129, 9586738, 26496094, 4324120, 1556511, -3550024, 27453819, 4763127, -19179614, 5867134}, + FieldElement{-32765025, 1927590, 31726409, -4753295, 23962434, -16019500, 27846559, 5931263, -29749703, -16108455}, + FieldElement{27461885, -2977536, 22380810, 1815854, -23033753, -3031938, 7283490, -15148073, -19526700, 7734629}, + }, + }, + { + { + FieldElement{-8010264, -9590817, -11120403, 6196038, 29344158, -13430885, 7585295, -3176626, 18549497, 15302069}, + FieldElement{-32658337, -6171222, -7672793, -11051681, 6258878, 13504381, 10458790, -6418461, -8872242, 8424746}, + FieldElement{24687205, 8613276, -30667046, -3233545, 1863892, -1830544, 19206234, 7134917, -11284482, -828919}, + }, + { + FieldElement{11334899, -9218022, 8025293, 12707519, 17523892, -10476071, 10243738, -14685461, -5066034, 16498837}, + FieldElement{8911542, 6887158, -9584260, -6958590, 11145641, -9543680, 17303925, -14124238, 6536641, 10543906}, + FieldElement{-28946384, 15479763, -17466835, 568876, -1497683, 11223454, -2669190, -16625574, -27235709, 8876771}, + }, + { + FieldElement{-25742899, -12566864, -15649966, -846607, -33026686, -796288, -33481822, 15824474, -604426, -9039817}, + FieldElement{10330056, 70051, 7957388, -9002667, 9764902, 15609756, 27698697, -4890037, 1657394, 3084098}, + FieldElement{10477963, -7470260, 12119566, -13250805, 29016247, -5365589, 31280319, 14396151, -30233575, 15272409}, + }, + { + FieldElement{-12288309, 3169463, 28813183, 16658753, 25116432, -5630466, -25173957, -12636138, -25014757, 1950504}, + FieldElement{-26180358, 9489187, 11053416, -14746161, -31053720, 5825630, -8384306, -8767532, 15341279, 8373727}, + FieldElement{28685821, 7759505, -14378516, -12002860, -31971820, 4079242, 298136, -10232602, -2878207, 15190420}, + }, + { + FieldElement{-32932876, 13806336, -14337485, -15794431, -24004620, 10940928, 8669718, 2742393, -26033313, -6875003}, + FieldElement{-1580388, -11729417, -25979658, -11445023, -17411874, -10912854, 9291594, -16247779, -12154742, 6048605}, + FieldElement{-30305315, 14843444, 1539301, 11864366, 20201677, 1900163, 13934231, 5128323, 11213262, 9168384}, + }, + { + FieldElement{-26280513, 11007847, 19408960, -940758, -18592965, -4328580, -5088060, -11105150, 20470157, -16398701}, + FieldElement{-23136053, 9282192, 14855179, -15390078, -7362815, -14408560, -22783952, 14461608, 14042978, 5230683}, + FieldElement{29969567, -2741594, -16711867, -8552442, 9175486, -2468974, 21556951, 3506042, -5933891, -12449708}, + }, + { + FieldElement{-3144746, 8744661, 19704003, 4581278, -20430686, 6830683, -21284170, 8971513, -28539189, 15326563}, + FieldElement{-19464629, 10110288, -17262528, -3503892, -23500387, 1355669, -15523050, 15300988, -20514118, 9168260}, + FieldElement{-5353335, 4488613, -23803248, 16314347, 7780487, -15638939, -28948358, 9601605, 33087103, -9011387}, + }, + { + FieldElement{-19443170, -15512900, -20797467, -12445323, -29824447, 10229461, -27444329, -15000531, -5996870, 15664672}, + FieldElement{23294591, -16632613, -22650781, -8470978, 27844204, 11461195, 13099750, -2460356, 18151676, 13417686}, + FieldElement{-24722913, -4176517, -31150679, 5988919, -26858785, 6685065, 1661597, -12551441, 15271676, -15452665}, + }, + }, + { + { + FieldElement{11433042, -13228665, 8239631, -5279517, -1985436, -725718, -18698764, 2167544, -6921301, -13440182}, + FieldElement{-31436171, 15575146, 30436815, 12192228, -22463353, 9395379, -9917708, -8638997, 12215110, 12028277}, + FieldElement{14098400, 6555944, 23007258, 5757252, -15427832, -12950502, 30123440, 4617780, -16900089, -655628}, + }, + { + FieldElement{-4026201, -15240835, 11893168, 13718664, -14809462, 1847385, -15819999, 10154009, 23973261, -12684474}, + FieldElement{-26531820, -3695990, -1908898, 2534301, -31870557, -16550355, 18341390, -11419951, 32013174, -10103539}, + FieldElement{-25479301, 10876443, -11771086, -14625140, -12369567, 1838104, 21911214, 6354752, 4425632, -837822}, + }, + { + FieldElement{-10433389, -14612966, 22229858, -3091047, -13191166, 776729, -17415375, -12020462, 4725005, 14044970}, + FieldElement{19268650, -7304421, 1555349, 8692754, -21474059, -9910664, 6347390, -1411784, -19522291, -16109756}, + FieldElement{-24864089, 12986008, -10898878, -5558584, -11312371, -148526, 19541418, 8180106, 9282262, 10282508}, + }, + { + FieldElement{-26205082, 4428547, -8661196, -13194263, 4098402, -14165257, 15522535, 8372215, 5542595, -10702683}, + FieldElement{-10562541, 14895633, 26814552, -16673850, -17480754, -2489360, -2781891, 6993761, -18093885, 10114655}, + FieldElement{-20107055, -929418, 31422704, 10427861, -7110749, 6150669, -29091755, -11529146, 25953725, -106158}, + }, + { + FieldElement{-4234397, -8039292, -9119125, 3046000, 2101609, -12607294, 19390020, 6094296, -3315279, 12831125}, + FieldElement{-15998678, 7578152, 5310217, 14408357, -33548620, -224739, 31575954, 6326196, 7381791, -2421839}, + FieldElement{-20902779, 3296811, 24736065, -16328389, 18374254, 7318640, 6295303, 8082724, -15362489, 12339664}, + }, + { + FieldElement{27724736, 2291157, 6088201, -14184798, 1792727, 5857634, 13848414, 15768922, 25091167, 14856294}, + FieldElement{-18866652, 8331043, 24373479, 8541013, -701998, -9269457, 12927300, -12695493, -22182473, -9012899}, + FieldElement{-11423429, -5421590, 11632845, 3405020, 30536730, -11674039, -27260765, 13866390, 30146206, 9142070}, + }, + { + FieldElement{3924129, -15307516, -13817122, -10054960, 12291820, -668366, -27702774, 9326384, -8237858, 4171294}, + FieldElement{-15921940, 16037937, 6713787, 16606682, -21612135, 2790944, 26396185, 3731949, 345228, -5462949}, + FieldElement{-21327538, 13448259, 25284571, 1143661, 20614966, -8849387, 2031539, -12391231, -16253183, -13582083}, + }, + { + FieldElement{31016211, -16722429, 26371392, -14451233, -5027349, 14854137, 17477601, 3842657, 28012650, -16405420}, + FieldElement{-5075835, 9368966, -8562079, -4600902, -15249953, 6970560, -9189873, 16292057, -8867157, 3507940}, + FieldElement{29439664, 3537914, 23333589, 6997794, -17555561, -11018068, -15209202, -15051267, -9164929, 6580396}, + }, + }, + { + { + FieldElement{-12185861, -7679788, 16438269, 10826160, -8696817, -6235611, 17860444, -9273846, -2095802, 9304567}, + FieldElement{20714564, -4336911, 29088195, 7406487, 11426967, -5095705, 14792667, -14608617, 5289421, -477127}, + FieldElement{-16665533, -10650790, -6160345, -13305760, 9192020, -1802462, 17271490, 12349094, 26939669, -3752294}, + }, + { + FieldElement{-12889898, 9373458, 31595848, 16374215, 21471720, 13221525, -27283495, -12348559, -3698806, 117887}, + FieldElement{22263325, -6560050, 3984570, -11174646, -15114008, -566785, 28311253, 5358056, -23319780, 541964}, + FieldElement{16259219, 3261970, 2309254, -15534474, -16885711, -4581916, 24134070, -16705829, -13337066, -13552195}, + }, + { + FieldElement{9378160, -13140186, -22845982, -12745264, 28198281, -7244098, -2399684, -717351, 690426, 14876244}, + FieldElement{24977353, -314384, -8223969, -13465086, 28432343, -1176353, -13068804, -12297348, -22380984, 6618999}, + FieldElement{-1538174, 11685646, 12944378, 13682314, -24389511, -14413193, 8044829, -13817328, 32239829, -5652762}, + }, + { + FieldElement{-18603066, 4762990, -926250, 8885304, -28412480, -3187315, 9781647, -10350059, 32779359, 5095274}, + FieldElement{-33008130, -5214506, -32264887, -3685216, 9460461, -9327423, -24601656, 14506724, 21639561, -2630236}, + FieldElement{-16400943, -13112215, 25239338, 15531969, 3987758, -4499318, -1289502, -6863535, 17874574, 558605}, + }, + { + FieldElement{-13600129, 10240081, 9171883, 16131053, -20869254, 9599700, 33499487, 5080151, 2085892, 5119761}, + FieldElement{-22205145, -2519528, -16381601, 414691, -25019550, 2170430, 30634760, -8363614, -31999993, -5759884}, + FieldElement{-6845704, 15791202, 8550074, -1312654, 29928809, -12092256, 27534430, -7192145, -22351378, 12961482}, + }, + { + FieldElement{-24492060, -9570771, 10368194, 11582341, -23397293, -2245287, 16533930, 8206996, -30194652, -5159638}, + FieldElement{-11121496, -3382234, 2307366, 6362031, -135455, 8868177, -16835630, 7031275, 7589640, 8945490}, + FieldElement{-32152748, 8917967, 6661220, -11677616, -1192060, -15793393, 7251489, -11182180, 24099109, -14456170}, + }, + { + FieldElement{5019558, -7907470, 4244127, -14714356, -26933272, 6453165, -19118182, -13289025, -6231896, -10280736}, + FieldElement{10853594, 10721687, 26480089, 5861829, -22995819, 1972175, -1866647, -10557898, -3363451, -6441124}, + FieldElement{-17002408, 5906790, 221599, -6563147, 7828208, -13248918, 24362661, -2008168, -13866408, 7421392}, + }, + { + FieldElement{8139927, -6546497, 32257646, -5890546, 30375719, 1886181, -21175108, 15441252, 28826358, -4123029}, + FieldElement{6267086, 9695052, 7709135, -16603597, -32869068, -1886135, 14795160, -7840124, 13746021, -1742048}, + FieldElement{28584902, 7787108, -6732942, -15050729, 22846041, -7571236, -3181936, -363524, 4771362, -8419958}, + }, + }, + { + { + FieldElement{24949256, 6376279, -27466481, -8174608, -18646154, -9930606, 33543569, -12141695, 3569627, 11342593}, + FieldElement{26514989, 4740088, 27912651, 3697550, 19331575, -11472339, 6809886, 4608608, 7325975, -14801071}, + FieldElement{-11618399, -14554430, -24321212, 7655128, -1369274, 5214312, -27400540, 10258390, -17646694, -8186692}, + }, + { + FieldElement{11431204, 15823007, 26570245, 14329124, 18029990, 4796082, -31446179, 15580664, 9280358, -3973687}, + FieldElement{-160783, -10326257, -22855316, -4304997, -20861367, -13621002, -32810901, -11181622, -15545091, 4387441}, + FieldElement{-20799378, 12194512, 3937617, -5805892, -27154820, 9340370, -24513992, 8548137, 20617071, -7482001}, + }, + { + FieldElement{-938825, -3930586, -8714311, 16124718, 24603125, -6225393, -13775352, -11875822, 24345683, 10325460}, + FieldElement{-19855277, -1568885, -22202708, 8714034, 14007766, 6928528, 16318175, -1010689, 4766743, 3552007}, + FieldElement{-21751364, -16730916, 1351763, -803421, -4009670, 3950935, 3217514, 14481909, 10988822, -3994762}, + }, + { + FieldElement{15564307, -14311570, 3101243, 5684148, 30446780, -8051356, 12677127, -6505343, -8295852, 13296005}, + FieldElement{-9442290, 6624296, -30298964, -11913677, -4670981, -2057379, 31521204, 9614054, -30000824, 12074674}, + FieldElement{4771191, -135239, 14290749, -13089852, 27992298, 14998318, -1413936, -1556716, 29832613, -16391035}, + }, + { + FieldElement{7064884, -7541174, -19161962, -5067537, -18891269, -2912736, 25825242, 5293297, -27122660, 13101590}, + FieldElement{-2298563, 2439670, -7466610, 1719965, -27267541, -16328445, 32512469, -5317593, -30356070, -4190957}, + FieldElement{-30006540, 10162316, -33180176, 3981723, -16482138, -13070044, 14413974, 9515896, 19568978, 9628812}, + }, + { + FieldElement{33053803, 199357, 15894591, 1583059, 27380243, -4580435, -17838894, -6106839, -6291786, 3437740}, + FieldElement{-18978877, 3884493, 19469877, 12726490, 15913552, 13614290, -22961733, 70104, 7463304, 4176122}, + FieldElement{-27124001, 10659917, 11482427, -16070381, 12771467, -6635117, -32719404, -5322751, 24216882, 5944158}, + }, + { + FieldElement{8894125, 7450974, -2664149, -9765752, -28080517, -12389115, 19345746, 14680796, 11632993, 5847885}, + FieldElement{26942781, -2315317, 9129564, -4906607, 26024105, 11769399, -11518837, 6367194, -9727230, 4782140}, + FieldElement{19916461, -4828410, -22910704, -11414391, 25606324, -5972441, 33253853, 8220911, 6358847, -1873857}, + }, + { + FieldElement{801428, -2081702, 16569428, 11065167, 29875704, 96627, 7908388, -4480480, -13538503, 1387155}, + FieldElement{19646058, 5720633, -11416706, 12814209, 11607948, 12749789, 14147075, 15156355, -21866831, 11835260}, + FieldElement{19299512, 1155910, 28703737, 14890794, 2925026, 7269399, 26121523, 15467869, -26560550, 5052483}, + }, + }, + { + { + FieldElement{-3017432, 10058206, 1980837, 3964243, 22160966, 12322533, -6431123, -12618185, 12228557, -7003677}, + FieldElement{32944382, 14922211, -22844894, 5188528, 21913450, -8719943, 4001465, 13238564, -6114803, 8653815}, + FieldElement{22865569, -4652735, 27603668, -12545395, 14348958, 8234005, 24808405, 5719875, 28483275, 2841751}, + }, + { + FieldElement{-16420968, -1113305, -327719, -12107856, 21886282, -15552774, -1887966, -315658, 19932058, -12739203}, + FieldElement{-11656086, 10087521, -8864888, -5536143, -19278573, -3055912, 3999228, 13239134, -4777469, -13910208}, + FieldElement{1382174, -11694719, 17266790, 9194690, -13324356, 9720081, 20403944, 11284705, -14013818, 3093230}, + }, + { + FieldElement{16650921, -11037932, -1064178, 1570629, -8329746, 7352753, -302424, 16271225, -24049421, -6691850}, + FieldElement{-21911077, -5927941, -4611316, -5560156, -31744103, -10785293, 24123614, 15193618, -21652117, -16739389}, + FieldElement{-9935934, -4289447, -25279823, 4372842, 2087473, 10399484, 31870908, 14690798, 17361620, 11864968}, + }, + { + FieldElement{-11307610, 6210372, 13206574, 5806320, -29017692, -13967200, -12331205, -7486601, -25578460, -16240689}, + FieldElement{14668462, -12270235, 26039039, 15305210, 25515617, 4542480, 10453892, 6577524, 9145645, -6443880}, + FieldElement{5974874, 3053895, -9433049, -10385191, -31865124, 3225009, -7972642, 3936128, -5652273, -3050304}, + }, + { + FieldElement{30625386, -4729400, -25555961, -12792866, -20484575, 7695099, 17097188, -16303496, -27999779, 1803632}, + FieldElement{-3553091, 9865099, -5228566, 4272701, -5673832, -16689700, 14911344, 12196514, -21405489, 7047412}, + FieldElement{20093277, 9920966, -11138194, -5343857, 13161587, 12044805, -32856851, 4124601, -32343828, -10257566}, + }, + { + FieldElement{-20788824, 14084654, -13531713, 7842147, 19119038, -13822605, 4752377, -8714640, -21679658, 2288038}, + FieldElement{-26819236, -3283715, 29965059, 3039786, -14473765, 2540457, 29457502, 14625692, -24819617, 12570232}, + FieldElement{-1063558, -11551823, 16920318, 12494842, 1278292, -5869109, -21159943, -3498680, -11974704, 4724943}, + }, + { + FieldElement{17960970, -11775534, -4140968, -9702530, -8876562, -1410617, -12907383, -8659932, -29576300, 1903856}, + FieldElement{23134274, -14279132, -10681997, -1611936, 20684485, 15770816, -12989750, 3190296, 26955097, 14109738}, + FieldElement{15308788, 5320727, -30113809, -14318877, 22902008, 7767164, 29425325, -11277562, 31960942, 11934971}, + }, + { + FieldElement{-27395711, 8435796, 4109644, 12222639, -24627868, 14818669, 20638173, 4875028, 10491392, 1379718}, + FieldElement{-13159415, 9197841, 3875503, -8936108, -1383712, -5879801, 33518459, 16176658, 21432314, 12180697}, + FieldElement{-11787308, 11500838, 13787581, -13832590, -22430679, 10140205, 1465425, 12689540, -10301319, -13872883}, + }, + }, + { + { + FieldElement{5414091, -15386041, -21007664, 9643570, 12834970, 1186149, -2622916, -1342231, 26128231, 6032912}, + FieldElement{-26337395, -13766162, 32496025, -13653919, 17847801, -12669156, 3604025, 8316894, -25875034, -10437358}, + FieldElement{3296484, 6223048, 24680646, -12246460, -23052020, 5903205, -8862297, -4639164, 12376617, 3188849}, + }, + { + FieldElement{29190488, -14659046, 27549113, -1183516, 3520066, -10697301, 32049515, -7309113, -16109234, -9852307}, + FieldElement{-14744486, -9309156, 735818, -598978, -20407687, -5057904, 25246078, -15795669, 18640741, -960977}, + FieldElement{-6928835, -16430795, 10361374, 5642961, 4910474, 12345252, -31638386, -494430, 10530747, 1053335}, + }, + { + FieldElement{-29265967, -14186805, -13538216, -12117373, -19457059, -10655384, -31462369, -2948985, 24018831, 15026644}, + FieldElement{-22592535, -3145277, -2289276, 5953843, -13440189, 9425631, 25310643, 13003497, -2314791, -15145616}, + FieldElement{-27419985, -603321, -8043984, -1669117, -26092265, 13987819, -27297622, 187899, -23166419, -2531735}, + }, + { + FieldElement{-21744398, -13810475, 1844840, 5021428, -10434399, -15911473, 9716667, 16266922, -5070217, 726099}, + FieldElement{29370922, -6053998, 7334071, -15342259, 9385287, 2247707, -13661962, -4839461, 30007388, -15823341}, + FieldElement{-936379, 16086691, 23751945, -543318, -1167538, -5189036, 9137109, 730663, 9835848, 4555336}, + }, + { + FieldElement{-23376435, 1410446, -22253753, -12899614, 30867635, 15826977, 17693930, 544696, -11985298, 12422646}, + FieldElement{31117226, -12215734, -13502838, 6561947, -9876867, -12757670, -5118685, -4096706, 29120153, 13924425}, + FieldElement{-17400879, -14233209, 19675799, -2734756, -11006962, -5858820, -9383939, -11317700, 7240931, -237388}, + }, + { + FieldElement{-31361739, -11346780, -15007447, -5856218, -22453340, -12152771, 1222336, 4389483, 3293637, -15551743}, + FieldElement{-16684801, -14444245, 11038544, 11054958, -13801175, -3338533, -24319580, 7733547, 12796905, -6335822}, + FieldElement{-8759414, -10817836, -25418864, 10783769, -30615557, -9746811, -28253339, 3647836, 3222231, -11160462}, + }, + { + FieldElement{18606113, 1693100, -25448386, -15170272, 4112353, 10045021, 23603893, -2048234, -7550776, 2484985}, + FieldElement{9255317, -3131197, -12156162, -1004256, 13098013, -9214866, 16377220, -2102812, -19802075, -3034702}, + FieldElement{-22729289, 7496160, -5742199, 11329249, 19991973, -3347502, -31718148, 9936966, -30097688, -10618797}, + }, + { + FieldElement{21878590, -5001297, 4338336, 13643897, -3036865, 13160960, 19708896, 5415497, -7360503, -4109293}, + FieldElement{27736861, 10103576, 12500508, 8502413, -3413016, -9633558, 10436918, -1550276, -23659143, -8132100}, + FieldElement{19492550, -12104365, -29681976, -852630, -3208171, 12403437, 30066266, 8367329, 13243957, 8709688}, + }, + }, + { + { + FieldElement{12015105, 2801261, 28198131, 10151021, 24818120, -4743133, -11194191, -5645734, 5150968, 7274186}, + FieldElement{2831366, -12492146, 1478975, 6122054, 23825128, -12733586, 31097299, 6083058, 31021603, -9793610}, + FieldElement{-2529932, -2229646, 445613, 10720828, -13849527, -11505937, -23507731, 16354465, 15067285, -14147707}, + }, + { + FieldElement{7840942, 14037873, -33364863, 15934016, -728213, -3642706, 21403988, 1057586, -19379462, -12403220}, + FieldElement{915865, -16469274, 15608285, -8789130, -24357026, 6060030, -17371319, 8410997, -7220461, 16527025}, + FieldElement{32922597, -556987, 20336074, -16184568, 10903705, -5384487, 16957574, 52992, 23834301, 6588044}, + }, + { + FieldElement{32752030, 11232950, 3381995, -8714866, 22652988, -10744103, 17159699, 16689107, -20314580, -1305992}, + FieldElement{-4689649, 9166776, -25710296, -10847306, 11576752, 12733943, 7924251, -2752281, 1976123, -7249027}, + FieldElement{21251222, 16309901, -2983015, -6783122, 30810597, 12967303, 156041, -3371252, 12331345, -8237197}, + }, + { + FieldElement{8651614, -4477032, -16085636, -4996994, 13002507, 2950805, 29054427, -5106970, 10008136, -4667901}, + FieldElement{31486080, 15114593, -14261250, 12951354, 14369431, -7387845, 16347321, -13662089, 8684155, -10532952}, + FieldElement{19443825, 11385320, 24468943, -9659068, -23919258, 2187569, -26263207, -6086921, 31316348, 14219878}, + }, + { + FieldElement{-28594490, 1193785, 32245219, 11392485, 31092169, 15722801, 27146014, 6992409, 29126555, 9207390}, + FieldElement{32382935, 1110093, 18477781, 11028262, -27411763, -7548111, -4980517, 10843782, -7957600, -14435730}, + FieldElement{2814918, 7836403, 27519878, -7868156, -20894015, -11553689, -21494559, 8550130, 28346258, 1994730}, + }, + { + FieldElement{-19578299, 8085545, -14000519, -3948622, 2785838, -16231307, -19516951, 7174894, 22628102, 8115180}, + FieldElement{-30405132, 955511, -11133838, -15078069, -32447087, -13278079, -25651578, 3317160, -9943017, 930272}, + FieldElement{-15303681, -6833769, 28856490, 1357446, 23421993, 1057177, 24091212, -1388970, -22765376, -10650715}, + }, + { + FieldElement{-22751231, -5303997, -12907607, -12768866, -15811511, -7797053, -14839018, -16554220, -1867018, 8398970}, + FieldElement{-31969310, 2106403, -4736360, 1362501, 12813763, 16200670, 22981545, -6291273, 18009408, -15772772}, + FieldElement{-17220923, -9545221, -27784654, 14166835, 29815394, 7444469, 29551787, -3727419, 19288549, 1325865}, + }, + { + FieldElement{15100157, -15835752, -23923978, -1005098, -26450192, 15509408, 12376730, -3479146, 33166107, -8042750}, + FieldElement{20909231, 13023121, -9209752, 16251778, -5778415, -8094914, 12412151, 10018715, 2213263, -13878373}, + FieldElement{32529814, -11074689, 30361439, -16689753, -9135940, 1513226, 22922121, 6382134, -5766928, 8371348}, + }, + }, + { + { + FieldElement{9923462, 11271500, 12616794, 3544722, -29998368, -1721626, 12891687, -8193132, -26442943, 10486144}, + FieldElement{-22597207, -7012665, 8587003, -8257861, 4084309, -12970062, 361726, 2610596, -23921530, -11455195}, + FieldElement{5408411, -1136691, -4969122, 10561668, 24145918, 14240566, 31319731, -4235541, 19985175, -3436086}, + }, + { + FieldElement{-13994457, 16616821, 14549246, 3341099, 32155958, 13648976, -17577068, 8849297, 65030, 8370684}, + FieldElement{-8320926, -12049626, 31204563, 5839400, -20627288, -1057277, -19442942, 6922164, 12743482, -9800518}, + FieldElement{-2361371, 12678785, 28815050, 4759974, -23893047, 4884717, 23783145, 11038569, 18800704, 255233}, + }, + { + FieldElement{-5269658, -1773886, 13957886, 7990715, 23132995, 728773, 13393847, 9066957, 19258688, -14753793}, + FieldElement{-2936654, -10827535, -10432089, 14516793, -3640786, 4372541, -31934921, 2209390, -1524053, 2055794}, + FieldElement{580882, 16705327, 5468415, -2683018, -30926419, -14696000, -7203346, -8994389, -30021019, 7394435}, + }, + { + FieldElement{23838809, 1822728, -15738443, 15242727, 8318092, -3733104, -21672180, -3492205, -4821741, 14799921}, + FieldElement{13345610, 9759151, 3371034, -16137791, 16353039, 8577942, 31129804, 13496856, -9056018, 7402518}, + FieldElement{2286874, -4435931, -20042458, -2008336, -13696227, 5038122, 11006906, -15760352, 8205061, 1607563}, + }, + { + FieldElement{14414086, -8002132, 3331830, -3208217, 22249151, -5594188, 18364661, -2906958, 30019587, -9029278}, + FieldElement{-27688051, 1585953, -10775053, 931069, -29120221, -11002319, -14410829, 12029093, 9944378, 8024}, + FieldElement{4368715, -3709630, 29874200, -15022983, -20230386, -11410704, -16114594, -999085, -8142388, 5640030}, + }, + { + FieldElement{10299610, 13746483, 11661824, 16234854, 7630238, 5998374, 9809887, -16694564, 15219798, -14327783}, + FieldElement{27425505, -5719081, 3055006, 10660664, 23458024, 595578, -15398605, -1173195, -18342183, 9742717}, + FieldElement{6744077, 2427284, 26042789, 2720740, -847906, 1118974, 32324614, 7406442, 12420155, 1994844}, + }, + { + FieldElement{14012521, -5024720, -18384453, -9578469, -26485342, -3936439, -13033478, -10909803, 24319929, -6446333}, + FieldElement{16412690, -4507367, 10772641, 15929391, -17068788, -4658621, 10555945, -10484049, -30102368, -4739048}, + FieldElement{22397382, -7767684, -9293161, -12792868, 17166287, -9755136, -27333065, 6199366, 21880021, -12250760}, + }, + { + FieldElement{-4283307, 5368523, -31117018, 8163389, -30323063, 3209128, 16557151, 8890729, 8840445, 4957760}, + FieldElement{-15447727, 709327, -6919446, -10870178, -29777922, 6522332, -21720181, 12130072, -14796503, 5005757}, + FieldElement{-2114751, -14308128, 23019042, 15765735, -25269683, 6002752, 10183197, -13239326, -16395286, -2176112}, + }, + }, + { + { + FieldElement{-19025756, 1632005, 13466291, -7995100, -23640451, 16573537, -32013908, -3057104, 22208662, 2000468}, + FieldElement{3065073, -1412761, -25598674, -361432, -17683065, -5703415, -8164212, 11248527, -3691214, -7414184}, + FieldElement{10379208, -6045554, 8877319, 1473647, -29291284, -12507580, 16690915, 2553332, -3132688, 16400289}, + }, + { + FieldElement{15716668, 1254266, -18472690, 7446274, -8448918, 6344164, -22097271, -7285580, 26894937, 9132066}, + FieldElement{24158887, 12938817, 11085297, -8177598, -28063478, -4457083, -30576463, 64452, -6817084, -2692882}, + FieldElement{13488534, 7794716, 22236231, 5989356, 25426474, -12578208, 2350710, -3418511, -4688006, 2364226}, + }, + { + FieldElement{16335052, 9132434, 25640582, 6678888, 1725628, 8517937, -11807024, -11697457, 15445875, -7798101}, + FieldElement{29004207, -7867081, 28661402, -640412, -12794003, -7943086, 31863255, -4135540, -278050, -15759279}, + FieldElement{-6122061, -14866665, -28614905, 14569919, -10857999, -3591829, 10343412, -6976290, -29828287, -10815811}, + }, + { + FieldElement{27081650, 3463984, 14099042, -4517604, 1616303, -6205604, 29542636, 15372179, 17293797, 960709}, + FieldElement{20263915, 11434237, -5765435, 11236810, 13505955, -10857102, -16111345, 6493122, -19384511, 7639714}, + FieldElement{-2830798, -14839232, 25403038, -8215196, -8317012, -16173699, 18006287, -16043750, 29994677, -15808121}, + }, + { + FieldElement{9769828, 5202651, -24157398, -13631392, -28051003, -11561624, -24613141, -13860782, -31184575, 709464}, + FieldElement{12286395, 13076066, -21775189, -1176622, -25003198, 4057652, -32018128, -8890874, 16102007, 13205847}, + FieldElement{13733362, 5599946, 10557076, 3195751, -5557991, 8536970, -25540170, 8525972, 10151379, 10394400}, + }, + { + FieldElement{4024660, -16137551, 22436262, 12276534, -9099015, -2686099, 19698229, 11743039, -33302334, 8934414}, + FieldElement{-15879800, -4525240, -8580747, -2934061, 14634845, -698278, -9449077, 3137094, -11536886, 11721158}, + FieldElement{17555939, -5013938, 8268606, 2331751, -22738815, 9761013, 9319229, 8835153, -9205489, -1280045}, + }, + { + FieldElement{-461409, -7830014, 20614118, 16688288, -7514766, -4807119, 22300304, 505429, 6108462, -6183415}, + FieldElement{-5070281, 12367917, -30663534, 3234473, 32617080, -8422642, 29880583, -13483331, -26898490, -7867459}, + FieldElement{-31975283, 5726539, 26934134, 10237677, -3173717, -605053, 24199304, 3795095, 7592688, -14992079}, + }, + { + FieldElement{21594432, -14964228, 17466408, -4077222, 32537084, 2739898, 6407723, 12018833, -28256052, 4298412}, + FieldElement{-20650503, -11961496, -27236275, 570498, 3767144, -1717540, 13891942, -1569194, 13717174, 10805743}, + FieldElement{-14676630, -15644296, 15287174, 11927123, 24177847, -8175568, -796431, 14860609, -26938930, -5863836}, + }, + }, + { + { + FieldElement{12962541, 5311799, -10060768, 11658280, 18855286, -7954201, 13286263, -12808704, -4381056, 9882022}, + FieldElement{18512079, 11319350, -20123124, 15090309, 18818594, 5271736, -22727904, 3666879, -23967430, -3299429}, + FieldElement{-6789020, -3146043, 16192429, 13241070, 15898607, -14206114, -10084880, -6661110, -2403099, 5276065}, + }, + { + FieldElement{30169808, -5317648, 26306206, -11750859, 27814964, 7069267, 7152851, 3684982, 1449224, 13082861}, + FieldElement{10342826, 3098505, 2119311, 193222, 25702612, 12233820, 23697382, 15056736, -21016438, -8202000}, + FieldElement{-33150110, 3261608, 22745853, 7948688, 19370557, -15177665, -26171976, 6482814, -10300080, -11060101}, + }, + { + FieldElement{32869458, -5408545, 25609743, 15678670, -10687769, -15471071, 26112421, 2521008, -22664288, 6904815}, + FieldElement{29506923, 4457497, 3377935, -9796444, -30510046, 12935080, 1561737, 3841096, -29003639, -6657642}, + FieldElement{10340844, -6630377, -18656632, -2278430, 12621151, -13339055, 30878497, -11824370, -25584551, 5181966}, + }, + { + FieldElement{25940115, -12658025, 17324188, -10307374, -8671468, 15029094, 24396252, -16450922, -2322852, -12388574}, + FieldElement{-21765684, 9916823, -1300409, 4079498, -1028346, 11909559, 1782390, 12641087, 20603771, -6561742}, + FieldElement{-18882287, -11673380, 24849422, 11501709, 13161720, -4768874, 1925523, 11914390, 4662781, 7820689}, + }, + { + FieldElement{12241050, -425982, 8132691, 9393934, 32846760, -1599620, 29749456, 12172924, 16136752, 15264020}, + FieldElement{-10349955, -14680563, -8211979, 2330220, -17662549, -14545780, 10658213, 6671822, 19012087, 3772772}, + FieldElement{3753511, -3421066, 10617074, 2028709, 14841030, -6721664, 28718732, -15762884, 20527771, 12988982}, + }, + { + FieldElement{-14822485, -5797269, -3707987, 12689773, -898983, -10914866, -24183046, -10564943, 3299665, -12424953}, + FieldElement{-16777703, -15253301, -9642417, 4978983, 3308785, 8755439, 6943197, 6461331, -25583147, 8991218}, + FieldElement{-17226263, 1816362, -1673288, -6086439, 31783888, -8175991, -32948145, 7417950, -30242287, 1507265}, + }, + { + FieldElement{29692663, 6829891, -10498800, 4334896, 20945975, -11906496, -28887608, 8209391, 14606362, -10647073}, + FieldElement{-3481570, 8707081, 32188102, 5672294, 22096700, 1711240, -33020695, 9761487, 4170404, -2085325}, + FieldElement{-11587470, 14855945, -4127778, -1531857, -26649089, 15084046, 22186522, 16002000, -14276837, -8400798}, + }, + { + FieldElement{-4811456, 13761029, -31703877, -2483919, -3312471, 7869047, -7113572, -9620092, 13240845, 10965870}, + FieldElement{-7742563, -8256762, -14768334, -13656260, -23232383, 12387166, 4498947, 14147411, 29514390, 4302863}, + FieldElement{-13413405, -12407859, 20757302, -13801832, 14785143, 8976368, -5061276, -2144373, 17846988, -13971927}, + }, + }, + { + { + FieldElement{-2244452, -754728, -4597030, -1066309, -6247172, 1455299, -21647728, -9214789, -5222701, 12650267}, + FieldElement{-9906797, -16070310, 21134160, 12198166, -27064575, 708126, 387813, 13770293, -19134326, 10958663}, + FieldElement{22470984, 12369526, 23446014, -5441109, -21520802, -9698723, -11772496, -11574455, -25083830, 4271862}, + }, + { + FieldElement{-25169565, -10053642, -19909332, 15361595, -5984358, 2159192, 75375, -4278529, -32526221, 8469673}, + FieldElement{15854970, 4148314, -8893890, 7259002, 11666551, 13824734, -30531198, 2697372, 24154791, -9460943}, + FieldElement{15446137, -15806644, 29759747, 14019369, 30811221, -9610191, -31582008, 12840104, 24913809, 9815020}, + }, + { + FieldElement{-4709286, -5614269, -31841498, -12288893, -14443537, 10799414, -9103676, 13438769, 18735128, 9466238}, + FieldElement{11933045, 9281483, 5081055, -5183824, -2628162, -4905629, -7727821, -10896103, -22728655, 16199064}, + FieldElement{14576810, 379472, -26786533, -8317236, -29426508, -10812974, -102766, 1876699, 30801119, 2164795}, + }, + { + FieldElement{15995086, 3199873, 13672555, 13712240, -19378835, -4647646, -13081610, -15496269, -13492807, 1268052}, + FieldElement{-10290614, -3659039, -3286592, 10948818, 23037027, 3794475, -3470338, -12600221, -17055369, 3565904}, + FieldElement{29210088, -9419337, -5919792, -4952785, 10834811, -13327726, -16512102, -10820713, -27162222, -14030531}, + }, + { + FieldElement{-13161890, 15508588, 16663704, -8156150, -28349942, 9019123, -29183421, -3769423, 2244111, -14001979}, + FieldElement{-5152875, -3800936, -9306475, -6071583, 16243069, 14684434, -25673088, -16180800, 13491506, 4641841}, + FieldElement{10813417, 643330, -19188515, -728916, 30292062, -16600078, 27548447, -7721242, 14476989, -12767431}, + }, + { + FieldElement{10292079, 9984945, 6481436, 8279905, -7251514, 7032743, 27282937, -1644259, -27912810, 12651324}, + FieldElement{-31185513, -813383, 22271204, 11835308, 10201545, 15351028, 17099662, 3988035, 21721536, -3148940}, + FieldElement{10202177, -6545839, -31373232, -9574638, -32150642, -8119683, -12906320, 3852694, 13216206, 14842320}, + }, + { + FieldElement{-15815640, -10601066, -6538952, -7258995, -6984659, -6581778, -31500847, 13765824, -27434397, 9900184}, + FieldElement{14465505, -13833331, -32133984, -14738873, -27443187, 12990492, 33046193, 15796406, -7051866, -8040114}, + FieldElement{30924417, -8279620, 6359016, -12816335, 16508377, 9071735, -25488601, 15413635, 9524356, -7018878}, + }, + { + FieldElement{12274201, -13175547, 32627641, -1785326, 6736625, 13267305, 5237659, -5109483, 15663516, 4035784}, + FieldElement{-2951309, 8903985, 17349946, 601635, -16432815, -4612556, -13732739, -15889334, -22258478, 4659091}, + FieldElement{-16916263, -4952973, -30393711, -15158821, 20774812, 15897498, 5736189, 15026997, -2178256, -13455585}, + }, + }, + { + { + FieldElement{-8858980, -2219056, 28571666, -10155518, -474467, -10105698, -3801496, 278095, 23440562, -290208}, + FieldElement{10226241, -5928702, 15139956, 120818, -14867693, 5218603, 32937275, 11551483, -16571960, -7442864}, + FieldElement{17932739, -12437276, -24039557, 10749060, 11316803, 7535897, 22503767, 5561594, -3646624, 3898661}, + }, + { + FieldElement{7749907, -969567, -16339731, -16464, -25018111, 15122143, -1573531, 7152530, 21831162, 1245233}, + FieldElement{26958459, -14658026, 4314586, 8346991, -5677764, 11960072, -32589295, -620035, -30402091, -16716212}, + FieldElement{-12165896, 9166947, 33491384, 13673479, 29787085, 13096535, 6280834, 14587357, -22338025, 13987525}, + }, + { + FieldElement{-24349909, 7778775, 21116000, 15572597, -4833266, -5357778, -4300898, -5124639, -7469781, -2858068}, + FieldElement{9681908, -6737123, -31951644, 13591838, -6883821, 386950, 31622781, 6439245, -14581012, 4091397}, + FieldElement{-8426427, 1470727, -28109679, -1596990, 3978627, -5123623, -19622683, 12092163, 29077877, -14741988}, + }, + { + FieldElement{5269168, -6859726, -13230211, -8020715, 25932563, 1763552, -5606110, -5505881, -20017847, 2357889}, + FieldElement{32264008, -15407652, -5387735, -1160093, -2091322, -3946900, 23104804, -12869908, 5727338, 189038}, + FieldElement{14609123, -8954470, -6000566, -16622781, -14577387, -7743898, -26745169, 10942115, -25888931, -14884697}, + }, + { + FieldElement{20513500, 5557931, -15604613, 7829531, 26413943, -2019404, -21378968, 7471781, 13913677, -5137875}, + FieldElement{-25574376, 11967826, 29233242, 12948236, -6754465, 4713227, -8940970, 14059180, 12878652, 8511905}, + FieldElement{-25656801, 3393631, -2955415, -7075526, -2250709, 9366908, -30223418, 6812974, 5568676, -3127656}, + }, + { + FieldElement{11630004, 12144454, 2116339, 13606037, 27378885, 15676917, -17408753, -13504373, -14395196, 8070818}, + FieldElement{27117696, -10007378, -31282771, -5570088, 1127282, 12772488, -29845906, 10483306, -11552749, -1028714}, + FieldElement{10637467, -5688064, 5674781, 1072708, -26343588, -6982302, -1683975, 9177853, -27493162, 15431203}, + }, + { + FieldElement{20525145, 10892566, -12742472, 12779443, -29493034, 16150075, -28240519, 14943142, -15056790, -7935931}, + FieldElement{-30024462, 5626926, -551567, -9981087, 753598, 11981191, 25244767, -3239766, -3356550, 9594024}, + FieldElement{-23752644, 2636870, -5163910, -10103818, 585134, 7877383, 11345683, -6492290, 13352335, -10977084}, + }, + { + FieldElement{-1931799, -5407458, 3304649, -12884869, 17015806, -4877091, -29783850, -7752482, -13215537, -319204}, + FieldElement{20239939, 6607058, 6203985, 3483793, -18386976, -779229, -20723742, 15077870, -22750759, 14523817}, + FieldElement{27406042, -6041657, 27423596, -4497394, 4996214, 10002360, -28842031, -4545494, -30172742, -4805667}, + }, + }, + { + { + FieldElement{11374242, 12660715, 17861383, -12540833, 10935568, 1099227, -13886076, -9091740, -27727044, 11358504}, + FieldElement{-12730809, 10311867, 1510375, 10778093, -2119455, -9145702, 32676003, 11149336, -26123651, 4985768}, + FieldElement{-19096303, 341147, -6197485, -239033, 15756973, -8796662, -983043, 13794114, -19414307, -15621255}, + }, + { + FieldElement{6490081, 11940286, 25495923, -7726360, 8668373, -8751316, 3367603, 6970005, -1691065, -9004790}, + FieldElement{1656497, 13457317, 15370807, 6364910, 13605745, 8362338, -19174622, -5475723, -16796596, -5031438}, + FieldElement{-22273315, -13524424, -64685, -4334223, -18605636, -10921968, -20571065, -7007978, -99853, -10237333}, + }, + { + FieldElement{17747465, 10039260, 19368299, -4050591, -20630635, -16041286, 31992683, -15857976, -29260363, -5511971}, + FieldElement{31932027, -4986141, -19612382, 16366580, 22023614, 88450, 11371999, -3744247, 4882242, -10626905}, + FieldElement{29796507, 37186, 19818052, 10115756, -11829032, 3352736, 18551198, 3272828, -5190932, -4162409}, + }, + { + FieldElement{12501286, 4044383, -8612957, -13392385, -32430052, 5136599, -19230378, -3529697, 330070, -3659409}, + FieldElement{6384877, 2899513, 17807477, 7663917, -2358888, 12363165, 25366522, -8573892, -271295, 12071499}, + FieldElement{-8365515, -4042521, 25133448, -4517355, -6211027, 2265927, -32769618, 1936675, -5159697, 3829363}, + }, + { + FieldElement{28425966, -5835433, -577090, -4697198, -14217555, 6870930, 7921550, -6567787, 26333140, 14267664}, + FieldElement{-11067219, 11871231, 27385719, -10559544, -4585914, -11189312, 10004786, -8709488, -21761224, 8930324}, + FieldElement{-21197785, -16396035, 25654216, -1725397, 12282012, 11008919, 1541940, 4757911, -26491501, -16408940}, + }, + { + FieldElement{13537262, -7759490, -20604840, 10961927, -5922820, -13218065, -13156584, 6217254, -15943699, 13814990}, + FieldElement{-17422573, 15157790, 18705543, 29619, 24409717, -260476, 27361681, 9257833, -1956526, -1776914}, + FieldElement{-25045300, -10191966, 15366585, 15166509, -13105086, 8423556, -29171540, 12361135, -18685978, 4578290}, + }, + { + FieldElement{24579768, 3711570, 1342322, -11180126, -27005135, 14124956, -22544529, 14074919, 21964432, 8235257}, + FieldElement{-6528613, -2411497, 9442966, -5925588, 12025640, -1487420, -2981514, -1669206, 13006806, 2355433}, + FieldElement{-16304899, -13605259, -6632427, -5142349, 16974359, -10911083, 27202044, 1719366, 1141648, -12796236}, + }, + { + FieldElement{-12863944, -13219986, -8318266, -11018091, -6810145, -4843894, 13475066, -3133972, 32674895, 13715045}, + FieldElement{11423335, -5468059, 32344216, 8962751, 24989809, 9241752, -13265253, 16086212, -28740881, -15642093}, + FieldElement{-1409668, 12530728, -6368726, 10847387, 19531186, -14132160, -11709148, 7791794, -27245943, 4383347}, + }, + }, + { + { + FieldElement{-28970898, 5271447, -1266009, -9736989, -12455236, 16732599, -4862407, -4906449, 27193557, 6245191}, + FieldElement{-15193956, 5362278, -1783893, 2695834, 4960227, 12840725, 23061898, 3260492, 22510453, 8577507}, + FieldElement{-12632451, 11257346, -32692994, 13548177, -721004, 10879011, 31168030, 13952092, -29571492, -3635906}, + }, + { + FieldElement{3877321, -9572739, 32416692, 5405324, -11004407, -13656635, 3759769, 11935320, 5611860, 8164018}, + FieldElement{-16275802, 14667797, 15906460, 12155291, -22111149, -9039718, 32003002, -8832289, 5773085, -8422109}, + FieldElement{-23788118, -8254300, 1950875, 8937633, 18686727, 16459170, -905725, 12376320, 31632953, 190926}, + }, + { + FieldElement{-24593607, -16138885, -8423991, 13378746, 14162407, 6901328, -8288749, 4508564, -25341555, -3627528}, + FieldElement{8884438, -5884009, 6023974, 10104341, -6881569, -4941533, 18722941, -14786005, -1672488, 827625}, + FieldElement{-32720583, -16289296, -32503547, 7101210, 13354605, 2659080, -1800575, -14108036, -24878478, 1541286}, + }, + { + FieldElement{2901347, -1117687, 3880376, -10059388, -17620940, -3612781, -21802117, -3567481, 20456845, -1885033}, + FieldElement{27019610, 12299467, -13658288, -1603234, -12861660, -4861471, -19540150, -5016058, 29439641, 15138866}, + FieldElement{21536104, -6626420, -32447818, -10690208, -22408077, 5175814, -5420040, -16361163, 7779328, 109896}, + }, + { + FieldElement{30279744, 14648750, -8044871, 6425558, 13639621, -743509, 28698390, 12180118, 23177719, -554075}, + FieldElement{26572847, 3405927, -31701700, 12890905, -19265668, 5335866, -6493768, 2378492, 4439158, -13279347}, + FieldElement{-22716706, 3489070, -9225266, -332753, 18875722, -1140095, 14819434, -12731527, -17717757, -5461437}, + }, + { + FieldElement{-5056483, 16566551, 15953661, 3767752, -10436499, 15627060, -820954, 2177225, 8550082, -15114165}, + FieldElement{-18473302, 16596775, -381660, 15663611, 22860960, 15585581, -27844109, -3582739, -23260460, -8428588}, + FieldElement{-32480551, 15707275, -8205912, -5652081, 29464558, 2713815, -22725137, 15860482, -21902570, 1494193}, + }, + { + FieldElement{-19562091, -14087393, -25583872, -9299552, 13127842, 759709, 21923482, 16529112, 8742704, 12967017}, + FieldElement{-28464899, 1553205, 32536856, -10473729, -24691605, -406174, -8914625, -2933896, -29903758, 15553883}, + FieldElement{21877909, 3230008, 9881174, 10539357, -4797115, 2841332, 11543572, 14513274, 19375923, -12647961}, + }, + { + FieldElement{8832269, -14495485, 13253511, 5137575, 5037871, 4078777, 24880818, -6222716, 2862653, 9455043}, + FieldElement{29306751, 5123106, 20245049, -14149889, 9592566, 8447059, -2077124, -2990080, 15511449, 4789663}, + FieldElement{-20679756, 7004547, 8824831, -9434977, -4045704, -3750736, -5754762, 108893, 23513200, 16652362}, + }, + }, + { + { + FieldElement{-33256173, 4144782, -4476029, -6579123, 10770039, -7155542, -6650416, -12936300, -18319198, 10212860}, + FieldElement{2756081, 8598110, 7383731, -6859892, 22312759, -1105012, 21179801, 2600940, -9988298, -12506466}, + FieldElement{-24645692, 13317462, -30449259, -15653928, 21365574, -10869657, 11344424, 864440, -2499677, -16710063}, + }, + { + FieldElement{-26432803, 6148329, -17184412, -14474154, 18782929, -275997, -22561534, 211300, 2719757, 4940997}, + FieldElement{-1323882, 3911313, -6948744, 14759765, -30027150, 7851207, 21690126, 8518463, 26699843, 5276295}, + FieldElement{-13149873, -6429067, 9396249, 365013, 24703301, -10488939, 1321586, 149635, -15452774, 7159369}, + }, + { + FieldElement{9987780, -3404759, 17507962, 9505530, 9731535, -2165514, 22356009, 8312176, 22477218, -8403385}, + FieldElement{18155857, -16504990, 19744716, 9006923, 15154154, -10538976, 24256460, -4864995, -22548173, 9334109}, + FieldElement{2986088, -4911893, 10776628, -3473844, 10620590, -7083203, -21413845, 14253545, -22587149, 536906}, + }, + { + FieldElement{4377756, 8115836, 24567078, 15495314, 11625074, 13064599, 7390551, 10589625, 10838060, -15420424}, + FieldElement{-19342404, 867880, 9277171, -3218459, -14431572, -1986443, 19295826, -15796950, 6378260, 699185}, + FieldElement{7895026, 4057113, -7081772, -13077756, -17886831, -323126, -716039, 15693155, -5045064, -13373962}, + }, + { + FieldElement{-7737563, -5869402, -14566319, -7406919, 11385654, 13201616, 31730678, -10962840, -3918636, -9669325}, + FieldElement{10188286, -15770834, -7336361, 13427543, 22223443, 14896287, 30743455, 7116568, -21786507, 5427593}, + FieldElement{696102, 13206899, 27047647, -10632082, 15285305, -9853179, 10798490, -4578720, 19236243, 12477404}, + }, + { + FieldElement{-11229439, 11243796, -17054270, -8040865, -788228, -8167967, -3897669, 11180504, -23169516, 7733644}, + FieldElement{17800790, -14036179, -27000429, -11766671, 23887827, 3149671, 23466177, -10538171, 10322027, 15313801}, + FieldElement{26246234, 11968874, 32263343, -5468728, 6830755, -13323031, -15794704, -101982, -24449242, 10890804}, + }, + { + FieldElement{-31365647, 10271363, -12660625, -6267268, 16690207, -13062544, -14982212, 16484931, 25180797, -5334884}, + FieldElement{-586574, 10376444, -32586414, -11286356, 19801893, 10997610, 2276632, 9482883, 316878, 13820577}, + FieldElement{-9882808, -4510367, -2115506, 16457136, -11100081, 11674996, 30756178, -7515054, 30696930, -3712849}, + }, + { + FieldElement{32988917, -9603412, 12499366, 7910787, -10617257, -11931514, -7342816, -9985397, -32349517, 7392473}, + FieldElement{-8855661, 15927861, 9866406, -3649411, -2396914, -16655781, -30409476, -9134995, 25112947, -2926644}, + FieldElement{-2504044, -436966, 25621774, -5678772, 15085042, -5479877, -24884878, -13526194, 5537438, -13914319}, + }, + }, + { + { + FieldElement{-11225584, 2320285, -9584280, 10149187, -33444663, 5808648, -14876251, -1729667, 31234590, 6090599}, + FieldElement{-9633316, 116426, 26083934, 2897444, -6364437, -2688086, 609721, 15878753, -6970405, -9034768}, + FieldElement{-27757857, 247744, -15194774, -9002551, 23288161, -10011936, -23869595, 6503646, 20650474, 1804084}, + }, + { + FieldElement{-27589786, 15456424, 8972517, 8469608, 15640622, 4439847, 3121995, -10329713, 27842616, -202328}, + FieldElement{-15306973, 2839644, 22530074, 10026331, 4602058, 5048462, 28248656, 5031932, -11375082, 12714369}, + FieldElement{20807691, -7270825, 29286141, 11421711, -27876523, -13868230, -21227475, 1035546, -19733229, 12796920}, + }, + { + FieldElement{12076899, -14301286, -8785001, -11848922, -25012791, 16400684, -17591495, -12899438, 3480665, -15182815}, + FieldElement{-32361549, 5457597, 28548107, 7833186, 7303070, -11953545, -24363064, -15921875, -33374054, 2771025}, + FieldElement{-21389266, 421932, 26597266, 6860826, 22486084, -6737172, -17137485, -4210226, -24552282, 15673397}, + }, + { + FieldElement{-20184622, 2338216, 19788685, -9620956, -4001265, -8740893, -20271184, 4733254, 3727144, -12934448}, + FieldElement{6120119, 814863, -11794402, -622716, 6812205, -15747771, 2019594, 7975683, 31123697, -10958981}, + FieldElement{30069250, -11435332, 30434654, 2958439, 18399564, -976289, 12296869, 9204260, -16432438, 9648165}, + }, + { + FieldElement{32705432, -1550977, 30705658, 7451065, -11805606, 9631813, 3305266, 5248604, -26008332, -11377501}, + FieldElement{17219865, 2375039, -31570947, -5575615, -19459679, 9219903, 294711, 15298639, 2662509, -16297073}, + FieldElement{-1172927, -7558695, -4366770, -4287744, -21346413, -8434326, 32087529, -1222777, 32247248, -14389861}, + }, + { + FieldElement{14312628, 1221556, 17395390, -8700143, -4945741, -8684635, -28197744, -9637817, -16027623, -13378845}, + FieldElement{-1428825, -9678990, -9235681, 6549687, -7383069, -468664, 23046502, 9803137, 17597934, 2346211}, + FieldElement{18510800, 15337574, 26171504, 981392, -22241552, 7827556, -23491134, -11323352, 3059833, -11782870}, + }, + { + FieldElement{10141598, 6082907, 17829293, -1947643, 9830092, 13613136, -25556636, -5544586, -33502212, 3592096}, + FieldElement{33114168, -15889352, -26525686, -13343397, 33076705, 8716171, 1151462, 1521897, -982665, -6837803}, + FieldElement{-32939165, -4255815, 23947181, -324178, -33072974, -12305637, -16637686, 3891704, 26353178, 693168}, + }, + { + FieldElement{30374239, 1595580, -16884039, 13186931, 4600344, 406904, 9585294, -400668, 31375464, 14369965}, + FieldElement{-14370654, -7772529, 1510301, 6434173, -18784789, -6262728, 32732230, -13108839, 17901441, 16011505}, + FieldElement{18171223, -11934626, -12500402, 15197122, -11038147, -15230035, -19172240, -16046376, 8764035, 12309598}, + }, + }, + { + { + FieldElement{5975908, -5243188, -19459362, -9681747, -11541277, 14015782, -23665757, 1228319, 17544096, -10593782}, + FieldElement{5811932, -1715293, 3442887, -2269310, -18367348, -8359541, -18044043, -15410127, -5565381, 12348900}, + FieldElement{-31399660, 11407555, 25755363, 6891399, -3256938, 14872274, -24849353, 8141295, -10632534, -585479}, + }, + { + FieldElement{-12675304, 694026, -5076145, 13300344, 14015258, -14451394, -9698672, -11329050, 30944593, 1130208}, + FieldElement{8247766, -6710942, -26562381, -7709309, -14401939, -14648910, 4652152, 2488540, 23550156, -271232}, + FieldElement{17294316, -3788438, 7026748, 15626851, 22990044, 113481, 2267737, -5908146, -408818, -137719}, + }, + { + FieldElement{16091085, -16253926, 18599252, 7340678, 2137637, -1221657, -3364161, 14550936, 3260525, -7166271}, + FieldElement{-4910104, -13332887, 18550887, 10864893, -16459325, -7291596, -23028869, -13204905, -12748722, 2701326}, + FieldElement{-8574695, 16099415, 4629974, -16340524, -20786213, -6005432, -10018363, 9276971, 11329923, 1862132}, + }, + { + FieldElement{14763076, -15903608, -30918270, 3689867, 3511892, 10313526, -21951088, 12219231, -9037963, -940300}, + FieldElement{8894987, -3446094, 6150753, 3013931, 301220, 15693451, -31981216, -2909717, -15438168, 11595570}, + FieldElement{15214962, 3537601, -26238722, -14058872, 4418657, -15230761, 13947276, 10730794, -13489462, -4363670}, + }, + { + FieldElement{-2538306, 7682793, 32759013, 263109, -29984731, -7955452, -22332124, -10188635, 977108, 699994}, + FieldElement{-12466472, 4195084, -9211532, 550904, -15565337, 12917920, 19118110, -439841, -30534533, -14337913}, + FieldElement{31788461, -14507657, 4799989, 7372237, 8808585, -14747943, 9408237, -10051775, 12493932, -5409317}, + }, + { + FieldElement{-25680606, 5260744, -19235809, -6284470, -3695942, 16566087, 27218280, 2607121, 29375955, 6024730}, + FieldElement{842132, -2794693, -4763381, -8722815, 26332018, -12405641, 11831880, 6985184, -9940361, 2854096}, + FieldElement{-4847262, -7969331, 2516242, -5847713, 9695691, -7221186, 16512645, 960770, 12121869, 16648078}, + }, + { + FieldElement{-15218652, 14667096, -13336229, 2013717, 30598287, -464137, -31504922, -7882064, 20237806, 2838411}, + FieldElement{-19288047, 4453152, 15298546, -16178388, 22115043, -15972604, 12544294, -13470457, 1068881, -12499905}, + FieldElement{-9558883, -16518835, 33238498, 13506958, 30505848, -1114596, -8486907, -2630053, 12521378, 4845654}, + }, + { + FieldElement{-28198521, 10744108, -2958380, 10199664, 7759311, -13088600, 3409348, -873400, -6482306, -12885870}, + FieldElement{-23561822, 6230156, -20382013, 10655314, -24040585, -11621172, 10477734, -1240216, -3113227, 13974498}, + FieldElement{12966261, 15550616, -32038948, -1615346, 21025980, -629444, 5642325, 7188737, 18895762, 12629579}, + }, + }, + { + { + FieldElement{14741879, -14946887, 22177208, -11721237, 1279741, 8058600, 11758140, 789443, 32195181, 3895677}, + FieldElement{10758205, 15755439, -4509950, 9243698, -4879422, 6879879, -2204575, -3566119, -8982069, 4429647}, + FieldElement{-2453894, 15725973, -20436342, -10410672, -5803908, -11040220, -7135870, -11642895, 18047436, -15281743}, + }, + { + FieldElement{-25173001, -11307165, 29759956, 11776784, -22262383, -15820455, 10993114, -12850837, -17620701, -9408468}, + FieldElement{21987233, 700364, -24505048, 14972008, -7774265, -5718395, 32155026, 2581431, -29958985, 8773375}, + FieldElement{-25568350, 454463, -13211935, 16126715, 25240068, 8594567, 20656846, 12017935, -7874389, -13920155}, + }, + { + FieldElement{6028182, 6263078, -31011806, -11301710, -818919, 2461772, -31841174, -5468042, -1721788, -2776725}, + FieldElement{-12278994, 16624277, 987579, -5922598, 32908203, 1248608, 7719845, -4166698, 28408820, 6816612}, + FieldElement{-10358094, -8237829, 19549651, -12169222, 22082623, 16147817, 20613181, 13982702, -10339570, 5067943}, + }, + { + FieldElement{-30505967, -3821767, 12074681, 13582412, -19877972, 2443951, -19719286, 12746132, 5331210, -10105944}, + FieldElement{30528811, 3601899, -1957090, 4619785, -27361822, -15436388, 24180793, -12570394, 27679908, -1648928}, + FieldElement{9402404, -13957065, 32834043, 10838634, -26580150, -13237195, 26653274, -8685565, 22611444, -12715406}, + }, + { + FieldElement{22190590, 1118029, 22736441, 15130463, -30460692, -5991321, 19189625, -4648942, 4854859, 6622139}, + FieldElement{-8310738, -2953450, -8262579, -3388049, -10401731, -271929, 13424426, -3567227, 26404409, 13001963}, + FieldElement{-31241838, -15415700, -2994250, 8939346, 11562230, -12840670, -26064365, -11621720, -15405155, 11020693}, + }, + { + FieldElement{1866042, -7949489, -7898649, -10301010, 12483315, 13477547, 3175636, -12424163, 28761762, 1406734}, + FieldElement{-448555, -1777666, 13018551, 3194501, -9580420, -11161737, 24760585, -4347088, 25577411, -13378680}, + FieldElement{-24290378, 4759345, -690653, -1852816, 2066747, 10693769, -29595790, 9884936, -9368926, 4745410}, + }, + { + FieldElement{-9141284, 6049714, -19531061, -4341411, -31260798, 9944276, -15462008, -11311852, 10931924, -11931931}, + FieldElement{-16561513, 14112680, -8012645, 4817318, -8040464, -11414606, -22853429, 10856641, -20470770, 13434654}, + FieldElement{22759489, -10073434, -16766264, -1871422, 13637442, -10168091, 1765144, -12654326, 28445307, -5364710}, + }, + { + FieldElement{29875063, 12493613, 2795536, -3786330, 1710620, 15181182, -10195717, -8788675, 9074234, 1167180}, + FieldElement{-26205683, 11014233, -9842651, -2635485, -26908120, 7532294, -18716888, -9535498, 3843903, 9367684}, + FieldElement{-10969595, -6403711, 9591134, 9582310, 11349256, 108879, 16235123, 8601684, -139197, 4242895}, + }, + }, + { + { + FieldElement{22092954, -13191123, -2042793, -11968512, 32186753, -11517388, -6574341, 2470660, -27417366, 16625501}, + FieldElement{-11057722, 3042016, 13770083, -9257922, 584236, -544855, -7770857, 2602725, -27351616, 14247413}, + FieldElement{6314175, -10264892, -32772502, 15957557, -10157730, 168750, -8618807, 14290061, 27108877, -1180880}, + }, + { + FieldElement{-8586597, -7170966, 13241782, 10960156, -32991015, -13794596, 33547976, -11058889, -27148451, 981874}, + FieldElement{22833440, 9293594, -32649448, -13618667, -9136966, 14756819, -22928859, -13970780, -10479804, -16197962}, + FieldElement{-7768587, 3326786, -28111797, 10783824, 19178761, 14905060, 22680049, 13906969, -15933690, 3797899}, + }, + { + FieldElement{21721356, -4212746, -12206123, 9310182, -3882239, -13653110, 23740224, -2709232, 20491983, -8042152}, + FieldElement{9209270, -15135055, -13256557, -6167798, -731016, 15289673, 25947805, 15286587, 30997318, -6703063}, + FieldElement{7392032, 16618386, 23946583, -8039892, -13265164, -1533858, -14197445, -2321576, 17649998, -250080}, + }, + { + FieldElement{-9301088, -14193827, 30609526, -3049543, -25175069, -1283752, -15241566, -9525724, -2233253, 7662146}, + FieldElement{-17558673, 1763594, -33114336, 15908610, -30040870, -12174295, 7335080, -8472199, -3174674, 3440183}, + FieldElement{-19889700, -5977008, -24111293, -9688870, 10799743, -16571957, 40450, -4431835, 4862400, 1133}, + }, + { + FieldElement{-32856209, -7873957, -5422389, 14860950, -16319031, 7956142, 7258061, 311861, -30594991, -7379421}, + FieldElement{-3773428, -1565936, 28985340, 7499440, 24445838, 9325937, 29727763, 16527196, 18278453, 15405622}, + FieldElement{-4381906, 8508652, -19898366, -3674424, -5984453, 15149970, -13313598, 843523, -21875062, 13626197}, + }, + { + FieldElement{2281448, -13487055, -10915418, -2609910, 1879358, 16164207, -10783882, 3953792, 13340839, 15928663}, + FieldElement{31727126, -7179855, -18437503, -8283652, 2875793, -16390330, -25269894, -7014826, -23452306, 5964753}, + FieldElement{4100420, -5959452, -17179337, 6017714, -18705837, 12227141, -26684835, 11344144, 2538215, -7570755}, + }, + { + FieldElement{-9433605, 6123113, 11159803, -2156608, 30016280, 14966241, -20474983, 1485421, -629256, -15958862}, + FieldElement{-26804558, 4260919, 11851389, 9658551, -32017107, 16367492, -20205425, -13191288, 11659922, -11115118}, + FieldElement{26180396, 10015009, -30844224, -8581293, 5418197, 9480663, 2231568, -10170080, 33100372, -1306171}, + }, + { + FieldElement{15121113, -5201871, -10389905, 15427821, -27509937, -15992507, 21670947, 4486675, -5931810, -14466380}, + FieldElement{16166486, -9483733, -11104130, 6023908, -31926798, -1364923, 2340060, -16254968, -10735770, -10039824}, + FieldElement{28042865, -3557089, -12126526, 12259706, -3717498, -6945899, 6766453, -8689599, 18036436, 5803270}, + }, + }, + { + { + FieldElement{-817581, 6763912, 11803561, 1585585, 10958447, -2671165, 23855391, 4598332, -6159431, -14117438}, + FieldElement{-31031306, -14256194, 17332029, -2383520, 31312682, -5967183, 696309, 50292, -20095739, 11763584}, + FieldElement{-594563, -2514283, -32234153, 12643980, 12650761, 14811489, 665117, -12613632, -19773211, -10713562}, + }, + { + FieldElement{30464590, -11262872, -4127476, -12734478, 19835327, -7105613, -24396175, 2075773, -17020157, 992471}, + FieldElement{18357185, -6994433, 7766382, 16342475, -29324918, 411174, 14578841, 8080033, -11574335, -10601610}, + FieldElement{19598397, 10334610, 12555054, 2555664, 18821899, -10339780, 21873263, 16014234, 26224780, 16452269}, + }, + { + FieldElement{-30223925, 5145196, 5944548, 16385966, 3976735, 2009897, -11377804, -7618186, -20533829, 3698650}, + FieldElement{14187449, 3448569, -10636236, -10810935, -22663880, -3433596, 7268410, -10890444, 27394301, 12015369}, + FieldElement{19695761, 16087646, 28032085, 12999827, 6817792, 11427614, 20244189, -1312777, -13259127, -3402461}, + }, + { + FieldElement{30860103, 12735208, -1888245, -4699734, -16974906, 2256940, -8166013, 12298312, -8550524, -10393462}, + FieldElement{-5719826, -11245325, -1910649, 15569035, 26642876, -7587760, -5789354, -15118654, -4976164, 12651793}, + FieldElement{-2848395, 9953421, 11531313, -5282879, 26895123, -12697089, -13118820, -16517902, 9768698, -2533218}, + }, + { + FieldElement{-24719459, 1894651, -287698, -4704085, 15348719, -8156530, 32767513, 12765450, 4940095, 10678226}, + FieldElement{18860224, 15980149, -18987240, -1562570, -26233012, -11071856, -7843882, 13944024, -24372348, 16582019}, + FieldElement{-15504260, 4970268, -29893044, 4175593, -20993212, -2199756, -11704054, 15444560, -11003761, 7989037}, + }, + { + FieldElement{31490452, 5568061, -2412803, 2182383, -32336847, 4531686, -32078269, 6200206, -19686113, -14800171}, + FieldElement{-17308668, -15879940, -31522777, -2831, -32887382, 16375549, 8680158, -16371713, 28550068, -6857132}, + FieldElement{-28126887, -5688091, 16837845, -1820458, -6850681, 12700016, -30039981, 4364038, 1155602, 5988841}, + }, + { + FieldElement{21890435, -13272907, -12624011, 12154349, -7831873, 15300496, 23148983, -4470481, 24618407, 8283181}, + FieldElement{-33136107, -10512751, 9975416, 6841041, -31559793, 16356536, 3070187, -7025928, 1466169, 10740210}, + FieldElement{-1509399, -15488185, -13503385, -10655916, 32799044, 909394, -13938903, -5779719, -32164649, -15327040}, + }, + { + FieldElement{3960823, -14267803, -28026090, -15918051, -19404858, 13146868, 15567327, 951507, -3260321, -573935}, + FieldElement{24740841, 5052253, -30094131, 8961361, 25877428, 6165135, -24368180, 14397372, -7380369, -6144105}, + FieldElement{-28888365, 3510803, -28103278, -1158478, -11238128, -10631454, -15441463, -14453128, -1625486, -6494814}, + }, + }, + { + { + FieldElement{793299, -9230478, 8836302, -6235707, -27360908, -2369593, 33152843, -4885251, -9906200, -621852}, + FieldElement{5666233, 525582, 20782575, -8038419, -24538499, 14657740, 16099374, 1468826, -6171428, -15186581}, + FieldElement{-4859255, -3779343, -2917758, -6748019, 7778750, 11688288, -30404353, -9871238, -1558923, -9863646}, + }, + { + FieldElement{10896332, -7719704, 824275, 472601, -19460308, 3009587, 25248958, 14783338, -30581476, -15757844}, + FieldElement{10566929, 12612572, -31944212, 11118703, -12633376, 12362879, 21752402, 8822496, 24003793, 14264025}, + FieldElement{27713862, -7355973, -11008240, 9227530, 27050101, 2504721, 23886875, -13117525, 13958495, -5732453}, + }, + { + FieldElement{-23481610, 4867226, -27247128, 3900521, 29838369, -8212291, -31889399, -10041781, 7340521, -15410068}, + FieldElement{4646514, -8011124, -22766023, -11532654, 23184553, 8566613, 31366726, -1381061, -15066784, -10375192}, + FieldElement{-17270517, 12723032, -16993061, 14878794, 21619651, -6197576, 27584817, 3093888, -8843694, 3849921}, + }, + { + FieldElement{-9064912, 2103172, 25561640, -15125738, -5239824, 9582958, 32477045, -9017955, 5002294, -15550259}, + FieldElement{-12057553, -11177906, 21115585, -13365155, 8808712, -12030708, 16489530, 13378448, -25845716, 12741426}, + FieldElement{-5946367, 10645103, -30911586, 15390284, -3286982, -7118677, 24306472, 15852464, 28834118, -7646072}, + }, + { + FieldElement{-17335748, -9107057, -24531279, 9434953, -8472084, -583362, -13090771, 455841, 20461858, 5491305}, + FieldElement{13669248, -16095482, -12481974, -10203039, -14569770, -11893198, -24995986, 11293807, -28588204, -9421832}, + FieldElement{28497928, 6272777, -33022994, 14470570, 8906179, -1225630, 18504674, -14165166, 29867745, -8795943}, + }, + { + FieldElement{-16207023, 13517196, -27799630, -13697798, 24009064, -6373891, -6367600, -13175392, 22853429, -4012011}, + FieldElement{24191378, 16712145, -13931797, 15217831, 14542237, 1646131, 18603514, -11037887, 12876623, -2112447}, + FieldElement{17902668, 4518229, -411702, -2829247, 26878217, 5258055, -12860753, 608397, 16031844, 3723494}, + }, + { + FieldElement{-28632773, 12763728, -20446446, 7577504, 33001348, -13017745, 17558842, -7872890, 23896954, -4314245}, + FieldElement{-20005381, -12011952, 31520464, 605201, 2543521, 5991821, -2945064, 7229064, -9919646, -8826859}, + FieldElement{28816045, 298879, -28165016, -15920938, 19000928, -1665890, -12680833, -2949325, -18051778, -2082915}, + }, + { + FieldElement{16000882, -344896, 3493092, -11447198, -29504595, -13159789, 12577740, 16041268, -19715240, 7847707}, + FieldElement{10151868, 10572098, 27312476, 7922682, 14825339, 4723128, -32855931, -6519018, -10020567, 3852848}, + FieldElement{-11430470, 15697596, -21121557, -4420647, 5386314, 15063598, 16514493, -15932110, 29330899, -15076224}, + }, + }, + { + { + FieldElement{-25499735, -4378794, -15222908, -6901211, 16615731, 2051784, 3303702, 15490, -27548796, 12314391}, + FieldElement{15683520, -6003043, 18109120, -9980648, 15337968, -5997823, -16717435, 15921866, 16103996, -3731215}, + FieldElement{-23169824, -10781249, 13588192, -1628807, -3798557, -1074929, -19273607, 5402699, -29815713, -9841101}, + }, + { + FieldElement{23190676, 2384583, -32714340, 3462154, -29903655, -1529132, -11266856, 8911517, -25205859, 2739713}, + FieldElement{21374101, -3554250, -33524649, 9874411, 15377179, 11831242, -33529904, 6134907, 4931255, 11987849}, + FieldElement{-7732, -2978858, -16223486, 7277597, 105524, -322051, -31480539, 13861388, -30076310, 10117930}, + }, + { + FieldElement{-29501170, -10744872, -26163768, 13051539, -25625564, 5089643, -6325503, 6704079, 12890019, 15728940}, + FieldElement{-21972360, -11771379, -951059, -4418840, 14704840, 2695116, 903376, -10428139, 12885167, 8311031}, + FieldElement{-17516482, 5352194, 10384213, -13811658, 7506451, 13453191, 26423267, 4384730, 1888765, -5435404}, + }, + { + FieldElement{-25817338, -3107312, -13494599, -3182506, 30896459, -13921729, -32251644, -12707869, -19464434, -3340243}, + FieldElement{-23607977, -2665774, -526091, 4651136, 5765089, 4618330, 6092245, 14845197, 17151279, -9854116}, + FieldElement{-24830458, -12733720, -15165978, 10367250, -29530908, -265356, 22825805, -7087279, -16866484, 16176525}, + }, + { + FieldElement{-23583256, 6564961, 20063689, 3798228, -4740178, 7359225, 2006182, -10363426, -28746253, -10197509}, + FieldElement{-10626600, -4486402, -13320562, -5125317, 3432136, -6393229, 23632037, -1940610, 32808310, 1099883}, + FieldElement{15030977, 5768825, -27451236, -2887299, -6427378, -15361371, -15277896, -6809350, 2051441, -15225865}, + }, + { + FieldElement{-3362323, -7239372, 7517890, 9824992, 23555850, 295369, 5148398, -14154188, -22686354, 16633660}, + FieldElement{4577086, -16752288, 13249841, -15304328, 19958763, -14537274, 18559670, -10759549, 8402478, -9864273}, + FieldElement{-28406330, -1051581, -26790155, -907698, -17212414, -11030789, 9453451, -14980072, 17983010, 9967138}, + }, + { + FieldElement{-25762494, 6524722, 26585488, 9969270, 24709298, 1220360, -1677990, 7806337, 17507396, 3651560}, + FieldElement{-10420457, -4118111, 14584639, 15971087, -15768321, 8861010, 26556809, -5574557, -18553322, -11357135}, + FieldElement{2839101, 14284142, 4029895, 3472686, 14402957, 12689363, -26642121, 8459447, -5605463, -7621941}, + }, + { + FieldElement{-4839289, -3535444, 9744961, 2871048, 25113978, 3187018, -25110813, -849066, 17258084, -7977739}, + FieldElement{18164541, -10595176, -17154882, -1542417, 19237078, -9745295, 23357533, -15217008, 26908270, 12150756}, + FieldElement{-30264870, -7647865, 5112249, -7036672, -1499807, -6974257, 43168, -5537701, -32302074, 16215819}, + }, + }, + { + { + FieldElement{-6898905, 9824394, -12304779, -4401089, -31397141, -6276835, 32574489, 12532905, -7503072, -8675347}, + FieldElement{-27343522, -16515468, -27151524, -10722951, 946346, 16291093, 254968, 7168080, 21676107, -1943028}, + FieldElement{21260961, -8424752, -16831886, -11920822, -23677961, 3968121, -3651949, -6215466, -3556191, -7913075}, + }, + { + FieldElement{16544754, 13250366, -16804428, 15546242, -4583003, 12757258, -2462308, -8680336, -18907032, -9662799}, + FieldElement{-2415239, -15577728, 18312303, 4964443, -15272530, -12653564, 26820651, 16690659, 25459437, -4564609}, + FieldElement{-25144690, 11425020, 28423002, -11020557, -6144921, -15826224, 9142795, -2391602, -6432418, -1644817}, + }, + { + FieldElement{-23104652, 6253476, 16964147, -3768872, -25113972, -12296437, -27457225, -16344658, 6335692, 7249989}, + FieldElement{-30333227, 13979675, 7503222, -12368314, -11956721, -4621693, -30272269, 2682242, 25993170, -12478523}, + FieldElement{4364628, 5930691, 32304656, -10044554, -8054781, 15091131, 22857016, -10598955, 31820368, 15075278}, + }, + { + FieldElement{31879134, -8918693, 17258761, 90626, -8041836, -4917709, 24162788, -9650886, -17970238, 12833045}, + FieldElement{19073683, 14851414, -24403169, -11860168, 7625278, 11091125, -19619190, 2074449, -9413939, 14905377}, + FieldElement{24483667, -11935567, -2518866, -11547418, -1553130, 15355506, -25282080, 9253129, 27628530, -7555480}, + }, + { + FieldElement{17597607, 8340603, 19355617, 552187, 26198470, -3176583, 4593324, -9157582, -14110875, 15297016}, + FieldElement{510886, 14337390, -31785257, 16638632, 6328095, 2713355, -20217417, -11864220, 8683221, 2921426}, + FieldElement{18606791, 11874196, 27155355, -5281482, -24031742, 6265446, -25178240, -1278924, 4674690, 13890525}, + }, + { + FieldElement{13609624, 13069022, -27372361, -13055908, 24360586, 9592974, 14977157, 9835105, 4389687, 288396}, + FieldElement{9922506, -519394, 13613107, 5883594, -18758345, -434263, -12304062, 8317628, 23388070, 16052080}, + FieldElement{12720016, 11937594, -31970060, -5028689, 26900120, 8561328, -20155687, -11632979, -14754271, -10812892}, + }, + { + FieldElement{15961858, 14150409, 26716931, -665832, -22794328, 13603569, 11829573, 7467844, -28822128, 929275}, + FieldElement{11038231, -11582396, -27310482, -7316562, -10498527, -16307831, -23479533, -9371869, -21393143, 2465074}, + FieldElement{20017163, -4323226, 27915242, 1529148, 12396362, 15675764, 13817261, -9658066, 2463391, -4622140}, + }, + { + FieldElement{-16358878, -12663911, -12065183, 4996454, -1256422, 1073572, 9583558, 12851107, 4003896, 12673717}, + FieldElement{-1731589, -15155870, -3262930, 16143082, 19294135, 13385325, 14741514, -9103726, 7903886, 2348101}, + FieldElement{24536016, -16515207, 12715592, -3862155, 1511293, 10047386, -3842346, -7129159, -28377538, 10048127}, + }, + }, + { + { + FieldElement{-12622226, -6204820, 30718825, 2591312, -10617028, 12192840, 18873298, -7297090, -32297756, 15221632}, + FieldElement{-26478122, -11103864, 11546244, -1852483, 9180880, 7656409, -21343950, 2095755, 29769758, 6593415}, + FieldElement{-31994208, -2907461, 4176912, 3264766, 12538965, -868111, 26312345, -6118678, 30958054, 8292160}, + }, + { + FieldElement{31429822, -13959116, 29173532, 15632448, 12174511, -2760094, 32808831, 3977186, 26143136, -3148876}, + FieldElement{22648901, 1402143, -22799984, 13746059, 7936347, 365344, -8668633, -1674433, -3758243, -2304625}, + FieldElement{-15491917, 8012313, -2514730, -12702462, -23965846, -10254029, -1612713, -1535569, -16664475, 8194478}, + }, + { + FieldElement{27338066, -7507420, -7414224, 10140405, -19026427, -6589889, 27277191, 8855376, 28572286, 3005164}, + FieldElement{26287124, 4821776, 25476601, -4145903, -3764513, -15788984, -18008582, 1182479, -26094821, -13079595}, + FieldElement{-7171154, 3178080, 23970071, 6201893, -17195577, -4489192, -21876275, -13982627, 32208683, -1198248}, + }, + { + FieldElement{-16657702, 2817643, -10286362, 14811298, 6024667, 13349505, -27315504, -10497842, -27672585, -11539858}, + FieldElement{15941029, -9405932, -21367050, 8062055, 31876073, -238629, -15278393, -1444429, 15397331, -4130193}, + FieldElement{8934485, -13485467, -23286397, -13423241, -32446090, 14047986, 31170398, -1441021, -27505566, 15087184}, + }, + { + FieldElement{-18357243, -2156491, 24524913, -16677868, 15520427, -6360776, -15502406, 11461896, 16788528, -5868942}, + FieldElement{-1947386, 16013773, 21750665, 3714552, -17401782, -16055433, -3770287, -10323320, 31322514, -11615635}, + FieldElement{21426655, -5650218, -13648287, -5347537, -28812189, -4920970, -18275391, -14621414, 13040862, -12112948}, + }, + { + FieldElement{11293895, 12478086, -27136401, 15083750, -29307421, 14748872, 14555558, -13417103, 1613711, 4896935}, + FieldElement{-25894883, 15323294, -8489791, -8057900, 25967126, -13425460, 2825960, -4897045, -23971776, -11267415}, + FieldElement{-15924766, -5229880, -17443532, 6410664, 3622847, 10243618, 20615400, 12405433, -23753030, -8436416}, + }, + { + FieldElement{-7091295, 12556208, -20191352, 9025187, -17072479, 4333801, 4378436, 2432030, 23097949, -566018}, + FieldElement{4565804, -16025654, 20084412, -7842817, 1724999, 189254, 24767264, 10103221, -18512313, 2424778}, + FieldElement{366633, -11976806, 8173090, -6890119, 30788634, 5745705, -7168678, 1344109, -3642553, 12412659}, + }, + { + FieldElement{-24001791, 7690286, 14929416, -168257, -32210835, -13412986, 24162697, -15326504, -3141501, 11179385}, + FieldElement{18289522, -14724954, 8056945, 16430056, -21729724, 7842514, -6001441, -1486897, -18684645, -11443503}, + FieldElement{476239, 6601091, -6152790, -9723375, 17503545, -4863900, 27672959, 13403813, 11052904, 5219329}, + }, + }, + { + { + FieldElement{20678546, -8375738, -32671898, 8849123, -5009758, 14574752, 31186971, -3973730, 9014762, -8579056}, + FieldElement{-13644050, -10350239, -15962508, 5075808, -1514661, -11534600, -33102500, 9160280, 8473550, -3256838}, + FieldElement{24900749, 14435722, 17209120, -15292541, -22592275, 9878983, -7689309, -16335821, -24568481, 11788948}, + }, + { + FieldElement{-3118155, -11395194, -13802089, 14797441, 9652448, -6845904, -20037437, 10410733, -24568470, -1458691}, + FieldElement{-15659161, 16736706, -22467150, 10215878, -9097177, 7563911, 11871841, -12505194, -18513325, 8464118}, + FieldElement{-23400612, 8348507, -14585951, -861714, -3950205, -6373419, 14325289, 8628612, 33313881, -8370517}, + }, + { + FieldElement{-20186973, -4967935, 22367356, 5271547, -1097117, -4788838, -24805667, -10236854, -8940735, -5818269}, + FieldElement{-6948785, -1795212, -32625683, -16021179, 32635414, -7374245, 15989197, -12838188, 28358192, -4253904}, + FieldElement{-23561781, -2799059, -32351682, -1661963, -9147719, 10429267, -16637684, 4072016, -5351664, 5596589}, + }, + { + FieldElement{-28236598, -3390048, 12312896, 6213178, 3117142, 16078565, 29266239, 2557221, 1768301, 15373193}, + FieldElement{-7243358, -3246960, -4593467, -7553353, -127927, -912245, -1090902, -4504991, -24660491, 3442910}, + FieldElement{-30210571, 5124043, 14181784, 8197961, 18964734, -11939093, 22597931, 7176455, -18585478, 13365930}, + }, + { + FieldElement{-7877390, -1499958, 8324673, 4690079, 6261860, 890446, 24538107, -8570186, -9689599, -3031667}, + FieldElement{25008904, -10771599, -4305031, -9638010, 16265036, 15721635, 683793, -11823784, 15723479, -15163481}, + FieldElement{-9660625, 12374379, -27006999, -7026148, -7724114, -12314514, 11879682, 5400171, 519526, -1235876}, + }, + { + FieldElement{22258397, -16332233, -7869817, 14613016, -22520255, -2950923, -20353881, 7315967, 16648397, 7605640}, + FieldElement{-8081308, -8464597, -8223311, 9719710, 19259459, -15348212, 23994942, -5281555, -9468848, 4763278}, + FieldElement{-21699244, 9220969, -15730624, 1084137, -25476107, -2852390, 31088447, -7764523, -11356529, 728112}, + }, + { + FieldElement{26047220, -11751471, -6900323, -16521798, 24092068, 9158119, -4273545, -12555558, -29365436, -5498272}, + FieldElement{17510331, -322857, 5854289, 8403524, 17133918, -3112612, -28111007, 12327945, 10750447, 10014012}, + FieldElement{-10312768, 3936952, 9156313, -8897683, 16498692, -994647, -27481051, -666732, 3424691, 7540221}, + }, + { + FieldElement{30322361, -6964110, 11361005, -4143317, 7433304, 4989748, -7071422, -16317219, -9244265, 15258046}, + FieldElement{13054562, -2779497, 19155474, 469045, -12482797, 4566042, 5631406, 2711395, 1062915, -5136345}, + FieldElement{-19240248, -11254599, -29509029, -7499965, -5835763, 13005411, -6066489, 12194497, 32960380, 1459310}, + }, + }, + { + { + FieldElement{19852034, 7027924, 23669353, 10020366, 8586503, -6657907, 394197, -6101885, 18638003, -11174937}, + FieldElement{31395534, 15098109, 26581030, 8030562, -16527914, -5007134, 9012486, -7584354, -6643087, -5442636}, + FieldElement{-9192165, -2347377, -1997099, 4529534, 25766844, 607986, -13222, 9677543, -32294889, -6456008}, + }, + { + FieldElement{-2444496, -149937, 29348902, 8186665, 1873760, 12489863, -30934579, -7839692, -7852844, -8138429}, + FieldElement{-15236356, -15433509, 7766470, 746860, 26346930, -10221762, -27333451, 10754588, -9431476, 5203576}, + FieldElement{31834314, 14135496, -770007, 5159118, 20917671, -16768096, -7467973, -7337524, 31809243, 7347066}, + }, + { + FieldElement{-9606723, -11874240, 20414459, 13033986, 13716524, -11691881, 19797970, -12211255, 15192876, -2087490}, + FieldElement{-12663563, -2181719, 1168162, -3804809, 26747877, -14138091, 10609330, 12694420, 33473243, -13382104}, + FieldElement{33184999, 11180355, 15832085, -11385430, -1633671, 225884, 15089336, -11023903, -6135662, 14480053}, + }, + { + FieldElement{31308717, -5619998, 31030840, -1897099, 15674547, -6582883, 5496208, 13685227, 27595050, 8737275}, + FieldElement{-20318852, -15150239, 10933843, -16178022, 8335352, -7546022, -31008351, -12610604, 26498114, 66511}, + FieldElement{22644454, -8761729, -16671776, 4884562, -3105614, -13559366, 30540766, -4286747, -13327787, -7515095}, + }, + { + FieldElement{-28017847, 9834845, 18617207, -2681312, -3401956, -13307506, 8205540, 13585437, -17127465, 15115439}, + FieldElement{23711543, -672915, 31206561, -8362711, 6164647, -9709987, -33535882, -1426096, 8236921, 16492939}, + FieldElement{-23910559, -13515526, -26299483, -4503841, 25005590, -7687270, 19574902, 10071562, 6708380, -6222424}, + }, + { + FieldElement{2101391, -4930054, 19702731, 2367575, -15427167, 1047675, 5301017, 9328700, 29955601, -11678310}, + FieldElement{3096359, 9271816, -21620864, -15521844, -14847996, -7592937, -25892142, -12635595, -9917575, 6216608}, + FieldElement{-32615849, 338663, -25195611, 2510422, -29213566, -13820213, 24822830, -6146567, -26767480, 7525079}, + }, + { + FieldElement{-23066649, -13985623, 16133487, -7896178, -3389565, 778788, -910336, -2782495, -19386633, 11994101}, + FieldElement{21691500, -13624626, -641331, -14367021, 3285881, -3483596, -25064666, 9718258, -7477437, 13381418}, + FieldElement{18445390, -4202236, 14979846, 11622458, -1727110, -3582980, 23111648, -6375247, 28535282, 15779576}, + }, + { + FieldElement{30098053, 3089662, -9234387, 16662135, -21306940, 11308411, -14068454, 12021730, 9955285, -16303356}, + FieldElement{9734894, -14576830, -7473633, -9138735, 2060392, 11313496, -18426029, 9924399, 20194861, 13380996}, + FieldElement{-26378102, -7965207, -22167821, 15789297, -18055342, -6168792, -1984914, 15707771, 26342023, 10146099}, + }, + }, + { + { + FieldElement{-26016874, -219943, 21339191, -41388, 19745256, -2878700, -29637280, 2227040, 21612326, -545728}, + FieldElement{-13077387, 1184228, 23562814, -5970442, -20351244, -6348714, 25764461, 12243797, -20856566, 11649658}, + FieldElement{-10031494, 11262626, 27384172, 2271902, 26947504, -15997771, 39944, 6114064, 33514190, 2333242}, + }, + { + FieldElement{-21433588, -12421821, 8119782, 7219913, -21830522, -9016134, -6679750, -12670638, 24350578, -13450001}, + FieldElement{-4116307, -11271533, -23886186, 4843615, -30088339, 690623, -31536088, -10406836, 8317860, 12352766}, + FieldElement{18200138, -14475911, -33087759, -2696619, -23702521, -9102511, -23552096, -2287550, 20712163, 6719373}, + }, + { + FieldElement{26656208, 6075253, -7858556, 1886072, -28344043, 4262326, 11117530, -3763210, 26224235, -3297458}, + FieldElement{-17168938, -14854097, -3395676, -16369877, -19954045, 14050420, 21728352, 9493610, 18620611, -16428628}, + FieldElement{-13323321, 13325349, 11432106, 5964811, 18609221, 6062965, -5269471, -9725556, -30701573, -16479657}, + }, + { + FieldElement{-23860538, -11233159, 26961357, 1640861, -32413112, -16737940, 12248509, -5240639, 13735342, 1934062}, + FieldElement{25089769, 6742589, 17081145, -13406266, 21909293, -16067981, -15136294, -3765346, -21277997, 5473616}, + FieldElement{31883677, -7961101, 1083432, -11572403, 22828471, 13290673, -7125085, 12469656, 29111212, -5451014}, + }, + { + FieldElement{24244947, -15050407, -26262976, 2791540, -14997599, 16666678, 24367466, 6388839, -10295587, 452383}, + FieldElement{-25640782, -3417841, 5217916, 16224624, 19987036, -4082269, -24236251, -5915248, 15766062, 8407814}, + FieldElement{-20406999, 13990231, 15495425, 16395525, 5377168, 15166495, -8917023, -4388953, -8067909, 2276718}, + }, + { + FieldElement{30157918, 12924066, -17712050, 9245753, 19895028, 3368142, -23827587, 5096219, 22740376, -7303417}, + FieldElement{2041139, -14256350, 7783687, 13876377, -25946985, -13352459, 24051124, 13742383, -15637599, 13295222}, + FieldElement{33338237, -8505733, 12532113, 7977527, 9106186, -1715251, -17720195, -4612972, -4451357, -14669444}, + }, + { + FieldElement{-20045281, 5454097, -14346548, 6447146, 28862071, 1883651, -2469266, -4141880, 7770569, 9620597}, + FieldElement{23208068, 7979712, 33071466, 8149229, 1758231, -10834995, 30945528, -1694323, -33502340, -14767970}, + FieldElement{1439958, -16270480, -1079989, -793782, 4625402, 10647766, -5043801, 1220118, 30494170, -11440799}, + }, + { + FieldElement{-5037580, -13028295, -2970559, -3061767, 15640974, -6701666, -26739026, 926050, -1684339, -13333647}, + FieldElement{13908495, -3549272, 30919928, -6273825, -21521863, 7989039, 9021034, 9078865, 3353509, 4033511}, + FieldElement{-29663431, -15113610, 32259991, -344482, 24295849, -12912123, 23161163, 8839127, 27485041, 7356032}, + }, + }, + { + { + FieldElement{9661027, 705443, 11980065, -5370154, -1628543, 14661173, -6346142, 2625015, 28431036, -16771834}, + FieldElement{-23839233, -8311415, -25945511, 7480958, -17681669, -8354183, -22545972, 14150565, 15970762, 4099461}, + FieldElement{29262576, 16756590, 26350592, -8793563, 8529671, -11208050, 13617293, -9937143, 11465739, 8317062}, + }, + { + FieldElement{-25493081, -6962928, 32500200, -9419051, -23038724, -2302222, 14898637, 3848455, 20969334, -5157516}, + FieldElement{-20384450, -14347713, -18336405, 13884722, -33039454, 2842114, -21610826, -3649888, 11177095, 14989547}, + FieldElement{-24496721, -11716016, 16959896, 2278463, 12066309, 10137771, 13515641, 2581286, -28487508, 9930240}, + }, + { + FieldElement{-17751622, -2097826, 16544300, -13009300, -15914807, -14949081, 18345767, -13403753, 16291481, -5314038}, + FieldElement{-33229194, 2553288, 32678213, 9875984, 8534129, 6889387, -9676774, 6957617, 4368891, 9788741}, + FieldElement{16660756, 7281060, -10830758, 12911820, 20108584, -8101676, -21722536, -8613148, 16250552, -11111103}, + }, + { + FieldElement{-19765507, 2390526, -16551031, 14161980, 1905286, 6414907, 4689584, 10604807, -30190403, 4782747}, + FieldElement{-1354539, 14736941, -7367442, -13292886, 7710542, -14155590, -9981571, 4383045, 22546403, 437323}, + FieldElement{31665577, -12180464, -16186830, 1491339, -18368625, 3294682, 27343084, 2786261, -30633590, -14097016}, + }, + { + FieldElement{-14467279, -683715, -33374107, 7448552, 19294360, 14334329, -19690631, 2355319, -19284671, -6114373}, + FieldElement{15121312, -15796162, 6377020, -6031361, -10798111, -12957845, 18952177, 15496498, -29380133, 11754228}, + FieldElement{-2637277, -13483075, 8488727, -14303896, 12728761, -1622493, 7141596, 11724556, 22761615, -10134141}, + }, + { + FieldElement{16918416, 11729663, -18083579, 3022987, -31015732, -13339659, -28741185, -12227393, 32851222, 11717399}, + FieldElement{11166634, 7338049, -6722523, 4531520, -29468672, -7302055, 31474879, 3483633, -1193175, -4030831}, + FieldElement{-185635, 9921305, 31456609, -13536438, -12013818, 13348923, 33142652, 6546660, -19985279, -3948376}, + }, + { + FieldElement{-32460596, 11266712, -11197107, -7899103, 31703694, 3855903, -8537131, -12833048, -30772034, -15486313}, + FieldElement{-18006477, 12709068, 3991746, -6479188, -21491523, -10550425, -31135347, -16049879, 10928917, 3011958}, + FieldElement{-6957757, -15594337, 31696059, 334240, 29576716, 14796075, -30831056, -12805180, 18008031, 10258577}, + }, + { + FieldElement{-22448644, 15655569, 7018479, -4410003, -30314266, -1201591, -1853465, 1367120, 25127874, 6671743}, + FieldElement{29701166, -14373934, -10878120, 9279288, -17568, 13127210, 21382910, 11042292, 25838796, 4642684}, + FieldElement{-20430234, 14955537, -24126347, 8124619, -5369288, -5990470, 30468147, -13900640, 18423289, 4177476}, + }, + }, +} diff --git a/internal/crypto/ed25519/internal/edwards25519/edwards25519.go b/internal/crypto/ed25519/internal/edwards25519/edwards25519.go new file mode 100644 index 000000000..fd03c252a --- /dev/null +++ b/internal/crypto/ed25519/internal/edwards25519/edwards25519.go @@ -0,0 +1,1793 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package edwards25519 + +import "encoding/binary" + +// This code is a port of the public domain, “ref10” implementation of ed25519 +// from SUPERCOP. + +// FieldElement represents an element of the field GF(2^255 - 19). An element +// t, entries t[0]...t[9], represents the integer t[0]+2^26 t[1]+2^51 t[2]+2^77 +// t[3]+2^102 t[4]+...+2^230 t[9]. Bounds on each t[i] vary depending on +// context. +type FieldElement [10]int32 + +var zero FieldElement + +func FeZero(fe *FieldElement) { + copy(fe[:], zero[:]) +} + +func FeOne(fe *FieldElement) { + FeZero(fe) + fe[0] = 1 +} + +func FeAdd(dst, a, b *FieldElement) { + dst[0] = a[0] + b[0] + dst[1] = a[1] + b[1] + dst[2] = a[2] + b[2] + dst[3] = a[3] + b[3] + dst[4] = a[4] + b[4] + dst[5] = a[5] + b[5] + dst[6] = a[6] + b[6] + dst[7] = a[7] + b[7] + dst[8] = a[8] + b[8] + dst[9] = a[9] + b[9] +} + +func FeSub(dst, a, b *FieldElement) { + dst[0] = a[0] - b[0] + dst[1] = a[1] - b[1] + dst[2] = a[2] - b[2] + dst[3] = a[3] - b[3] + dst[4] = a[4] - b[4] + dst[5] = a[5] - b[5] + dst[6] = a[6] - b[6] + dst[7] = a[7] - b[7] + dst[8] = a[8] - b[8] + dst[9] = a[9] - b[9] +} + +func FeCopy(dst, src *FieldElement) { + copy(dst[:], src[:]) +} + +// Replace (f,g) with (g,g) if b == 1; +// replace (f,g) with (f,g) if b == 0. +// +// Preconditions: b in {0,1}. +func FeCMove(f, g *FieldElement, b int32) { + b = -b + f[0] ^= b & (f[0] ^ g[0]) + f[1] ^= b & (f[1] ^ g[1]) + f[2] ^= b & (f[2] ^ g[2]) + f[3] ^= b & (f[3] ^ g[3]) + f[4] ^= b & (f[4] ^ g[4]) + f[5] ^= b & (f[5] ^ g[5]) + f[6] ^= b & (f[6] ^ g[6]) + f[7] ^= b & (f[7] ^ g[7]) + f[8] ^= b & (f[8] ^ g[8]) + f[9] ^= b & (f[9] ^ g[9]) +} + +func load3(in []byte) int64 { + var r int64 + r = int64(in[0]) + r |= int64(in[1]) << 8 + r |= int64(in[2]) << 16 + return r +} + +func load4(in []byte) int64 { + var r int64 + r = int64(in[0]) + r |= int64(in[1]) << 8 + r |= int64(in[2]) << 16 + r |= int64(in[3]) << 24 + return r +} + +func FeFromBytes(dst *FieldElement, src *[32]byte) { + h0 := load4(src[:]) + h1 := load3(src[4:]) << 6 + h2 := load3(src[7:]) << 5 + h3 := load3(src[10:]) << 3 + h4 := load3(src[13:]) << 2 + h5 := load4(src[16:]) + h6 := load3(src[20:]) << 7 + h7 := load3(src[23:]) << 5 + h8 := load3(src[26:]) << 4 + h9 := (load3(src[29:]) & 8388607) << 2 + + FeCombine(dst, h0, h1, h2, h3, h4, h5, h6, h7, h8, h9) +} + +// FeToBytes marshals h to s. +// Preconditions: +// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc. +// +// Write p=2^255-19; q=floor(h/p). +// Basic claim: q = floor(2^(-255)(h + 19 2^(-25)h9 + 2^(-1))). +// +// Proof: +// Have |h|<=p so |q|<=1 so |19^2 2^(-255) q|<1/4. +// Also have |h-2^230 h9|<2^230 so |19 2^(-255)(h-2^230 h9)|<1/4. +// +// Write y=2^(-1)-19^2 2^(-255)q-19 2^(-255)(h-2^230 h9). +// Then 0> 25 + q = (h[0] + q) >> 26 + q = (h[1] + q) >> 25 + q = (h[2] + q) >> 26 + q = (h[3] + q) >> 25 + q = (h[4] + q) >> 26 + q = (h[5] + q) >> 25 + q = (h[6] + q) >> 26 + q = (h[7] + q) >> 25 + q = (h[8] + q) >> 26 + q = (h[9] + q) >> 25 + + // Goal: Output h-(2^255-19)q, which is between 0 and 2^255-20. + h[0] += 19 * q + // Goal: Output h-2^255 q, which is between 0 and 2^255-20. + + carry[0] = h[0] >> 26 + h[1] += carry[0] + h[0] -= carry[0] << 26 + carry[1] = h[1] >> 25 + h[2] += carry[1] + h[1] -= carry[1] << 25 + carry[2] = h[2] >> 26 + h[3] += carry[2] + h[2] -= carry[2] << 26 + carry[3] = h[3] >> 25 + h[4] += carry[3] + h[3] -= carry[3] << 25 + carry[4] = h[4] >> 26 + h[5] += carry[4] + h[4] -= carry[4] << 26 + carry[5] = h[5] >> 25 + h[6] += carry[5] + h[5] -= carry[5] << 25 + carry[6] = h[6] >> 26 + h[7] += carry[6] + h[6] -= carry[6] << 26 + carry[7] = h[7] >> 25 + h[8] += carry[7] + h[7] -= carry[7] << 25 + carry[8] = h[8] >> 26 + h[9] += carry[8] + h[8] -= carry[8] << 26 + carry[9] = h[9] >> 25 + h[9] -= carry[9] << 25 + // h10 = carry9 + + // Goal: Output h[0]+...+2^255 h10-2^255 q, which is between 0 and 2^255-20. + // Have h[0]+...+2^230 h[9] between 0 and 2^255-1; + // evidently 2^255 h10-2^255 q = 0. + // Goal: Output h[0]+...+2^230 h[9]. + + s[0] = byte(h[0] >> 0) + s[1] = byte(h[0] >> 8) + s[2] = byte(h[0] >> 16) + s[3] = byte((h[0] >> 24) | (h[1] << 2)) + s[4] = byte(h[1] >> 6) + s[5] = byte(h[1] >> 14) + s[6] = byte((h[1] >> 22) | (h[2] << 3)) + s[7] = byte(h[2] >> 5) + s[8] = byte(h[2] >> 13) + s[9] = byte((h[2] >> 21) | (h[3] << 5)) + s[10] = byte(h[3] >> 3) + s[11] = byte(h[3] >> 11) + s[12] = byte((h[3] >> 19) | (h[4] << 6)) + s[13] = byte(h[4] >> 2) + s[14] = byte(h[4] >> 10) + s[15] = byte(h[4] >> 18) + s[16] = byte(h[5] >> 0) + s[17] = byte(h[5] >> 8) + s[18] = byte(h[5] >> 16) + s[19] = byte((h[5] >> 24) | (h[6] << 1)) + s[20] = byte(h[6] >> 7) + s[21] = byte(h[6] >> 15) + s[22] = byte((h[6] >> 23) | (h[7] << 3)) + s[23] = byte(h[7] >> 5) + s[24] = byte(h[7] >> 13) + s[25] = byte((h[7] >> 21) | (h[8] << 4)) + s[26] = byte(h[8] >> 4) + s[27] = byte(h[8] >> 12) + s[28] = byte((h[8] >> 20) | (h[9] << 6)) + s[29] = byte(h[9] >> 2) + s[30] = byte(h[9] >> 10) + s[31] = byte(h[9] >> 18) +} + +func FeIsNegative(f *FieldElement) byte { + var s [32]byte + FeToBytes(&s, f) + return s[0] & 1 +} + +func FeIsNonZero(f *FieldElement) int32 { + var s [32]byte + FeToBytes(&s, f) + var x uint8 + for _, b := range s { + x |= b + } + x |= x >> 4 + x |= x >> 2 + x |= x >> 1 + return int32(x & 1) +} + +// FeNeg sets h = -f +// +// Preconditions: +// |f| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc. +// +// Postconditions: +// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc. +func FeNeg(h, f *FieldElement) { + h[0] = -f[0] + h[1] = -f[1] + h[2] = -f[2] + h[3] = -f[3] + h[4] = -f[4] + h[5] = -f[5] + h[6] = -f[6] + h[7] = -f[7] + h[8] = -f[8] + h[9] = -f[9] +} + +func FeCombine(h *FieldElement, h0, h1, h2, h3, h4, h5, h6, h7, h8, h9 int64) { + var c0, c1, c2, c3, c4, c5, c6, c7, c8, c9 int64 + + /* + |h0| <= (1.1*1.1*2^52*(1+19+19+19+19)+1.1*1.1*2^50*(38+38+38+38+38)) + i.e. |h0| <= 1.2*2^59; narrower ranges for h2, h4, h6, h8 + |h1| <= (1.1*1.1*2^51*(1+1+19+19+19+19+19+19+19+19)) + i.e. |h1| <= 1.5*2^58; narrower ranges for h3, h5, h7, h9 + */ + + c0 = (h0 + (1 << 25)) >> 26 + h1 += c0 + h0 -= c0 << 26 + c4 = (h4 + (1 << 25)) >> 26 + h5 += c4 + h4 -= c4 << 26 + /* |h0| <= 2^25 */ + /* |h4| <= 2^25 */ + /* |h1| <= 1.51*2^58 */ + /* |h5| <= 1.51*2^58 */ + + c1 = (h1 + (1 << 24)) >> 25 + h2 += c1 + h1 -= c1 << 25 + c5 = (h5 + (1 << 24)) >> 25 + h6 += c5 + h5 -= c5 << 25 + /* |h1| <= 2^24; from now on fits into int32 */ + /* |h5| <= 2^24; from now on fits into int32 */ + /* |h2| <= 1.21*2^59 */ + /* |h6| <= 1.21*2^59 */ + + c2 = (h2 + (1 << 25)) >> 26 + h3 += c2 + h2 -= c2 << 26 + c6 = (h6 + (1 << 25)) >> 26 + h7 += c6 + h6 -= c6 << 26 + /* |h2| <= 2^25; from now on fits into int32 unchanged */ + /* |h6| <= 2^25; from now on fits into int32 unchanged */ + /* |h3| <= 1.51*2^58 */ + /* |h7| <= 1.51*2^58 */ + + c3 = (h3 + (1 << 24)) >> 25 + h4 += c3 + h3 -= c3 << 25 + c7 = (h7 + (1 << 24)) >> 25 + h8 += c7 + h7 -= c7 << 25 + /* |h3| <= 2^24; from now on fits into int32 unchanged */ + /* |h7| <= 2^24; from now on fits into int32 unchanged */ + /* |h4| <= 1.52*2^33 */ + /* |h8| <= 1.52*2^33 */ + + c4 = (h4 + (1 << 25)) >> 26 + h5 += c4 + h4 -= c4 << 26 + c8 = (h8 + (1 << 25)) >> 26 + h9 += c8 + h8 -= c8 << 26 + /* |h4| <= 2^25; from now on fits into int32 unchanged */ + /* |h8| <= 2^25; from now on fits into int32 unchanged */ + /* |h5| <= 1.01*2^24 */ + /* |h9| <= 1.51*2^58 */ + + c9 = (h9 + (1 << 24)) >> 25 + h0 += c9 * 19 + h9 -= c9 << 25 + /* |h9| <= 2^24; from now on fits into int32 unchanged */ + /* |h0| <= 1.8*2^37 */ + + c0 = (h0 + (1 << 25)) >> 26 + h1 += c0 + h0 -= c0 << 26 + /* |h0| <= 2^25; from now on fits into int32 unchanged */ + /* |h1| <= 1.01*2^24 */ + + h[0] = int32(h0) + h[1] = int32(h1) + h[2] = int32(h2) + h[3] = int32(h3) + h[4] = int32(h4) + h[5] = int32(h5) + h[6] = int32(h6) + h[7] = int32(h7) + h[8] = int32(h8) + h[9] = int32(h9) +} + +// FeMul calculates h = f * g +// Can overlap h with f or g. +// +// Preconditions: +// |f| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc. +// |g| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc. +// +// Postconditions: +// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc. +// +// Notes on implementation strategy: +// +// Using schoolbook multiplication. +// Karatsuba would save a little in some cost models. +// +// Most multiplications by 2 and 19 are 32-bit precomputations; +// cheaper than 64-bit postcomputations. +// +// There is one remaining multiplication by 19 in the carry chain; +// one *19 precomputation can be merged into this, +// but the resulting data flow is considerably less clean. +// +// There are 12 carries below. +// 10 of them are 2-way parallelizable and vectorizable. +// Can get away with 11 carries, but then data flow is much deeper. +// +// With tighter constraints on inputs, can squeeze carries into int32. +func FeMul(h, f, g *FieldElement) { + f0 := int64(f[0]) + f1 := int64(f[1]) + f2 := int64(f[2]) + f3 := int64(f[3]) + f4 := int64(f[4]) + f5 := int64(f[5]) + f6 := int64(f[6]) + f7 := int64(f[7]) + f8 := int64(f[8]) + f9 := int64(f[9]) + + f1_2 := int64(2 * f[1]) + f3_2 := int64(2 * f[3]) + f5_2 := int64(2 * f[5]) + f7_2 := int64(2 * f[7]) + f9_2 := int64(2 * f[9]) + + g0 := int64(g[0]) + g1 := int64(g[1]) + g2 := int64(g[2]) + g3 := int64(g[3]) + g4 := int64(g[4]) + g5 := int64(g[5]) + g6 := int64(g[6]) + g7 := int64(g[7]) + g8 := int64(g[8]) + g9 := int64(g[9]) + + g1_19 := int64(19 * g[1]) /* 1.4*2^29 */ + g2_19 := int64(19 * g[2]) /* 1.4*2^30; still ok */ + g3_19 := int64(19 * g[3]) + g4_19 := int64(19 * g[4]) + g5_19 := int64(19 * g[5]) + g6_19 := int64(19 * g[6]) + g7_19 := int64(19 * g[7]) + g8_19 := int64(19 * g[8]) + g9_19 := int64(19 * g[9]) + + h0 := f0*g0 + f1_2*g9_19 + f2*g8_19 + f3_2*g7_19 + f4*g6_19 + f5_2*g5_19 + f6*g4_19 + f7_2*g3_19 + f8*g2_19 + f9_2*g1_19 + h1 := f0*g1 + f1*g0 + f2*g9_19 + f3*g8_19 + f4*g7_19 + f5*g6_19 + f6*g5_19 + f7*g4_19 + f8*g3_19 + f9*g2_19 + h2 := f0*g2 + f1_2*g1 + f2*g0 + f3_2*g9_19 + f4*g8_19 + f5_2*g7_19 + f6*g6_19 + f7_2*g5_19 + f8*g4_19 + f9_2*g3_19 + h3 := f0*g3 + f1*g2 + f2*g1 + f3*g0 + f4*g9_19 + f5*g8_19 + f6*g7_19 + f7*g6_19 + f8*g5_19 + f9*g4_19 + h4 := f0*g4 + f1_2*g3 + f2*g2 + f3_2*g1 + f4*g0 + f5_2*g9_19 + f6*g8_19 + f7_2*g7_19 + f8*g6_19 + f9_2*g5_19 + h5 := f0*g5 + f1*g4 + f2*g3 + f3*g2 + f4*g1 + f5*g0 + f6*g9_19 + f7*g8_19 + f8*g7_19 + f9*g6_19 + h6 := f0*g6 + f1_2*g5 + f2*g4 + f3_2*g3 + f4*g2 + f5_2*g1 + f6*g0 + f7_2*g9_19 + f8*g8_19 + f9_2*g7_19 + h7 := f0*g7 + f1*g6 + f2*g5 + f3*g4 + f4*g3 + f5*g2 + f6*g1 + f7*g0 + f8*g9_19 + f9*g8_19 + h8 := f0*g8 + f1_2*g7 + f2*g6 + f3_2*g5 + f4*g4 + f5_2*g3 + f6*g2 + f7_2*g1 + f8*g0 + f9_2*g9_19 + h9 := f0*g9 + f1*g8 + f2*g7 + f3*g6 + f4*g5 + f5*g4 + f6*g3 + f7*g2 + f8*g1 + f9*g0 + + FeCombine(h, h0, h1, h2, h3, h4, h5, h6, h7, h8, h9) +} + +func feSquare(f *FieldElement) (h0, h1, h2, h3, h4, h5, h6, h7, h8, h9 int64) { + f0 := int64(f[0]) + f1 := int64(f[1]) + f2 := int64(f[2]) + f3 := int64(f[3]) + f4 := int64(f[4]) + f5 := int64(f[5]) + f6 := int64(f[6]) + f7 := int64(f[7]) + f8 := int64(f[8]) + f9 := int64(f[9]) + f0_2 := int64(2 * f[0]) + f1_2 := int64(2 * f[1]) + f2_2 := int64(2 * f[2]) + f3_2 := int64(2 * f[3]) + f4_2 := int64(2 * f[4]) + f5_2 := int64(2 * f[5]) + f6_2 := int64(2 * f[6]) + f7_2 := int64(2 * f[7]) + f5_38 := 38 * f5 // 1.31*2^30 + f6_19 := 19 * f6 // 1.31*2^30 + f7_38 := 38 * f7 // 1.31*2^30 + f8_19 := 19 * f8 // 1.31*2^30 + f9_38 := 38 * f9 // 1.31*2^30 + + h0 = f0*f0 + f1_2*f9_38 + f2_2*f8_19 + f3_2*f7_38 + f4_2*f6_19 + f5*f5_38 + h1 = f0_2*f1 + f2*f9_38 + f3_2*f8_19 + f4*f7_38 + f5_2*f6_19 + h2 = f0_2*f2 + f1_2*f1 + f3_2*f9_38 + f4_2*f8_19 + f5_2*f7_38 + f6*f6_19 + h3 = f0_2*f3 + f1_2*f2 + f4*f9_38 + f5_2*f8_19 + f6*f7_38 + h4 = f0_2*f4 + f1_2*f3_2 + f2*f2 + f5_2*f9_38 + f6_2*f8_19 + f7*f7_38 + h5 = f0_2*f5 + f1_2*f4 + f2_2*f3 + f6*f9_38 + f7_2*f8_19 + h6 = f0_2*f6 + f1_2*f5_2 + f2_2*f4 + f3_2*f3 + f7_2*f9_38 + f8*f8_19 + h7 = f0_2*f7 + f1_2*f6 + f2_2*f5 + f3_2*f4 + f8*f9_38 + h8 = f0_2*f8 + f1_2*f7_2 + f2_2*f6 + f3_2*f5_2 + f4*f4 + f9*f9_38 + h9 = f0_2*f9 + f1_2*f8 + f2_2*f7 + f3_2*f6 + f4_2*f5 + + return +} + +// FeSquare calculates h = f*f. Can overlap h with f. +// +// Preconditions: +// |f| bounded by 1.1*2^26,1.1*2^25,1.1*2^26,1.1*2^25,etc. +// +// Postconditions: +// |h| bounded by 1.1*2^25,1.1*2^24,1.1*2^25,1.1*2^24,etc. +func FeSquare(h, f *FieldElement) { + h0, h1, h2, h3, h4, h5, h6, h7, h8, h9 := feSquare(f) + FeCombine(h, h0, h1, h2, h3, h4, h5, h6, h7, h8, h9) +} + +// FeSquare2 sets h = 2 * f * f +// +// Can overlap h with f. +// +// Preconditions: +// |f| bounded by 1.65*2^26,1.65*2^25,1.65*2^26,1.65*2^25,etc. +// +// Postconditions: +// |h| bounded by 1.01*2^25,1.01*2^24,1.01*2^25,1.01*2^24,etc. +// See fe_mul.c for discussion of implementation strategy. +func FeSquare2(h, f *FieldElement) { + h0, h1, h2, h3, h4, h5, h6, h7, h8, h9 := feSquare(f) + + h0 += h0 + h1 += h1 + h2 += h2 + h3 += h3 + h4 += h4 + h5 += h5 + h6 += h6 + h7 += h7 + h8 += h8 + h9 += h9 + + FeCombine(h, h0, h1, h2, h3, h4, h5, h6, h7, h8, h9) +} + +func FeInvert(out, z *FieldElement) { + var t0, t1, t2, t3 FieldElement + var i int + + FeSquare(&t0, z) // 2^1 + FeSquare(&t1, &t0) // 2^2 + for i = 1; i < 2; i++ { // 2^3 + FeSquare(&t1, &t1) + } + FeMul(&t1, z, &t1) // 2^3 + 2^0 + FeMul(&t0, &t0, &t1) // 2^3 + 2^1 + 2^0 + FeSquare(&t2, &t0) // 2^4 + 2^2 + 2^1 + FeMul(&t1, &t1, &t2) // 2^4 + 2^3 + 2^2 + 2^1 + 2^0 + FeSquare(&t2, &t1) // 5,4,3,2,1 + for i = 1; i < 5; i++ { // 9,8,7,6,5 + FeSquare(&t2, &t2) + } + FeMul(&t1, &t2, &t1) // 9,8,7,6,5,4,3,2,1,0 + FeSquare(&t2, &t1) // 10..1 + for i = 1; i < 10; i++ { // 19..10 + FeSquare(&t2, &t2) + } + FeMul(&t2, &t2, &t1) // 19..0 + FeSquare(&t3, &t2) // 20..1 + for i = 1; i < 20; i++ { // 39..20 + FeSquare(&t3, &t3) + } + FeMul(&t2, &t3, &t2) // 39..0 + FeSquare(&t2, &t2) // 40..1 + for i = 1; i < 10; i++ { // 49..10 + FeSquare(&t2, &t2) + } + FeMul(&t1, &t2, &t1) // 49..0 + FeSquare(&t2, &t1) // 50..1 + for i = 1; i < 50; i++ { // 99..50 + FeSquare(&t2, &t2) + } + FeMul(&t2, &t2, &t1) // 99..0 + FeSquare(&t3, &t2) // 100..1 + for i = 1; i < 100; i++ { // 199..100 + FeSquare(&t3, &t3) + } + FeMul(&t2, &t3, &t2) // 199..0 + FeSquare(&t2, &t2) // 200..1 + for i = 1; i < 50; i++ { // 249..50 + FeSquare(&t2, &t2) + } + FeMul(&t1, &t2, &t1) // 249..0 + FeSquare(&t1, &t1) // 250..1 + for i = 1; i < 5; i++ { // 254..5 + FeSquare(&t1, &t1) + } + FeMul(out, &t1, &t0) // 254..5,3,1,0 +} + +func fePow22523(out, z *FieldElement) { + var t0, t1, t2 FieldElement + var i int + + FeSquare(&t0, z) + for i = 1; i < 1; i++ { + FeSquare(&t0, &t0) + } + FeSquare(&t1, &t0) + for i = 1; i < 2; i++ { + FeSquare(&t1, &t1) + } + FeMul(&t1, z, &t1) + FeMul(&t0, &t0, &t1) + FeSquare(&t0, &t0) + for i = 1; i < 1; i++ { + FeSquare(&t0, &t0) + } + FeMul(&t0, &t1, &t0) + FeSquare(&t1, &t0) + for i = 1; i < 5; i++ { + FeSquare(&t1, &t1) + } + FeMul(&t0, &t1, &t0) + FeSquare(&t1, &t0) + for i = 1; i < 10; i++ { + FeSquare(&t1, &t1) + } + FeMul(&t1, &t1, &t0) + FeSquare(&t2, &t1) + for i = 1; i < 20; i++ { + FeSquare(&t2, &t2) + } + FeMul(&t1, &t2, &t1) + FeSquare(&t1, &t1) + for i = 1; i < 10; i++ { + FeSquare(&t1, &t1) + } + FeMul(&t0, &t1, &t0) + FeSquare(&t1, &t0) + for i = 1; i < 50; i++ { + FeSquare(&t1, &t1) + } + FeMul(&t1, &t1, &t0) + FeSquare(&t2, &t1) + for i = 1; i < 100; i++ { + FeSquare(&t2, &t2) + } + FeMul(&t1, &t2, &t1) + FeSquare(&t1, &t1) + for i = 1; i < 50; i++ { + FeSquare(&t1, &t1) + } + FeMul(&t0, &t1, &t0) + FeSquare(&t0, &t0) + for i = 1; i < 2; i++ { + FeSquare(&t0, &t0) + } + FeMul(out, &t0, z) +} + +// Group elements are members of the elliptic curve -x^2 + y^2 = 1 + d * x^2 * +// y^2 where d = -121665/121666. +// +// Several representations are used: +// ProjectiveGroupElement: (X:Y:Z) satisfying x=X/Z, y=Y/Z +// ExtendedGroupElement: (X:Y:Z:T) satisfying x=X/Z, y=Y/Z, XY=ZT +// CompletedGroupElement: ((X:Z),(Y:T)) satisfying x=X/Z, y=Y/T +// PreComputedGroupElement: (y+x,y-x,2dxy) + +type ProjectiveGroupElement struct { + X, Y, Z FieldElement +} + +type ExtendedGroupElement struct { + X, Y, Z, T FieldElement +} + +type CompletedGroupElement struct { + X, Y, Z, T FieldElement +} + +type PreComputedGroupElement struct { + yPlusX, yMinusX, xy2d FieldElement +} + +type CachedGroupElement struct { + yPlusX, yMinusX, Z, T2d FieldElement +} + +func (p *ProjectiveGroupElement) Zero() { + FeZero(&p.X) + FeOne(&p.Y) + FeOne(&p.Z) +} + +func (p *ProjectiveGroupElement) Double(r *CompletedGroupElement) { + var t0 FieldElement + + FeSquare(&r.X, &p.X) + FeSquare(&r.Z, &p.Y) + FeSquare2(&r.T, &p.Z) + FeAdd(&r.Y, &p.X, &p.Y) + FeSquare(&t0, &r.Y) + FeAdd(&r.Y, &r.Z, &r.X) + FeSub(&r.Z, &r.Z, &r.X) + FeSub(&r.X, &t0, &r.Y) + FeSub(&r.T, &r.T, &r.Z) +} + +func (p *ProjectiveGroupElement) ToBytes(s *[32]byte) { + var recip, x, y FieldElement + + FeInvert(&recip, &p.Z) + FeMul(&x, &p.X, &recip) + FeMul(&y, &p.Y, &recip) + FeToBytes(s, &y) + s[31] ^= FeIsNegative(&x) << 7 +} + +func (p *ExtendedGroupElement) Zero() { + FeZero(&p.X) + FeOne(&p.Y) + FeOne(&p.Z) + FeZero(&p.T) +} + +func (p *ExtendedGroupElement) Double(r *CompletedGroupElement) { + var q ProjectiveGroupElement + p.ToProjective(&q) + q.Double(r) +} + +func (p *ExtendedGroupElement) ToCached(r *CachedGroupElement) { + FeAdd(&r.yPlusX, &p.Y, &p.X) + FeSub(&r.yMinusX, &p.Y, &p.X) + FeCopy(&r.Z, &p.Z) + FeMul(&r.T2d, &p.T, &d2) +} + +func (p *ExtendedGroupElement) ToProjective(r *ProjectiveGroupElement) { + FeCopy(&r.X, &p.X) + FeCopy(&r.Y, &p.Y) + FeCopy(&r.Z, &p.Z) +} + +func (p *ExtendedGroupElement) ToBytes(s *[32]byte) { + var recip, x, y FieldElement + + FeInvert(&recip, &p.Z) + FeMul(&x, &p.X, &recip) + FeMul(&y, &p.Y, &recip) + FeToBytes(s, &y) + s[31] ^= FeIsNegative(&x) << 7 +} + +func (p *ExtendedGroupElement) FromBytes(s *[32]byte) bool { + var u, v, v3, vxx, check FieldElement + + FeFromBytes(&p.Y, s) + FeOne(&p.Z) + FeSquare(&u, &p.Y) + FeMul(&v, &u, &d) + FeSub(&u, &u, &p.Z) // y = y^2-1 + FeAdd(&v, &v, &p.Z) // v = dy^2+1 + + FeSquare(&v3, &v) + FeMul(&v3, &v3, &v) // v3 = v^3 + FeSquare(&p.X, &v3) + FeMul(&p.X, &p.X, &v) + FeMul(&p.X, &p.X, &u) // x = uv^7 + + fePow22523(&p.X, &p.X) // x = (uv^7)^((q-5)/8) + FeMul(&p.X, &p.X, &v3) + FeMul(&p.X, &p.X, &u) // x = uv^3(uv^7)^((q-5)/8) + + var tmpX, tmp2 [32]byte + + FeSquare(&vxx, &p.X) + FeMul(&vxx, &vxx, &v) + FeSub(&check, &vxx, &u) // vx^2-u + if FeIsNonZero(&check) == 1 { + FeAdd(&check, &vxx, &u) // vx^2+u + if FeIsNonZero(&check) == 1 { + return false + } + FeMul(&p.X, &p.X, &SqrtM1) + + FeToBytes(&tmpX, &p.X) + for i, v := range tmpX { + tmp2[31-i] = v + } + } + + if FeIsNegative(&p.X) != (s[31] >> 7) { + FeNeg(&p.X, &p.X) + } + + FeMul(&p.T, &p.X, &p.Y) + return true +} + +func (p *CompletedGroupElement) ToProjective(r *ProjectiveGroupElement) { + FeMul(&r.X, &p.X, &p.T) + FeMul(&r.Y, &p.Y, &p.Z) + FeMul(&r.Z, &p.Z, &p.T) +} + +func (p *CompletedGroupElement) ToExtended(r *ExtendedGroupElement) { + FeMul(&r.X, &p.X, &p.T) + FeMul(&r.Y, &p.Y, &p.Z) + FeMul(&r.Z, &p.Z, &p.T) + FeMul(&r.T, &p.X, &p.Y) +} + +func (p *PreComputedGroupElement) Zero() { + FeOne(&p.yPlusX) + FeOne(&p.yMinusX) + FeZero(&p.xy2d) +} + +func geAdd(r *CompletedGroupElement, p *ExtendedGroupElement, q *CachedGroupElement) { + var t0 FieldElement + + FeAdd(&r.X, &p.Y, &p.X) + FeSub(&r.Y, &p.Y, &p.X) + FeMul(&r.Z, &r.X, &q.yPlusX) + FeMul(&r.Y, &r.Y, &q.yMinusX) + FeMul(&r.T, &q.T2d, &p.T) + FeMul(&r.X, &p.Z, &q.Z) + FeAdd(&t0, &r.X, &r.X) + FeSub(&r.X, &r.Z, &r.Y) + FeAdd(&r.Y, &r.Z, &r.Y) + FeAdd(&r.Z, &t0, &r.T) + FeSub(&r.T, &t0, &r.T) +} + +func geSub(r *CompletedGroupElement, p *ExtendedGroupElement, q *CachedGroupElement) { + var t0 FieldElement + + FeAdd(&r.X, &p.Y, &p.X) + FeSub(&r.Y, &p.Y, &p.X) + FeMul(&r.Z, &r.X, &q.yMinusX) + FeMul(&r.Y, &r.Y, &q.yPlusX) + FeMul(&r.T, &q.T2d, &p.T) + FeMul(&r.X, &p.Z, &q.Z) + FeAdd(&t0, &r.X, &r.X) + FeSub(&r.X, &r.Z, &r.Y) + FeAdd(&r.Y, &r.Z, &r.Y) + FeSub(&r.Z, &t0, &r.T) + FeAdd(&r.T, &t0, &r.T) +} + +func geMixedAdd(r *CompletedGroupElement, p *ExtendedGroupElement, q *PreComputedGroupElement) { + var t0 FieldElement + + FeAdd(&r.X, &p.Y, &p.X) + FeSub(&r.Y, &p.Y, &p.X) + FeMul(&r.Z, &r.X, &q.yPlusX) + FeMul(&r.Y, &r.Y, &q.yMinusX) + FeMul(&r.T, &q.xy2d, &p.T) + FeAdd(&t0, &p.Z, &p.Z) + FeSub(&r.X, &r.Z, &r.Y) + FeAdd(&r.Y, &r.Z, &r.Y) + FeAdd(&r.Z, &t0, &r.T) + FeSub(&r.T, &t0, &r.T) +} + +func geMixedSub(r *CompletedGroupElement, p *ExtendedGroupElement, q *PreComputedGroupElement) { + var t0 FieldElement + + FeAdd(&r.X, &p.Y, &p.X) + FeSub(&r.Y, &p.Y, &p.X) + FeMul(&r.Z, &r.X, &q.yMinusX) + FeMul(&r.Y, &r.Y, &q.yPlusX) + FeMul(&r.T, &q.xy2d, &p.T) + FeAdd(&t0, &p.Z, &p.Z) + FeSub(&r.X, &r.Z, &r.Y) + FeAdd(&r.Y, &r.Z, &r.Y) + FeSub(&r.Z, &t0, &r.T) + FeAdd(&r.T, &t0, &r.T) +} + +func slide(r *[256]int8, a *[32]byte) { + for i := range r { + r[i] = int8(1 & (a[i>>3] >> uint(i&7))) + } + + for i := range r { + if r[i] != 0 { + for b := 1; b <= 6 && i+b < 256; b++ { + if r[i+b] != 0 { + if r[i]+(r[i+b]<= -15 { + r[i] -= r[i+b] << uint(b) + for k := i + b; k < 256; k++ { + if r[k] == 0 { + r[k] = 1 + break + } + r[k] = 0 + } + } else { + break + } + } + } + } + } +} + +// GeDoubleScalarMultVartime sets r = a*A + b*B +// where a = a[0]+256*a[1]+...+256^31 a[31]. +// and b = b[0]+256*b[1]+...+256^31 b[31]. +// B is the Ed25519 base point (x,4/5) with x positive. +func GeDoubleScalarMultVartime(r *ProjectiveGroupElement, a *[32]byte, A *ExtendedGroupElement, b *[32]byte) { + var aSlide, bSlide [256]int8 + var Ai [8]CachedGroupElement // A,3A,5A,7A,9A,11A,13A,15A + var t CompletedGroupElement + var u, A2 ExtendedGroupElement + var i int + + slide(&aSlide, a) + slide(&bSlide, b) + + A.ToCached(&Ai[0]) + A.Double(&t) + t.ToExtended(&A2) + + for i := 0; i < 7; i++ { + geAdd(&t, &A2, &Ai[i]) + t.ToExtended(&u) + u.ToCached(&Ai[i+1]) + } + + r.Zero() + + for i = 255; i >= 0; i-- { + if aSlide[i] != 0 || bSlide[i] != 0 { + break + } + } + + for ; i >= 0; i-- { + r.Double(&t) + + if aSlide[i] > 0 { + t.ToExtended(&u) + geAdd(&t, &u, &Ai[aSlide[i]/2]) + } else if aSlide[i] < 0 { + t.ToExtended(&u) + geSub(&t, &u, &Ai[(-aSlide[i])/2]) + } + + if bSlide[i] > 0 { + t.ToExtended(&u) + geMixedAdd(&t, &u, &bi[bSlide[i]/2]) + } else if bSlide[i] < 0 { + t.ToExtended(&u) + geMixedSub(&t, &u, &bi[(-bSlide[i])/2]) + } + + t.ToProjective(r) + } +} + +// equal returns 1 if b == c and 0 otherwise, assuming that b and c are +// non-negative. +func equal(b, c int32) int32 { + x := uint32(b ^ c) + x-- + return int32(x >> 31) +} + +// negative returns 1 if b < 0 and 0 otherwise. +func negative(b int32) int32 { + return (b >> 31) & 1 +} + +func PreComputedGroupElementCMove(t, u *PreComputedGroupElement, b int32) { + FeCMove(&t.yPlusX, &u.yPlusX, b) + FeCMove(&t.yMinusX, &u.yMinusX, b) + FeCMove(&t.xy2d, &u.xy2d, b) +} + +func selectPoint(t *PreComputedGroupElement, pos int32, b int32) { + var minusT PreComputedGroupElement + bNegative := negative(b) + bAbs := b - (((-bNegative) & b) << 1) + + t.Zero() + for i := int32(0); i < 8; i++ { + PreComputedGroupElementCMove(t, &base[pos][i], equal(bAbs, i+1)) + } + FeCopy(&minusT.yPlusX, &t.yMinusX) + FeCopy(&minusT.yMinusX, &t.yPlusX) + FeNeg(&minusT.xy2d, &t.xy2d) + PreComputedGroupElementCMove(t, &minusT, bNegative) +} + +// GeScalarMultBase computes h = a*B, where +// a = a[0]+256*a[1]+...+256^31 a[31] +// B is the Ed25519 base point (x,4/5) with x positive. +// +// Preconditions: +// a[31] <= 127 +func GeScalarMultBase(h *ExtendedGroupElement, a *[32]byte) { + var e [64]int8 + + for i, v := range a { + e[2*i] = int8(v & 15) + e[2*i+1] = int8((v >> 4) & 15) + } + + // each e[i] is between 0 and 15 and e[63] is between 0 and 7. + + carry := int8(0) + for i := 0; i < 63; i++ { + e[i] += carry + carry = (e[i] + 8) >> 4 + e[i] -= carry << 4 + } + e[63] += carry + // each e[i] is between -8 and 8. + + h.Zero() + var t PreComputedGroupElement + var r CompletedGroupElement + for i := int32(1); i < 64; i += 2 { + selectPoint(&t, i/2, int32(e[i])) + geMixedAdd(&r, h, &t) + r.ToExtended(h) + } + + var s ProjectiveGroupElement + + h.Double(&r) + r.ToProjective(&s) + s.Double(&r) + r.ToProjective(&s) + s.Double(&r) + r.ToProjective(&s) + s.Double(&r) + r.ToExtended(h) + + for i := int32(0); i < 64; i += 2 { + selectPoint(&t, i/2, int32(e[i])) + geMixedAdd(&r, h, &t) + r.ToExtended(h) + } +} + +// The scalars are GF(2^252 + 27742317777372353535851937790883648493). + +// Input: +// a[0]+256*a[1]+...+256^31*a[31] = a +// b[0]+256*b[1]+...+256^31*b[31] = b +// c[0]+256*c[1]+...+256^31*c[31] = c +// +// Output: +// s[0]+256*s[1]+...+256^31*s[31] = (ab+c) mod l +// where l = 2^252 + 27742317777372353535851937790883648493. +func ScMulAdd(s, a, b, c *[32]byte) { + a0 := 2097151 & load3(a[:]) + a1 := 2097151 & (load4(a[2:]) >> 5) + a2 := 2097151 & (load3(a[5:]) >> 2) + a3 := 2097151 & (load4(a[7:]) >> 7) + a4 := 2097151 & (load4(a[10:]) >> 4) + a5 := 2097151 & (load3(a[13:]) >> 1) + a6 := 2097151 & (load4(a[15:]) >> 6) + a7 := 2097151 & (load3(a[18:]) >> 3) + a8 := 2097151 & load3(a[21:]) + a9 := 2097151 & (load4(a[23:]) >> 5) + a10 := 2097151 & (load3(a[26:]) >> 2) + a11 := (load4(a[28:]) >> 7) + b0 := 2097151 & load3(b[:]) + b1 := 2097151 & (load4(b[2:]) >> 5) + b2 := 2097151 & (load3(b[5:]) >> 2) + b3 := 2097151 & (load4(b[7:]) >> 7) + b4 := 2097151 & (load4(b[10:]) >> 4) + b5 := 2097151 & (load3(b[13:]) >> 1) + b6 := 2097151 & (load4(b[15:]) >> 6) + b7 := 2097151 & (load3(b[18:]) >> 3) + b8 := 2097151 & load3(b[21:]) + b9 := 2097151 & (load4(b[23:]) >> 5) + b10 := 2097151 & (load3(b[26:]) >> 2) + b11 := (load4(b[28:]) >> 7) + c0 := 2097151 & load3(c[:]) + c1 := 2097151 & (load4(c[2:]) >> 5) + c2 := 2097151 & (load3(c[5:]) >> 2) + c3 := 2097151 & (load4(c[7:]) >> 7) + c4 := 2097151 & (load4(c[10:]) >> 4) + c5 := 2097151 & (load3(c[13:]) >> 1) + c6 := 2097151 & (load4(c[15:]) >> 6) + c7 := 2097151 & (load3(c[18:]) >> 3) + c8 := 2097151 & load3(c[21:]) + c9 := 2097151 & (load4(c[23:]) >> 5) + c10 := 2097151 & (load3(c[26:]) >> 2) + c11 := (load4(c[28:]) >> 7) + var carry [23]int64 + + s0 := c0 + a0*b0 + s1 := c1 + a0*b1 + a1*b0 + s2 := c2 + a0*b2 + a1*b1 + a2*b0 + s3 := c3 + a0*b3 + a1*b2 + a2*b1 + a3*b0 + s4 := c4 + a0*b4 + a1*b3 + a2*b2 + a3*b1 + a4*b0 + s5 := c5 + a0*b5 + a1*b4 + a2*b3 + a3*b2 + a4*b1 + a5*b0 + s6 := c6 + a0*b6 + a1*b5 + a2*b4 + a3*b3 + a4*b2 + a5*b1 + a6*b0 + s7 := c7 + a0*b7 + a1*b6 + a2*b5 + a3*b4 + a4*b3 + a5*b2 + a6*b1 + a7*b0 + s8 := c8 + a0*b8 + a1*b7 + a2*b6 + a3*b5 + a4*b4 + a5*b3 + a6*b2 + a7*b1 + a8*b0 + s9 := c9 + a0*b9 + a1*b8 + a2*b7 + a3*b6 + a4*b5 + a5*b4 + a6*b3 + a7*b2 + a8*b1 + a9*b0 + s10 := c10 + a0*b10 + a1*b9 + a2*b8 + a3*b7 + a4*b6 + a5*b5 + a6*b4 + a7*b3 + a8*b2 + a9*b1 + a10*b0 + s11 := c11 + a0*b11 + a1*b10 + a2*b9 + a3*b8 + a4*b7 + a5*b6 + a6*b5 + a7*b4 + a8*b3 + a9*b2 + a10*b1 + a11*b0 + s12 := a1*b11 + a2*b10 + a3*b9 + a4*b8 + a5*b7 + a6*b6 + a7*b5 + a8*b4 + a9*b3 + a10*b2 + a11*b1 + s13 := a2*b11 + a3*b10 + a4*b9 + a5*b8 + a6*b7 + a7*b6 + a8*b5 + a9*b4 + a10*b3 + a11*b2 + s14 := a3*b11 + a4*b10 + a5*b9 + a6*b8 + a7*b7 + a8*b6 + a9*b5 + a10*b4 + a11*b3 + s15 := a4*b11 + a5*b10 + a6*b9 + a7*b8 + a8*b7 + a9*b6 + a10*b5 + a11*b4 + s16 := a5*b11 + a6*b10 + a7*b9 + a8*b8 + a9*b7 + a10*b6 + a11*b5 + s17 := a6*b11 + a7*b10 + a8*b9 + a9*b8 + a10*b7 + a11*b6 + s18 := a7*b11 + a8*b10 + a9*b9 + a10*b8 + a11*b7 + s19 := a8*b11 + a9*b10 + a10*b9 + a11*b8 + s20 := a9*b11 + a10*b10 + a11*b9 + s21 := a10*b11 + a11*b10 + s22 := a11 * b11 + s23 := int64(0) + + carry[0] = (s0 + (1 << 20)) >> 21 + s1 += carry[0] + s0 -= carry[0] << 21 + carry[2] = (s2 + (1 << 20)) >> 21 + s3 += carry[2] + s2 -= carry[2] << 21 + carry[4] = (s4 + (1 << 20)) >> 21 + s5 += carry[4] + s4 -= carry[4] << 21 + carry[6] = (s6 + (1 << 20)) >> 21 + s7 += carry[6] + s6 -= carry[6] << 21 + carry[8] = (s8 + (1 << 20)) >> 21 + s9 += carry[8] + s8 -= carry[8] << 21 + carry[10] = (s10 + (1 << 20)) >> 21 + s11 += carry[10] + s10 -= carry[10] << 21 + carry[12] = (s12 + (1 << 20)) >> 21 + s13 += carry[12] + s12 -= carry[12] << 21 + carry[14] = (s14 + (1 << 20)) >> 21 + s15 += carry[14] + s14 -= carry[14] << 21 + carry[16] = (s16 + (1 << 20)) >> 21 + s17 += carry[16] + s16 -= carry[16] << 21 + carry[18] = (s18 + (1 << 20)) >> 21 + s19 += carry[18] + s18 -= carry[18] << 21 + carry[20] = (s20 + (1 << 20)) >> 21 + s21 += carry[20] + s20 -= carry[20] << 21 + carry[22] = (s22 + (1 << 20)) >> 21 + s23 += carry[22] + s22 -= carry[22] << 21 + + carry[1] = (s1 + (1 << 20)) >> 21 + s2 += carry[1] + s1 -= carry[1] << 21 + carry[3] = (s3 + (1 << 20)) >> 21 + s4 += carry[3] + s3 -= carry[3] << 21 + carry[5] = (s5 + (1 << 20)) >> 21 + s6 += carry[5] + s5 -= carry[5] << 21 + carry[7] = (s7 + (1 << 20)) >> 21 + s8 += carry[7] + s7 -= carry[7] << 21 + carry[9] = (s9 + (1 << 20)) >> 21 + s10 += carry[9] + s9 -= carry[9] << 21 + carry[11] = (s11 + (1 << 20)) >> 21 + s12 += carry[11] + s11 -= carry[11] << 21 + carry[13] = (s13 + (1 << 20)) >> 21 + s14 += carry[13] + s13 -= carry[13] << 21 + carry[15] = (s15 + (1 << 20)) >> 21 + s16 += carry[15] + s15 -= carry[15] << 21 + carry[17] = (s17 + (1 << 20)) >> 21 + s18 += carry[17] + s17 -= carry[17] << 21 + carry[19] = (s19 + (1 << 20)) >> 21 + s20 += carry[19] + s19 -= carry[19] << 21 + carry[21] = (s21 + (1 << 20)) >> 21 + s22 += carry[21] + s21 -= carry[21] << 21 + + s11 += s23 * 666643 + s12 += s23 * 470296 + s13 += s23 * 654183 + s14 -= s23 * 997805 + s15 += s23 * 136657 + s16 -= s23 * 683901 + s23 = 0 + + s10 += s22 * 666643 + s11 += s22 * 470296 + s12 += s22 * 654183 + s13 -= s22 * 997805 + s14 += s22 * 136657 + s15 -= s22 * 683901 + s22 = 0 + + s9 += s21 * 666643 + s10 += s21 * 470296 + s11 += s21 * 654183 + s12 -= s21 * 997805 + s13 += s21 * 136657 + s14 -= s21 * 683901 + s21 = 0 + + s8 += s20 * 666643 + s9 += s20 * 470296 + s10 += s20 * 654183 + s11 -= s20 * 997805 + s12 += s20 * 136657 + s13 -= s20 * 683901 + s20 = 0 + + s7 += s19 * 666643 + s8 += s19 * 470296 + s9 += s19 * 654183 + s10 -= s19 * 997805 + s11 += s19 * 136657 + s12 -= s19 * 683901 + s19 = 0 + + s6 += s18 * 666643 + s7 += s18 * 470296 + s8 += s18 * 654183 + s9 -= s18 * 997805 + s10 += s18 * 136657 + s11 -= s18 * 683901 + s18 = 0 + + carry[6] = (s6 + (1 << 20)) >> 21 + s7 += carry[6] + s6 -= carry[6] << 21 + carry[8] = (s8 + (1 << 20)) >> 21 + s9 += carry[8] + s8 -= carry[8] << 21 + carry[10] = (s10 + (1 << 20)) >> 21 + s11 += carry[10] + s10 -= carry[10] << 21 + carry[12] = (s12 + (1 << 20)) >> 21 + s13 += carry[12] + s12 -= carry[12] << 21 + carry[14] = (s14 + (1 << 20)) >> 21 + s15 += carry[14] + s14 -= carry[14] << 21 + carry[16] = (s16 + (1 << 20)) >> 21 + s17 += carry[16] + s16 -= carry[16] << 21 + + carry[7] = (s7 + (1 << 20)) >> 21 + s8 += carry[7] + s7 -= carry[7] << 21 + carry[9] = (s9 + (1 << 20)) >> 21 + s10 += carry[9] + s9 -= carry[9] << 21 + carry[11] = (s11 + (1 << 20)) >> 21 + s12 += carry[11] + s11 -= carry[11] << 21 + carry[13] = (s13 + (1 << 20)) >> 21 + s14 += carry[13] + s13 -= carry[13] << 21 + carry[15] = (s15 + (1 << 20)) >> 21 + s16 += carry[15] + s15 -= carry[15] << 21 + + s5 += s17 * 666643 + s6 += s17 * 470296 + s7 += s17 * 654183 + s8 -= s17 * 997805 + s9 += s17 * 136657 + s10 -= s17 * 683901 + s17 = 0 + + s4 += s16 * 666643 + s5 += s16 * 470296 + s6 += s16 * 654183 + s7 -= s16 * 997805 + s8 += s16 * 136657 + s9 -= s16 * 683901 + s16 = 0 + + s3 += s15 * 666643 + s4 += s15 * 470296 + s5 += s15 * 654183 + s6 -= s15 * 997805 + s7 += s15 * 136657 + s8 -= s15 * 683901 + s15 = 0 + + s2 += s14 * 666643 + s3 += s14 * 470296 + s4 += s14 * 654183 + s5 -= s14 * 997805 + s6 += s14 * 136657 + s7 -= s14 * 683901 + s14 = 0 + + s1 += s13 * 666643 + s2 += s13 * 470296 + s3 += s13 * 654183 + s4 -= s13 * 997805 + s5 += s13 * 136657 + s6 -= s13 * 683901 + s13 = 0 + + s0 += s12 * 666643 + s1 += s12 * 470296 + s2 += s12 * 654183 + s3 -= s12 * 997805 + s4 += s12 * 136657 + s5 -= s12 * 683901 + s12 = 0 + + carry[0] = (s0 + (1 << 20)) >> 21 + s1 += carry[0] + s0 -= carry[0] << 21 + carry[2] = (s2 + (1 << 20)) >> 21 + s3 += carry[2] + s2 -= carry[2] << 21 + carry[4] = (s4 + (1 << 20)) >> 21 + s5 += carry[4] + s4 -= carry[4] << 21 + carry[6] = (s6 + (1 << 20)) >> 21 + s7 += carry[6] + s6 -= carry[6] << 21 + carry[8] = (s8 + (1 << 20)) >> 21 + s9 += carry[8] + s8 -= carry[8] << 21 + carry[10] = (s10 + (1 << 20)) >> 21 + s11 += carry[10] + s10 -= carry[10] << 21 + + carry[1] = (s1 + (1 << 20)) >> 21 + s2 += carry[1] + s1 -= carry[1] << 21 + carry[3] = (s3 + (1 << 20)) >> 21 + s4 += carry[3] + s3 -= carry[3] << 21 + carry[5] = (s5 + (1 << 20)) >> 21 + s6 += carry[5] + s5 -= carry[5] << 21 + carry[7] = (s7 + (1 << 20)) >> 21 + s8 += carry[7] + s7 -= carry[7] << 21 + carry[9] = (s9 + (1 << 20)) >> 21 + s10 += carry[9] + s9 -= carry[9] << 21 + carry[11] = (s11 + (1 << 20)) >> 21 + s12 += carry[11] + s11 -= carry[11] << 21 + + s0 += s12 * 666643 + s1 += s12 * 470296 + s2 += s12 * 654183 + s3 -= s12 * 997805 + s4 += s12 * 136657 + s5 -= s12 * 683901 + s12 = 0 + + carry[0] = s0 >> 21 + s1 += carry[0] + s0 -= carry[0] << 21 + carry[1] = s1 >> 21 + s2 += carry[1] + s1 -= carry[1] << 21 + carry[2] = s2 >> 21 + s3 += carry[2] + s2 -= carry[2] << 21 + carry[3] = s3 >> 21 + s4 += carry[3] + s3 -= carry[3] << 21 + carry[4] = s4 >> 21 + s5 += carry[4] + s4 -= carry[4] << 21 + carry[5] = s5 >> 21 + s6 += carry[5] + s5 -= carry[5] << 21 + carry[6] = s6 >> 21 + s7 += carry[6] + s6 -= carry[6] << 21 + carry[7] = s7 >> 21 + s8 += carry[7] + s7 -= carry[7] << 21 + carry[8] = s8 >> 21 + s9 += carry[8] + s8 -= carry[8] << 21 + carry[9] = s9 >> 21 + s10 += carry[9] + s9 -= carry[9] << 21 + carry[10] = s10 >> 21 + s11 += carry[10] + s10 -= carry[10] << 21 + carry[11] = s11 >> 21 + s12 += carry[11] + s11 -= carry[11] << 21 + + s0 += s12 * 666643 + s1 += s12 * 470296 + s2 += s12 * 654183 + s3 -= s12 * 997805 + s4 += s12 * 136657 + s5 -= s12 * 683901 + s12 = 0 + + carry[0] = s0 >> 21 + s1 += carry[0] + s0 -= carry[0] << 21 + carry[1] = s1 >> 21 + s2 += carry[1] + s1 -= carry[1] << 21 + carry[2] = s2 >> 21 + s3 += carry[2] + s2 -= carry[2] << 21 + carry[3] = s3 >> 21 + s4 += carry[3] + s3 -= carry[3] << 21 + carry[4] = s4 >> 21 + s5 += carry[4] + s4 -= carry[4] << 21 + carry[5] = s5 >> 21 + s6 += carry[5] + s5 -= carry[5] << 21 + carry[6] = s6 >> 21 + s7 += carry[6] + s6 -= carry[6] << 21 + carry[7] = s7 >> 21 + s8 += carry[7] + s7 -= carry[7] << 21 + carry[8] = s8 >> 21 + s9 += carry[8] + s8 -= carry[8] << 21 + carry[9] = s9 >> 21 + s10 += carry[9] + s9 -= carry[9] << 21 + carry[10] = s10 >> 21 + s11 += carry[10] + s10 -= carry[10] << 21 + + s[0] = byte(s0 >> 0) + s[1] = byte(s0 >> 8) + s[2] = byte((s0 >> 16) | (s1 << 5)) + s[3] = byte(s1 >> 3) + s[4] = byte(s1 >> 11) + s[5] = byte((s1 >> 19) | (s2 << 2)) + s[6] = byte(s2 >> 6) + s[7] = byte((s2 >> 14) | (s3 << 7)) + s[8] = byte(s3 >> 1) + s[9] = byte(s3 >> 9) + s[10] = byte((s3 >> 17) | (s4 << 4)) + s[11] = byte(s4 >> 4) + s[12] = byte(s4 >> 12) + s[13] = byte((s4 >> 20) | (s5 << 1)) + s[14] = byte(s5 >> 7) + s[15] = byte((s5 >> 15) | (s6 << 6)) + s[16] = byte(s6 >> 2) + s[17] = byte(s6 >> 10) + s[18] = byte((s6 >> 18) | (s7 << 3)) + s[19] = byte(s7 >> 5) + s[20] = byte(s7 >> 13) + s[21] = byte(s8 >> 0) + s[22] = byte(s8 >> 8) + s[23] = byte((s8 >> 16) | (s9 << 5)) + s[24] = byte(s9 >> 3) + s[25] = byte(s9 >> 11) + s[26] = byte((s9 >> 19) | (s10 << 2)) + s[27] = byte(s10 >> 6) + s[28] = byte((s10 >> 14) | (s11 << 7)) + s[29] = byte(s11 >> 1) + s[30] = byte(s11 >> 9) + s[31] = byte(s11 >> 17) +} + +// Input: +// s[0]+256*s[1]+...+256^63*s[63] = s +// +// Output: +// s[0]+256*s[1]+...+256^31*s[31] = s mod l +// where l = 2^252 + 27742317777372353535851937790883648493. +func ScReduce(out *[32]byte, s *[64]byte) { + s0 := 2097151 & load3(s[:]) + s1 := 2097151 & (load4(s[2:]) >> 5) + s2 := 2097151 & (load3(s[5:]) >> 2) + s3 := 2097151 & (load4(s[7:]) >> 7) + s4 := 2097151 & (load4(s[10:]) >> 4) + s5 := 2097151 & (load3(s[13:]) >> 1) + s6 := 2097151 & (load4(s[15:]) >> 6) + s7 := 2097151 & (load3(s[18:]) >> 3) + s8 := 2097151 & load3(s[21:]) + s9 := 2097151 & (load4(s[23:]) >> 5) + s10 := 2097151 & (load3(s[26:]) >> 2) + s11 := 2097151 & (load4(s[28:]) >> 7) + s12 := 2097151 & (load4(s[31:]) >> 4) + s13 := 2097151 & (load3(s[34:]) >> 1) + s14 := 2097151 & (load4(s[36:]) >> 6) + s15 := 2097151 & (load3(s[39:]) >> 3) + s16 := 2097151 & load3(s[42:]) + s17 := 2097151 & (load4(s[44:]) >> 5) + s18 := 2097151 & (load3(s[47:]) >> 2) + s19 := 2097151 & (load4(s[49:]) >> 7) + s20 := 2097151 & (load4(s[52:]) >> 4) + s21 := 2097151 & (load3(s[55:]) >> 1) + s22 := 2097151 & (load4(s[57:]) >> 6) + s23 := (load4(s[60:]) >> 3) + + s11 += s23 * 666643 + s12 += s23 * 470296 + s13 += s23 * 654183 + s14 -= s23 * 997805 + s15 += s23 * 136657 + s16 -= s23 * 683901 + s23 = 0 + + s10 += s22 * 666643 + s11 += s22 * 470296 + s12 += s22 * 654183 + s13 -= s22 * 997805 + s14 += s22 * 136657 + s15 -= s22 * 683901 + s22 = 0 + + s9 += s21 * 666643 + s10 += s21 * 470296 + s11 += s21 * 654183 + s12 -= s21 * 997805 + s13 += s21 * 136657 + s14 -= s21 * 683901 + s21 = 0 + + s8 += s20 * 666643 + s9 += s20 * 470296 + s10 += s20 * 654183 + s11 -= s20 * 997805 + s12 += s20 * 136657 + s13 -= s20 * 683901 + s20 = 0 + + s7 += s19 * 666643 + s8 += s19 * 470296 + s9 += s19 * 654183 + s10 -= s19 * 997805 + s11 += s19 * 136657 + s12 -= s19 * 683901 + s19 = 0 + + s6 += s18 * 666643 + s7 += s18 * 470296 + s8 += s18 * 654183 + s9 -= s18 * 997805 + s10 += s18 * 136657 + s11 -= s18 * 683901 + s18 = 0 + + var carry [17]int64 + + carry[6] = (s6 + (1 << 20)) >> 21 + s7 += carry[6] + s6 -= carry[6] << 21 + carry[8] = (s8 + (1 << 20)) >> 21 + s9 += carry[8] + s8 -= carry[8] << 21 + carry[10] = (s10 + (1 << 20)) >> 21 + s11 += carry[10] + s10 -= carry[10] << 21 + carry[12] = (s12 + (1 << 20)) >> 21 + s13 += carry[12] + s12 -= carry[12] << 21 + carry[14] = (s14 + (1 << 20)) >> 21 + s15 += carry[14] + s14 -= carry[14] << 21 + carry[16] = (s16 + (1 << 20)) >> 21 + s17 += carry[16] + s16 -= carry[16] << 21 + + carry[7] = (s7 + (1 << 20)) >> 21 + s8 += carry[7] + s7 -= carry[7] << 21 + carry[9] = (s9 + (1 << 20)) >> 21 + s10 += carry[9] + s9 -= carry[9] << 21 + carry[11] = (s11 + (1 << 20)) >> 21 + s12 += carry[11] + s11 -= carry[11] << 21 + carry[13] = (s13 + (1 << 20)) >> 21 + s14 += carry[13] + s13 -= carry[13] << 21 + carry[15] = (s15 + (1 << 20)) >> 21 + s16 += carry[15] + s15 -= carry[15] << 21 + + s5 += s17 * 666643 + s6 += s17 * 470296 + s7 += s17 * 654183 + s8 -= s17 * 997805 + s9 += s17 * 136657 + s10 -= s17 * 683901 + s17 = 0 + + s4 += s16 * 666643 + s5 += s16 * 470296 + s6 += s16 * 654183 + s7 -= s16 * 997805 + s8 += s16 * 136657 + s9 -= s16 * 683901 + s16 = 0 + + s3 += s15 * 666643 + s4 += s15 * 470296 + s5 += s15 * 654183 + s6 -= s15 * 997805 + s7 += s15 * 136657 + s8 -= s15 * 683901 + s15 = 0 + + s2 += s14 * 666643 + s3 += s14 * 470296 + s4 += s14 * 654183 + s5 -= s14 * 997805 + s6 += s14 * 136657 + s7 -= s14 * 683901 + s14 = 0 + + s1 += s13 * 666643 + s2 += s13 * 470296 + s3 += s13 * 654183 + s4 -= s13 * 997805 + s5 += s13 * 136657 + s6 -= s13 * 683901 + s13 = 0 + + s0 += s12 * 666643 + s1 += s12 * 470296 + s2 += s12 * 654183 + s3 -= s12 * 997805 + s4 += s12 * 136657 + s5 -= s12 * 683901 + s12 = 0 + + carry[0] = (s0 + (1 << 20)) >> 21 + s1 += carry[0] + s0 -= carry[0] << 21 + carry[2] = (s2 + (1 << 20)) >> 21 + s3 += carry[2] + s2 -= carry[2] << 21 + carry[4] = (s4 + (1 << 20)) >> 21 + s5 += carry[4] + s4 -= carry[4] << 21 + carry[6] = (s6 + (1 << 20)) >> 21 + s7 += carry[6] + s6 -= carry[6] << 21 + carry[8] = (s8 + (1 << 20)) >> 21 + s9 += carry[8] + s8 -= carry[8] << 21 + carry[10] = (s10 + (1 << 20)) >> 21 + s11 += carry[10] + s10 -= carry[10] << 21 + + carry[1] = (s1 + (1 << 20)) >> 21 + s2 += carry[1] + s1 -= carry[1] << 21 + carry[3] = (s3 + (1 << 20)) >> 21 + s4 += carry[3] + s3 -= carry[3] << 21 + carry[5] = (s5 + (1 << 20)) >> 21 + s6 += carry[5] + s5 -= carry[5] << 21 + carry[7] = (s7 + (1 << 20)) >> 21 + s8 += carry[7] + s7 -= carry[7] << 21 + carry[9] = (s9 + (1 << 20)) >> 21 + s10 += carry[9] + s9 -= carry[9] << 21 + carry[11] = (s11 + (1 << 20)) >> 21 + s12 += carry[11] + s11 -= carry[11] << 21 + + s0 += s12 * 666643 + s1 += s12 * 470296 + s2 += s12 * 654183 + s3 -= s12 * 997805 + s4 += s12 * 136657 + s5 -= s12 * 683901 + s12 = 0 + + carry[0] = s0 >> 21 + s1 += carry[0] + s0 -= carry[0] << 21 + carry[1] = s1 >> 21 + s2 += carry[1] + s1 -= carry[1] << 21 + carry[2] = s2 >> 21 + s3 += carry[2] + s2 -= carry[2] << 21 + carry[3] = s3 >> 21 + s4 += carry[3] + s3 -= carry[3] << 21 + carry[4] = s4 >> 21 + s5 += carry[4] + s4 -= carry[4] << 21 + carry[5] = s5 >> 21 + s6 += carry[5] + s5 -= carry[5] << 21 + carry[6] = s6 >> 21 + s7 += carry[6] + s6 -= carry[6] << 21 + carry[7] = s7 >> 21 + s8 += carry[7] + s7 -= carry[7] << 21 + carry[8] = s8 >> 21 + s9 += carry[8] + s8 -= carry[8] << 21 + carry[9] = s9 >> 21 + s10 += carry[9] + s9 -= carry[9] << 21 + carry[10] = s10 >> 21 + s11 += carry[10] + s10 -= carry[10] << 21 + carry[11] = s11 >> 21 + s12 += carry[11] + s11 -= carry[11] << 21 + + s0 += s12 * 666643 + s1 += s12 * 470296 + s2 += s12 * 654183 + s3 -= s12 * 997805 + s4 += s12 * 136657 + s5 -= s12 * 683901 + s12 = 0 + + carry[0] = s0 >> 21 + s1 += carry[0] + s0 -= carry[0] << 21 + carry[1] = s1 >> 21 + s2 += carry[1] + s1 -= carry[1] << 21 + carry[2] = s2 >> 21 + s3 += carry[2] + s2 -= carry[2] << 21 + carry[3] = s3 >> 21 + s4 += carry[3] + s3 -= carry[3] << 21 + carry[4] = s4 >> 21 + s5 += carry[4] + s4 -= carry[4] << 21 + carry[5] = s5 >> 21 + s6 += carry[5] + s5 -= carry[5] << 21 + carry[6] = s6 >> 21 + s7 += carry[6] + s6 -= carry[6] << 21 + carry[7] = s7 >> 21 + s8 += carry[7] + s7 -= carry[7] << 21 + carry[8] = s8 >> 21 + s9 += carry[8] + s8 -= carry[8] << 21 + carry[9] = s9 >> 21 + s10 += carry[9] + s9 -= carry[9] << 21 + carry[10] = s10 >> 21 + s11 += carry[10] + s10 -= carry[10] << 21 + + out[0] = byte(s0 >> 0) + out[1] = byte(s0 >> 8) + out[2] = byte((s0 >> 16) | (s1 << 5)) + out[3] = byte(s1 >> 3) + out[4] = byte(s1 >> 11) + out[5] = byte((s1 >> 19) | (s2 << 2)) + out[6] = byte(s2 >> 6) + out[7] = byte((s2 >> 14) | (s3 << 7)) + out[8] = byte(s3 >> 1) + out[9] = byte(s3 >> 9) + out[10] = byte((s3 >> 17) | (s4 << 4)) + out[11] = byte(s4 >> 4) + out[12] = byte(s4 >> 12) + out[13] = byte((s4 >> 20) | (s5 << 1)) + out[14] = byte(s5 >> 7) + out[15] = byte((s5 >> 15) | (s6 << 6)) + out[16] = byte(s6 >> 2) + out[17] = byte(s6 >> 10) + out[18] = byte((s6 >> 18) | (s7 << 3)) + out[19] = byte(s7 >> 5) + out[20] = byte(s7 >> 13) + out[21] = byte(s8 >> 0) + out[22] = byte(s8 >> 8) + out[23] = byte((s8 >> 16) | (s9 << 5)) + out[24] = byte(s9 >> 3) + out[25] = byte(s9 >> 11) + out[26] = byte((s9 >> 19) | (s10 << 2)) + out[27] = byte(s10 >> 6) + out[28] = byte((s10 >> 14) | (s11 << 7)) + out[29] = byte(s11 >> 1) + out[30] = byte(s11 >> 9) + out[31] = byte(s11 >> 17) +} + +// order is the order of Curve25519 in little-endian form. +var order = [4]uint64{0x5812631a5cf5d3ed, 0x14def9dea2f79cd6, 0, 0x1000000000000000} + +// ScMinimal returns true if the given scalar is less than the order of the +// curve. +func ScMinimal(scalar *[32]byte) bool { + for i := 3; ; i-- { + v := binary.LittleEndian.Uint64(scalar[i*8:]) + if v > order[i] { + return false + } else if v < order[i] { + break + } else if i == 0 { + return false + } + } + + return true +} diff --git a/internal/crypto/go.mod b/internal/crypto/go.mod new file mode 100644 index 000000000..4ea849af0 --- /dev/null +++ b/internal/crypto/go.mod @@ -0,0 +1,9 @@ +module github.com/github/go-ghcs-crypto + +go 1.11 + +require ( + golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 + golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 + golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 +) diff --git a/internal/crypto/internal/subtle/aliasing.go b/internal/crypto/internal/subtle/aliasing.go new file mode 100644 index 000000000..f38797bfa --- /dev/null +++ b/internal/crypto/internal/subtle/aliasing.go @@ -0,0 +1,32 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !appengine + +// Package subtle implements functions that are often useful in cryptographic +// code but require careful thought to use correctly. +package subtle // import "golang.org/x/crypto/internal/subtle" + +import "unsafe" + +// AnyOverlap reports whether x and y share memory at any (not necessarily +// corresponding) index. The memory beyond the slice length is ignored. +func AnyOverlap(x, y []byte) bool { + return len(x) > 0 && len(y) > 0 && + uintptr(unsafe.Pointer(&x[0])) <= uintptr(unsafe.Pointer(&y[len(y)-1])) && + uintptr(unsafe.Pointer(&y[0])) <= uintptr(unsafe.Pointer(&x[len(x)-1])) +} + +// InexactOverlap reports whether x and y share memory at any non-corresponding +// index. The memory beyond the slice length is ignored. Note that x and y can +// have different lengths and still not have any inexact overlap. +// +// InexactOverlap can be used to implement the requirements of the crypto/cipher +// AEAD, Block, BlockMode and Stream interfaces. +func InexactOverlap(x, y []byte) bool { + if len(x) == 0 || len(y) == 0 || &x[0] == &y[0] { + return false + } + return AnyOverlap(x, y) +} diff --git a/internal/crypto/internal/subtle/aliasing_appengine.go b/internal/crypto/internal/subtle/aliasing_appengine.go new file mode 100644 index 000000000..0cc4a8a64 --- /dev/null +++ b/internal/crypto/internal/subtle/aliasing_appengine.go @@ -0,0 +1,35 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build appengine + +// Package subtle implements functions that are often useful in cryptographic +// code but require careful thought to use correctly. +package subtle // import "golang.org/x/crypto/internal/subtle" + +// This is the Google App Engine standard variant based on reflect +// because the unsafe package and cgo are disallowed. + +import "reflect" + +// AnyOverlap reports whether x and y share memory at any (not necessarily +// corresponding) index. The memory beyond the slice length is ignored. +func AnyOverlap(x, y []byte) bool { + return len(x) > 0 && len(y) > 0 && + reflect.ValueOf(&x[0]).Pointer() <= reflect.ValueOf(&y[len(y)-1]).Pointer() && + reflect.ValueOf(&y[0]).Pointer() <= reflect.ValueOf(&x[len(x)-1]).Pointer() +} + +// InexactOverlap reports whether x and y share memory at any non-corresponding +// index. The memory beyond the slice length is ignored. Note that x and y can +// have different lengths and still not have any inexact overlap. +// +// InexactOverlap can be used to implement the requirements of the crypto/cipher +// AEAD, Block, BlockMode and Stream interfaces. +func InexactOverlap(x, y []byte) bool { + if len(x) == 0 || len(y) == 0 || &x[0] == &y[0] { + return false + } + return AnyOverlap(x, y) +} diff --git a/internal/crypto/poly1305/bits_compat.go b/internal/crypto/poly1305/bits_compat.go new file mode 100644 index 000000000..157a69f61 --- /dev/null +++ b/internal/crypto/poly1305/bits_compat.go @@ -0,0 +1,39 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.13 + +package poly1305 + +// Generic fallbacks for the math/bits intrinsics, copied from +// src/math/bits/bits.go. They were added in Go 1.12, but Add64 and Sum64 had +// variable time fallbacks until Go 1.13. + +func bitsAdd64(x, y, carry uint64) (sum, carryOut uint64) { + sum = x + y + carry + carryOut = ((x & y) | ((x | y) &^ sum)) >> 63 + return +} + +func bitsSub64(x, y, borrow uint64) (diff, borrowOut uint64) { + diff = x - y - borrow + borrowOut = ((^x & y) | (^(x ^ y) & diff)) >> 63 + return +} + +func bitsMul64(x, y uint64) (hi, lo uint64) { + const mask32 = 1<<32 - 1 + x0 := x & mask32 + x1 := x >> 32 + y0 := y & mask32 + y1 := y >> 32 + w0 := x0 * y0 + t := x1*y0 + w0>>32 + w1 := t & mask32 + w2 := t >> 32 + w1 += x0 * y1 + hi = x1*y1 + w2 + w1>>32 + lo = x * y + return +} diff --git a/internal/crypto/poly1305/bits_go1.13.go b/internal/crypto/poly1305/bits_go1.13.go new file mode 100644 index 000000000..a0a185f0f --- /dev/null +++ b/internal/crypto/poly1305/bits_go1.13.go @@ -0,0 +1,21 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.13 + +package poly1305 + +import "math/bits" + +func bitsAdd64(x, y, carry uint64) (sum, carryOut uint64) { + return bits.Add64(x, y, carry) +} + +func bitsSub64(x, y, borrow uint64) (diff, borrowOut uint64) { + return bits.Sub64(x, y, borrow) +} + +func bitsMul64(x, y uint64) (hi, lo uint64) { + return bits.Mul64(x, y) +} diff --git a/internal/crypto/poly1305/mac_noasm.go b/internal/crypto/poly1305/mac_noasm.go new file mode 100644 index 000000000..d118f30ed --- /dev/null +++ b/internal/crypto/poly1305/mac_noasm.go @@ -0,0 +1,9 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !amd64,!ppc64le,!s390x gccgo purego + +package poly1305 + +type mac struct{ macGeneric } diff --git a/internal/crypto/poly1305/poly1305.go b/internal/crypto/poly1305/poly1305.go new file mode 100644 index 000000000..9d7a6af09 --- /dev/null +++ b/internal/crypto/poly1305/poly1305.go @@ -0,0 +1,99 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package poly1305 implements Poly1305 one-time message authentication code as +// specified in https://cr.yp.to/mac/poly1305-20050329.pdf. +// +// Poly1305 is a fast, one-time authentication function. It is infeasible for an +// attacker to generate an authenticator for a message without the key. However, a +// key must only be used for a single message. Authenticating two different +// messages with the same key allows an attacker to forge authenticators for other +// messages with the same key. +// +// Poly1305 was originally coupled with AES in order to make Poly1305-AES. AES was +// used with a fixed key in order to generate one-time keys from an nonce. +// However, in this package AES isn't used and the one-time key is specified +// directly. +package poly1305 // import "golang.org/x/crypto/poly1305" + +import "crypto/subtle" + +// TagSize is the size, in bytes, of a poly1305 authenticator. +const TagSize = 16 + +// Sum generates an authenticator for msg using a one-time key and puts the +// 16-byte result into out. Authenticating two different messages with the same +// key allows an attacker to forge messages at will. +func Sum(out *[16]byte, m []byte, key *[32]byte) { + h := New(key) + h.Write(m) + h.Sum(out[:0]) +} + +// Verify returns true if mac is a valid authenticator for m with the given key. +func Verify(mac *[16]byte, m []byte, key *[32]byte) bool { + var tmp [16]byte + Sum(&tmp, m, key) + return subtle.ConstantTimeCompare(tmp[:], mac[:]) == 1 +} + +// New returns a new MAC computing an authentication +// tag of all data written to it with the given key. +// This allows writing the message progressively instead +// of passing it as a single slice. Common users should use +// the Sum function instead. +// +// The key must be unique for each message, as authenticating +// two different messages with the same key allows an attacker +// to forge messages at will. +func New(key *[32]byte) *MAC { + m := &MAC{} + initialize(key, &m.macState) + return m +} + +// MAC is an io.Writer computing an authentication tag +// of the data written to it. +// +// MAC cannot be used like common hash.Hash implementations, +// because using a poly1305 key twice breaks its security. +// Therefore writing data to a running MAC after calling +// Sum or Verify causes it to panic. +type MAC struct { + mac // platform-dependent implementation + + finalized bool +} + +// Size returns the number of bytes Sum will return. +func (h *MAC) Size() int { return TagSize } + +// Write adds more data to the running message authentication code. +// It never returns an error. +// +// It must not be called after the first call of Sum or Verify. +func (h *MAC) Write(p []byte) (n int, err error) { + if h.finalized { + panic("poly1305: write to MAC after Sum or Verify") + } + return h.mac.Write(p) +} + +// Sum computes the authenticator of all data written to the +// message authentication code. +func (h *MAC) Sum(b []byte) []byte { + var mac [TagSize]byte + h.mac.Sum(&mac) + h.finalized = true + return append(b, mac[:]...) +} + +// Verify returns whether the authenticator of all data written to +// the message authentication code matches the expected value. +func (h *MAC) Verify(expected []byte) bool { + var mac [TagSize]byte + h.mac.Sum(&mac) + h.finalized = true + return subtle.ConstantTimeCompare(expected, mac[:]) == 1 +} diff --git a/internal/crypto/poly1305/sum_amd64.go b/internal/crypto/poly1305/sum_amd64.go new file mode 100644 index 000000000..99e5a1d50 --- /dev/null +++ b/internal/crypto/poly1305/sum_amd64.go @@ -0,0 +1,47 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !gccgo,!purego + +package poly1305 + +//go:noescape +func update(state *macState, msg []byte) + +// mac is a wrapper for macGeneric that redirects calls that would have gone to +// updateGeneric to update. +// +// Its Write and Sum methods are otherwise identical to the macGeneric ones, but +// using function pointers would carry a major performance cost. +type mac struct{ macGeneric } + +func (h *mac) Write(p []byte) (int, error) { + nn := len(p) + if h.offset > 0 { + n := copy(h.buffer[h.offset:], p) + if h.offset+n < TagSize { + h.offset += n + return nn, nil + } + p = p[n:] + h.offset = 0 + update(&h.macState, h.buffer[:]) + } + if n := len(p) - (len(p) % TagSize); n > 0 { + update(&h.macState, p[:n]) + p = p[n:] + } + if len(p) > 0 { + h.offset += copy(h.buffer[h.offset:], p) + } + return nn, nil +} + +func (h *mac) Sum(out *[16]byte) { + state := h.macState + if h.offset > 0 { + update(&state, h.buffer[:h.offset]) + } + finalize(out, &state.h, &state.s) +} diff --git a/internal/crypto/poly1305/sum_amd64.s b/internal/crypto/poly1305/sum_amd64.s new file mode 100644 index 000000000..8d394a212 --- /dev/null +++ b/internal/crypto/poly1305/sum_amd64.s @@ -0,0 +1,108 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !gccgo,!purego + +#include "textflag.h" + +#define POLY1305_ADD(msg, h0, h1, h2) \ + ADDQ 0(msg), h0; \ + ADCQ 8(msg), h1; \ + ADCQ $1, h2; \ + LEAQ 16(msg), msg + +#define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3) \ + MOVQ r0, AX; \ + MULQ h0; \ + MOVQ AX, t0; \ + MOVQ DX, t1; \ + MOVQ r0, AX; \ + MULQ h1; \ + ADDQ AX, t1; \ + ADCQ $0, DX; \ + MOVQ r0, t2; \ + IMULQ h2, t2; \ + ADDQ DX, t2; \ + \ + MOVQ r1, AX; \ + MULQ h0; \ + ADDQ AX, t1; \ + ADCQ $0, DX; \ + MOVQ DX, h0; \ + MOVQ r1, t3; \ + IMULQ h2, t3; \ + MOVQ r1, AX; \ + MULQ h1; \ + ADDQ AX, t2; \ + ADCQ DX, t3; \ + ADDQ h0, t2; \ + ADCQ $0, t3; \ + \ + MOVQ t0, h0; \ + MOVQ t1, h1; \ + MOVQ t2, h2; \ + ANDQ $3, h2; \ + MOVQ t2, t0; \ + ANDQ $0xFFFFFFFFFFFFFFFC, t0; \ + ADDQ t0, h0; \ + ADCQ t3, h1; \ + ADCQ $0, h2; \ + SHRQ $2, t3, t2; \ + SHRQ $2, t3; \ + ADDQ t2, h0; \ + ADCQ t3, h1; \ + ADCQ $0, h2 + +// func update(state *[7]uint64, msg []byte) +TEXT ·update(SB), $0-32 + MOVQ state+0(FP), DI + MOVQ msg_base+8(FP), SI + MOVQ msg_len+16(FP), R15 + + MOVQ 0(DI), R8 // h0 + MOVQ 8(DI), R9 // h1 + MOVQ 16(DI), R10 // h2 + MOVQ 24(DI), R11 // r0 + MOVQ 32(DI), R12 // r1 + + CMPQ R15, $16 + JB bytes_between_0_and_15 + +loop: + POLY1305_ADD(SI, R8, R9, R10) + +multiply: + POLY1305_MUL(R8, R9, R10, R11, R12, BX, CX, R13, R14) + SUBQ $16, R15 + CMPQ R15, $16 + JAE loop + +bytes_between_0_and_15: + TESTQ R15, R15 + JZ done + MOVQ $1, BX + XORQ CX, CX + XORQ R13, R13 + ADDQ R15, SI + +flush_buffer: + SHLQ $8, BX, CX + SHLQ $8, BX + MOVB -1(SI), R13 + XORQ R13, BX + DECQ SI + DECQ R15 + JNZ flush_buffer + + ADDQ BX, R8 + ADCQ CX, R9 + ADCQ $0, R10 + MOVQ $16, R15 + JMP multiply + +done: + MOVQ R8, 0(DI) + MOVQ R9, 8(DI) + MOVQ R10, 16(DI) + RET diff --git a/internal/crypto/poly1305/sum_generic.go b/internal/crypto/poly1305/sum_generic.go new file mode 100644 index 000000000..c942a6590 --- /dev/null +++ b/internal/crypto/poly1305/sum_generic.go @@ -0,0 +1,310 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file provides the generic implementation of Sum and MAC. Other files +// might provide optimized assembly implementations of some of this code. + +package poly1305 + +import "encoding/binary" + +// Poly1305 [RFC 7539] is a relatively simple algorithm: the authentication tag +// for a 64 bytes message is approximately +// +// s + m[0:16] * r⁴ + m[16:32] * r³ + m[32:48] * r² + m[48:64] * r mod 2¹³⁰ - 5 +// +// for some secret r and s. It can be computed sequentially like +// +// for len(msg) > 0: +// h += read(msg, 16) +// h *= r +// h %= 2¹³⁰ - 5 +// return h + s +// +// All the complexity is about doing performant constant-time math on numbers +// larger than any available numeric type. + +func sumGeneric(out *[TagSize]byte, msg []byte, key *[32]byte) { + h := newMACGeneric(key) + h.Write(msg) + h.Sum(out) +} + +func newMACGeneric(key *[32]byte) macGeneric { + m := macGeneric{} + initialize(key, &m.macState) + return m +} + +// macState holds numbers in saturated 64-bit little-endian limbs. That is, +// the value of [x0, x1, x2] is x[0] + x[1] * 2⁶⁴ + x[2] * 2¹²⁸. +type macState struct { + // h is the main accumulator. It is to be interpreted modulo 2¹³⁰ - 5, but + // can grow larger during and after rounds. It must, however, remain below + // 2 * (2¹³⁰ - 5). + h [3]uint64 + // r and s are the private key components. + r [2]uint64 + s [2]uint64 +} + +type macGeneric struct { + macState + + buffer [TagSize]byte + offset int +} + +// Write splits the incoming message into TagSize chunks, and passes them to +// update. It buffers incomplete chunks. +func (h *macGeneric) Write(p []byte) (int, error) { + nn := len(p) + if h.offset > 0 { + n := copy(h.buffer[h.offset:], p) + if h.offset+n < TagSize { + h.offset += n + return nn, nil + } + p = p[n:] + h.offset = 0 + updateGeneric(&h.macState, h.buffer[:]) + } + if n := len(p) - (len(p) % TagSize); n > 0 { + updateGeneric(&h.macState, p[:n]) + p = p[n:] + } + if len(p) > 0 { + h.offset += copy(h.buffer[h.offset:], p) + } + return nn, nil +} + +// Sum flushes the last incomplete chunk from the buffer, if any, and generates +// the MAC output. It does not modify its state, in order to allow for multiple +// calls to Sum, even if no Write is allowed after Sum. +func (h *macGeneric) Sum(out *[TagSize]byte) { + state := h.macState + if h.offset > 0 { + updateGeneric(&state, h.buffer[:h.offset]) + } + finalize(out, &state.h, &state.s) +} + +// [rMask0, rMask1] is the specified Poly1305 clamping mask in little-endian. It +// clears some bits of the secret coefficient to make it possible to implement +// multiplication more efficiently. +const ( + rMask0 = 0x0FFFFFFC0FFFFFFF + rMask1 = 0x0FFFFFFC0FFFFFFC +) + +// initialize loads the 256-bit key into the two 128-bit secret values r and s. +func initialize(key *[32]byte, m *macState) { + m.r[0] = binary.LittleEndian.Uint64(key[0:8]) & rMask0 + m.r[1] = binary.LittleEndian.Uint64(key[8:16]) & rMask1 + m.s[0] = binary.LittleEndian.Uint64(key[16:24]) + m.s[1] = binary.LittleEndian.Uint64(key[24:32]) +} + +// uint128 holds a 128-bit number as two 64-bit limbs, for use with the +// bits.Mul64 and bits.Add64 intrinsics. +type uint128 struct { + lo, hi uint64 +} + +func mul64(a, b uint64) uint128 { + hi, lo := bitsMul64(a, b) + return uint128{lo, hi} +} + +func add128(a, b uint128) uint128 { + lo, c := bitsAdd64(a.lo, b.lo, 0) + hi, c := bitsAdd64(a.hi, b.hi, c) + if c != 0 { + panic("poly1305: unexpected overflow") + } + return uint128{lo, hi} +} + +func shiftRightBy2(a uint128) uint128 { + a.lo = a.lo>>2 | (a.hi&3)<<62 + a.hi = a.hi >> 2 + return a +} + +// updateGeneric absorbs msg into the state.h accumulator. For each chunk m of +// 128 bits of message, it computes +// +// h₊ = (h + m) * r mod 2¹³⁰ - 5 +// +// If the msg length is not a multiple of TagSize, it assumes the last +// incomplete chunk is the final one. +func updateGeneric(state *macState, msg []byte) { + h0, h1, h2 := state.h[0], state.h[1], state.h[2] + r0, r1 := state.r[0], state.r[1] + + for len(msg) > 0 { + var c uint64 + + // For the first step, h + m, we use a chain of bits.Add64 intrinsics. + // The resulting value of h might exceed 2¹³⁰ - 5, but will be partially + // reduced at the end of the multiplication below. + // + // The spec requires us to set a bit just above the message size, not to + // hide leading zeroes. For full chunks, that's 1 << 128, so we can just + // add 1 to the most significant (2¹²⁸) limb, h2. + if len(msg) >= TagSize { + h0, c = bitsAdd64(h0, binary.LittleEndian.Uint64(msg[0:8]), 0) + h1, c = bitsAdd64(h1, binary.LittleEndian.Uint64(msg[8:16]), c) + h2 += c + 1 + + msg = msg[TagSize:] + } else { + var buf [TagSize]byte + copy(buf[:], msg) + buf[len(msg)] = 1 + + h0, c = bitsAdd64(h0, binary.LittleEndian.Uint64(buf[0:8]), 0) + h1, c = bitsAdd64(h1, binary.LittleEndian.Uint64(buf[8:16]), c) + h2 += c + + msg = nil + } + + // Multiplication of big number limbs is similar to elementary school + // columnar multiplication. Instead of digits, there are 64-bit limbs. + // + // We are multiplying a 3 limbs number, h, by a 2 limbs number, r. + // + // h2 h1 h0 x + // r1 r0 = + // ---------------- + // h2r0 h1r0 h0r0 <-- individual 128-bit products + // + h2r1 h1r1 h0r1 + // ------------------------ + // m3 m2 m1 m0 <-- result in 128-bit overlapping limbs + // ------------------------ + // m3.hi m2.hi m1.hi m0.hi <-- carry propagation + // + m3.lo m2.lo m1.lo m0.lo + // ------------------------------- + // t4 t3 t2 t1 t0 <-- final result in 64-bit limbs + // + // The main difference from pen-and-paper multiplication is that we do + // carry propagation in a separate step, as if we wrote two digit sums + // at first (the 128-bit limbs), and then carried the tens all at once. + + h0r0 := mul64(h0, r0) + h1r0 := mul64(h1, r0) + h2r0 := mul64(h2, r0) + h0r1 := mul64(h0, r1) + h1r1 := mul64(h1, r1) + h2r1 := mul64(h2, r1) + + // Since h2 is known to be at most 7 (5 + 1 + 1), and r0 and r1 have their + // top 4 bits cleared by rMask{0,1}, we know that their product is not going + // to overflow 64 bits, so we can ignore the high part of the products. + // + // This also means that the product doesn't have a fifth limb (t4). + if h2r0.hi != 0 { + panic("poly1305: unexpected overflow") + } + if h2r1.hi != 0 { + panic("poly1305: unexpected overflow") + } + + m0 := h0r0 + m1 := add128(h1r0, h0r1) // These two additions don't overflow thanks again + m2 := add128(h2r0, h1r1) // to the 4 masked bits at the top of r0 and r1. + m3 := h2r1 + + t0 := m0.lo + t1, c := bitsAdd64(m1.lo, m0.hi, 0) + t2, c := bitsAdd64(m2.lo, m1.hi, c) + t3, _ := bitsAdd64(m3.lo, m2.hi, c) + + // Now we have the result as 4 64-bit limbs, and we need to reduce it + // modulo 2¹³⁰ - 5. The special shape of this Crandall prime lets us do + // a cheap partial reduction according to the reduction identity + // + // c * 2¹³⁰ + n = c * 5 + n mod 2¹³⁰ - 5 + // + // because 2¹³⁰ = 5 mod 2¹³⁰ - 5. Partial reduction since the result is + // likely to be larger than 2¹³⁰ - 5, but still small enough to fit the + // assumptions we make about h in the rest of the code. + // + // See also https://speakerdeck.com/gtank/engineering-prime-numbers?slide=23 + + // We split the final result at the 2¹³⁰ mark into h and cc, the carry. + // Note that the carry bits are effectively shifted left by 2, in other + // words, cc = c * 4 for the c in the reduction identity. + h0, h1, h2 = t0, t1, t2&maskLow2Bits + cc := uint128{t2 & maskNotLow2Bits, t3} + + // To add c * 5 to h, we first add cc = c * 4, and then add (cc >> 2) = c. + + h0, c = bitsAdd64(h0, cc.lo, 0) + h1, c = bitsAdd64(h1, cc.hi, c) + h2 += c + + cc = shiftRightBy2(cc) + + h0, c = bitsAdd64(h0, cc.lo, 0) + h1, c = bitsAdd64(h1, cc.hi, c) + h2 += c + + // h2 is at most 3 + 1 + 1 = 5, making the whole of h at most + // + // 5 * 2¹²⁸ + (2¹²⁸ - 1) = 6 * 2¹²⁸ - 1 + } + + state.h[0], state.h[1], state.h[2] = h0, h1, h2 +} + +const ( + maskLow2Bits uint64 = 0x0000000000000003 + maskNotLow2Bits uint64 = ^maskLow2Bits +) + +// select64 returns x if v == 1 and y if v == 0, in constant time. +func select64(v, x, y uint64) uint64 { return ^(v-1)&x | (v-1)&y } + +// [p0, p1, p2] is 2¹³⁰ - 5 in little endian order. +const ( + p0 = 0xFFFFFFFFFFFFFFFB + p1 = 0xFFFFFFFFFFFFFFFF + p2 = 0x0000000000000003 +) + +// finalize completes the modular reduction of h and computes +// +// out = h + s mod 2¹²⁸ +// +func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) { + h0, h1, h2 := h[0], h[1], h[2] + + // After the partial reduction in updateGeneric, h might be more than + // 2¹³⁰ - 5, but will be less than 2 * (2¹³⁰ - 5). To complete the reduction + // in constant time, we compute t = h - (2¹³⁰ - 5), and select h as the + // result if the subtraction underflows, and t otherwise. + + hMinusP0, b := bitsSub64(h0, p0, 0) + hMinusP1, b := bitsSub64(h1, p1, b) + _, b = bitsSub64(h2, p2, b) + + // h = h if h < p else h - p + h0 = select64(b, h0, hMinusP0) + h1 = select64(b, h1, hMinusP1) + + // Finally, we compute the last Poly1305 step + // + // tag = h + s mod 2¹²⁸ + // + // by just doing a wide addition with the 128 low bits of h and discarding + // the overflow. + h0, c := bitsAdd64(h0, s[0], 0) + h1, _ = bitsAdd64(h1, s[1], c) + + binary.LittleEndian.PutUint64(out[0:8], h0) + binary.LittleEndian.PutUint64(out[8:16], h1) +} diff --git a/internal/crypto/poly1305/sum_ppc64le.go b/internal/crypto/poly1305/sum_ppc64le.go new file mode 100644 index 000000000..2e7a120b1 --- /dev/null +++ b/internal/crypto/poly1305/sum_ppc64le.go @@ -0,0 +1,47 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !gccgo,!purego + +package poly1305 + +//go:noescape +func update(state *macState, msg []byte) + +// mac is a wrapper for macGeneric that redirects calls that would have gone to +// updateGeneric to update. +// +// Its Write and Sum methods are otherwise identical to the macGeneric ones, but +// using function pointers would carry a major performance cost. +type mac struct{ macGeneric } + +func (h *mac) Write(p []byte) (int, error) { + nn := len(p) + if h.offset > 0 { + n := copy(h.buffer[h.offset:], p) + if h.offset+n < TagSize { + h.offset += n + return nn, nil + } + p = p[n:] + h.offset = 0 + update(&h.macState, h.buffer[:]) + } + if n := len(p) - (len(p) % TagSize); n > 0 { + update(&h.macState, p[:n]) + p = p[n:] + } + if len(p) > 0 { + h.offset += copy(h.buffer[h.offset:], p) + } + return nn, nil +} + +func (h *mac) Sum(out *[16]byte) { + state := h.macState + if h.offset > 0 { + update(&state, h.buffer[:h.offset]) + } + finalize(out, &state.h, &state.s) +} diff --git a/internal/crypto/poly1305/sum_ppc64le.s b/internal/crypto/poly1305/sum_ppc64le.s new file mode 100644 index 000000000..4e0281387 --- /dev/null +++ b/internal/crypto/poly1305/sum_ppc64le.s @@ -0,0 +1,181 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !gccgo,!purego + +#include "textflag.h" + +// This was ported from the amd64 implementation. + +#define POLY1305_ADD(msg, h0, h1, h2, t0, t1, t2) \ + MOVD (msg), t0; \ + MOVD 8(msg), t1; \ + MOVD $1, t2; \ + ADDC t0, h0, h0; \ + ADDE t1, h1, h1; \ + ADDE t2, h2; \ + ADD $16, msg + +#define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3, t4, t5) \ + MULLD r0, h0, t0; \ + MULLD r0, h1, t4; \ + MULHDU r0, h0, t1; \ + MULHDU r0, h1, t5; \ + ADDC t4, t1, t1; \ + MULLD r0, h2, t2; \ + ADDZE t5; \ + MULHDU r1, h0, t4; \ + MULLD r1, h0, h0; \ + ADD t5, t2, t2; \ + ADDC h0, t1, t1; \ + MULLD h2, r1, t3; \ + ADDZE t4, h0; \ + MULHDU r1, h1, t5; \ + MULLD r1, h1, t4; \ + ADDC t4, t2, t2; \ + ADDE t5, t3, t3; \ + ADDC h0, t2, t2; \ + MOVD $-4, t4; \ + MOVD t0, h0; \ + MOVD t1, h1; \ + ADDZE t3; \ + ANDCC $3, t2, h2; \ + AND t2, t4, t0; \ + ADDC t0, h0, h0; \ + ADDE t3, h1, h1; \ + SLD $62, t3, t4; \ + SRD $2, t2; \ + ADDZE h2; \ + OR t4, t2, t2; \ + SRD $2, t3; \ + ADDC t2, h0, h0; \ + ADDE t3, h1, h1; \ + ADDZE h2 + +DATA ·poly1305Mask<>+0x00(SB)/8, $0x0FFFFFFC0FFFFFFF +DATA ·poly1305Mask<>+0x08(SB)/8, $0x0FFFFFFC0FFFFFFC +GLOBL ·poly1305Mask<>(SB), RODATA, $16 + +// func update(state *[7]uint64, msg []byte) +TEXT ·update(SB), $0-32 + MOVD state+0(FP), R3 + MOVD msg_base+8(FP), R4 + MOVD msg_len+16(FP), R5 + + MOVD 0(R3), R8 // h0 + MOVD 8(R3), R9 // h1 + MOVD 16(R3), R10 // h2 + MOVD 24(R3), R11 // r0 + MOVD 32(R3), R12 // r1 + + CMP R5, $16 + BLT bytes_between_0_and_15 + +loop: + POLY1305_ADD(R4, R8, R9, R10, R20, R21, R22) + +multiply: + POLY1305_MUL(R8, R9, R10, R11, R12, R16, R17, R18, R14, R20, R21) + ADD $-16, R5 + CMP R5, $16 + BGE loop + +bytes_between_0_and_15: + CMP $0, R5 + BEQ done + MOVD $0, R16 // h0 + MOVD $0, R17 // h1 + +flush_buffer: + CMP R5, $8 + BLE just1 + + MOVD $8, R21 + SUB R21, R5, R21 + + // Greater than 8 -- load the rightmost remaining bytes in msg + // and put into R17 (h1) + MOVD (R4)(R21), R17 + MOVD $16, R22 + + // Find the offset to those bytes + SUB R5, R22, R22 + SLD $3, R22 + + // Shift to get only the bytes in msg + SRD R22, R17, R17 + + // Put 1 at high end + MOVD $1, R23 + SLD $3, R21 + SLD R21, R23, R23 + OR R23, R17, R17 + + // Remainder is 8 + MOVD $8, R5 + +just1: + CMP R5, $8 + BLT less8 + + // Exactly 8 + MOVD (R4), R16 + + CMP $0, R17 + + // Check if we've already set R17; if not + // set 1 to indicate end of msg. + BNE carry + MOVD $1, R17 + BR carry + +less8: + MOVD $0, R16 // h0 + MOVD $0, R22 // shift count + CMP R5, $4 + BLT less4 + MOVWZ (R4), R16 + ADD $4, R4 + ADD $-4, R5 + MOVD $32, R22 + +less4: + CMP R5, $2 + BLT less2 + MOVHZ (R4), R21 + SLD R22, R21, R21 + OR R16, R21, R16 + ADD $16, R22 + ADD $-2, R5 + ADD $2, R4 + +less2: + CMP $0, R5 + BEQ insert1 + MOVBZ (R4), R21 + SLD R22, R21, R21 + OR R16, R21, R16 + ADD $8, R22 + +insert1: + // Insert 1 at end of msg + MOVD $1, R21 + SLD R22, R21, R21 + OR R16, R21, R16 + +carry: + // Add new values to h0, h1, h2 + ADDC R16, R8 + ADDE R17, R9 + ADDE $0, R10 + MOVD $16, R5 + ADD R5, R4 + BR multiply + +done: + // Save h0, h1, h2 in state + MOVD R8, 0(R3) + MOVD R9, 8(R3) + MOVD R10, 16(R3) + RET diff --git a/internal/crypto/poly1305/sum_s390x.go b/internal/crypto/poly1305/sum_s390x.go new file mode 100644 index 000000000..958fedc07 --- /dev/null +++ b/internal/crypto/poly1305/sum_s390x.go @@ -0,0 +1,75 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !gccgo,!purego + +package poly1305 + +import ( + "golang.org/x/sys/cpu" +) + +// updateVX is an assembly implementation of Poly1305 that uses vector +// instructions. It must only be called if the vector facility (vx) is +// available. +//go:noescape +func updateVX(state *macState, msg []byte) + +// mac is a replacement for macGeneric that uses a larger buffer and redirects +// calls that would have gone to updateGeneric to updateVX if the vector +// facility is installed. +// +// A larger buffer is required for good performance because the vector +// implementation has a higher fixed cost per call than the generic +// implementation. +type mac struct { + macState + + buffer [16 * TagSize]byte // size must be a multiple of block size (16) + offset int +} + +func (h *mac) Write(p []byte) (int, error) { + nn := len(p) + if h.offset > 0 { + n := copy(h.buffer[h.offset:], p) + if h.offset+n < len(h.buffer) { + h.offset += n + return nn, nil + } + p = p[n:] + h.offset = 0 + if cpu.S390X.HasVX { + updateVX(&h.macState, h.buffer[:]) + } else { + updateGeneric(&h.macState, h.buffer[:]) + } + } + + tail := len(p) % len(h.buffer) // number of bytes to copy into buffer + body := len(p) - tail // number of bytes to process now + if body > 0 { + if cpu.S390X.HasVX { + updateVX(&h.macState, p[:body]) + } else { + updateGeneric(&h.macState, p[:body]) + } + } + h.offset = copy(h.buffer[:], p[body:]) // copy tail bytes - can be 0 + return nn, nil +} + +func (h *mac) Sum(out *[TagSize]byte) { + state := h.macState + remainder := h.buffer[:h.offset] + + // Use the generic implementation if we have 2 or fewer blocks left + // to sum. The vector implementation has a higher startup time. + if cpu.S390X.HasVX && len(remainder) > 2*TagSize { + updateVX(&state, remainder) + } else if len(remainder) > 0 { + updateGeneric(&state, remainder) + } + finalize(out, &state.h, &state.s) +} diff --git a/internal/crypto/poly1305/sum_s390x.s b/internal/crypto/poly1305/sum_s390x.s new file mode 100644 index 000000000..0fa9ee6e0 --- /dev/null +++ b/internal/crypto/poly1305/sum_s390x.s @@ -0,0 +1,503 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !gccgo,!purego + +#include "textflag.h" + +// This implementation of Poly1305 uses the vector facility (vx) +// to process up to 2 blocks (32 bytes) per iteration using an +// algorithm based on the one described in: +// +// NEON crypto, Daniel J. Bernstein & Peter Schwabe +// https://cryptojedi.org/papers/neoncrypto-20120320.pdf +// +// This algorithm uses 5 26-bit limbs to represent a 130-bit +// value. These limbs are, for the most part, zero extended and +// placed into 64-bit vector register elements. Each vector +// register is 128-bits wide and so holds 2 of these elements. +// Using 26-bit limbs allows us plenty of headroom to accomodate +// accumulations before and after multiplication without +// overflowing either 32-bits (before multiplication) or 64-bits +// (after multiplication). +// +// In order to parallelise the operations required to calculate +// the sum we use two separate accumulators and then sum those +// in an extra final step. For compatibility with the generic +// implementation we perform this summation at the end of every +// updateVX call. +// +// To use two accumulators we must multiply the message blocks +// by r² rather than r. Only the final message block should be +// multiplied by r. +// +// Example: +// +// We want to calculate the sum (h) for a 64 byte message (m): +// +// h = m[0:16]r⁴ + m[16:32]r³ + m[32:48]r² + m[48:64]r +// +// To do this we split the calculation into the even indices +// and odd indices of the message. These form our SIMD 'lanes': +// +// h = m[ 0:16]r⁴ + m[32:48]r² + <- lane 0 +// m[16:32]r³ + m[48:64]r <- lane 1 +// +// To calculate this iteratively we refactor so that both lanes +// are written in terms of r² and r: +// +// h = (m[ 0:16]r² + m[32:48])r² + <- lane 0 +// (m[16:32]r² + m[48:64])r <- lane 1 +// ^ ^ +// | coefficients for second iteration +// coefficients for first iteration +// +// So in this case we would have two iterations. In the first +// both lanes are multiplied by r². In the second only the +// first lane is multiplied by r² and the second lane is +// instead multiplied by r. This gives use the odd and even +// powers of r that we need from the original equation. +// +// Notation: +// +// h - accumulator +// r - key +// m - message +// +// [a, b] - SIMD register holding two 64-bit values +// [a, b, c, d] - SIMD register holding four 32-bit values +// xᵢ[n] - limb n of variable x with bit width i +// +// Limbs are expressed in little endian order, so for 26-bit +// limbs x₂₆[4] will be the most significant limb and x₂₆[0] +// will be the least significant limb. + +// masking constants +#define MOD24 V0 // [0x0000000000ffffff, 0x0000000000ffffff] - mask low 24-bits +#define MOD26 V1 // [0x0000000003ffffff, 0x0000000003ffffff] - mask low 26-bits + +// expansion constants (see EXPAND macro) +#define EX0 V2 +#define EX1 V3 +#define EX2 V4 + +// key (r², r or 1 depending on context) +#define R_0 V5 +#define R_1 V6 +#define R_2 V7 +#define R_3 V8 +#define R_4 V9 + +// precalculated coefficients (5r², 5r or 0 depending on context) +#define R5_1 V10 +#define R5_2 V11 +#define R5_3 V12 +#define R5_4 V13 + +// message block (m) +#define M_0 V14 +#define M_1 V15 +#define M_2 V16 +#define M_3 V17 +#define M_4 V18 + +// accumulator (h) +#define H_0 V19 +#define H_1 V20 +#define H_2 V21 +#define H_3 V22 +#define H_4 V23 + +// temporary registers (for short-lived values) +#define T_0 V24 +#define T_1 V25 +#define T_2 V26 +#define T_3 V27 +#define T_4 V28 + +GLOBL ·constants<>(SB), RODATA, $0x30 +// EX0 +DATA ·constants<>+0x00(SB)/8, $0x0006050403020100 +DATA ·constants<>+0x08(SB)/8, $0x1016151413121110 +// EX1 +DATA ·constants<>+0x10(SB)/8, $0x060c0b0a09080706 +DATA ·constants<>+0x18(SB)/8, $0x161c1b1a19181716 +// EX2 +DATA ·constants<>+0x20(SB)/8, $0x0d0d0d0d0d0f0e0d +DATA ·constants<>+0x28(SB)/8, $0x1d1d1d1d1d1f1e1d + +// MULTIPLY multiplies each lane of f and g, partially reduced +// modulo 2¹³⁰ - 5. The result, h, consists of partial products +// in each lane that need to be reduced further to produce the +// final result. +// +// h₁₃₀ = (f₁₃₀g₁₃₀) % 2¹³⁰ + (5f₁₃₀g₁₃₀) / 2¹³⁰ +// +// Note that the multiplication by 5 of the high bits is +// achieved by precalculating the multiplication of four of the +// g coefficients by 5. These are g51-g54. +#define MULTIPLY(f0, f1, f2, f3, f4, g0, g1, g2, g3, g4, g51, g52, g53, g54, h0, h1, h2, h3, h4) \ + VMLOF f0, g0, h0 \ + VMLOF f0, g3, h3 \ + VMLOF f0, g1, h1 \ + VMLOF f0, g4, h4 \ + VMLOF f0, g2, h2 \ + VMLOF f1, g54, T_0 \ + VMLOF f1, g2, T_3 \ + VMLOF f1, g0, T_1 \ + VMLOF f1, g3, T_4 \ + VMLOF f1, g1, T_2 \ + VMALOF f2, g53, h0, h0 \ + VMALOF f2, g1, h3, h3 \ + VMALOF f2, g54, h1, h1 \ + VMALOF f2, g2, h4, h4 \ + VMALOF f2, g0, h2, h2 \ + VMALOF f3, g52, T_0, T_0 \ + VMALOF f3, g0, T_3, T_3 \ + VMALOF f3, g53, T_1, T_1 \ + VMALOF f3, g1, T_4, T_4 \ + VMALOF f3, g54, T_2, T_2 \ + VMALOF f4, g51, h0, h0 \ + VMALOF f4, g54, h3, h3 \ + VMALOF f4, g52, h1, h1 \ + VMALOF f4, g0, h4, h4 \ + VMALOF f4, g53, h2, h2 \ + VAG T_0, h0, h0 \ + VAG T_3, h3, h3 \ + VAG T_1, h1, h1 \ + VAG T_4, h4, h4 \ + VAG T_2, h2, h2 + +// REDUCE performs the following carry operations in four +// stages, as specified in Bernstein & Schwabe: +// +// 1: h₂₆[0]->h₂₆[1] h₂₆[3]->h₂₆[4] +// 2: h₂₆[1]->h₂₆[2] h₂₆[4]->h₂₆[0] +// 3: h₂₆[0]->h₂₆[1] h₂₆[2]->h₂₆[3] +// 4: h₂₆[3]->h₂₆[4] +// +// The result is that all of the limbs are limited to 26-bits +// except for h₂₆[1] and h₂₆[4] which are limited to 27-bits. +// +// Note that although each limb is aligned at 26-bit intervals +// they may contain values that exceed 2²⁶ - 1, hence the need +// to carry the excess bits in each limb. +#define REDUCE(h0, h1, h2, h3, h4) \ + VESRLG $26, h0, T_0 \ + VESRLG $26, h3, T_1 \ + VN MOD26, h0, h0 \ + VN MOD26, h3, h3 \ + VAG T_0, h1, h1 \ + VAG T_1, h4, h4 \ + VESRLG $26, h1, T_2 \ + VESRLG $26, h4, T_3 \ + VN MOD26, h1, h1 \ + VN MOD26, h4, h4 \ + VESLG $2, T_3, T_4 \ + VAG T_3, T_4, T_4 \ + VAG T_2, h2, h2 \ + VAG T_4, h0, h0 \ + VESRLG $26, h2, T_0 \ + VESRLG $26, h0, T_1 \ + VN MOD26, h2, h2 \ + VN MOD26, h0, h0 \ + VAG T_0, h3, h3 \ + VAG T_1, h1, h1 \ + VESRLG $26, h3, T_2 \ + VN MOD26, h3, h3 \ + VAG T_2, h4, h4 + +// EXPAND splits the 128-bit little-endian values in0 and in1 +// into 26-bit big-endian limbs and places the results into +// the first and second lane of d₂₆[0:4] respectively. +// +// The EX0, EX1 and EX2 constants are arrays of byte indices +// for permutation. The permutation both reverses the bytes +// in the input and ensures the bytes are copied into the +// destination limb ready to be shifted into their final +// position. +#define EXPAND(in0, in1, d0, d1, d2, d3, d4) \ + VPERM in0, in1, EX0, d0 \ + VPERM in0, in1, EX1, d2 \ + VPERM in0, in1, EX2, d4 \ + VESRLG $26, d0, d1 \ + VESRLG $30, d2, d3 \ + VESRLG $4, d2, d2 \ + VN MOD26, d0, d0 \ // [in0₂₆[0], in1₂₆[0]] + VN MOD26, d3, d3 \ // [in0₂₆[3], in1₂₆[3]] + VN MOD26, d1, d1 \ // [in0₂₆[1], in1₂₆[1]] + VN MOD24, d4, d4 \ // [in0₂₆[4], in1₂₆[4]] + VN MOD26, d2, d2 // [in0₂₆[2], in1₂₆[2]] + +// func updateVX(state *macState, msg []byte) +TEXT ·updateVX(SB), NOSPLIT, $0 + MOVD state+0(FP), R1 + LMG msg+8(FP), R2, R3 // R2=msg_base, R3=msg_len + + // load EX0, EX1 and EX2 + MOVD $·constants<>(SB), R5 + VLM (R5), EX0, EX2 + + // generate masks + VGMG $(64-24), $63, MOD24 // [0x00ffffff, 0x00ffffff] + VGMG $(64-26), $63, MOD26 // [0x03ffffff, 0x03ffffff] + + // load h (accumulator) and r (key) from state + VZERO T_1 // [0, 0] + VL 0(R1), T_0 // [h₆₄[0], h₆₄[1]] + VLEG $0, 16(R1), T_1 // [h₆₄[2], 0] + VL 24(R1), T_2 // [r₆₄[0], r₆₄[1]] + VPDI $0, T_0, T_2, T_3 // [h₆₄[0], r₆₄[0]] + VPDI $5, T_0, T_2, T_4 // [h₆₄[1], r₆₄[1]] + + // unpack h and r into 26-bit limbs + // note: h₆₄[2] may have the low 3 bits set, so h₂₆[4] is a 27-bit value + VN MOD26, T_3, H_0 // [h₂₆[0], r₂₆[0]] + VZERO H_1 // [0, 0] + VZERO H_3 // [0, 0] + VGMG $(64-12-14), $(63-12), T_0 // [0x03fff000, 0x03fff000] - 26-bit mask with low 12 bits masked out + VESLG $24, T_1, T_1 // [h₆₄[2]<<24, 0] + VERIMG $-26&63, T_3, MOD26, H_1 // [h₂₆[1], r₂₆[1]] + VESRLG $+52&63, T_3, H_2 // [h₂₆[2], r₂₆[2]] - low 12 bits only + VERIMG $-14&63, T_4, MOD26, H_3 // [h₂₆[1], r₂₆[1]] + VESRLG $40, T_4, H_4 // [h₂₆[4], r₂₆[4]] - low 24 bits only + VERIMG $+12&63, T_4, T_0, H_2 // [h₂₆[2], r₂₆[2]] - complete + VO T_1, H_4, H_4 // [h₂₆[4], r₂₆[4]] - complete + + // replicate r across all 4 vector elements + VREPF $3, H_0, R_0 // [r₂₆[0], r₂₆[0], r₂₆[0], r₂₆[0]] + VREPF $3, H_1, R_1 // [r₂₆[1], r₂₆[1], r₂₆[1], r₂₆[1]] + VREPF $3, H_2, R_2 // [r₂₆[2], r₂₆[2], r₂₆[2], r₂₆[2]] + VREPF $3, H_3, R_3 // [r₂₆[3], r₂₆[3], r₂₆[3], r₂₆[3]] + VREPF $3, H_4, R_4 // [r₂₆[4], r₂₆[4], r₂₆[4], r₂₆[4]] + + // zero out lane 1 of h + VLEIG $1, $0, H_0 // [h₂₆[0], 0] + VLEIG $1, $0, H_1 // [h₂₆[1], 0] + VLEIG $1, $0, H_2 // [h₂₆[2], 0] + VLEIG $1, $0, H_3 // [h₂₆[3], 0] + VLEIG $1, $0, H_4 // [h₂₆[4], 0] + + // calculate 5r (ignore least significant limb) + VREPIF $5, T_0 + VMLF T_0, R_1, R5_1 // [5r₂₆[1], 5r₂₆[1], 5r₂₆[1], 5r₂₆[1]] + VMLF T_0, R_2, R5_2 // [5r₂₆[2], 5r₂₆[2], 5r₂₆[2], 5r₂₆[2]] + VMLF T_0, R_3, R5_3 // [5r₂₆[3], 5r₂₆[3], 5r₂₆[3], 5r₂₆[3]] + VMLF T_0, R_4, R5_4 // [5r₂₆[4], 5r₂₆[4], 5r₂₆[4], 5r₂₆[4]] + + // skip r² calculation if we are only calculating one block + CMPBLE R3, $16, skip + + // calculate r² + MULTIPLY(R_0, R_1, R_2, R_3, R_4, R_0, R_1, R_2, R_3, R_4, R5_1, R5_2, R5_3, R5_4, M_0, M_1, M_2, M_3, M_4) + REDUCE(M_0, M_1, M_2, M_3, M_4) + VGBM $0x0f0f, T_0 + VERIMG $0, M_0, T_0, R_0 // [r₂₆[0], r²₂₆[0], r₂₆[0], r²₂₆[0]] + VERIMG $0, M_1, T_0, R_1 // [r₂₆[1], r²₂₆[1], r₂₆[1], r²₂₆[1]] + VERIMG $0, M_2, T_0, R_2 // [r₂₆[2], r²₂₆[2], r₂₆[2], r²₂₆[2]] + VERIMG $0, M_3, T_0, R_3 // [r₂₆[3], r²₂₆[3], r₂₆[3], r²₂₆[3]] + VERIMG $0, M_4, T_0, R_4 // [r₂₆[4], r²₂₆[4], r₂₆[4], r²₂₆[4]] + + // calculate 5r² (ignore least significant limb) + VREPIF $5, T_0 + VMLF T_0, R_1, R5_1 // [5r₂₆[1], 5r²₂₆[1], 5r₂₆[1], 5r²₂₆[1]] + VMLF T_0, R_2, R5_2 // [5r₂₆[2], 5r²₂₆[2], 5r₂₆[2], 5r²₂₆[2]] + VMLF T_0, R_3, R5_3 // [5r₂₆[3], 5r²₂₆[3], 5r₂₆[3], 5r²₂₆[3]] + VMLF T_0, R_4, R5_4 // [5r₂₆[4], 5r²₂₆[4], 5r₂₆[4], 5r²₂₆[4]] + +loop: + CMPBLE R3, $32, b2 // 2 or fewer blocks remaining, need to change key coefficients + + // load next 2 blocks from message + VLM (R2), T_0, T_1 + + // update message slice + SUB $32, R3 + MOVD $32(R2), R2 + + // unpack message blocks into 26-bit big-endian limbs + EXPAND(T_0, T_1, M_0, M_1, M_2, M_3, M_4) + + // add 2¹²⁸ to each message block value + VLEIB $4, $1, M_4 + VLEIB $12, $1, M_4 + +multiply: + // accumulate the incoming message + VAG H_0, M_0, M_0 + VAG H_3, M_3, M_3 + VAG H_1, M_1, M_1 + VAG H_4, M_4, M_4 + VAG H_2, M_2, M_2 + + // multiply the accumulator by the key coefficient + MULTIPLY(M_0, M_1, M_2, M_3, M_4, R_0, R_1, R_2, R_3, R_4, R5_1, R5_2, R5_3, R5_4, H_0, H_1, H_2, H_3, H_4) + + // carry and partially reduce the partial products + REDUCE(H_0, H_1, H_2, H_3, H_4) + + CMPBNE R3, $0, loop + +finish: + // sum lane 0 and lane 1 and put the result in lane 1 + VZERO T_0 + VSUMQG H_0, T_0, H_0 + VSUMQG H_3, T_0, H_3 + VSUMQG H_1, T_0, H_1 + VSUMQG H_4, T_0, H_4 + VSUMQG H_2, T_0, H_2 + + // reduce again after summation + // TODO(mundaym): there might be a more efficient way to do this + // now that we only have 1 active lane. For example, we could + // simultaneously pack the values as we reduce them. + REDUCE(H_0, H_1, H_2, H_3, H_4) + + // carry h[1] through to h[4] so that only h[4] can exceed 2²⁶ - 1 + // TODO(mundaym): in testing this final carry was unnecessary. + // Needs a proof before it can be removed though. + VESRLG $26, H_1, T_1 + VN MOD26, H_1, H_1 + VAQ T_1, H_2, H_2 + VESRLG $26, H_2, T_2 + VN MOD26, H_2, H_2 + VAQ T_2, H_3, H_3 + VESRLG $26, H_3, T_3 + VN MOD26, H_3, H_3 + VAQ T_3, H_4, H_4 + + // h is now < 2(2¹³⁰ - 5) + // Pack each lane in h₂₆[0:4] into h₁₂₈[0:1]. + VESLG $26, H_1, H_1 + VESLG $26, H_3, H_3 + VO H_0, H_1, H_0 + VO H_2, H_3, H_2 + VESLG $4, H_2, H_2 + VLEIB $7, $48, H_1 + VSLB H_1, H_2, H_2 + VO H_0, H_2, H_0 + VLEIB $7, $104, H_1 + VSLB H_1, H_4, H_3 + VO H_3, H_0, H_0 + VLEIB $7, $24, H_1 + VSRLB H_1, H_4, H_1 + + // update state + VSTEG $1, H_0, 0(R1) + VSTEG $0, H_0, 8(R1) + VSTEG $1, H_1, 16(R1) + RET + +b2: // 2 or fewer blocks remaining + CMPBLE R3, $16, b1 + + // Load the 2 remaining blocks (17-32 bytes remaining). + MOVD $-17(R3), R0 // index of final byte to load modulo 16 + VL (R2), T_0 // load full 16 byte block + VLL R0, 16(R2), T_1 // load final (possibly partial) block and pad with zeros to 16 bytes + + // The Poly1305 algorithm requires that a 1 bit be appended to + // each message block. If the final block is less than 16 bytes + // long then it is easiest to insert the 1 before the message + // block is split into 26-bit limbs. If, on the other hand, the + // final message block is 16 bytes long then we append the 1 bit + // after expansion as normal. + MOVBZ $1, R0 + MOVD $-16(R3), R3 // index of byte in last block to insert 1 at (could be 16) + CMPBEQ R3, $16, 2(PC) // skip the insertion if the final block is 16 bytes long + VLVGB R3, R0, T_1 // insert 1 into the byte at index R3 + + // Split both blocks into 26-bit limbs in the appropriate lanes. + EXPAND(T_0, T_1, M_0, M_1, M_2, M_3, M_4) + + // Append a 1 byte to the end of the second to last block. + VLEIB $4, $1, M_4 + + // Append a 1 byte to the end of the last block only if it is a + // full 16 byte block. + CMPBNE R3, $16, 2(PC) + VLEIB $12, $1, M_4 + + // Finally, set up the coefficients for the final multiplication. + // We have previously saved r and 5r in the 32-bit even indexes + // of the R_[0-4] and R5_[1-4] coefficient registers. + // + // We want lane 0 to be multiplied by r² so that can be kept the + // same. We want lane 1 to be multiplied by r so we need to move + // the saved r value into the 32-bit odd index in lane 1 by + // rotating the 64-bit lane by 32. + VGBM $0x00ff, T_0 // [0, 0xffffffffffffffff] - mask lane 1 only + VERIMG $32, R_0, T_0, R_0 // [_, r²₂₆[0], _, r₂₆[0]] + VERIMG $32, R_1, T_0, R_1 // [_, r²₂₆[1], _, r₂₆[1]] + VERIMG $32, R_2, T_0, R_2 // [_, r²₂₆[2], _, r₂₆[2]] + VERIMG $32, R_3, T_0, R_3 // [_, r²₂₆[3], _, r₂₆[3]] + VERIMG $32, R_4, T_0, R_4 // [_, r²₂₆[4], _, r₂₆[4]] + VERIMG $32, R5_1, T_0, R5_1 // [_, 5r²₂₆[1], _, 5r₂₆[1]] + VERIMG $32, R5_2, T_0, R5_2 // [_, 5r²₂₆[2], _, 5r₂₆[2]] + VERIMG $32, R5_3, T_0, R5_3 // [_, 5r²₂₆[3], _, 5r₂₆[3]] + VERIMG $32, R5_4, T_0, R5_4 // [_, 5r²₂₆[4], _, 5r₂₆[4]] + + MOVD $0, R3 + BR multiply + +skip: + CMPBEQ R3, $0, finish + +b1: // 1 block remaining + + // Load the final block (1-16 bytes). This will be placed into + // lane 0. + MOVD $-1(R3), R0 + VLL R0, (R2), T_0 // pad to 16 bytes with zeros + + // The Poly1305 algorithm requires that a 1 bit be appended to + // each message block. If the final block is less than 16 bytes + // long then it is easiest to insert the 1 before the message + // block is split into 26-bit limbs. If, on the other hand, the + // final message block is 16 bytes long then we append the 1 bit + // after expansion as normal. + MOVBZ $1, R0 + CMPBEQ R3, $16, 2(PC) + VLVGB R3, R0, T_0 + + // Set the message block in lane 1 to the value 0 so that it + // can be accumulated without affecting the final result. + VZERO T_1 + + // Split the final message block into 26-bit limbs in lane 0. + // Lane 1 will be contain 0. + EXPAND(T_0, T_1, M_0, M_1, M_2, M_3, M_4) + + // Append a 1 byte to the end of the last block only if it is a + // full 16 byte block. + CMPBNE R3, $16, 2(PC) + VLEIB $4, $1, M_4 + + // We have previously saved r and 5r in the 32-bit even indexes + // of the R_[0-4] and R5_[1-4] coefficient registers. + // + // We want lane 0 to be multiplied by r so we need to move the + // saved r value into the 32-bit odd index in lane 0. We want + // lane 1 to be set to the value 1. This makes multiplication + // a no-op. We do this by setting lane 1 in every register to 0 + // and then just setting the 32-bit index 3 in R_0 to 1. + VZERO T_0 + MOVD $0, R0 + MOVD $0x10111213, R12 + VLVGP R12, R0, T_1 // [_, 0x10111213, _, 0x00000000] + VPERM T_0, R_0, T_1, R_0 // [_, r₂₆[0], _, 0] + VPERM T_0, R_1, T_1, R_1 // [_, r₂₆[1], _, 0] + VPERM T_0, R_2, T_1, R_2 // [_, r₂₆[2], _, 0] + VPERM T_0, R_3, T_1, R_3 // [_, r₂₆[3], _, 0] + VPERM T_0, R_4, T_1, R_4 // [_, r₂₆[4], _, 0] + VPERM T_0, R5_1, T_1, R5_1 // [_, 5r₂₆[1], _, 0] + VPERM T_0, R5_2, T_1, R5_2 // [_, 5r₂₆[2], _, 0] + VPERM T_0, R5_3, T_1, R5_3 // [_, 5r₂₆[3], _, 0] + VPERM T_0, R5_4, T_1, R5_4 // [_, 5r₂₆[4], _, 0] + + // Set the value of lane 1 to be 1. + VLEIF $3, $1, R_0 // [_, r₂₆[0], _, 1] + + MOVD $0, R3 + BR multiply diff --git a/internal/crypto/ssh/buffer.go b/internal/crypto/ssh/buffer.go new file mode 100644 index 000000000..1ab07d078 --- /dev/null +++ b/internal/crypto/ssh/buffer.go @@ -0,0 +1,97 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "sync" +) + +// buffer provides a linked list buffer for data exchange +// between producer and consumer. Theoretically the buffer is +// of unlimited capacity as it does no allocation of its own. +type buffer struct { + // protects concurrent access to head, tail and closed + *sync.Cond + + head *element // the buffer that will be read first + tail *element // the buffer that will be read last + + closed bool +} + +// An element represents a single link in a linked list. +type element struct { + buf []byte + next *element +} + +// newBuffer returns an empty buffer that is not closed. +func newBuffer() *buffer { + e := new(element) + b := &buffer{ + Cond: newCond(), + head: e, + tail: e, + } + return b +} + +// write makes buf available for Read to receive. +// buf must not be modified after the call to write. +func (b *buffer) write(buf []byte) { + b.Cond.L.Lock() + e := &element{buf: buf} + b.tail.next = e + b.tail = e + b.Cond.Signal() + b.Cond.L.Unlock() +} + +// eof closes the buffer. Reads from the buffer once all +// the data has been consumed will receive io.EOF. +func (b *buffer) eof() { + b.Cond.L.Lock() + b.closed = true + b.Cond.Signal() + b.Cond.L.Unlock() +} + +// Read reads data from the internal buffer in buf. Reads will block +// if no data is available, or until the buffer is closed. +func (b *buffer) Read(buf []byte) (n int, err error) { + b.Cond.L.Lock() + defer b.Cond.L.Unlock() + + for len(buf) > 0 { + // if there is data in b.head, copy it + if len(b.head.buf) > 0 { + r := copy(buf, b.head.buf) + buf, b.head.buf = buf[r:], b.head.buf[r:] + n += r + continue + } + // if there is a next buffer, make it the head + if len(b.head.buf) == 0 && b.head != b.tail { + b.head = b.head.next + continue + } + + // if at least one byte has been copied, return + if n > 0 { + break + } + + // if nothing was read, and there is nothing outstanding + // check to see if the buffer is closed. + if b.closed { + err = io.EOF + break + } + // out of buffers, wait for producer + b.Cond.Wait() + } + return +} diff --git a/internal/crypto/ssh/certs.go b/internal/crypto/ssh/certs.go new file mode 100644 index 000000000..3bf086828 --- /dev/null +++ b/internal/crypto/ssh/certs.go @@ -0,0 +1,556 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "sort" + "time" +) + +// These constants from [PROTOCOL.certkeys] represent the algorithm names +// for certificate types supported by this package. +const ( + CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com" + CertAlgoRSASHA2256v01 = "rsa-sha2-256-cert-v01@openssh.com" + CertAlgoRSASHA2512v01 = "rsa-sha2-512-cert-v01@openssh.com" + CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com" + CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com" + CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com" + CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com" + CertAlgoSKECDSA256v01 = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com" + CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com" + CertAlgoSKED25519v01 = "sk-ssh-ed25519-cert-v01@openssh.com" +) + +// Certificate types distinguish between host and user +// certificates. The values can be set in the CertType field of +// Certificate. +const ( + UserCert = 1 + HostCert = 2 +) + +// Signature represents a cryptographic signature. +type Signature struct { + Format string + Blob []byte + Rest []byte `ssh:"rest"` +} + +// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that +// a certificate does not expire. +const CertTimeInfinity = 1<<64 - 1 + +// An Certificate represents an OpenSSH certificate as defined in +// [PROTOCOL.certkeys]?rev=1.8. The Certificate type implements the +// PublicKey interface, so it can be unmarshaled using +// ParsePublicKey. +type Certificate struct { + Nonce []byte + Key PublicKey + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []string + ValidAfter uint64 + ValidBefore uint64 + Permissions + Reserved []byte + SignatureKey PublicKey + Signature *Signature +} + +// genericCertData holds the key-independent part of the certificate data. +// Overall, certificates contain an nonce, public key fields and +// key-independent fields. +type genericCertData struct { + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []byte + ValidAfter uint64 + ValidBefore uint64 + CriticalOptions []byte + Extensions []byte + Reserved []byte + SignatureKey []byte + Signature []byte +} + +func marshalStringList(namelist []string) []byte { + var to []byte + for _, name := range namelist { + s := struct{ N string }{name} + to = append(to, Marshal(&s)...) + } + return to +} + +type optionsTuple struct { + Key string + Value []byte +} + +type optionsTupleValue struct { + Value string +} + +// serialize a map of critical options or extensions +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty string value +func marshalTuples(tups map[string]string) []byte { + keys := make([]string, 0, len(tups)) + for key := range tups { + keys = append(keys, key) + } + sort.Strings(keys) + + var ret []byte + for _, key := range keys { + s := optionsTuple{Key: key} + if value := tups[key]; len(value) > 0 { + s.Value = Marshal(&optionsTupleValue{value}) + } + ret = append(ret, Marshal(&s)...) + } + return ret +} + +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty option value +func parseTuples(in []byte) (map[string]string, error) { + tups := map[string]string{} + var lastKey string + var haveLastKey bool + + for len(in) > 0 { + var key, val, extra []byte + var ok bool + + if key, in, ok = parseString(in); !ok { + return nil, errShortRead + } + keyStr := string(key) + // according to [PROTOCOL.certkeys], the names must be in + // lexical order. + if haveLastKey && keyStr <= lastKey { + return nil, fmt.Errorf("ssh: certificate options are not in lexical order") + } + lastKey, haveLastKey = keyStr, true + // the next field is a data field, which if non-empty has a string embedded + if val, in, ok = parseString(in); !ok { + return nil, errShortRead + } + if len(val) > 0 { + val, extra, ok = parseString(val) + if !ok { + return nil, errShortRead + } + if len(extra) > 0 { + return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value") + } + tups[keyStr] = string(val) + } else { + tups[keyStr] = "" + } + } + return tups, nil +} + +func parseCert(in []byte, privAlgo string) (*Certificate, error) { + nonce, rest, ok := parseString(in) + if !ok { + return nil, errShortRead + } + + key, rest, err := parsePubKey(rest, privAlgo) + if err != nil { + return nil, err + } + + var g genericCertData + if err := Unmarshal(rest, &g); err != nil { + return nil, err + } + + c := &Certificate{ + Nonce: nonce, + Key: key, + Serial: g.Serial, + CertType: g.CertType, + KeyId: g.KeyId, + ValidAfter: g.ValidAfter, + ValidBefore: g.ValidBefore, + } + + for principals := g.ValidPrincipals; len(principals) > 0; { + principal, rest, ok := parseString(principals) + if !ok { + return nil, errShortRead + } + c.ValidPrincipals = append(c.ValidPrincipals, string(principal)) + principals = rest + } + + c.CriticalOptions, err = parseTuples(g.CriticalOptions) + if err != nil { + return nil, err + } + c.Extensions, err = parseTuples(g.Extensions) + if err != nil { + return nil, err + } + c.Reserved = g.Reserved + k, err := ParsePublicKey(g.SignatureKey) + if err != nil { + return nil, err + } + + c.SignatureKey = k + c.Signature, rest, ok = parseSignatureBody(g.Signature) + if !ok || len(rest) > 0 { + return nil, errors.New("ssh: signature parse error") + } + + return c, nil +} + +type openSSHCertSigner struct { + pub *Certificate + signer Signer +} + +type algorithmOpenSSHCertSigner struct { + *openSSHCertSigner + algorithmSigner AlgorithmSigner +} + +// NewCertSigner returns a Signer that signs with the given Certificate, whose +// private key is held by signer. It returns an error if the public key in cert +// doesn't match the key used by signer. +func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { + if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { + return nil, errors.New("ssh: signer and cert have different public key") + } + + if algorithmSigner, ok := signer.(AlgorithmSigner); ok { + return &algorithmOpenSSHCertSigner{ + &openSSHCertSigner{cert, signer}, algorithmSigner}, nil + } else { + return &openSSHCertSigner{cert, signer}, nil + } +} + +func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.signer.Sign(rand, data) +} + +func (s *openSSHCertSigner) PublicKey() PublicKey { + return s.pub +} + +func (s *algorithmOpenSSHCertSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + return s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm) +} + +const sourceAddressCriticalOption = "source-address" + +// CertChecker does the work of verifying a certificate. Its methods +// can be plugged into ClientConfig.HostKeyCallback and +// ServerConfig.PublicKeyCallback. For the CertChecker to work, +// minimally, the IsAuthority callback should be set. +type CertChecker struct { + // SupportedCriticalOptions lists the CriticalOptions that the + // server application layer understands. These are only used + // for user certificates. + SupportedCriticalOptions []string + + // IsUserAuthority should return true if the key is recognized as an + // authority for the given user certificate. This allows for + // certificates to be signed by other certificates. This must be set + // if this CertChecker will be checking user certificates. + IsUserAuthority func(auth PublicKey) bool + + // IsHostAuthority should report whether the key is recognized as + // an authority for this host. This allows for certificates to be + // signed by other keys, and for those other keys to only be valid + // signers for particular hostnames. This must be set if this + // CertChecker will be checking host certificates. + IsHostAuthority func(auth PublicKey, address string) bool + + // Clock is used for verifying time stamps. If nil, time.Now + // is used. + Clock func() time.Time + + // UserKeyFallback is called when CertChecker.Authenticate encounters a + // public key that is not a certificate. It must implement validation + // of user keys or else, if nil, all such keys are rejected. + UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // HostKeyFallback is called when CertChecker.CheckHostKey encounters a + // public key that is not a certificate. It must implement host key + // validation or else, if nil, all such keys are rejected. + HostKeyFallback HostKeyCallback + + // IsRevoked is called for each certificate so that revocation checking + // can be implemented. It should return true if the given certificate + // is revoked and false otherwise. If nil, no certificates are + // considered to have been revoked. + IsRevoked func(cert *Certificate) bool +} + +// CheckHostKey checks a host key certificate. This method can be +// plugged into ClientConfig.HostKeyCallback. +func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error { + cert, ok := key.(*Certificate) + if !ok { + if c.HostKeyFallback != nil { + return c.HostKeyFallback(addr, remote, key) + } + return errors.New("ssh: non-certificate host key") + } + if cert.CertType != HostCert { + return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) + } + if !c.IsHostAuthority(cert.SignatureKey, addr) { + return fmt.Errorf("ssh: no authorities for hostname: %v", addr) + } + + hostname, _, err := net.SplitHostPort(addr) + if err != nil { + return err + } + + // Pass hostname only as principal for host certificates (consistent with OpenSSH) + return c.CheckCert(hostname, cert) +} + +// Authenticate checks a user certificate. Authenticate can be used as +// a value for ServerConfig.PublicKeyCallback. +func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) { + cert, ok := pubKey.(*Certificate) + if !ok { + if c.UserKeyFallback != nil { + return c.UserKeyFallback(conn, pubKey) + } + return nil, errors.New("ssh: normal key pairs not accepted") + } + + if cert.CertType != UserCert { + return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) + } + if !c.IsUserAuthority(cert.SignatureKey) { + return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority") + } + + if err := c.CheckCert(conn.User(), cert); err != nil { + return nil, err + } + + return &cert.Permissions, nil +} + +// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and +// the signature of the certificate. +func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { + if c.IsRevoked != nil && c.IsRevoked(cert) { + return fmt.Errorf("ssh: certificate serial %d revoked", cert.Serial) + } + + for opt := range cert.CriticalOptions { + // sourceAddressCriticalOption will be enforced by + // serverAuthenticate + if opt == sourceAddressCriticalOption { + continue + } + + found := false + for _, supp := range c.SupportedCriticalOptions { + if supp == opt { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt) + } + } + + if len(cert.ValidPrincipals) > 0 { + // By default, certs are valid for all users/hosts. + found := false + for _, p := range cert.ValidPrincipals { + if p == principal { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals) + } + } + + clock := c.Clock + if clock == nil { + clock = time.Now + } + + unixNow := clock().Unix() + if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { + return fmt.Errorf("ssh: cert is not yet valid") + } + if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) { + return fmt.Errorf("ssh: cert has expired") + } + if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { + return fmt.Errorf("ssh: certificate signature does not verify") + } + + return nil +} + +// SignCert signs the certificate with an authority, setting the Nonce, +// SignatureKey, and Signature fields. +func (c *Certificate) SignCert(rand io.Reader, authority Signer) error { + c.Nonce = make([]byte, 32) + if _, err := io.ReadFull(rand, c.Nonce); err != nil { + return err + } + c.SignatureKey = authority.PublicKey() + + if v, ok := authority.(AlgorithmSigner); ok { + if v.PublicKey().Type() == KeyAlgoRSA { + authority = &defaultAlgorithmSigner{v, SigAlgoRSASHA2512} + } + } + + sig, err := authority.Sign(rand, c.bytesForSigning()) + if err != nil { + return err + } + c.Signature = sig + return nil +} + +var certAlgoNames = map[string]string{ + KeyAlgoRSA: CertAlgoRSAv01, + KeyAlgoRSASHA2256: CertAlgoRSASHA2256v01, + KeyAlgoRSASHA2512: CertAlgoRSASHA2512v01, + KeyAlgoDSA: CertAlgoDSAv01, + KeyAlgoECDSA256: CertAlgoECDSA256v01, + KeyAlgoECDSA384: CertAlgoECDSA384v01, + KeyAlgoECDSA521: CertAlgoECDSA521v01, + KeyAlgoSKECDSA256: CertAlgoSKECDSA256v01, + KeyAlgoED25519: CertAlgoED25519v01, + KeyAlgoSKED25519: CertAlgoSKED25519v01, +} + +// certToPrivAlgo returns the underlying algorithm for a certificate algorithm. +// Panics if a non-certificate algorithm is passed. +func certToPrivAlgo(algo string) string { + for privAlgo, pubAlgo := range certAlgoNames { + if pubAlgo == algo { + return privAlgo + } + } + panic("unknown cert algorithm") +} + +func (cert *Certificate) bytesForSigning() []byte { + c2 := *cert + c2.Signature = nil + out := c2.Marshal() + // Drop trailing signature length. + return out[:len(out)-4] +} + +// Marshal serializes c into OpenSSH's wire format. It is part of the +// PublicKey interface. +func (c *Certificate) Marshal() []byte { + generic := genericCertData{ + Serial: c.Serial, + CertType: c.CertType, + KeyId: c.KeyId, + ValidPrincipals: marshalStringList(c.ValidPrincipals), + ValidAfter: uint64(c.ValidAfter), + ValidBefore: uint64(c.ValidBefore), + CriticalOptions: marshalTuples(c.CriticalOptions), + Extensions: marshalTuples(c.Extensions), + Reserved: c.Reserved, + SignatureKey: c.SignatureKey.Marshal(), + } + if c.Signature != nil { + generic.Signature = Marshal(c.Signature) + } + genericBytes := Marshal(&generic) + keyBytes := c.Key.Marshal() + _, keyBytes, _ = parseString(keyBytes) + prefix := Marshal(&struct { + Name string + Nonce []byte + Key []byte `ssh:"rest"` + }{c.Type(), c.Nonce, keyBytes}) + + result := make([]byte, 0, len(prefix)+len(genericBytes)) + result = append(result, prefix...) + result = append(result, genericBytes...) + return result +} + +// Type returns the key name. It is part of the PublicKey interface. +func (c *Certificate) Type() string { + algo, ok := certAlgoNames[c.Key.Type()] + if !ok { + panic("unknown cert key type " + c.Key.Type()) + } + return algo +} + +// Verify verifies a signature against the certificate's public +// key. It is part of the PublicKey interface. +func (c *Certificate) Verify(data []byte, sig *Signature) error { + return c.Key.Verify(data, sig) +} + +func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) { + format, in, ok := parseString(in) + if !ok { + return + } + + out = &Signature{ + Format: string(format), + } + + if out.Blob, in, ok = parseString(in); !ok { + return + } + + switch out.Format { + case KeyAlgoSKECDSA256, CertAlgoSKECDSA256v01, KeyAlgoSKED25519, CertAlgoSKED25519v01: + out.Rest = in + return out, nil, ok + } + + return out, in, ok +} + +func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) { + sigBytes, rest, ok := parseString(in) + if !ok { + return + } + + out, trailing, ok := parseSignatureBody(sigBytes) + if !ok || len(trailing) > 0 { + return nil, nil, false + } + return +} diff --git a/internal/crypto/ssh/channel.go b/internal/crypto/ssh/channel.go new file mode 100644 index 000000000..c0834c00d --- /dev/null +++ b/internal/crypto/ssh/channel.go @@ -0,0 +1,633 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "sync" +) + +const ( + minPacketLength = 9 + // channelMaxPacket contains the maximum number of bytes that will be + // sent in a single packet. As per RFC 4253, section 6.1, 32k is also + // the minimum. + channelMaxPacket = 1 << 15 + // We follow OpenSSH here. + channelWindowSize = 64 * channelMaxPacket +) + +// NewChannel represents an incoming request to a channel. It must either be +// accepted for use by calling Accept, or rejected by calling Reject. +type NewChannel interface { + // Accept accepts the channel creation request. It returns the Channel + // and a Go channel containing SSH requests. The Go channel must be + // serviced otherwise the Channel will hang. + Accept() (Channel, <-chan *Request, error) + + // Reject rejects the channel creation request. After calling + // this, no other methods on the Channel may be called. + Reject(reason RejectionReason, message string) error + + // ChannelType returns the type of the channel, as supplied by the + // client. + ChannelType() string + + // ExtraData returns the arbitrary payload for this channel, as supplied + // by the client. This data is specific to the channel type. + ExtraData() []byte +} + +// A Channel is an ordered, reliable, flow-controlled, duplex stream +// that is multiplexed over an SSH connection. +type Channel interface { + // Read reads up to len(data) bytes from the channel. + Read(data []byte) (int, error) + + // Write writes len(data) bytes to the channel. + Write(data []byte) (int, error) + + // Close signals end of channel use. No data may be sent after this + // call. + Close() error + + // CloseWrite signals the end of sending in-band + // data. Requests may still be sent, and the other side may + // still send data + CloseWrite() error + + // SendRequest sends a channel request. If wantReply is true, + // it will wait for a reply and return the result as a + // boolean, otherwise the return value will be false. Channel + // requests are out-of-band messages so they may be sent even + // if the data stream is closed or blocked by flow control. + // If the channel is closed before a reply is returned, io.EOF + // is returned. + SendRequest(name string, wantReply bool, payload []byte) (bool, error) + + // Stderr returns an io.ReadWriter that writes to this channel + // with the extended data type set to stderr. Stderr may + // safely be read and written from a different goroutine than + // Read and Write respectively. + Stderr() io.ReadWriter +} + +// Request is a request sent outside of the normal stream of +// data. Requests can either be specific to an SSH channel, or they +// can be global. +type Request struct { + Type string + WantReply bool + Payload []byte + + ch *channel + mux *mux +} + +// Reply sends a response to a request. It must be called for all requests +// where WantReply is true and is a no-op otherwise. The payload argument is +// ignored for replies to channel-specific requests. +func (r *Request) Reply(ok bool, payload []byte) error { + if !r.WantReply { + return nil + } + + if r.ch == nil { + return r.mux.ackRequest(ok, payload) + } + + return r.ch.ackRequest(ok) +} + +// RejectionReason is an enumeration used when rejecting channel creation +// requests. See RFC 4254, section 5.1. +type RejectionReason uint32 + +const ( + Prohibited RejectionReason = iota + 1 + ConnectionFailed + UnknownChannelType + ResourceShortage +) + +// String converts the rejection reason to human readable form. +func (r RejectionReason) String() string { + switch r { + case Prohibited: + return "administratively prohibited" + case ConnectionFailed: + return "connect failed" + case UnknownChannelType: + return "unknown channel type" + case ResourceShortage: + return "resource shortage" + } + return fmt.Sprintf("unknown reason %d", int(r)) +} + +func min(a uint32, b int) uint32 { + if a < uint32(b) { + return a + } + return uint32(b) +} + +type channelDirection uint8 + +const ( + channelInbound channelDirection = iota + channelOutbound +) + +// channel is an implementation of the Channel interface that works +// with the mux class. +type channel struct { + // R/O after creation + chanType string + extraData []byte + localId, remoteId uint32 + + // maxIncomingPayload and maxRemotePayload are the maximum + // payload sizes of normal and extended data packets for + // receiving and sending, respectively. The wire packet will + // be 9 or 13 bytes larger (excluding encryption overhead). + maxIncomingPayload uint32 + maxRemotePayload uint32 + + mux *mux + + // decided is set to true if an accept or reject message has been sent + // (for outbound channels) or received (for inbound channels). + decided bool + + // direction contains either channelOutbound, for channels created + // locally, or channelInbound, for channels created by the peer. + direction channelDirection + + // Pending internal channel messages. + msg chan interface{} + + // Since requests have no ID, there can be only one request + // with WantReply=true outstanding. This lock is held by a + // goroutine that has such an outgoing request pending. + sentRequestMu sync.Mutex + + incomingRequests chan *Request + + sentEOF bool + + // thread-safe data + remoteWin window + pending *buffer + extPending *buffer + + // windowMu protects myWindow, the flow-control window. + windowMu sync.Mutex + myWindow uint32 + + // writeMu serializes calls to mux.conn.writePacket() and + // protects sentClose and packetPool. This mutex must be + // different from windowMu, as writePacket can block if there + // is a key exchange pending. + writeMu sync.Mutex + sentClose bool + + // packetPool has a buffer for each extended channel ID to + // save allocations during writes. + packetPool map[uint32][]byte +} + +// writePacket sends a packet. If the packet is a channel close, it updates +// sentClose. This method takes the lock c.writeMu. +func (ch *channel) writePacket(packet []byte) error { + ch.writeMu.Lock() + if ch.sentClose { + ch.writeMu.Unlock() + return io.EOF + } + ch.sentClose = (packet[0] == msgChannelClose) + err := ch.mux.conn.writePacket(packet) + ch.writeMu.Unlock() + return err +} + +func (ch *channel) sendMessage(msg interface{}) error { + if debugMux { + log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg) + } + + p := Marshal(msg) + binary.BigEndian.PutUint32(p[1:], ch.remoteId) + return ch.writePacket(p) +} + +// WriteExtended writes data to a specific extended stream. These streams are +// used, for example, for stderr. +func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { + if ch.sentEOF { + return 0, io.EOF + } + // 1 byte message type, 4 bytes remoteId, 4 bytes data length + opCode := byte(msgChannelData) + headerLength := uint32(9) + if extendedCode > 0 { + headerLength += 4 + opCode = msgChannelExtendedData + } + + ch.writeMu.Lock() + packet := ch.packetPool[extendedCode] + // We don't remove the buffer from packetPool, so + // WriteExtended calls from different goroutines will be + // flagged as errors by the race detector. + ch.writeMu.Unlock() + + for len(data) > 0 { + space := min(ch.maxRemotePayload, len(data)) + if space, err = ch.remoteWin.reserve(space); err != nil { + return n, err + } + if want := headerLength + space; uint32(cap(packet)) < want { + packet = make([]byte, want) + } else { + packet = packet[:want] + } + + todo := data[:space] + + packet[0] = opCode + binary.BigEndian.PutUint32(packet[1:], ch.remoteId) + if extendedCode > 0 { + binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) + } + binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) + copy(packet[headerLength:], todo) + if err = ch.writePacket(packet); err != nil { + return n, err + } + + n += len(todo) + data = data[len(todo):] + } + + ch.writeMu.Lock() + ch.packetPool[extendedCode] = packet + ch.writeMu.Unlock() + + return n, err +} + +func (ch *channel) handleData(packet []byte) error { + headerLen := 9 + isExtendedData := packet[0] == msgChannelExtendedData + if isExtendedData { + headerLen = 13 + } + if len(packet) < headerLen { + // malformed data packet + return parseError(packet[0]) + } + + var extended uint32 + if isExtendedData { + extended = binary.BigEndian.Uint32(packet[5:]) + } + + length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) + if length == 0 { + return nil + } + if length > ch.maxIncomingPayload { + // TODO(hanwen): should send Disconnect? + return errors.New("ssh: incoming packet exceeds maximum payload size") + } + + data := packet[headerLen:] + if length != uint32(len(data)) { + return errors.New("ssh: wrong packet length") + } + + ch.windowMu.Lock() + if ch.myWindow < length { + ch.windowMu.Unlock() + // TODO(hanwen): should send Disconnect with reason? + return errors.New("ssh: remote side wrote too much") + } + ch.myWindow -= length + ch.windowMu.Unlock() + + if extended == 1 { + ch.extPending.write(data) + } else if extended > 0 { + // discard other extended data. + } else { + ch.pending.write(data) + } + return nil +} + +func (c *channel) adjustWindow(n uint32) error { + c.windowMu.Lock() + // Since myWindow is managed on our side, and can never exceed + // the initial window setting, we don't worry about overflow. + c.myWindow += uint32(n) + c.windowMu.Unlock() + return c.sendMessage(windowAdjustMsg{ + AdditionalBytes: uint32(n), + }) +} + +func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { + switch extended { + case 1: + n, err = c.extPending.Read(data) + case 0: + n, err = c.pending.Read(data) + default: + return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) + } + + if n > 0 { + err = c.adjustWindow(uint32(n)) + // sendWindowAdjust can return io.EOF if the remote + // peer has closed the connection, however we want to + // defer forwarding io.EOF to the caller of Read until + // the buffer has been drained. + if n > 0 && err == io.EOF { + err = nil + } + } + + return n, err +} + +func (c *channel) close() { + c.pending.eof() + c.extPending.eof() + close(c.msg) + close(c.incomingRequests) + c.writeMu.Lock() + // This is not necessary for a normal channel teardown, but if + // there was another error, it is. + c.sentClose = true + c.writeMu.Unlock() + // Unblock writers. + c.remoteWin.close() +} + +// responseMessageReceived is called when a success or failure message is +// received on a channel to check that such a message is reasonable for the +// given channel. +func (ch *channel) responseMessageReceived() error { + if ch.direction == channelInbound { + return errors.New("ssh: channel response message received on inbound channel") + } + if ch.decided { + return errors.New("ssh: duplicate response received for channel") + } + ch.decided = true + return nil +} + +func (ch *channel) handlePacket(packet []byte) error { + switch packet[0] { + case msgChannelData, msgChannelExtendedData: + return ch.handleData(packet) + case msgChannelClose: + ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId}) + ch.mux.chanList.remove(ch.localId) + ch.close() + return nil + case msgChannelEOF: + // RFC 4254 is mute on how EOF affects dataExt messages but + // it is logical to signal EOF at the same time. + ch.extPending.eof() + ch.pending.eof() + return nil + } + + decoded, err := decode(packet) + if err != nil { + return err + } + + switch msg := decoded.(type) { + case *channelOpenFailureMsg: + if err := ch.responseMessageReceived(); err != nil { + return err + } + ch.mux.chanList.remove(msg.PeersID) + ch.msg <- msg + case *channelOpenConfirmMsg: + if err := ch.responseMessageReceived(); err != nil { + return err + } + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) + } + ch.remoteId = msg.MyID + ch.maxRemotePayload = msg.MaxPacketSize + ch.remoteWin.add(msg.MyWindow) + ch.msg <- msg + case *windowAdjustMsg: + if !ch.remoteWin.add(msg.AdditionalBytes) { + return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) + } + case *channelRequestMsg: + req := Request{ + Type: msg.Request, + WantReply: msg.WantReply, + Payload: msg.RequestSpecificData, + ch: ch, + } + + ch.incomingRequests <- &req + default: + ch.msg <- msg + } + return nil +} + +func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { + ch := &channel{ + remoteWin: window{Cond: newCond()}, + myWindow: channelWindowSize, + pending: newBuffer(), + extPending: newBuffer(), + direction: direction, + incomingRequests: make(chan *Request, chanSize), + msg: make(chan interface{}, chanSize), + chanType: chanType, + extraData: extraData, + mux: m, + packetPool: make(map[uint32][]byte), + } + ch.localId = m.chanList.add(ch) + return ch +} + +var errUndecided = errors.New("ssh: must Accept or Reject channel") +var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") + +type extChannel struct { + code uint32 + ch *channel +} + +func (e *extChannel) Write(data []byte) (n int, err error) { + return e.ch.WriteExtended(data, e.code) +} + +func (e *extChannel) Read(data []byte) (n int, err error) { + return e.ch.ReadExtended(data, e.code) +} + +func (ch *channel) Accept() (Channel, <-chan *Request, error) { + if ch.decided { + return nil, nil, errDecidedAlready + } + ch.maxIncomingPayload = channelMaxPacket + confirm := channelOpenConfirmMsg{ + PeersID: ch.remoteId, + MyID: ch.localId, + MyWindow: ch.myWindow, + MaxPacketSize: ch.maxIncomingPayload, + } + ch.decided = true + if err := ch.sendMessage(confirm); err != nil { + return nil, nil, err + } + + return ch, ch.incomingRequests, nil +} + +func (ch *channel) Reject(reason RejectionReason, message string) error { + if ch.decided { + return errDecidedAlready + } + reject := channelOpenFailureMsg{ + PeersID: ch.remoteId, + Reason: reason, + Message: message, + Language: "en", + } + ch.decided = true + return ch.sendMessage(reject) +} + +func (ch *channel) Read(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.ReadExtended(data, 0) +} + +func (ch *channel) Write(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.WriteExtended(data, 0) +} + +func (ch *channel) CloseWrite() error { + if !ch.decided { + return errUndecided + } + ch.sentEOF = true + return ch.sendMessage(channelEOFMsg{ + PeersID: ch.remoteId}) +} + +func (ch *channel) Close() error { + if !ch.decided { + return errUndecided + } + + return ch.sendMessage(channelCloseMsg{ + PeersID: ch.remoteId}) +} + +// Extended returns an io.ReadWriter that sends and receives data on the given, +// SSH extended stream. Such streams are used, for example, for stderr. +func (ch *channel) Extended(code uint32) io.ReadWriter { + if !ch.decided { + return nil + } + return &extChannel{code, ch} +} + +func (ch *channel) Stderr() io.ReadWriter { + return ch.Extended(1) +} + +func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + if !ch.decided { + return false, errUndecided + } + + if wantReply { + ch.sentRequestMu.Lock() + defer ch.sentRequestMu.Unlock() + } + + msg := channelRequestMsg{ + PeersID: ch.remoteId, + Request: name, + WantReply: wantReply, + RequestSpecificData: payload, + } + + if err := ch.sendMessage(msg); err != nil { + return false, err + } + + if wantReply { + m, ok := (<-ch.msg) + if !ok { + return false, io.EOF + } + switch m.(type) { + case *channelRequestFailureMsg: + return false, nil + case *channelRequestSuccessMsg: + return true, nil + default: + return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) + } + } + + return false, nil +} + +// ackRequest either sends an ack or nack to the channel request. +func (ch *channel) ackRequest(ok bool) error { + if !ch.decided { + return errUndecided + } + + var msg interface{} + if !ok { + msg = channelRequestFailureMsg{ + PeersID: ch.remoteId, + } + } else { + msg = channelRequestSuccessMsg{ + PeersID: ch.remoteId, + } + } + return ch.sendMessage(msg) +} + +func (ch *channel) ChannelType() string { + return ch.chanType +} + +func (ch *channel) ExtraData() []byte { + return ch.extraData +} diff --git a/internal/crypto/ssh/cipher.go b/internal/crypto/ssh/cipher.go new file mode 100644 index 000000000..8bd6b3daf --- /dev/null +++ b/internal/crypto/ssh/cipher.go @@ -0,0 +1,781 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/rc4" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "hash" + "io" + "io/ioutil" + + "golang.org/x/crypto/chacha20" + "golang.org/x/crypto/poly1305" +) + +const ( + packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher. + + // RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations + // MUST be able to process (plus a few more kilobytes for padding and mac). The RFC + // indicates implementations SHOULD be able to handle larger packet sizes, but then + // waffles on about reasonable limits. + // + // OpenSSH caps their maxPacket at 256kB so we choose to do + // the same. maxPacket is also used to ensure that uint32 + // length fields do not overflow, so it should remain well + // below 4G. + maxPacket = 256 * 1024 +) + +// noneCipher implements cipher.Stream and provides no encryption. It is used +// by the transport before the first key-exchange. +type noneCipher struct{} + +func (c noneCipher) XORKeyStream(dst, src []byte) { + copy(dst, src) +} + +func newAESCTR(key, iv []byte) (cipher.Stream, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return cipher.NewCTR(c, iv), nil +} + +func newRC4(key, iv []byte) (cipher.Stream, error) { + return rc4.NewCipher(key) +} + +type cipherMode struct { + keySize int + ivSize int + create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) +} + +func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + stream, err := createFunc(key, iv) + if err != nil { + return nil, err + } + + var streamDump []byte + if skip > 0 { + streamDump = make([]byte, 512) + } + + for remainingToDump := skip; remainingToDump > 0; { + dumpThisTime := remainingToDump + if dumpThisTime > len(streamDump) { + dumpThisTime = len(streamDump) + } + stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime]) + remainingToDump -= dumpThisTime + } + + mac := macModes[algs.MAC].new(macKey) + return &streamPacketCipher{ + mac: mac, + etm: macModes[algs.MAC].etm, + macResult: make([]byte, mac.Size()), + cipher: stream, + }, nil + } +} + +// cipherModes documents properties of supported ciphers. Ciphers not included +// are not supported and will not be negotiated, even if explicitly requested in +// ClientConfig.Crypto.Ciphers. +var cipherModes = map[string]*cipherMode{ + // Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms + // are defined in the order specified in the RFC. + "aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)}, + "aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)}, + "aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)}, + + // Ciphers from RFC4345, which introduces security-improved arcfour ciphers. + // They are defined in the order specified in the RFC. + "arcfour128": {16, 0, streamCipherMode(1536, newRC4)}, + "arcfour256": {32, 0, streamCipherMode(1536, newRC4)}, + + // Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol. + // Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and + // RC4) has problems with weak keys, and should be used with caution." + // RFC4345 introduces improved versions of Arcfour. + "arcfour": {16, 0, streamCipherMode(0, newRC4)}, + + // AEAD ciphers + gcmCipherID: {16, 12, newGCMCipher}, + chacha20Poly1305ID: {64, 0, newChaCha20Cipher}, + + // CBC mode is insecure and so is not included in the default config. + // (See https://www.ieee-security.org/TC/SP2013/papers/4977a526.pdf). If absolutely + // needed, it's possible to specify a custom Config to enable it. + // You should expect that an active attacker can recover plaintext if + // you do. + aes128cbcID: {16, aes.BlockSize, newAESCBCCipher}, + + // 3des-cbc is insecure and is not included in the default + // config. + tripledescbcID: {24, des.BlockSize, newTripleDESCBCCipher}, +} + +// prefixLen is the length of the packet prefix that contains the packet length +// and number of padding bytes. +const prefixLen = 5 + +// streamPacketCipher is a packetCipher using a stream cipher. +type streamPacketCipher struct { + mac hash.Hash + cipher cipher.Stream + etm bool + + // The following members are to avoid per-packet allocations. + prefix [prefixLen]byte + seqNumBytes [4]byte + padding [2 * packetSizeMultiple]byte + packetData []byte + macResult []byte +} + +// readCipherPacket reads and decrypt a single packet from the reader argument. +func (s *streamPacketCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, s.prefix[:]); err != nil { + return nil, err + } + + var encryptedPaddingLength [1]byte + if s.mac != nil && s.etm { + copy(encryptedPaddingLength[:], s.prefix[4:5]) + s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5]) + } else { + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + } + + length := binary.BigEndian.Uint32(s.prefix[0:4]) + paddingLength := uint32(s.prefix[4]) + + var macSize uint32 + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + if s.etm { + s.mac.Write(s.prefix[:4]) + s.mac.Write(encryptedPaddingLength[:]) + } else { + s.mac.Write(s.prefix[:]) + } + macSize = uint32(s.mac.Size()) + } + + if length <= paddingLength+1 { + return nil, errors.New("ssh: invalid packet length, packet too small") + } + + if length > maxPacket { + return nil, errors.New("ssh: invalid packet length, packet too large") + } + + // the maxPacket check above ensures that length-1+macSize + // does not overflow. + if uint32(cap(s.packetData)) < length-1+macSize { + s.packetData = make([]byte, length-1+macSize) + } else { + s.packetData = s.packetData[:length-1+macSize] + } + + if _, err := io.ReadFull(r, s.packetData); err != nil { + return nil, err + } + mac := s.packetData[length-1:] + data := s.packetData[:length-1] + + if s.mac != nil && s.etm { + s.mac.Write(data) + } + + s.cipher.XORKeyStream(data, data) + + if s.mac != nil { + if !s.etm { + s.mac.Write(data) + } + s.macResult = s.mac.Sum(s.macResult[:0]) + if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { + return nil, errors.New("ssh: MAC failure") + } + } + + return s.packetData[:length-paddingLength-1], nil +} + +// writeCipherPacket encrypts and sends a packet of data to the writer argument +func (s *streamPacketCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + if len(packet) > maxPacket { + return errors.New("ssh: packet too large") + } + + aadlen := 0 + if s.mac != nil && s.etm { + // packet length is not encrypted for EtM modes + aadlen = 4 + } + + paddingLength := packetSizeMultiple - (prefixLen+len(packet)-aadlen)%packetSizeMultiple + if paddingLength < 4 { + paddingLength += packetSizeMultiple + } + + length := len(packet) + 1 + paddingLength + binary.BigEndian.PutUint32(s.prefix[:], uint32(length)) + s.prefix[4] = byte(paddingLength) + padding := s.padding[:paddingLength] + if _, err := io.ReadFull(rand, padding); err != nil { + return err + } + + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + + if s.etm { + // For EtM algorithms, the packet length must stay unencrypted, + // but the following data (padding length) must be encrypted + s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5]) + } + + s.mac.Write(s.prefix[:]) + + if !s.etm { + // For non-EtM algorithms, the algorithm is applied on unencrypted data + s.mac.Write(packet) + s.mac.Write(padding) + } + } + + if !(s.mac != nil && s.etm) { + // For EtM algorithms, the padding length has already been encrypted + // and the packet length must remain unencrypted + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + } + + s.cipher.XORKeyStream(packet, packet) + s.cipher.XORKeyStream(padding, padding) + + if s.mac != nil && s.etm { + // For EtM algorithms, packet and padding must be encrypted + s.mac.Write(packet) + s.mac.Write(padding) + } + + if _, err := w.Write(s.prefix[:]); err != nil { + return err + } + if _, err := w.Write(packet); err != nil { + return err + } + if _, err := w.Write(padding); err != nil { + return err + } + + if s.mac != nil { + s.macResult = s.mac.Sum(s.macResult[:0]) + if _, err := w.Write(s.macResult); err != nil { + return err + } + } + + return nil +} + +type gcmCipher struct { + aead cipher.AEAD + prefix [4]byte + iv []byte + buf []byte +} + +func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + return &gcmCipher{ + aead: aead, + iv: iv, + }, nil +} + +const gcmTagSize = 16 + +func (c *gcmCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + // Pad out to multiple of 16 bytes. This is different from the + // stream cipher because that encrypts the length too. + padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple) + if padding < 4 { + padding += packetSizeMultiple + } + + length := uint32(len(packet) + int(padding) + 1) + binary.BigEndian.PutUint32(c.prefix[:], length) + if _, err := w.Write(c.prefix[:]); err != nil { + return err + } + + if cap(c.buf) < int(length) { + c.buf = make([]byte, length) + } else { + c.buf = c.buf[:length] + } + + c.buf[0] = padding + copy(c.buf[1:], packet) + if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil { + return err + } + c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if _, err := w.Write(c.buf); err != nil { + return err + } + c.incIV() + + return nil +} + +func (c *gcmCipher) incIV() { + for i := 4 + 7; i >= 4; i-- { + c.iv[i]++ + if c.iv[i] != 0 { + break + } + } +} + +func (c *gcmCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, c.prefix[:]); err != nil { + return nil, err + } + length := binary.BigEndian.Uint32(c.prefix[:]) + if length > maxPacket { + return nil, errors.New("ssh: max packet length exceeded") + } + + if cap(c.buf) < int(length+gcmTagSize) { + c.buf = make([]byte, length+gcmTagSize) + } else { + c.buf = c.buf[:length+gcmTagSize] + } + + if _, err := io.ReadFull(r, c.buf); err != nil { + return nil, err + } + + plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if err != nil { + return nil, err + } + c.incIV() + + padding := plain[0] + if padding < 4 { + // padding is a byte, so it automatically satisfies + // the maximum size, which is 255. + return nil, fmt.Errorf("ssh: illegal padding %d", padding) + } + + if int(padding+1) >= len(plain) { + return nil, fmt.Errorf("ssh: padding %d too large", padding) + } + plain = plain[1 : length-uint32(padding)] + return plain, nil +} + +// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1 +type cbcCipher struct { + mac hash.Hash + macSize uint32 + decrypter cipher.BlockMode + encrypter cipher.BlockMode + + // The following members are to avoid per-packet allocations. + seqNumBytes [4]byte + packetData []byte + macResult []byte + + // Amount of data we should still read to hide which + // verification error triggered. + oracleCamouflage uint32 +} + +func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + cbc := &cbcCipher{ + mac: macModes[algs.MAC].new(macKey), + decrypter: cipher.NewCBCDecrypter(c, iv), + encrypter: cipher.NewCBCEncrypter(c, iv), + packetData: make([]byte, 1024), + } + if cbc.mac != nil { + cbc.macSize = uint32(cbc.mac.Size()) + } + + return cbc, nil +} + +func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + cbc, err := newCBCCipher(c, key, iv, macKey, algs) + if err != nil { + return nil, err + } + + return cbc, nil +} + +func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + c, err := des.NewTripleDESCipher(key) + if err != nil { + return nil, err + } + + cbc, err := newCBCCipher(c, key, iv, macKey, algs) + if err != nil { + return nil, err + } + + return cbc, nil +} + +func maxUInt32(a, b int) uint32 { + if a > b { + return uint32(a) + } + return uint32(b) +} + +const ( + cbcMinPacketSizeMultiple = 8 + cbcMinPacketSize = 16 + cbcMinPaddingSize = 4 +) + +// cbcError represents a verification error that may leak information. +type cbcError string + +func (e cbcError) Error() string { return string(e) } + +func (c *cbcCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + p, err := c.readCipherPacketLeaky(seqNum, r) + if err != nil { + if _, ok := err.(cbcError); ok { + // Verification error: read a fixed amount of + // data, to make distinguishing between + // failing MAC and failing length check more + // difficult. + io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage)) + } + } + return p, err +} + +func (c *cbcCipher) readCipherPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) { + blockSize := c.decrypter.BlockSize() + + // Read the header, which will include some of the subsequent data in the + // case of block ciphers - this is copied back to the payload later. + // How many bytes of payload/padding will be read with this first read. + firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize) + firstBlock := c.packetData[:firstBlockLength] + if _, err := io.ReadFull(r, firstBlock); err != nil { + return nil, err + } + + c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength + + c.decrypter.CryptBlocks(firstBlock, firstBlock) + length := binary.BigEndian.Uint32(firstBlock[:4]) + if length > maxPacket { + return nil, cbcError("ssh: packet too large") + } + if length+4 < maxUInt32(cbcMinPacketSize, blockSize) { + // The minimum size of a packet is 16 (or the cipher block size, whichever + // is larger) bytes. + return nil, cbcError("ssh: packet too small") + } + // The length of the packet (including the length field but not the MAC) must + // be a multiple of the block size or 8, whichever is larger. + if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 { + return nil, cbcError("ssh: invalid packet length multiple") + } + + paddingLength := uint32(firstBlock[4]) + if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 { + return nil, cbcError("ssh: invalid packet length") + } + + // Positions within the c.packetData buffer: + macStart := 4 + length + paddingStart := macStart - paddingLength + + // Entire packet size, starting before length, ending at end of mac. + entirePacketSize := macStart + c.macSize + + // Ensure c.packetData is large enough for the entire packet data. + if uint32(cap(c.packetData)) < entirePacketSize { + // Still need to upsize and copy, but this should be rare at runtime, only + // on upsizing the packetData buffer. + c.packetData = make([]byte, entirePacketSize) + copy(c.packetData, firstBlock) + } else { + c.packetData = c.packetData[:entirePacketSize] + } + + n, err := io.ReadFull(r, c.packetData[firstBlockLength:]) + if err != nil { + return nil, err + } + c.oracleCamouflage -= uint32(n) + + remainingCrypted := c.packetData[firstBlockLength:macStart] + c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) + + mac := c.packetData[macStart:] + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData[:macStart]) + c.macResult = c.mac.Sum(c.macResult[:0]) + if subtle.ConstantTimeCompare(c.macResult, mac) != 1 { + return nil, cbcError("ssh: MAC failure") + } + } + + return c.packetData[prefixLen:paddingStart], nil +} + +func (c *cbcCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize()) + + // Length of encrypted portion of the packet (header, payload, padding). + // Enforce minimum padding and packet size. + encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize) + // Enforce block size. + encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize + + length := encLength - 4 + paddingLength := int(length) - (1 + len(packet)) + + // Overall buffer contains: header, payload, padding, mac. + // Space for the MAC is reserved in the capacity but not the slice length. + bufferSize := encLength + c.macSize + if uint32(cap(c.packetData)) < bufferSize { + c.packetData = make([]byte, encLength, bufferSize) + } else { + c.packetData = c.packetData[:encLength] + } + + p := c.packetData + + // Packet header. + binary.BigEndian.PutUint32(p, length) + p = p[4:] + p[0] = byte(paddingLength) + + // Payload. + p = p[1:] + copy(p, packet) + + // Padding. + p = p[len(packet):] + if _, err := io.ReadFull(rand, p); err != nil { + return err + } + + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData) + // The MAC is now appended into the capacity reserved for it earlier. + c.packetData = c.mac.Sum(c.packetData) + } + + c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength]) + + if _, err := w.Write(c.packetData); err != nil { + return err + } + + return nil +} + +const chacha20Poly1305ID = "chacha20-poly1305@openssh.com" + +// chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com +// AEAD, which is described here: +// +// https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00 +// +// the methods here also implement padding, which RFC4253 Section 6 +// also requires of stream ciphers. +type chacha20Poly1305Cipher struct { + lengthKey [32]byte + contentKey [32]byte + buf []byte +} + +func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) { + if len(key) != 64 { + panic(len(key)) + } + + c := &chacha20Poly1305Cipher{ + buf: make([]byte, 256), + } + + copy(c.contentKey[:], key[:32]) + copy(c.lengthKey[:], key[32:]) + return c, nil +} + +func (c *chacha20Poly1305Cipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + nonce := make([]byte, 12) + binary.BigEndian.PutUint32(nonce[8:], seqNum) + s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce) + if err != nil { + return nil, err + } + var polyKey, discardBuf [32]byte + s.XORKeyStream(polyKey[:], polyKey[:]) + s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes + + encryptedLength := c.buf[:4] + if _, err := io.ReadFull(r, encryptedLength); err != nil { + return nil, err + } + + var lenBytes [4]byte + ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce) + if err != nil { + return nil, err + } + ls.XORKeyStream(lenBytes[:], encryptedLength) + + length := binary.BigEndian.Uint32(lenBytes[:]) + if length > maxPacket { + return nil, errors.New("ssh: invalid packet length, packet too large") + } + + contentEnd := 4 + length + packetEnd := contentEnd + poly1305.TagSize + if uint32(cap(c.buf)) < packetEnd { + c.buf = make([]byte, packetEnd) + copy(c.buf[:], encryptedLength) + } else { + c.buf = c.buf[:packetEnd] + } + + if _, err := io.ReadFull(r, c.buf[4:packetEnd]); err != nil { + return nil, err + } + + var mac [poly1305.TagSize]byte + copy(mac[:], c.buf[contentEnd:packetEnd]) + if !poly1305.Verify(&mac, c.buf[:contentEnd], &polyKey) { + return nil, errors.New("ssh: MAC failure") + } + + plain := c.buf[4:contentEnd] + s.XORKeyStream(plain, plain) + + padding := plain[0] + if padding < 4 { + // padding is a byte, so it automatically satisfies + // the maximum size, which is 255. + return nil, fmt.Errorf("ssh: illegal padding %d", padding) + } + + if int(padding)+1 >= len(plain) { + return nil, fmt.Errorf("ssh: padding %d too large", padding) + } + + plain = plain[1 : len(plain)-int(padding)] + + return plain, nil +} + +func (c *chacha20Poly1305Cipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, payload []byte) error { + nonce := make([]byte, 12) + binary.BigEndian.PutUint32(nonce[8:], seqNum) + s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce) + if err != nil { + return err + } + var polyKey, discardBuf [32]byte + s.XORKeyStream(polyKey[:], polyKey[:]) + s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes + + // There is no blocksize, so fall back to multiple of 8 byte + // padding, as described in RFC 4253, Sec 6. + const packetSizeMultiple = 8 + + padding := packetSizeMultiple - (1+len(payload))%packetSizeMultiple + if padding < 4 { + padding += packetSizeMultiple + } + + // size (4 bytes), padding (1), payload, padding, tag. + totalLength := 4 + 1 + len(payload) + padding + poly1305.TagSize + if cap(c.buf) < totalLength { + c.buf = make([]byte, totalLength) + } else { + c.buf = c.buf[:totalLength] + } + + binary.BigEndian.PutUint32(c.buf, uint32(1+len(payload)+padding)) + ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce) + if err != nil { + return err + } + ls.XORKeyStream(c.buf, c.buf[:4]) + c.buf[4] = byte(padding) + copy(c.buf[5:], payload) + packetEnd := 5 + len(payload) + padding + if _, err := io.ReadFull(rand, c.buf[5+len(payload):packetEnd]); err != nil { + return err + } + + s.XORKeyStream(c.buf[4:], c.buf[4:packetEnd]) + + var mac [poly1305.TagSize]byte + poly1305.Sum(&mac, c.buf[:packetEnd], &polyKey) + + copy(c.buf[packetEnd:], mac[:]) + + if _, err := w.Write(c.buf); err != nil { + return err + } + return nil +} diff --git a/internal/crypto/ssh/client.go b/internal/crypto/ssh/client.go new file mode 100644 index 000000000..2d5a83b39 --- /dev/null +++ b/internal/crypto/ssh/client.go @@ -0,0 +1,287 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "net" + "os" + "sync" + "time" +) + +// Client implements a traditional SSH client that supports shells, +// subprocesses, TCP port/streamlocal forwarding and tunneled dialing. +type Client struct { + Conn + + handleForwardsOnce sync.Once // guards calling (*Client).handleForwards + + forwards forwardList // forwarded tcpip connections from the remote side + mu sync.Mutex + channelHandlers map[string]chan NewChannel +} + +// HandleChannelOpen returns a channel on which NewChannel requests +// for the given type are sent. If the type already is being handled, +// nil is returned. The channel is closed when the connection is closed. +func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel { + c.mu.Lock() + defer c.mu.Unlock() + if c.channelHandlers == nil { + // The SSH channel has been closed. + c := make(chan NewChannel) + close(c) + return c + } + + ch := c.channelHandlers[channelType] + if ch != nil { + return nil + } + + ch = make(chan NewChannel, chanSize) + c.channelHandlers[channelType] = ch + return ch +} + +// NewClient creates a Client on top of the given connection. +func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { + conn := &Client{ + Conn: c, + channelHandlers: make(map[string]chan NewChannel, 1), + } + + go conn.handleGlobalRequests(reqs) + go conn.handleChannelOpens(chans) + go func() { + conn.Wait() + conn.forwards.closeAll() + }() + return conn +} + +// NewClientConn establishes an authenticated SSH connection using c +// as the underlying transport. The Request and NewChannel channels +// must be serviced or the connection will hang. +func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + if fullConf.HostKeyCallback == nil { + c.Close() + return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback") + } + + conn := &connection{ + sshConn: sshConn{conn: c}, + } + + if err := conn.clientHandshake(addr, &fullConf); err != nil { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err) + } + conn.mux = newMux(conn.transport) + return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil +} + +// clientHandshake performs the client side key exchange. See RFC 4253 Section +// 7. +func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { + if config.ClientVersion != "" { + c.clientVersion = []byte(config.ClientVersion) + } else { + c.clientVersion = []byte(packageVersion) + } + var err error + c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) + if err != nil { + return err + } + + c.transport = newClientTransport( + newTransport(c.sshConn.conn, config.Rand, true /* is client */), + c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) + if err := c.transport.waitSession(); err != nil { + return err + } + + c.sessionID = c.transport.getSessionID() + return c.clientAuthenticate(config) +} + +// verifyHostKeySignature verifies the host key obtained in the key +// exchange. +func verifyHostKeySignature(hostKey PublicKey, algo string, result *kexResult) error { + sig, rest, ok := parseSignatureBody(result.Signature) + if len(rest) > 0 || !ok { + return errors.New("ssh: signature parse error") + } + + for privAlgo, pubAlgo := range certAlgoNames { + if pubAlgo == algo { + algo = privAlgo + } + } + if sig.Format != algo { + return fmt.Errorf("ssh: invalid signature algorithm %q, expected %q", sig.Format, algo) + } + + return hostKey.Verify(result.H, sig) +} + +// NewSession opens a new Session for this client. (A session is a remote +// execution of a program.) +func (c *Client) NewSession() (*Session, error) { + ch, in, err := c.OpenChannel("session", nil) + if err != nil { + return nil, err + } + return newSession(ch, in) +} + +func (c *Client) handleGlobalRequests(incoming <-chan *Request) { + for r := range incoming { + // This handles keepalive messages and matches + // the behaviour of OpenSSH. + r.Reply(false, nil) + } +} + +// handleChannelOpens channel open messages from the remote side. +func (c *Client) handleChannelOpens(in <-chan NewChannel) { + for ch := range in { + c.mu.Lock() + handler := c.channelHandlers[ch.ChannelType()] + c.mu.Unlock() + + if handler != nil { + handler <- ch + } else { + ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType())) + } + } + + c.mu.Lock() + for _, ch := range c.channelHandlers { + close(ch) + } + c.channelHandlers = nil + c.mu.Unlock() +} + +// Dial starts a client connection to the given SSH server. It is a +// convenience function that connects to the given network address, +// initiates the SSH handshake, and then sets up a Client. For access +// to incoming channels and requests, use net.Dial with NewClientConn +// instead. +func Dial(network, addr string, config *ClientConfig) (*Client, error) { + conn, err := net.DialTimeout(network, addr, config.Timeout) + if err != nil { + return nil, err + } + c, chans, reqs, err := NewClientConn(conn, addr, config) + if err != nil { + return nil, err + } + return NewClient(c, chans, reqs), nil +} + +// HostKeyCallback is the function type used for verifying server +// keys. A HostKeyCallback must return nil if the host key is OK, or +// an error to reject it. It receives the hostname as passed to Dial +// or NewClientConn. The remote address is the RemoteAddr of the +// net.Conn underlying the SSH connection. +type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error + +// BannerCallback is the function type used for treat the banner sent by +// the server. A BannerCallback receives the message sent by the remote server. +type BannerCallback func(message string) error + +// A ClientConfig structure is used to configure a Client. It must not be +// modified after having been passed to an SSH function. +type ClientConfig struct { + // Config contains configuration that is shared between clients and + // servers. + Config + + // User contains the username to authenticate as. + User string + + // Auth contains possible authentication methods to use with the + // server. Only the first instance of a particular RFC 4252 method will + // be used during authentication. + Auth []AuthMethod + + // HostKeyCallback is called during the cryptographic + // handshake to validate the server's host key. The client + // configuration must supply this callback for the connection + // to succeed. The functions InsecureIgnoreHostKey or + // FixedHostKey can be used for simplistic host key checks. + HostKeyCallback HostKeyCallback + + // BannerCallback is called during the SSH dance to display a custom + // server's message. The client configuration can supply this callback to + // handle it as wished. The function BannerDisplayStderr can be used for + // simplistic display on Stderr. + BannerCallback BannerCallback + + // ClientVersion contains the version identification string that will + // be used for the connection. If empty, a reasonable default is used. + ClientVersion string + + // HostKeyAlgorithms lists the key types that the client will + // accept from the server as host key, in order of + // preference. If empty, a reasonable default is used. Any + // string returned from PublicKey.Type method may be used, or + // any of the CertAlgoXxxx and KeyAlgoXxxx constants. + HostKeyAlgorithms []string + + // Timeout is the maximum amount of time for the TCP connection to establish. + // + // A Timeout of zero means no timeout. + Timeout time.Duration +} + +// InsecureIgnoreHostKey returns a function that can be used for +// ClientConfig.HostKeyCallback to accept any host key. It should +// not be used for production code. +func InsecureIgnoreHostKey() HostKeyCallback { + return func(hostname string, remote net.Addr, key PublicKey) error { + return nil + } +} + +type fixedHostKey struct { + key PublicKey +} + +func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error { + if f.key == nil { + return fmt.Errorf("ssh: required host key was nil") + } + if !bytes.Equal(key.Marshal(), f.key.Marshal()) { + return fmt.Errorf("ssh: host key mismatch") + } + return nil +} + +// FixedHostKey returns a function for use in +// ClientConfig.HostKeyCallback to accept only a specific host key. +func FixedHostKey(key PublicKey) HostKeyCallback { + hk := &fixedHostKey{key} + return hk.check +} + +// BannerDisplayStderr returns a function that can be used for +// ClientConfig.BannerCallback to display banners on os.Stderr. +func BannerDisplayStderr() BannerCallback { + return func(banner string) error { + _, err := os.Stderr.WriteString(banner) + + return err + } +} diff --git a/internal/crypto/ssh/client_auth.go b/internal/crypto/ssh/client_auth.go new file mode 100644 index 000000000..f3265655e --- /dev/null +++ b/internal/crypto/ssh/client_auth.go @@ -0,0 +1,641 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" +) + +type authResult int + +const ( + authFailure authResult = iota + authPartialSuccess + authSuccess +) + +// clientAuthenticate authenticates with the remote server. See RFC 4252. +func (c *connection) clientAuthenticate(config *ClientConfig) error { + // initiate user auth session + if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { + return err + } + packet, err := c.transport.readPacket() + if err != nil { + return err + } + var serviceAccept serviceAcceptMsg + if err := Unmarshal(packet, &serviceAccept); err != nil { + return err + } + + // during the authentication phase the client first attempts the "none" method + // then any untried methods suggested by the server. + var tried []string + var lastMethods []string + + sessionID := c.transport.getSessionID() + for auth := AuthMethod(new(noneAuth)); auth != nil; { + ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand) + if err != nil { + return err + } + if ok == authSuccess { + // success + return nil + } else if ok == authFailure { + if m := auth.method(); !contains(tried, m) { + tried = append(tried, m) + } + } + if methods == nil { + methods = lastMethods + } + lastMethods = methods + + auth = nil + + findNext: + for _, a := range config.Auth { + candidateMethod := a.method() + if contains(tried, candidateMethod) { + continue + } + for _, meth := range methods { + if meth == candidateMethod { + auth = a + break findNext + } + } + } + } + return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", tried) +} + +func contains(list []string, e string) bool { + for _, s := range list { + if s == e { + return true + } + } + return false +} + +// An AuthMethod represents an instance of an RFC 4252 authentication method. +type AuthMethod interface { + // auth authenticates user over transport t. + // Returns true if authentication is successful. + // If authentication is not successful, a []string of alternative + // method names is returned. If the slice is nil, it will be ignored + // and the previous set of possible methods will be reused. + auth(session []byte, user string, p packetConn, rand io.Reader) (authResult, []string, error) + + // method returns the RFC 4252 method name. + method() string +} + +// "none" authentication, RFC 4252 section 5.2. +type noneAuth int + +func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) { + if err := c.writePacket(Marshal(&userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: "none", + })); err != nil { + return authFailure, nil, err + } + + return handleAuthResponse(c) +} + +func (n *noneAuth) method() string { + return "none" +} + +// passwordCallback is an AuthMethod that fetches the password through +// a function call, e.g. by prompting the user. +type passwordCallback func() (password string, err error) + +func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) { + type passwordAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + Reply bool + Password string + } + + pw, err := cb() + // REVIEW NOTE: is there a need to support skipping a password attempt? + // The program may only find out that the user doesn't have a password + // when prompting. + if err != nil { + return authFailure, nil, err + } + + if err := c.writePacket(Marshal(&passwordAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + Reply: false, + Password: pw, + })); err != nil { + return authFailure, nil, err + } + + return handleAuthResponse(c) +} + +func (cb passwordCallback) method() string { + return "password" +} + +// Password returns an AuthMethod using the given password. +func Password(secret string) AuthMethod { + return passwordCallback(func() (string, error) { return secret, nil }) +} + +// PasswordCallback returns an AuthMethod that uses a callback for +// fetching a password. +func PasswordCallback(prompt func() (secret string, err error)) AuthMethod { + return passwordCallback(prompt) +} + +type publickeyAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + // HasSig indicates to the receiver packet that the auth request is signed and + // should be used for authentication of the request. + HasSig bool + Algoname string + PubKey []byte + // Sig is tagged with "rest" so Marshal will exclude it during + // validateKey + Sig []byte `ssh:"rest"` +} + +// publicKeyCallback is an AuthMethod that uses a set of key +// pairs for authentication. +type publicKeyCallback func() ([]Signer, error) + +func (cb publicKeyCallback) method() string { + return "publickey" +} + +func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) { + // Authentication is performed by sending an enquiry to test if a key is + // acceptable to the remote. If the key is acceptable, the client will + // attempt to authenticate with the valid key. If not the client will repeat + // the process with the remaining keys. + + signers, err := cb() + if err != nil { + return authFailure, nil, err + } + var methods []string + for _, signer := range signers { + ok, err := validateKey(signer.PublicKey(), user, c) + if err != nil { + return authFailure, nil, err + } + if !ok { + continue + } + + pub := signer.PublicKey() + pubKey := pub.Marshal() + sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + }, []byte(pub.Type()), pubKey)) + if err != nil { + return authFailure, nil, err + } + + // manually wrap the serialized signature in a string + s := Marshal(sign) + sig := make([]byte, stringLength(len(s))) + marshalString(sig, s) + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + HasSig: true, + Algoname: pub.Type(), + PubKey: pubKey, + Sig: sig, + } + p := Marshal(&msg) + if err := c.writePacket(p); err != nil { + return authFailure, nil, err + } + var success authResult + success, methods, err = handleAuthResponse(c) + if err != nil { + return authFailure, nil, err + } + + // If authentication succeeds or the list of available methods does not + // contain the "publickey" method, do not attempt to authenticate with any + // other keys. According to RFC 4252 Section 7, the latter can occur when + // additional authentication methods are required. + if success == authSuccess || !containsMethod(methods, cb.method()) { + return success, methods, err + } + } + + return authFailure, methods, nil +} + +func containsMethod(methods []string, method string) bool { + for _, m := range methods { + if m == method { + return true + } + } + + return false +} + +// validateKey validates the key provided is acceptable to the server. +func validateKey(key PublicKey, user string, c packetConn) (bool, error) { + pubKey := key.Marshal() + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: "publickey", + HasSig: false, + Algoname: key.Type(), + PubKey: pubKey, + } + if err := c.writePacket(Marshal(&msg)); err != nil { + return false, err + } + + return confirmKeyAck(key, c) +} + +func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { + pubKey := key.Marshal() + algoname := key.Type() + + for { + packet, err := c.readPacket() + if err != nil { + return false, err + } + switch packet[0] { + case msgUserAuthBanner: + if err := handleBannerResponse(c, packet); err != nil { + return false, err + } + case msgUserAuthPubKeyOk: + var msg userAuthPubKeyOkMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, err + } + if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) { + return false, nil + } + return true, nil + case msgUserAuthFailure: + return false, nil + default: + return false, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + } +} + +// PublicKeys returns an AuthMethod that uses the given key +// pairs. +func PublicKeys(signers ...Signer) AuthMethod { + return publicKeyCallback(func() ([]Signer, error) { return signers, nil }) +} + +// PublicKeysCallback returns an AuthMethod that runs the given +// function to obtain a list of key pairs. +func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod { + return publicKeyCallback(getSigners) +} + +// handleAuthResponse returns whether the preceding authentication request succeeded +// along with a list of remaining authentication methods to try next and +// an error if an unexpected response was received. +func handleAuthResponse(c packetConn) (authResult, []string, error) { + for { + packet, err := c.readPacket() + if err != nil { + return authFailure, nil, err + } + + switch packet[0] { + case msgUserAuthBanner: + if err := handleBannerResponse(c, packet); err != nil { + return authFailure, nil, err + } + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + if msg.PartialSuccess { + return authPartialSuccess, msg.Methods, nil + } + return authFailure, msg.Methods, nil + case msgUserAuthSuccess: + return authSuccess, nil, nil + default: + return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + } +} + +func handleBannerResponse(c packetConn, packet []byte) error { + var msg userAuthBannerMsg + if err := Unmarshal(packet, &msg); err != nil { + return err + } + + transport, ok := c.(*handshakeTransport) + if !ok { + return nil + } + + if transport.bannerCallback != nil { + return transport.bannerCallback(msg.Message) + } + + return nil +} + +// KeyboardInteractiveChallenge should print questions, optionally +// disabling echoing (e.g. for passwords), and return all the answers. +// Challenge may be called multiple times in a single session. After +// successful authentication, the server may send a challenge with no +// questions, for which the user and instruction messages should be +// printed. RFC 4256 section 3.3 details how the UI should behave for +// both CLI and GUI environments. +type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error) + +// KeyboardInteractive returns an AuthMethod using a prompt/response +// sequence controlled by the server. +func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { + return challenge +} + +func (cb KeyboardInteractiveChallenge) method() string { + return "keyboard-interactive" +} + +func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) { + type initiateMsg struct { + User string `sshtype:"50"` + Service string + Method string + Language string + Submethods string + } + + if err := c.writePacket(Marshal(&initiateMsg{ + User: user, + Service: serviceSSH, + Method: "keyboard-interactive", + })); err != nil { + return authFailure, nil, err + } + + for { + packet, err := c.readPacket() + if err != nil { + return authFailure, nil, err + } + + // like handleAuthResponse, but with less options. + switch packet[0] { + case msgUserAuthBanner: + if err := handleBannerResponse(c, packet); err != nil { + return authFailure, nil, err + } + continue + case msgUserAuthInfoRequest: + // OK + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + if msg.PartialSuccess { + return authPartialSuccess, msg.Methods, nil + } + return authFailure, msg.Methods, nil + case msgUserAuthSuccess: + return authSuccess, nil, nil + default: + return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } + + var msg userAuthInfoRequestMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + + // Manually unpack the prompt/echo pairs. + rest := msg.Prompts + var prompts []string + var echos []bool + for i := 0; i < int(msg.NumPrompts); i++ { + prompt, r, ok := parseString(rest) + if !ok || len(r) == 0 { + return authFailure, nil, errors.New("ssh: prompt format error") + } + prompts = append(prompts, string(prompt)) + echos = append(echos, r[0] != 0) + rest = r[1:] + } + + if len(rest) != 0 { + return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs") + } + + answers, err := cb(msg.User, msg.Instruction, prompts, echos) + if err != nil { + return authFailure, nil, err + } + + if len(answers) != len(prompts) { + return authFailure, nil, errors.New("ssh: not enough answers from keyboard-interactive callback") + } + responseLength := 1 + 4 + for _, a := range answers { + responseLength += stringLength(len(a)) + } + serialized := make([]byte, responseLength) + p := serialized + p[0] = msgUserAuthInfoResponse + p = p[1:] + p = marshalUint32(p, uint32(len(answers))) + for _, a := range answers { + p = marshalString(p, []byte(a)) + } + + if err := c.writePacket(serialized); err != nil { + return authFailure, nil, err + } + } +} + +type retryableAuthMethod struct { + authMethod AuthMethod + maxTries int +} + +func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader) (ok authResult, methods []string, err error) { + for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ { + ok, methods, err = r.authMethod.auth(session, user, c, rand) + if ok != authFailure || err != nil { // either success, partial success or error terminate + return ok, methods, err + } + } + return ok, methods, err +} + +func (r *retryableAuthMethod) method() string { + return r.authMethod.method() +} + +// RetryableAuthMethod is a decorator for other auth methods enabling them to +// be retried up to maxTries before considering that AuthMethod itself failed. +// If maxTries is <= 0, will retry indefinitely +// +// This is useful for interactive clients using challenge/response type +// authentication (e.g. Keyboard-Interactive, Password, etc) where the user +// could mistype their response resulting in the server issuing a +// SSH_MSG_USERAUTH_FAILURE (rfc4252 #8 [password] and rfc4256 #3.4 +// [keyboard-interactive]); Without this decorator, the non-retryable +// AuthMethod would be removed from future consideration, and never tried again +// (and so the user would never be able to retry their entry). +func RetryableAuthMethod(auth AuthMethod, maxTries int) AuthMethod { + return &retryableAuthMethod{authMethod: auth, maxTries: maxTries} +} + +// GSSAPIWithMICAuthMethod is an AuthMethod with "gssapi-with-mic" authentication. +// See RFC 4462 section 3 +// gssAPIClient is implementation of the GSSAPIClient interface, see the definition of the interface for details. +// target is the server host you want to log in to. +func GSSAPIWithMICAuthMethod(gssAPIClient GSSAPIClient, target string) AuthMethod { + if gssAPIClient == nil { + panic("gss-api client must be not nil with enable gssapi-with-mic") + } + return &gssAPIWithMICCallback{gssAPIClient: gssAPIClient, target: target} +} + +type gssAPIWithMICCallback struct { + gssAPIClient GSSAPIClient + target string +} + +func (g *gssAPIWithMICCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) { + m := &userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: g.method(), + } + // The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST. + // See RFC 4462 section 3.2. + m.Payload = appendU32(m.Payload, 1) + m.Payload = appendString(m.Payload, string(krb5OID)) + if err := c.writePacket(Marshal(m)); err != nil { + return authFailure, nil, err + } + // The server responds to the SSH_MSG_USERAUTH_REQUEST with either an + // SSH_MSG_USERAUTH_FAILURE if none of the mechanisms are supported or + // with an SSH_MSG_USERAUTH_GSSAPI_RESPONSE. + // See RFC 4462 section 3.3. + // OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication,so I don't want to check + // selected mech if it is valid. + packet, err := c.readPacket() + if err != nil { + return authFailure, nil, err + } + userAuthGSSAPIResp := &userAuthGSSAPIResponse{} + if err := Unmarshal(packet, userAuthGSSAPIResp); err != nil { + return authFailure, nil, err + } + // Start the loop into the exchange token. + // See RFC 4462 section 3.4. + var token []byte + defer g.gssAPIClient.DeleteSecContext() + for { + // Initiates the establishment of a security context between the application and a remote peer. + nextToken, needContinue, err := g.gssAPIClient.InitSecContext("host@"+g.target, token, false) + if err != nil { + return authFailure, nil, err + } + if len(nextToken) > 0 { + if err := c.writePacket(Marshal(&userAuthGSSAPIToken{ + Token: nextToken, + })); err != nil { + return authFailure, nil, err + } + } + if !needContinue { + break + } + packet, err = c.readPacket() + if err != nil { + return authFailure, nil, err + } + switch packet[0] { + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + if msg.PartialSuccess { + return authPartialSuccess, msg.Methods, nil + } + return authFailure, msg.Methods, nil + case msgUserAuthGSSAPIError: + userAuthGSSAPIErrorResp := &userAuthGSSAPIError{} + if err := Unmarshal(packet, userAuthGSSAPIErrorResp); err != nil { + return authFailure, nil, err + } + return authFailure, nil, fmt.Errorf("GSS-API Error:\n"+ + "Major Status: %d\n"+ + "Minor Status: %d\n"+ + "Error Message: %s\n", userAuthGSSAPIErrorResp.MajorStatus, userAuthGSSAPIErrorResp.MinorStatus, + userAuthGSSAPIErrorResp.Message) + case msgUserAuthGSSAPIToken: + userAuthGSSAPITokenReq := &userAuthGSSAPIToken{} + if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { + return authFailure, nil, err + } + token = userAuthGSSAPITokenReq.Token + } + } + // Binding Encryption Keys. + // See RFC 4462 section 3.5. + micField := buildMIC(string(session), user, "ssh-connection", "gssapi-with-mic") + micToken, err := g.gssAPIClient.GetMIC(micField) + if err != nil { + return authFailure, nil, err + } + if err := c.writePacket(Marshal(&userAuthGSSAPIMIC{ + MIC: micToken, + })); err != nil { + return authFailure, nil, err + } + return handleAuthResponse(c) +} + +func (g *gssAPIWithMICCallback) method() string { + return "gssapi-with-mic" +} diff --git a/internal/crypto/ssh/common.go b/internal/crypto/ssh/common.go new file mode 100644 index 000000000..f8ce35db2 --- /dev/null +++ b/internal/crypto/ssh/common.go @@ -0,0 +1,408 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto" + "crypto/rand" + "fmt" + "io" + "math" + "sync" + + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" +) + +// These are string constants in the SSH protocol. +const ( + compressionNone = "none" + serviceUserAuth = "ssh-userauth" + serviceSSH = "ssh-connection" +) + +// supportedCiphers lists ciphers we support but might not recommend. +var supportedCiphers = []string{ + "aes128-ctr", "aes192-ctr", "aes256-ctr", + "aes128-gcm@openssh.com", + chacha20Poly1305ID, + "arcfour256", "arcfour128", "arcfour", + aes128cbcID, + tripledescbcID, +} + +// preferredCiphers specifies the default preference for ciphers. +var preferredCiphers = []string{ + "aes128-gcm@openssh.com", + chacha20Poly1305ID, + "aes128-ctr", "aes192-ctr", "aes256-ctr", +} + +// supportedKexAlgos specifies the supported key-exchange algorithms in +// preference order. +var supportedKexAlgos = []string{ + kexAlgoCurve25519SHA256, + // P384 and P521 are not constant-time yet, but since we don't + // reuse ephemeral keys, using them for ECDH should be OK. + kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, + kexAlgoDH14SHA1, kexAlgoDH1SHA1, +} + +// serverForbiddenKexAlgos contains key exchange algorithms, that are forbidden +// for the server half. +var serverForbiddenKexAlgos = map[string]struct{}{ + kexAlgoDHGEXSHA1: {}, // server half implementation is only minimal to satisfy the automated tests + kexAlgoDHGEXSHA256: {}, // server half implementation is only minimal to satisfy the automated tests +} + +// preferredKexAlgos specifies the default preference for key-exchange algorithms +// in preference order. +var preferredKexAlgos = []string{ + kexAlgoCurve25519SHA256, + kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, + kexAlgoDH14SHA1, +} + +// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods +// of authenticating servers) in preference order. +var supportedHostKeyAlgos = []string{ + CertAlgoRSASHA2512v01, CertAlgoRSASHA2256v01, CertAlgoRSAv01, + CertAlgoDSAv01, CertAlgoECDSA256v01, + CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01, + + KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, + KeyAlgoRSASHA2512, KeyAlgoRSASHA2256, + KeyAlgoRSA, KeyAlgoDSA, + + KeyAlgoED25519, +} + +// supportedMACs specifies a default set of MAC algorithms in preference order. +// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed +// because they have reached the end of their useful life. +var supportedMACs = []string{ + "hmac-sha2-256-etm@openssh.com", "hmac-sha2-256", "hmac-sha1", "hmac-sha1-96", +} + +var supportedCompressions = []string{compressionNone} + +// hashFuncs keeps the mapping of supported algorithms to their respective +// hashes needed for signature verification. +var hashFuncs = map[string]crypto.Hash{ + KeyAlgoRSA: crypto.SHA1, + KeyAlgoRSASHA2256: crypto.SHA256, + KeyAlgoRSASHA2512: crypto.SHA512, + KeyAlgoDSA: crypto.SHA1, + KeyAlgoECDSA256: crypto.SHA256, + KeyAlgoECDSA384: crypto.SHA384, + KeyAlgoECDSA521: crypto.SHA512, + CertAlgoRSAv01: crypto.SHA1, + CertAlgoDSAv01: crypto.SHA1, + CertAlgoECDSA256v01: crypto.SHA256, + CertAlgoECDSA384v01: crypto.SHA384, + CertAlgoECDSA521v01: crypto.SHA512, +} + +// unexpectedMessageError results when the SSH message that we received didn't +// match what we wanted. +func unexpectedMessageError(expected, got uint8) error { + return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected) +} + +// parseError results from a malformed SSH message. +func parseError(tag uint8) error { + return fmt.Errorf("ssh: parse error in message type %d", tag) +} + +func findCommon(what string, client []string, server []string) (common string, err error) { + for _, c := range client { + for _, s := range server { + if c == s { + return c, nil + } + } + } + return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) +} + +// directionAlgorithms records algorithm choices in one direction (either read or write) +type directionAlgorithms struct { + Cipher string + MAC string + Compression string +} + +// rekeyBytes returns a rekeying intervals in bytes. +func (a *directionAlgorithms) rekeyBytes() int64 { + // According to RFC4344 block ciphers should rekey after + // 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is + // 128. + switch a.Cipher { + case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcmCipherID, aes128cbcID: + return 16 * (1 << 32) + + } + + // For others, stick with RFC4253 recommendation to rekey after 1 Gb of data. + return 1 << 30 +} + +type algorithms struct { + kex string + hostKey string + w directionAlgorithms + r directionAlgorithms +} + +func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) { + result := &algorithms{} + + result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) + if err != nil { + return + } + + result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) + if err != nil { + return + } + + stoc, ctos := &result.w, &result.r + if isClient { + ctos, stoc = stoc, ctos + } + + ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) + if err != nil { + return + } + + stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) + if err != nil { + return + } + + ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) + if err != nil { + return + } + + stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) + if err != nil { + return + } + + ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) + if err != nil { + return + } + + stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) + if err != nil { + return + } + + return result, nil +} + +// If rekeythreshold is too small, we can't make any progress sending +// stuff. +const minRekeyThreshold uint64 = 256 + +// Config contains configuration data common to both ServerConfig and +// ClientConfig. +type Config struct { + // Rand provides the source of entropy for cryptographic + // primitives. If Rand is nil, the cryptographic random reader + // in package crypto/rand will be used. + Rand io.Reader + + // The maximum number of bytes sent or received after which a + // new key is negotiated. It must be at least 256. If + // unspecified, a size suitable for the chosen cipher is used. + RekeyThreshold uint64 + + // The allowed key exchanges algorithms. If unspecified then a + // default set of algorithms is used. + KeyExchanges []string + + // The allowed cipher algorithms. If unspecified then a sensible + // default is used. + Ciphers []string + + // The allowed MAC algorithms. If unspecified then a sensible default + // is used. + MACs []string +} + +// SetDefaults sets sensible values for unset fields in config. This is +// exported for testing: Configs passed to SSH functions are copied and have +// default values set automatically. +func (c *Config) SetDefaults() { + if c.Rand == nil { + c.Rand = rand.Reader + } + if c.Ciphers == nil { + c.Ciphers = preferredCiphers + } + var ciphers []string + for _, c := range c.Ciphers { + if cipherModes[c] != nil { + // reject the cipher if we have no cipherModes definition + ciphers = append(ciphers, c) + } + } + c.Ciphers = ciphers + + if c.KeyExchanges == nil { + c.KeyExchanges = preferredKexAlgos + } + + if c.MACs == nil { + c.MACs = supportedMACs + } + + if c.RekeyThreshold == 0 { + // cipher specific default + } else if c.RekeyThreshold < minRekeyThreshold { + c.RekeyThreshold = minRekeyThreshold + } else if c.RekeyThreshold >= math.MaxInt64 { + // Avoid weirdness if somebody uses -1 as a threshold. + c.RekeyThreshold = math.MaxInt64 + } +} + +// buildDataSignedForAuth returns the data that is signed in order to prove +// possession of a private key. See RFC 4252, section 7. +func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { + data := struct { + Session []byte + Type byte + User string + Service string + Method string + Sign bool + Algo []byte + PubKey []byte + }{ + sessionID, + msgUserAuthRequest, + req.User, + req.Service, + req.Method, + true, + algo, + pubKey, + } + return Marshal(data) +} + +func appendU16(buf []byte, n uint16) []byte { + return append(buf, byte(n>>8), byte(n)) +} + +func appendU32(buf []byte, n uint32) []byte { + return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendU64(buf []byte, n uint64) []byte { + return append(buf, + byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), + byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendInt(buf []byte, n int) []byte { + return appendU32(buf, uint32(n)) +} + +func appendString(buf []byte, s string) []byte { + buf = appendU32(buf, uint32(len(s))) + buf = append(buf, s...) + return buf +} + +func appendBool(buf []byte, b bool) []byte { + if b { + return append(buf, 1) + } + return append(buf, 0) +} + +// newCond is a helper to hide the fact that there is no usable zero +// value for sync.Cond. +func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) } + +// window represents the buffer available to clients +// wishing to write to a channel. +type window struct { + *sync.Cond + win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 + writeWaiters int + closed bool +} + +// add adds win to the amount of window available +// for consumers. +func (w *window) add(win uint32) bool { + // a zero sized window adjust is a noop. + if win == 0 { + return true + } + w.L.Lock() + if w.win+win < win { + w.L.Unlock() + return false + } + w.win += win + // It is unusual that multiple goroutines would be attempting to reserve + // window space, but not guaranteed. Use broadcast to notify all waiters + // that additional window is available. + w.Broadcast() + w.L.Unlock() + return true +} + +// close sets the window to closed, so all reservations fail +// immediately. +func (w *window) close() { + w.L.Lock() + w.closed = true + w.Broadcast() + w.L.Unlock() +} + +// reserve reserves win from the available window capacity. +// If no capacity remains, reserve will block. reserve may +// return less than requested. +func (w *window) reserve(win uint32) (uint32, error) { + var err error + w.L.Lock() + w.writeWaiters++ + w.Broadcast() + for w.win == 0 && !w.closed { + w.Wait() + } + w.writeWaiters-- + if w.win < win { + win = w.win + } + w.win -= win + if w.closed { + err = io.EOF + } + w.L.Unlock() + return win, err +} + +// waitWriterBlocked waits until some goroutine is blocked for further +// writes. It is used in tests only. +func (w *window) waitWriterBlocked() { + w.Cond.L.Lock() + for w.writeWaiters == 0 { + w.Cond.Wait() + } + w.Cond.L.Unlock() +} diff --git a/internal/crypto/ssh/connection.go b/internal/crypto/ssh/connection.go new file mode 100644 index 000000000..fd6b0681b --- /dev/null +++ b/internal/crypto/ssh/connection.go @@ -0,0 +1,143 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "fmt" + "net" +) + +// OpenChannelError is returned if the other side rejects an +// OpenChannel request. +type OpenChannelError struct { + Reason RejectionReason + Message string +} + +func (e *OpenChannelError) Error() string { + return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) +} + +// ConnMetadata holds metadata for the connection. +type ConnMetadata interface { + // User returns the user ID for this connection. + User() string + + // SessionID returns the session hash, also denoted by H. + SessionID() []byte + + // ClientVersion returns the client's version string as hashed + // into the session ID. + ClientVersion() []byte + + // ServerVersion returns the server's version string as hashed + // into the session ID. + ServerVersion() []byte + + // RemoteAddr returns the remote address for this connection. + RemoteAddr() net.Addr + + // LocalAddr returns the local address for this connection. + LocalAddr() net.Addr +} + +// Conn represents an SSH connection for both server and client roles. +// Conn is the basis for implementing an application layer, such +// as ClientConn, which implements the traditional shell access for +// clients. +type Conn interface { + ConnMetadata + + // SendRequest sends a global request, and returns the + // reply. If wantReply is true, it returns the response status + // and payload. See also RFC4254, section 4. + SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) + + // OpenChannel tries to open an channel. If the request is + // rejected, it returns *OpenChannelError. On success it returns + // the SSH Channel and a Go channel for incoming, out-of-band + // requests. The Go channel must be serviced, or the + // connection will hang. + OpenChannel(name string, data []byte) (Channel, <-chan *Request, error) + + // Close closes the underlying network connection + Close() error + + // Wait blocks until the connection has shut down, and returns the + // error causing the shutdown. + Wait() error + + // TODO(hanwen): consider exposing: + // RequestKeyChange + // Disconnect +} + +// DiscardRequests consumes and rejects all requests from the +// passed-in channel. +func DiscardRequests(in <-chan *Request) { + for req := range in { + if req.WantReply { + req.Reply(false, nil) + } + } +} + +// A connection represents an incoming connection. +type connection struct { + transport *handshakeTransport + sshConn + + // The connection protocol. + *mux +} + +func (c *connection) Close() error { + return c.sshConn.conn.Close() +} + +// sshconn provides net.Conn metadata, but disallows direct reads and +// writes. +type sshConn struct { + conn net.Conn + + user string + sessionID []byte + clientVersion []byte + serverVersion []byte +} + +func dup(src []byte) []byte { + dst := make([]byte, len(src)) + copy(dst, src) + return dst +} + +func (c *sshConn) User() string { + return c.user +} + +func (c *sshConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *sshConn) Close() error { + return c.conn.Close() +} + +func (c *sshConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *sshConn) SessionID() []byte { + return dup(c.sessionID) +} + +func (c *sshConn) ClientVersion() []byte { + return dup(c.clientVersion) +} + +func (c *sshConn) ServerVersion() []byte { + return dup(c.serverVersion) +} diff --git a/internal/crypto/ssh/doc.go b/internal/crypto/ssh/doc.go new file mode 100644 index 000000000..67b7322c0 --- /dev/null +++ b/internal/crypto/ssh/doc.go @@ -0,0 +1,21 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package ssh implements an SSH client and server. + +SSH is a transport security protocol, an authentication protocol and a +family of application protocols. The most typical application level +protocol is a remote shell and this is specifically implemented. However, +the multiplexed nature of SSH is exposed to users that wish to support +others. + +References: + [PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD + [SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 + +This package does not fall under the stability promise of the Go language itself, +so its API may be changed when pressing needs arise. +*/ +package ssh // import "golang.org/x/crypto/ssh" diff --git a/internal/crypto/ssh/handshake.go b/internal/crypto/ssh/handshake.go new file mode 100644 index 000000000..48424d1cc --- /dev/null +++ b/internal/crypto/ssh/handshake.go @@ -0,0 +1,646 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + "log" + "net" + "sync" +) + +// debugHandshake, if set, prints messages sent and received. Key +// exchange messages are printed as if DH were used, so the debug +// messages are wrong when using ECDH. +const debugHandshake = false + +// chanSize sets the amount of buffering SSH connections. This is +// primarily for testing: setting chanSize=0 uncovers deadlocks more +// quickly. +const chanSize = 16 + +// keyingTransport is a packet based transport that supports key +// changes. It need not be thread-safe. It should pass through +// msgNewKeys in both directions. +type keyingTransport interface { + packetConn + + // prepareKeyChange sets up a key change. The key change for a + // direction will be effected if a msgNewKeys message is sent + // or received. + prepareKeyChange(*algorithms, *kexResult) error +} + +// handshakeTransport implements rekeying on top of a keyingTransport +// and offers a thread-safe writePacket() interface. +type handshakeTransport struct { + conn keyingTransport + config *Config + + serverVersion []byte + clientVersion []byte + + // hostKeys is non-empty if we are the server. In that case, + // it contains all host keys that can be used to sign the + // connection. + hostKeys map[string]Signer + + // hostKeyAlgorithms is non-empty if we are the client. In that case, + // we accept these key types from the server as host key. + hostKeyAlgorithms []string + + // On read error, incoming is closed, and readError is set. + incoming chan []byte + readError error + + mu sync.Mutex + writeError error + sentInitPacket []byte + sentInitMsg *kexInitMsg + pendingPackets [][]byte // Used when a key exchange is in progress. + + // If the read loop wants to schedule a kex, it pings this + // channel, and the write loop will send out a kex + // message. + requestKex chan struct{} + + // If the other side requests or confirms a kex, its kexInit + // packet is sent here for the write loop to find it. + startKex chan *pendingKex + + // data for host key checking + hostKeyCallback HostKeyCallback + dialAddress string + remoteAddr net.Addr + + // bannerCallback is non-empty if we are the client and it has been set in + // ClientConfig. In that case it is called during the user authentication + // dance to handle a custom server's message. + bannerCallback BannerCallback + + // Algorithms agreed in the last key exchange. + algorithms *algorithms + + readPacketsLeft uint32 + readBytesLeft int64 + + writePacketsLeft uint32 + writeBytesLeft int64 + + // The session ID or nil if first kex did not complete yet. + sessionID []byte +} + +type pendingKex struct { + otherInit []byte + done chan error +} + +func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { + t := &handshakeTransport{ + conn: conn, + serverVersion: serverVersion, + clientVersion: clientVersion, + incoming: make(chan []byte, chanSize), + requestKex: make(chan struct{}, 1), + startKex: make(chan *pendingKex, 1), + + config: config, + } + t.resetReadThresholds() + t.resetWriteThresholds() + + // We always start with a mandatory key exchange. + t.requestKex <- struct{}{} + return t +} + +func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.dialAddress = dialAddr + t.remoteAddr = addr + t.hostKeyCallback = config.HostKeyCallback + t.bannerCallback = config.BannerCallback + if config.HostKeyAlgorithms != nil { + t.hostKeyAlgorithms = config.HostKeyAlgorithms + } else { + t.hostKeyAlgorithms = supportedHostKeyAlgos + } + go t.readLoop() + go t.kexLoop() + return t +} + +func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.hostKeys = config.hostKeys + go t.readLoop() + go t.kexLoop() + return t +} + +func (t *handshakeTransport) getSessionID() []byte { + return t.sessionID +} + +// waitSession waits for the session to be established. This should be +// the first thing to call after instantiating handshakeTransport. +func (t *handshakeTransport) waitSession() error { + p, err := t.readPacket() + if err != nil { + return err + } + if p[0] != msgNewKeys { + return fmt.Errorf("ssh: first packet should be msgNewKeys") + } + + return nil +} + +func (t *handshakeTransport) id() string { + if len(t.hostKeys) > 0 { + return "server" + } + return "client" +} + +func (t *handshakeTransport) printPacket(p []byte, write bool) { + action := "got" + if write { + action = "sent" + } + + if p[0] == msgChannelData || p[0] == msgChannelExtendedData { + log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p)) + } else { + msg, err := decode(p) + log.Printf("%s %s %T %+v (%+v)", t.id(), action, msg, msg, err) + } +} + +func (t *handshakeTransport) readPacket() ([]byte, error) { + p, ok := <-t.incoming + if !ok { + return nil, t.readError + } + return p, nil +} + +func (t *handshakeTransport) readLoop() { + first := true + for { + p, err := t.readOnePacket(first) + first = false + if err != nil { + t.readError = err + close(t.incoming) + break + } + if p[0] == msgIgnore || p[0] == msgDebug { + continue + } + t.incoming <- p + } + + // Stop writers too. + t.recordWriteError(t.readError) + + // Unblock the writer should it wait for this. + close(t.startKex) + + // Don't close t.requestKex; it's also written to from writePacket. +} + +func (t *handshakeTransport) pushPacket(p []byte) error { + if debugHandshake { + t.printPacket(p, true) + } + return t.conn.writePacket(p) +} + +func (t *handshakeTransport) getWriteError() error { + t.mu.Lock() + defer t.mu.Unlock() + return t.writeError +} + +func (t *handshakeTransport) recordWriteError(err error) { + t.mu.Lock() + defer t.mu.Unlock() + if t.writeError == nil && err != nil { + t.writeError = err + } +} + +func (t *handshakeTransport) requestKeyExchange() { + select { + case t.requestKex <- struct{}{}: + default: + // something already requested a kex, so do nothing. + } +} + +func (t *handshakeTransport) resetWriteThresholds() { + t.writePacketsLeft = packetRekeyThreshold + if t.config.RekeyThreshold > 0 { + t.writeBytesLeft = int64(t.config.RekeyThreshold) + } else if t.algorithms != nil { + t.writeBytesLeft = t.algorithms.w.rekeyBytes() + } else { + t.writeBytesLeft = 1 << 30 + } +} + +func (t *handshakeTransport) kexLoop() { + +write: + for t.getWriteError() == nil { + var request *pendingKex + var sent bool + + for request == nil || !sent { + var ok bool + select { + case request, ok = <-t.startKex: + if !ok { + break write + } + case <-t.requestKex: + break + } + + if !sent { + if err := t.sendKexInit(); err != nil { + t.recordWriteError(err) + break + } + sent = true + } + } + + if err := t.getWriteError(); err != nil { + if request != nil { + request.done <- err + } + break + } + + // We're not servicing t.requestKex, but that is OK: + // we never block on sending to t.requestKex. + + // We're not servicing t.startKex, but the remote end + // has just sent us a kexInitMsg, so it can't send + // another key change request, until we close the done + // channel on the pendingKex request. + + err := t.enterKeyExchange(request.otherInit) + + t.mu.Lock() + t.writeError = err + t.sentInitPacket = nil + t.sentInitMsg = nil + + t.resetWriteThresholds() + + // we have completed the key exchange. Since the + // reader is still blocked, it is safe to clear out + // the requestKex channel. This avoids the situation + // where: 1) we consumed our own request for the + // initial kex, and 2) the kex from the remote side + // caused another send on the requestKex channel, + clear: + for { + select { + case <-t.requestKex: + // + default: + break clear + } + } + + request.done <- t.writeError + + // kex finished. Push packets that we received while + // the kex was in progress. Don't look at t.startKex + // and don't increment writtenSinceKex: if we trigger + // another kex while we are still busy with the last + // one, things will become very confusing. + for _, p := range t.pendingPackets { + t.writeError = t.pushPacket(p) + if t.writeError != nil { + break + } + } + t.pendingPackets = t.pendingPackets[:0] + t.mu.Unlock() + } + + // drain startKex channel. We don't service t.requestKex + // because nobody does blocking sends there. + go func() { + for init := range t.startKex { + init.done <- t.writeError + } + }() + + // Unblock reader. + t.conn.Close() +} + +// The protocol uses uint32 for packet counters, so we can't let them +// reach 1<<32. We will actually read and write more packets than +// this, though: the other side may send more packets, and after we +// hit this limit on writing we will send a few more packets for the +// key exchange itself. +const packetRekeyThreshold = (1 << 31) + +func (t *handshakeTransport) resetReadThresholds() { + t.readPacketsLeft = packetRekeyThreshold + if t.config.RekeyThreshold > 0 { + t.readBytesLeft = int64(t.config.RekeyThreshold) + } else if t.algorithms != nil { + t.readBytesLeft = t.algorithms.r.rekeyBytes() + } else { + t.readBytesLeft = 1 << 30 + } +} + +func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) { + p, err := t.conn.readPacket() + if err != nil { + return nil, err + } + + if t.readPacketsLeft > 0 { + t.readPacketsLeft-- + } else { + t.requestKeyExchange() + } + + if t.readBytesLeft > 0 { + t.readBytesLeft -= int64(len(p)) + } else { + t.requestKeyExchange() + } + + if debugHandshake { + t.printPacket(p, false) + } + + if first && p[0] != msgKexInit { + return nil, fmt.Errorf("ssh: first packet should be msgKexInit") + } + + if p[0] != msgKexInit { + return p, nil + } + + firstKex := t.sessionID == nil + + kex := pendingKex{ + done: make(chan error, 1), + otherInit: p, + } + t.startKex <- &kex + err = <-kex.done + + if debugHandshake { + log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err) + } + + if err != nil { + return nil, err + } + + t.resetReadThresholds() + + // By default, a key exchange is hidden from higher layers by + // translating it into msgIgnore. + successPacket := []byte{msgIgnore} + if firstKex { + // sendKexInit() for the first kex waits for + // msgNewKeys so the authentication process is + // guaranteed to happen over an encrypted transport. + successPacket = []byte{msgNewKeys} + } + + return successPacket, nil +} + +// sendKexInit sends a key change message. +func (t *handshakeTransport) sendKexInit() error { + t.mu.Lock() + defer t.mu.Unlock() + if t.sentInitMsg != nil { + // kexInits may be sent either in response to the other side, + // or because our side wants to initiate a key change, so we + // may have already sent a kexInit. In that case, don't send a + // second kexInit. + return nil + } + + msg := &kexInitMsg{ + KexAlgos: t.config.KeyExchanges, + CiphersClientServer: t.config.Ciphers, + CiphersServerClient: t.config.Ciphers, + MACsClientServer: t.config.MACs, + MACsServerClient: t.config.MACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + } + io.ReadFull(rand.Reader, msg.Cookie[:]) + + if len(t.hostKeys) > 0 { + for alg, _ := range t.hostKeys { + msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, alg) + } + } else { + msg.ServerHostKeyAlgos = t.hostKeyAlgorithms + } + packet := Marshal(msg) + + // writePacket destroys the contents, so save a copy. + packetCopy := make([]byte, len(packet)) + copy(packetCopy, packet) + + if err := t.pushPacket(packetCopy); err != nil { + return err + } + + t.sentInitMsg = msg + t.sentInitPacket = packet + + return nil +} + +func (t *handshakeTransport) writePacket(p []byte) error { + switch p[0] { + case msgKexInit: + return errors.New("ssh: only handshakeTransport can send kexInit") + case msgNewKeys: + return errors.New("ssh: only handshakeTransport can send newKeys") + } + + t.mu.Lock() + defer t.mu.Unlock() + if t.writeError != nil { + return t.writeError + } + + if t.sentInitMsg != nil { + // Copy the packet so the writer can reuse the buffer. + cp := make([]byte, len(p)) + copy(cp, p) + t.pendingPackets = append(t.pendingPackets, cp) + return nil + } + + if t.writeBytesLeft > 0 { + t.writeBytesLeft -= int64(len(p)) + } else { + t.requestKeyExchange() + } + + if t.writePacketsLeft > 0 { + t.writePacketsLeft-- + } else { + t.requestKeyExchange() + } + + if err := t.pushPacket(p); err != nil { + t.writeError = err + } + + return nil +} + +func (t *handshakeTransport) Close() error { + return t.conn.Close() +} + +func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { + if debugHandshake { + log.Printf("%s entered key exchange", t.id()) + } + + otherInit := &kexInitMsg{} + if err := Unmarshal(otherInitPacket, otherInit); err != nil { + return err + } + + magics := handshakeMagics{ + clientVersion: t.clientVersion, + serverVersion: t.serverVersion, + clientKexInit: otherInitPacket, + serverKexInit: t.sentInitPacket, + } + + clientInit := otherInit + serverInit := t.sentInitMsg + isClient := len(t.hostKeys) == 0 + if isClient { + clientInit, serverInit = serverInit, clientInit + + magics.clientKexInit = t.sentInitPacket + magics.serverKexInit = otherInitPacket + } + + var err error + t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit) + if err != nil { + return err + } + + // We don't send FirstKexFollows, but we handle receiving it. + // + // RFC 4253 section 7 defines the kex and the agreement method for + // first_kex_packet_follows. It states that the guessed packet + // should be ignored if the "kex algorithm and/or the host + // key algorithm is guessed wrong (server and client have + // different preferred algorithm), or if any of the other + // algorithms cannot be agreed upon". The other algorithms have + // already been checked above so the kex algorithm and host key + // algorithm are checked here. + if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) { + // other side sent a kex message for the wrong algorithm, + // which we have to ignore. + if _, err := t.conn.readPacket(); err != nil { + return err + } + } + + kex, ok := kexAlgoMap[t.algorithms.kex] + if !ok { + return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex) + } + + var result *kexResult + if len(t.hostKeys) > 0 { + result, err = t.server(kex, t.algorithms, &magics) + } else { + result, err = t.client(kex, t.algorithms, &magics) + } + + if err != nil { + return err + } + + if t.sessionID == nil { + t.sessionID = result.H + } + result.SessionID = t.sessionID + + if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil { + return err + } + if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { + return err + } + if packet, err := t.conn.readPacket(); err != nil { + return err + } else if packet[0] != msgNewKeys { + return unexpectedMessageError(msgNewKeys, packet[0]) + } + + return nil +} + +func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { + var hostKey Signer + for alg, k := range t.hostKeys { + if algs.hostKey == alg { + hostKey = k + } + } + + r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey) + return r, err +} + +func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { + result, err := kex.Client(t.conn, t.config.Rand, magics) + if err != nil { + return nil, err + } + + hostKey, err := ParsePublicKey(result.HostKey) + if err != nil { + return nil, err + } + + if err := verifyHostKeySignature(hostKey, algs.hostKey, result); err != nil { + return nil, err + } + + err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) + if err != nil { + return nil, err + } + + return result, nil +} diff --git a/internal/crypto/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf.go b/internal/crypto/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf.go new file mode 100644 index 000000000..af81d2665 --- /dev/null +++ b/internal/crypto/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf.go @@ -0,0 +1,93 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package bcrypt_pbkdf implements bcrypt_pbkdf(3) from OpenBSD. +// +// See https://flak.tedunangst.com/post/bcrypt-pbkdf and +// https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/lib/libutil/bcrypt_pbkdf.c. +package bcrypt_pbkdf + +import ( + "crypto/sha512" + "errors" + "golang.org/x/crypto/blowfish" +) + +const blockSize = 32 + +// Key derives a key from the password, salt and rounds count, returning a +// []byte of length keyLen that can be used as cryptographic key. +func Key(password, salt []byte, rounds, keyLen int) ([]byte, error) { + if rounds < 1 { + return nil, errors.New("bcrypt_pbkdf: number of rounds is too small") + } + if len(password) == 0 { + return nil, errors.New("bcrypt_pbkdf: empty password") + } + if len(salt) == 0 || len(salt) > 1<<20 { + return nil, errors.New("bcrypt_pbkdf: bad salt length") + } + if keyLen > 1024 { + return nil, errors.New("bcrypt_pbkdf: keyLen is too large") + } + + numBlocks := (keyLen + blockSize - 1) / blockSize + key := make([]byte, numBlocks*blockSize) + + h := sha512.New() + h.Write(password) + shapass := h.Sum(nil) + + shasalt := make([]byte, 0, sha512.Size) + cnt, tmp := make([]byte, 4), make([]byte, blockSize) + for block := 1; block <= numBlocks; block++ { + h.Reset() + h.Write(salt) + cnt[0] = byte(block >> 24) + cnt[1] = byte(block >> 16) + cnt[2] = byte(block >> 8) + cnt[3] = byte(block) + h.Write(cnt) + bcryptHash(tmp, shapass, h.Sum(shasalt)) + + out := make([]byte, blockSize) + copy(out, tmp) + for i := 2; i <= rounds; i++ { + h.Reset() + h.Write(tmp) + bcryptHash(tmp, shapass, h.Sum(shasalt)) + for j := 0; j < len(out); j++ { + out[j] ^= tmp[j] + } + } + + for i, v := range out { + key[i*numBlocks+(block-1)] = v + } + } + return key[:keyLen], nil +} + +var magic = []byte("OxychromaticBlowfishSwatDynamite") + +func bcryptHash(out, shapass, shasalt []byte) { + c, err := blowfish.NewSaltedCipher(shapass, shasalt) + if err != nil { + panic(err) + } + for i := 0; i < 64; i++ { + blowfish.ExpandKey(shasalt, c) + blowfish.ExpandKey(shapass, c) + } + copy(out, magic) + for i := 0; i < 32; i += 8 { + for j := 0; j < 64; j++ { + c.Encrypt(out[i:i+8], out[i:i+8]) + } + } + // Swap bytes due to different endianness. + for i := 0; i < 32; i += 4 { + out[i+3], out[i+2], out[i+1], out[i] = out[i], out[i+1], out[i+2], out[i+3] + } +} diff --git a/internal/crypto/ssh/kex.go b/internal/crypto/ssh/kex.go new file mode 100644 index 000000000..7eedb209f --- /dev/null +++ b/internal/crypto/ssh/kex.go @@ -0,0 +1,789 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + + "golang.org/x/crypto/curve25519" +) + +const ( + kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1" + kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" + kexAlgoECDH256 = "ecdh-sha2-nistp256" + kexAlgoECDH384 = "ecdh-sha2-nistp384" + kexAlgoECDH521 = "ecdh-sha2-nistp521" + kexAlgoCurve25519SHA256 = "curve25519-sha256@libssh.org" + + // For the following kex only the client half contains a production + // ready implementation. The server half only consists of a minimal + // implementation to satisfy the automated tests. + kexAlgoDHGEXSHA1 = "diffie-hellman-group-exchange-sha1" + kexAlgoDHGEXSHA256 = "diffie-hellman-group-exchange-sha256" +) + +// kexResult captures the outcome of a key exchange. +type kexResult struct { + // Session hash. See also RFC 4253, section 8. + H []byte + + // Shared secret. See also RFC 4253, section 8. + K []byte + + // Host key as hashed into H. + HostKey []byte + + // Signature of H. + Signature []byte + + // A cryptographic hash function that matches the security + // level of the key exchange algorithm. It is used for + // calculating H, and for deriving keys from H and K. + Hash crypto.Hash + + // The session ID, which is the first H computed. This is used + // to derive key material inside the transport. + SessionID []byte +} + +// handshakeMagics contains data that is always included in the +// session hash. +type handshakeMagics struct { + clientVersion, serverVersion []byte + clientKexInit, serverKexInit []byte +} + +func (m *handshakeMagics) write(w io.Writer) { + writeString(w, m.clientVersion) + writeString(w, m.serverVersion) + writeString(w, m.clientKexInit) + writeString(w, m.serverKexInit) +} + +// kexAlgorithm abstracts different key exchange algorithms. +type kexAlgorithm interface { + // Server runs server-side key agreement, signing the result + // with a hostkey. + Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error) + + // Client runs the client-side key agreement. Caller is + // responsible for verifying the host key signature. + Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) +} + +// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement. +type dhGroup struct { + g, p, pMinus1 *big.Int +} + +func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { + if theirPublic.Cmp(bigOne) <= 0 || theirPublic.Cmp(group.pMinus1) >= 0 { + return nil, errors.New("ssh: DH parameter out of bounds") + } + return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil +} + +func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { + hashFunc := crypto.SHA1 + + var x *big.Int + for { + var err error + if x, err = rand.Int(randSource, group.pMinus1); err != nil { + return nil, err + } + if x.Sign() > 0 { + break + } + } + + X := new(big.Int).Exp(group.g, x, group.p) + kexDHInit := kexDHInitMsg{ + X: X, + } + if err := c.writePacket(Marshal(&kexDHInit)); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexDHReply kexDHReplyMsg + if err = Unmarshal(packet, &kexDHReply); err != nil { + return nil, err + } + + ki, err := group.diffieHellman(kexDHReply.Y, x) + if err != nil { + return nil, err + } + + h := hashFunc.New() + magics.write(h) + writeString(h, kexDHReply.HostKey) + writeInt(h, X) + writeInt(h, kexDHReply.Y) + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: kexDHReply.HostKey, + Signature: kexDHReply.Signature, + Hash: crypto.SHA1, + }, nil +} + +func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { + hashFunc := crypto.SHA1 + packet, err := c.readPacket() + if err != nil { + return + } + var kexDHInit kexDHInitMsg + if err = Unmarshal(packet, &kexDHInit); err != nil { + return + } + + var y *big.Int + for { + if y, err = rand.Int(randSource, group.pMinus1); err != nil { + return + } + if y.Sign() > 0 { + break + } + } + + Y := new(big.Int).Exp(group.g, y, group.p) + ki, err := group.diffieHellman(kexDHInit.X, y) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := hashFunc.New() + magics.write(h) + writeString(h, hostKeyBytes) + writeInt(h, kexDHInit.X) + writeInt(h, Y) + + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, randSource, H) + if err != nil { + return nil, err + } + + kexDHReply := kexDHReplyMsg{ + HostKey: hostKeyBytes, + Y: Y, + Signature: sig, + } + packet = Marshal(&kexDHReply) + + err = c.writePacket(packet) + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: crypto.SHA1, + }, err +} + +// ecdh performs Elliptic Curve Diffie-Hellman key exchange as +// described in RFC 5656, section 4. +type ecdh struct { + curve elliptic.Curve +} + +func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + kexInit := kexECDHInitMsg{ + ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), + } + + serialized := Marshal(&kexInit) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var reply kexECDHReplyMsg + if err = Unmarshal(packet, &reply); err != nil { + return nil, err + } + + x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey) + if err != nil { + return nil, err + } + + // generate shared secret + secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, reply.HostKey) + writeString(h, kexInit.ClientPubKey) + writeString(h, reply.EphemeralPubKey) + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: reply.HostKey, + Signature: reply.Signature, + Hash: ecHash(kex.curve), + }, nil +} + +// unmarshalECKey parses and checks an EC key. +func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) { + x, y = elliptic.Unmarshal(curve, pubkey) + if x == nil { + return nil, nil, errors.New("ssh: elliptic.Unmarshal failure") + } + if !validateECPublicKey(curve, x, y) { + return nil, nil, errors.New("ssh: public key not on curve") + } + return x, y, nil +} + +// validateECPublicKey checks that the point is a valid public key for +// the given curve. See [SEC1], 3.2.2 +func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool { + if x.Sign() == 0 && y.Sign() == 0 { + return false + } + + if x.Cmp(curve.Params().P) >= 0 { + return false + } + + if y.Cmp(curve.Params().P) >= 0 { + return false + } + + if !curve.IsOnCurve(x, y) { + return false + } + + // We don't check if N * PubKey == 0, since + // + // - the NIST curves have cofactor = 1, so this is implicit. + // (We don't foresee an implementation that supports non NIST + // curves) + // + // - for ephemeral keys, we don't need to worry about small + // subgroup attacks. + return true +} + +func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexECDHInit kexECDHInitMsg + if err = Unmarshal(packet, &kexECDHInit); err != nil { + return nil, err + } + + clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey) + if err != nil { + return nil, err + } + + // We could cache this key across multiple users/multiple + // connection attempts, but the benefit is small. OpenSSH + // generates a new key for each incoming connection. + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y) + + // generate shared secret + secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, hostKeyBytes) + writeString(h, kexECDHInit.ClientPubKey) + writeString(h, serializedEphKey) + + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, rand, H) + if err != nil { + return nil, err + } + + reply := kexECDHReplyMsg{ + EphemeralPubKey: serializedEphKey, + HostKey: hostKeyBytes, + Signature: sig, + } + + serialized := Marshal(&reply) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + return &kexResult{ + H: H, + K: K, + HostKey: reply.HostKey, + Signature: sig, + Hash: ecHash(kex.curve), + }, nil +} + +var kexAlgoMap = map[string]kexAlgorithm{} + +func init() { + // This is the group called diffie-hellman-group1-sha1 in RFC + // 4253 and Oakley Group 2 in RFC 2409. + p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) + kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + } + + // This is the group called diffie-hellman-group14-sha1 in RFC + // 4253 and Oakley Group 14 in RFC 3526. + p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + + kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + } + + kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()} + kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()} + kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()} + kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{} + kexAlgoMap[kexAlgoDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1} + kexAlgoMap[kexAlgoDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256} +} + +// curve25519sha256 implements the curve25519-sha256@libssh.org key +// agreement protocol, as described in +// https://git.libssh.org/projects/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt +type curve25519sha256 struct{} + +type curve25519KeyPair struct { + priv [32]byte + pub [32]byte +} + +func (kp *curve25519KeyPair) generate(rand io.Reader) error { + if _, err := io.ReadFull(rand, kp.priv[:]); err != nil { + return err + } + curve25519.ScalarBaseMult(&kp.pub, &kp.priv) + return nil +} + +// curve25519Zeros is just an array of 32 zero bytes so that we have something +// convenient to compare against in order to reject curve25519 points with the +// wrong order. +var curve25519Zeros [32]byte + +func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { + var kp curve25519KeyPair + if err := kp.generate(rand); err != nil { + return nil, err + } + if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var reply kexECDHReplyMsg + if err = Unmarshal(packet, &reply); err != nil { + return nil, err + } + if len(reply.EphemeralPubKey) != 32 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong length") + } + + var servPub, secret [32]byte + copy(servPub[:], reply.EphemeralPubKey) + curve25519.ScalarMult(&secret, &kp.priv, &servPub) + if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong order") + } + + h := crypto.SHA256.New() + magics.write(h) + writeString(h, reply.HostKey) + writeString(h, kp.pub[:]) + writeString(h, reply.EphemeralPubKey) + + ki := new(big.Int).SetBytes(secret[:]) + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: reply.HostKey, + Signature: reply.Signature, + Hash: crypto.SHA256, + }, nil +} + +func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return + } + var kexInit kexECDHInitMsg + if err = Unmarshal(packet, &kexInit); err != nil { + return + } + + if len(kexInit.ClientPubKey) != 32 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong length") + } + + var kp curve25519KeyPair + if err := kp.generate(rand); err != nil { + return nil, err + } + + var clientPub, secret [32]byte + copy(clientPub[:], kexInit.ClientPubKey) + curve25519.ScalarMult(&secret, &kp.priv, &clientPub) + if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong order") + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := crypto.SHA256.New() + magics.write(h) + writeString(h, hostKeyBytes) + writeString(h, kexInit.ClientPubKey) + writeString(h, kp.pub[:]) + + ki := new(big.Int).SetBytes(secret[:]) + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + H := h.Sum(nil) + + sig, err := signAndMarshal(priv, rand, H) + if err != nil { + return nil, err + } + + reply := kexECDHReplyMsg{ + EphemeralPubKey: kp.pub[:], + HostKey: hostKeyBytes, + Signature: sig, + } + if err := c.writePacket(Marshal(&reply)); err != nil { + return nil, err + } + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: crypto.SHA256, + }, nil +} + +// dhGEXSHA implements the diffie-hellman-group-exchange-sha1 and +// diffie-hellman-group-exchange-sha256 key agreement protocols, +// as described in RFC 4419 +type dhGEXSHA struct { + g, p *big.Int + hashFunc crypto.Hash +} + +const numMRTests = 64 + +const ( + dhGroupExchangeMinimumBits = 2048 + dhGroupExchangePreferredBits = 2048 + dhGroupExchangeMaximumBits = 8192 +) + +func (gex *dhGEXSHA) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { + if theirPublic.Sign() <= 0 || theirPublic.Cmp(gex.p) >= 0 { + return nil, fmt.Errorf("ssh: DH parameter out of bounds") + } + return new(big.Int).Exp(theirPublic, myPrivate, gex.p), nil +} + +func (gex dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { + // Send GexRequest + kexDHGexRequest := kexDHGexRequestMsg{ + MinBits: dhGroupExchangeMinimumBits, + PreferedBits: dhGroupExchangePreferredBits, + MaxBits: dhGroupExchangeMaximumBits, + } + if err := c.writePacket(Marshal(&kexDHGexRequest)); err != nil { + return nil, err + } + + // Receive GexGroup + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexDHGexGroup kexDHGexGroupMsg + if err = Unmarshal(packet, &kexDHGexGroup); err != nil { + return nil, err + } + + // reject if p's bit length < dhGroupExchangeMinimumBits or > dhGroupExchangeMaximumBits + if kexDHGexGroup.P.BitLen() < dhGroupExchangeMinimumBits || kexDHGexGroup.P.BitLen() > dhGroupExchangeMaximumBits { + return nil, fmt.Errorf("ssh: server-generated gex p is out of range (%d bits)", kexDHGexGroup.P.BitLen()) + } + + gex.p = kexDHGexGroup.P + gex.g = kexDHGexGroup.G + + // Check if p is safe by verifing that p and (p-1)/2 are primes + one := big.NewInt(1) + var pHalf = &big.Int{} + pHalf.Rsh(gex.p, 1) + if !gex.p.ProbablyPrime(numMRTests) || !pHalf.ProbablyPrime(numMRTests) { + return nil, fmt.Errorf("ssh: server provided gex p is not safe") + } + + // Check if g is safe by verifing that g > 1 and g < p - 1 + var pMinusOne = &big.Int{} + pMinusOne.Sub(gex.p, one) + if gex.g.Cmp(one) != 1 && gex.g.Cmp(pMinusOne) != -1 { + return nil, fmt.Errorf("ssh: server provided gex g is not safe") + } + + // Send GexInit + x, err := rand.Int(randSource, pHalf) + if err != nil { + return nil, err + } + X := new(big.Int).Exp(gex.g, x, gex.p) + kexDHGexInit := kexDHGexInitMsg{ + X: X, + } + if err := c.writePacket(Marshal(&kexDHGexInit)); err != nil { + return nil, err + } + + // Receive GexReply + packet, err = c.readPacket() + if err != nil { + return nil, err + } + + var kexDHGexReply kexDHGexReplyMsg + if err = Unmarshal(packet, &kexDHGexReply); err != nil { + return nil, err + } + + kInt, err := gex.diffieHellman(kexDHGexReply.Y, x) + if err != nil { + return nil, err + } + + // Check if k is safe by verifing that k > 1 and k < p - 1 + if kInt.Cmp(one) != 1 && kInt.Cmp(pMinusOne) != -1 { + return nil, fmt.Errorf("ssh: derived k is not safe") + } + + h := gex.hashFunc.New() + magics.write(h) + writeString(h, kexDHGexReply.HostKey) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits)) + writeInt(h, gex.p) + writeInt(h, gex.g) + writeInt(h, X) + writeInt(h, kexDHGexReply.Y) + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: kexDHGexReply.HostKey, + Signature: kexDHGexReply.Signature, + Hash: gex.hashFunc, + }, nil +} + +// Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256. +// +// This is a minimal implementation to satisfy the automated tests. +func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { + // Receive GexRequest + packet, err := c.readPacket() + if err != nil { + return + } + var kexDHGexRequest kexDHGexRequestMsg + if err = Unmarshal(packet, &kexDHGexRequest); err != nil { + return + } + + // smoosh the user's preferred size into our own limits + if kexDHGexRequest.PreferedBits > dhGroupExchangeMaximumBits { + kexDHGexRequest.PreferedBits = dhGroupExchangeMaximumBits + } + if kexDHGexRequest.PreferedBits < dhGroupExchangeMinimumBits { + kexDHGexRequest.PreferedBits = dhGroupExchangeMinimumBits + } + // fix min/max if they're inconsistent. technically, we could just pout + // and hang up, but there's no harm in giving them the benefit of the + // doubt and just picking a bitsize for them. + if kexDHGexRequest.MinBits > kexDHGexRequest.PreferedBits { + kexDHGexRequest.MinBits = kexDHGexRequest.PreferedBits + } + if kexDHGexRequest.MaxBits < kexDHGexRequest.PreferedBits { + kexDHGexRequest.MaxBits = kexDHGexRequest.PreferedBits + } + + // Send GexGroup + // This is the group called diffie-hellman-group14-sha1 in RFC + // 4253 and Oakley Group 14 in RFC 3526. + p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + gex.p = p + gex.g = big.NewInt(2) + + kexDHGexGroup := kexDHGexGroupMsg{ + P: gex.p, + G: gex.g, + } + if err := c.writePacket(Marshal(&kexDHGexGroup)); err != nil { + return nil, err + } + + // Receive GexInit + packet, err = c.readPacket() + if err != nil { + return + } + var kexDHGexInit kexDHGexInitMsg + if err = Unmarshal(packet, &kexDHGexInit); err != nil { + return + } + + var pHalf = &big.Int{} + pHalf.Rsh(gex.p, 1) + + y, err := rand.Int(randSource, pHalf) + if err != nil { + return + } + + Y := new(big.Int).Exp(gex.g, y, gex.p) + kInt, err := gex.diffieHellman(kexDHGexInit.X, y) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := gex.hashFunc.New() + magics.write(h) + writeString(h, hostKeyBytes) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits)) + writeInt(h, gex.p) + writeInt(h, gex.g) + writeInt(h, kexDHGexInit.X) + writeInt(h, Y) + + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, randSource, H) + if err != nil { + return nil, err + } + + kexDHGexReply := kexDHGexReplyMsg{ + HostKey: hostKeyBytes, + Y: Y, + Signature: sig, + } + packet = Marshal(&kexDHGexReply) + + err = c.writePacket(packet) + + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: gex.hashFunc, + }, err +} diff --git a/internal/crypto/ssh/keys.go b/internal/crypto/ssh/keys.go new file mode 100644 index 000000000..2bc67ad7d --- /dev/null +++ b/internal/crypto/ssh/keys.go @@ -0,0 +1,1493 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/dsa" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/md5" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/asn1" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" + "strings" + + "golang.org/x/crypto/ed25519" + "golang.org/x/crypto/ssh/internal/bcrypt_pbkdf" +) + +// These constants represent the algorithm names for key types supported by this +// package. +const ( + KeyAlgoRSA = "ssh-rsa" + KeyAlgoRSASHA2256 = "rsa-sha2-256" + KeyAlgoRSASHA2512 = "rsa-sha2-512" + KeyAlgoDSA = "ssh-dss" + KeyAlgoECDSA256 = "ecdsa-sha2-nistp256" + KeyAlgoSKECDSA256 = "sk-ecdsa-sha2-nistp256@openssh.com" + KeyAlgoECDSA384 = "ecdsa-sha2-nistp384" + KeyAlgoECDSA521 = "ecdsa-sha2-nistp521" + KeyAlgoED25519 = "ssh-ed25519" + KeyAlgoSKED25519 = "sk-ssh-ed25519@openssh.com" +) + +// These constants represent non-default signature algorithms that are supported +// as algorithm parameters to AlgorithmSigner.SignWithAlgorithm methods. See +// [PROTOCOL.agent] section 4.5.1 and +// https://tools.ietf.org/html/draft-ietf-curdle-rsa-sha2-10 +const ( + SigAlgoRSA = "ssh-rsa" + SigAlgoRSASHA2256 = "rsa-sha2-256" + SigAlgoRSASHA2512 = "rsa-sha2-512" +) + +// parsePubKey parses a public key of the given algorithm. +// Use ParsePublicKey for keys with prepended algorithm. +func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { + switch algo { + case KeyAlgoRSA, KeyAlgoRSASHA2256, KeyAlgoRSASHA2512: + return parseRSA(in) + case KeyAlgoDSA: + return parseDSA(in) + case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: + return parseECDSA(in) + case KeyAlgoSKECDSA256: + return parseSKECDSA(in) + case KeyAlgoED25519: + return parseED25519(in) + case KeyAlgoSKED25519: + return parseSKEd25519(in) + case CertAlgoRSAv01, CertAlgoRSASHA2256v01, CertAlgoRSASHA2512v01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01: + cert, err := parseCert(in, certToPrivAlgo(algo)) + if err != nil { + return nil, nil, err + } + return cert, nil, nil + } + return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", algo) +} + +// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format +// (see sshd(8) manual page) once the options and key type fields have been +// removed. +func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) { + in = bytes.TrimSpace(in) + + i := bytes.IndexAny(in, " \t") + if i == -1 { + i = len(in) + } + base64Key := in[:i] + + key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key))) + n, err := base64.StdEncoding.Decode(key, base64Key) + if err != nil { + return nil, "", err + } + key = key[:n] + out, err = ParsePublicKey(key) + if err != nil { + return nil, "", err + } + comment = string(bytes.TrimSpace(in[i:])) + return out, comment, nil +} + +// ParseKnownHosts parses an entry in the format of the known_hosts file. +// +// The known_hosts format is documented in the sshd(8) manual page. This +// function will parse a single entry from in. On successful return, marker +// will contain the optional marker value (i.e. "cert-authority" or "revoked") +// or else be empty, hosts will contain the hosts that this entry matches, +// pubKey will contain the public key and comment will contain any trailing +// comment at the end of the line. See the sshd(8) manual page for the various +// forms that a host string can take. +// +// The unparsed remainder of the input will be returned in rest. This function +// can be called repeatedly to parse multiple entries. +// +// If no entries were found in the input then err will be io.EOF. Otherwise a +// non-nil err value indicates a parse error. +func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey, comment string, rest []byte, err error) { + for len(in) > 0 { + end := bytes.IndexByte(in, '\n') + if end != -1 { + rest = in[end+1:] + in = in[:end] + } else { + rest = nil + } + + end = bytes.IndexByte(in, '\r') + if end != -1 { + in = in[:end] + } + + in = bytes.TrimSpace(in) + if len(in) == 0 || in[0] == '#' { + in = rest + continue + } + + i := bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + // Strip out the beginning of the known_host key. + // This is either an optional marker or a (set of) hostname(s). + keyFields := bytes.Fields(in) + if len(keyFields) < 3 || len(keyFields) > 5 { + return "", nil, nil, "", nil, errors.New("ssh: invalid entry in known_hosts data") + } + + // keyFields[0] is either "@cert-authority", "@revoked" or a comma separated + // list of hosts + marker := "" + if keyFields[0][0] == '@' { + marker = string(keyFields[0][1:]) + keyFields = keyFields[1:] + } + + hosts := string(keyFields[0]) + // keyFields[1] contains the key type (e.g. “ssh-rsa”). + // However, that information is duplicated inside the + // base64-encoded key and so is ignored here. + + key := bytes.Join(keyFields[2:], []byte(" ")) + if pubKey, comment, err = parseAuthorizedKey(key); err != nil { + return "", nil, nil, "", nil, err + } + + return marker, strings.Split(hosts, ","), pubKey, comment, rest, nil + } + + return "", nil, nil, "", nil, io.EOF +} + +// ParseAuthorizedKeys parses a public key from an authorized_keys +// file used in OpenSSH according to the sshd(8) manual page. +func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { + for len(in) > 0 { + end := bytes.IndexByte(in, '\n') + if end != -1 { + rest = in[end+1:] + in = in[:end] + } else { + rest = nil + } + + end = bytes.IndexByte(in, '\r') + if end != -1 { + in = in[:end] + } + + in = bytes.TrimSpace(in) + if len(in) == 0 || in[0] == '#' { + in = rest + continue + } + + i := bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + return out, comment, options, rest, nil + } + + // No key type recognised. Maybe there's an options field at + // the beginning. + var b byte + inQuote := false + var candidateOptions []string + optionStart := 0 + for i, b = range in { + isEnd := !inQuote && (b == ' ' || b == '\t') + if (b == ',' && !inQuote) || isEnd { + if i-optionStart > 0 { + candidateOptions = append(candidateOptions, string(in[optionStart:i])) + } + optionStart = i + 1 + } + if isEnd { + break + } + if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) { + inQuote = !inQuote + } + } + for i < len(in) && (in[i] == ' ' || in[i] == '\t') { + i++ + } + if i == len(in) { + // Invalid line: unmatched quote + in = rest + continue + } + + in = in[i:] + i = bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + options = candidateOptions + return out, comment, options, rest, nil + } + + in = rest + continue + } + + return nil, "", nil, nil, errors.New("ssh: no key found") +} + +// ParsePublicKey parses an SSH public key formatted for use in +// the SSH wire protocol according to RFC 4253, section 6.6. +func ParsePublicKey(in []byte) (out PublicKey, err error) { + algo, in, ok := parseString(in) + if !ok { + return nil, errShortRead + } + var rest []byte + out, rest, err = parsePubKey(in, string(algo)) + if len(rest) > 0 { + return nil, errors.New("ssh: trailing junk in public key") + } + + return out, err +} + +// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH +// authorized_keys file. The return value ends with newline. +func MarshalAuthorizedKey(key PublicKey) []byte { + b := &bytes.Buffer{} + b.WriteString(key.Type()) + b.WriteByte(' ') + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(key.Marshal()) + e.Close() + b.WriteByte('\n') + return b.Bytes() +} + +// PublicKey is an abstraction of different types of public keys. +type PublicKey interface { + // Type returns the key's type, e.g. "ssh-rsa". + Type() string + + // Marshal returns the serialized key data in SSH wire format, + // with the name prefix. To unmarshal the returned data, use + // the ParsePublicKey function. + Marshal() []byte + + // Verify that sig is a signature on the given data using this + // key. This function will hash the data appropriately first. + Verify(data []byte, sig *Signature) error +} + +// CryptoPublicKey, if implemented by a PublicKey, +// returns the underlying crypto.PublicKey form of the key. +type CryptoPublicKey interface { + CryptoPublicKey() crypto.PublicKey +} + +// A Signer can create signatures that verify against a public key. +type Signer interface { + // PublicKey returns an associated PublicKey instance. + PublicKey() PublicKey + + // Sign returns raw signature for the given data. This method + // will apply the hash specified for the keytype to the data. + Sign(rand io.Reader, data []byte) (*Signature, error) +} + +// A AlgorithmSigner is a Signer that also supports specifying a specific +// algorithm to use for signing. +type AlgorithmSigner interface { + Signer + + // SignWithAlgorithm is like Signer.Sign, but allows specification of a + // non-default signing algorithm. See the SigAlgo* constants in this + // package for signature algorithms supported by this package. Callers may + // pass an empty string for the algorithm in which case the AlgorithmSigner + // will use its default algorithm. + SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) +} + +type rsaPublicKey rsa.PublicKey + +func (r *rsaPublicKey) Type() string { + return "ssh-rsa" +} + +// parseRSA parses an RSA key according to RFC 4253, section 6.6. +func parseRSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + E *big.Int + N *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if w.E.BitLen() > 24 { + return nil, nil, errors.New("ssh: exponent too large") + } + e := w.E.Int64() + if e < 3 || e&1 == 0 { + return nil, nil, errors.New("ssh: incorrect exponent") + } + + var key rsa.PublicKey + key.E = int(e) + key.N = w.N + return (*rsaPublicKey)(&key), w.Rest, nil +} + +func (r *rsaPublicKey) Marshal() []byte { + e := new(big.Int).SetInt64(int64(r.E)) + // RSA publickey struct layout should match the struct used by + // parseRSACert in the x/crypto/ssh/agent package. + wirekey := struct { + Name string + E *big.Int + N *big.Int + }{ + KeyAlgoRSA, + e, + r.N, + } + return Marshal(&wirekey) +} + +func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { + var hash crypto.Hash + switch sig.Format { + case SigAlgoRSA: + hash = crypto.SHA1 + case SigAlgoRSASHA2256: + hash = crypto.SHA256 + case SigAlgoRSASHA2512: + hash = crypto.SHA512 + default: + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) + } + h := hash.New() + h.Write(data) + digest := h.Sum(nil) + return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), hash, digest, sig.Blob) +} + +func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*rsa.PublicKey)(r) +} + +type dsaPublicKey dsa.PublicKey + +func (k *dsaPublicKey) Type() string { + return "ssh-dss" +} + +func checkDSAParams(param *dsa.Parameters) error { + // SSH specifies FIPS 186-2, which only provided a single size + // (1024 bits) DSA key. FIPS 186-3 allows for larger key + // sizes, which would confuse SSH. + if l := param.P.BitLen(); l != 1024 { + return fmt.Errorf("ssh: unsupported DSA key size %d", l) + } + + return nil +} + +// parseDSA parses an DSA key according to RFC 4253, section 6.6. +func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + P, Q, G, Y *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + param := dsa.Parameters{ + P: w.P, + Q: w.Q, + G: w.G, + } + if err := checkDSAParams(¶m); err != nil { + return nil, nil, err + } + + key := &dsaPublicKey{ + Parameters: param, + Y: w.Y, + } + return key, w.Rest, nil +} + +func (k *dsaPublicKey) Marshal() []byte { + // DSA publickey struct layout should match the struct used by + // parseDSACert in the x/crypto/ssh/agent package. + w := struct { + Name string + P, Q, G, Y *big.Int + }{ + k.Type(), + k.P, + k.Q, + k.G, + k.Y, + } + + return Marshal(&w) +} + +func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + h := crypto.SHA1.New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 4253, section 6.6, + // The value for 'dss_signature_blob' is encoded as a string containing + // r, followed by s (which are 160-bit integers, without lengths or + // padding, unsigned, and in network byte order). + // For DSS purposes, sig.Blob should be exactly 40 bytes in length. + if len(sig.Blob) != 40 { + return errors.New("ssh: DSA signature parse error") + } + r := new(big.Int).SetBytes(sig.Blob[:20]) + s := new(big.Int).SetBytes(sig.Blob[20:]) + if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *dsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*dsa.PublicKey)(k) +} + +type dsaPrivateKey struct { + *dsa.PrivateKey +} + +func (k *dsaPrivateKey) PublicKey() PublicKey { + return (*dsaPublicKey)(&k.PrivateKey.PublicKey) +} + +func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { + return k.SignWithAlgorithm(rand, data, "") +} + +func (k *dsaPrivateKey) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + if algorithm != "" && algorithm != k.PublicKey().Type() { + return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm) + } + + h := crypto.SHA1.New() + h.Write(data) + digest := h.Sum(nil) + r, s, err := dsa.Sign(rand, k.PrivateKey, digest) + if err != nil { + return nil, err + } + + sig := make([]byte, 40) + rb := r.Bytes() + sb := s.Bytes() + + copy(sig[20-len(rb):20], rb) + copy(sig[40-len(sb):], sb) + + return &Signature{ + Format: k.PublicKey().Type(), + Blob: sig, + }, nil +} + +type ecdsaPublicKey ecdsa.PublicKey + +func (k *ecdsaPublicKey) Type() string { + return "ecdsa-sha2-" + k.nistID() +} + +func (k *ecdsaPublicKey) nistID() string { + switch k.Params().BitSize { + case 256: + return "nistp256" + case 384: + return "nistp384" + case 521: + return "nistp521" + } + panic("ssh: unsupported ecdsa key size") +} + +type ed25519PublicKey ed25519.PublicKey + +func (k ed25519PublicKey) Type() string { + return KeyAlgoED25519 +} + +func parseED25519(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + KeyBytes []byte + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if l := len(w.KeyBytes); l != ed25519.PublicKeySize { + return nil, nil, fmt.Errorf("invalid size %d for Ed25519 public key", l) + } + + return ed25519PublicKey(w.KeyBytes), w.Rest, nil +} + +func (k ed25519PublicKey) Marshal() []byte { + w := struct { + Name string + KeyBytes []byte + }{ + KeyAlgoED25519, + []byte(k), + } + return Marshal(&w) +} + +func (k ed25519PublicKey) Verify(b []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + if l := len(k); l != ed25519.PublicKeySize { + return fmt.Errorf("ssh: invalid size %d for Ed25519 public key", l) + } + + if ok := ed25519.Verify(ed25519.PublicKey(k), b, sig.Blob); !ok { + return errors.New("ssh: signature did not verify") + } + + return nil +} + +func (k ed25519PublicKey) CryptoPublicKey() crypto.PublicKey { + return ed25519.PublicKey(k) +} + +func supportedEllipticCurve(curve elliptic.Curve) bool { + return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521() +} + +// ecHash returns the hash to match the given elliptic curve, see RFC +// 5656, section 6.2.1 +func ecHash(curve elliptic.Curve) crypto.Hash { + bitSize := curve.Params().BitSize + switch { + case bitSize <= 256: + return crypto.SHA256 + case bitSize <= 384: + return crypto.SHA384 + } + return crypto.SHA512 +} + +// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1. +func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + Curve string + KeyBytes []byte + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := new(ecdsa.PublicKey) + + switch w.Curve { + case "nistp256": + key.Curve = elliptic.P256() + case "nistp384": + key.Curve = elliptic.P384() + case "nistp521": + key.Curve = elliptic.P521() + default: + return nil, nil, errors.New("ssh: unsupported curve") + } + + key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) + if key.X == nil || key.Y == nil { + return nil, nil, errors.New("ssh: invalid curve point") + } + return (*ecdsaPublicKey)(key), w.Rest, nil +} + +func (k *ecdsaPublicKey) Marshal() []byte { + // See RFC 5656, section 3.1. + keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y) + // ECDSA publickey struct layout should match the struct used by + // parseECDSACert in the x/crypto/ssh/agent package. + w := struct { + Name string + ID string + Key []byte + }{ + k.Type(), + k.nistID(), + keyBytes, + } + + return Marshal(&w) +} + +func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + + h := ecHash(k.Curve).New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 5656, section 3.1.2, + // The ecdsa_signature_blob value has the following specific encoding: + // mpint r + // mpint s + var ecSig struct { + R *big.Int + S *big.Int + } + + if err := Unmarshal(sig.Blob, &ecSig); err != nil { + return err + } + + if ecdsa.Verify((*ecdsa.PublicKey)(k), digest, ecSig.R, ecSig.S) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *ecdsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*ecdsa.PublicKey)(k) +} + +// skFields holds the additional fields present in U2F/FIDO2 signatures. +// See openssh/PROTOCOL.u2f 'SSH U2F Signatures' for details. +type skFields struct { + // Flags contains U2F/FIDO2 flags such as 'user present' + Flags byte + // Counter is a monotonic signature counter which can be + // used to detect concurrent use of a private key, should + // it be extracted from hardware. + Counter uint32 +} + +type skECDSAPublicKey struct { + // application is a URL-like string, typically "ssh:" for SSH. + // see openssh/PROTOCOL.u2f for details. + application string + ecdsa.PublicKey +} + +func (k *skECDSAPublicKey) Type() string { + return KeyAlgoSKECDSA256 +} + +func (k *skECDSAPublicKey) nistID() string { + return "nistp256" +} + +func parseSKECDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + Curve string + KeyBytes []byte + Application string + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := new(skECDSAPublicKey) + key.application = w.Application + + if w.Curve != "nistp256" { + return nil, nil, errors.New("ssh: unsupported curve") + } + key.Curve = elliptic.P256() + + key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) + if key.X == nil || key.Y == nil { + return nil, nil, errors.New("ssh: invalid curve point") + } + + return key, w.Rest, nil +} + +func (k *skECDSAPublicKey) Marshal() []byte { + // See RFC 5656, section 3.1. + keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y) + w := struct { + Name string + ID string + Key []byte + Application string + }{ + k.Type(), + k.nistID(), + keyBytes, + k.application, + } + + return Marshal(&w) +} + +func (k *skECDSAPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + + h := ecHash(k.Curve).New() + h.Write([]byte(k.application)) + appDigest := h.Sum(nil) + + h.Reset() + h.Write(data) + dataDigest := h.Sum(nil) + + var ecSig struct { + R *big.Int + S *big.Int + } + if err := Unmarshal(sig.Blob, &ecSig); err != nil { + return err + } + + var skf skFields + if err := Unmarshal(sig.Rest, &skf); err != nil { + return err + } + + blob := struct { + ApplicationDigest []byte `ssh:"rest"` + Flags byte + Counter uint32 + MessageDigest []byte `ssh:"rest"` + }{ + appDigest, + skf.Flags, + skf.Counter, + dataDigest, + } + + original := Marshal(blob) + + h.Reset() + h.Write(original) + digest := h.Sum(nil) + + if ecdsa.Verify((*ecdsa.PublicKey)(&k.PublicKey), digest, ecSig.R, ecSig.S) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +type skEd25519PublicKey struct { + // application is a URL-like string, typically "ssh:" for SSH. + // see openssh/PROTOCOL.u2f for details. + application string + ed25519.PublicKey +} + +func (k *skEd25519PublicKey) Type() string { + return KeyAlgoSKED25519 +} + +func parseSKEd25519(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + KeyBytes []byte + Application string + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if l := len(w.KeyBytes); l != ed25519.PublicKeySize { + return nil, nil, fmt.Errorf("invalid size %d for Ed25519 public key", l) + } + + key := new(skEd25519PublicKey) + key.application = w.Application + key.PublicKey = ed25519.PublicKey(w.KeyBytes) + + return key, w.Rest, nil +} + +func (k *skEd25519PublicKey) Marshal() []byte { + w := struct { + Name string + KeyBytes []byte + Application string + }{ + KeyAlgoSKED25519, + []byte(k.PublicKey), + k.application, + } + return Marshal(&w) +} + +func (k *skEd25519PublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + if l := len(k.PublicKey); l != ed25519.PublicKeySize { + return fmt.Errorf("invalid size %d for Ed25519 public key", l) + } + + h := sha256.New() + h.Write([]byte(k.application)) + appDigest := h.Sum(nil) + + h.Reset() + h.Write(data) + dataDigest := h.Sum(nil) + + var edSig struct { + Signature []byte `ssh:"rest"` + } + + if err := Unmarshal(sig.Blob, &edSig); err != nil { + return err + } + + var skf skFields + if err := Unmarshal(sig.Rest, &skf); err != nil { + return err + } + + blob := struct { + ApplicationDigest []byte `ssh:"rest"` + Flags byte + Counter uint32 + MessageDigest []byte `ssh:"rest"` + }{ + appDigest, + skf.Flags, + skf.Counter, + dataDigest, + } + + original := Marshal(blob) + + if ok := ed25519.Verify(k.PublicKey, original, edSig.Signature); !ok { + return errors.New("ssh: signature did not verify") + } + + return nil +} + +// NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey, +// *ecdsa.PrivateKey or any other crypto.Signer and returns a +// corresponding Signer instance. ECDSA keys must use P-256, P-384 or +// P-521. DSA keys must use parameter size L1024N160. +func NewSignerFromKey(key interface{}) (Signer, error) { + switch key := key.(type) { + case crypto.Signer: + return NewSignerFromSigner(key) + case *dsa.PrivateKey: + return newDSAPrivateKey(key) + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } +} + +type defaultAlgorithmSigner struct { + AlgorithmSigner + algorithm string +} + +func (s *defaultAlgorithmSigner) PublicKey() PublicKey { + return s.AlgorithmSigner.PublicKey() +} + +func (s *defaultAlgorithmSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.AlgorithmSigner.SignWithAlgorithm(rand, data, s.algorithm) +} + +func (s *defaultAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + return s.AlgorithmSigner.SignWithAlgorithm(rand, data, algorithm) +} + +func newDSAPrivateKey(key *dsa.PrivateKey) (Signer, error) { + if err := checkDSAParams(&key.PublicKey.Parameters); err != nil { + return nil, err + } + + return &dsaPrivateKey{key}, nil +} + +type wrappedSigner struct { + signer crypto.Signer + pubKey PublicKey +} + +// NewSignerFromSigner takes any crypto.Signer implementation and +// returns a corresponding Signer interface. This can be used, for +// example, with keys kept in hardware modules. +func NewSignerFromSigner(signer crypto.Signer) (Signer, error) { + pubKey, err := NewPublicKey(signer.Public()) + if err != nil { + return nil, err + } + + return &wrappedSigner{signer, pubKey}, nil +} + +func (s *wrappedSigner) PublicKey() PublicKey { + return s.pubKey +} + +func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.SignWithAlgorithm(rand, data, "") +} + +func (s *wrappedSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + var hashFunc crypto.Hash + + if _, ok := s.pubKey.(*rsaPublicKey); ok { + // RSA keys support a few hash functions determined by the requested signature algorithm + switch algorithm { + case "", SigAlgoRSA: + algorithm = SigAlgoRSA + hashFunc = crypto.SHA1 + case SigAlgoRSASHA2256: + hashFunc = crypto.SHA256 + case SigAlgoRSASHA2512: + hashFunc = crypto.SHA512 + default: + return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm) + } + } else { + // The only supported algorithm for all other key types is the same as the type of the key + if algorithm == "" { + algorithm = s.pubKey.Type() + } else if algorithm != s.pubKey.Type() { + return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm) + } + + switch key := s.pubKey.(type) { + case *dsaPublicKey: + hashFunc = crypto.SHA1 + case *ecdsaPublicKey: + hashFunc = ecHash(key.Curve) + case ed25519PublicKey: + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } + } + + var digest []byte + if hashFunc != 0 { + h := hashFunc.New() + h.Write(data) + digest = h.Sum(nil) + } else { + digest = data + } + + signature, err := s.signer.Sign(rand, digest, hashFunc) + if err != nil { + return nil, err + } + + // crypto.Signer.Sign is expected to return an ASN.1-encoded signature + // for ECDSA and DSA, but that's not the encoding expected by SSH, so + // re-encode. + switch s.pubKey.(type) { + case *ecdsaPublicKey, *dsaPublicKey: + type asn1Signature struct { + R, S *big.Int + } + asn1Sig := new(asn1Signature) + _, err := asn1.Unmarshal(signature, asn1Sig) + if err != nil { + return nil, err + } + + switch s.pubKey.(type) { + case *ecdsaPublicKey: + signature = Marshal(asn1Sig) + + case *dsaPublicKey: + signature = make([]byte, 40) + r := asn1Sig.R.Bytes() + s := asn1Sig.S.Bytes() + copy(signature[20-len(r):20], r) + copy(signature[40-len(s):40], s) + } + } + + return &Signature{ + Format: algorithm, + Blob: signature, + }, nil +} + +// NewPublicKey takes an *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey, +// or ed25519.PublicKey returns a corresponding PublicKey instance. +// ECDSA keys must use P-256, P-384 or P-521. +func NewPublicKey(key interface{}) (PublicKey, error) { + switch key := key.(type) { + case *rsa.PublicKey: + return (*rsaPublicKey)(key), nil + case *ecdsa.PublicKey: + if !supportedEllipticCurve(key.Curve) { + return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported") + } + return (*ecdsaPublicKey)(key), nil + case *dsa.PublicKey: + return (*dsaPublicKey)(key), nil + case ed25519.PublicKey: + if l := len(key); l != ed25519.PublicKeySize { + return nil, fmt.Errorf("ssh: invalid size %d for Ed25519 public key", l) + } + return ed25519PublicKey(key), nil + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } +} + +// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports +// the same keys as ParseRawPrivateKey. If the private key is encrypted, it +// will return a PassphraseMissingError. +func ParsePrivateKey(pemBytes []byte) (Signer, error) { + key, err := ParseRawPrivateKey(pemBytes) + if err != nil { + return nil, err + } + + return NewSignerFromKey(key) +} + +// ParsePrivateKeyWithPassphrase returns a Signer from a PEM encoded private +// key and passphrase. It supports the same keys as +// ParseRawPrivateKeyWithPassphrase. +func ParsePrivateKeyWithPassphrase(pemBytes, passphrase []byte) (Signer, error) { + key, err := ParseRawPrivateKeyWithPassphrase(pemBytes, passphrase) + if err != nil { + return nil, err + } + + return NewSignerFromKey(key) +} + +// encryptedBlock tells whether a private key is +// encrypted by examining its Proc-Type header +// for a mention of ENCRYPTED +// according to RFC 1421 Section 4.6.1.1. +func encryptedBlock(block *pem.Block) bool { + return strings.Contains(block.Headers["Proc-Type"], "ENCRYPTED") +} + +// A PassphraseMissingError indicates that parsing this private key requires a +// passphrase. Use ParsePrivateKeyWithPassphrase. +type PassphraseMissingError struct { + // PublicKey will be set if the private key format includes an unencrypted + // public key along with the encrypted private key. + PublicKey PublicKey +} + +func (*PassphraseMissingError) Error() string { + return "ssh: this private key is passphrase protected" +} + +// ParseRawPrivateKey returns a private key from a PEM encoded private key. It +// supports RSA (PKCS#1), PKCS#8, DSA (OpenSSL), and ECDSA private keys. If the +// private key is encrypted, it will return a PassphraseMissingError. +func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("ssh: no key found") + } + + if encryptedBlock(block) { + return nil, &PassphraseMissingError{} + } + + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + // RFC5208 - https://tools.ietf.org/html/rfc5208 + case "PRIVATE KEY": + return x509.ParsePKCS8PrivateKey(block.Bytes) + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(block.Bytes) + case "DSA PRIVATE KEY": + return ParseDSAPrivateKey(block.Bytes) + case "OPENSSH PRIVATE KEY": + return parseOpenSSHPrivateKey(block.Bytes, unencryptedOpenSSHKey) + default: + return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) + } +} + +// ParseRawPrivateKeyWithPassphrase returns a private key decrypted with +// passphrase from a PEM encoded private key. If the passphrase is wrong, it +// will return x509.IncorrectPasswordError. +func ParseRawPrivateKeyWithPassphrase(pemBytes, passphrase []byte) (interface{}, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("ssh: no key found") + } + + if block.Type == "OPENSSH PRIVATE KEY" { + return parseOpenSSHPrivateKey(block.Bytes, passphraseProtectedOpenSSHKey(passphrase)) + } + + if !encryptedBlock(block) || !x509.IsEncryptedPEMBlock(block) { + return nil, errors.New("ssh: not an encrypted key") + } + + buf, err := x509.DecryptPEMBlock(block, passphrase) + if err != nil { + if err == x509.IncorrectPasswordError { + return nil, err + } + return nil, fmt.Errorf("ssh: cannot decode encrypted private keys: %v", err) + } + + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(buf) + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(buf) + case "DSA PRIVATE KEY": + return ParseDSAPrivateKey(buf) + default: + return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) + } +} + +// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as +// specified by the OpenSSL DSA man page. +func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { + var k struct { + Version int + P *big.Int + Q *big.Int + G *big.Int + Pub *big.Int + Priv *big.Int + } + rest, err := asn1.Unmarshal(der, &k) + if err != nil { + return nil, errors.New("ssh: failed to parse DSA key: " + err.Error()) + } + if len(rest) > 0 { + return nil, errors.New("ssh: garbage after DSA key") + } + + return &dsa.PrivateKey{ + PublicKey: dsa.PublicKey{ + Parameters: dsa.Parameters{ + P: k.P, + Q: k.Q, + G: k.G, + }, + Y: k.Pub, + }, + X: k.Priv, + }, nil +} + +func unencryptedOpenSSHKey(cipherName, kdfName, kdfOpts string, privKeyBlock []byte) ([]byte, error) { + if kdfName != "none" || cipherName != "none" { + return nil, &PassphraseMissingError{} + } + if kdfOpts != "" { + return nil, errors.New("ssh: invalid openssh private key") + } + return privKeyBlock, nil +} + +func passphraseProtectedOpenSSHKey(passphrase []byte) openSSHDecryptFunc { + return func(cipherName, kdfName, kdfOpts string, privKeyBlock []byte) ([]byte, error) { + if kdfName == "none" || cipherName == "none" { + return nil, errors.New("ssh: key is not password protected") + } + if kdfName != "bcrypt" { + return nil, fmt.Errorf("ssh: unknown KDF %q, only supports %q", kdfName, "bcrypt") + } + + var opts struct { + Salt string + Rounds uint32 + } + if err := Unmarshal([]byte(kdfOpts), &opts); err != nil { + return nil, err + } + + k, err := bcrypt_pbkdf.Key(passphrase, []byte(opts.Salt), int(opts.Rounds), 32+16) + if err != nil { + return nil, err + } + key, iv := k[:32], k[32:] + + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + switch cipherName { + case "aes256-ctr": + ctr := cipher.NewCTR(c, iv) + ctr.XORKeyStream(privKeyBlock, privKeyBlock) + case "aes256-cbc": + if len(privKeyBlock)%c.BlockSize() != 0 { + return nil, fmt.Errorf("ssh: invalid encrypted private key length, not a multiple of the block size") + } + cbc := cipher.NewCBCDecrypter(c, iv) + cbc.CryptBlocks(privKeyBlock, privKeyBlock) + default: + return nil, fmt.Errorf("ssh: unknown cipher %q, only supports %q or %q", cipherName, "aes256-ctr", "aes256-cbc") + } + + return privKeyBlock, nil + } +} + +type openSSHDecryptFunc func(CipherName, KdfName, KdfOpts string, PrivKeyBlock []byte) ([]byte, error) + +// parseOpenSSHPrivateKey parses an OpenSSH private key, using the decrypt +// function to unwrap the encrypted portion. unencryptedOpenSSHKey can be used +// as the decrypt function to parse an unencrypted private key. See +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key. +func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.PrivateKey, error) { + const magic = "openssh-key-v1\x00" + if len(key) < len(magic) || string(key[:len(magic)]) != magic { + return nil, errors.New("ssh: invalid openssh private key format") + } + remaining := key[len(magic):] + + var w struct { + CipherName string + KdfName string + KdfOpts string + NumKeys uint32 + PubKey []byte + PrivKeyBlock []byte + } + + if err := Unmarshal(remaining, &w); err != nil { + return nil, err + } + if w.NumKeys != 1 { + // We only support single key files, and so does OpenSSH. + // https://github.com/openssh/openssh-portable/blob/4103a3ec7/sshkey.c#L4171 + return nil, errors.New("ssh: multi-key files are not supported") + } + + privKeyBlock, err := decrypt(w.CipherName, w.KdfName, w.KdfOpts, w.PrivKeyBlock) + if err != nil { + if err, ok := err.(*PassphraseMissingError); ok { + pub, errPub := ParsePublicKey(w.PubKey) + if errPub != nil { + return nil, fmt.Errorf("ssh: failed to parse embedded public key: %v", errPub) + } + err.PublicKey = pub + } + return nil, err + } + + pk1 := struct { + Check1 uint32 + Check2 uint32 + Keytype string + Rest []byte `ssh:"rest"` + }{} + + if err := Unmarshal(privKeyBlock, &pk1); err != nil || pk1.Check1 != pk1.Check2 { + if w.CipherName != "none" { + return nil, x509.IncorrectPasswordError + } + return nil, errors.New("ssh: malformed OpenSSH key") + } + + switch pk1.Keytype { + case KeyAlgoRSA: + // https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L2760-L2773 + key := struct { + N *big.Int + E *big.Int + D *big.Int + Iqmp *big.Int + P *big.Int + Q *big.Int + Comment string + Pad []byte `ssh:"rest"` + }{} + + if err := Unmarshal(pk1.Rest, &key); err != nil { + return nil, err + } + + if err := checkOpenSSHKeyPadding(key.Pad); err != nil { + return nil, err + } + + pk := &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: key.N, + E: int(key.E.Int64()), + }, + D: key.D, + Primes: []*big.Int{key.P, key.Q}, + } + + if err := pk.Validate(); err != nil { + return nil, err + } + + pk.Precompute() + + return pk, nil + case KeyAlgoED25519: + key := struct { + Pub []byte + Priv []byte + Comment string + Pad []byte `ssh:"rest"` + }{} + + if err := Unmarshal(pk1.Rest, &key); err != nil { + return nil, err + } + + if len(key.Priv) != ed25519.PrivateKeySize { + return nil, errors.New("ssh: private key unexpected length") + } + + if err := checkOpenSSHKeyPadding(key.Pad); err != nil { + return nil, err + } + + pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize)) + copy(pk, key.Priv) + return &pk, nil + case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: + key := struct { + Curve string + Pub []byte + D *big.Int + Comment string + Pad []byte `ssh:"rest"` + }{} + + if err := Unmarshal(pk1.Rest, &key); err != nil { + return nil, err + } + + if err := checkOpenSSHKeyPadding(key.Pad); err != nil { + return nil, err + } + + var curve elliptic.Curve + switch key.Curve { + case "nistp256": + curve = elliptic.P256() + case "nistp384": + curve = elliptic.P384() + case "nistp521": + curve = elliptic.P521() + default: + return nil, errors.New("ssh: unhandled elliptic curve: " + key.Curve) + } + + X, Y := elliptic.Unmarshal(curve, key.Pub) + if X == nil || Y == nil { + return nil, errors.New("ssh: failed to unmarshal public key") + } + + if key.D.Cmp(curve.Params().N) >= 0 { + return nil, errors.New("ssh: scalar is out of range") + } + + x, y := curve.ScalarBaseMult(key.D.Bytes()) + if x.Cmp(X) != 0 || y.Cmp(Y) != 0 { + return nil, errors.New("ssh: public key does not match private key") + } + + return &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: curve, + X: X, + Y: Y, + }, + D: key.D, + }, nil + default: + return nil, errors.New("ssh: unhandled key type") + } +} + +func checkOpenSSHKeyPadding(pad []byte) error { + for i, b := range pad { + if int(b) != i+1 { + return errors.New("ssh: padding not as expected") + } + } + return nil +} + +// FingerprintLegacyMD5 returns the user presentation of the key's +// fingerprint as described by RFC 4716 section 4. +func FingerprintLegacyMD5(pubKey PublicKey) string { + md5sum := md5.Sum(pubKey.Marshal()) + hexarray := make([]string, len(md5sum)) + for i, c := range md5sum { + hexarray[i] = hex.EncodeToString([]byte{c}) + } + return strings.Join(hexarray, ":") +} + +// FingerprintSHA256 returns the user presentation of the key's +// fingerprint as unpadded base64 encoded sha256 hash. +// This format was introduced from OpenSSH 6.8. +// https://www.openssh.com/txt/release-6.8 +// https://tools.ietf.org/html/rfc4648#section-3.2 (unpadded base64 encoding) +func FingerprintSHA256(pubKey PublicKey) string { + sha256sum := sha256.Sum256(pubKey.Marshal()) + hash := base64.RawStdEncoding.EncodeToString(sha256sum[:]) + return "SHA256:" + hash +} diff --git a/internal/crypto/ssh/mac.go b/internal/crypto/ssh/mac.go new file mode 100644 index 000000000..c07a06285 --- /dev/null +++ b/internal/crypto/ssh/mac.go @@ -0,0 +1,61 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Message authentication support + +import ( + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" + "hash" +) + +type macMode struct { + keySize int + etm bool + new func(key []byte) hash.Hash +} + +// truncatingMAC wraps around a hash.Hash and truncates the output digest to +// a given size. +type truncatingMAC struct { + length int + hmac hash.Hash +} + +func (t truncatingMAC) Write(data []byte) (int, error) { + return t.hmac.Write(data) +} + +func (t truncatingMAC) Sum(in []byte) []byte { + out := t.hmac.Sum(in) + return out[:len(in)+t.length] +} + +func (t truncatingMAC) Reset() { + t.hmac.Reset() +} + +func (t truncatingMAC) Size() int { + return t.length +} + +func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } + +var macModes = map[string]*macMode{ + "hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash { + return hmac.New(sha256.New, key) + }}, + "hmac-sha2-256": {32, false, func(key []byte) hash.Hash { + return hmac.New(sha256.New, key) + }}, + "hmac-sha1": {20, false, func(key []byte) hash.Hash { + return hmac.New(sha1.New, key) + }}, + "hmac-sha1-96": {20, false, func(key []byte) hash.Hash { + return truncatingMAC{12, hmac.New(sha1.New, key)} + }}, +} diff --git a/internal/crypto/ssh/messages.go b/internal/crypto/ssh/messages.go new file mode 100644 index 000000000..ac41a4168 --- /dev/null +++ b/internal/crypto/ssh/messages.go @@ -0,0 +1,866 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + "reflect" + "strconv" + "strings" +) + +// These are SSH message type numbers. They are scattered around several +// documents but many were taken from [SSH-PARAMETERS]. +const ( + msgIgnore = 2 + msgUnimplemented = 3 + msgDebug = 4 + msgNewKeys = 21 +) + +// SSH messages: +// +// These structures mirror the wire format of the corresponding SSH messages. +// They are marshaled using reflection with the marshal and unmarshal functions +// in this file. The only wrinkle is that a final member of type []byte with a +// ssh tag of "rest" receives the remainder of a packet when unmarshaling. + +// See RFC 4253, section 11.1. +const msgDisconnect = 1 + +// disconnectMsg is the message that signals a disconnect. It is also +// the error type returned from mux.Wait() +type disconnectMsg struct { + Reason uint32 `sshtype:"1"` + Message string + Language string +} + +func (d *disconnectMsg) Error() string { + return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message) +} + +// See RFC 4253, section 7.1. +const msgKexInit = 20 + +type kexInitMsg struct { + Cookie [16]byte `sshtype:"20"` + KexAlgos []string + ServerHostKeyAlgos []string + CiphersClientServer []string + CiphersServerClient []string + MACsClientServer []string + MACsServerClient []string + CompressionClientServer []string + CompressionServerClient []string + LanguagesClientServer []string + LanguagesServerClient []string + FirstKexFollows bool + Reserved uint32 +} + +// See RFC 4253, section 8. + +// Diffie-Helman +const msgKexDHInit = 30 + +type kexDHInitMsg struct { + X *big.Int `sshtype:"30"` +} + +const msgKexECDHInit = 30 + +type kexECDHInitMsg struct { + ClientPubKey []byte `sshtype:"30"` +} + +const msgKexECDHReply = 31 + +type kexECDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + EphemeralPubKey []byte + Signature []byte +} + +const msgKexDHReply = 31 + +type kexDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + Y *big.Int + Signature []byte +} + +// See RFC 4419, section 5. +const msgKexDHGexGroup = 31 + +type kexDHGexGroupMsg struct { + P *big.Int `sshtype:"31"` + G *big.Int +} + +const msgKexDHGexInit = 32 + +type kexDHGexInitMsg struct { + X *big.Int `sshtype:"32"` +} + +const msgKexDHGexReply = 33 + +type kexDHGexReplyMsg struct { + HostKey []byte `sshtype:"33"` + Y *big.Int + Signature []byte +} + +const msgKexDHGexRequest = 34 + +type kexDHGexRequestMsg struct { + MinBits uint32 `sshtype:"34"` + PreferedBits uint32 + MaxBits uint32 +} + +// See RFC 4253, section 10. +const msgServiceRequest = 5 + +type serviceRequestMsg struct { + Service string `sshtype:"5"` +} + +// See RFC 4253, section 10. +const msgServiceAccept = 6 + +type serviceAcceptMsg struct { + Service string `sshtype:"6"` +} + +// See RFC 4252, section 5. +const msgUserAuthRequest = 50 + +type userAuthRequestMsg struct { + User string `sshtype:"50"` + Service string + Method string + Payload []byte `ssh:"rest"` +} + +// Used for debug printouts of packets. +type userAuthSuccessMsg struct { +} + +// See RFC 4252, section 5.1 +const msgUserAuthFailure = 51 + +type userAuthFailureMsg struct { + Methods []string `sshtype:"51"` + PartialSuccess bool +} + +// See RFC 4252, section 5.1 +const msgUserAuthSuccess = 52 + +// See RFC 4252, section 5.4 +const msgUserAuthBanner = 53 + +type userAuthBannerMsg struct { + Message string `sshtype:"53"` + // unused, but required to allow message parsing + Language string +} + +// See RFC 4256, section 3.2 +const msgUserAuthInfoRequest = 60 +const msgUserAuthInfoResponse = 61 + +type userAuthInfoRequestMsg struct { + User string `sshtype:"60"` + Instruction string + DeprecatedLanguage string + NumPrompts uint32 + Prompts []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpen = 90 + +type channelOpenMsg struct { + ChanType string `sshtype:"90"` + PeersID uint32 + PeersWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +const msgChannelExtendedData = 95 +const msgChannelData = 94 + +// Used for debug print outs of packets. +type channelDataMsg struct { + PeersID uint32 `sshtype:"94"` + Length uint32 + Rest []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenConfirm = 91 + +type channelOpenConfirmMsg struct { + PeersID uint32 `sshtype:"91"` + MyID uint32 + MyWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenFailure = 92 + +type channelOpenFailureMsg struct { + PeersID uint32 `sshtype:"92"` + Reason RejectionReason + Message string + Language string +} + +const msgChannelRequest = 98 + +type channelRequestMsg struct { + PeersID uint32 `sshtype:"98"` + Request string + WantReply bool + RequestSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.4. +const msgChannelSuccess = 99 + +type channelRequestSuccessMsg struct { + PeersID uint32 `sshtype:"99"` +} + +// See RFC 4254, section 5.4. +const msgChannelFailure = 100 + +type channelRequestFailureMsg struct { + PeersID uint32 `sshtype:"100"` +} + +// See RFC 4254, section 5.3 +const msgChannelClose = 97 + +type channelCloseMsg struct { + PeersID uint32 `sshtype:"97"` +} + +// See RFC 4254, section 5.3 +const msgChannelEOF = 96 + +type channelEOFMsg struct { + PeersID uint32 `sshtype:"96"` +} + +// See RFC 4254, section 4 +const msgGlobalRequest = 80 + +type globalRequestMsg struct { + Type string `sshtype:"80"` + WantReply bool + Data []byte `ssh:"rest"` +} + +// See RFC 4254, section 4 +const msgRequestSuccess = 81 + +type globalRequestSuccessMsg struct { + Data []byte `ssh:"rest" sshtype:"81"` +} + +// See RFC 4254, section 4 +const msgRequestFailure = 82 + +type globalRequestFailureMsg struct { + Data []byte `ssh:"rest" sshtype:"82"` +} + +// See RFC 4254, section 5.2 +const msgChannelWindowAdjust = 93 + +type windowAdjustMsg struct { + PeersID uint32 `sshtype:"93"` + AdditionalBytes uint32 +} + +// See RFC 4252, section 7 +const msgUserAuthPubKeyOk = 60 + +type userAuthPubKeyOkMsg struct { + Algo string `sshtype:"60"` + PubKey []byte +} + +// See RFC 4462, section 3 +const msgUserAuthGSSAPIResponse = 60 + +type userAuthGSSAPIResponse struct { + SupportMech []byte `sshtype:"60"` +} + +const msgUserAuthGSSAPIToken = 61 + +type userAuthGSSAPIToken struct { + Token []byte `sshtype:"61"` +} + +const msgUserAuthGSSAPIMIC = 66 + +type userAuthGSSAPIMIC struct { + MIC []byte `sshtype:"66"` +} + +// See RFC 4462, section 3.9 +const msgUserAuthGSSAPIErrTok = 64 + +type userAuthGSSAPIErrTok struct { + ErrorToken []byte `sshtype:"64"` +} + +// See RFC 4462, section 3.8 +const msgUserAuthGSSAPIError = 65 + +type userAuthGSSAPIError struct { + MajorStatus uint32 `sshtype:"65"` + MinorStatus uint32 + Message string + LanguageTag string +} + +// typeTags returns the possible type bytes for the given reflect.Type, which +// should be a struct. The possible values are separated by a '|' character. +func typeTags(structType reflect.Type) (tags []byte) { + tagStr := structType.Field(0).Tag.Get("sshtype") + + for _, tag := range strings.Split(tagStr, "|") { + i, err := strconv.Atoi(tag) + if err == nil { + tags = append(tags, byte(i)) + } + } + + return tags +} + +func fieldError(t reflect.Type, field int, problem string) error { + if problem != "" { + problem = ": " + problem + } + return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem) +} + +var errShortRead = errors.New("ssh: short read") + +// Unmarshal parses data in SSH wire format into a structure. The out +// argument should be a pointer to struct. If the first member of the +// struct has the "sshtype" tag set to a '|'-separated set of numbers +// in decimal, the packet must start with one of those numbers. In +// case of error, Unmarshal returns a ParseError or +// UnexpectedMessageError. +func Unmarshal(data []byte, out interface{}) error { + v := reflect.ValueOf(out).Elem() + structType := v.Type() + expectedTypes := typeTags(structType) + + var expectedType byte + if len(expectedTypes) > 0 { + expectedType = expectedTypes[0] + } + + if len(data) == 0 { + return parseError(expectedType) + } + + if len(expectedTypes) > 0 { + goodType := false + for _, e := range expectedTypes { + if e > 0 && data[0] == e { + goodType = true + break + } + } + if !goodType { + return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes) + } + data = data[1:] + } + + var ok bool + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + t := field.Type() + switch t.Kind() { + case reflect.Bool: + if len(data) < 1 { + return errShortRead + } + field.SetBool(data[0] != 0) + data = data[1:] + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + return fieldError(structType, i, "array of unsupported type") + } + if len(data) < t.Len() { + return errShortRead + } + for j, n := 0, t.Len(); j < n; j++ { + field.Index(j).Set(reflect.ValueOf(data[j])) + } + data = data[t.Len():] + case reflect.Uint64: + var u64 uint64 + if u64, data, ok = parseUint64(data); !ok { + return errShortRead + } + field.SetUint(u64) + case reflect.Uint32: + var u32 uint32 + if u32, data, ok = parseUint32(data); !ok { + return errShortRead + } + field.SetUint(uint64(u32)) + case reflect.Uint8: + if len(data) < 1 { + return errShortRead + } + field.SetUint(uint64(data[0])) + data = data[1:] + case reflect.String: + var s []byte + if s, data, ok = parseString(data); !ok { + return fieldError(structType, i, "") + } + field.SetString(string(s)) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if structType.Field(i).Tag.Get("ssh") == "rest" { + field.Set(reflect.ValueOf(data)) + data = nil + } else { + var s []byte + if s, data, ok = parseString(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(s)) + } + case reflect.String: + var nl []string + if nl, data, ok = parseNameList(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(nl)) + default: + return fieldError(structType, i, "slice of unsupported type") + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + if n, data, ok = parseInt(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(n)) + } else { + return fieldError(structType, i, "pointer to unsupported type") + } + default: + return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t)) + } + } + + if len(data) != 0 { + return parseError(expectedType) + } + + return nil +} + +// Marshal serializes the message in msg to SSH wire format. The msg +// argument should be a struct or pointer to struct. If the first +// member has the "sshtype" tag set to a number in decimal, that +// number is prepended to the result. If the last of member has the +// "ssh" tag set to "rest", its contents are appended to the output. +func Marshal(msg interface{}) []byte { + out := make([]byte, 0, 64) + return marshalStruct(out, msg) +} + +func marshalStruct(out []byte, msg interface{}) []byte { + v := reflect.Indirect(reflect.ValueOf(msg)) + msgTypes := typeTags(v.Type()) + if len(msgTypes) > 0 { + out = append(out, msgTypes[0]) + } + + for i, n := 0, v.NumField(); i < n; i++ { + field := v.Field(i) + switch t := field.Type(); t.Kind() { + case reflect.Bool: + var v uint8 + if field.Bool() { + v = 1 + } + out = append(out, v) + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface())) + } + for j, l := 0, t.Len(); j < l; j++ { + out = append(out, uint8(field.Index(j).Uint())) + } + case reflect.Uint32: + out = appendU32(out, uint32(field.Uint())) + case reflect.Uint64: + out = appendU64(out, uint64(field.Uint())) + case reflect.Uint8: + out = append(out, uint8(field.Uint())) + case reflect.String: + s := field.String() + out = appendInt(out, len(s)) + out = append(out, s...) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if v.Type().Field(i).Tag.Get("ssh") != "rest" { + out = appendInt(out, field.Len()) + } + out = append(out, field.Bytes()...) + case reflect.String: + offset := len(out) + out = appendU32(out, 0) + if n := field.Len(); n > 0 { + for j := 0; j < n; j++ { + f := field.Index(j) + if j != 0 { + out = append(out, ',') + } + out = append(out, f.String()...) + } + // overwrite length value + binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4)) + } + default: + panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + nValue := reflect.ValueOf(&n) + nValue.Elem().Set(field) + needed := intLength(n) + oldLength := len(out) + + if cap(out)-len(out) < needed { + newOut := make([]byte, len(out), 2*(len(out)+needed)) + copy(newOut, out) + out = newOut + } + out = out[:oldLength+needed] + marshalInt(out[oldLength:], n) + } else { + panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface())) + } + } + } + + return out +} + +var bigOne = big.NewInt(1) + +func parseString(in []byte) (out, rest []byte, ok bool) { + if len(in) < 4 { + return + } + length := binary.BigEndian.Uint32(in) + in = in[4:] + if uint32(len(in)) < length { + return + } + out = in[:length] + rest = in[length:] + ok = true + return +} + +var ( + comma = []byte{','} + emptyNameList = []string{} +) + +func parseNameList(in []byte) (out []string, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + if len(contents) == 0 { + out = emptyNameList + return + } + parts := bytes.Split(contents, comma) + out = make([]string, len(parts)) + for i, part := range parts { + out[i] = string(part) + } + return +} + +func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + out = new(big.Int) + + if len(contents) > 0 && contents[0]&0x80 == 0x80 { + // This is a negative number + notBytes := make([]byte, len(contents)) + for i := range notBytes { + notBytes[i] = ^contents[i] + } + out.SetBytes(notBytes) + out.Add(out, bigOne) + out.Neg(out) + } else { + // Positive number + out.SetBytes(contents) + } + ok = true + return +} + +func parseUint32(in []byte) (uint32, []byte, bool) { + if len(in) < 4 { + return 0, nil, false + } + return binary.BigEndian.Uint32(in), in[4:], true +} + +func parseUint64(in []byte) (uint64, []byte, bool) { + if len(in) < 8 { + return 0, nil, false + } + return binary.BigEndian.Uint64(in), in[8:], true +} + +func intLength(n *big.Int) int { + length := 4 /* length bytes */ + if n.Sign() < 0 { + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bitLen := nMinus1.BitLen() + if bitLen%8 == 0 { + // The number will need 0xff padding + length++ + } + length += (bitLen + 7) / 8 + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bitLen := n.BitLen() + if bitLen%8 == 0 { + // The number will need 0x00 padding + length++ + } + length += (bitLen + 7) / 8 + } + + return length +} + +func marshalUint32(to []byte, n uint32) []byte { + binary.BigEndian.PutUint32(to, n) + return to[4:] +} + +func marshalUint64(to []byte, n uint64) []byte { + binary.BigEndian.PutUint64(to, n) + return to[8:] +} + +func marshalInt(to []byte, n *big.Int) []byte { + lengthBytes := to + to = to[4:] + length := 0 + + if n.Sign() < 0 { + // A negative number has to be converted to two's-complement + // form. So we'll subtract 1 and invert. If the + // most-significant-bit isn't set then we'll need to pad the + // beginning with 0xff in order to keep the number negative. + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bytes := nMinus1.Bytes() + for i := range bytes { + bytes[i] ^= 0xff + } + if len(bytes) == 0 || bytes[0]&0x80 == 0 { + to[0] = 0xff + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bytes := n.Bytes() + if len(bytes) > 0 && bytes[0]&0x80 != 0 { + // We'll have to pad this with a 0x00 in order to + // stop it looking like a negative number. + to[0] = 0 + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } + + lengthBytes[0] = byte(length >> 24) + lengthBytes[1] = byte(length >> 16) + lengthBytes[2] = byte(length >> 8) + lengthBytes[3] = byte(length) + return to +} + +func writeInt(w io.Writer, n *big.Int) { + length := intLength(n) + buf := make([]byte, length) + marshalInt(buf, n) + w.Write(buf) +} + +func writeString(w io.Writer, s []byte) { + var lengthBytes [4]byte + lengthBytes[0] = byte(len(s) >> 24) + lengthBytes[1] = byte(len(s) >> 16) + lengthBytes[2] = byte(len(s) >> 8) + lengthBytes[3] = byte(len(s)) + w.Write(lengthBytes[:]) + w.Write(s) +} + +func stringLength(n int) int { + return 4 + n +} + +func marshalString(to []byte, s []byte) []byte { + to[0] = byte(len(s) >> 24) + to[1] = byte(len(s) >> 16) + to[2] = byte(len(s) >> 8) + to[3] = byte(len(s)) + to = to[4:] + copy(to, s) + return to[len(s):] +} + +var bigIntType = reflect.TypeOf((*big.Int)(nil)) + +// Decode a packet into its corresponding message. +func decode(packet []byte) (interface{}, error) { + var msg interface{} + switch packet[0] { + case msgDisconnect: + msg = new(disconnectMsg) + case msgServiceRequest: + msg = new(serviceRequestMsg) + case msgServiceAccept: + msg = new(serviceAcceptMsg) + case msgKexInit: + msg = new(kexInitMsg) + case msgKexDHInit: + msg = new(kexDHInitMsg) + case msgKexDHReply: + msg = new(kexDHReplyMsg) + case msgUserAuthRequest: + msg = new(userAuthRequestMsg) + case msgUserAuthSuccess: + return new(userAuthSuccessMsg), nil + case msgUserAuthFailure: + msg = new(userAuthFailureMsg) + case msgUserAuthPubKeyOk: + msg = new(userAuthPubKeyOkMsg) + case msgGlobalRequest: + msg = new(globalRequestMsg) + case msgRequestSuccess: + msg = new(globalRequestSuccessMsg) + case msgRequestFailure: + msg = new(globalRequestFailureMsg) + case msgChannelOpen: + msg = new(channelOpenMsg) + case msgChannelData: + msg = new(channelDataMsg) + case msgChannelOpenConfirm: + msg = new(channelOpenConfirmMsg) + case msgChannelOpenFailure: + msg = new(channelOpenFailureMsg) + case msgChannelWindowAdjust: + msg = new(windowAdjustMsg) + case msgChannelEOF: + msg = new(channelEOFMsg) + case msgChannelClose: + msg = new(channelCloseMsg) + case msgChannelRequest: + msg = new(channelRequestMsg) + case msgChannelSuccess: + msg = new(channelRequestSuccessMsg) + case msgChannelFailure: + msg = new(channelRequestFailureMsg) + case msgUserAuthGSSAPIToken: + msg = new(userAuthGSSAPIToken) + case msgUserAuthGSSAPIMIC: + msg = new(userAuthGSSAPIMIC) + case msgUserAuthGSSAPIErrTok: + msg = new(userAuthGSSAPIErrTok) + case msgUserAuthGSSAPIError: + msg = new(userAuthGSSAPIError) + default: + return nil, unexpectedMessageError(0, packet[0]) + } + if err := Unmarshal(packet, msg); err != nil { + return nil, err + } + return msg, nil +} + +var packetTypeNames = map[byte]string{ + msgDisconnect: "disconnectMsg", + msgServiceRequest: "serviceRequestMsg", + msgServiceAccept: "serviceAcceptMsg", + msgKexInit: "kexInitMsg", + msgKexDHInit: "kexDHInitMsg", + msgKexDHReply: "kexDHReplyMsg", + msgUserAuthRequest: "userAuthRequestMsg", + msgUserAuthSuccess: "userAuthSuccessMsg", + msgUserAuthFailure: "userAuthFailureMsg", + msgUserAuthPubKeyOk: "userAuthPubKeyOkMsg", + msgGlobalRequest: "globalRequestMsg", + msgRequestSuccess: "globalRequestSuccessMsg", + msgRequestFailure: "globalRequestFailureMsg", + msgChannelOpen: "channelOpenMsg", + msgChannelData: "channelDataMsg", + msgChannelOpenConfirm: "channelOpenConfirmMsg", + msgChannelOpenFailure: "channelOpenFailureMsg", + msgChannelWindowAdjust: "windowAdjustMsg", + msgChannelEOF: "channelEOFMsg", + msgChannelClose: "channelCloseMsg", + msgChannelRequest: "channelRequestMsg", + msgChannelSuccess: "channelRequestSuccessMsg", + msgChannelFailure: "channelRequestFailureMsg", +} diff --git a/internal/crypto/ssh/mux.go b/internal/crypto/ssh/mux.go new file mode 100644 index 000000000..9654c0186 --- /dev/null +++ b/internal/crypto/ssh/mux.go @@ -0,0 +1,351 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "sync" + "sync/atomic" +) + +// debugMux, if set, causes messages in the connection protocol to be +// logged. +const debugMux = false + +// chanList is a thread safe channel list. +type chanList struct { + // protects concurrent access to chans + sync.Mutex + + // chans are indexed by the local id of the channel, which the + // other side should send in the PeersId field. + chans []*channel + + // This is a debugging aid: it offsets all IDs by this + // amount. This helps distinguish otherwise identical + // server/client muxes + offset uint32 +} + +// Assigns a channel ID to the given channel. +func (c *chanList) add(ch *channel) uint32 { + c.Lock() + defer c.Unlock() + for i := range c.chans { + if c.chans[i] == nil { + c.chans[i] = ch + return uint32(i) + c.offset + } + } + c.chans = append(c.chans, ch) + return uint32(len(c.chans)-1) + c.offset +} + +// getChan returns the channel for the given ID. +func (c *chanList) getChan(id uint32) *channel { + id -= c.offset + + c.Lock() + defer c.Unlock() + if id < uint32(len(c.chans)) { + return c.chans[id] + } + return nil +} + +func (c *chanList) remove(id uint32) { + id -= c.offset + c.Lock() + if id < uint32(len(c.chans)) { + c.chans[id] = nil + } + c.Unlock() +} + +// dropAll forgets all channels it knows, returning them in a slice. +func (c *chanList) dropAll() []*channel { + c.Lock() + defer c.Unlock() + var r []*channel + + for _, ch := range c.chans { + if ch == nil { + continue + } + r = append(r, ch) + } + c.chans = nil + return r +} + +// mux represents the state for the SSH connection protocol, which +// multiplexes many channels onto a single packet transport. +type mux struct { + conn packetConn + chanList chanList + + incomingChannels chan NewChannel + + globalSentMu sync.Mutex + globalResponses chan interface{} + incomingRequests chan *Request + + errCond *sync.Cond + err error +} + +// When debugging, each new chanList instantiation has a different +// offset. +var globalOff uint32 + +func (m *mux) Wait() error { + m.errCond.L.Lock() + defer m.errCond.L.Unlock() + for m.err == nil { + m.errCond.Wait() + } + return m.err +} + +// newMux returns a mux that runs over the given connection. +func newMux(p packetConn) *mux { + m := &mux{ + conn: p, + incomingChannels: make(chan NewChannel, chanSize), + globalResponses: make(chan interface{}, 1), + incomingRequests: make(chan *Request, chanSize), + errCond: newCond(), + } + if debugMux { + m.chanList.offset = atomic.AddUint32(&globalOff, 1) + } + + go m.loop() + return m +} + +func (m *mux) sendMessage(msg interface{}) error { + p := Marshal(msg) + if debugMux { + log.Printf("send global(%d): %#v", m.chanList.offset, msg) + } + return m.conn.writePacket(p) +} + +func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { + if wantReply { + m.globalSentMu.Lock() + defer m.globalSentMu.Unlock() + } + + if err := m.sendMessage(globalRequestMsg{ + Type: name, + WantReply: wantReply, + Data: payload, + }); err != nil { + return false, nil, err + } + + if !wantReply { + return false, nil, nil + } + + msg, ok := <-m.globalResponses + if !ok { + return false, nil, io.EOF + } + switch msg := msg.(type) { + case *globalRequestFailureMsg: + return false, msg.Data, nil + case *globalRequestSuccessMsg: + return true, msg.Data, nil + default: + return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) + } +} + +// ackRequest must be called after processing a global request that +// has WantReply set. +func (m *mux) ackRequest(ok bool, data []byte) error { + if ok { + return m.sendMessage(globalRequestSuccessMsg{Data: data}) + } + return m.sendMessage(globalRequestFailureMsg{Data: data}) +} + +func (m *mux) Close() error { + return m.conn.Close() +} + +// loop runs the connection machine. It will process packets until an +// error is encountered. To synchronize on loop exit, use mux.Wait. +func (m *mux) loop() { + var err error + for err == nil { + err = m.onePacket() + } + + for _, ch := range m.chanList.dropAll() { + ch.close() + } + + close(m.incomingChannels) + close(m.incomingRequests) + close(m.globalResponses) + + m.conn.Close() + + m.errCond.L.Lock() + m.err = err + m.errCond.Broadcast() + m.errCond.L.Unlock() + + if debugMux { + log.Println("loop exit", err) + } +} + +// onePacket reads and processes one packet. +func (m *mux) onePacket() error { + packet, err := m.conn.readPacket() + if err != nil { + return err + } + + if debugMux { + if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { + log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) + } else { + p, _ := decode(packet) + log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) + } + } + + switch packet[0] { + case msgChannelOpen: + return m.handleChannelOpen(packet) + case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: + return m.handleGlobalPacket(packet) + } + + // assume a channel packet. + if len(packet) < 5 { + return parseError(packet[0]) + } + id := binary.BigEndian.Uint32(packet[1:]) + ch := m.chanList.getChan(id) + if ch == nil { + return m.handleUnknownChannelPacket(id, packet) + } + + return ch.handlePacket(packet) +} + +func (m *mux) handleGlobalPacket(packet []byte) error { + msg, err := decode(packet) + if err != nil { + return err + } + + switch msg := msg.(type) { + case *globalRequestMsg: + m.incomingRequests <- &Request{ + Type: msg.Type, + WantReply: msg.WantReply, + Payload: msg.Data, + mux: m, + } + case *globalRequestSuccessMsg, *globalRequestFailureMsg: + m.globalResponses <- msg + default: + panic(fmt.Sprintf("not a global message %#v", msg)) + } + + return nil +} + +// handleChannelOpen schedules a channel to be Accept()ed. +func (m *mux) handleChannelOpen(packet []byte) error { + var msg channelOpenMsg + if err := Unmarshal(packet, &msg); err != nil { + return err + } + + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + failMsg := channelOpenFailureMsg{ + PeersID: msg.PeersID, + Reason: ConnectionFailed, + Message: "invalid request", + Language: "en_US.UTF-8", + } + return m.sendMessage(failMsg) + } + + c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) + c.remoteId = msg.PeersID + c.maxRemotePayload = msg.MaxPacketSize + c.remoteWin.add(msg.PeersWindow) + m.incomingChannels <- c + return nil +} + +func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { + ch, err := m.openChannel(chanType, extra) + if err != nil { + return nil, nil, err + } + + return ch, ch.incomingRequests, nil +} + +func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { + ch := m.newChannel(chanType, channelOutbound, extra) + + ch.maxIncomingPayload = channelMaxPacket + + open := channelOpenMsg{ + ChanType: chanType, + PeersWindow: ch.myWindow, + MaxPacketSize: ch.maxIncomingPayload, + TypeSpecificData: extra, + PeersID: ch.localId, + } + if err := m.sendMessage(open); err != nil { + return nil, err + } + + switch msg := (<-ch.msg).(type) { + case *channelOpenConfirmMsg: + return ch, nil + case *channelOpenFailureMsg: + return nil, &OpenChannelError{msg.Reason, msg.Message} + default: + return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) + } +} + +func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error { + msg, err := decode(packet) + if err != nil { + return err + } + + switch msg := msg.(type) { + // RFC 4254 section 5.4 says unrecognized channel requests should + // receive a failure response. + case *channelRequestMsg: + if msg.WantReply { + return m.sendMessage(channelRequestFailureMsg{ + PeersID: msg.PeersID, + }) + } + return nil + default: + return fmt.Errorf("ssh: invalid channel %d", id) + } +} diff --git a/internal/crypto/ssh/server.go b/internal/crypto/ssh/server.go new file mode 100644 index 000000000..3cd0032e7 --- /dev/null +++ b/internal/crypto/ssh/server.go @@ -0,0 +1,743 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "strings" +) + +// The Permissions type holds fine-grained permissions that are +// specific to a user or a specific authentication method for a user. +// The Permissions value for a successful authentication attempt is +// available in ServerConn, so it can be used to pass information from +// the user-authentication phase to the application layer. +type Permissions struct { + // CriticalOptions indicate restrictions to the default + // permissions, and are typically used in conjunction with + // user certificates. The standard for SSH certificates + // defines "force-command" (only allow the given command to + // execute) and "source-address" (only allow connections from + // the given address). The SSH package currently only enforces + // the "source-address" critical option. It is up to server + // implementations to enforce other critical options, such as + // "force-command", by checking them after the SSH handshake + // is successful. In general, SSH servers should reject + // connections that specify critical options that are unknown + // or not supported. + CriticalOptions map[string]string + + // Extensions are extra functionality that the server may + // offer on authenticated connections. Lack of support for an + // extension does not preclude authenticating a user. Common + // extensions are "permit-agent-forwarding", + // "permit-X11-forwarding". The Go SSH library currently does + // not act on any extension, and it is up to server + // implementations to honor them. Extensions can be used to + // pass data from the authentication callbacks to the server + // application layer. + Extensions map[string]string +} + +type GSSAPIWithMICConfig struct { + // AllowLogin, must be set, is called when gssapi-with-mic + // authentication is selected (RFC 4462 section 3). The srcName is from the + // results of the GSS-API authentication. The format is username@DOMAIN. + // GSSAPI just guarantees to the server who the user is, but not if they can log in, and with what permissions. + // This callback is called after the user identity is established with GSSAPI to decide if the user can login with + // which permissions. If the user is allowed to login, it should return a nil error. + AllowLogin func(conn ConnMetadata, srcName string) (*Permissions, error) + + // Server must be set. It's the implementation + // of the GSSAPIServer interface. See GSSAPIServer interface for details. + Server GSSAPIServer +} + +// ServerConfig holds server specific configuration data. +type ServerConfig struct { + // Config contains configuration shared between client and server. + Config + + hostKeys map[string]Signer + + // NoClientAuth is true if clients are allowed to connect without + // authenticating. + NoClientAuth bool + + // MaxAuthTries specifies the maximum number of authentication attempts + // permitted per connection. If set to a negative number, the number of + // attempts are unlimited. If set to zero, the number of attempts are limited + // to 6. + MaxAuthTries int + + // PasswordCallback, if non-nil, is called when a user + // attempts to authenticate using a password. + PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) + + // PublicKeyCallback, if non-nil, is called when a client + // offers a public key for authentication. It must return a nil error + // if the given public key can be used to authenticate the + // given user. For example, see CertChecker.Authenticate. A + // call to this function does not guarantee that the key + // offered is in fact used to authenticate. To record any data + // depending on the public key, store it inside a + // Permissions.Extensions entry. + PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // KeyboardInteractiveCallback, if non-nil, is called when + // keyboard-interactive authentication is selected (RFC + // 4256). The client object's Challenge function should be + // used to query the user. The callback may offer multiple + // Challenge rounds. To avoid information leaks, the client + // should be presented a challenge even if the user is + // unknown. + KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) + + // AuthLogCallback, if non-nil, is called to log all authentication + // attempts. + AuthLogCallback func(conn ConnMetadata, method string, err error) + + // ServerVersion is the version identification string to announce in + // the public handshake. + // If empty, a reasonable default is used. + // Note that RFC 4253 section 4.2 requires that this string start with + // "SSH-2.0-". + ServerVersion string + + // BannerCallback, if present, is called and the return string is sent to + // the client after key exchange completed but before authentication. + BannerCallback func(conn ConnMetadata) string + + // GSSAPIWithMICConfig includes gssapi server and callback, which if both non-nil, is used + // when gssapi-with-mic authentication is selected (RFC 4462 section 3). + GSSAPIWithMICConfig *GSSAPIWithMICConfig +} + +// AddHostKey adds a private key as a host key. If an existing host +// key exists with the same algorithm, it is overwritten. Each server +// config must have at least one host key. +func (s *ServerConfig) AddHostKey(key Signer) { + if s.hostKeys == nil { + s.hostKeys = make(map[string]Signer) + } + + keyType := key.PublicKey().Type() + switch keyType { + case KeyAlgoRSA, KeyAlgoRSASHA2256, KeyAlgoRSASHA2512: + if algorithmSigner, ok := key.(AlgorithmSigner); ok { + s.hostKeys[KeyAlgoRSA] = &defaultAlgorithmSigner{ + algorithmSigner, SigAlgoRSA, + } + s.hostKeys[KeyAlgoRSASHA2256] = &defaultAlgorithmSigner{ + algorithmSigner, SigAlgoRSASHA2256, + } + s.hostKeys[KeyAlgoRSASHA2512] = &defaultAlgorithmSigner{ + algorithmSigner, SigAlgoRSASHA2512, + } + return + } + case CertAlgoRSAv01, CertAlgoRSASHA2256v01, CertAlgoRSASHA2512v01: + if algorithmSigner, ok := key.(AlgorithmSigner); ok { + s.hostKeys[CertAlgoRSAv01] = &defaultAlgorithmSigner{ + algorithmSigner, SigAlgoRSA, + } + s.hostKeys[CertAlgoRSASHA2256v01] = &defaultAlgorithmSigner{ + algorithmSigner, SigAlgoRSASHA2256, + } + s.hostKeys[CertAlgoRSASHA2512v01] = &defaultAlgorithmSigner{ + algorithmSigner, SigAlgoRSASHA2512, + } + return + } + } + s.hostKeys[keyType] = key + +} + +// cachedPubKey contains the results of querying whether a public key is +// acceptable for a user. +type cachedPubKey struct { + user string + pubKeyData []byte + result error + perms *Permissions +} + +const maxCachedPubKeys = 16 + +// pubKeyCache caches tests for public keys. Since SSH clients +// will query whether a public key is acceptable before attempting to +// authenticate with it, we end up with duplicate queries for public +// key validity. The cache only applies to a single ServerConn. +type pubKeyCache struct { + keys []cachedPubKey +} + +// get returns the result for a given user/algo/key tuple. +func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) { + for _, k := range c.keys { + if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) { + return k, true + } + } + return cachedPubKey{}, false +} + +// add adds the given tuple to the cache. +func (c *pubKeyCache) add(candidate cachedPubKey) { + if len(c.keys) < maxCachedPubKeys { + c.keys = append(c.keys, candidate) + } +} + +// ServerConn is an authenticated SSH connection, as seen from the +// server +type ServerConn struct { + Conn + + // If the succeeding authentication callback returned a + // non-nil Permissions pointer, it is stored here. + Permissions *Permissions +} + +// NewServerConn starts a new SSH server with c as the underlying +// transport. It starts with a handshake and, if the handshake is +// unsuccessful, it closes the connection and returns an error. The +// Request and NewChannel channels must be serviced, or the connection +// will hang. +// +// The returned error may be of type *ServerAuthError for +// authentication errors. +func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + if fullConf.MaxAuthTries == 0 { + fullConf.MaxAuthTries = 6 + } + // Check if the config contains any unsupported key exchanges + for _, kex := range fullConf.KeyExchanges { + if _, ok := serverForbiddenKexAlgos[kex]; ok { + return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex) + } + } + + s := &connection{ + sshConn: sshConn{conn: c}, + } + perms, err := s.serverHandshake(&fullConf) + if err != nil { + c.Close() + return nil, nil, nil, err + } + return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil +} + +// signAndMarshal signs the data with the appropriate algorithm, +// and serializes the result in SSH wire format. +func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) { + sig, err := k.Sign(rand, data) + if err != nil { + return nil, err + } + + return Marshal(sig), nil +} + +// handshake performs key exchange and user authentication. +func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) { + if len(config.hostKeys) == 0 { + return nil, errors.New("ssh: server has no host keys") + } + + if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil && + config.KeyboardInteractiveCallback == nil && (config.GSSAPIWithMICConfig == nil || + config.GSSAPIWithMICConfig.AllowLogin == nil || config.GSSAPIWithMICConfig.Server == nil) { + return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") + } + + if config.ServerVersion != "" { + s.serverVersion = []byte(config.ServerVersion) + } else { + s.serverVersion = []byte(packageVersion) + } + var err error + s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion) + if err != nil { + return nil, err + } + + tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */) + s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config) + + if err := s.transport.waitSession(); err != nil { + return nil, err + } + + // We just did the key change, so the session ID is established. + s.sessionID = s.transport.getSessionID() + + var packet []byte + if packet, err = s.transport.readPacket(); err != nil { + return nil, err + } + + var serviceRequest serviceRequestMsg + if err = Unmarshal(packet, &serviceRequest); err != nil { + return nil, err + } + if serviceRequest.Service != serviceUserAuth { + return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") + } + serviceAccept := serviceAcceptMsg{ + Service: serviceUserAuth, + } + if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil { + return nil, err + } + + perms, err := s.serverAuthenticate(config) + if err != nil { + return nil, err + } + s.mux = newMux(s.transport) + return perms, err +} + +func isAcceptableAlgo(algo string) bool { + switch algo { + case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoSKECDSA256, KeyAlgoED25519, KeyAlgoSKED25519, + CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01: + return true + } + return false +} + +func checkSourceAddress(addr net.Addr, sourceAddrs string) error { + if addr == nil { + return errors.New("ssh: no address known for client, but source-address match required") + } + + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr) + } + + for _, sourceAddr := range strings.Split(sourceAddrs, ",") { + if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil { + if allowedIP.Equal(tcpAddr.IP) { + return nil + } + } else { + _, ipNet, err := net.ParseCIDR(sourceAddr) + if err != nil { + return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err) + } + + if ipNet.Contains(tcpAddr.IP) { + return nil + } + } + } + + return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) +} + +func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *connection, + sessionID []byte, userAuthReq userAuthRequestMsg) (authErr error, perms *Permissions, err error) { + gssAPIServer := gssapiConfig.Server + defer gssAPIServer.DeleteSecContext() + var srcName string + for { + var ( + outToken []byte + needContinue bool + ) + outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(firstToken) + if err != nil { + return err, nil, nil + } + if len(outToken) != 0 { + if err := s.transport.writePacket(Marshal(&userAuthGSSAPIToken{ + Token: outToken, + })); err != nil { + return nil, nil, err + } + } + if !needContinue { + break + } + packet, err := s.transport.readPacket() + if err != nil { + return nil, nil, err + } + userAuthGSSAPITokenReq := &userAuthGSSAPIToken{} + if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { + return nil, nil, err + } + } + packet, err := s.transport.readPacket() + if err != nil { + return nil, nil, err + } + userAuthGSSAPIMICReq := &userAuthGSSAPIMIC{} + if err := Unmarshal(packet, userAuthGSSAPIMICReq); err != nil { + return nil, nil, err + } + mic := buildMIC(string(sessionID), userAuthReq.User, userAuthReq.Service, userAuthReq.Method) + if err := gssAPIServer.VerifyMIC(mic, userAuthGSSAPIMICReq.MIC); err != nil { + return err, nil, nil + } + perms, authErr = gssapiConfig.AllowLogin(s, srcName) + return authErr, perms, nil +} + +// ServerAuthError represents server authentication errors and is +// sometimes returned by NewServerConn. It appends any authentication +// errors that may occur, and is returned if all of the authentication +// methods provided by the user failed to authenticate. +type ServerAuthError struct { + // Errors contains authentication errors returned by the authentication + // callback methods. The first entry is typically ErrNoAuth. + Errors []error +} + +func (l ServerAuthError) Error() string { + var errs []string + for _, err := range l.Errors { + errs = append(errs, err.Error()) + } + return "[" + strings.Join(errs, ", ") + "]" +} + +// ErrNoAuth is the error value returned if no +// authentication method has been passed yet. This happens as a normal +// part of the authentication loop, since the client first tries +// 'none' authentication to discover available methods. +// It is returned in ServerAuthError.Errors from NewServerConn. +var ErrNoAuth = errors.New("ssh: no auth passed yet") + +func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { + sessionID := s.transport.getSessionID() + var cache pubKeyCache + var perms *Permissions + + authFailures := 0 + var authErrs []error + var displayedBanner bool + +userAuthLoop: + for { + if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 { + discMsg := &disconnectMsg{ + Reason: 2, + Message: "too many authentication failures", + } + + if err := s.transport.writePacket(Marshal(discMsg)); err != nil { + return nil, err + } + + return nil, discMsg + } + + var userAuthReq userAuthRequestMsg + if packet, err := s.transport.readPacket(); err != nil { + if err == io.EOF { + return nil, &ServerAuthError{Errors: authErrs} + } + return nil, err + } else if err = Unmarshal(packet, &userAuthReq); err != nil { + return nil, err + } + + if userAuthReq.Service != serviceSSH { + return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) + } + + s.user = userAuthReq.User + + if !displayedBanner && config.BannerCallback != nil { + displayedBanner = true + msg := config.BannerCallback(s) + if msg != "" { + bannerMsg := &userAuthBannerMsg{ + Message: msg, + } + if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil { + return nil, err + } + } + } + + perms = nil + authErr := ErrNoAuth + + switch userAuthReq.Method { + case "none": + if config.NoClientAuth { + authErr = nil + } + + // allow initial attempt of 'none' without penalty + if authFailures == 0 { + authFailures-- + } + case "password": + if config.PasswordCallback == nil { + authErr = errors.New("ssh: password auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 || payload[0] != 0 { + return nil, parseError(msgUserAuthRequest) + } + payload = payload[1:] + password, payload, ok := parseString(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + + perms, authErr = config.PasswordCallback(s, password) + case "keyboard-interactive": + if config.KeyboardInteractiveCallback == nil { + authErr = errors.New("ssh: keyboard-interactive auth not configured") + break + } + + prompter := &sshClientKeyboardInteractive{s} + perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge) + case "publickey": + if config.PublicKeyCallback == nil { + authErr = errors.New("ssh: publickey auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 { + return nil, parseError(msgUserAuthRequest) + } + isQuery := payload[0] == 0 + payload = payload[1:] + algoBytes, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + algo := string(algoBytes) + if !isAcceptableAlgo(algo) { + authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo) + break + } + + pubKeyData, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + + pubKey, err := ParsePublicKey(pubKeyData) + if err != nil { + return nil, err + } + + candidate, ok := cache.get(s.user, pubKeyData) + if !ok { + candidate.user = s.user + candidate.pubKeyData = pubKeyData + candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey) + if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { + candidate.result = checkSourceAddress( + s.RemoteAddr(), + candidate.perms.CriticalOptions[sourceAddressCriticalOption]) + } + cache.add(candidate) + } + + if isQuery { + // The client can query if the given public key + // would be okay. + + if len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + + if candidate.result == nil { + okMsg := userAuthPubKeyOkMsg{ + Algo: algo, + PubKey: pubKeyData, + } + if err = s.transport.writePacket(Marshal(&okMsg)); err != nil { + return nil, err + } + continue userAuthLoop + } + authErr = candidate.result + } else { + sig, payload, ok := parseSignature(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + // Ensure the public key algo and signature algo + // are supported. Compare the private key + // algorithm name that corresponds to algo with + // sig.Format. This is usually the same, but + // for certs, the names differ. + if !isAcceptableAlgo(sig.Format) { + authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format) + break + } + signedData := buildDataSignedForAuth(sessionID, userAuthReq, algoBytes, pubKeyData) + + if err := pubKey.Verify(signedData, sig); err != nil { + return nil, err + } + + authErr = candidate.result + perms = candidate.perms + } + case "gssapi-with-mic": + gssapiConfig := config.GSSAPIWithMICConfig + userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload) + if err != nil { + return nil, parseError(msgUserAuthRequest) + } + // OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication. + if userAuthRequestGSSAPI.N == 0 { + authErr = fmt.Errorf("ssh: Mechanism negotiation is not supported") + break + } + var i uint32 + present := false + for i = 0; i < userAuthRequestGSSAPI.N; i++ { + if userAuthRequestGSSAPI.OIDS[i].Equal(krb5Mesh) { + present = true + break + } + } + if !present { + authErr = fmt.Errorf("ssh: GSSAPI authentication must use the Kerberos V5 mechanism") + break + } + // Initial server response, see RFC 4462 section 3.3. + if err := s.transport.writePacket(Marshal(&userAuthGSSAPIResponse{ + SupportMech: krb5OID, + })); err != nil { + return nil, err + } + // Exchange token, see RFC 4462 section 3.4. + packet, err := s.transport.readPacket() + if err != nil { + return nil, err + } + userAuthGSSAPITokenReq := &userAuthGSSAPIToken{} + if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { + return nil, err + } + authErr, perms, err = gssExchangeToken(gssapiConfig, userAuthGSSAPITokenReq.Token, s, sessionID, + userAuthReq) + if err != nil { + return nil, err + } + default: + authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method) + } + + authErrs = append(authErrs, authErr) + + if config.AuthLogCallback != nil { + config.AuthLogCallback(s, userAuthReq.Method, authErr) + } + + if authErr == nil { + break userAuthLoop + } + + authFailures++ + + var failureMsg userAuthFailureMsg + if config.PasswordCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "password") + } + if config.PublicKeyCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "publickey") + } + if config.KeyboardInteractiveCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") + } + if config.GSSAPIWithMICConfig != nil && config.GSSAPIWithMICConfig.Server != nil && + config.GSSAPIWithMICConfig.AllowLogin != nil { + failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic") + } + + if len(failureMsg.Methods) == 0 { + return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") + } + + if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil { + return nil, err + } + } + + if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { + return nil, err + } + return perms, nil +} + +// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by +// asking the client on the other side of a ServerConn. +type sshClientKeyboardInteractive struct { + *connection +} + +func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) { + if len(questions) != len(echos) { + return nil, errors.New("ssh: echos and questions must have equal length") + } + + var prompts []byte + for i := range questions { + prompts = appendString(prompts, questions[i]) + prompts = appendBool(prompts, echos[i]) + } + + if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{ + Instruction: instruction, + NumPrompts: uint32(len(questions)), + Prompts: prompts, + })); err != nil { + return nil, err + } + + packet, err := c.transport.readPacket() + if err != nil { + return nil, err + } + if packet[0] != msgUserAuthInfoResponse { + return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0]) + } + packet = packet[1:] + + n, packet, ok := parseUint32(packet) + if !ok || int(n) != len(questions) { + return nil, parseError(msgUserAuthInfoResponse) + } + + for i := uint32(0); i < n; i++ { + ans, rest, ok := parseString(packet) + if !ok { + return nil, parseError(msgUserAuthInfoResponse) + } + + answers = append(answers, string(ans)) + packet = rest + } + if len(packet) != 0 { + return nil, errors.New("ssh: junk at end of message") + } + + return answers, nil +} diff --git a/internal/crypto/ssh/session.go b/internal/crypto/ssh/session.go new file mode 100644 index 000000000..d3321f6b7 --- /dev/null +++ b/internal/crypto/ssh/session.go @@ -0,0 +1,647 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Session implements an interactive session described in +// "RFC 4254, section 6". + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "sync" +) + +type Signal string + +// POSIX signals as listed in RFC 4254 Section 6.10. +const ( + SIGABRT Signal = "ABRT" + SIGALRM Signal = "ALRM" + SIGFPE Signal = "FPE" + SIGHUP Signal = "HUP" + SIGILL Signal = "ILL" + SIGINT Signal = "INT" + SIGKILL Signal = "KILL" + SIGPIPE Signal = "PIPE" + SIGQUIT Signal = "QUIT" + SIGSEGV Signal = "SEGV" + SIGTERM Signal = "TERM" + SIGUSR1 Signal = "USR1" + SIGUSR2 Signal = "USR2" +) + +var signals = map[Signal]int{ + SIGABRT: 6, + SIGALRM: 14, + SIGFPE: 8, + SIGHUP: 1, + SIGILL: 4, + SIGINT: 2, + SIGKILL: 9, + SIGPIPE: 13, + SIGQUIT: 3, + SIGSEGV: 11, + SIGTERM: 15, +} + +type TerminalModes map[uint8]uint32 + +// POSIX terminal mode flags as listed in RFC 4254 Section 8. +const ( + tty_OP_END = 0 + VINTR = 1 + VQUIT = 2 + VERASE = 3 + VKILL = 4 + VEOF = 5 + VEOL = 6 + VEOL2 = 7 + VSTART = 8 + VSTOP = 9 + VSUSP = 10 + VDSUSP = 11 + VREPRINT = 12 + VWERASE = 13 + VLNEXT = 14 + VFLUSH = 15 + VSWTCH = 16 + VSTATUS = 17 + VDISCARD = 18 + IGNPAR = 30 + PARMRK = 31 + INPCK = 32 + ISTRIP = 33 + INLCR = 34 + IGNCR = 35 + ICRNL = 36 + IUCLC = 37 + IXON = 38 + IXANY = 39 + IXOFF = 40 + IMAXBEL = 41 + ISIG = 50 + ICANON = 51 + XCASE = 52 + ECHO = 53 + ECHOE = 54 + ECHOK = 55 + ECHONL = 56 + NOFLSH = 57 + TOSTOP = 58 + IEXTEN = 59 + ECHOCTL = 60 + ECHOKE = 61 + PENDIN = 62 + OPOST = 70 + OLCUC = 71 + ONLCR = 72 + OCRNL = 73 + ONOCR = 74 + ONLRET = 75 + CS7 = 90 + CS8 = 91 + PARENB = 92 + PARODD = 93 + TTY_OP_ISPEED = 128 + TTY_OP_OSPEED = 129 +) + +// A Session represents a connection to a remote command or shell. +type Session struct { + // Stdin specifies the remote process's standard input. + // If Stdin is nil, the remote process reads from an empty + // bytes.Buffer. + Stdin io.Reader + + // Stdout and Stderr specify the remote process's standard + // output and error. + // + // If either is nil, Run connects the corresponding file + // descriptor to an instance of ioutil.Discard. There is a + // fixed amount of buffering that is shared for the two streams. + // If either blocks it may eventually cause the remote + // command to block. + Stdout io.Writer + Stderr io.Writer + + ch Channel // the channel backing this session + started bool // true once Start, Run or Shell is invoked. + copyFuncs []func() error + errors chan error // one send per copyFunc + + // true if pipe method is active + stdinpipe, stdoutpipe, stderrpipe bool + + // stdinPipeWriter is non-nil if StdinPipe has not been called + // and Stdin was specified by the user; it is the write end of + // a pipe connecting Session.Stdin to the stdin channel. + stdinPipeWriter io.WriteCloser + + exitStatus chan error +} + +// SendRequest sends an out-of-band channel request on the SSH channel +// underlying the session. +func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + return s.ch.SendRequest(name, wantReply, payload) +} + +func (s *Session) Close() error { + return s.ch.Close() +} + +// RFC 4254 Section 6.4. +type setenvRequest struct { + Name string + Value string +} + +// Setenv sets an environment variable that will be applied to any +// command executed by Shell or Run. +func (s *Session) Setenv(name, value string) error { + msg := setenvRequest{ + Name: name, + Value: value, + } + ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: setenv failed") + } + return err +} + +// RFC 4254 Section 6.2. +type ptyRequestMsg struct { + Term string + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 + Modelist string +} + +// RequestPty requests the association of a pty with the session on the remote host. +func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { + var tm []byte + for k, v := range termmodes { + kv := struct { + Key byte + Val uint32 + }{k, v} + + tm = append(tm, Marshal(&kv)...) + } + tm = append(tm, tty_OP_END) + req := ptyRequestMsg{ + Term: term, + Columns: uint32(w), + Rows: uint32(h), + Width: uint32(w * 8), + Height: uint32(h * 8), + Modelist: string(tm), + } + ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) + if err == nil && !ok { + err = errors.New("ssh: pty-req failed") + } + return err +} + +// RFC 4254 Section 6.5. +type subsystemRequestMsg struct { + Subsystem string +} + +// RequestSubsystem requests the association of a subsystem with the session on the remote host. +// A subsystem is a predefined command that runs in the background when the ssh session is initiated +func (s *Session) RequestSubsystem(subsystem string) error { + msg := subsystemRequestMsg{ + Subsystem: subsystem, + } + ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: subsystem request failed") + } + return err +} + +// RFC 4254 Section 6.7. +type ptyWindowChangeMsg struct { + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 +} + +// WindowChange informs the remote host about a terminal window dimension change to h rows and w columns. +func (s *Session) WindowChange(h, w int) error { + req := ptyWindowChangeMsg{ + Columns: uint32(w), + Rows: uint32(h), + Width: uint32(w * 8), + Height: uint32(h * 8), + } + _, err := s.ch.SendRequest("window-change", false, Marshal(&req)) + return err +} + +// RFC 4254 Section 6.9. +type signalMsg struct { + Signal string +} + +// Signal sends the given signal to the remote process. +// sig is one of the SIG* constants. +func (s *Session) Signal(sig Signal) error { + msg := signalMsg{ + Signal: string(sig), + } + + _, err := s.ch.SendRequest("signal", false, Marshal(&msg)) + return err +} + +// RFC 4254 Section 6.5. +type execMsg struct { + Command string +} + +// Start runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start or Shell. +func (s *Session) Start(cmd string) error { + if s.started { + return errors.New("ssh: session already started") + } + req := execMsg{ + Command: cmd, + } + + ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) + if err == nil && !ok { + err = fmt.Errorf("ssh: command %v failed", cmd) + } + if err != nil { + return err + } + return s.start() +} + +// Run runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start, Shell, Output, +// or CombinedOutput. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the remote server does not send an exit status, an error of type +// *ExitMissingError is returned. If the command completes +// unsuccessfully or is interrupted by a signal, the error is of type +// *ExitError. Other error types may be returned for I/O problems. +func (s *Session) Run(cmd string) error { + err := s.Start(cmd) + if err != nil { + return err + } + return s.Wait() +} + +// Output runs cmd on the remote host and returns its standard output. +func (s *Session) Output(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + var b bytes.Buffer + s.Stdout = &b + err := s.Run(cmd) + return b.Bytes(), err +} + +type singleWriter struct { + b bytes.Buffer + mu sync.Mutex +} + +func (w *singleWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + return w.b.Write(p) +} + +// CombinedOutput runs cmd on the remote host and returns its combined +// standard output and standard error. +func (s *Session) CombinedOutput(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + var b singleWriter + s.Stdout = &b + s.Stderr = &b + err := s.Run(cmd) + return b.b.Bytes(), err +} + +// Shell starts a login shell on the remote host. A Session only +// accepts one call to Run, Start, Shell, Output, or CombinedOutput. +func (s *Session) Shell() error { + if s.started { + return errors.New("ssh: session already started") + } + + ok, err := s.ch.SendRequest("shell", true, nil) + if err == nil && !ok { + return errors.New("ssh: could not start shell") + } + if err != nil { + return err + } + return s.start() +} + +func (s *Session) start() error { + s.started = true + + type F func(*Session) + for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} { + setupFd(s) + } + + s.errors = make(chan error, len(s.copyFuncs)) + for _, fn := range s.copyFuncs { + go func(fn func() error) { + s.errors <- fn() + }(fn) + } + return nil +} + +// Wait waits for the remote command to exit. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the remote server does not send an exit status, an error of type +// *ExitMissingError is returned. If the command completes +// unsuccessfully or is interrupted by a signal, the error is of type +// *ExitError. Other error types may be returned for I/O problems. +func (s *Session) Wait() error { + if !s.started { + return errors.New("ssh: session not started") + } + waitErr := <-s.exitStatus + + if s.stdinPipeWriter != nil { + s.stdinPipeWriter.Close() + } + var copyError error + for range s.copyFuncs { + if err := <-s.errors; err != nil && copyError == nil { + copyError = err + } + } + if waitErr != nil { + return waitErr + } + return copyError +} + +func (s *Session) wait(reqs <-chan *Request) error { + wm := Waitmsg{status: -1} + // Wait for msg channel to be closed before returning. + for msg := range reqs { + switch msg.Type { + case "exit-status": + wm.status = int(binary.BigEndian.Uint32(msg.Payload)) + case "exit-signal": + var sigval struct { + Signal string + CoreDumped bool + Error string + Lang string + } + if err := Unmarshal(msg.Payload, &sigval); err != nil { + return err + } + + // Must sanitize strings? + wm.signal = sigval.Signal + wm.msg = sigval.Error + wm.lang = sigval.Lang + default: + // This handles keepalives and matches + // OpenSSH's behaviour. + if msg.WantReply { + msg.Reply(false, nil) + } + } + } + if wm.status == 0 { + return nil + } + if wm.status == -1 { + // exit-status was never sent from server + if wm.signal == "" { + // signal was not sent either. RFC 4254 + // section 6.10 recommends against this + // behavior, but it is allowed, so we let + // clients handle it. + return &ExitMissingError{} + } + wm.status = 128 + if _, ok := signals[Signal(wm.signal)]; ok { + wm.status += signals[Signal(wm.signal)] + } + } + + return &ExitError{wm} +} + +// ExitMissingError is returned if a session is torn down cleanly, but +// the server sends no confirmation of the exit status. +type ExitMissingError struct{} + +func (e *ExitMissingError) Error() string { + return "wait: remote command exited without exit status or exit signal" +} + +func (s *Session) stdin() { + if s.stdinpipe { + return + } + var stdin io.Reader + if s.Stdin == nil { + stdin = new(bytes.Buffer) + } else { + r, w := io.Pipe() + go func() { + _, err := io.Copy(w, s.Stdin) + w.CloseWithError(err) + }() + stdin, s.stdinPipeWriter = r, w + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.ch, stdin) + if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { + err = err1 + } + return err + }) +} + +func (s *Session) stdout() { + if s.stdoutpipe { + return + } + if s.Stdout == nil { + s.Stdout = ioutil.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stdout, s.ch) + return err + }) +} + +func (s *Session) stderr() { + if s.stderrpipe { + return + } + if s.Stderr == nil { + s.Stderr = ioutil.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stderr, s.ch.Stderr()) + return err + }) +} + +// sessionStdin reroutes Close to CloseWrite. +type sessionStdin struct { + io.Writer + ch Channel +} + +func (s *sessionStdin) Close() error { + return s.ch.CloseWrite() +} + +// StdinPipe returns a pipe that will be connected to the +// remote command's standard input when the command starts. +func (s *Session) StdinPipe() (io.WriteCloser, error) { + if s.Stdin != nil { + return nil, errors.New("ssh: Stdin already set") + } + if s.started { + return nil, errors.New("ssh: StdinPipe after process started") + } + s.stdinpipe = true + return &sessionStdin{s.ch, s.ch}, nil +} + +// StdoutPipe returns a pipe that will be connected to the +// remote command's standard output when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StdoutPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StdoutPipe() (io.Reader, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.started { + return nil, errors.New("ssh: StdoutPipe after process started") + } + s.stdoutpipe = true + return s.ch, nil +} + +// StderrPipe returns a pipe that will be connected to the +// remote command's standard error when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StderrPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StderrPipe() (io.Reader, error) { + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + if s.started { + return nil, errors.New("ssh: StderrPipe after process started") + } + s.stderrpipe = true + return s.ch.Stderr(), nil +} + +// newSession returns a new interactive session on the remote host. +func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { + s := &Session{ + ch: ch, + } + s.exitStatus = make(chan error, 1) + go func() { + s.exitStatus <- s.wait(reqs) + }() + + return s, nil +} + +// An ExitError reports unsuccessful completion of a remote command. +type ExitError struct { + Waitmsg +} + +func (e *ExitError) Error() string { + return e.Waitmsg.String() +} + +// Waitmsg stores the information about an exited remote command +// as reported by Wait. +type Waitmsg struct { + status int + signal string + msg string + lang string +} + +// ExitStatus returns the exit status of the remote command. +func (w Waitmsg) ExitStatus() int { + return w.status +} + +// Signal returns the exit signal of the remote command if +// it was terminated violently. +func (w Waitmsg) Signal() string { + return w.signal +} + +// Msg returns the exit message given by the remote command +func (w Waitmsg) Msg() string { + return w.msg +} + +// Lang returns the language tag. See RFC 3066 +func (w Waitmsg) Lang() string { + return w.lang +} + +func (w Waitmsg) String() string { + str := fmt.Sprintf("Process exited with status %v", w.status) + if w.signal != "" { + str += fmt.Sprintf(" from signal %v", w.signal) + } + if w.msg != "" { + str += fmt.Sprintf(". Reason was: %v", w.msg) + } + return str +} diff --git a/internal/crypto/ssh/ssh_gss.go b/internal/crypto/ssh/ssh_gss.go new file mode 100644 index 000000000..24bd7c8e8 --- /dev/null +++ b/internal/crypto/ssh/ssh_gss.go @@ -0,0 +1,139 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/asn1" + "errors" +) + +var krb5OID []byte + +func init() { + krb5OID, _ = asn1.Marshal(krb5Mesh) +} + +// GSSAPIClient provides the API to plug-in GSSAPI authentication for client logins. +type GSSAPIClient interface { + // InitSecContext initiates the establishment of a security context for GSS-API between the + // ssh client and ssh server. Initially the token parameter should be specified as nil. + // The routine may return a outputToken which should be transferred to + // the ssh server, where the ssh server will present it to + // AcceptSecContext. If no token need be sent, InitSecContext will indicate this by setting + // needContinue to false. To complete the context + // establishment, one or more reply tokens may be required from the ssh + // server;if so, InitSecContext will return a needContinue which is true. + // In this case, InitSecContext should be called again when the + // reply token is received from the ssh server, passing the reply + // token to InitSecContext via the token parameters. + // See RFC 2743 section 2.2.1 and RFC 4462 section 3.4. + InitSecContext(target string, token []byte, isGSSDelegCreds bool) (outputToken []byte, needContinue bool, err error) + // GetMIC generates a cryptographic MIC for the SSH2 message, and places + // the MIC in a token for transfer to the ssh server. + // The contents of the MIC field are obtained by calling GSS_GetMIC() + // over the following, using the GSS-API context that was just + // established: + // string session identifier + // byte SSH_MSG_USERAUTH_REQUEST + // string user name + // string service + // string "gssapi-with-mic" + // See RFC 2743 section 2.3.1 and RFC 4462 3.5. + GetMIC(micFiled []byte) ([]byte, error) + // Whenever possible, it should be possible for + // DeleteSecContext() calls to be successfully processed even + // if other calls cannot succeed, thereby enabling context-related + // resources to be released. + // In addition to deleting established security contexts, + // gss_delete_sec_context must also be able to delete "half-built" + // security contexts resulting from an incomplete sequence of + // InitSecContext()/AcceptSecContext() calls. + // See RFC 2743 section 2.2.3. + DeleteSecContext() error +} + +// GSSAPIServer provides the API to plug in GSSAPI authentication for server logins. +type GSSAPIServer interface { + // AcceptSecContext allows a remotely initiated security context between the application + // and a remote peer to be established by the ssh client. The routine may return a + // outputToken which should be transferred to the ssh client, + // where the ssh client will present it to InitSecContext. + // If no token need be sent, AcceptSecContext will indicate this + // by setting the needContinue to false. To + // complete the context establishment, one or more reply tokens may be + // required from the ssh client. if so, AcceptSecContext + // will return a needContinue which is true, in which case it + // should be called again when the reply token is received from the ssh + // client, passing the token to AcceptSecContext via the + // token parameters. + // The srcName return value is the authenticated username. + // See RFC 2743 section 2.2.2 and RFC 4462 section 3.4. + AcceptSecContext(token []byte) (outputToken []byte, srcName string, needContinue bool, err error) + // VerifyMIC verifies that a cryptographic MIC, contained in the token parameter, + // fits the supplied message is received from the ssh client. + // See RFC 2743 section 2.3.2. + VerifyMIC(micField []byte, micToken []byte) error + // Whenever possible, it should be possible for + // DeleteSecContext() calls to be successfully processed even + // if other calls cannot succeed, thereby enabling context-related + // resources to be released. + // In addition to deleting established security contexts, + // gss_delete_sec_context must also be able to delete "half-built" + // security contexts resulting from an incomplete sequence of + // InitSecContext()/AcceptSecContext() calls. + // See RFC 2743 section 2.2.3. + DeleteSecContext() error +} + +var ( + // OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication, + // so we also support the krb5 mechanism only. + // See RFC 1964 section 1. + krb5Mesh = asn1.ObjectIdentifier{1, 2, 840, 113554, 1, 2, 2} +) + +// The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST +// See RFC 4462 section 3.2. +type userAuthRequestGSSAPI struct { + N uint32 + OIDS []asn1.ObjectIdentifier +} + +func parseGSSAPIPayload(payload []byte) (*userAuthRequestGSSAPI, error) { + n, rest, ok := parseUint32(payload) + if !ok { + return nil, errors.New("parse uint32 failed") + } + s := &userAuthRequestGSSAPI{ + N: n, + OIDS: make([]asn1.ObjectIdentifier, n), + } + for i := 0; i < int(n); i++ { + var ( + desiredMech []byte + err error + ) + desiredMech, rest, ok = parseString(rest) + if !ok { + return nil, errors.New("parse string failed") + } + if rest, err = asn1.Unmarshal(desiredMech, &s.OIDS[i]); err != nil { + return nil, err + } + + } + return s, nil +} + +// See RFC 4462 section 3.6. +func buildMIC(sessionID string, username string, service string, authMethod string) []byte { + out := make([]byte, 0, 0) + out = appendString(out, sessionID) + out = append(out, msgUserAuthRequest) + out = appendString(out, username) + out = appendString(out, service) + out = appendString(out, authMethod) + return out +} diff --git a/internal/crypto/ssh/streamlocal.go b/internal/crypto/ssh/streamlocal.go new file mode 100644 index 000000000..b171b330b --- /dev/null +++ b/internal/crypto/ssh/streamlocal.go @@ -0,0 +1,116 @@ +package ssh + +import ( + "errors" + "io" + "net" +) + +// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message +// with "direct-streamlocal@openssh.com" string. +// +// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235 +type streamLocalChannelOpenDirectMsg struct { + socketPath string + reserved0 string + reserved1 uint32 +} + +// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message +// with "forwarded-streamlocal@openssh.com" string. +type forwardedStreamLocalPayload struct { + SocketPath string + Reserved0 string +} + +// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message +// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string. +type streamLocalChannelForwardMsg struct { + socketPath string +} + +// ListenUnix is similar to ListenTCP but uses a Unix domain socket. +func (c *Client) ListenUnix(socketPath string) (net.Listener, error) { + c.handleForwardsOnce.Do(c.handleForwards) + m := streamLocalChannelForwardMsg{ + socketPath, + } + // send message + ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m)) + if err != nil { + return nil, err + } + if !ok { + return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer") + } + ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"}) + + return &unixListener{socketPath, c, ch}, nil +} + +func (c *Client) dialStreamLocal(socketPath string) (Channel, error) { + msg := streamLocalChannelOpenDirectMsg{ + socketPath: socketPath, + } + ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg)) + if err != nil { + return nil, err + } + go DiscardRequests(in) + return ch, err +} + +type unixListener struct { + socketPath string + + conn *Client + in <-chan forward +} + +// Accept waits for and returns the next connection to the listener. +func (l *unixListener) Accept() (net.Conn, error) { + s, ok := <-l.in + if !ok { + return nil, io.EOF + } + ch, incoming, err := s.newCh.Accept() + if err != nil { + return nil, err + } + go DiscardRequests(incoming) + + return &chanConn{ + Channel: ch, + laddr: &net.UnixAddr{ + Name: l.socketPath, + Net: "unix", + }, + raddr: &net.UnixAddr{ + Name: "@", + Net: "unix", + }, + }, nil +} + +// Close closes the listener. +func (l *unixListener) Close() error { + // this also closes the listener. + l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"}) + m := streamLocalChannelForwardMsg{ + l.socketPath, + } + ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m)) + if err == nil && !ok { + err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed") + } + return err +} + +// Addr returns the listener's network address. +func (l *unixListener) Addr() net.Addr { + return &net.UnixAddr{ + Name: l.socketPath, + Net: "unix", + } +} diff --git a/internal/crypto/ssh/tcpip.go b/internal/crypto/ssh/tcpip.go new file mode 100644 index 000000000..80d35f5ec --- /dev/null +++ b/internal/crypto/ssh/tcpip.go @@ -0,0 +1,474 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "fmt" + "io" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "time" +) + +// Listen requests the remote peer open a listening socket on +// addr. Incoming connections will be available by calling Accept on +// the returned net.Listener. The listener must be serviced, or the +// SSH connection may hang. +// N must be "tcp", "tcp4", "tcp6", or "unix". +func (c *Client) Listen(n, addr string) (net.Listener, error) { + switch n { + case "tcp", "tcp4", "tcp6": + laddr, err := net.ResolveTCPAddr(n, addr) + if err != nil { + return nil, err + } + return c.ListenTCP(laddr) + case "unix": + return c.ListenUnix(addr) + default: + return nil, fmt.Errorf("ssh: unsupported protocol: %s", n) + } +} + +// Automatic port allocation is broken with OpenSSH before 6.0. See +// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In +// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0, +// rather than the actual port number. This means you can never open +// two different listeners with auto allocated ports. We work around +// this by trying explicit ports until we succeed. + +const openSSHPrefix = "OpenSSH_" + +var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano())) + +// isBrokenOpenSSHVersion returns true if the given version string +// specifies a version of OpenSSH that is known to have a bug in port +// forwarding. +func isBrokenOpenSSHVersion(versionStr string) bool { + i := strings.Index(versionStr, openSSHPrefix) + if i < 0 { + return false + } + i += len(openSSHPrefix) + j := i + for ; j < len(versionStr); j++ { + if versionStr[j] < '0' || versionStr[j] > '9' { + break + } + } + version, _ := strconv.Atoi(versionStr[i:j]) + return version < 6 +} + +// autoPortListenWorkaround simulates automatic port allocation by +// trying random ports repeatedly. +func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) { + var sshListener net.Listener + var err error + const tries = 10 + for i := 0; i < tries; i++ { + addr := *laddr + addr.Port = 1024 + portRandomizer.Intn(60000) + sshListener, err = c.ListenTCP(&addr) + if err == nil { + laddr.Port = addr.Port + return sshListener, err + } + } + return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err) +} + +// RFC 4254 7.1 +type channelForwardMsg struct { + addr string + rport uint32 +} + +// handleForwards starts goroutines handling forwarded connections. +// It's called on first use by (*Client).ListenTCP to not launch +// goroutines until needed. +func (c *Client) handleForwards() { + go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip")) + go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com")) +} + +// ListenTCP requests the remote peer open a listening socket +// on laddr. Incoming connections will be available by calling +// Accept on the returned net.Listener. +func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { + c.handleForwardsOnce.Do(c.handleForwards) + if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) { + return c.autoPortListenWorkaround(laddr) + } + + m := channelForwardMsg{ + laddr.IP.String(), + uint32(laddr.Port), + } + // send message + ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m)) + if err != nil { + return nil, err + } + if !ok { + return nil, errors.New("ssh: tcpip-forward request denied by peer") + } + + // If the original port was 0, then the remote side will + // supply a real port number in the response. + if laddr.Port == 0 { + var p struct { + Port uint32 + } + if err := Unmarshal(resp, &p); err != nil { + return nil, err + } + laddr.Port = int(p.Port) + } + + // Register this forward, using the port number we obtained. + ch := c.forwards.add(laddr) + + return &tcpListener{laddr, c, ch}, nil +} + +// forwardList stores a mapping between remote +// forward requests and the tcpListeners. +type forwardList struct { + sync.Mutex + entries []forwardEntry +} + +// forwardEntry represents an established mapping of a laddr on a +// remote ssh server to a channel connected to a tcpListener. +type forwardEntry struct { + laddr net.Addr + c chan forward +} + +// forward represents an incoming forwarded tcpip connection. The +// arguments to add/remove/lookup should be address as specified in +// the original forward-request. +type forward struct { + newCh NewChannel // the ssh client channel underlying this forward + raddr net.Addr // the raddr of the incoming connection +} + +func (l *forwardList) add(addr net.Addr) chan forward { + l.Lock() + defer l.Unlock() + f := forwardEntry{ + laddr: addr, + c: make(chan forward, 1), + } + l.entries = append(l.entries, f) + return f.c +} + +// See RFC 4254, section 7.2 +type forwardedTCPPayload struct { + Addr string + Port uint32 + OriginAddr string + OriginPort uint32 +} + +// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr. +func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) { + if port == 0 || port > 65535 { + return nil, fmt.Errorf("ssh: port number out of range: %d", port) + } + ip := net.ParseIP(string(addr)) + if ip == nil { + return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr) + } + return &net.TCPAddr{IP: ip, Port: int(port)}, nil +} + +func (l *forwardList) handleChannels(in <-chan NewChannel) { + for ch := range in { + var ( + laddr net.Addr + raddr net.Addr + err error + ) + switch channelType := ch.ChannelType(); channelType { + case "forwarded-tcpip": + var payload forwardedTCPPayload + if err = Unmarshal(ch.ExtraData(), &payload); err != nil { + ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) + continue + } + + // RFC 4254 section 7.2 specifies that incoming + // addresses should list the address, in string + // format. It is implied that this should be an IP + // address, as it would be impossible to connect to it + // otherwise. + laddr, err = parseTCPAddr(payload.Addr, payload.Port) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + + case "forwarded-streamlocal@openssh.com": + var payload forwardedStreamLocalPayload + if err = Unmarshal(ch.ExtraData(), &payload); err != nil { + ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error()) + continue + } + laddr = &net.UnixAddr{ + Name: payload.SocketPath, + Net: "unix", + } + raddr = &net.UnixAddr{ + Name: "@", + Net: "unix", + } + default: + panic(fmt.Errorf("ssh: unknown channel type %s", channelType)) + } + if ok := l.forward(laddr, raddr, ch); !ok { + // Section 7.2, implementations MUST reject spurious incoming + // connections. + ch.Reject(Prohibited, "no forward for address") + continue + } + + } +} + +// remove removes the forward entry, and the channel feeding its +// listener. +func (l *forwardList) remove(addr net.Addr) { + l.Lock() + defer l.Unlock() + for i, f := range l.entries { + if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() { + l.entries = append(l.entries[:i], l.entries[i+1:]...) + close(f.c) + return + } + } +} + +// closeAll closes and clears all forwards. +func (l *forwardList) closeAll() { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + close(f.c) + } + l.entries = nil +} + +func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() { + f.c <- forward{newCh: ch, raddr: raddr} + return true + } + } + return false +} + +type tcpListener struct { + laddr *net.TCPAddr + + conn *Client + in <-chan forward +} + +// Accept waits for and returns the next connection to the listener. +func (l *tcpListener) Accept() (net.Conn, error) { + s, ok := <-l.in + if !ok { + return nil, io.EOF + } + ch, incoming, err := s.newCh.Accept() + if err != nil { + return nil, err + } + go DiscardRequests(incoming) + + return &chanConn{ + Channel: ch, + laddr: l.laddr, + raddr: s.raddr, + }, nil +} + +// Close closes the listener. +func (l *tcpListener) Close() error { + m := channelForwardMsg{ + l.laddr.IP.String(), + uint32(l.laddr.Port), + } + + // this also closes the listener. + l.conn.forwards.remove(l.laddr) + ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) + if err == nil && !ok { + err = errors.New("ssh: cancel-tcpip-forward failed") + } + return err +} + +// Addr returns the listener's network address. +func (l *tcpListener) Addr() net.Addr { + return l.laddr +} + +// Dial initiates a connection to the addr from the remote host. +// The resulting connection has a zero LocalAddr() and RemoteAddr(). +func (c *Client) Dial(n, addr string) (net.Conn, error) { + var ch Channel + switch n { + case "tcp", "tcp4", "tcp6": + // Parse the address into host and numeric port. + host, portString, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return nil, err + } + ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port)) + if err != nil { + return nil, err + } + // Use a zero address for local and remote address. + zeroAddr := &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + return &chanConn{ + Channel: ch, + laddr: zeroAddr, + raddr: zeroAddr, + }, nil + case "unix": + var err error + ch, err = c.dialStreamLocal(addr) + if err != nil { + return nil, err + } + return &chanConn{ + Channel: ch, + laddr: &net.UnixAddr{ + Name: "@", + Net: "unix", + }, + raddr: &net.UnixAddr{ + Name: addr, + Net: "unix", + }, + }, nil + default: + return nil, fmt.Errorf("ssh: unsupported protocol: %s", n) + } +} + +// DialTCP connects to the remote address raddr on the network net, +// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used +// as the local address for the connection. +func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) { + if laddr == nil { + laddr = &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + } + ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port) + if err != nil { + return nil, err + } + return &chanConn{ + Channel: ch, + laddr: laddr, + raddr: raddr, + }, nil +} + +// RFC 4254 7.2 +type channelOpenDirectMsg struct { + raddr string + rport uint32 + laddr string + lport uint32 +} + +func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) { + msg := channelOpenDirectMsg{ + raddr: raddr, + rport: uint32(rport), + laddr: laddr, + lport: uint32(lport), + } + ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg)) + if err != nil { + return nil, err + } + go DiscardRequests(in) + return ch, err +} + +type tcpChan struct { + Channel // the backing channel +} + +// chanConn fulfills the net.Conn interface without +// the tcpChan having to hold laddr or raddr directly. +type chanConn struct { + Channel + laddr, raddr net.Addr +} + +// LocalAddr returns the local network address. +func (t *chanConn) LocalAddr() net.Addr { + return t.laddr +} + +// RemoteAddr returns the remote network address. +func (t *chanConn) RemoteAddr() net.Addr { + return t.raddr +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. +func (t *chanConn) SetDeadline(deadline time.Time) error { + if err := t.SetReadDeadline(deadline); err != nil { + return err + } + return t.SetWriteDeadline(deadline) +} + +// SetReadDeadline sets the read deadline. +// A zero value for t means Read will not time out. +// After the deadline, the error from Read will implement net.Error +// with Timeout() == true. +func (t *chanConn) SetReadDeadline(deadline time.Time) error { + // for compatibility with previous version, + // the error message contains "tcpChan" + return errors.New("ssh: tcpChan: deadline not supported") +} + +// SetWriteDeadline exists to satisfy the net.Conn interface +// but is not implemented by this type. It always returns an error. +func (t *chanConn) SetWriteDeadline(deadline time.Time) error { + return errors.New("ssh: tcpChan: deadline not supported") +} diff --git a/internal/crypto/ssh/terminal/terminal.go b/internal/crypto/ssh/terminal/terminal.go new file mode 100644 index 000000000..2ffb97bfb --- /dev/null +++ b/internal/crypto/ssh/terminal/terminal.go @@ -0,0 +1,987 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package terminal + +import ( + "bytes" + "io" + "runtime" + "strconv" + "sync" + "unicode/utf8" +) + +// EscapeCodes contains escape sequences that can be written to the terminal in +// order to achieve different styles of text. +type EscapeCodes struct { + // Foreground colors + Black, Red, Green, Yellow, Blue, Magenta, Cyan, White []byte + + // Reset all attributes + Reset []byte +} + +var vt100EscapeCodes = EscapeCodes{ + Black: []byte{keyEscape, '[', '3', '0', 'm'}, + Red: []byte{keyEscape, '[', '3', '1', 'm'}, + Green: []byte{keyEscape, '[', '3', '2', 'm'}, + Yellow: []byte{keyEscape, '[', '3', '3', 'm'}, + Blue: []byte{keyEscape, '[', '3', '4', 'm'}, + Magenta: []byte{keyEscape, '[', '3', '5', 'm'}, + Cyan: []byte{keyEscape, '[', '3', '6', 'm'}, + White: []byte{keyEscape, '[', '3', '7', 'm'}, + + Reset: []byte{keyEscape, '[', '0', 'm'}, +} + +// Terminal contains the state for running a VT100 terminal that is capable of +// reading lines of input. +type Terminal struct { + // AutoCompleteCallback, if non-null, is called for each keypress with + // the full input line and the current position of the cursor (in + // bytes, as an index into |line|). If it returns ok=false, the key + // press is processed normally. Otherwise it returns a replacement line + // and the new cursor position. + AutoCompleteCallback func(line string, pos int, key rune) (newLine string, newPos int, ok bool) + + // Escape contains a pointer to the escape codes for this terminal. + // It's always a valid pointer, although the escape codes themselves + // may be empty if the terminal doesn't support them. + Escape *EscapeCodes + + // lock protects the terminal and the state in this object from + // concurrent processing of a key press and a Write() call. + lock sync.Mutex + + c io.ReadWriter + prompt []rune + + // line is the current line being entered. + line []rune + // pos is the logical position of the cursor in line + pos int + // echo is true if local echo is enabled + echo bool + // pasteActive is true iff there is a bracketed paste operation in + // progress. + pasteActive bool + + // cursorX contains the current X value of the cursor where the left + // edge is 0. cursorY contains the row number where the first row of + // the current line is 0. + cursorX, cursorY int + // maxLine is the greatest value of cursorY so far. + maxLine int + + termWidth, termHeight int + + // outBuf contains the terminal data to be sent. + outBuf []byte + // remainder contains the remainder of any partial key sequences after + // a read. It aliases into inBuf. + remainder []byte + inBuf [256]byte + + // history contains previously entered commands so that they can be + // accessed with the up and down keys. + history stRingBuffer + // historyIndex stores the currently accessed history entry, where zero + // means the immediately previous entry. + historyIndex int + // When navigating up and down the history it's possible to return to + // the incomplete, initial line. That value is stored in + // historyPending. + historyPending string +} + +// NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is +// a local terminal, that terminal must first have been put into raw mode. +// prompt is a string that is written at the start of each input line (i.e. +// "> "). +func NewTerminal(c io.ReadWriter, prompt string) *Terminal { + return &Terminal{ + Escape: &vt100EscapeCodes, + c: c, + prompt: []rune(prompt), + termWidth: 80, + termHeight: 24, + echo: true, + historyIndex: -1, + } +} + +const ( + keyCtrlC = 3 + keyCtrlD = 4 + keyCtrlU = 21 + keyEnter = '\r' + keyEscape = 27 + keyBackspace = 127 + keyUnknown = 0xd800 /* UTF-16 surrogate area */ + iota + keyUp + keyDown + keyLeft + keyRight + keyAltLeft + keyAltRight + keyHome + keyEnd + keyDeleteWord + keyDeleteLine + keyClearScreen + keyPasteStart + keyPasteEnd +) + +var ( + crlf = []byte{'\r', '\n'} + pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'} + pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'} +) + +// bytesToKey tries to parse a key sequence from b. If successful, it returns +// the key and the remainder of the input. Otherwise it returns utf8.RuneError. +func bytesToKey(b []byte, pasteActive bool) (rune, []byte) { + if len(b) == 0 { + return utf8.RuneError, nil + } + + if !pasteActive { + switch b[0] { + case 1: // ^A + return keyHome, b[1:] + case 2: // ^B + return keyLeft, b[1:] + case 5: // ^E + return keyEnd, b[1:] + case 6: // ^F + return keyRight, b[1:] + case 8: // ^H + return keyBackspace, b[1:] + case 11: // ^K + return keyDeleteLine, b[1:] + case 12: // ^L + return keyClearScreen, b[1:] + case 23: // ^W + return keyDeleteWord, b[1:] + case 14: // ^N + return keyDown, b[1:] + case 16: // ^P + return keyUp, b[1:] + } + } + + if b[0] != keyEscape { + if !utf8.FullRune(b) { + return utf8.RuneError, b + } + r, l := utf8.DecodeRune(b) + return r, b[l:] + } + + if !pasteActive && len(b) >= 3 && b[0] == keyEscape && b[1] == '[' { + switch b[2] { + case 'A': + return keyUp, b[3:] + case 'B': + return keyDown, b[3:] + case 'C': + return keyRight, b[3:] + case 'D': + return keyLeft, b[3:] + case 'H': + return keyHome, b[3:] + case 'F': + return keyEnd, b[3:] + } + } + + if !pasteActive && len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' { + switch b[5] { + case 'C': + return keyAltRight, b[6:] + case 'D': + return keyAltLeft, b[6:] + } + } + + if !pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteStart) { + return keyPasteStart, b[6:] + } + + if pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteEnd) { + return keyPasteEnd, b[6:] + } + + // If we get here then we have a key that we don't recognise, or a + // partial sequence. It's not clear how one should find the end of a + // sequence without knowing them all, but it seems that [a-zA-Z~] only + // appears at the end of a sequence. + for i, c := range b[0:] { + if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c == '~' { + return keyUnknown, b[i+1:] + } + } + + return utf8.RuneError, b +} + +// queue appends data to the end of t.outBuf +func (t *Terminal) queue(data []rune) { + t.outBuf = append(t.outBuf, []byte(string(data))...) +} + +var eraseUnderCursor = []rune{' ', keyEscape, '[', 'D'} +var space = []rune{' '} + +func isPrintable(key rune) bool { + isInSurrogateArea := key >= 0xd800 && key <= 0xdbff + return key >= 32 && !isInSurrogateArea +} + +// moveCursorToPos appends data to t.outBuf which will move the cursor to the +// given, logical position in the text. +func (t *Terminal) moveCursorToPos(pos int) { + if !t.echo { + return + } + + x := visualLength(t.prompt) + pos + y := x / t.termWidth + x = x % t.termWidth + + up := 0 + if y < t.cursorY { + up = t.cursorY - y + } + + down := 0 + if y > t.cursorY { + down = y - t.cursorY + } + + left := 0 + if x < t.cursorX { + left = t.cursorX - x + } + + right := 0 + if x > t.cursorX { + right = x - t.cursorX + } + + t.cursorX = x + t.cursorY = y + t.move(up, down, left, right) +} + +func (t *Terminal) move(up, down, left, right int) { + m := []rune{} + + // 1 unit up can be expressed as ^[[A or ^[A + // 5 units up can be expressed as ^[[5A + + if up == 1 { + m = append(m, keyEscape, '[', 'A') + } else if up > 1 { + m = append(m, keyEscape, '[') + m = append(m, []rune(strconv.Itoa(up))...) + m = append(m, 'A') + } + + if down == 1 { + m = append(m, keyEscape, '[', 'B') + } else if down > 1 { + m = append(m, keyEscape, '[') + m = append(m, []rune(strconv.Itoa(down))...) + m = append(m, 'B') + } + + if right == 1 { + m = append(m, keyEscape, '[', 'C') + } else if right > 1 { + m = append(m, keyEscape, '[') + m = append(m, []rune(strconv.Itoa(right))...) + m = append(m, 'C') + } + + if left == 1 { + m = append(m, keyEscape, '[', 'D') + } else if left > 1 { + m = append(m, keyEscape, '[') + m = append(m, []rune(strconv.Itoa(left))...) + m = append(m, 'D') + } + + t.queue(m) +} + +func (t *Terminal) clearLineToRight() { + op := []rune{keyEscape, '[', 'K'} + t.queue(op) +} + +const maxLineLength = 4096 + +func (t *Terminal) setLine(newLine []rune, newPos int) { + if t.echo { + t.moveCursorToPos(0) + t.writeLine(newLine) + for i := len(newLine); i < len(t.line); i++ { + t.writeLine(space) + } + t.moveCursorToPos(newPos) + } + t.line = newLine + t.pos = newPos +} + +func (t *Terminal) advanceCursor(places int) { + t.cursorX += places + t.cursorY += t.cursorX / t.termWidth + if t.cursorY > t.maxLine { + t.maxLine = t.cursorY + } + t.cursorX = t.cursorX % t.termWidth + + if places > 0 && t.cursorX == 0 { + // Normally terminals will advance the current position + // when writing a character. But that doesn't happen + // for the last character in a line. However, when + // writing a character (except a new line) that causes + // a line wrap, the position will be advanced two + // places. + // + // So, if we are stopping at the end of a line, we + // need to write a newline so that our cursor can be + // advanced to the next line. + t.outBuf = append(t.outBuf, '\r', '\n') + } +} + +func (t *Terminal) eraseNPreviousChars(n int) { + if n == 0 { + return + } + + if t.pos < n { + n = t.pos + } + t.pos -= n + t.moveCursorToPos(t.pos) + + copy(t.line[t.pos:], t.line[n+t.pos:]) + t.line = t.line[:len(t.line)-n] + if t.echo { + t.writeLine(t.line[t.pos:]) + for i := 0; i < n; i++ { + t.queue(space) + } + t.advanceCursor(n) + t.moveCursorToPos(t.pos) + } +} + +// countToLeftWord returns then number of characters from the cursor to the +// start of the previous word. +func (t *Terminal) countToLeftWord() int { + if t.pos == 0 { + return 0 + } + + pos := t.pos - 1 + for pos > 0 { + if t.line[pos] != ' ' { + break + } + pos-- + } + for pos > 0 { + if t.line[pos] == ' ' { + pos++ + break + } + pos-- + } + + return t.pos - pos +} + +// countToRightWord returns then number of characters from the cursor to the +// start of the next word. +func (t *Terminal) countToRightWord() int { + pos := t.pos + for pos < len(t.line) { + if t.line[pos] == ' ' { + break + } + pos++ + } + for pos < len(t.line) { + if t.line[pos] != ' ' { + break + } + pos++ + } + return pos - t.pos +} + +// visualLength returns the number of visible glyphs in s. +func visualLength(runes []rune) int { + inEscapeSeq := false + length := 0 + + for _, r := range runes { + switch { + case inEscapeSeq: + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { + inEscapeSeq = false + } + case r == '\x1b': + inEscapeSeq = true + default: + length++ + } + } + + return length +} + +// handleKey processes the given key and, optionally, returns a line of text +// that the user has entered. +func (t *Terminal) handleKey(key rune) (line string, ok bool) { + if t.pasteActive && key != keyEnter { + t.addKeyToLine(key) + return + } + + switch key { + case keyBackspace: + if t.pos == 0 { + return + } + t.eraseNPreviousChars(1) + case keyAltLeft: + // move left by a word. + t.pos -= t.countToLeftWord() + t.moveCursorToPos(t.pos) + case keyAltRight: + // move right by a word. + t.pos += t.countToRightWord() + t.moveCursorToPos(t.pos) + case keyLeft: + if t.pos == 0 { + return + } + t.pos-- + t.moveCursorToPos(t.pos) + case keyRight: + if t.pos == len(t.line) { + return + } + t.pos++ + t.moveCursorToPos(t.pos) + case keyHome: + if t.pos == 0 { + return + } + t.pos = 0 + t.moveCursorToPos(t.pos) + case keyEnd: + if t.pos == len(t.line) { + return + } + t.pos = len(t.line) + t.moveCursorToPos(t.pos) + case keyUp: + entry, ok := t.history.NthPreviousEntry(t.historyIndex + 1) + if !ok { + return "", false + } + if t.historyIndex == -1 { + t.historyPending = string(t.line) + } + t.historyIndex++ + runes := []rune(entry) + t.setLine(runes, len(runes)) + case keyDown: + switch t.historyIndex { + case -1: + return + case 0: + runes := []rune(t.historyPending) + t.setLine(runes, len(runes)) + t.historyIndex-- + default: + entry, ok := t.history.NthPreviousEntry(t.historyIndex - 1) + if ok { + t.historyIndex-- + runes := []rune(entry) + t.setLine(runes, len(runes)) + } + } + case keyEnter: + t.moveCursorToPos(len(t.line)) + t.queue([]rune("\r\n")) + line = string(t.line) + ok = true + t.line = t.line[:0] + t.pos = 0 + t.cursorX = 0 + t.cursorY = 0 + t.maxLine = 0 + case keyDeleteWord: + // Delete zero or more spaces and then one or more characters. + t.eraseNPreviousChars(t.countToLeftWord()) + case keyDeleteLine: + // Delete everything from the current cursor position to the + // end of line. + for i := t.pos; i < len(t.line); i++ { + t.queue(space) + t.advanceCursor(1) + } + t.line = t.line[:t.pos] + t.moveCursorToPos(t.pos) + case keyCtrlD: + // Erase the character under the current position. + // The EOF case when the line is empty is handled in + // readLine(). + if t.pos < len(t.line) { + t.pos++ + t.eraseNPreviousChars(1) + } + case keyCtrlU: + t.eraseNPreviousChars(t.pos) + case keyClearScreen: + // Erases the screen and moves the cursor to the home position. + t.queue([]rune("\x1b[2J\x1b[H")) + t.queue(t.prompt) + t.cursorX, t.cursorY = 0, 0 + t.advanceCursor(visualLength(t.prompt)) + t.setLine(t.line, t.pos) + default: + if t.AutoCompleteCallback != nil { + prefix := string(t.line[:t.pos]) + suffix := string(t.line[t.pos:]) + + t.lock.Unlock() + newLine, newPos, completeOk := t.AutoCompleteCallback(prefix+suffix, len(prefix), key) + t.lock.Lock() + + if completeOk { + t.setLine([]rune(newLine), utf8.RuneCount([]byte(newLine)[:newPos])) + return + } + } + if !isPrintable(key) { + return + } + if len(t.line) == maxLineLength { + return + } + t.addKeyToLine(key) + } + return +} + +// addKeyToLine inserts the given key at the current position in the current +// line. +func (t *Terminal) addKeyToLine(key rune) { + if len(t.line) == cap(t.line) { + newLine := make([]rune, len(t.line), 2*(1+len(t.line))) + copy(newLine, t.line) + t.line = newLine + } + t.line = t.line[:len(t.line)+1] + copy(t.line[t.pos+1:], t.line[t.pos:]) + t.line[t.pos] = key + if t.echo { + t.writeLine(t.line[t.pos:]) + } + t.pos++ + t.moveCursorToPos(t.pos) +} + +func (t *Terminal) writeLine(line []rune) { + for len(line) != 0 { + remainingOnLine := t.termWidth - t.cursorX + todo := len(line) + if todo > remainingOnLine { + todo = remainingOnLine + } + t.queue(line[:todo]) + t.advanceCursor(visualLength(line[:todo])) + line = line[todo:] + } +} + +// writeWithCRLF writes buf to w but replaces all occurrences of \n with \r\n. +func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) { + for len(buf) > 0 { + i := bytes.IndexByte(buf, '\n') + todo := len(buf) + if i >= 0 { + todo = i + } + + var nn int + nn, err = w.Write(buf[:todo]) + n += nn + if err != nil { + return n, err + } + buf = buf[todo:] + + if i >= 0 { + if _, err = w.Write(crlf); err != nil { + return n, err + } + n++ + buf = buf[1:] + } + } + + return n, nil +} + +func (t *Terminal) Write(buf []byte) (n int, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + if t.cursorX == 0 && t.cursorY == 0 { + // This is the easy case: there's nothing on the screen that we + // have to move out of the way. + return writeWithCRLF(t.c, buf) + } + + // We have a prompt and possibly user input on the screen. We + // have to clear it first. + t.move(0 /* up */, 0 /* down */, t.cursorX /* left */, 0 /* right */) + t.cursorX = 0 + t.clearLineToRight() + + for t.cursorY > 0 { + t.move(1 /* up */, 0, 0, 0) + t.cursorY-- + t.clearLineToRight() + } + + if _, err = t.c.Write(t.outBuf); err != nil { + return + } + t.outBuf = t.outBuf[:0] + + if n, err = writeWithCRLF(t.c, buf); err != nil { + return + } + + t.writeLine(t.prompt) + if t.echo { + t.writeLine(t.line) + } + + t.moveCursorToPos(t.pos) + + if _, err = t.c.Write(t.outBuf); err != nil { + return + } + t.outBuf = t.outBuf[:0] + return +} + +// ReadPassword temporarily changes the prompt and reads a password, without +// echo, from the terminal. +func (t *Terminal) ReadPassword(prompt string) (line string, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + oldPrompt := t.prompt + t.prompt = []rune(prompt) + t.echo = false + + line, err = t.readLine() + + t.prompt = oldPrompt + t.echo = true + + return +} + +// ReadLine returns a line of input from the terminal. +func (t *Terminal) ReadLine() (line string, err error) { + t.lock.Lock() + defer t.lock.Unlock() + + return t.readLine() +} + +func (t *Terminal) readLine() (line string, err error) { + // t.lock must be held at this point + + if t.cursorX == 0 && t.cursorY == 0 { + t.writeLine(t.prompt) + t.c.Write(t.outBuf) + t.outBuf = t.outBuf[:0] + } + + lineIsPasted := t.pasteActive + + for { + rest := t.remainder + lineOk := false + for !lineOk { + var key rune + key, rest = bytesToKey(rest, t.pasteActive) + if key == utf8.RuneError { + break + } + if !t.pasteActive { + if key == keyCtrlD { + if len(t.line) == 0 { + return "", io.EOF + } + } + if key == keyCtrlC { + return "", io.EOF + } + if key == keyPasteStart { + t.pasteActive = true + if len(t.line) == 0 { + lineIsPasted = true + } + continue + } + } else if key == keyPasteEnd { + t.pasteActive = false + continue + } + if !t.pasteActive { + lineIsPasted = false + } + line, lineOk = t.handleKey(key) + } + if len(rest) > 0 { + n := copy(t.inBuf[:], rest) + t.remainder = t.inBuf[:n] + } else { + t.remainder = nil + } + t.c.Write(t.outBuf) + t.outBuf = t.outBuf[:0] + if lineOk { + if t.echo { + t.historyIndex = -1 + t.history.Add(line) + } + if lineIsPasted { + err = ErrPasteIndicator + } + return + } + + // t.remainder is a slice at the beginning of t.inBuf + // containing a partial key sequence + readBuf := t.inBuf[len(t.remainder):] + var n int + + t.lock.Unlock() + n, err = t.c.Read(readBuf) + t.lock.Lock() + + if err != nil { + return + } + + t.remainder = t.inBuf[:n+len(t.remainder)] + } +} + +// SetPrompt sets the prompt to be used when reading subsequent lines. +func (t *Terminal) SetPrompt(prompt string) { + t.lock.Lock() + defer t.lock.Unlock() + + t.prompt = []rune(prompt) +} + +func (t *Terminal) clearAndRepaintLinePlusNPrevious(numPrevLines int) { + // Move cursor to column zero at the start of the line. + t.move(t.cursorY, 0, t.cursorX, 0) + t.cursorX, t.cursorY = 0, 0 + t.clearLineToRight() + for t.cursorY < numPrevLines { + // Move down a line + t.move(0, 1, 0, 0) + t.cursorY++ + t.clearLineToRight() + } + // Move back to beginning. + t.move(t.cursorY, 0, 0, 0) + t.cursorX, t.cursorY = 0, 0 + + t.queue(t.prompt) + t.advanceCursor(visualLength(t.prompt)) + t.writeLine(t.line) + t.moveCursorToPos(t.pos) +} + +func (t *Terminal) SetSize(width, height int) error { + t.lock.Lock() + defer t.lock.Unlock() + + if width == 0 { + width = 1 + } + + oldWidth := t.termWidth + t.termWidth, t.termHeight = width, height + + switch { + case width == oldWidth: + // If the width didn't change then nothing else needs to be + // done. + return nil + case len(t.line) == 0 && t.cursorX == 0 && t.cursorY == 0: + // If there is nothing on current line and no prompt printed, + // just do nothing + return nil + case width < oldWidth: + // Some terminals (e.g. xterm) will truncate lines that were + // too long when shinking. Others, (e.g. gnome-terminal) will + // attempt to wrap them. For the former, repainting t.maxLine + // works great, but that behaviour goes badly wrong in the case + // of the latter because they have doubled every full line. + + // We assume that we are working on a terminal that wraps lines + // and adjust the cursor position based on every previous line + // wrapping and turning into two. This causes the prompt on + // xterms to move upwards, which isn't great, but it avoids a + // huge mess with gnome-terminal. + if t.cursorX >= t.termWidth { + t.cursorX = t.termWidth - 1 + } + t.cursorY *= 2 + t.clearAndRepaintLinePlusNPrevious(t.maxLine * 2) + case width > oldWidth: + // If the terminal expands then our position calculations will + // be wrong in the future because we think the cursor is + // |t.pos| chars into the string, but there will be a gap at + // the end of any wrapped line. + // + // But the position will actually be correct until we move, so + // we can move back to the beginning and repaint everything. + t.clearAndRepaintLinePlusNPrevious(t.maxLine) + } + + _, err := t.c.Write(t.outBuf) + t.outBuf = t.outBuf[:0] + return err +} + +type pasteIndicatorError struct{} + +func (pasteIndicatorError) Error() string { + return "terminal: ErrPasteIndicator not correctly handled" +} + +// ErrPasteIndicator may be returned from ReadLine as the error, in addition +// to valid line data. It indicates that bracketed paste mode is enabled and +// that the returned line consists only of pasted data. Programs may wish to +// interpret pasted data more literally than typed data. +var ErrPasteIndicator = pasteIndicatorError{} + +// SetBracketedPasteMode requests that the terminal bracket paste operations +// with markers. Not all terminals support this but, if it is supported, then +// enabling this mode will stop any autocomplete callback from running due to +// pastes. Additionally, any lines that are completely pasted will be returned +// from ReadLine with the error set to ErrPasteIndicator. +func (t *Terminal) SetBracketedPasteMode(on bool) { + if on { + io.WriteString(t.c, "\x1b[?2004h") + } else { + io.WriteString(t.c, "\x1b[?2004l") + } +} + +// stRingBuffer is a ring buffer of strings. +type stRingBuffer struct { + // entries contains max elements. + entries []string + max int + // head contains the index of the element most recently added to the ring. + head int + // size contains the number of elements in the ring. + size int +} + +func (s *stRingBuffer) Add(a string) { + if s.entries == nil { + const defaultNumEntries = 100 + s.entries = make([]string, defaultNumEntries) + s.max = defaultNumEntries + } + + s.head = (s.head + 1) % s.max + s.entries[s.head] = a + if s.size < s.max { + s.size++ + } +} + +// NthPreviousEntry returns the value passed to the nth previous call to Add. +// If n is zero then the immediately prior value is returned, if one, then the +// next most recent, and so on. If such an element doesn't exist then ok is +// false. +func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) { + if n >= s.size { + return "", false + } + index := s.head - n + if index < 0 { + index += s.max + } + return s.entries[index], true +} + +// readPasswordLine reads from reader until it finds \n or io.EOF. +// The slice returned does not include the \n. +// readPasswordLine also ignores any \r it finds. +// Windows uses \r as end of line. So, on Windows, readPasswordLine +// reads until it finds \r and ignores any \n it finds during processing. +func readPasswordLine(reader io.Reader) ([]byte, error) { + var buf [1]byte + var ret []byte + + for { + n, err := reader.Read(buf[:]) + if n > 0 { + switch buf[0] { + case '\b': + if len(ret) > 0 { + ret = ret[:len(ret)-1] + } + case '\n': + if runtime.GOOS != "windows" { + return ret, nil + } + // otherwise ignore \n + case '\r': + if runtime.GOOS == "windows" { + return ret, nil + } + // otherwise ignore \r + default: + ret = append(ret, buf[0]) + } + continue + } + if err != nil { + if err == io.EOF && len(ret) > 0 { + return ret, nil + } + return ret, err + } + } +} diff --git a/internal/crypto/ssh/terminal/util.go b/internal/crypto/ssh/terminal/util.go new file mode 100644 index 000000000..391104084 --- /dev/null +++ b/internal/crypto/ssh/terminal/util.go @@ -0,0 +1,114 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build aix darwin dragonfly freebsd linux,!appengine netbsd openbsd + +// Package terminal provides support functions for dealing with terminals, as +// commonly found on UNIX systems. +// +// Putting a terminal into raw mode is the most common requirement: +// +// oldState, err := terminal.MakeRaw(0) +// if err != nil { +// panic(err) +// } +// defer terminal.Restore(0, oldState) +package terminal // import "golang.org/x/crypto/ssh/terminal" + +import ( + "golang.org/x/sys/unix" +) + +// State contains the state of a terminal. +type State struct { + termios unix.Termios +} + +// IsTerminal returns whether the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + _, err := unix.IoctlGetTermios(fd, ioctlReadTermios) + return err == nil +} + +// MakeRaw put the terminal connected to the given file descriptor into raw +// mode and returns the previous state of the terminal so that it can be +// restored. +func MakeRaw(fd int) (*State, error) { + termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios) + if err != nil { + return nil, err + } + + oldState := State{termios: *termios} + + // This attempts to replicate the behaviour documented for cfmakeraw in + // the termios(3) manpage. + termios.Iflag &^= unix.IGNBRK | unix.BRKINT | unix.PARMRK | unix.ISTRIP | unix.INLCR | unix.IGNCR | unix.ICRNL | unix.IXON + termios.Oflag &^= unix.OPOST + termios.Lflag &^= unix.ECHO | unix.ECHONL | unix.ICANON | unix.ISIG | unix.IEXTEN + termios.Cflag &^= unix.CSIZE | unix.PARENB + termios.Cflag |= unix.CS8 + termios.Cc[unix.VMIN] = 1 + termios.Cc[unix.VTIME] = 0 + if err := unix.IoctlSetTermios(fd, ioctlWriteTermios, termios); err != nil { + return nil, err + } + + return &oldState, nil +} + +// GetState returns the current state of a terminal which may be useful to +// restore the terminal after a signal. +func GetState(fd int) (*State, error) { + termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios) + if err != nil { + return nil, err + } + + return &State{termios: *termios}, nil +} + +// Restore restores the terminal connected to the given file descriptor to a +// previous state. +func Restore(fd int, state *State) error { + return unix.IoctlSetTermios(fd, ioctlWriteTermios, &state.termios) +} + +// GetSize returns the dimensions of the given terminal. +func GetSize(fd int) (width, height int, err error) { + ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ) + if err != nil { + return -1, -1, err + } + return int(ws.Col), int(ws.Row), nil +} + +// passwordReader is an io.Reader that reads from a specific file descriptor. +type passwordReader int + +func (r passwordReader) Read(buf []byte) (int, error) { + return unix.Read(int(r), buf) +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios) + if err != nil { + return nil, err + } + + newState := *termios + newState.Lflag &^= unix.ECHO + newState.Lflag |= unix.ICANON | unix.ISIG + newState.Iflag |= unix.ICRNL + if err := unix.IoctlSetTermios(fd, ioctlWriteTermios, &newState); err != nil { + return nil, err + } + + defer unix.IoctlSetTermios(fd, ioctlWriteTermios, termios) + + return readPasswordLine(passwordReader(fd)) +} diff --git a/internal/crypto/ssh/terminal/util_aix.go b/internal/crypto/ssh/terminal/util_aix.go new file mode 100644 index 000000000..dfcd62785 --- /dev/null +++ b/internal/crypto/ssh/terminal/util_aix.go @@ -0,0 +1,12 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build aix + +package terminal + +import "golang.org/x/sys/unix" + +const ioctlReadTermios = unix.TCGETS +const ioctlWriteTermios = unix.TCSETS diff --git a/internal/crypto/ssh/terminal/util_bsd.go b/internal/crypto/ssh/terminal/util_bsd.go new file mode 100644 index 000000000..cb23a5904 --- /dev/null +++ b/internal/crypto/ssh/terminal/util_bsd.go @@ -0,0 +1,12 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd netbsd openbsd + +package terminal + +import "golang.org/x/sys/unix" + +const ioctlReadTermios = unix.TIOCGETA +const ioctlWriteTermios = unix.TIOCSETA diff --git a/internal/crypto/ssh/terminal/util_linux.go b/internal/crypto/ssh/terminal/util_linux.go new file mode 100644 index 000000000..5fadfe8a1 --- /dev/null +++ b/internal/crypto/ssh/terminal/util_linux.go @@ -0,0 +1,10 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package terminal + +import "golang.org/x/sys/unix" + +const ioctlReadTermios = unix.TCGETS +const ioctlWriteTermios = unix.TCSETS diff --git a/internal/crypto/ssh/terminal/util_plan9.go b/internal/crypto/ssh/terminal/util_plan9.go new file mode 100644 index 000000000..9317ac7ed --- /dev/null +++ b/internal/crypto/ssh/terminal/util_plan9.go @@ -0,0 +1,58 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package terminal provides support functions for dealing with terminals, as +// commonly found on UNIX systems. +// +// Putting a terminal into raw mode is the most common requirement: +// +// oldState, err := terminal.MakeRaw(0) +// if err != nil { +// panic(err) +// } +// defer terminal.Restore(0, oldState) +package terminal + +import ( + "fmt" + "runtime" +) + +type State struct{} + +// IsTerminal returns whether the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + return false +} + +// MakeRaw put the terminal connected to the given file descriptor into raw +// mode and returns the previous state of the terminal so that it can be +// restored. +func MakeRaw(fd int) (*State, error) { + return nil, fmt.Errorf("terminal: MakeRaw not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +// GetState returns the current state of a terminal which may be useful to +// restore the terminal after a signal. +func GetState(fd int) (*State, error) { + return nil, fmt.Errorf("terminal: GetState not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +// Restore restores the terminal connected to the given file descriptor to a +// previous state. +func Restore(fd int, state *State) error { + return fmt.Errorf("terminal: Restore not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +// GetSize returns the dimensions of the given terminal. +func GetSize(fd int) (width, height int, err error) { + return 0, 0, fmt.Errorf("terminal: GetSize not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + return nil, fmt.Errorf("terminal: ReadPassword not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) +} diff --git a/internal/crypto/ssh/terminal/util_solaris.go b/internal/crypto/ssh/terminal/util_solaris.go new file mode 100644 index 000000000..3d5f06a9f --- /dev/null +++ b/internal/crypto/ssh/terminal/util_solaris.go @@ -0,0 +1,124 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build solaris + +package terminal // import "golang.org/x/crypto/ssh/terminal" + +import ( + "golang.org/x/sys/unix" + "io" + "syscall" +) + +// State contains the state of a terminal. +type State struct { + termios unix.Termios +} + +// IsTerminal returns whether the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + _, err := unix.IoctlGetTermio(fd, unix.TCGETA) + return err == nil +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + // see also: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libast/common/uwin/getpass.c + val, err := unix.IoctlGetTermios(fd, unix.TCGETS) + if err != nil { + return nil, err + } + oldState := *val + + newState := oldState + newState.Lflag &^= syscall.ECHO + newState.Lflag |= syscall.ICANON | syscall.ISIG + newState.Iflag |= syscall.ICRNL + err = unix.IoctlSetTermios(fd, unix.TCSETS, &newState) + if err != nil { + return nil, err + } + + defer unix.IoctlSetTermios(fd, unix.TCSETS, &oldState) + + var buf [16]byte + var ret []byte + for { + n, err := syscall.Read(fd, buf[:]) + if err != nil { + return nil, err + } + if n == 0 { + if len(ret) == 0 { + return nil, io.EOF + } + break + } + if buf[n-1] == '\n' { + n-- + } + ret = append(ret, buf[:n]...) + if n < len(buf) { + break + } + } + + return ret, nil +} + +// MakeRaw puts the terminal connected to the given file descriptor into raw +// mode and returns the previous state of the terminal so that it can be +// restored. +// see http://cr.illumos.org/~webrev/andy_js/1060/ +func MakeRaw(fd int) (*State, error) { + termios, err := unix.IoctlGetTermios(fd, unix.TCGETS) + if err != nil { + return nil, err + } + + oldState := State{termios: *termios} + + termios.Iflag &^= unix.IGNBRK | unix.BRKINT | unix.PARMRK | unix.ISTRIP | unix.INLCR | unix.IGNCR | unix.ICRNL | unix.IXON + termios.Oflag &^= unix.OPOST + termios.Lflag &^= unix.ECHO | unix.ECHONL | unix.ICANON | unix.ISIG | unix.IEXTEN + termios.Cflag &^= unix.CSIZE | unix.PARENB + termios.Cflag |= unix.CS8 + termios.Cc[unix.VMIN] = 1 + termios.Cc[unix.VTIME] = 0 + + if err := unix.IoctlSetTermios(fd, unix.TCSETS, termios); err != nil { + return nil, err + } + + return &oldState, nil +} + +// Restore restores the terminal connected to the given file descriptor to a +// previous state. +func Restore(fd int, oldState *State) error { + return unix.IoctlSetTermios(fd, unix.TCSETS, &oldState.termios) +} + +// GetState returns the current state of a terminal which may be useful to +// restore the terminal after a signal. +func GetState(fd int) (*State, error) { + termios, err := unix.IoctlGetTermios(fd, unix.TCGETS) + if err != nil { + return nil, err + } + + return &State{termios: *termios}, nil +} + +// GetSize returns the dimensions of the given terminal. +func GetSize(fd int) (width, height int, err error) { + ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ) + if err != nil { + return 0, 0, err + } + return int(ws.Col), int(ws.Row), nil +} diff --git a/internal/crypto/ssh/terminal/util_windows.go b/internal/crypto/ssh/terminal/util_windows.go new file mode 100644 index 000000000..f614e9cb6 --- /dev/null +++ b/internal/crypto/ssh/terminal/util_windows.go @@ -0,0 +1,105 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows + +// Package terminal provides support functions for dealing with terminals, as +// commonly found on UNIX systems. +// +// Putting a terminal into raw mode is the most common requirement: +// +// oldState, err := terminal.MakeRaw(0) +// if err != nil { +// panic(err) +// } +// defer terminal.Restore(0, oldState) +package terminal + +import ( + "os" + + "golang.org/x/sys/windows" +) + +type State struct { + mode uint32 +} + +// IsTerminal returns whether the given file descriptor is a terminal. +func IsTerminal(fd int) bool { + var st uint32 + err := windows.GetConsoleMode(windows.Handle(fd), &st) + return err == nil +} + +// MakeRaw put the terminal connected to the given file descriptor into raw +// mode and returns the previous state of the terminal so that it can be +// restored. +func MakeRaw(fd int) (*State, error) { + var st uint32 + if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil { + return nil, err + } + raw := st &^ (windows.ENABLE_ECHO_INPUT | windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT) + if err := windows.SetConsoleMode(windows.Handle(fd), raw); err != nil { + return nil, err + } + return &State{st}, nil +} + +// GetState returns the current state of a terminal which may be useful to +// restore the terminal after a signal. +func GetState(fd int) (*State, error) { + var st uint32 + if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil { + return nil, err + } + return &State{st}, nil +} + +// Restore restores the terminal connected to the given file descriptor to a +// previous state. +func Restore(fd int, state *State) error { + return windows.SetConsoleMode(windows.Handle(fd), state.mode) +} + +// GetSize returns the visible dimensions of the given terminal. +// +// These dimensions don't include any scrollback buffer height. +func GetSize(fd int) (width, height int, err error) { + var info windows.ConsoleScreenBufferInfo + if err := windows.GetConsoleScreenBufferInfo(windows.Handle(fd), &info); err != nil { + return 0, 0, err + } + return int(info.Window.Right - info.Window.Left + 1), int(info.Window.Bottom - info.Window.Top + 1), nil +} + +// ReadPassword reads a line of input from a terminal without local echo. This +// is commonly used for inputting passwords and other sensitive data. The slice +// returned does not include the \n. +func ReadPassword(fd int) ([]byte, error) { + var st uint32 + if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil { + return nil, err + } + old := st + + st &^= (windows.ENABLE_ECHO_INPUT | windows.ENABLE_LINE_INPUT) + st |= (windows.ENABLE_PROCESSED_OUTPUT | windows.ENABLE_PROCESSED_INPUT) + if err := windows.SetConsoleMode(windows.Handle(fd), st); err != nil { + return nil, err + } + + defer windows.SetConsoleMode(windows.Handle(fd), old) + + var h windows.Handle + p, _ := windows.GetCurrentProcess() + if err := windows.DuplicateHandle(p, windows.Handle(fd), p, &h, 0, false, windows.DUPLICATE_SAME_ACCESS); err != nil { + return nil, err + } + + f := os.NewFile(uintptr(h), "stdin") + defer f.Close() + return readPasswordLine(f) +} diff --git a/internal/crypto/ssh/transport.go b/internal/crypto/ssh/transport.go new file mode 100644 index 000000000..49ddc2e7d --- /dev/null +++ b/internal/crypto/ssh/transport.go @@ -0,0 +1,353 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bufio" + "bytes" + "errors" + "io" + "log" +) + +// debugTransport if set, will print packet types as they go over the +// wire. No message decoding is done, to minimize the impact on timing. +const debugTransport = false + +const ( + gcmCipherID = "aes128-gcm@openssh.com" + aes128cbcID = "aes128-cbc" + tripledescbcID = "3des-cbc" +) + +// packetConn represents a transport that implements packet based +// operations. +type packetConn interface { + // Encrypt and send a packet of data to the remote peer. + writePacket(packet []byte) error + + // Read a packet from the connection. The read is blocking, + // i.e. if error is nil, then the returned byte slice is + // always non-empty. + readPacket() ([]byte, error) + + // Close closes the write-side of the connection. + Close() error +} + +// transport is the keyingTransport that implements the SSH packet +// protocol. +type transport struct { + reader connectionState + writer connectionState + + bufReader *bufio.Reader + bufWriter *bufio.Writer + rand io.Reader + isClient bool + io.Closer +} + +// packetCipher represents a combination of SSH encryption/MAC +// protocol. A single instance should be used for one direction only. +type packetCipher interface { + // writeCipherPacket encrypts the packet and writes it to w. The + // contents of the packet are generally scrambled. + writeCipherPacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error + + // readCipherPacket reads and decrypts a packet of data. The + // returned packet may be overwritten by future calls of + // readPacket. + readCipherPacket(seqnum uint32, r io.Reader) ([]byte, error) +} + +// connectionState represents one side (read or write) of the +// connection. This is necessary because each direction has its own +// keys, and can even have its own algorithms +type connectionState struct { + packetCipher + seqNum uint32 + dir direction + pendingKeyChange chan packetCipher +} + +// prepareKeyChange sets up key material for a keychange. The key changes in +// both directions are triggered by reading and writing a msgNewKey packet +// respectively. +func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { + ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult) + if err != nil { + return err + } + t.reader.pendingKeyChange <- ciph + + ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult) + if err != nil { + return err + } + t.writer.pendingKeyChange <- ciph + + return nil +} + +func (t *transport) printPacket(p []byte, write bool) { + if len(p) == 0 { + return + } + who := "server" + if t.isClient { + who = "client" + } + what := "read" + if write { + what = "write" + } + + log.Println(what, who, p[0]) +} + +// Read and decrypt next packet. +func (t *transport) readPacket() (p []byte, err error) { + for { + p, err = t.reader.readPacket(t.bufReader) + if err != nil { + break + } + if len(p) == 0 || (p[0] != msgIgnore && p[0] != msgDebug) { + break + } + } + if debugTransport { + t.printPacket(p, false) + } + + return p, err +} + +func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) { + packet, err := s.packetCipher.readCipherPacket(s.seqNum, r) + s.seqNum++ + if err == nil && len(packet) == 0 { + err = errors.New("ssh: zero length packet") + } + + if len(packet) > 0 { + switch packet[0] { + case msgNewKeys: + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + default: + return nil, errors.New("ssh: got bogus newkeys message") + } + + case msgDisconnect: + // Transform a disconnect message into an + // error. Since this is lowest level at which + // we interpret message types, doing it here + // ensures that we don't have to handle it + // elsewhere. + var msg disconnectMsg + if err := Unmarshal(packet, &msg); err != nil { + return nil, err + } + return nil, &msg + } + } + + // The packet may point to an internal buffer, so copy the + // packet out here. + fresh := make([]byte, len(packet)) + copy(fresh, packet) + + return fresh, err +} + +func (t *transport) writePacket(packet []byte) error { + if debugTransport { + t.printPacket(packet, true) + } + return t.writer.writePacket(t.bufWriter, t.rand, packet) +} + +func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error { + changeKeys := len(packet) > 0 && packet[0] == msgNewKeys + + err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet) + if err != nil { + return err + } + if err = w.Flush(); err != nil { + return err + } + s.seqNum++ + if changeKeys { + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + default: + panic("ssh: no key material for msgNewKeys") + } + } + return err +} + +func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport { + t := &transport{ + bufReader: bufio.NewReader(rwc), + bufWriter: bufio.NewWriter(rwc), + rand: rand, + reader: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + writer: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + Closer: rwc, + } + t.isClient = isClient + + if isClient { + t.reader.dir = serverKeys + t.writer.dir = clientKeys + } else { + t.reader.dir = clientKeys + t.writer.dir = serverKeys + } + + return t +} + +type direction struct { + ivTag []byte + keyTag []byte + macKeyTag []byte +} + +var ( + serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}} + clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} +) + +// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as +// described in RFC 4253, section 6.4. direction should either be serverKeys +// (to setup server->client keys) or clientKeys (for client->server keys). +func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) { + cipherMode := cipherModes[algs.Cipher] + macMode := macModes[algs.MAC] + + iv := make([]byte, cipherMode.ivSize) + key := make([]byte, cipherMode.keySize) + macKey := make([]byte, macMode.keySize) + + generateKeyMaterial(iv, d.ivTag, kex) + generateKeyMaterial(key, d.keyTag, kex) + generateKeyMaterial(macKey, d.macKeyTag, kex) + + return cipherModes[algs.Cipher].create(key, iv, macKey, algs) +} + +// generateKeyMaterial fills out with key material generated from tag, K, H +// and sessionId, as specified in RFC 4253, section 7.2. +func generateKeyMaterial(out, tag []byte, r *kexResult) { + var digestsSoFar []byte + + h := r.Hash.New() + for len(out) > 0 { + h.Reset() + h.Write(r.K) + h.Write(r.H) + + if len(digestsSoFar) == 0 { + h.Write(tag) + h.Write(r.SessionID) + } else { + h.Write(digestsSoFar) + } + + digest := h.Sum(nil) + n := copy(out, digest) + out = out[n:] + if len(out) > 0 { + digestsSoFar = append(digestsSoFar, digest...) + } + } +} + +const packageVersion = "SSH-2.0-Go" + +// Sends and receives a version line. The versionLine string should +// be US ASCII, start with "SSH-2.0-", and should not include a +// newline. exchangeVersions returns the other side's version line. +func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) { + // Contrary to the RFC, we do not ignore lines that don't + // start with "SSH-2.0-" to make the library usable with + // nonconforming servers. + for _, c := range versionLine { + // The spec disallows non US-ASCII chars, and + // specifically forbids null chars. + if c < 32 { + return nil, errors.New("ssh: junk character in version line") + } + } + if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil { + return + } + + them, err = readVersion(rw) + return them, err +} + +// maxVersionStringBytes is the maximum number of bytes that we'll +// accept as a version string. RFC 4253 section 4.2 limits this at 255 +// chars +const maxVersionStringBytes = 255 + +// Read version string as specified by RFC 4253, section 4.2. +func readVersion(r io.Reader) ([]byte, error) { + versionString := make([]byte, 0, 64) + var ok bool + var buf [1]byte + + for length := 0; length < maxVersionStringBytes; length++ { + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return nil, err + } + // The RFC says that the version should be terminated with \r\n + // but several SSH servers actually only send a \n. + if buf[0] == '\n' { + if !bytes.HasPrefix(versionString, []byte("SSH-")) { + // RFC 4253 says we need to ignore all version string lines + // except the one containing the SSH version (provided that + // all the lines do not exceed 255 bytes in total). + versionString = versionString[:0] + continue + } + ok = true + break + } + + // non ASCII chars are disallowed, but we are lenient, + // since Go doesn't use null-terminated strings. + + // The RFC allows a comment after a space, however, + // all of it (version and comments) goes into the + // session hash. + versionString = append(versionString, buf[0]) + } + + if !ok { + return nil, errors.New("ssh: overflow reading version string") + } + + // There might be a '\r' on the end which we should remove. + if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' { + versionString = versionString[:len(versionString)-1] + } + return versionString, nil +} From 9b9e533cb9697c63bac18062f1fc38ab07505d06 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 10 Sep 2021 13:32:30 -0400 Subject: [PATCH 188/290] add comment for special handling --- cmd/ghcs/ports.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 8b73626fa..4b190b36f 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -63,6 +63,7 @@ func ports(codespaceName string, asJSON bool) error { codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { + // TODO(josebalius): remove special handling of this error here and it other places if err == errNoCodespaces { return err } From 34e52ba24a30f3b289ad1f485f09eef3c57a13b9 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 10 Sep 2021 14:11:50 -0400 Subject: [PATCH 189/290] deprecate subcommands --- cmd/ghcs/delete.go | 51 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 75b9362bb..fdda850b1 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "fmt" "os" "strings" @@ -12,34 +13,57 @@ import ( ) func newDeleteCmd() *cobra.Command { + var ( + codespace string + allCodespaces bool + repo string + ) + + log := output.NewLogger(os.Stdout, os.Stderr, false) + deleteCmd := &cobra.Command{ - Use: "delete []", + Use: "delete", Short: "Delete a codespace", - Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - var codespaceName string if len(args) > 0 { - codespaceName = args[0] + log.Errorln(" argument is deprecated. Use --codespace instead.") + codespace = args[0] + } + + switch { + case allCodespaces && repo != "": + return errors.New("both --all and --repo is not supported.") + case allCodespaces: + return deleteAll(log) + case repo != "": + return deleteByRepo(log, repo) + default: + return delete_(log, codespace) } - return delete_(codespaceName) }, } + deleteCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + deleteCmd.Flags().BoolVar(&allCodespaces, "all", false, "Delete all codespaces") + deleteCmd.Flags().StringVarP(&repo, "repo", "r", "", "Delete all codespaces for a repository") + deleteAllCmd := &cobra.Command{ Use: "all", - Short: "Delete all codespaces for the current user", + Short: "(Deprecated) Delete all codespaces for the current user", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return deleteAll() + log.Errorln("all command is deprecated. Use --all instead.") + return deleteAll(log) }, } deleteByRepoCmd := &cobra.Command{ Use: "repo ", - Short: "Delete all codespaces for a repository", + Short: "(Deprecated) Delete all codespaces for a repository", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return deleteByRepo(args[0]) + log.Errorln("repo command is deprecated. Use --repo instead.") + return deleteByRepo(log, args[0]) }, } @@ -52,10 +76,9 @@ func init() { rootCmd.AddCommand(newDeleteCmd()) } -func delete_(codespaceName string) error { +func delete_(log *output.Logger, codespaceName string) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() - log := output.NewLogger(os.Stdout, os.Stderr, false) user, err := apiClient.GetUser(ctx) if err != nil { @@ -76,10 +99,9 @@ func delete_(codespaceName string) error { return list(&listOptions{}) } -func deleteAll() error { +func deleteAll(log *output.Logger) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() - log := output.NewLogger(os.Stdout, os.Stderr, false) user, err := apiClient.GetUser(ctx) if err != nil { @@ -107,10 +129,9 @@ func deleteAll() error { return list(&listOptions{}) } -func deleteByRepo(repo string) error { +func deleteByRepo(log *output.Logger, repo string) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() - log := output.NewLogger(os.Stdout, os.Stderr, false) user, err := apiClient.GetUser(ctx) if err != nil { From 798075045e39e4dc1dd2319dde9665bd69edeab1 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 10 Sep 2021 14:58:47 -0400 Subject: [PATCH 190/290] remove terminal, bash_profile setup --- cmd/ghcs/ssh.go | 70 ------------------------------------------------- 1 file changed, 70 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index aefb959f3..c97c72f9b 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -1,12 +1,10 @@ package main import ( - "bufio" "context" "fmt" "net" "os" - "strings" "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" @@ -67,21 +65,6 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo return fmt.Errorf("error getting ssh server details: %v", err) } - terminal := liveshare.NewTerminal(session) - - log.Print("Preparing SSH...") - if sshProfile == "" { - containerID, err := getContainerID(ctx, log, terminal) - if err != nil { - return fmt.Errorf("error getting container id: %v", err) - } - - if err := setupEnv(ctx, log, terminal, containerID, codespace.RepositoryName, sshUser); err != nil { - return fmt.Errorf("error creating ssh server: %v", err) - } - } - log.Print("\n") - usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell // Ensure local port is listening before client (Shell) connects. @@ -119,56 +102,3 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo return nil // success } } - -func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) { - logger.Print(".") - - cmd := terminal.NewCommand( - "/", - "/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running", - ) - - stream, err := cmd.Run(ctx) - if err != nil { - return "", fmt.Errorf("error running command: %v", err) - } - - logger.Print(".") - scanner := bufio.NewScanner(stream) - scanner.Scan() - - logger.Print(".") - containerID := scanner.Text() - if err := scanner.Err(); err != nil { - return "", fmt.Errorf("error scanning stream: %v", err) - } - - logger.Print(".") - if err := stream.Close(); err != nil { - return "", fmt.Errorf("error closing stream: %v", err) - } - - return containerID, nil -} - -func setupEnv(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal, containerID, repositoryName, containerUser string) error { - setupBashProfileCmd := fmt.Sprintf(`echo "export $(cat /workspaces/.codespaces/shared/.env | xargs); exec /bin/zsh;" > /home/%v/.bash_profile`, containerUser) - - logger.Print(".") - compositeCommand := []string{setupBashProfileCmd} - cmd := terminal.NewCommand( - "/", - fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c '"+strings.Join(compositeCommand, "; ")+"'", containerID), - ) - stream, err := cmd.Run(ctx) - if err != nil { - return fmt.Errorf("error running command: %v", err) - } - - logger.Print(".") - if err := stream.Close(); err != nil { - return fmt.Errorf("error closing stream: %v", err) - } - - return nil -} From 5b23d87d47f4ffa6dfce6c10794e9e32ec5c6371 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 10 Sep 2021 15:09:45 -0400 Subject: [PATCH 191/290] Remove Terminal, no longer needed by ghcs --- client.go | 4 ++ rpc.go | 41 +++---------------- rpc_test.go | 31 -------------- terminal.go | 116 ---------------------------------------------------- 4 files changed, 9 insertions(+), 183 deletions(-) delete mode 100644 rpc_test.go delete mode 100644 terminal.go diff --git a/client.go b/client.go index 566db6cd3..65e80a94a 100644 --- a/client.go +++ b/client.go @@ -115,6 +115,10 @@ func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorksp } func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) { + type getStreamArgs struct { + StreamName string `json:"streamName"` + Condition string `json:"condition"` + } args := getStreamArgs{ StreamName: id.name, Condition: id.condition, diff --git a/rpc.go b/rpc.go index 10aa2c7eb..68e187ad6 100644 --- a/rpc.go +++ b/rpc.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "sync" "github.com/opentracing/opentracing-go" "github.com/sourcegraph/jsonrpc2" @@ -12,18 +11,17 @@ import ( type rpcClient struct { *jsonrpc2.Conn - conn io.ReadWriteCloser - handler *rpcHandler + conn io.ReadWriteCloser } func newRPCClient(conn io.ReadWriteCloser) *rpcClient { - return &rpcClient{conn: conn, handler: newRPCHandler()} + return &rpcClient{conn: conn} } func (r *rpcClient) connect(ctx context.Context) { stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) // TODO(adonovan): fix: ensure r.Conn is eventually Closed! - r.Conn = jsonrpc2.NewConn(ctx, stream, r.handler) + r.Conn = jsonrpc2.NewConn(ctx, stream, nullHandler{}) } func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error { @@ -38,36 +36,7 @@ func (r *rpcClient) do(ctx context.Context, method string, args, result interfac return waiter.Wait(ctx, result) } -type rpcHandlerFunc = func(*jsonrpc2.Request) +type nullHandler struct{} -type rpcHandler struct { - handlersMu sync.Mutex - handlers map[string][]rpcHandlerFunc -} - -func newRPCHandler() *rpcHandler { - return &rpcHandler{ - handlers: make(map[string][]rpcHandlerFunc), - } -} - -// registerEventHandler registers a handler for the specified event. -// After the next occurrence of the event, the handler will be called, -// once, in its own goroutine. -func (r *rpcHandler) registerEventHandler(eventMethod string, h rpcHandlerFunc) { - r.handlersMu.Lock() - r.handlers[eventMethod] = append(r.handlers[eventMethod], h) - r.handlersMu.Unlock() -} - -// Handle calls all registered handlers for the request, concurrently, each in its own goroutine. -func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { - r.handlersMu.Lock() - handlers := r.handlers[req.Method] - r.handlers[req.Method] = nil - r.handlersMu.Unlock() - - for _, h := range handlers { - go h(req) - } +func (nullHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { } diff --git a/rpc_test.go b/rpc_test.go deleted file mode 100644 index cf9c4cf81..000000000 --- a/rpc_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package liveshare - -import ( - "context" - "testing" - "time" - - "github.com/sourcegraph/jsonrpc2" -) - -func TestRPCHandlerEvents(t *testing.T) { - rpcHandler := newRPCHandler() - eventCh := make(chan *jsonrpc2.Request) - rpcHandler.registerEventHandler("somethingHappened", func(req *jsonrpc2.Request) { - eventCh <- req - }) - go func() { - time.Sleep(1 * time.Second) - rpcHandler.Handle(context.Background(), nil, &jsonrpc2.Request{Method: "somethingHappened"}) - }() - ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) - defer cancel() - select { - case event := <-eventCh: - if event.Method != "somethingHappened" { - t.Error("event.Method is not the expect value") - } - case <-ctx.Done(): - t.Error("Test time out") - } -} diff --git a/terminal.go b/terminal.go deleted file mode 100644 index 24a0f5121..000000000 --- a/terminal.go +++ /dev/null @@ -1,116 +0,0 @@ -package liveshare - -import ( - "context" - "fmt" - "io" - - "github.com/sourcegraph/jsonrpc2" - "golang.org/x/crypto/ssh" -) - -type Terminal struct { - session *Session -} - -func NewTerminal(session *Session) *Terminal { - return &Terminal{session: session} -} - -type TerminalCommand struct { - terminal *Terminal - cwd string - cmd string -} - -func (t *Terminal) NewCommand(cwd, cmd string) TerminalCommand { - return TerminalCommand{t, cwd, cmd} -} - -type runArgs struct { - Name string `json:"name"` - Rows int `json:"rows"` - Cols int `json:"cols"` - App string `json:"app"` - Cwd string `json:"cwd"` - CommandLine []string `json:"commandLine"` - ReadOnlyForGuests bool `json:"readOnlyForGuests"` -} - -type startTerminalResult struct { - ID int `json:"id"` - StreamName string `json:"streamName"` - StreamCondition string `json:"streamCondition"` - LocalPipeName string `json:"localPipeName"` - AppProcessID int `json:"appProcessId"` -} - -type getStreamArgs struct { - StreamName string `json:"streamName"` - Condition string `json:"condition"` -} - -type stopTerminalArgs struct { - ID int `json:"id"` -} - -func (t TerminalCommand) Run(ctx context.Context) (io.ReadCloser, error) { - args := runArgs{ - Name: "RunCommand", - Rows: 10, - Cols: 80, - App: "/bin/bash", - Cwd: t.cwd, - CommandLine: []string{"-c", t.cmd}, - ReadOnlyForGuests: false, - } - - started := make(chan struct{}) - t.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStarted", func(*jsonrpc2.Request) { - close(started) - }) - var result startTerminalResult - if err := t.terminal.session.rpc.do(ctx, "terminal.startTerminal", &args, &result); err != nil { - return nil, fmt.Errorf("error making terminal.startTerminal call: %v", err) - } - <-started - - channel, err := t.terminal.session.openStreamingChannel(ctx, channelID{result.StreamName, result.StreamCondition}) - if err != nil { - return nil, fmt.Errorf("error opening streaming channel: %v", err) - } - - return t.newTerminalReadCloser(result.ID, channel), nil -} - -type terminalReadCloser struct { - terminalCommand TerminalCommand - terminalID int - channel ssh.Channel -} - -func (t TerminalCommand) newTerminalReadCloser(terminalID int, channel ssh.Channel) io.ReadCloser { - return terminalReadCloser{t, terminalID, channel} -} - -func (t terminalReadCloser) Read(b []byte) (int, error) { - return t.channel.Read(b) -} - -func (t terminalReadCloser) Close() error { - stopped := make(chan struct{}) - t.terminalCommand.terminal.session.rpc.handler.registerEventHandler("terminal.terminalStopped", func(*jsonrpc2.Request) { - close(stopped) - }) - if err := t.terminalCommand.terminal.session.rpc.do(context.Background(), "terminal.stopTerminal", []int{t.terminalID}, nil); err != nil { - return fmt.Errorf("error making terminal.stopTerminal call: %v", err) - } - - if err := t.channel.Close(); err != nil && err != io.EOF { - return fmt.Errorf("error closing channel: %v", err) - } - - <-stopped - - return nil -} From af301bfff1ab4669e95f79067bb4aae1460f17c0 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 10 Sep 2021 17:36:20 -0400 Subject: [PATCH 192/290] stdin/stdout fds are not 0/1 on windows --- cmd/ghcs/common.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index 229e04c78..b77d4a041 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -93,7 +93,9 @@ func getOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.Use return codespace, token, nil } -var hasTTY = term.IsTerminal(0) && term.IsTerminal(1) // is process connected to a terminal? +// hasTTY indicates whether the process connected to a terminal. +// It is not portable to assume stdin/stdout are fds 0 and 1. +var hasTTY = term.IsTerminal(os.Stdin.Fd()) && term.IsTerminal(os.Stdout.Fd()) // ask asks survey questions on the terminal, using standard options. // It fails unless hasTTY, but ideally callers should avoid calling it in that case. From 1526ab5bff3065558bbc3b40db3eb501f6e56061 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 10 Sep 2021 18:08:48 -0400 Subject: [PATCH 193/290] fix URL --- cmd/ghcs/common.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index b77d4a041..c4171acc2 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -109,7 +109,7 @@ func ask(qs []*survey.Question, response interface{}) error { // ASCII \x03 (ETX) instead of delivering SIGINT to the application. // So we have to serve ourselves the SIGINT. // - // https://github.com/AlecAivazis/survey/#why-isnt-sending-a-sigint-aka-ctrl-c-signal-working + // https://github.com/AlecAivazis/survey/#why-isnt-ctrl-c-working if err == terminal.InterruptErr { self, _ := os.FindProcess(os.Getpid()) _ = self.Signal(os.Interrupt) // assumes POSIX From c4be0a0e284ed1d22fc26dc452b81b56ab4549b5 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 13 Sep 2021 09:29:46 -0400 Subject: [PATCH 194/290] this time without compile errors --- cmd/ghcs/common.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index c4171acc2..bfcb67496 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -95,7 +95,7 @@ func getOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.Use // hasTTY indicates whether the process connected to a terminal. // It is not portable to assume stdin/stdout are fds 0 and 1. -var hasTTY = term.IsTerminal(os.Stdin.Fd()) && term.IsTerminal(os.Stdout.Fd()) +var hasTTY = term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd())) // ask asks survey questions on the terminal, using standard options. // It fails unless hasTTY, but ideally callers should avoid calling it in that case. From f5adc9e3a75eecd6a20b2541e9406fad2820a4fc Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 13 Sep 2021 10:58:00 -0400 Subject: [PATCH 195/290] remove all deprecation messages and deprecated functionality --- cmd/ghcs/code.go | 7 ------- cmd/ghcs/delete.go | 27 --------------------------- cmd/ghcs/logs.go | 10 ---------- cmd/ghcs/ports.go | 29 +++-------------------------- 4 files changed, 3 insertions(+), 70 deletions(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index d34b75ed8..3bd67053d 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -7,7 +7,6 @@ import ( "os" "github.com/github/ghcs/api" - "github.com/github/ghcs/cmd/ghcs/output" "github.com/skratchdot/open-golang/open" "github.com/spf13/cobra" ) @@ -18,17 +17,11 @@ func newCodeCmd() *cobra.Command { useInsiders bool ) - log := output.NewLogger(os.Stdout, os.Stderr, false) - codeCmd := &cobra.Command{ Use: "code", Short: "Open a codespace in VS Code", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - if len(args) > 0 { - log.Errorln(" argument is deprecated. Use --codespace instead.") - codespace = args[0] - } return code(codespace, useInsiders) }, } diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index fdda850b1..7800a13c0 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -25,11 +25,6 @@ func newDeleteCmd() *cobra.Command { Use: "delete", Short: "Delete a codespace", RunE: func(cmd *cobra.Command, args []string) error { - if len(args) > 0 { - log.Errorln(" argument is deprecated. Use --codespace instead.") - codespace = args[0] - } - switch { case allCodespaces && repo != "": return errors.New("both --all and --repo is not supported.") @@ -47,28 +42,6 @@ func newDeleteCmd() *cobra.Command { deleteCmd.Flags().BoolVar(&allCodespaces, "all", false, "Delete all codespaces") deleteCmd.Flags().StringVarP(&repo, "repo", "r", "", "Delete all codespaces for a repository") - deleteAllCmd := &cobra.Command{ - Use: "all", - Short: "(Deprecated) Delete all codespaces for the current user", - Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { - log.Errorln("all command is deprecated. Use --all instead.") - return deleteAll(log) - }, - } - - deleteByRepoCmd := &cobra.Command{ - Use: "repo ", - Short: "(Deprecated) Delete all codespaces for a repository", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - log.Errorln("repo command is deprecated. Use --repo instead.") - return deleteByRepo(log, args[0]) - }, - } - - deleteCmd.AddCommand(deleteAllCmd, deleteByRepoCmd) - return deleteCmd } diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index cb6ba19d2..a685a364d 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -16,7 +16,6 @@ import ( func newLogsCmd() *cobra.Command { var ( codespace string - tail bool follow bool ) @@ -27,20 +26,11 @@ func newLogsCmd() *cobra.Command { Short: "Access codespace logs", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - if len(args) > 0 { - log.Errorln(" argument is deprecated. Use --codespace instead.") - codespace = args[0] - } - if tail { - log.Errorln("--tail flag is deprecated. Use --follow instead.") - follow = true - } return logs(context.Background(), log, codespace, follow) }, } logsCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") - logsCmd.Flags().BoolVarP(&tail, "tail", "t", false, "Tail the logs (deprecated, use --follow)") logsCmd.Flags().BoolVarP(&follow, "follow", "f", false, "Tail and follow the logs") return logsCmd diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 4b190b36f..f6d1f6e2a 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -172,14 +172,7 @@ func newPortsPublicCmd() *cobra.Command { } log := output.NewLogger(os.Stdout, os.Stderr, false) - - port := args[0] - if len(args) > 1 { - log.Errorln(" argument is deprecated. Use --codespace instead.") - codespace, port = args[0], args[1] - } - - return updatePortVisibility(log, codespace, port, true) + return updatePortVisibility(log, codespace, args[0], true) }, } } @@ -200,14 +193,7 @@ func newPortsPrivateCmd() *cobra.Command { } log := output.NewLogger(os.Stdout, os.Stderr, false) - - port := args[0] - if len(args) > 1 { - log.Errorln(" argument is deprecated. Use --codespace instead.") - codespace, port = args[0], args[1] - } - - return updatePortVisibility(log, codespace, port, false) + return updatePortVisibility(log, codespace, args[0], false) }, } } @@ -269,16 +255,7 @@ func newPortsForwardCmd() *cobra.Command { } log := output.NewLogger(os.Stdout, os.Stderr, false) - - ports := args[0:] - if len(args) > 1 && !strings.Contains(args[0], ":") { - // assume this is a codespace name - log.Errorln(" argument is deprecated. Use --codespace instead.") - codespace = args[0] - ports = args[1:] - } - - return forwardPorts(log, codespace, ports) + return forwardPorts(log, codespace, args) }, } } From 497b45e4e2e41c2fca55f61995a289a6e62022e6 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 14 Sep 2021 23:57:40 +0000 Subject: [PATCH 196/290] ssh server docs --- ssh_server.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ssh_server.go b/ssh_server.go index 03b45f25f..ca66ec7de 100644 --- a/ssh_server.go +++ b/ssh_server.go @@ -4,14 +4,20 @@ import ( "context" ) +// A SSHServer handles starting the remote SSH server. +// If there is no SSH server available it installs one. type SSHServer struct { session *Session } +// SSHServer returns a new SSHServer from the LiveShare Session. func (session *Session) SSHServer() *SSHServer { return &SSHServer{session: session} } +// SSHServerStartResult contains whether or not the start of the SSH server was +// successful. If it succeeded the server port and user is included. If it failed, +// it contains an explanation message. type SSHServerStartResult struct { Result bool `json:"result"` ServerPort string `json:"serverPort"` @@ -19,6 +25,7 @@ type SSHServerStartResult struct { Message string `json:"message"` } +// StartRemoteServer starts or install the remote SSH server and returns the result. func (s *SSHServer) StartRemoteServer(ctx context.Context) (*SSHServerStartResult, error) { var response SSHServerStartResult From fb5a35568ca44340ba6e332452bca6873baae62d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 15 Sep 2021 13:58:10 +0200 Subject: [PATCH 197/290] Ensure original errors are wrapped with "%w" instead of "%v" --- api/api.go | 92 +++++++++++++++---------------- cmd/ghcs/code.go | 4 +- cmd/ghcs/common.go | 12 ++-- cmd/ghcs/create.go | 20 +++---- cmd/ghcs/delete.go | 22 ++++---- cmd/ghcs/list.go | 4 +- cmd/ghcs/logs.go | 12 ++-- cmd/ghcs/ports.go | 40 +++++++------- cmd/ghcs/ssh.go | 12 ++-- internal/codespaces/codespaces.go | 6 +- internal/codespaces/ssh.go | 4 +- internal/codespaces/states.go | 14 ++--- 12 files changed, 121 insertions(+), 121 deletions(-) diff --git a/api/api.go b/api/api.go index ad69f23fc..47948bbba 100644 --- a/api/api.go +++ b/api/api.go @@ -40,19 +40,19 @@ type User struct { func (a *API) GetUser(ctx context.Context) (*User, error) { req, err := http.NewRequest(http.MethodGet, githubAPI+"/user", nil) if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) + return nil, fmt.Errorf("error creating request: %w", err) } a.setHeaders(req) resp, err := a.do(ctx, req, "/user") if err != nil { - return nil, fmt.Errorf("error making request: %v", err) + return nil, fmt.Errorf("error making request: %w", err) } defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) + return nil, fmt.Errorf("error reading response body: %w", err) } if resp.StatusCode != http.StatusOK { @@ -61,7 +61,7 @@ func (a *API) GetUser(ctx context.Context) (*User, error) { var response User if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response: %v", err) + return nil, fmt.Errorf("error unmarshaling response: %w", err) } return &response, nil @@ -72,7 +72,7 @@ func jsonErrorResponse(b []byte) error { Message string `json:"message"` } if err := json.Unmarshal(b, &response); err != nil { - return fmt.Errorf("error unmarshaling error response: %v", err) + return fmt.Errorf("error unmarshaling error response: %w", err) } return errors.New(response.Message) @@ -85,19 +85,19 @@ type Repository struct { func (a *API) GetRepository(ctx context.Context, nwo string) (*Repository, error) { req, err := http.NewRequest(http.MethodGet, githubAPI+"/repos/"+strings.ToLower(nwo), nil) if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) + return nil, fmt.Errorf("error creating request: %w", err) } a.setHeaders(req) resp, err := a.do(ctx, req, "/repos/*") if err != nil { - return nil, fmt.Errorf("error making request: %v", err) + return nil, fmt.Errorf("error making request: %w", err) } defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) + return nil, fmt.Errorf("error reading response body: %w", err) } if resp.StatusCode != http.StatusOK { @@ -106,7 +106,7 @@ func (a *API) GetRepository(ctx context.Context, nwo string) (*Repository, error var response Repository if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response: %v", err) + return nil, fmt.Errorf("error unmarshaling response: %w", err) } return &response, nil @@ -154,19 +154,19 @@ func (a *API) ListCodespaces(ctx context.Context, user *User) ([]*Codespace, err http.MethodGet, githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces", nil, ) if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) + return nil, fmt.Errorf("error creating request: %w", err) } a.setHeaders(req) resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces") if err != nil { - return nil, fmt.Errorf("error making request: %v", err) + return nil, fmt.Errorf("error making request: %w", err) } defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) + return nil, fmt.Errorf("error reading response body: %w", err) } if resp.StatusCode != http.StatusOK { @@ -177,7 +177,7 @@ func (a *API) ListCodespaces(ctx context.Context, user *User) ([]*Codespace, err Codespaces []*Codespace `json:"codespaces"` } if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response: %v", err) + return nil, fmt.Errorf("error unmarshaling response: %w", err) } return response.Codespaces, nil } @@ -193,7 +193,7 @@ type getCodespaceTokenResponse struct { func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName string) (string, error) { reqBody, err := json.Marshal(getCodespaceTokenRequest{true}) if err != nil { - return "", fmt.Errorf("error preparing request body: %v", err) + return "", fmt.Errorf("error preparing request body: %w", err) } req, err := http.NewRequest( @@ -202,19 +202,19 @@ func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName s bytes.NewBuffer(reqBody), ) if err != nil { - return "", fmt.Errorf("error creating request: %v", err) + return "", fmt.Errorf("error creating request: %w", err) } a.setHeaders(req) resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*/token") if err != nil { - return "", fmt.Errorf("error making request: %v", err) + return "", fmt.Errorf("error making request: %w", err) } defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("error reading response body: %v", err) + return "", fmt.Errorf("error reading response body: %w", err) } if resp.StatusCode != http.StatusOK { @@ -223,7 +223,7 @@ func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName s var response getCodespaceTokenResponse if err := json.Unmarshal(b, &response); err != nil { - return "", fmt.Errorf("error unmarshaling response: %v", err) + return "", fmt.Errorf("error unmarshaling response: %w", err) } return response.RepositoryToken, nil @@ -236,19 +236,19 @@ func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) nil, ) if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) + return nil, fmt.Errorf("error creating request: %w", err) } req.Header.Set("Authorization", "Bearer "+token) resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") if err != nil { - return nil, fmt.Errorf("error making request: %v", err) + return nil, fmt.Errorf("error making request: %w", err) } defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) + return nil, fmt.Errorf("error reading response body: %w", err) } if resp.StatusCode != http.StatusOK { @@ -257,7 +257,7 @@ func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) var response Codespace if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response: %v", err) + return nil, fmt.Errorf("error unmarshaling response: %w", err) } return &response, nil @@ -270,19 +270,19 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes nil, ) if err != nil { - return fmt.Errorf("error creating request: %v", err) + return fmt.Errorf("error creating request: %w", err) } req.Header.Set("Authorization", "Bearer "+token) resp, err := a.do(ctx, req, "/vscs_internal/proxy/environments/*/start") if err != nil { - return fmt.Errorf("error making request: %v", err) + return fmt.Errorf("error making request: %w", err) } defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { - return fmt.Errorf("error reading response body: %v", err) + return fmt.Errorf("error reading response body: %w", err) } if resp.StatusCode != http.StatusOK { @@ -308,18 +308,18 @@ type getCodespaceRegionLocationResponse struct { func (a *API) GetCodespaceRegionLocation(ctx context.Context) (string, error) { req, err := http.NewRequest(http.MethodGet, "https://online.visualstudio.com/api/v1/locations", nil) if err != nil { - return "", fmt.Errorf("error creating request: %v", err) + return "", fmt.Errorf("error creating request: %w", err) } resp, err := a.do(ctx, req, req.URL.String()) if err != nil { - return "", fmt.Errorf("error making request: %v", err) + return "", fmt.Errorf("error making request: %w", err) } defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("error reading response body: %v", err) + return "", fmt.Errorf("error reading response body: %w", err) } if resp.StatusCode != http.StatusOK { @@ -328,7 +328,7 @@ func (a *API) GetCodespaceRegionLocation(ctx context.Context) (string, error) { var response getCodespaceRegionLocationResponse if err := json.Unmarshal(b, &response); err != nil { - return "", fmt.Errorf("error unmarshaling response: %v", err) + return "", fmt.Errorf("error unmarshaling response: %w", err) } return response.Current, nil @@ -342,7 +342,7 @@ type SKU struct { func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Repository, branch, location string) ([]*SKU, error) { req, err := http.NewRequest(http.MethodGet, githubAPI+"/vscs_internal/user/"+user.Login+"/skus", nil) if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) + return nil, fmt.Errorf("error creating request: %w", err) } q := req.URL.Query() @@ -354,13 +354,13 @@ func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Rep a.setHeaders(req) resp, err := a.do(ctx, req, "/vscs_internal/user/*/skus") if err != nil { - return nil, fmt.Errorf("error making request: %v", err) + return nil, fmt.Errorf("error making request: %w", err) } defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) + return nil, fmt.Errorf("error reading response body: %w", err) } if resp.StatusCode != http.StatusOK { @@ -371,7 +371,7 @@ func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Rep SKUs []*SKU `json:"skus"` } if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response: %v", err) + return nil, fmt.Errorf("error unmarshaling response: %w", err) } return response.SKUs, nil @@ -387,24 +387,24 @@ type createCodespaceRequest struct { func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repository, sku, branch, location string) (*Codespace, error) { requestBody, err := json.Marshal(createCodespaceRequest{repository.ID, branch, location, sku}) if err != nil { - return nil, fmt.Errorf("error marshaling request: %v", err) + return nil, fmt.Errorf("error marshaling request: %w", err) } req, err := http.NewRequest(http.MethodPost, githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces", bytes.NewBuffer(requestBody)) if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) + return nil, fmt.Errorf("error creating request: %w", err) } a.setHeaders(req) resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces") if err != nil { - return nil, fmt.Errorf("error making request: %v", err) + return nil, fmt.Errorf("error making request: %w", err) } defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) + return nil, fmt.Errorf("error reading response body: %w", err) } if resp.StatusCode > http.StatusAccepted { @@ -413,7 +413,7 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos var response Codespace if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response: %v", err) + return nil, fmt.Errorf("error unmarshaling response: %w", err) } return &response, nil @@ -422,20 +422,20 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos func (a *API) DeleteCodespace(ctx context.Context, user *User, token, codespaceName string) error { req, err := http.NewRequest(http.MethodDelete, githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces/"+codespaceName, nil) if err != nil { - return fmt.Errorf("error creating request: %v", err) + return fmt.Errorf("error creating request: %w", err) } req.Header.Set("Authorization", "Bearer "+token) resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") if err != nil { - return fmt.Errorf("error making request: %v", err) + return fmt.Errorf("error making request: %w", err) } defer resp.Body.Close() if resp.StatusCode > http.StatusAccepted { b, err := ioutil.ReadAll(resp.Body) if err != nil { - return fmt.Errorf("error reading response body: %v", err) + return fmt.Errorf("error reading response body: %w", err) } return jsonErrorResponse(b) } @@ -450,7 +450,7 @@ type getCodespaceRepositoryContentsResponse struct { func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Codespace, path string) ([]byte, error) { req, err := http.NewRequest(http.MethodGet, githubAPI+"/repos/"+codespace.RepositoryNWO+"/contents/"+path, nil) if err != nil { - return nil, fmt.Errorf("error creating request: %v", err) + return nil, fmt.Errorf("error creating request: %w", err) } q := req.URL.Query() @@ -460,7 +460,7 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod a.setHeaders(req) resp, err := a.do(ctx, req, "/repos/*/contents/*") if err != nil { - return nil, fmt.Errorf("error making request: %v", err) + return nil, fmt.Errorf("error making request: %w", err) } defer resp.Body.Close() @@ -470,7 +470,7 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod b, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("error reading response body: %v", err) + return nil, fmt.Errorf("error reading response body: %w", err) } if resp.StatusCode != http.StatusOK { @@ -479,12 +479,12 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod var response getCodespaceRepositoryContentsResponse if err := json.Unmarshal(b, &response); err != nil { - return nil, fmt.Errorf("error unmarshaling response: %v", err) + return nil, fmt.Errorf("error unmarshaling response: %w", err) } decoded, err := base64.StdEncoding.DecodeString(response.Content) if err != nil { - return nil, fmt.Errorf("error decoding content: %v", err) + return nil, fmt.Errorf("error decoding content: %w", err) } return decoded, nil diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 3bd67053d..bdad09828 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -42,7 +42,7 @@ func code(codespaceName string, useInsiders bool) error { user, err := apiClient.GetUser(ctx) if err != nil { - return fmt.Errorf("error getting user: %v", err) + return fmt.Errorf("error getting user: %w", err) } if codespaceName == "" { @@ -51,7 +51,7 @@ func code(codespaceName string, useInsiders bool) error { if err == errNoCodespaces { return err } - return fmt.Errorf("error choosing codespace: %v", err) + return fmt.Errorf("error choosing codespace: %w", err) } codespaceName = codespace.Name } diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index bfcb67496..2e716e897 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -20,7 +20,7 @@ var errNoCodespaces = errors.New("You have no codespaces.") func chooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (*api.Codespace, error) { codespaces, err := apiClient.ListCodespaces(ctx, user) if err != nil { - return nil, fmt.Errorf("error getting codespaces: %v", err) + return nil, fmt.Errorf("error getting codespaces: %w", err) } if len(codespaces) == 0 { @@ -54,7 +54,7 @@ func chooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (* Codespace string } if err := ask(sshSurvey, &answers); err != nil { - return nil, fmt.Errorf("error getting answers: %v", err) + return nil, fmt.Errorf("error getting answers: %w", err) } codespace := codespacesByName[answers.Codespace] @@ -70,23 +70,23 @@ func getOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.Use if err == errNoCodespaces { return nil, "", err } - return nil, "", fmt.Errorf("choosing codespace: %v", err) + return nil, "", fmt.Errorf("choosing codespace: %w", err) } codespaceName = codespace.Name token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) if err != nil { - return nil, "", fmt.Errorf("getting codespace token: %v", err) + return nil, "", fmt.Errorf("getting codespace token: %w", err) } } else { token, err = apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) if err != nil { - return nil, "", fmt.Errorf("getting codespace token for given codespace: %v", err) + return nil, "", fmt.Errorf("getting codespace token for given codespace: %w", err) } codespace, err = apiClient.GetCodespace(ctx, token, user.Login, codespaceName) if err != nil { - return nil, "", fmt.Errorf("getting full codespace details: %v", err) + return nil, "", fmt.Errorf("getting full codespace details: %w", err) } } diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 093450e7d..9f9b1da7a 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -55,31 +55,31 @@ func create(opts *createOptions) error { repo, err := getRepoName(opts.repo) if err != nil { - return fmt.Errorf("error getting repository name: %v", err) + return fmt.Errorf("error getting repository name: %w", err) } branch, err := getBranchName(opts.branch) if err != nil { - return fmt.Errorf("error getting branch name: %v", err) + return fmt.Errorf("error getting branch name: %w", err) } repository, err := apiClient.GetRepository(ctx, repo) if err != nil { - return fmt.Errorf("error getting repository: %v", err) + return fmt.Errorf("error getting repository: %w", err) } locationResult := <-locationCh if locationResult.Err != nil { - return fmt.Errorf("error getting codespace region location: %v", locationResult.Err) + return fmt.Errorf("error getting codespace region location: %w", locationResult.Err) } userResult := <-userCh if userResult.Err != nil { - return fmt.Errorf("error getting codespace user: %v", userResult.Err) + return fmt.Errorf("error getting codespace user: %w", userResult.Err) } machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, apiClient) if err != nil { - return fmt.Errorf("error getting machine type: %v", err) + return fmt.Errorf("error getting machine type: %w", err) } if machine == "" { return errors.New("There are no available machine types for this repository") @@ -89,7 +89,7 @@ func create(opts *createOptions) error { codespace, err := apiClient.CreateCodespace(ctx, userResult.User, repository, machine, branch, locationResult.Location) if err != nil { - return fmt.Errorf("error creating codespace: %v", err) + return fmt.Errorf("error creating codespace: %w", err) } if opts.showStatus { @@ -154,7 +154,7 @@ func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, use } if err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace, poller); err != nil { - return fmt.Errorf("failed to poll state changes from codespace: %v", err) + return fmt.Errorf("failed to poll state changes from codespace: %w", err) } return nil @@ -228,7 +228,7 @@ func getBranchName(branch string) (string, error) { func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient *api.API) (string, error) { skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, branch, location) if err != nil { - return "", fmt.Errorf("error requesting machine instance types: %v", err) + return "", fmt.Errorf("error requesting machine instance types: %w", err) } // if user supplied a machine type, it must be valid @@ -278,7 +278,7 @@ func getMachineName(ctx context.Context, machine string, user *api.User, repo *a var skuAnswers struct{ SKU string } if err := ask(skuSurvey, &skuAnswers); err != nil { - return "", fmt.Errorf("error getting SKU: %v", err) + return "", fmt.Errorf("error getting SKU: %w", err) } sku := skuByName[skuAnswers.SKU] diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 7800a13c0..100793010 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -55,16 +55,16 @@ func delete_(log *output.Logger, codespaceName string) error { user, err := apiClient.GetUser(ctx) if err != nil { - return fmt.Errorf("error getting user: %v", err) + return fmt.Errorf("error getting user: %w", err) } codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { - return fmt.Errorf("get or choose codespace: %v", err) + return fmt.Errorf("get or choose codespace: %w", err) } if err := apiClient.DeleteCodespace(ctx, user, token, codespace.Name); err != nil { - return fmt.Errorf("error deleting codespace: %v", err) + return fmt.Errorf("error deleting codespace: %w", err) } log.Println("Codespace deleted.") @@ -78,22 +78,22 @@ func deleteAll(log *output.Logger) error { user, err := apiClient.GetUser(ctx) if err != nil { - return fmt.Errorf("error getting user: %v", err) + return fmt.Errorf("error getting user: %w", err) } codespaces, err := apiClient.ListCodespaces(ctx, user) if err != nil { - return fmt.Errorf("error getting codespaces: %v", err) + return fmt.Errorf("error getting codespaces: %w", err) } for _, c := range codespaces { token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) + return fmt.Errorf("error getting codespace token: %w", err) } if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { - return fmt.Errorf("error deleting codespace: %v", err) + return fmt.Errorf("error deleting codespace: %w", err) } log.Printf("Codespace deleted: %s\n", c.Name) @@ -108,12 +108,12 @@ func deleteByRepo(log *output.Logger, repo string) error { user, err := apiClient.GetUser(ctx) if err != nil { - return fmt.Errorf("error getting user: %v", err) + return fmt.Errorf("error getting user: %w", err) } codespaces, err := apiClient.ListCodespaces(ctx, user) if err != nil { - return fmt.Errorf("error getting codespaces: %v", err) + return fmt.Errorf("error getting codespaces: %w", err) } var deleted bool @@ -125,11 +125,11 @@ func deleteByRepo(log *output.Logger, repo string) error { token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) + return fmt.Errorf("error getting codespace token: %w", err) } if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { - return fmt.Errorf("error deleting codespace: %v", err) + return fmt.Errorf("error deleting codespace: %w", err) } log.Printf("Codespace deleted: %s\n", c.Name) diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index ee26e3013..a315d12ed 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -41,12 +41,12 @@ func list(opts *listOptions) error { user, err := apiClient.GetUser(ctx) if err != nil { - return fmt.Errorf("error getting user: %v", err) + return fmt.Errorf("error getting user: %w", err) } codespaces, err := apiClient.ListCodespaces(ctx, user) if err != nil { - return fmt.Errorf("error getting codespaces: %v", err) + return fmt.Errorf("error getting codespaces: %w", err) } table := output.NewTable(os.Stdout, opts.asJSON) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index a685a364d..6b5e2f875 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -49,17 +49,17 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow user, err := apiClient.GetUser(ctx) if err != nil { - return fmt.Errorf("getting user: %v", err) + return fmt.Errorf("getting user: %w", err) } codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { - return fmt.Errorf("get or choose codespace: %v", err) + return fmt.Errorf("get or choose codespace: %w", err) } session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return fmt.Errorf("connecting to Live Share: %v", err) + return fmt.Errorf("connecting to Live Share: %w", err) } // Ensure local port is listening before client (getPostCreateOutput) connects. @@ -72,7 +72,7 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log) if err != nil { - return fmt.Errorf("error getting ssh server details: %v", err) + return fmt.Errorf("error getting ssh server details: %w", err) } cmdType := "cat" @@ -98,10 +98,10 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow select { case err := <-tunnelClosed: - return fmt.Errorf("connection closed: %v", err) + return fmt.Errorf("connection closed: %w", err) case err := <-cmdDone: if err != nil { - return fmt.Errorf("error retrieving logs: %v", err) + return fmt.Errorf("error retrieving logs: %w", err) } return nil // success diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index f6d1f6e2a..12c631c04 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -58,7 +58,7 @@ func ports(codespaceName string, asJSON bool) error { user, err := apiClient.GetUser(ctx) if err != nil { - return fmt.Errorf("error getting user: %v", err) + return fmt.Errorf("error getting user: %w", err) } codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) @@ -67,20 +67,20 @@ func ports(codespaceName string, asJSON bool) error { if err == errNoCodespaces { return err } - return fmt.Errorf("error choosing codespace: %v", err) + return fmt.Errorf("error choosing codespace: %w", err) } devContainerCh := getDevContainer(ctx, apiClient, codespace) session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return fmt.Errorf("error connecting to Live Share: %v", err) + return fmt.Errorf("error connecting to Live Share: %w", err) } log.Println("Loading ports...") ports, err := session.GetSharedServers(ctx) if err != nil { - return fmt.Errorf("error getting ports of shared servers: %v", err) + return fmt.Errorf("error getting ports of shared servers: %w", err) } devContainerResult := <-devContainerCh @@ -130,7 +130,7 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod go func() { contents, err := apiClient.GetCodespaceRepositoryContents(ctx, codespace, ".devcontainer/devcontainer.json") if err != nil { - ch <- devContainerResult{nil, fmt.Errorf("error getting content: %v", err)} + ch <- devContainerResult{nil, fmt.Errorf("error getting content: %w", err)} return } @@ -147,7 +147,7 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod var container devContainer if err := json.Unmarshal(convertedJSON, &container); err != nil { - ch <- devContainerResult{nil, fmt.Errorf("error unmarshaling: %v", err)} + ch <- devContainerResult{nil, fmt.Errorf("error unmarshaling: %w", err)} return } @@ -168,7 +168,7 @@ func newPortsPublicCmd() *cobra.Command { // should only happen if flag is not defined // or if the flag is not of string type // since it's a persistent flag that we control it should never happen - return fmt.Errorf("get codespace flag: %v", err) + return fmt.Errorf("get codespace flag: %w", err) } log := output.NewLogger(os.Stdout, os.Stderr, false) @@ -189,7 +189,7 @@ func newPortsPrivateCmd() *cobra.Command { // should only happen if flag is not defined // or if the flag is not of string type // since it's a persistent flag that we control it should never happen - return fmt.Errorf("get codespace flag: %v", err) + return fmt.Errorf("get codespace flag: %w", err) } log := output.NewLogger(os.Stdout, os.Stderr, false) @@ -204,7 +204,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, user, err := apiClient.GetUser(ctx) if err != nil { - return fmt.Errorf("error getting user: %v", err) + return fmt.Errorf("error getting user: %w", err) } codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) @@ -212,21 +212,21 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, if err == errNoCodespaces { return err } - return fmt.Errorf("error getting codespace: %v", err) + return fmt.Errorf("error getting codespace: %w", err) } session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return fmt.Errorf("error connecting to Live Share: %v", err) + return fmt.Errorf("error connecting to Live Share: %w", err) } port, err := strconv.Atoi(sourcePort) if err != nil { - return fmt.Errorf("error reading port number: %v", err) + return fmt.Errorf("error reading port number: %w", err) } if err := session.UpdateSharedVisibility(ctx, port, public); err != nil { - return fmt.Errorf("error update port to public: %v", err) + return fmt.Errorf("error update port to public: %w", err) } state := "PUBLIC" @@ -251,7 +251,7 @@ func newPortsForwardCmd() *cobra.Command { // should only happen if flag is not defined // or if the flag is not of string type // since it's a persistent flag that we control it should never happen - return fmt.Errorf("get codespace flag: %v", err) + return fmt.Errorf("get codespace flag: %w", err) } log := output.NewLogger(os.Stdout, os.Stderr, false) @@ -266,12 +266,12 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro portPairs, err := getPortPairs(ports) if err != nil { - return fmt.Errorf("get port pairs: %v", err) + return fmt.Errorf("get port pairs: %w", err) } user, err := apiClient.GetUser(ctx) if err != nil { - return fmt.Errorf("error getting user: %v", err) + return fmt.Errorf("error getting user: %w", err) } codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) @@ -279,12 +279,12 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro if err == errNoCodespaces { return err } - return fmt.Errorf("error getting codespace: %v", err) + return fmt.Errorf("error getting codespace: %w", err) } session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return fmt.Errorf("error connecting to Live Share: %v", err) + return fmt.Errorf("error connecting to Live Share: %w", err) } // Run forwarding of all ports concurrently, aborting all of @@ -323,12 +323,12 @@ func getPortPairs(ports []string) ([]portPair, error) { remote, err := strconv.Atoi(parts[0]) if err != nil { - return pp, fmt.Errorf("convert remote port to int: %v", err) + return pp, fmt.Errorf("convert remote port to int: %w", err) } local, err := strconv.Atoi(parts[1]) if err != nil { - return pp, fmt.Errorf("convert local port to int: %v", err) + return pp, fmt.Errorf("convert local port to int: %w", err) } pp = append(pp, portPair{remote, local}) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index e5289e205..fd2086bcc 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -46,22 +46,22 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo user, err := apiClient.GetUser(ctx) if err != nil { - return fmt.Errorf("error getting user: %v", err) + return fmt.Errorf("error getting user: %w", err) } codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { - return fmt.Errorf("get or choose codespace: %v", err) + return fmt.Errorf("get or choose codespace: %w", err) } session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return fmt.Errorf("error connecting to Live Share: %v", err) + return fmt.Errorf("error connecting to Live Share: %w", err) } remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log) if err != nil { - return fmt.Errorf("error getting ssh server details: %v", err) + return fmt.Errorf("error getting ssh server details: %w", err) } usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell @@ -93,10 +93,10 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo select { case err := <-tunnelClosed: - return fmt.Errorf("tunnel closed: %v", err) + return fmt.Errorf("tunnel closed: %w", err) case err := <-shellClosed: if err != nil { - return fmt.Errorf("shell closed: %v", err) + return fmt.Errorf("shell closed: %w", err) } return nil // success } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 9aee3564c..805eb3c96 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -29,7 +29,7 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use startedCodespace = true log.Print("Starting your codespace...") if err := apiClient.StartCodespace(ctx, token, codespace); err != nil { - return nil, fmt.Errorf("error starting codespace: %v", err) + return nil, fmt.Errorf("error starting codespace: %w", err) } } @@ -49,7 +49,7 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use var err error codespace, err = apiClient.GetCodespace(ctx, token, userLogin, codespace.Name) if err != nil { - return nil, fmt.Errorf("error getting codespace: %v", err) + return nil, fmt.Errorf("error getting codespace: %w", err) } } @@ -68,7 +68,7 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use }), ) if err != nil { - return nil, fmt.Errorf("error creating Live Share client: %v", err) + return nil, fmt.Errorf("error creating Live Share client: %w", err) } return lsclient.JoinWorkspace(ctx) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 14dbfbb88..256acee2b 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -21,7 +21,7 @@ func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) sshServerStartResult, err := sshServer.StartRemoteServer(ctx) if err != nil { - return 0, "", fmt.Errorf("error starting live share: %v", err) + return 0, "", fmt.Errorf("error starting live share: %w", err) } if !sshServerStartResult.Result { @@ -30,7 +30,7 @@ func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) portInt, err := strconv.Atoi(sshServerStartResult.ServerPort) if err != nil { - return 0, "", fmt.Errorf("error parsing port: %v", err) + return 0, "", fmt.Errorf("error parsing port: %w", err) } return portInt, sshServerStartResult.User, nil diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 492ce3964..2d7da9d75 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -39,12 +39,12 @@ type PostCreateState struct { func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { - return fmt.Errorf("getting codespace token: %v", err) + return fmt.Errorf("getting codespace token: %w", err) } session, err := ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return fmt.Errorf("connect to Live Share: %v", err) + return fmt.Errorf("connect to Live Share: %w", err) } // Ensure local port is listening before client (getPostCreateOutput) connects. @@ -56,7 +56,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, session, log) if err != nil { - return fmt.Errorf("error getting ssh server details: %v", err) + return fmt.Errorf("error getting ssh server details: %w", err) } tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness @@ -74,12 +74,12 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return ctx.Err() case err := <-tunnelClosed: - return fmt.Errorf("connection failed: %v", err) + return fmt.Errorf("connection failed: %w", err) case <-t.C: states, err := getPostCreateOutput(ctx, localPort, codespace, sshUser) if err != nil { - return fmt.Errorf("get post create output: %v", err) + return fmt.Errorf("get post create output: %w", err) } poller(states) @@ -95,13 +95,13 @@ func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Cod stdout := new(bytes.Buffer) cmd.Stdout = stdout if err := cmd.Run(); err != nil { - return nil, fmt.Errorf("run command: %v", err) + return nil, fmt.Errorf("run command: %w", err) } var output struct { Steps []PostCreateState `json:"steps"` } if err := json.Unmarshal(stdout.Bytes(), &output); err != nil { - return nil, fmt.Errorf("unmarshal output: %v", err) + return nil, fmt.Errorf("unmarshal output: %w", err) } return output.Steps, nil From 8abff2af97688a47744066aac4cd632f22259b53 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 15 Sep 2021 13:14:58 +0000 Subject: [PATCH 198/290] move StartSSHServer to Session --- port_forwarder.go | 2 +- session.go | 28 ++++++++++++++++++++++++++++ ssh_server.go | 37 ------------------------------------- 3 files changed, 29 insertions(+), 38 deletions(-) delete mode 100644 ssh_server.go diff --git a/port_forwarder.go b/port_forwarder.go index 400d6ac97..5dafd0c65 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -9,7 +9,7 @@ import ( "github.com/opentracing/opentracing-go" ) -// A PortForwarder forwards TCP traffic over a LiveShare session from a port on a remote +// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote // container to a local destination such as a network port or Go reader/writer. type PortForwarder struct { session *Session diff --git a/session.go b/session.go index 0e3120cd7..1d13ad58c 100644 --- a/session.go +++ b/session.go @@ -3,6 +3,7 @@ package liveshare import ( "context" "fmt" + "strconv" ) // A Session represents the session between a connected Live Share client and server. @@ -59,3 +60,30 @@ func (s *Session) UpdateSharedVisibility(ctx context.Context, port int, public b return nil } + +// StartSSHServer starts the SSHD server and returns the user and port for which to authenticate with. +// If there is no SSHD server installed on the server, it will attempt to install it. The installation +// process can take upwards of 20+ seconds. +func (s *Session) StartSSHServer(ctx context.Context) (string, int64, error) { + var response struct { + Result bool `json:"result"` + ServerPort string `json:"serverPort"` + User string `json:"user"` + Message string `json:"message"` + } + + if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { + return "", 0, err + } + + if !response.Result { + return "", 0, fmt.Errorf("failed to start server: %s", response.Message) + } + + port, err := strconv.ParseInt(response.ServerPort, 10, 64) + if err != nil { + return "", 0, fmt.Errorf("failed to parse port: %w", err) + } + + return response.User, port, nil +} diff --git a/ssh_server.go b/ssh_server.go deleted file mode 100644 index ca66ec7de..000000000 --- a/ssh_server.go +++ /dev/null @@ -1,37 +0,0 @@ -package liveshare - -import ( - "context" -) - -// A SSHServer handles starting the remote SSH server. -// If there is no SSH server available it installs one. -type SSHServer struct { - session *Session -} - -// SSHServer returns a new SSHServer from the LiveShare Session. -func (session *Session) SSHServer() *SSHServer { - return &SSHServer{session: session} -} - -// SSHServerStartResult contains whether or not the start of the SSH server was -// successful. If it succeeded the server port and user is included. If it failed, -// it contains an explanation message. -type SSHServerStartResult struct { - Result bool `json:"result"` - ServerPort string `json:"serverPort"` - User string `json:"user"` - Message string `json:"message"` -} - -// StartRemoteServer starts or install the remote SSH server and returns the result. -func (s *SSHServer) StartRemoteServer(ctx context.Context) (*SSHServerStartResult, error) { - var response SSHServerStartResult - - if err := s.session.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { - return nil, err - } - - return &response, nil -} From 20e618fd025e115726853db4b4fc37ec76295ebd Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 15 Sep 2021 13:49:03 +0000 Subject: [PATCH 199/290] pr feedback --- session.go | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/session.go b/session.go index 1d13ad58c..f427fac6d 100644 --- a/session.go +++ b/session.go @@ -61,10 +61,9 @@ func (s *Session) UpdateSharedVisibility(ctx context.Context, port int, public b return nil } -// StartSSHServer starts the SSHD server and returns the user and port for which to authenticate with. -// If there is no SSHD server installed on the server, it will attempt to install it. The installation -// process can take upwards of 20+ seconds. -func (s *Session) StartSSHServer(ctx context.Context) (string, int64, error) { +// StartsSSHServer starts an SSH server in the container, installing sshd if necessary, +// and returns the port on which it listens and the user name clients should provide. +func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) { var response struct { Result bool `json:"result"` ServerPort string `json:"serverPort"` @@ -73,17 +72,17 @@ func (s *Session) StartSSHServer(ctx context.Context) (string, int64, error) { } if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil { - return "", 0, err + return 0, "", err } if !response.Result { - return "", 0, fmt.Errorf("failed to start server: %s", response.Message) + return 0, "", fmt.Errorf("failed to start server: %s", response.Message) } - port, err := strconv.ParseInt(response.ServerPort, 10, 64) + port, err := strconv.Atoi(response.ServerPort) if err != nil { - return "", 0, fmt.Errorf("failed to parse port: %w", err) + return 0, "", fmt.Errorf("failed to parse port: %w", err) } - return response.User, port, nil + return port, response.User, nil } From 547c62922050ad98fa8a20f42c2532b05c932909 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 15 Sep 2021 10:38:19 -0400 Subject: [PATCH 200/290] fix ctx cancellation errors & fix todo for X11 forwarding --- cmd/ghcs/create.go | 7 ++++++- internal/codespaces/ssh.go | 9 +++++++-- internal/codespaces/states.go | 14 +++++++------- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 093450e7d..3669d7434 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -153,7 +153,12 @@ func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, use } } - if err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace, poller); err != nil { + err := codespaces.PollPostCreateStates(ctx, log, apiClient, user, codespace, poller) + if err != nil { + if errors.Is(err, context.Canceled) && breakNextState { + return nil // we cancelled the context to stop polling, we can ignore the error + } + return fmt.Errorf("failed to poll state changes from codespace: %v", err) } diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 14dbfbb88..7dcab3de4 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -60,9 +60,14 @@ func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command // an interactive shell) over ssh. func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cmd, []string) { connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} - // TODO(adonovan): eliminate X11 and X11Trust flags where unneeded. - cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression + cmdArgs := []string{dst, "-C"} // Always use Compression + if command == "" { + // if we are in a shell send X11 and X11Trust + cmdArgs = append(cmdArgs, "-X", "-Y") + } + + cmdArgs = append(cmdArgs, connArgs...) if command != "" { cmdArgs = append(cmdArgs, command) } diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 492ce3964..2d7da9d75 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -39,12 +39,12 @@ type PostCreateState struct { func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { - return fmt.Errorf("getting codespace token: %v", err) + return fmt.Errorf("getting codespace token: %w", err) } session, err := ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) if err != nil { - return fmt.Errorf("connect to Live Share: %v", err) + return fmt.Errorf("connect to Live Share: %w", err) } // Ensure local port is listening before client (getPostCreateOutput) connects. @@ -56,7 +56,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, session, log) if err != nil { - return fmt.Errorf("error getting ssh server details: %v", err) + return fmt.Errorf("error getting ssh server details: %w", err) } tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness @@ -74,12 +74,12 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return ctx.Err() case err := <-tunnelClosed: - return fmt.Errorf("connection failed: %v", err) + return fmt.Errorf("connection failed: %w", err) case <-t.C: states, err := getPostCreateOutput(ctx, localPort, codespace, sshUser) if err != nil { - return fmt.Errorf("get post create output: %v", err) + return fmt.Errorf("get post create output: %w", err) } poller(states) @@ -95,13 +95,13 @@ func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Cod stdout := new(bytes.Buffer) cmd.Stdout = stdout if err := cmd.Run(); err != nil { - return nil, fmt.Errorf("run command: %v", err) + return nil, fmt.Errorf("run command: %w", err) } var output struct { Steps []PostCreateState `json:"steps"` } if err := json.Unmarshal(stdout.Bytes(), &output); err != nil { - return nil, fmt.Errorf("unmarshal output: %v", err) + return nil, fmt.Errorf("unmarshal output: %w", err) } return output.Steps, nil From 06719866c95e834d87fd2ecd87b5889b14d6df2b Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Wed, 15 Sep 2021 13:09:31 -0400 Subject: [PATCH 201/290] move api to internal/ --- cmd/ghcs/code.go | 2 +- cmd/ghcs/common.go | 2 +- cmd/ghcs/create.go | 2 +- cmd/ghcs/delete.go | 2 +- cmd/ghcs/list.go | 2 +- cmd/ghcs/logs.go | 2 +- cmd/ghcs/ports.go | 2 +- cmd/ghcs/ssh.go | 2 +- {api => internal/api}/api.go | 19 ++++++++++++++++++- internal/codespaces/codespaces.go | 2 +- internal/codespaces/states.go | 2 +- 11 files changed, 28 insertions(+), 11 deletions(-) rename {api => internal/api}/api.go (94%) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 3bd67053d..0ec363a6d 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -6,7 +6,7 @@ import ( "net/url" "os" - "github.com/github/ghcs/api" + "github.com/github/ghcs/internal/api" "github.com/skratchdot/open-golang/open" "github.com/spf13/cobra" ) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index bfcb67496..133e6d1de 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -11,7 +11,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/AlecAivazis/survey/v2/terminal" - "github.com/github/ghcs/api" + "github.com/github/ghcs/internal/api" "golang.org/x/term" ) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 093450e7d..82414f80e 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -9,8 +9,8 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/fatih/camelcase" - "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" + "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" "github.com/spf13/cobra" ) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 7800a13c0..3f5d68684 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -7,8 +7,8 @@ import ( "os" "strings" - "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" + "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" ) diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index ee26e3013..dee1b0875 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -5,8 +5,8 @@ import ( "fmt" "os" - "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" + "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" ) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index a685a364d..8d7dc475e 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -6,8 +6,8 @@ import ( "net" "os" - "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" + "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index f6d1f6e2a..8256c13e2 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -11,8 +11,8 @@ import ( "strconv" "strings" - "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" + "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/muhammadmuzzammil1998/jsonc" diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index e5289e205..b020a799e 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -6,8 +6,8 @@ import ( "net" "os" - "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" + "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" "github.com/github/go-liveshare" "github.com/spf13/cobra" diff --git a/api/api.go b/internal/api/api.go similarity index 94% rename from api/api.go rename to internal/api/api.go index ad69f23fc..8277e903a 100644 --- a/api/api.go +++ b/internal/api/api.go @@ -1,4 +1,3 @@ -// TODO(adonovan): rename to package codespaces, and codespaces.Client. package api // For descriptions of service interfaces, see: @@ -7,6 +6,24 @@ package api // - https://github.com/github/github/blob/master/app/api/codespaces.rb (for vscs_internal) // TODO(adonovan): replace the last link with a public doc URL when available. +// TODO(adonovan): a possible reorganization would be to split this +// file into three internal packages, one per backend service, and to +// rename api.API to github.Client: +// +// - github.GetUser(github.Client) +// - github.GetRepository(Client) +// - github.ReadFile(Client, nwo, branch, path) // was GetCodespaceRepositoryContents +// - codespaces.Create(Client, user, repo, sku, branch, location) +// - codespaces.Delete(Client, user, token, name) +// - codespaces.Get(Client, token, owner, name) +// - codespaces.GetMachineTypes(Client, user, repo, branch, location) +// - codespaces.GetToken(Client, login, name) +// - codespaces.List(Client, user) +// - codespaces.Start(Client, token, codespace) +// - visualstudio.GetRegionLocation(http.Client) // no dependency on github +// +// This would make the meaning of each operation clearer. + import ( "bytes" "context" diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 9aee3564c..804c2dab5 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "github.com/github/ghcs/api" + "github.com/github/ghcs/internal/api" "github.com/github/go-liveshare" ) diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 492ce3964..c7a7767ae 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "github.com/github/ghcs/api" + "github.com/github/ghcs/internal/api" "github.com/github/go-liveshare" ) From 0f72e3d88642e36a1f3cbb2ef95866b2269541d2 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 15 Sep 2021 14:29:16 -0400 Subject: [PATCH 202/290] defer stopPolling and docs --- cmd/ghcs/create.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 3669d7434..aa4c14658 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -105,12 +105,16 @@ func create(opts *createOptions) error { return nil } +// showStatus polls the codespace for a list of post create states and their status. It will keep polling +// until all states have finished. Once all states have finished, we poll once more to check if any new +// states have been introduced and stop polling otherwise. func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, user *api.User, codespace *api.Codespace) error { var lastState codespaces.PostCreateState var breakNextState bool finishedStates := make(map[string]bool) ctx, stopPolling := context.WithCancel(ctx) + defer stopPolling() poller := func(states []codespaces.PostCreateState) { var inProgress bool From ecd0c7056798bcd1b0d1fb2d65d3da7b1477a7be Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 15 Sep 2021 15:15:28 -0400 Subject: [PATCH 203/290] upgrade to go-liveshare 0.16.0 --- cmd/ghcs/ssh.go | 3 ++- internal/codespaces/ssh.go | 42 ----------------------------------- internal/codespaces/states.go | 3 ++- 3 files changed, 4 insertions(+), 44 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index e5289e205..08c3bd7ca 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -59,7 +59,8 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo return fmt.Errorf("error connecting to Live Share: %v", err) } - remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log) + log.Println("Fetching SSH Details...") + remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) if err != nil { return fmt.Errorf("error getting ssh server details: %v", err) } diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 14dbfbb88..28c2761b1 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -2,53 +2,11 @@ package codespaces import ( "context" - "errors" - "fmt" "os" "os/exec" "strconv" - "strings" - - "github.com/github/go-liveshare" ) -// StartSSHServer installs (if necessary) and starts the SSH in the codespace. -// It returns the remote port where it is running, the user to log in with, or an error if something failed. -func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) (serverPort int, user string, err error) { - log.Println("Fetching SSH details...") - - sshServer := session.SSHServer() - - sshServerStartResult, err := sshServer.StartRemoteServer(ctx) - if err != nil { - return 0, "", fmt.Errorf("error starting live share: %v", err) - } - - if !sshServerStartResult.Result { - return 0, "", errors.New(sshServerStartResult.Message) - } - - portInt, err := strconv.Atoi(sshServerStartResult.ServerPort) - if err != nil { - return 0, "", fmt.Errorf("error parsing port: %v", err) - } - - return portInt, sshServerStartResult.User, nil -} - -// Shell runs an interactive secure shell over an existing -// port-forwarding session. It runs until the shell is terminated -// (including by cancellation of the context). -func Shell(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) error { - cmd, connArgs := newSSHCommand(ctx, port, destination, "") - - if usingCustomPort { - log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " ")) - } - - return cmd.Run() -} - // NewRemoteCommand returns an exec.Cmd that will securely run a shell // command on the remote machine. func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 492ce3964..f4375fa5d 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -54,7 +54,8 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u } localPort := listen.Addr().(*net.TCPAddr).Port - remoteSSHServerPort, sshUser, err := StartSSHServer(ctx, session, log) + log.Println("Fetching SSH Details...") + remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) if err != nil { return fmt.Errorf("error getting ssh server details: %v", err) } From 26d3199082dae4dcf8bf75d234e94ff781b872c4 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 15 Sep 2021 15:18:54 -0400 Subject: [PATCH 204/290] add back codespaces.Shell --- internal/codespaces/ssh.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 28c2761b1..d39fc17a0 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -5,8 +5,22 @@ import ( "os" "os/exec" "strconv" + "strings" ) +// Shell runs an interactive secure shell over an existing +// port-forwarding session. It runs until the shell is terminated +// (including by cancellation of the context). +func Shell(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) error { + cmd, connArgs := newSSHCommand(ctx, port, destination, "") + + if usingCustomPort { + log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " ")) + } + + return cmd.Run() +} + // NewRemoteCommand returns an exec.Cmd that will securely run a shell // command on the remote machine. func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd { From c5bd8c41279a91711be7ce5a33cfcba4c98208bd Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 15 Sep 2021 15:37:37 -0400 Subject: [PATCH 205/290] initial spike to accept args --- cmd/ghcs/ssh.go | 6 +++--- internal/codespaces/ssh.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index e5289e205..3bc2110a0 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -21,7 +21,7 @@ func newSSHCmd() *cobra.Command { Use: "ssh", Short: "SSH into a codespace", RunE: func(cmd *cobra.Command, args []string) error { - return ssh(context.Background(), sshProfile, codespaceName, sshServerPort) + return ssh(context.Background(), args, sshProfile, codespaceName, sshServerPort) }, } @@ -36,7 +36,7 @@ func init() { rootCmd.AddCommand(newSSHCmd()) } -func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPort int) error { +func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) error { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -88,7 +88,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo shellClosed := make(chan error, 1) go func() { - shellClosed <- codespaces.Shell(ctx, log, localSSHServerPort, connectDestination, usingCustomPort) + shellClosed <- codespaces.Shell(ctx, log, sshArgs, localSSHServerPort, connectDestination, usingCustomPort) }() select { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 14dbfbb88..717371286 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -39,7 +39,7 @@ func StartSSHServer(ctx context.Context, session *liveshare.Session, log logger) // Shell runs an interactive secure shell over an existing // port-forwarding session. It runs until the shell is terminated // (including by cancellation of the context). -func Shell(ctx context.Context, log logger, port int, destination string, usingCustomPort bool) error { +func Shell(ctx context.Context, log logger, sshArgs []string, port int, destination string, usingCustomPort bool) error { cmd, connArgs := newSSHCommand(ctx, port, destination, "") if usingCustomPort { From b2234969e4f728e049a9dc0f5eeff856baf91af7 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 15 Sep 2021 15:40:07 -0400 Subject: [PATCH 206/290] update logs --- cmd/ghcs/logs.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index a685a364d..be283818f 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -70,7 +70,8 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow defer listen.Close() localPort := listen.Addr().(*net.TCPAddr).Port - remoteSSHServerPort, sshUser, err := codespaces.StartSSHServer(ctx, session, log) + log.Println("Fetching SSH Details...") + remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) if err != nil { return fmt.Errorf("error getting ssh server details: %v", err) } From cc1b86461e457a5ab0eac363f360791389f1fe83 Mon Sep 17 00:00:00 2001 From: Christian Gregg Date: Thu, 16 Sep 2021 13:47:15 +0100 Subject: [PATCH 207/290] Confirm deletion of codespaces with unpushed/uncommited changes (#129) Adds a confirmation dialog on `ghcs delete` if the codespace in question has unpushed or uncommited changes. This confirmation can be skipped using the `--force` or `-f` flag. Closes: #84 Closes: #10 --- cmd/ghcs/delete.go | 72 +++++++++++++++++++++++++++++++++++++++++----- cmd/ghcs/list.go | 13 ++++----- 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index b2547e350..961310675 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -7,6 +7,7 @@ import ( "os" "strings" + "github.com/AlecAivazis/survey/v2" "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" @@ -17,10 +18,10 @@ func newDeleteCmd() *cobra.Command { codespace string allCodespaces bool repo string + force bool ) log := output.NewLogger(os.Stdout, os.Stderr, false) - deleteCmd := &cobra.Command{ Use: "delete", Short: "Delete a codespace", @@ -29,11 +30,11 @@ func newDeleteCmd() *cobra.Command { case allCodespaces && repo != "": return errors.New("both --all and --repo is not supported.") case allCodespaces: - return deleteAll(log) + return deleteAll(log, force) case repo != "": - return deleteByRepo(log, repo) + return deleteByRepo(log, repo, force) default: - return delete_(log, codespace) + return delete_(log, codespace, force) } }, } @@ -41,6 +42,7 @@ func newDeleteCmd() *cobra.Command { deleteCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") deleteCmd.Flags().BoolVar(&allCodespaces, "all", false, "Delete all codespaces") deleteCmd.Flags().StringVarP(&repo, "repo", "r", "", "Delete all codespaces for a repository") + deleteCmd.Flags().BoolVarP(&force, "force", "f", false, "Delete codespaces with unsaved changes without confirmation") return deleteCmd } @@ -49,7 +51,7 @@ func init() { rootCmd.AddCommand(newDeleteCmd()) } -func delete_(log *output.Logger, codespaceName string) error { +func delete_(log *output.Logger, codespaceName string, force bool) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() @@ -63,6 +65,15 @@ func delete_(log *output.Logger, codespaceName string) error { return fmt.Errorf("get or choose codespace: %w", err) } + confirmed, err := confirmDeletion(codespace, force) + if err != nil { + return fmt.Errorf("deletion could not be confirmed: %w", err) + } + + if !confirmed { + return nil + } + if err := apiClient.DeleteCodespace(ctx, user, token, codespace.Name); err != nil { return fmt.Errorf("error deleting codespace: %w", err) } @@ -72,7 +83,7 @@ func delete_(log *output.Logger, codespaceName string) error { return list(&listOptions{}) } -func deleteAll(log *output.Logger) error { +func deleteAll(log *output.Logger, force bool) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() @@ -87,6 +98,15 @@ func deleteAll(log *output.Logger) error { } for _, c := range codespaces { + confirmed, err := confirmDeletion(c, force) + if err != nil { + return fmt.Errorf("deletion could not be confirmed: %w", err) + } + + if !confirmed { + continue + } + token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) if err != nil { return fmt.Errorf("error getting codespace token: %w", err) @@ -102,7 +122,7 @@ func deleteAll(log *output.Logger) error { return list(&listOptions{}) } -func deleteByRepo(log *output.Logger, repo string) error { +func deleteByRepo(log *output.Logger, repo string, force bool) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() @@ -121,6 +141,16 @@ func deleteByRepo(log *output.Logger, repo string) error { if !strings.EqualFold(c.RepositoryNWO, repo) { continue } + + confirmed, err := confirmDeletion(c, force) + if err != nil { + return fmt.Errorf("deletion could not be confirmed: %w", err) + } + + if !confirmed { + continue + } + deleted = true token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) @@ -141,3 +171,31 @@ func deleteByRepo(log *output.Logger, repo string) error { return list(&listOptions{}) } + +func confirmDeletion(codespace *api.Codespace, force bool) (bool, error) { + gs := codespace.Environment.GitStatus + hasUnsavedChanges := gs.HasUncommitedChanges || gs.HasUnpushedChanges + if force || !hasUnsavedChanges { + return true, nil + } + if !hasTTY { + return false, fmt.Errorf("codespace %s has unsaved changes (use --force to override)", codespace.Name) + } + + var confirmed struct { + Confirmed bool + } + q := []*survey.Question{ + { + Name: "confirmed", + Prompt: &survey.Confirm{ + Message: fmt.Sprintf("Codespace %s has unsaved changes. OK to delete?", codespace.Name), + }, + }, + } + if err := ask(q, &confirmed); err != nil { + return false, fmt.Errorf("failed to prompt: %w", err) + } + + return confirmed.Confirmed, nil +} diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 2db609d50..72a25becc 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -55,7 +55,7 @@ func list(opts *listOptions) error { table.Append([]string{ codespace.Name, codespace.RepositoryNWO, - branch(codespace), + codespace.Name + dirtyStar(codespace.Environment.GitStatus), codespace.Environment.State, codespace.CreatedAt, }) @@ -65,13 +65,10 @@ func list(opts *listOptions) error { return nil } -func branch(codespace *api.Codespace) string { - name := codespace.Branch - gitStatus := codespace.Environment.GitStatus - - if gitStatus.HasUncommitedChanges || gitStatus.HasUnpushedChanges { - name += "*" +func dirtyStar(status api.CodespaceEnvironmentGitStatus) string { + if status.HasUncommitedChanges || status.HasUnpushedChanges { + return "*" } - return name + return "" } From 8a0f8b6d1c1834186ca4fcb6da650d34df89b1eb Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 16 Sep 2021 10:32:27 -0400 Subject: [PATCH 208/290] parse ssh args and command --- internal/codespaces/ssh.go | 63 ++++++++++++++++++++++++++------ internal/codespaces/ssh_test.go | 64 +++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 10 deletions(-) create mode 100644 internal/codespaces/ssh_test.go diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 661caecdc..0eee21286 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -12,7 +12,7 @@ import ( // port-forwarding session. It runs until the shell is terminated // (including by cancellation of the context). func Shell(ctx context.Context, log logger, sshArgs []string, port int, destination string, usingCustomPort bool) error { - cmd, connArgs := newSSHCommand(ctx, port, destination, "") + cmd, connArgs := newSSHCommand(ctx, port, destination, sshArgs) if usingCustomPort { log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " ")) @@ -23,23 +23,21 @@ func Shell(ctx context.Context, log logger, sshArgs []string, port int, destinat // NewRemoteCommand returns an exec.Cmd that will securely run a shell // command on the remote machine. -func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd { - cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command) +func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) *exec.Cmd { + cmd, _ := newSSHCommand(ctx, tunnelPort, destination, sshArgs) return cmd } // newSSHCommand populates an exec.Cmd to run a command (or if blank, // an interactive shell) over ssh. -func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cmd, []string) { +func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string) { connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} - cmdArgs := []string{dst, "-C"} // Always use Compression - if command == "" { - // if we are in a shell send X11 and X11Trust - cmdArgs = append(cmdArgs, "-X", "-Y") - } - + cmdArgs, command := parseSSHArgs(cmdArgs) cmdArgs = append(cmdArgs, connArgs...) + cmdArgs = append(cmdArgs, "-C") // Compression + cmdArgs = append(cmdArgs, dst) // user@host + if command != "" { cmdArgs = append(cmdArgs, command) } @@ -51,3 +49,48 @@ func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cm return cmd, connArgs } + +var sshArgumentFlags = map[string]bool{ + "-b": true, + "-c": true, + "-D": true, + "-e": true, + "-F": true, + "-I": true, + "-i": true, + "-L": true, + "-l": true, + "-m": true, + "-O": true, + "-o": true, + "-p": true, + "-R": true, + "-S": true, + "-W": true, + "-w": true, +} + +func parseSSHArgs(sshArgs []string) ([]string, string) { + var ( + cmdArgs []string + command []string + flagArgument bool + ) + + for _, arg := range sshArgs { + switch { + case strings.HasPrefix(arg, "-"): + cmdArgs = append(cmdArgs, arg) + if _, ok := sshArgumentFlags[arg]; ok { + flagArgument = true + } + case flagArgument: + cmdArgs = append(cmdArgs, arg) + flagArgument = false + default: + command = append(command, arg) + } + } + + return cmdArgs, strings.Join(command, " ") +} diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go new file mode 100644 index 000000000..cd92b39a6 --- /dev/null +++ b/internal/codespaces/ssh_test.go @@ -0,0 +1,64 @@ +package codespaces + +import "testing" + +func TestParseSSHArgs(t *testing.T) { + type testCase struct { + Args []string + ParsedArgs []string + Command string + } + + testCases := []testCase{ + { + Args: []string{"-X", "-Y"}, + ParsedArgs: []string{"-X", "-Y"}, + Command: "", + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: "", + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test", "somecommand"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: "somecommand", + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test", "echo", "test"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: "echo test", + }, + { + Args: []string{"somecommand"}, + ParsedArgs: []string{}, + Command: "somecommand", + }, + { + Args: []string{"echo", "test"}, + ParsedArgs: []string{}, + Command: "echo test", + }, + { + Args: []string{"-v", "echo", "hello", "world"}, + ParsedArgs: []string{"-v"}, + Command: "echo hello world", + }, + } + + for _, tcase := range testCases { + args, command := parseSSHArgs(tcase.Args) + if len(args) != len(tcase.ParsedArgs) { + t.Fatalf("args do not match length of expected args. %#v, got '%d', expected: '%d'", tcase, len(args), len(tcase.ParsedArgs)) + } + for i, arg := range args { + if arg != tcase.ParsedArgs[i] { + t.Fatalf("arg does not match expected parsed arg. %v, got '%s', expected: '%s'", tcase, arg, tcase.ParsedArgs[i]) + } + } + if command != tcase.Command { + t.Fatalf("command does not match expected command. %v, got: '%s', expected: '%s'", tcase, command, tcase.Command) + } + } +} From 68f4cad1af18a1a10f8bea2ef8a3e78a11df6567 Mon Sep 17 00:00:00 2001 From: Raffaele Di Fazio Date: Thu, 16 Sep 2021 16:42:53 +0200 Subject: [PATCH 209/290] implement delete all with thresold Signed-off-by: Raffaele Di Fazio --- cmd/ghcs/delete.go | 47 +++++++++++- internal/api/api.go | 49 +++++++++---- internal/api/api_test.go | 155 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 234 insertions(+), 17 deletions(-) create mode 100644 internal/api/api_test.go diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 961310675..a07559cfe 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -15,10 +15,11 @@ import ( func newDeleteCmd() *cobra.Command { var ( - codespace string - allCodespaces bool - repo string - force bool + codespace string + allCodespaces bool + repo string + force bool + keepThresholdDays int ) log := output.NewLogger(os.Stdout, os.Stderr, false) @@ -29,6 +30,8 @@ func newDeleteCmd() *cobra.Command { switch { case allCodespaces && repo != "": return errors.New("both --all and --repo is not supported.") + case allCodespaces && keepThresholdDays != 0: + return deleteWithThreshold(log, keepThresholdDays) case allCodespaces: return deleteAll(log, force) case repo != "": @@ -43,6 +46,7 @@ func newDeleteCmd() *cobra.Command { deleteCmd.Flags().BoolVar(&allCodespaces, "all", false, "Delete all codespaces") deleteCmd.Flags().StringVarP(&repo, "repo", "r", "", "Delete all codespaces for a repository") deleteCmd.Flags().BoolVarP(&force, "force", "f", false, "Delete codespaces with unsaved changes without confirmation") + deleteCmd.Flags().IntVar(&keepThresholdDays, "days", 0, "Value of threshold for codespaces to keep") return deleteCmd } @@ -199,3 +203,38 @@ func confirmDeletion(codespace *api.Codespace, force bool) (bool, error) { return confirmed.Confirmed, nil } + +func deleteWithThreshold(log *output.Logger, keepThresholdDays int) error { + apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + ctx := context.Background() + + user, err := apiClient.GetUser(ctx) + if err != nil { + return fmt.Errorf("error getting user: %v", err) + } + + codespaces, err := apiClient.ListCodespaces(ctx, user) + if err != nil { + return fmt.Errorf("error getting codespaces: %v", err) + } + + codespacesToDelete, err := apiClient.FilterCodespacesToDelete(codespaces, keepThresholdDays) + if err != nil { + return err + } + + for _, c := range codespacesToDelete { + token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) + if err != nil { + return fmt.Errorf("error getting codespace token: %v", err) + } + + if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { + return fmt.Errorf("error deleting codespace: %v", err) + } + + log.Printf("Codespace deleted: %s\n", c.Name) + } + + return list(&listOptions{}) +} diff --git a/internal/api/api.go b/internal/api/api.go index 1246389e8..e0d506c02 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -35,19 +35,23 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/opentracing/opentracing-go" ) const githubAPI = "https://api.github.com" +var now func() time.Time = time.Now + type API struct { - token string - client *http.Client + token string + client *http.Client + githubAPI string } func New(token string) *API { - return &API{token, &http.Client{}} + return &API{token, &http.Client{}, githubAPI} } type User struct { @@ -55,7 +59,7 @@ type User struct { } func (a *API) GetUser(ctx context.Context) (*User, error) { - req, err := http.NewRequest(http.MethodGet, githubAPI+"/user", nil) + req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/user", nil) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) } @@ -100,7 +104,7 @@ type Repository struct { } func (a *API) GetRepository(ctx context.Context, nwo string) (*Repository, error) { - req, err := http.NewRequest(http.MethodGet, githubAPI+"/repos/"+strings.ToLower(nwo), nil) + req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/repos/"+strings.ToLower(nwo), nil) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) } @@ -133,6 +137,7 @@ type Codespace struct { Name string `json:"name"` GUID string `json:"guid"` CreatedAt string `json:"created_at"` + LastUsedAt string `json:"last_used_at"` Branch string `json:"branch"` RepositoryName string `json:"repository_name"` RepositoryNWO string `json:"repository_nwo"` @@ -168,7 +173,7 @@ type CodespaceEnvironmentConnection struct { func (a *API) ListCodespaces(ctx context.Context, user *User) ([]*Codespace, error) { req, err := http.NewRequest( - http.MethodGet, githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces", nil, + http.MethodGet, a.githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces", nil, ) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) @@ -215,7 +220,7 @@ func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName s req, err := http.NewRequest( http.MethodPost, - githubAPI+"/vscs_internal/user/"+ownerLogin+"/codespaces/"+codespaceName+"/token", + a.githubAPI+"/vscs_internal/user/"+ownerLogin+"/codespaces/"+codespaceName+"/token", bytes.NewBuffer(reqBody), ) if err != nil { @@ -249,7 +254,7 @@ func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName s func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) (*Codespace, error) { req, err := http.NewRequest( http.MethodGet, - githubAPI+"/vscs_internal/user/"+owner+"/codespaces/"+codespace, + a.githubAPI+"/vscs_internal/user/"+owner+"/codespaces/"+codespace, nil, ) if err != nil { @@ -283,7 +288,7 @@ func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codespace) error { req, err := http.NewRequest( http.MethodPost, - githubAPI+"/vscs_internal/proxy/environments/"+codespace.GUID+"/start", + a.githubAPI+"/vscs_internal/proxy/environments/"+codespace.GUID+"/start", nil, ) if err != nil { @@ -357,7 +362,7 @@ type SKU struct { } func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Repository, branch, location string) ([]*SKU, error) { - req, err := http.NewRequest(http.MethodGet, githubAPI+"/vscs_internal/user/"+user.Login+"/skus", nil) + req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/vscs_internal/user/"+user.Login+"/skus", nil) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) } @@ -407,7 +412,7 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos return nil, fmt.Errorf("error marshaling request: %w", err) } - req, err := http.NewRequest(http.MethodPost, githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces", bytes.NewBuffer(requestBody)) + req, err := http.NewRequest(http.MethodPost, a.githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces", bytes.NewBuffer(requestBody)) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) } @@ -437,7 +442,7 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos } func (a *API) DeleteCodespace(ctx context.Context, user *User, token, codespaceName string) error { - req, err := http.NewRequest(http.MethodDelete, githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces/"+codespaceName, nil) + req, err := http.NewRequest(http.MethodDelete, a.githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces/"+codespaceName, nil) if err != nil { return fmt.Errorf("error creating request: %w", err) } @@ -465,7 +470,7 @@ type getCodespaceRepositoryContentsResponse struct { } func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Codespace, path string) ([]byte, error) { - req, err := http.NewRequest(http.MethodGet, githubAPI+"/repos/"+codespace.RepositoryNWO+"/contents/"+path, nil) + req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/repos/"+codespace.RepositoryNWO+"/contents/"+path, nil) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) } @@ -507,6 +512,24 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod return decoded, nil } +func (a *API) FilterCodespacesToDelete(codespaces []*Codespace, keepThresholdDays int) ([]*Codespace, error) { + if keepThresholdDays < 0 { + return nil, fmt.Errorf("invalid value for threshold: %d", keepThresholdDays) + } + codespacesToDelete := []*Codespace{} + for _, codespace := range codespaces { + // get a date from a string representation + t, err := time.Parse(time.RFC3339, codespace.LastUsedAt) + if err != nil { + return nil, fmt.Errorf("error parsing last used at date: %v", err) + } + if t.Before(now().AddDate(0, 0, -keepThresholdDays)) && codespace.Environment.State == "Shutdown" { + codespacesToDelete = append(codespacesToDelete, codespace) + } + } + return codespacesToDelete, nil +} + func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http.Response, error) { // TODO(adonovan): use NewRequestWithContext(ctx) and drop ctx parameter. span, ctx := opentracing.StartSpanFromContext(ctx, spanName) diff --git a/internal/api/api_test.go b/internal/api/api_test.go new file mode 100644 index 000000000..6653248e1 --- /dev/null +++ b/internal/api/api_test.go @@ -0,0 +1,155 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestListCodespaces(t *testing.T) { + user := &User{ + Login: "testuser", + } + + codespaces := []*Codespace{ + { + Name: "testcodespace", + CreatedAt: "2021-08-09T10:10:24+02:00", + LastUsedAt: "2021-08-09T13:10:24+02:00", + }, + } + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := struct { + Codespaces []*Codespace `json:"codespaces"` + }{ + Codespaces: codespaces, + } + data, _ := json.Marshal(response) + fmt.Fprint(w, string(data)) + })) + defer svr.Close() + + api := API{ + githubAPI: svr.URL, + client: &http.Client{}, + token: "faketoken", + } + ctx := context.TODO() + codespaces, err := api.ListCodespaces(ctx, user) + if err != nil { + t.Fatal(err) + } + + if len(codespaces) != 1 { + t.Fatalf("expected 1 codespace, got %d", len(codespaces)) + } + + if codespaces[0].Name != "testcodespace" { + t.Fatalf("expected testcodespace, got %s", codespaces[0].Name) + } + +} + +func TestCleanupUnusedCodespaces(t *testing.T) { + type args struct { + codespaces []*Codespace + thresholdDays int + } + tests := []struct { + name string + now time.Time + args args + wantErr bool + deleted []*Codespace + }{ + { + name: "no codespaces is to be deleted", + + args: args{ + codespaces: []*Codespace{ + { + Name: "testcodespace", + CreatedAt: "2021-08-09T10:10:24+02:00", + LastUsedAt: "2021-08-09T13:10:24+02:00", + Environment: CodespaceEnvironment{ + State: "Shutdown", + }, + }, + }, + thresholdDays: 1, + }, + now: time.Date(2021, 8, 9, 20, 10, 24, 0, time.UTC), + deleted: []*Codespace{}, + }, + { + name: "one codespace is to be deleted", + + args: args{ + codespaces: []*Codespace{ + { + Name: "testcodespace", + CreatedAt: "2021-08-09T10:10:24+02:00", + LastUsedAt: "2021-08-09T13:10:24+02:00", + Environment: CodespaceEnvironment{ + State: "Shutdown", + }, + }, + }, + thresholdDays: 1, + }, + now: time.Date(2021, 8, 15, 20, 12, 24, 0, time.UTC), + deleted: []*Codespace{ + { + Name: "testcodespace", + CreatedAt: "2021-08-09T10:10:24+02:00", + LastUsedAt: "2021-08-09T13:10:24+02:00", + }, + }, + }, + { + name: "threshold is invalid", + + args: args{ + codespaces: []*Codespace{ + { + Name: "testcodespace", + CreatedAt: "2021-08-09T10:10:24+02:00", + LastUsedAt: "2021-08-09T13:10:24+02:00", + Environment: CodespaceEnvironment{ + State: "Shutdown", + }, + }, + }, + thresholdDays: -1, + }, + now: time.Date(2021, 8, 15, 20, 12, 24, 0, time.UTC), + wantErr: true, + deleted: []*Codespace{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + now = func() time.Time { + return tt.now + } + + a := &API{ + token: "testtoken", + client: &http.Client{}, + } + codespaces, err := a.FilterCodespacesToDelete(tt.args.codespaces, tt.args.thresholdDays) + if (err != nil) != tt.wantErr { + t.Errorf("API.CleanupUnusedCodespaces() error = %v, wantErr %v", err, tt.wantErr) + } + + if len(codespaces) != len(tt.deleted) { + t.Errorf("expected %d deleted codespaces, got %d", len(tt.deleted), len(codespaces)) + } + }) + } +} From 5cd90fea889c9a3a496fa4b16030933a5b58bcfe Mon Sep 17 00:00:00 2001 From: Raffaele Di Fazio Date: Thu, 16 Sep 2021 16:45:07 +0200 Subject: [PATCH 210/290] fix linter Signed-off-by: Raffaele Di Fazio --- cmd/ghcs/delete.go | 8 ++++---- internal/api/api.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index a07559cfe..cb718d47a 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -210,12 +210,12 @@ func deleteWithThreshold(log *output.Logger, keepThresholdDays int) error { user, err := apiClient.GetUser(ctx) if err != nil { - return fmt.Errorf("error getting user: %v", err) + return fmt.Errorf("error getting user: %w", err) } codespaces, err := apiClient.ListCodespaces(ctx, user) if err != nil { - return fmt.Errorf("error getting codespaces: %v", err) + return fmt.Errorf("error getting codespaces: %w", err) } codespacesToDelete, err := apiClient.FilterCodespacesToDelete(codespaces, keepThresholdDays) @@ -226,11 +226,11 @@ func deleteWithThreshold(log *output.Logger, keepThresholdDays int) error { for _, c := range codespacesToDelete { token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) if err != nil { - return fmt.Errorf("error getting codespace token: %v", err) + return fmt.Errorf("error getting codespace token: %w", err) } if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { - return fmt.Errorf("error deleting codespace: %v", err) + return fmt.Errorf("error deleting codespace: %w", err) } log.Printf("Codespace deleted: %s\n", c.Name) diff --git a/internal/api/api.go b/internal/api/api.go index e0d506c02..db26191a8 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -521,7 +521,7 @@ func (a *API) FilterCodespacesToDelete(codespaces []*Codespace, keepThresholdDay // get a date from a string representation t, err := time.Parse(time.RFC3339, codespace.LastUsedAt) if err != nil { - return nil, fmt.Errorf("error parsing last used at date: %v", err) + return nil, fmt.Errorf("error parsing last used at date: %w", err) } if t.Before(now().AddDate(0, 0, -keepThresholdDays)) && codespace.Environment.State == "Shutdown" { codespacesToDelete = append(codespacesToDelete, codespace) From 35e0f95243e1048033748de94afb932d6fa401e8 Mon Sep 17 00:00:00 2001 From: Raffaele Di Fazio Date: Thu, 16 Sep 2021 18:42:41 +0200 Subject: [PATCH 211/290] Update cmd/ghcs/delete.go Co-authored-by: CamiloGarciaLaRotta --- cmd/ghcs/delete.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index cb718d47a..8f3027f6f 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -46,7 +46,7 @@ func newDeleteCmd() *cobra.Command { deleteCmd.Flags().BoolVar(&allCodespaces, "all", false, "Delete all codespaces") deleteCmd.Flags().StringVarP(&repo, "repo", "r", "", "Delete all codespaces for a repository") deleteCmd.Flags().BoolVarP(&force, "force", "f", false, "Delete codespaces with unsaved changes without confirmation") - deleteCmd.Flags().IntVar(&keepThresholdDays, "days", 0, "Value of threshold for codespaces to keep") + deleteCmd.Flags().IntVar(&keepThresholdDays, "days", 0, "Minimum number of days that a codespace has to have to be deleted. Only shutdown codespaces will be considered for deletion.") return deleteCmd } From 22e9da790c92e6e65b5ec531a466856bda0fc2a3 Mon Sep 17 00:00:00 2001 From: Raffaele Di Fazio Date: Thu, 16 Sep 2021 18:43:16 +0200 Subject: [PATCH 212/290] Update internal/api/api_test.go Co-authored-by: CamiloGarciaLaRotta --- internal/api/api_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 6653248e1..22e9ac234 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -54,7 +54,7 @@ func TestListCodespaces(t *testing.T) { } -func TestCleanupUnusedCodespaces(t *testing.T) { +func TestDeleteCodespacesByAge(t *testing.T) { type args struct { codespaces []*Codespace thresholdDays int From 455dabb484f818cb88e1e204b98ff86e2cf5cb8f Mon Sep 17 00:00:00 2001 From: Raffaele Di Fazio Date: Thu, 16 Sep 2021 18:49:44 +0200 Subject: [PATCH 213/290] use named params Signed-off-by: Raffaele Di Fazio --- internal/api/api.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/internal/api/api.go b/internal/api/api.go index db26191a8..d7aee0f66 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -51,7 +51,11 @@ type API struct { } func New(token string) *API { - return &API{token, &http.Client{}, githubAPI} + return &API{ + token: token, + client: &http.Client{}, + githubAPI: githubAPI, + } } type User struct { From 42e47a98d7b51d0ea70ee1bcc5a09392a49080fc Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 16 Sep 2021 15:22:47 -0400 Subject: [PATCH 214/290] add docs, simplify map, error on invalid args --- cmd/ghcs/logs.go | 5 ++- internal/codespaces/ssh.go | 58 ++++++++++++++++----------------- internal/codespaces/ssh_test.go | 13 +++++++- internal/codespaces/states.go | 6 +++- 4 files changed, 50 insertions(+), 32 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index ccfb46236..725a243b0 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -82,9 +82,12 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow } dst := fmt.Sprintf("%s@localhost", sshUser) - cmd := codespaces.NewRemoteCommand( + cmd, err := codespaces.NewRemoteCommand( ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) + if err != nil { + return fmt.Errorf("remote command: %w", err) + } tunnelClosed := make(chan error, 1) go func() { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 0eee21286..b58741e34 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -2,6 +2,7 @@ package codespaces import ( "context" + "fmt" "os" "os/exec" "strconv" @@ -12,7 +13,10 @@ import ( // port-forwarding session. It runs until the shell is terminated // (including by cancellation of the context). func Shell(ctx context.Context, log logger, sshArgs []string, port int, destination string, usingCustomPort bool) error { - cmd, connArgs := newSSHCommand(ctx, port, destination, sshArgs) + cmd, connArgs, err := newSSHCommand(ctx, port, destination, sshArgs) + if err != nil { + return fmt.Errorf("failed to create ssh command: %w", err) + } if usingCustomPort { log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " ")) @@ -23,17 +27,27 @@ func Shell(ctx context.Context, log logger, sshArgs []string, port int, destinat // NewRemoteCommand returns an exec.Cmd that will securely run a shell // command on the remote machine. -func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) *exec.Cmd { - cmd, _ := newSSHCommand(ctx, tunnelPort, destination, sshArgs) - return cmd +func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) (*exec.Cmd, error) { + cmd, _, err := newSSHCommand(ctx, tunnelPort, destination, sshArgs) + return cmd, err } // newSSHCommand populates an exec.Cmd to run a command (or if blank, // an interactive shell) over ssh. -func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string) { +func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string, error) { connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} - cmdArgs, command := parseSSHArgs(cmdArgs) + // The ssh command syntax is: ssh [flags] user@host command [args...] + // There is no way to specify the user@host destination as a flag. + // Unfortunately, that means we need to know which user-provided words are + // SSH flags and which are command arguments so that we can place + // them before or after the destination, and that means we need to know all + // the flags and their arities. + cmdArgs, command, err := parseSSHArgs(cmdArgs) + if err != nil { + return nil, []string{}, err + } + cmdArgs = append(cmdArgs, connArgs...) cmdArgs = append(cmdArgs, "-C") // Compression cmdArgs = append(cmdArgs, dst) // user@host @@ -47,30 +61,12 @@ func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr - return cmd, connArgs + return cmd, connArgs, nil } -var sshArgumentFlags = map[string]bool{ - "-b": true, - "-c": true, - "-D": true, - "-e": true, - "-F": true, - "-I": true, - "-i": true, - "-L": true, - "-l": true, - "-m": true, - "-O": true, - "-o": true, - "-p": true, - "-R": true, - "-S": true, - "-W": true, - "-w": true, -} +var sshArgumentFlags = "-b-c-D-e-F-I-i-L-l-m-O-o-p-R-S-W-w" -func parseSSHArgs(sshArgs []string) ([]string, string) { +func parseSSHArgs(sshArgs []string) ([]string, string, error) { var ( cmdArgs []string command []string @@ -80,8 +76,12 @@ func parseSSHArgs(sshArgs []string) ([]string, string) { for _, arg := range sshArgs { switch { case strings.HasPrefix(arg, "-"): + if len(command) > 0 { + return []string{}, "", fmt.Errorf("invalid flag after command: %s", arg) + } + cmdArgs = append(cmdArgs, arg) - if _, ok := sshArgumentFlags[arg]; ok { + if strings.Contains(sshArgumentFlags, arg) { flagArgument = true } case flagArgument: @@ -92,5 +92,5 @@ func parseSSHArgs(sshArgs []string) ([]string, string) { } } - return cmdArgs, strings.Join(command, " ") + return cmdArgs, strings.Join(command, " "), nil } diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go index cd92b39a6..2847ffc9f 100644 --- a/internal/codespaces/ssh_test.go +++ b/internal/codespaces/ssh_test.go @@ -48,7 +48,11 @@ func TestParseSSHArgs(t *testing.T) { } for _, tcase := range testCases { - args, command := parseSSHArgs(tcase.Args) + args, command, err := parseSSHArgs(tcase.Args) + if err != nil { + t.Errorf("received unexpected error: %w", err) + } + if len(args) != len(tcase.ParsedArgs) { t.Fatalf("args do not match length of expected args. %#v, got '%d', expected: '%d'", tcase, len(args), len(tcase.ParsedArgs)) } @@ -62,3 +66,10 @@ func TestParseSSHArgs(t *testing.T) { } } } + +func TestParseSSHArgsError(t *testing.T) { + _, _, err := parseSSHArgs([]string{"-X", "test", "-Y"}) + if err == nil { + t.Error("expected an error for invalid args") + } +} diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 99683b51d..408f11941 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -89,10 +89,14 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u } func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace, user string) ([]PostCreateState, error) { - cmd := NewRemoteCommand( + cmd, err := NewRemoteCommand( ctx, tunnelPort, fmt.Sprintf("%s@localhost", user), "cat /workspaces/.codespaces/shared/postCreateOutput.json", ) + if err != nil { + return nil, fmt.Errorf("remote command: %w", err) + } + stdout := new(bytes.Buffer) cmd.Stdout = stdout if err := cmd.Run(); err != nil { From bc74c4aafab3c0f2bd25cccb7b2a796b4357a607 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Thu, 16 Sep 2021 18:24:43 -0400 Subject: [PATCH 215/290] make delete --repo parallel --- cmd/ghcs/delete.go | 53 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 7800a13c0..a3971f491 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "strings" + "sync" "github.com/github/ghcs/api" "github.com/github/ghcs/cmd/ghcs/output" @@ -116,28 +117,58 @@ func deleteByRepo(log *output.Logger, repo string) error { return fmt.Errorf("error getting codespaces: %v", err) } - var deleted bool - for _, c := range codespaces { - if !strings.EqualFold(c.RepositoryNWO, repo) { - continue - } - deleted = true - - token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) + delete := func(name string) error { + token, err := apiClient.GetCodespaceToken(ctx, user.Login, name) if err != nil { return fmt.Errorf("error getting codespace token: %v", err) } - if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { + if err := apiClient.DeleteCodespace(ctx, user, token, name); err != nil { return fmt.Errorf("error deleting codespace: %v", err) } - log.Printf("Codespace deleted: %s\n", c.Name) + return nil } - if !deleted { + // Perform deletions in parallel. + var ( + found bool + mu sync.Mutex // guards errs, logger + errs []error + wg sync.WaitGroup + ) + for _, c := range codespaces { + if !strings.EqualFold(c.RepositoryNWO, repo) { + continue + } + found = true + c := c + wg.Add(1) + go func() { + defer wg.Done() + err := delete(c.Name) + mu.Lock() + defer mu.Unlock() + if err != nil { + errs = append(errs, err) + } else { + log.Printf("Codespace deleted: %s\n", c.Name) + } + }() + } + if !found { return fmt.Errorf("No codespace was found for repository: %s", repo) } + wg.Wait() + + // Return first error, plus count of others. + if errs != nil { + err := errs[0] + if others := len(errs) - 1; others > 0 { + err = fmt.Errorf("%w (+%d more)", err, others) + } + return err + } return list(&listOptions{}) } From 29c2a17866cc03f3c701b7503aa26b90d0950c97 Mon Sep 17 00:00:00 2001 From: Raffaele Di Fazio Date: Fri, 17 Sep 2021 08:55:54 +0200 Subject: [PATCH 216/290] Update cmd/ghcs/delete.go Co-authored-by: Jose Garcia --- cmd/ghcs/delete.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 8f3027f6f..6aeedb4c5 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -46,7 +46,7 @@ func newDeleteCmd() *cobra.Command { deleteCmd.Flags().BoolVar(&allCodespaces, "all", false, "Delete all codespaces") deleteCmd.Flags().StringVarP(&repo, "repo", "r", "", "Delete all codespaces for a repository") deleteCmd.Flags().BoolVarP(&force, "force", "f", false, "Delete codespaces with unsaved changes without confirmation") - deleteCmd.Flags().IntVar(&keepThresholdDays, "days", 0, "Minimum number of days that a codespace has to have to be deleted. Only shutdown codespaces will be considered for deletion.") + deleteCmd.Flags().IntVar(&keepThresholdDays, "days", 0, "Minimum number of days since the codespace was created") return deleteCmd } From a4f1fa076b7c1d98d47af20bd698c3036ff50c1d Mon Sep 17 00:00:00 2001 From: Max Beizer Date: Fri, 17 Sep 2021 06:10:37 -0500 Subject: [PATCH 217/290] Fix up all the static-check warnings (#162) --- cmd/ghcs/common.go | 2 +- cmd/ghcs/create.go | 2 +- cmd/ghcs/delete.go | 4 ++-- cmd/ghcs/main.go | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index 2f4862e13..e71e3dfe4 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -15,7 +15,7 @@ import ( "golang.org/x/term" ) -var errNoCodespaces = errors.New("You have no codespaces.") +var errNoCodespaces = errors.New("you have no codespaces") func chooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (*api.Codespace, error) { codespaces, err := apiClient.ListCodespaces(ctx, user) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 8b7a2e7d9..2125176fd 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -82,7 +82,7 @@ func create(opts *createOptions) error { return fmt.Errorf("error getting machine type: %w", err) } if machine == "" { - return errors.New("There are no available machine types for this repository") + return errors.New("there are no available machine types for this repository") } log.Println("Creating your codespace...") diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 961310675..2c0de43ea 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -28,7 +28,7 @@ func newDeleteCmd() *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { switch { case allCodespaces && repo != "": - return errors.New("both --all and --repo is not supported.") + return errors.New("both --all and --repo is not supported") case allCodespaces: return deleteAll(log, force) case repo != "": @@ -166,7 +166,7 @@ func deleteByRepo(log *output.Logger, repo string, force bool) error { } if !deleted { - return fmt.Errorf("No codespace was found for repository: %s", repo) + return fmt.Errorf("no codespace was found for repository: %s", repo) } return list(&listOptions{}) diff --git a/cmd/ghcs/main.go b/cmd/ghcs/main.go index 651d98c1d..7903dad2a 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/main.go @@ -40,7 +40,7 @@ token to access the GitHub API with.`, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { if os.Getenv("GITHUB_TOKEN") == "" { - return tokenError + return errTokenMissing } return initLightstep(lightstep) }, @@ -51,10 +51,10 @@ token to access the GitHub API with.`, return root } -var tokenError = errors.New("GITHUB_TOKEN is missing") +var errTokenMissing = errors.New("GITHUB_TOKEN is missing") func explainError(w io.Writer, err error) { - if errors.Is(err, tokenError) { + if errors.Is(err, errTokenMissing) { fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo") fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") return From 054fec0ba117508cb761e527bcf7a30d449b9a89 Mon Sep 17 00:00:00 2001 From: Raffaele Di Fazio Date: Fri, 17 Sep 2021 14:45:08 +0200 Subject: [PATCH 218/290] address code comments Signed-off-by: Raffaele Di Fazio --- cmd/ghcs/delete.go | 55 +++++++++------------ internal/api/api.go | 24 +--------- internal/api/api_test.go | 101 --------------------------------------- 3 files changed, 23 insertions(+), 157 deletions(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 6aeedb4c5..4ad776cf9 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "strings" + "time" "github.com/AlecAivazis/survey/v2" "github.com/github/ghcs/cmd/ghcs/output" @@ -13,6 +14,8 @@ import ( "github.com/spf13/cobra" ) +var now func() time.Time = time.Now + func newDeleteCmd() *cobra.Command { var ( codespace string @@ -30,10 +33,8 @@ func newDeleteCmd() *cobra.Command { switch { case allCodespaces && repo != "": return errors.New("both --all and --repo is not supported.") - case allCodespaces && keepThresholdDays != 0: - return deleteWithThreshold(log, keepThresholdDays) case allCodespaces: - return deleteAll(log, force) + return deleteAll(log, force, keepThresholdDays) case repo != "": return deleteByRepo(log, repo, force) default: @@ -87,7 +88,7 @@ func delete_(log *output.Logger, codespaceName string, force bool) error { return list(&listOptions{}) } -func deleteAll(log *output.Logger, force bool) error { +func deleteAll(log *output.Logger, force bool, keepThresholdDays int) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() @@ -101,7 +102,12 @@ func deleteAll(log *output.Logger, force bool) error { return fmt.Errorf("error getting codespaces: %w", err) } - for _, c := range codespaces { + codespacesToDelete, err := filterCodespacesToDelete(codespaces, keepThresholdDays) + if err != nil { + return err + } + + for _, c := range codespacesToDelete { confirmed, err := confirmDeletion(c, force) if err != nil { return fmt.Errorf("deletion could not be confirmed: %w", err) @@ -204,37 +210,20 @@ func confirmDeletion(codespace *api.Codespace, force bool) (bool, error) { return confirmed.Confirmed, nil } -func deleteWithThreshold(log *output.Logger, keepThresholdDays int) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) - ctx := context.Background() - - user, err := apiClient.GetUser(ctx) - if err != nil { - return fmt.Errorf("error getting user: %w", err) +func filterCodespacesToDelete(codespaces []*api.Codespace, keepThresholdDays int) ([]*api.Codespace, error) { + if keepThresholdDays < 0 { + return nil, fmt.Errorf("invalid value for threshold: %d", keepThresholdDays) } - - codespaces, err := apiClient.ListCodespaces(ctx, user) - if err != nil { - return fmt.Errorf("error getting codespaces: %w", err) - } - - codespacesToDelete, err := apiClient.FilterCodespacesToDelete(codespaces, keepThresholdDays) - if err != nil { - return err - } - - for _, c := range codespacesToDelete { - token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) + codespacesToDelete := []*api.Codespace{} + for _, codespace := range codespaces { + // get a date from a string representation + t, err := time.Parse(time.RFC3339, codespace.LastUsedAt) if err != nil { - return fmt.Errorf("error getting codespace token: %w", err) + return nil, fmt.Errorf("error parsing last used at date: %w", err) } - - if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { - return fmt.Errorf("error deleting codespace: %w", err) + if t.Before(now().AddDate(0, 0, -keepThresholdDays)) && codespace.Environment.State == "Shutdown" { + codespacesToDelete = append(codespacesToDelete, codespace) } - - log.Printf("Codespace deleted: %s\n", c.Name) } - - return list(&listOptions{}) + return codespacesToDelete, nil } diff --git a/internal/api/api.go b/internal/api/api.go index d7aee0f66..823e3cf86 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -31,19 +31,15 @@ import ( "encoding/json" "errors" "fmt" + "github.com/opentracing/opentracing-go" "io/ioutil" "net/http" "strconv" "strings" - "time" - - "github.com/opentracing/opentracing-go" ) const githubAPI = "https://api.github.com" -var now func() time.Time = time.Now - type API struct { token string client *http.Client @@ -516,24 +512,6 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod return decoded, nil } -func (a *API) FilterCodespacesToDelete(codespaces []*Codespace, keepThresholdDays int) ([]*Codespace, error) { - if keepThresholdDays < 0 { - return nil, fmt.Errorf("invalid value for threshold: %d", keepThresholdDays) - } - codespacesToDelete := []*Codespace{} - for _, codespace := range codespaces { - // get a date from a string representation - t, err := time.Parse(time.RFC3339, codespace.LastUsedAt) - if err != nil { - return nil, fmt.Errorf("error parsing last used at date: %w", err) - } - if t.Before(now().AddDate(0, 0, -keepThresholdDays)) && codespace.Environment.State == "Shutdown" { - codespacesToDelete = append(codespacesToDelete, codespace) - } - } - return codespacesToDelete, nil -} - func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http.Response, error) { // TODO(adonovan): use NewRequestWithContext(ctx) and drop ctx parameter. span, ctx := opentracing.StartSpanFromContext(ctx, spanName) diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 22e9ac234..c1f4e5c19 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -7,7 +7,6 @@ import ( "net/http" "net/http/httptest" "testing" - "time" ) func TestListCodespaces(t *testing.T) { @@ -53,103 +52,3 @@ func TestListCodespaces(t *testing.T) { } } - -func TestDeleteCodespacesByAge(t *testing.T) { - type args struct { - codespaces []*Codespace - thresholdDays int - } - tests := []struct { - name string - now time.Time - args args - wantErr bool - deleted []*Codespace - }{ - { - name: "no codespaces is to be deleted", - - args: args{ - codespaces: []*Codespace{ - { - Name: "testcodespace", - CreatedAt: "2021-08-09T10:10:24+02:00", - LastUsedAt: "2021-08-09T13:10:24+02:00", - Environment: CodespaceEnvironment{ - State: "Shutdown", - }, - }, - }, - thresholdDays: 1, - }, - now: time.Date(2021, 8, 9, 20, 10, 24, 0, time.UTC), - deleted: []*Codespace{}, - }, - { - name: "one codespace is to be deleted", - - args: args{ - codespaces: []*Codespace{ - { - Name: "testcodespace", - CreatedAt: "2021-08-09T10:10:24+02:00", - LastUsedAt: "2021-08-09T13:10:24+02:00", - Environment: CodespaceEnvironment{ - State: "Shutdown", - }, - }, - }, - thresholdDays: 1, - }, - now: time.Date(2021, 8, 15, 20, 12, 24, 0, time.UTC), - deleted: []*Codespace{ - { - Name: "testcodespace", - CreatedAt: "2021-08-09T10:10:24+02:00", - LastUsedAt: "2021-08-09T13:10:24+02:00", - }, - }, - }, - { - name: "threshold is invalid", - - args: args{ - codespaces: []*Codespace{ - { - Name: "testcodespace", - CreatedAt: "2021-08-09T10:10:24+02:00", - LastUsedAt: "2021-08-09T13:10:24+02:00", - Environment: CodespaceEnvironment{ - State: "Shutdown", - }, - }, - }, - thresholdDays: -1, - }, - now: time.Date(2021, 8, 15, 20, 12, 24, 0, time.UTC), - wantErr: true, - deleted: []*Codespace{}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - now = func() time.Time { - return tt.now - } - - a := &API{ - token: "testtoken", - client: &http.Client{}, - } - codespaces, err := a.FilterCodespacesToDelete(tt.args.codespaces, tt.args.thresholdDays) - if (err != nil) != tt.wantErr { - t.Errorf("API.CleanupUnusedCodespaces() error = %v, wantErr %v", err, tt.wantErr) - } - - if len(codespaces) != len(tt.deleted) { - t.Errorf("expected %d deleted codespaces, got %d", len(tt.deleted), len(codespaces)) - } - }) - } -} From c6b5fb5ba336cc3160c637bab413342c15ffcfc5 Mon Sep 17 00:00:00 2001 From: Raffaele Di Fazio Date: Fri, 17 Sep 2021 14:55:50 +0200 Subject: [PATCH 219/290] add the tests Signed-off-by: Raffaele Di Fazio --- cmd/ghcs/delete_test.go | 103 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 cmd/ghcs/delete_test.go diff --git a/cmd/ghcs/delete_test.go b/cmd/ghcs/delete_test.go new file mode 100644 index 000000000..efb59124e --- /dev/null +++ b/cmd/ghcs/delete_test.go @@ -0,0 +1,103 @@ +package main + +import ( + "github.com/github/ghcs/internal/api" + "testing" + "time" +) + +func TestFilterCodespacesToDelete(t *testing.T) { + type args struct { + codespaces []*api.Codespace + thresholdDays int + } + tests := []struct { + name string + now time.Time + args args + wantErr bool + deleted []*api.Codespace + }{ + { + name: "no codespaces is to be deleted", + + args: args{ + codespaces: []*api.Codespace{ + { + Name: "testcodespace", + CreatedAt: "2021-08-09T10:10:24+02:00", + LastUsedAt: "2021-08-09T13:10:24+02:00", + Environment: api.CodespaceEnvironment{ + State: "Shutdown", + }, + }, + }, + thresholdDays: 1, + }, + now: time.Date(2021, 8, 9, 20, 10, 24, 0, time.UTC), + deleted: []*api.Codespace{}, + }, + { + name: "one codespace is to be deleted", + + args: args{ + codespaces: []*api.Codespace{ + { + Name: "testcodespace", + CreatedAt: "2021-08-09T10:10:24+02:00", + LastUsedAt: "2021-08-09T13:10:24+02:00", + Environment: api.CodespaceEnvironment{ + State: "Shutdown", + }, + }, + }, + thresholdDays: 1, + }, + now: time.Date(2021, 8, 15, 20, 12, 24, 0, time.UTC), + deleted: []*api.Codespace{ + { + Name: "testcodespace", + CreatedAt: "2021-08-09T10:10:24+02:00", + LastUsedAt: "2021-08-09T13:10:24+02:00", + }, + }, + }, + { + name: "threshold is invalid", + + args: args{ + codespaces: []*api.Codespace{ + { + Name: "testcodespace", + CreatedAt: "2021-08-09T10:10:24+02:00", + LastUsedAt: "2021-08-09T13:10:24+02:00", + Environment: api.CodespaceEnvironment{ + State: "Shutdown", + }, + }, + }, + thresholdDays: -1, + }, + now: time.Date(2021, 8, 15, 20, 12, 24, 0, time.UTC), + wantErr: true, + deleted: []*api.Codespace{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + now = func() time.Time { + return tt.now + } + + codespaces, err := filterCodespacesToDelete(tt.args.codespaces, tt.args.thresholdDays) + if (err != nil) != tt.wantErr { + t.Errorf("API.CleanupUnusedCodespaces() error = %v, wantErr %v", err, tt.wantErr) + } + + if len(codespaces) != len(tt.deleted) { + t.Errorf("expected %d deleted codespaces, got %d", len(tt.deleted), len(codespaces)) + } + }) + } +} From 4de457281375b5306b0b5644c7446033a27d7e0c Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 17 Sep 2021 09:31:05 -0400 Subject: [PATCH 220/290] add comment --- cmd/ghcs/delete.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 25b77bd5e..1ff8122b3 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -150,7 +150,8 @@ func deleteByRepo(log *output.Logger, repo string, force bool) error { return nil } - // Perform deletions in parallel. + // Perform deletions in parallel, for performance, + // and to ensure all are attempted even if any one fails. var ( found bool mu sync.Mutex // guards errs, logger From 747d7e717354bcb0671e2e36ae2d31c35b42f437 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 17 Sep 2021 09:45:49 -0400 Subject: [PATCH 221/290] Restore confirmation to delete -r, lost in botched merge --- cmd/ghcs/delete.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index d7fb1359e..2bb5ea7ac 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -162,6 +162,19 @@ func deleteByRepo(log *output.Logger, repo string, force bool) error { if !strings.EqualFold(c.RepositoryNWO, repo) { continue } + + confirmed, err := confirmDeletion(c, force) + if err != nil { + mu.Lock() + errs = append(errs, fmt.Errorf("deletion could not be confirmed: %w", err)) + mu.Unlock() + continue + } + + if !confirmed { + continue + } + found = true c := c wg.Add(1) From d23eca8c5fa8c9b40c679dc05884c2b63eefba72 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 17 Sep 2021 09:51:11 -0400 Subject: [PATCH 222/290] remove "list" operation from "delete -r" command --- cmd/ghcs/delete.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index d7fb1359e..a4516f2b4 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -191,7 +191,7 @@ func deleteByRepo(log *output.Logger, repo string, force bool) error { return err } - return list(&listOptions{}) + return nil } func confirmDeletion(codespace *api.Codespace, force bool) (bool, error) { From ce4bbe5bd862917f5cc51478c69b2e1d0bd02639 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Fri, 17 Sep 2021 10:13:35 -0400 Subject: [PATCH 223/290] list: show branch (not name) in branch column --- cmd/ghcs/list.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 72a25becc..fb8d83c78 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -55,7 +55,7 @@ func list(opts *listOptions) error { table.Append([]string{ codespace.Name, codespace.RepositoryNWO, - codespace.Name + dirtyStar(codespace.Environment.GitStatus), + codespace.Branch + dirtyStar(codespace.Environment.GitStatus), codespace.Environment.State, codespace.CreatedAt, }) From c2f3537a322e25b8ffdfcd6ab31b9081f8695995 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 17 Sep 2021 16:26:20 +0200 Subject: [PATCH 224/290] Separate "main" package from "ghcs" package To make "ghcs" importable, this separates out the `main()` function into its own package that lives under "cmd/ghcs/main". Typically the main package would be called "cmd/ghcs", but we wanted to leave the current ghcs implementation where it is to avoid causing conflicts with current work in progress. Co-authored-by: Jose Garcia --- cmd/ghcs/code.go | 6 +----- cmd/ghcs/common.go | 2 +- cmd/ghcs/create.go | 6 +----- cmd/ghcs/delete.go | 6 +----- cmd/ghcs/list.go | 6 +----- cmd/ghcs/logs.go | 6 +----- cmd/ghcs/main/main.go | 26 ++++++++++++++++++++++++++ cmd/ghcs/ports.go | 6 +----- cmd/ghcs/{main.go => root.go} | 35 ++++++++++++++--------------------- cmd/ghcs/ssh.go | 6 +----- 10 files changed, 48 insertions(+), 57 deletions(-) create mode 100644 cmd/ghcs/main/main.go rename cmd/ghcs/{main.go => root.go} (77%) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index d905c75ac..245a362be 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "context" @@ -32,10 +32,6 @@ func newCodeCmd() *cobra.Command { return codeCmd } -func init() { - rootCmd.AddCommand(newCodeCmd()) -} - func code(codespaceName string, useInsiders bool) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index e71e3dfe4..f15d2fef6 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -1,4 +1,4 @@ -package main +package ghcs // This file defines functions common to the entire ghcs command set. diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 2125176fd..493a70247 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "context" @@ -42,10 +42,6 @@ func newCreateCmd() *cobra.Command { return createCmd } -func init() { - rootCmd.AddCommand(newCreateCmd()) -} - func create(opts *createOptions) error { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 2c0de43ea..b75def7b6 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "context" @@ -47,10 +47,6 @@ func newDeleteCmd() *cobra.Command { return deleteCmd } -func init() { - rootCmd.AddCommand(newDeleteCmd()) -} - func delete_(log *output.Logger, codespaceName string, force bool) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 72a25becc..09b0ea5b0 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "context" @@ -31,10 +31,6 @@ func newListCmd() *cobra.Command { return listCmd } -func init() { - rootCmd.AddCommand(newListCmd()) -} - func list(opts *listOptions) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index f65fa1109..40cac88ed 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "context" @@ -36,10 +36,6 @@ func newLogsCmd() *cobra.Command { return logsCmd } -func init() { - rootCmd.AddCommand(newLogsCmd()) -} - func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) error { // Ensure all child tasks (port forwarding, remote exec) terminate before return. ctx, cancel := context.WithCancel(ctx) diff --git a/cmd/ghcs/main/main.go b/cmd/ghcs/main/main.go new file mode 100644 index 000000000..01dde1270 --- /dev/null +++ b/cmd/ghcs/main/main.go @@ -0,0 +1,26 @@ +package main + +import ( + "errors" + "fmt" + "io" + "os" + + "github.com/github/ghcs/cmd/ghcs" +) + +func main() { + rootCmd := ghcs.NewRootCmd() + if err := rootCmd.Execute(); err != nil { + explainError(os.Stderr, err) + os.Exit(1) + } +} + +func explainError(w io.Writer, err error) { + if errors.Is(err, ghcs.ErrTokenMissing) { + fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo") + fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") + return + } +} diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 7bc53c441..052dfd37c 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "bytes" @@ -47,10 +47,6 @@ func newPortsCmd() *cobra.Command { return portsCmd } -func init() { - rootCmd.AddCommand(newPortsCmd()) -} - func ports(codespaceName string, asJSON bool) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() diff --git a/cmd/ghcs/main.go b/cmd/ghcs/root.go similarity index 77% rename from cmd/ghcs/main.go rename to cmd/ghcs/root.go index 7903dad2a..6db4144a8 100644 --- a/cmd/ghcs/main.go +++ b/cmd/ghcs/root.go @@ -1,9 +1,8 @@ -package main +package ghcs import ( "errors" "fmt" - "io" "log" "os" "strconv" @@ -14,18 +13,12 @@ import ( "github.com/spf13/cobra" ) -func main() { - if err := rootCmd.Execute(); err != nil { - explainError(os.Stderr, err) - os.Exit(1) - } -} - var version = "DEV" // Replaced in the release build process (by GoReleaser or Homebrew) by the git tag version number. -var rootCmd = newRootCmd() +// GithubToken is a temporary stopgap to make the token configurable by apps that import this package +var GithubToken = os.Getenv("GITHUB_TOKEN") -func newRootCmd() *cobra.Command { +func NewRootCmd() *cobra.Command { var lightstep string root := &cobra.Command{ @@ -40,7 +33,7 @@ token to access the GitHub API with.`, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { if os.Getenv("GITHUB_TOKEN") == "" { - return errTokenMissing + return ErrTokenMissing } return initLightstep(lightstep) }, @@ -48,18 +41,18 @@ token to access the GitHub API with.`, root.PersistentFlags().StringVar(&lightstep, "lightstep", "", "Lightstep tracing endpoint (service:token@host:port)") + root.AddCommand(newCodeCmd()) + root.AddCommand(newCreateCmd()) + root.AddCommand(newDeleteCmd()) + root.AddCommand(newListCmd()) + root.AddCommand(newLogsCmd()) + root.AddCommand(newPortsCmd()) + root.AddCommand(newSSHCmd()) + return root } -var errTokenMissing = errors.New("GITHUB_TOKEN is missing") - -func explainError(w io.Writer, err error) { - if errors.Is(err, errTokenMissing) { - fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo") - fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") - return - } -} +var ErrTokenMissing = errors.New("GITHUB_TOKEN is missing") // initLightstep parses the --lightstep=service:token@host:port flag and // enables tracing if non-empty. diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 6d5f2376b..527ae120f 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "context" @@ -32,10 +32,6 @@ func newSSHCmd() *cobra.Command { return sshCmd } -func init() { - rootCmd.AddCommand(newSSHCmd()) -} - func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPort int) error { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) From 8c0c7a8e19c5f971f434cb2b69c9e8a4dcbbccdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 17 Sep 2021 16:29:35 +0200 Subject: [PATCH 225/290] Make GITHUB_TOKEN configurable through Go member Co-authored-by: Jose Garcia --- cmd/ghcs/code.go | 3 +-- cmd/ghcs/create.go | 2 +- cmd/ghcs/delete.go | 6 +++--- cmd/ghcs/list.go | 2 +- cmd/ghcs/logs.go | 2 +- cmd/ghcs/ports.go | 6 +++--- cmd/ghcs/ssh.go | 2 +- 7 files changed, 11 insertions(+), 12 deletions(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 245a362be..9f09438d5 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net/url" - "os" "github.com/github/ghcs/internal/api" "github.com/skratchdot/open-golang/open" @@ -33,7 +32,7 @@ func newCodeCmd() *cobra.Command { } func code(codespaceName string, useInsiders bool) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) ctx := context.Background() user, err := apiClient.GetUser(ctx) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 493a70247..45aa794e6 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -44,7 +44,7 @@ func newCreateCmd() *cobra.Command { func create(opts *createOptions) error { ctx := context.Background() - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) locationCh := getLocation(ctx, apiClient) userCh := getUser(ctx, apiClient) log := output.NewLogger(os.Stdout, os.Stderr, false) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index b75def7b6..34c1bc095 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -48,7 +48,7 @@ func newDeleteCmd() *cobra.Command { } func delete_(log *output.Logger, codespaceName string, force bool) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) ctx := context.Background() user, err := apiClient.GetUser(ctx) @@ -80,7 +80,7 @@ func delete_(log *output.Logger, codespaceName string, force bool) error { } func deleteAll(log *output.Logger, force bool) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) ctx := context.Background() user, err := apiClient.GetUser(ctx) @@ -119,7 +119,7 @@ func deleteAll(log *output.Logger, force bool) error { } func deleteByRepo(log *output.Logger, repo string, force bool) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) ctx := context.Background() user, err := apiClient.GetUser(ctx) diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 09b0ea5b0..ccc150f08 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -32,7 +32,7 @@ func newListCmd() *cobra.Command { } func list(opts *listOptions) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) ctx := context.Background() user, err := apiClient.GetUser(ctx) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 40cac88ed..4051cc209 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -41,7 +41,7 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow ctx, cancel := context.WithCancel(ctx) defer cancel() - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) user, err := apiClient.GetUser(ctx) if err != nil { diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 052dfd37c..ebfd281cd 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -48,7 +48,7 @@ func newPortsCmd() *cobra.Command { } func ports(codespaceName string, asJSON bool) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) ctx := context.Background() log := output.NewLogger(os.Stdout, os.Stderr, asJSON) @@ -196,7 +196,7 @@ func newPortsPrivateCmd() *cobra.Command { func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) error { ctx := context.Background() - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) user, err := apiClient.GetUser(ctx) if err != nil { @@ -258,7 +258,7 @@ func newPortsForwardCmd() *cobra.Command { func forwardPorts(log *output.Logger, codespaceName string, ports []string) error { ctx := context.Background() - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) portPairs, err := getPortPairs(ports) if err != nil { diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 527ae120f..5063e8fc9 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -37,7 +37,7 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo ctx, cancel := context.WithCancel(ctx) defer cancel() - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) + apiClient := api.New(GithubToken) log := output.NewLogger(os.Stdout, os.Stderr, false) user, err := apiClient.GetUser(ctx) From 60d066f0a69ad86e6a09053284bc4beef924dfa0 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 17 Sep 2021 11:51:37 -0400 Subject: [PATCH 226/290] PR Feedback - return nil for slices - handle `-L -l` case - document `parseSSHArgs` --- internal/codespaces/ssh.go | 43 +++++++++++++++++---------------- internal/codespaces/ssh_test.go | 35 ++++++++++++++++++--------- 2 files changed, 45 insertions(+), 33 deletions(-) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index b58741e34..33fbd092a 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -45,15 +45,15 @@ func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) // the flags and their arities. cmdArgs, command, err := parseSSHArgs(cmdArgs) if err != nil { - return nil, []string{}, err + return nil, nil, err } cmdArgs = append(cmdArgs, connArgs...) cmdArgs = append(cmdArgs, "-C") // Compression cmdArgs = append(cmdArgs, dst) // user@host - if command != "" { - cmdArgs = append(cmdArgs, command) + if command != nil { + cmdArgs = append(cmdArgs, command...) } cmd := exec.CommandContext(ctx, "ssh", cmdArgs...) @@ -64,33 +64,34 @@ func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) return cmd, connArgs, nil } -var sshArgumentFlags = "-b-c-D-e-F-I-i-L-l-m-O-o-p-R-S-W-w" - -func parseSSHArgs(sshArgs []string) ([]string, string, error) { +// parseSSHArgs parses SSH arguments into two distinct slices of flags +// and command. It returns an error if flags are found after a command +// or if a unary flag is provided without an argument. +func parseSSHArgs(args []string) ([]string, []string, error) { var ( - cmdArgs []string - command []string - flagArgument bool + cmdArgs []string + command []string ) - for _, arg := range sshArgs { - switch { - case strings.HasPrefix(arg, "-"): - if len(command) > 0 { - return []string{}, "", fmt.Errorf("invalid flag after command: %s", arg) + for i := 0; i < len(args); i++ { + arg := args[i] + if strings.HasPrefix(arg, "-") { + if command != nil { + return nil, nil, fmt.Errorf("invalid flag after command: %s", arg) } cmdArgs = append(cmdArgs, arg) - if strings.Contains(sshArgumentFlags, arg) { - flagArgument = true + if strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) { + if i++; i == len(args) { + return nil, nil, fmt.Errorf("invalid unary flag without argument: %s", arg) + } + + cmdArgs = append(cmdArgs, args[i]) } - case flagArgument: - cmdArgs = append(cmdArgs, arg) - flagArgument = false - default: + } else { command = append(command, arg) } } - return cmdArgs, strings.Join(command, " "), nil + return cmdArgs, command, nil } diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go index 2847ffc9f..04d52b090 100644 --- a/internal/codespaces/ssh_test.go +++ b/internal/codespaces/ssh_test.go @@ -6,44 +6,49 @@ func TestParseSSHArgs(t *testing.T) { type testCase struct { Args []string ParsedArgs []string - Command string + Command []string } testCases := []testCase{ { Args: []string{"-X", "-Y"}, ParsedArgs: []string{"-X", "-Y"}, - Command: "", + Command: nil, }, { Args: []string{"-X", "-Y", "-o", "someoption=test"}, ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, - Command: "", + Command: nil, }, { Args: []string{"-X", "-Y", "-o", "someoption=test", "somecommand"}, ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, - Command: "somecommand", + Command: []string{"somecommand"}, }, { Args: []string{"-X", "-Y", "-o", "someoption=test", "echo", "test"}, ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, - Command: "echo test", + Command: []string{"echo", "test"}, }, { Args: []string{"somecommand"}, ParsedArgs: []string{}, - Command: "somecommand", + Command: []string{"somecommand"}, }, { Args: []string{"echo", "test"}, ParsedArgs: []string{}, - Command: "echo test", + Command: []string{"echo", "test"}, }, { Args: []string{"-v", "echo", "hello", "world"}, ParsedArgs: []string{"-v"}, - Command: "echo hello world", + Command: []string{"echo", "hello", "world"}, + }, + { + Args: []string{"-L", "-l"}, + ParsedArgs: []string{"-L", "-l"}, + Command: nil, }, } @@ -54,15 +59,21 @@ func TestParseSSHArgs(t *testing.T) { } if len(args) != len(tcase.ParsedArgs) { - t.Fatalf("args do not match length of expected args. %#v, got '%d', expected: '%d'", tcase, len(args), len(tcase.ParsedArgs)) + t.Fatalf("args do not match length of expected args. %#v, got '%d'", tcase, len(args)) } + if len(command) != len(tcase.Command) { + t.Fatalf("command dooes not match length of expected command. %#v, got '%d'", tcase, len(command)) + } + for i, arg := range args { if arg != tcase.ParsedArgs[i] { - t.Fatalf("arg does not match expected parsed arg. %v, got '%s', expected: '%s'", tcase, arg, tcase.ParsedArgs[i]) + t.Fatalf("arg does not match expected parsed arg. %v, got '%s'", tcase, arg) } } - if command != tcase.Command { - t.Fatalf("command does not match expected command. %v, got: '%s', expected: '%s'", tcase, command, tcase.Command) + for i, c := range command { + if c != tcase.Command[i] { + t.Fatalf("command does not match expected command. %v, got: '%v'", tcase, command) + } } } } From 54265afda00db981d030cffa6c9564161e096ca9 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 17 Sep 2021 13:43:23 -0400 Subject: [PATCH 227/290] PR Feedback - use named returns - handle command flags + test case - simplify tests --- internal/codespaces/ssh.go | 20 +++++-------- internal/codespaces/ssh_test.go | 52 ++++++++++++++++++--------------- 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 33fbd092a..e99f8971d 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -67,23 +67,19 @@ func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) // parseSSHArgs parses SSH arguments into two distinct slices of flags // and command. It returns an error if flags are found after a command // or if a unary flag is provided without an argument. -func parseSSHArgs(args []string) ([]string, []string, error) { - var ( - cmdArgs []string - command []string - ) - +func parseSSHArgs(args []string) (cmdArgs []string, command []string, err error) { for i := 0; i < len(args); i++ { arg := args[i] - if strings.HasPrefix(arg, "-") { - if command != nil { - return nil, nil, fmt.Errorf("invalid flag after command: %s", arg) - } + if command != nil { + command = append(command, arg) + continue + } + if strings.HasPrefix(arg, "-") { cmdArgs = append(cmdArgs, arg) - if strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) { + if len(arg) == 2 && strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) { if i++; i == len(args) { - return nil, nil, fmt.Errorf("invalid unary flag without argument: %s", arg) + return nil, nil, fmt.Errorf("ssh flag: %s requires an argument", arg) } cmdArgs = append(cmdArgs, args[i]) diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go index 04d52b090..5450adf1a 100644 --- a/internal/codespaces/ssh_test.go +++ b/internal/codespaces/ssh_test.go @@ -1,12 +1,16 @@ package codespaces -import "testing" +import ( + "fmt" + "testing" +) func TestParseSSHArgs(t *testing.T) { type testCase struct { Args []string ParsedArgs []string Command []string + Error bool } testCases := []testCase{ @@ -50,37 +54,39 @@ func TestParseSSHArgs(t *testing.T) { ParsedArgs: []string{"-L", "-l"}, Command: nil, }, + { + Args: []string{"-v", "echo", "-n", "test"}, + ParsedArgs: []string{"-v"}, + Command: []string{"echo", "-n", "test"}, + }, + { + Args: []string{"-b"}, + ParsedArgs: nil, + Command: nil, + Error: true, + }, } for _, tcase := range testCases { args, command, err := parseSSHArgs(tcase.Args) - if err != nil { - t.Errorf("received unexpected error: %w", err) + if err != nil && !tcase.Error { + t.Errorf("unexpected error: %v on test case: %#v", err, tcase) + continue } - if len(args) != len(tcase.ParsedArgs) { - t.Fatalf("args do not match length of expected args. %#v, got '%d'", tcase, len(args)) - } - if len(command) != len(tcase.Command) { - t.Fatalf("command dooes not match length of expected command. %#v, got '%d'", tcase, len(command)) + if tcase.Error && err == nil { + t.Errorf("expected error and got nil: %#v", tcase) + continue } - for i, arg := range args { - if arg != tcase.ParsedArgs[i] { - t.Fatalf("arg does not match expected parsed arg. %v, got '%s'", tcase, arg) - } + argsStr, parsedArgsStr := fmt.Sprintf("%s", args), fmt.Sprintf("%s", tcase.ParsedArgs) + if argsStr != parsedArgsStr { + t.Errorf("args do not match parsed args. got: '%s', expected: '%s'", argsStr, parsedArgsStr) } - for i, c := range command { - if c != tcase.Command[i] { - t.Fatalf("command does not match expected command. %v, got: '%v'", tcase, command) - } + + commandStr, parsedCommandStr := fmt.Sprintf("%s", command), fmt.Sprintf("%s", tcase.Command) + if commandStr != parsedCommandStr { + t.Errorf("command does not match parsed command. got: '%s', expected: '%s'", commandStr, parsedCommandStr) } } } - -func TestParseSSHArgsError(t *testing.T) { - _, _, err := parseSSHArgs([]string{"-X", "test", "-Y"}) - if err == nil { - t.Error("expected an error for invalid args") - } -} From 76037ee75367125bdea702aa92885947cff3973c Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 17 Sep 2021 13:54:00 -0400 Subject: [PATCH 228/290] Update docs, simplify loop to append to command --- internal/codespaces/ssh.go | 16 +++++++--------- internal/codespaces/ssh_test.go | 2 +- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index e99f8971d..4cd0a4c92 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -64,16 +64,11 @@ func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) return cmd, connArgs, nil } -// parseSSHArgs parses SSH arguments into two distinct slices of flags -// and command. It returns an error if flags are found after a command -// or if a unary flag is provided without an argument. +// parseSSHArgs parses SSH arguments into two distinct slices of flags and command. +// It returns an error if a unary flag is provided without an argument. func parseSSHArgs(args []string) (cmdArgs []string, command []string, err error) { for i := 0; i < len(args); i++ { arg := args[i] - if command != nil { - command = append(command, arg) - continue - } if strings.HasPrefix(arg, "-") { cmdArgs = append(cmdArgs, arg) @@ -84,9 +79,12 @@ func parseSSHArgs(args []string) (cmdArgs []string, command []string, err error) cmdArgs = append(cmdArgs, args[i]) } - } else { - command = append(command, arg) + continue } + + // if we've started parsing the command, append all further args to it + command = append(command, args[i:]...) + break } return cmdArgs, command, nil diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go index 5450adf1a..ed6922762 100644 --- a/internal/codespaces/ssh_test.go +++ b/internal/codespaces/ssh_test.go @@ -69,7 +69,7 @@ func TestParseSSHArgs(t *testing.T) { for _, tcase := range testCases { args, command, err := parseSSHArgs(tcase.Args) - if err != nil && !tcase.Error { + if !tcase.Error && err != nil { t.Errorf("unexpected error: %v on test case: %#v", err, tcase) continue } From 65e1c6f789fb52415b74b6a3b5d33787732b2ab4 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 17 Sep 2021 13:56:38 -0400 Subject: [PATCH 229/290] More test cases --- internal/codespaces/ssh_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go index ed6922762..c3e1b4c0a 100644 --- a/internal/codespaces/ssh_test.go +++ b/internal/codespaces/ssh_test.go @@ -14,6 +14,7 @@ func TestParseSSHArgs(t *testing.T) { } testCases := []testCase{ + {}, // empty test case { Args: []string{"-X", "-Y"}, ParsedArgs: []string{"-X", "-Y"}, @@ -59,6 +60,11 @@ func TestParseSSHArgs(t *testing.T) { ParsedArgs: []string{"-v"}, Command: []string{"echo", "-n", "test"}, }, + { + Args: []string{"-v", "echo", "-b", "test"}, + ParsedArgs: []string{"-v"}, + Command: []string{"echo", "-b", "test"}, + }, { Args: []string{"-b"}, ParsedArgs: nil, From 9f84015bd010818416442935717ae174c1072098 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 17 Sep 2021 14:00:16 -0400 Subject: [PATCH 230/290] Avoid append --- internal/codespaces/ssh.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 4cd0a4c92..6563db91a 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -82,8 +82,8 @@ func parseSSHArgs(args []string) (cmdArgs []string, command []string, err error) continue } - // if we've started parsing the command, append all further args to it - command = append(command, args[i:]...) + // if we've started parsing the command, set it to the rest of the args + command = args[i:] break } From da58313358f62a478794fa1337841dcc00aab065 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 17 Sep 2021 14:03:31 -0400 Subject: [PATCH 231/290] Remove redudant type def --- internal/codespaces/ssh.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 6563db91a..4c6ccb6d7 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -66,7 +66,7 @@ func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) // parseSSHArgs parses SSH arguments into two distinct slices of flags and command. // It returns an error if a unary flag is provided without an argument. -func parseSSHArgs(args []string) (cmdArgs []string, command []string, err error) { +func parseSSHArgs(args []string) (cmdArgs, command []string, err error) { for i := 0; i < len(args); i++ { arg := args[i] From 5890d6ad66ee56899ad497fe503d987adebfe744 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 17 Sep 2021 15:04:55 -0400 Subject: [PATCH 232/290] Switch if block logic, assert err string --- internal/codespaces/ssh.go | 25 ++++++++++++------------- internal/codespaces/ssh_test.go | 19 +++++++++++++------ 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 4c6ccb6d7..36c8bf5b2 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -70,21 +70,20 @@ func parseSSHArgs(args []string) (cmdArgs, command []string, err error) { for i := 0; i < len(args); i++ { arg := args[i] - if strings.HasPrefix(arg, "-") { - cmdArgs = append(cmdArgs, arg) - if len(arg) == 2 && strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) { - if i++; i == len(args) { - return nil, nil, fmt.Errorf("ssh flag: %s requires an argument", arg) - } - - cmdArgs = append(cmdArgs, args[i]) - } - continue + // if we've started parsing the command, set it to the rest of the args + if !strings.HasPrefix(arg, "-") { + command = args[i:] + break } - // if we've started parsing the command, set it to the rest of the args - command = args[i:] - break + cmdArgs = append(cmdArgs, arg) + if len(arg) == 2 && strings.Contains("bcDeFIiLlmOopRSWw", arg[1:2]) { + if i++; i == len(args) { + return nil, nil, fmt.Errorf("ssh flag: %s requires an argument", arg) + } + + cmdArgs = append(cmdArgs, args[i]) + } } return cmdArgs, command, nil diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go index c3e1b4c0a..c804f6000 100644 --- a/internal/codespaces/ssh_test.go +++ b/internal/codespaces/ssh_test.go @@ -10,7 +10,7 @@ func TestParseSSHArgs(t *testing.T) { Args []string ParsedArgs []string Command []string - Error bool + Error string } testCases := []testCase{ @@ -69,19 +69,26 @@ func TestParseSSHArgs(t *testing.T) { Args: []string{"-b"}, ParsedArgs: nil, Command: nil, - Error: true, + Error: "ssh flag: -b requires an argument", }, } for _, tcase := range testCases { args, command, err := parseSSHArgs(tcase.Args) - if !tcase.Error && err != nil { - t.Errorf("unexpected error: %v on test case: %#v", err, tcase) + if tcase.Error != "" { + if err == nil { + t.Errorf("expected error and got nil: %#v", tcase) + } + + if err.Error() != tcase.Error { + t.Errorf("error does not match expected error, got: '%s', expected: '%s'", err.Error(), tcase.Error) + } + continue } - if tcase.Error && err == nil { - t.Errorf("expected error and got nil: %#v", tcase) + if err != nil { + t.Errorf("unexpected error: %v on test case: %#v", err, tcase) continue } From 47c6a5fce818b6681edec6a39d292f9387c0d008 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 17 Sep 2021 15:13:09 -0400 Subject: [PATCH 233/290] Update usage --- cmd/ghcs/ssh.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index cdfcdbcbe..390fc25c8 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -18,7 +18,7 @@ func newSSHCmd() *cobra.Command { var sshServerPort int sshCmd := &cobra.Command{ - Use: "ssh", + Use: "ssh [flags] -- [ssh-flags] [command]", Short: "SSH into a codespace", RunE: func(cmd *cobra.Command, args []string) error { return ssh(context.Background(), args, sshProfile, codespaceName, sshServerPort) From 82c19729d3ce9fcc0c604c811ddd609ed6190f5e Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 17 Sep 2021 15:17:38 -0400 Subject: [PATCH 234/290] Wrap -- with optional argument brackets --- cmd/ghcs/ssh.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 390fc25c8..e459f265a 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -18,7 +18,7 @@ func newSSHCmd() *cobra.Command { var sshServerPort int sshCmd := &cobra.Command{ - Use: "ssh [flags] -- [ssh-flags] [command]", + Use: "ssh [flags] [--] [ssh-flags] [command]", Short: "SSH into a codespace", RunE: func(cmd *cobra.Command, args []string) error { return ssh(context.Background(), args, sshProfile, codespaceName, sshServerPort) From 11024f71fabc349404b1115884eb26e0ecf5ea2b Mon Sep 17 00:00:00 2001 From: Raffaele Di Fazio Date: Mon, 20 Sep 2021 10:27:29 +0200 Subject: [PATCH 235/290] force is not used in delete by repo Signed-off-by: Raffaele Di Fazio --- cmd/ghcs/delete.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 97864664a..5aff182c5 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -37,7 +37,7 @@ func newDeleteCmd() *cobra.Command { case allCodespaces: return deleteAll(log, force, keepThresholdDays) case repo != "": - return deleteByRepo(log, repo, force) + return deleteByRepo(log, repo) default: return delete_(log, codespace, force) } @@ -133,7 +133,7 @@ func deleteAll(log *output.Logger, force bool, keepThresholdDays int) error { return list(&listOptions{}) } -func deleteByRepo(log *output.Logger, repo string, force bool) error { +func deleteByRepo(log *output.Logger, repo string) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() From 4721e7004be64656c693901ae87a236ee646cd51 Mon Sep 17 00:00:00 2001 From: Raffaele Di Fazio Date: Mon, 20 Sep 2021 11:10:44 +0200 Subject: [PATCH 236/290] add threshold to delete by repo Signed-off-by: Raffaele Di Fazio --- cmd/ghcs/delete.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 5aff182c5..b91962f92 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -37,7 +37,7 @@ func newDeleteCmd() *cobra.Command { case allCodespaces: return deleteAll(log, force, keepThresholdDays) case repo != "": - return deleteByRepo(log, repo) + return deleteByRepo(log, repo, keepThresholdDays) default: return delete_(log, codespace, force) } @@ -133,7 +133,7 @@ func deleteAll(log *output.Logger, force bool, keepThresholdDays int) error { return list(&listOptions{}) } -func deleteByRepo(log *output.Logger, repo string) error { +func deleteByRepo(log *output.Logger, repo string, keepThresholdDays int) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() @@ -147,6 +147,11 @@ func deleteByRepo(log *output.Logger, repo string) error { return fmt.Errorf("error getting codespaces: %w", err) } + codespaces, err = filterCodespacesToDelete(codespaces, keepThresholdDays) + if err != nil { + return err + } + delete := func(name string) error { token, err := apiClient.GetCodespaceToken(ctx, user.Login, name) if err != nil { From c4f0eda96d18199431b66c83c75ab954e13af685 Mon Sep 17 00:00:00 2001 From: Raffaele Di Fazio Date: Mon, 20 Sep 2021 11:54:30 +0200 Subject: [PATCH 237/290] force was actually needed by a next commit Signed-off-by: Raffaele Di Fazio --- cmd/ghcs/delete.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 621a2050b..c5bd53f98 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -37,7 +37,7 @@ func newDeleteCmd() *cobra.Command { case allCodespaces: return deleteAll(log, force, keepThresholdDays) case repo != "": - return deleteByRepo(log, repo, keepThresholdDays) + return deleteByRepo(log, repo, force, keepThresholdDays) default: return delete_(log, codespace, force) } @@ -133,7 +133,7 @@ func deleteAll(log *output.Logger, force bool, keepThresholdDays int) error { return list(&listOptions{}) } -func deleteByRepo(log *output.Logger, repo string, keepThresholdDays int) error { +func deleteByRepo(log *output.Logger, repo string, force bool, keepThresholdDays int) error { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() From 57d04dc5f020ebefbe080e1fa6873dcded731d7a Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 20 Sep 2021 13:16:38 +0000 Subject: [PATCH 238/290] Allow clients to Close a Session, general tidy up - Allow clients to call Close on a Session to clean up resources - Switch to the %w verb for error wrapping - Fix typo on Port struct after verifying the server does not have a typo --- client.go | 14 +++++++------- port_forwarder.go | 4 ++-- rpc.go | 3 +-- session.go | 18 +++++++++++++++--- ssh.go | 8 ++++---- 5 files changed, 29 insertions(+), 18 deletions(-) diff --git a/client.go b/client.go index 65e80a94a..ba9d2f5e7 100644 --- a/client.go +++ b/client.go @@ -58,18 +58,18 @@ func (c *Client) JoinWorkspace(ctx context.Context) (*Session, error) { clientSocket := newSocket(c.connection, c.tlsConfig) if err := clientSocket.connect(ctx); err != nil { - return nil, fmt.Errorf("error connecting websocket: %v", err) + return nil, fmt.Errorf("error connecting websocket: %w", err) } ssh := newSSHSession(c.connection.SessionToken, clientSocket) if err := ssh.connect(ctx); err != nil { - return nil, fmt.Errorf("error connecting to ssh session: %v", err) + return nil, fmt.Errorf("error connecting to ssh session: %w", err) } rpc := newRPCClient(ssh) rpc.connect(ctx) if _, err := c.joinWorkspace(ctx, rpc); err != nil { - return nil, fmt.Errorf("error joining Live Share workspace: %v", err) + return nil, fmt.Errorf("error joining Live Share workspace: %w", err) } return &Session{ssh: ssh, rpc: rpc}, nil @@ -108,7 +108,7 @@ func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorksp var result joinWorkspaceResult if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { - return nil, fmt.Errorf("error making workspace.joinWorkspace call: %v", err) + return nil, fmt.Errorf("error making workspace.joinWorkspace call: %w", err) } return &result, nil @@ -125,7 +125,7 @@ func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.C } var streamID string if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil { - return nil, fmt.Errorf("error getting stream id: %v", err) + return nil, fmt.Errorf("error getting stream id: %w", err) } span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") @@ -133,13 +133,13 @@ func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.C channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) if err != nil { - return nil, fmt.Errorf("error opening ssh channel for transport: %v", err) + return nil, fmt.Errorf("error opening ssh channel for transport: %w", err) } go ssh.DiscardRequests(reqs) requestType := fmt.Sprintf("stream-transport-%s", streamID) if _, err = channel.SendRequest(requestType, true, nil); err != nil { - return nil, fmt.Errorf("error sending channel request: %v", err) + return nil, fmt.Errorf("error sending channel request: %w", err) } return channel, nil diff --git a/port_forwarder.go b/port_forwarder.go index 5dafd0c65..56401cc4d 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -92,7 +92,7 @@ func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) { id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort) if err != nil { - err = fmt.Errorf("failed to share remote port %d: %v", fwd.remotePort, err) + err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err) } return id, nil } @@ -115,7 +115,7 @@ func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, co channel, err := fwd.session.openStreamingChannel(ctx, id) if err != nil { - return fmt.Errorf("error opening streaming channel for new connection: %v", err) + return fmt.Errorf("error opening streaming channel for new connection: %w", err) } // Ideally we would call safeClose again, but (*ssh.channel).Close // appears to have a bug that causes it return io.EOF spuriously diff --git a/rpc.go b/rpc.go index 68e187ad6..bfd214c89 100644 --- a/rpc.go +++ b/rpc.go @@ -20,7 +20,6 @@ func newRPCClient(conn io.ReadWriteCloser) *rpcClient { func (r *rpcClient) connect(ctx context.Context) { stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{}) - // TODO(adonovan): fix: ensure r.Conn is eventually Closed! r.Conn = jsonrpc2.NewConn(ctx, stream, nullHandler{}) } @@ -30,7 +29,7 @@ func (r *rpcClient) do(ctx context.Context, method string, args, result interfac waiter, err := r.Conn.DispatchCall(ctx, method, args) if err != nil { - return fmt.Errorf("error dispatching %q call: %v", method, err) + return fmt.Errorf("error dispatching %q call: %w", method, err) } return waiter.Wait(ctx, result) diff --git a/session.go b/session.go index f427fac6d..6a078da7e 100644 --- a/session.go +++ b/session.go @@ -12,6 +12,20 @@ type Session struct { rpc *rpcClient } +// Close should be called by users to clean up RPC and SSH resources whenever the session +// is no longer active. +func (s *Session) Close() error { + if err := s.rpc.Close(); err != nil { + return fmt.Errorf("failed to close RPC conn: %w", err) + } + + if err := s.ssh.Close(); err != nil { + return fmt.Errorf("failed to close SSH conn: %w", err) + } + + return nil +} + // Port describes a port exposed by the container. type Port struct { SourcePort int `json:"sourcePort"` @@ -22,9 +36,7 @@ type Port struct { BrowseURL string `json:"browseUrl"` IsPublic bool `json:"isPublic"` IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"` - HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"` - // ^^^ - // TODO(adonovan): fix possible typo in field name, and audit others. + HasTLSHandshakePassed bool `json:"hasTLSHandshakePassed"` } // startSharing tells the Live Share host to start sharing the specified port from the container. diff --git a/ssh.go b/ssh.go index b68d400a1..15f67d2a4 100644 --- a/ssh.go +++ b/ssh.go @@ -36,24 +36,24 @@ func (s *sshSession) connect(ctx context.Context) error { sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig) if err != nil { - return fmt.Errorf("error creating ssh client connection: %v", err) + return fmt.Errorf("error creating ssh client connection: %w", err) } s.conn = sshClientConn sshClient := ssh.NewClient(sshClientConn, chans, reqs) s.Session, err = sshClient.NewSession() if err != nil { - return fmt.Errorf("error creating ssh client session: %v", err) + return fmt.Errorf("error creating ssh client session: %w", err) } s.reader, err = s.Session.StdoutPipe() if err != nil { - return fmt.Errorf("error creating ssh session reader: %v", err) + return fmt.Errorf("error creating ssh session reader: %w", err) } s.writer, err = s.Session.StdinPipe() if err != nil { - return fmt.Errorf("error creating ssh session writer: %v", err) + return fmt.Errorf("error creating ssh session writer: %w", err) } return nil From c222c3d696ef599229a47b20a023aa2ca2ecfdea Mon Sep 17 00:00:00 2001 From: Raffaele Di Fazio Date: Mon, 20 Sep 2021 18:23:00 +0200 Subject: [PATCH 239/290] drop check on shut down Signed-off-by: Raffaele Di Fazio --- cmd/ghcs/delete.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index c5bd53f98..dbbbae814 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -261,7 +261,7 @@ func filterCodespacesToDelete(codespaces []*api.Codespace, keepThresholdDays int if err != nil { return nil, fmt.Errorf("error parsing last used at date: %w", err) } - if t.Before(now().AddDate(0, 0, -keepThresholdDays)) && codespace.Environment.State == "Shutdown" { + if t.Before(now().AddDate(0, 0, -keepThresholdDays)) { codespacesToDelete = append(codespacesToDelete, codespace) } } From b894d3e1340da8aeaf8b83df77a88cc85b1f169a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Mon, 20 Sep 2021 18:37:00 +0200 Subject: [PATCH 240/290] Simplify delete implementation --- cmd/ghcs/common.go | 3 + cmd/ghcs/delete.go | 263 +++++++++++++-------------------------------- 2 files changed, 79 insertions(+), 187 deletions(-) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index e71e3dfe4..46a4f8c0b 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -22,7 +22,10 @@ func chooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (* if err != nil { return nil, fmt.Errorf("error getting codespaces: %w", err) } + return chooseCodespaceFromList(ctx, codespaces) +} +func chooseCodespaceFromList(ctx context.Context, codespaces []*api.Codespace) (*api.Codespace, error) { if len(codespaces) == 0 { return nil, errNoCodespaces } diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index c5bd53f98..fdb813c83 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -2,53 +2,57 @@ package main import ( "context" - "errors" "fmt" "os" "strings" - "sync" "time" "github.com/AlecAivazis/survey/v2" "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" ) -var now func() time.Time = time.Now +type deleteOptions struct { + deleteAll bool + skipConfirm bool + isInteractive bool + codespaceName string + repoFilter string + keepDays uint16 + now func() time.Time + apiClient *api.API +} func newDeleteCmd() *cobra.Command { - var ( - codespace string - allCodespaces bool - repo string - force bool - keepThresholdDays int - ) + opts := deleteOptions{ + apiClient: api.New(os.Getenv("GITHUB_TOKEN")), + now: time.Now, + isInteractive: hasTTY, + } - log := output.NewLogger(os.Stdout, os.Stderr, false) deleteCmd := &cobra.Command{ Use: "delete", Short: "Delete a codespace", RunE: func(cmd *cobra.Command, args []string) error { - switch { - case allCodespaces && repo != "": - return errors.New("both --all and --repo is not supported") - case allCodespaces: - return deleteAll(log, force, keepThresholdDays) - case repo != "": - return deleteByRepo(log, repo, force, keepThresholdDays) - default: - return delete_(log, codespace, force) - } + // switch { + // case allCodespaces && repo != "": + // return errors.New("both --all and --repo is not supported") + // case allCodespaces: + // return deleteAll(log, force, keepThresholdDays) + // case repo != "": + // return deleteByRepo(log, repo, force, keepThresholdDays) + log := output.NewLogger(os.Stdout, os.Stderr, false) + return delete(context.Background(), log, opts) }, } - deleteCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") - deleteCmd.Flags().BoolVar(&allCodespaces, "all", false, "Delete all codespaces") - deleteCmd.Flags().StringVarP(&repo, "repo", "r", "", "Delete all codespaces for a repository") - deleteCmd.Flags().BoolVarP(&force, "force", "f", false, "Delete codespaces with unsaved changes without confirmation") - deleteCmd.Flags().IntVar(&keepThresholdDays, "days", 0, "Minimum number of days since the codespace was created") + deleteCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "Delete codespace by `name`") + deleteCmd.Flags().BoolVar(&opts.deleteAll, "all", false, "Delete all codespaces") + deleteCmd.Flags().StringVarP(&opts.repoFilter, "repo", "r", "", "Delete codespaces for a repository") + deleteCmd.Flags().BoolVarP(&opts.skipConfirm, "force", "f", false, "Skip confirmation for codespaces that contain unsaved changes") + deleteCmd.Flags().Uint16Var(&opts.keepDays, "days", 0, "Delete codespaces older than `N` days") return deleteCmd } @@ -57,175 +61,78 @@ func init() { rootCmd.AddCommand(newDeleteCmd()) } -func delete_(log *output.Logger, codespaceName string, force bool) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) - ctx := context.Background() - - user, err := apiClient.GetUser(ctx) +func delete(ctx context.Context, log *output.Logger, opts deleteOptions) error { + user, err := opts.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) - if err != nil { - return fmt.Errorf("get or choose codespace: %w", err) - } - - confirmed, err := confirmDeletion(codespace, force) - if err != nil { - return fmt.Errorf("deletion could not be confirmed: %w", err) - } - - if !confirmed { - return nil - } - - if err := apiClient.DeleteCodespace(ctx, user, token, codespace.Name); err != nil { - return fmt.Errorf("error deleting codespace: %w", err) - } - - log.Println("Codespace deleted.") - - return list(&listOptions{}) -} - -func deleteAll(log *output.Logger, force bool, keepThresholdDays int) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) - ctx := context.Background() - - user, err := apiClient.GetUser(ctx) - if err != nil { - return fmt.Errorf("error getting user: %w", err) - } - - codespaces, err := apiClient.ListCodespaces(ctx, user) + codespaces, err := opts.apiClient.ListCodespaces(ctx, user) if err != nil { return fmt.Errorf("error getting codespaces: %w", err) } - codespacesToDelete, err := filterCodespacesToDelete(codespaces, keepThresholdDays) - if err != nil { - return err - } - - for _, c := range codespacesToDelete { - confirmed, err := confirmDeletion(c, force) + nameFilter := opts.codespaceName + if nameFilter == "" && !opts.deleteAll && opts.repoFilter == "" { + c, err := chooseCodespaceFromList(ctx, codespaces) if err != nil { - return fmt.Errorf("deletion could not be confirmed: %w", err) + return fmt.Errorf("error choosing codespace: %w", err) } - - if !confirmed { - continue - } - - token, err := apiClient.GetCodespaceToken(ctx, user.Login, c.Name) - if err != nil { - return fmt.Errorf("error getting codespace token: %w", err) - } - - if err := apiClient.DeleteCodespace(ctx, user, token, c.Name); err != nil { - return fmt.Errorf("error deleting codespace: %w", err) - } - - log.Printf("Codespace deleted: %s\n", c.Name) + nameFilter = c.Name } - return list(&listOptions{}) -} - -func deleteByRepo(log *output.Logger, repo string, force bool, keepThresholdDays int) error { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) - ctx := context.Background() - - user, err := apiClient.GetUser(ctx) - if err != nil { - return fmt.Errorf("error getting user: %w", err) - } - - codespaces, err := apiClient.ListCodespaces(ctx, user) - if err != nil { - return fmt.Errorf("error getting codespaces: %w", err) - } - - codespaces, err = filterCodespacesToDelete(codespaces, keepThresholdDays) - if err != nil { - return err - } - - delete := func(name string) error { - token, err := apiClient.GetCodespaceToken(ctx, user.Login, name) - if err != nil { - return fmt.Errorf("error getting codespace token: %w", err) - } - - if err := apiClient.DeleteCodespace(ctx, user, token, name); err != nil { - return fmt.Errorf("error deleting codespace: %w", err) - } - - return nil - } - - // Perform deletions in parallel, for performance, - // and to ensure all are attempted even if any one fails. - var ( - found bool - mu sync.Mutex // guards errs, logger - errs []error - wg sync.WaitGroup - ) + var codespacesToDelete []*api.Codespace + lastUpdatedCutoffTime := opts.now().AddDate(0, 0, -int(opts.keepDays)) for _, c := range codespaces { - if !strings.EqualFold(c.RepositoryNWO, repo) { + if nameFilter != "" && c.Name != nameFilter { continue } - - confirmed, err := confirmDeletion(c, force) - if err != nil { - mu.Lock() - errs = append(errs, fmt.Errorf("deletion could not be confirmed: %w", err)) - mu.Unlock() + if opts.repoFilter != "" && !strings.EqualFold(c.RepositoryNWO, opts.repoFilter) { continue } - - if !confirmed { - continue - } - - found = true - c := c - wg.Add(1) - go func() { - defer wg.Done() - err := delete(c.Name) - mu.Lock() - defer mu.Unlock() + if opts.keepDays > 0 { + t, err := time.Parse(time.RFC3339, c.LastUsedAt) if err != nil { - errs = append(errs, err) - } else { - log.Printf("Codespace deleted: %s\n", c.Name) + return fmt.Errorf("error parsing last_used_at timestamp %q: %w", c.LastUsedAt, err) + } + if t.After(lastUpdatedCutoffTime) { + continue } - }() - } - if !found { - return fmt.Errorf("no codespace was found for repository: %s", repo) - } - wg.Wait() - - // Return first error, plus count of others. - if errs != nil { - err := errs[0] - if others := len(errs) - 1; others > 0 { - err = fmt.Errorf("%w (+%d more)", err, others) } - return err + if nameFilter == "" || !opts.skipConfirm { + confirmed, err := confirmDeletion(c) + if err != nil { + return fmt.Errorf("deletion could not be confirmed: %w", err) + } + if !confirmed { + continue + } + } + codespacesToDelete = append(codespacesToDelete, c) } - return nil + g := errgroup.Group{} + for _, c := range codespacesToDelete { + codespaceName := c.Name + g.Go(func() error { + token, err := opts.apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) + if err != nil { + return fmt.Errorf("error getting codespace token: %w", err) + } + if err := opts.apiClient.DeleteCodespace(ctx, user, token, codespaceName); err != nil { + return fmt.Errorf("error deleting codespace: %w", err) + } + return nil + }) + } + + return g.Wait() } -func confirmDeletion(codespace *api.Codespace, force bool) (bool, error) { +func confirmDeletion(codespace *api.Codespace) (bool, error) { gs := codespace.Environment.GitStatus hasUnsavedChanges := gs.HasUncommitedChanges || gs.HasUnpushedChanges - if force || !hasUnsavedChanges { + if !hasUnsavedChanges { return true, nil } if !hasTTY { @@ -249,21 +156,3 @@ func confirmDeletion(codespace *api.Codespace, force bool) (bool, error) { return confirmed.Confirmed, nil } - -func filterCodespacesToDelete(codespaces []*api.Codespace, keepThresholdDays int) ([]*api.Codespace, error) { - if keepThresholdDays < 0 { - return nil, fmt.Errorf("invalid value for threshold: %d", keepThresholdDays) - } - codespacesToDelete := []*api.Codespace{} - for _, codespace := range codespaces { - // get a date from a string representation - t, err := time.Parse(time.RFC3339, codespace.LastUsedAt) - if err != nil { - return nil, fmt.Errorf("error parsing last used at date: %w", err) - } - if t.Before(now().AddDate(0, 0, -keepThresholdDays)) && codespace.Environment.State == "Shutdown" { - codespacesToDelete = append(codespacesToDelete, codespace) - } - } - return codespacesToDelete, nil -} From 9e08b7477da09d6ccb717716fa2c5874579a72a1 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 20 Sep 2021 13:40:45 -0400 Subject: [PATCH 241/290] delete: reject position args --- cmd/ghcs/delete.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index eb00e567f..e6c5c9f53 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -27,6 +27,9 @@ func newDeleteCmd() *cobra.Command { Use: "delete", Short: "Delete a codespace", RunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + return fmt.Errorf("delete: unexpected positional arguments") + } switch { case allCodespaces && repo != "": return errors.New("both --all and --repo is not supported") From dbb80d8b1ef8bf2289c9d6614dbf95d4ad6fba0d Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 20 Sep 2021 16:01:43 -0400 Subject: [PATCH 242/290] check for authorised SSH keys --- cmd/ghcs/ssh.go | 21 +++++++++++++++++++++ internal/api/api.go | 26 ++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 6d5f2376b..3967f5512 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -49,6 +49,23 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo return fmt.Errorf("error getting user: %w", err) } + // Check whether the user has registered any SSH keys. + // See https://github.com/github/ghcs/issues/166#issuecomment-921769703 + checkAuthKeys := func(user string) error { + keys, err := apiClient.AuthorizedKeys(ctx, user) + if err != nil { + return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err) + } + if len(keys) == 0 { + return fmt.Errorf("user %s has no GitHub-authorized SSH keys", user) + } + return nil // success + } + authkeys := make(chan error, 1) + go func() { + authkeys <- checkAuthKeys(user.Login) + }() + codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) @@ -59,6 +76,10 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo return fmt.Errorf("error connecting to Live Share: %w", err) } + if err := <-authkeys; err != nil { + return err + } + log.Println("Fetching SSH Details...") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) if err != nil { diff --git a/internal/api/api.go b/internal/api/api.go index 1246389e8..2dd4d71b2 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -13,6 +13,7 @@ package api // - github.GetUser(github.Client) // - github.GetRepository(Client) // - github.ReadFile(Client, nwo, branch, path) // was GetCodespaceRepositoryContents +// - github.AuthorizedKeys(Client, user) // - codespaces.Create(Client, user, repo, sku, branch, location) // - codespaces.Delete(Client, user, token, name) // - codespaces.Get(Client, token, owner, name) @@ -507,6 +508,31 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod return decoded, nil } +// AuthorizedKeys returns the public keys (in ~/.ssh/authorized_keys +// format) registered by the specified GitHub user. +func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) { + url := fmt.Sprintf("https://github.com/%s.keys", user) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + resp, err := a.do(ctx, req, "/user.keys") + if err != nil { + return nil, err + } + defer resp.Body.Close() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("server returned %s", resp.Status) + } + return b, nil +} + func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http.Response, error) { // TODO(adonovan): use NewRequestWithContext(ctx) and drop ctx parameter. span, ctx := opentracing.StartSpanFromContext(ctx, spanName) From 40886479ae42cff937e37febadeea4708451d4cb Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 20 Sep 2021 20:35:12 +0000 Subject: [PATCH 243/290] Close SSH even if RPC Close fails --- session.go | 1 + 1 file changed, 1 insertion(+) diff --git a/session.go b/session.go index 6a078da7e..5ea961d82 100644 --- a/session.go +++ b/session.go @@ -16,6 +16,7 @@ type Session struct { // is no longer active. func (s *Session) Close() error { if err := s.rpc.Close(); err != nil { + s.ssh.Close() // close SSH and ignore error return fmt.Errorf("failed to close RPC conn: %w", err) } From 7f682f9c398099f30ab0824db56afc95a9edce9e Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 20 Sep 2021 16:56:57 -0400 Subject: [PATCH 244/290] Close Live Share sessions - New helper method codespaces.CloseSession to be used using defer - Upgrade to go-liveshare v0.17.0 --- cmd/ghcs/logs.go | 3 ++- cmd/ghcs/ports.go | 9 ++++++--- cmd/ghcs/ssh.go | 3 ++- internal/codespaces/codespaces.go | 8 ++++++++ internal/codespaces/states.go | 3 ++- 5 files changed, 20 insertions(+), 6 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 19528061a..4a38319e6 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -40,7 +40,7 @@ func init() { rootCmd.AddCommand(newLogsCmd()) } -func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) error { +func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) (err error) { // Ensure all child tasks (port forwarding, remote exec) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -61,6 +61,7 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow if err != nil { return fmt.Errorf("connecting to Live Share: %w", err) } + defer codespaces.CloseSession(session, &err) // Ensure local port is listening before client (getPostCreateOutput) connects. listen, err := net.Listen("tcp", ":0") // arbitrary port diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 7bc53c441..45f92da7d 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -51,7 +51,7 @@ func init() { rootCmd.AddCommand(newPortsCmd()) } -func ports(codespaceName string, asJSON bool) error { +func ports(codespaceName string, asJSON bool) (err error) { apiClient := api.New(os.Getenv("GITHUB_TOKEN")) ctx := context.Background() log := output.NewLogger(os.Stdout, os.Stderr, asJSON) @@ -76,6 +76,7 @@ func ports(codespaceName string, asJSON bool) error { if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer codespaces.CloseSession(session, &err) log.Println("Loading ports...") ports, err := session.GetSharedServers(ctx) @@ -198,7 +199,7 @@ func newPortsPrivateCmd() *cobra.Command { } } -func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) error { +func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) (err error) { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) @@ -219,6 +220,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer codespaces.CloseSession(session, &err) port, err := strconv.Atoi(sourcePort) if err != nil { @@ -260,7 +262,7 @@ func newPortsForwardCmd() *cobra.Command { } } -func forwardPorts(log *output.Logger, codespaceName string, ports []string) error { +func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err error) { ctx := context.Background() apiClient := api.New(os.Getenv("GITHUB_TOKEN")) @@ -286,6 +288,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) erro if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer codespaces.CloseSession(session, &err) // Run forwarding of all ports concurrently, aborting all of // them at the first failure, including cancellation of the context. diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 4ece84d91..a92c99bb3 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -36,7 +36,7 @@ func init() { rootCmd.AddCommand(newSSHCmd()) } -func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) error { +func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -58,6 +58,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } + defer codespaces.CloseSession(session, &err) log.Println("Fetching SSH Details...") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 6235ca3a0..fe62a3d79 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -73,3 +73,11 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use return lsclient.JoinWorkspace(ctx) } + +// CloseSession closes the Live Share session and assigns the error to the pointer if it is nil. +func CloseSession(session *liveshare.Session, err *error) { + closeErr := session.Close() + if *err == nil { + *err = closeErr + } +} diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 408f11941..7e464d919 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -36,7 +36,7 @@ type PostCreateState struct { // PollPostCreateStates watches for state changes in a codespace, // and calls the supplied poller for each batch of state changes. // It runs until it encounters an error, including cancellation of the context. -func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) error { +func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { return fmt.Errorf("getting codespace token: %w", err) @@ -46,6 +46,7 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u if err != nil { return fmt.Errorf("connect to Live Share: %w", err) } + defer CloseSession(session, &err) // Ensure local port is listening before client (getPostCreateOutput) connects. listen, err := net.Listen("tcp", ":0") // arbitrary port From 23f6d449e0f6bcf9846dbc57b652f129db7e133d Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Mon, 20 Sep 2021 21:16:54 +0000 Subject: [PATCH 245/290] Close RPC conn only - Only close SSH if RPC fails. Closing RPC automatically closes the underlying stream which in this case is the SSH connection. - I thought about closing the SSH conn instead of RPC, but there is a bit more cleanup that the RPC library needs to do. --- session.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/session.go b/session.go index 5ea961d82..5202205a8 100644 --- a/session.go +++ b/session.go @@ -15,16 +15,14 @@ type Session struct { // Close should be called by users to clean up RPC and SSH resources whenever the session // is no longer active. func (s *Session) Close() error { - if err := s.rpc.Close(); err != nil { + // Closing the RPC conn closes the underlying stream (SSH) + // So we only need to close once + err := s.rpc.Close() + if err != nil { s.ssh.Close() // close SSH and ignore error - return fmt.Errorf("failed to close RPC conn: %w", err) } - if err := s.ssh.Close(); err != nil { - return fmt.Errorf("failed to close SSH conn: %w", err) - } - - return nil + return err } // Port describes a port exposed by the container. From a83b3c08167cddb3f8edb07c25ad1de428194c79 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 21 Sep 2021 08:46:32 -0400 Subject: [PATCH 246/290] Update to go-livesare v0.18.0 - Only set err if closeErr is non-nil --- internal/codespaces/codespaces.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index fe62a3d79..ae1115905 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -23,6 +23,8 @@ func connectionReady(codespace *api.Codespace) bool { codespace.Environment.State == api.CodespaceEnvironmentStateAvailable } +// ConnectToLiveshare creates a Live Share client and joins the Live Share session. +// It will start the Codespace if it is not already running, it will time out after 60 seconds if fails to start. func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { @@ -75,9 +77,10 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use } // CloseSession closes the Live Share session and assigns the error to the pointer if it is nil. +// It is meant to be called using defer with a named return argument for the error. func CloseSession(session *liveshare.Session, err *error) { closeErr := session.Close() - if *err == nil { - *err = closeErr + if *err == nil && closeErr != nil { + *err = fmt.Errorf("failed to close Live Share session: %w", closeErr) } } From 5f6b3a5eeed2c8d0ea3c9073df06dad33686af88 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 21 Sep 2021 13:46:30 +0000 Subject: [PATCH 247/290] Add error context to Session.Close --- session.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/session.go b/session.go index 5202205a8..929e8605b 100644 --- a/session.go +++ b/session.go @@ -17,12 +17,12 @@ type Session struct { func (s *Session) Close() error { // Closing the RPC conn closes the underlying stream (SSH) // So we only need to close once - err := s.rpc.Close() - if err != nil { + if err := s.rpc.Close(); err != nil { s.ssh.Close() // close SSH and ignore error + return fmt.Errorf("error while closing Live Share session: %w", err) } - return err + return nil } // Port describes a port exposed by the container. From 0b68aaab7edf7083679e0d257b4fc2e18aa5e26e Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 21 Sep 2021 09:59:16 -0400 Subject: [PATCH 248/290] Return error on 202 responses - Start implementing the retry/poll flow --- cmd/ghcs/create.go | 19 ++++++++++++++++++- internal/api/api.go | 7 ++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 2125176fd..ff3e13962 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "strings" + "time" "github.com/AlecAivazis/survey/v2" "github.com/fatih/camelcase" @@ -87,8 +88,20 @@ func create(opts *createOptions) error { log.Println("Creating your codespace...") - codespace, err := apiClient.CreateCodespace(ctx, userResult.User, repository, machine, branch, locationResult.Location) + codespace, err := apiClient.CreateCodespace( + ctx, userResult.User, repository, machine, branch, locationResult.Location, + ) if err != nil { + if err == api.ErrCreateAsyncRetry { + createRetryCtx, cancelRetry := context.WithTimeout(ctx, 2*time.Minute) + defer cancelRetry() + + codespace, err = pollForProvisionedCodespace(createRetryCtx, codespace) + if err != nil { + return fmt.Errorf("error creating codespace after retry: %w", err) + } + } + return fmt.Errorf("error creating codespace: %w", err) } @@ -105,6 +118,10 @@ func create(opts *createOptions) error { return nil } +func pollForProvisionedCodespace(ctx context.Context, provisioningCodespace *api.Codespace) (*api.Codespace, error) { + return nil, nil +} + // showStatus polls the codespace for a list of post create states and their status. It will keep polling // until all states have finished. Once all states have finished, we poll once more to check if any new // states have been introduced and stop polling otherwise. diff --git a/internal/api/api.go b/internal/api/api.go index 1246389e8..df9fd10c7 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -401,6 +401,8 @@ type createCodespaceRequest struct { SkuName string `json:"sku_name"` } +var ErrCreateAsyncRetry = errors.New("initial creation failed, retrying async") + func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repository, sku, branch, location string) (*Codespace, error) { requestBody, err := json.Marshal(createCodespaceRequest{repository.ID, branch, location, sku}) if err != nil { @@ -424,8 +426,11 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos return nil, fmt.Errorf("error reading response body: %w", err) } - if resp.StatusCode > http.StatusAccepted { + switch { + case resp.StatusCode > http.StatusAccepted: return nil, jsonErrorResponse(b) + case resp.StatusCode == http.StatusAccepted: + return nil, ErrCreateAsyncRetry } var response Codespace From d3d1ce726d5853c907775e2bafd8b0dbd163e416 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 21 Sep 2021 09:59:19 -0400 Subject: [PATCH 249/290] do logs too --- cmd/ghcs/common.go | 14 ++++++++++++++ cmd/ghcs/logs.go | 9 +++++++++ cmd/ghcs/ssh.go | 14 +------------- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index e71e3dfe4..ba23ef8b4 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -120,3 +120,17 @@ func ask(qs []*survey.Question, response interface{}) error { } return err } + +// checkAuthorizedKeys reports an error if the user has not registered any SSH keys; +// see https://github.com/github/ghcs/issues/166#issuecomment-921769703. +// The check is not required for security but it improves the error message. +func checkAuthorizedKeys(ctx context.Context, client *api.API, user string) error { + keys, err := client.AuthorizedKeys(ctx, user) + if err != nil { + return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err) + } + if len(keys) == 0 { + return fmt.Errorf("user %s has no GitHub-authorized SSH keys", user) + } + return nil // success +} diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index f65fa1109..d0f164c37 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -52,6 +52,11 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow return fmt.Errorf("getting user: %w", err) } + authkeys := make(chan error, 1) + go func() { + authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login) + }() + codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) @@ -62,6 +67,10 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow return fmt.Errorf("connecting to Live Share: %w", err) } + if err := <-authkeys; err != nil { + return err + } + // Ensure local port is listening before client (getPostCreateOutput) connects. listen, err := net.Listen("tcp", ":0") // arbitrary port if err != nil { diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 3967f5512..23e87b33d 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -49,21 +49,9 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, localSSHServerPo return fmt.Errorf("error getting user: %w", err) } - // Check whether the user has registered any SSH keys. - // See https://github.com/github/ghcs/issues/166#issuecomment-921769703 - checkAuthKeys := func(user string) error { - keys, err := apiClient.AuthorizedKeys(ctx, user) - if err != nil { - return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err) - } - if len(keys) == 0 { - return fmt.Errorf("user %s has no GitHub-authorized SSH keys", user) - } - return nil // success - } authkeys := make(chan error, 1) go func() { - authkeys <- checkAuthKeys(user.Login) + authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login) }() codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) From e8e914c220b9ec828b4965a07a843d67bb4c3c18 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 21 Sep 2021 10:05:48 -0400 Subject: [PATCH 250/290] PR Feedback - Upgrade to go-liveshare v0.19.0 - Remove export helper method - Use local implementation --- cmd/ghcs/common.go | 7 +++++++ cmd/ghcs/logs.go | 2 +- cmd/ghcs/ports.go | 6 +++--- cmd/ghcs/ssh.go | 2 +- internal/codespaces/codespaces.go | 9 --------- internal/codespaces/states.go | 6 +++++- 6 files changed, 17 insertions(+), 15 deletions(-) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index e71e3dfe4..79fda32ef 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -6,6 +6,7 @@ import ( "context" "errors" "fmt" + "io" "os" "sort" @@ -93,6 +94,12 @@ func getOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.Use return codespace, token, nil } +func safeClose(closer io.Closer, err *error) { + if closeErr := closer.Close(); *err == nil { + *err = closeErr + } +} + // hasTTY indicates whether the process connected to a terminal. // It is not portable to assume stdin/stdout are fds 0 and 1. var hasTTY = term.IsTerminal(int(os.Stdin.Fd())) && term.IsTerminal(int(os.Stdout.Fd())) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 4a38319e6..db83250c5 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -61,7 +61,7 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow if err != nil { return fmt.Errorf("connecting to Live Share: %w", err) } - defer codespaces.CloseSession(session, &err) + defer safeClose(session, &err) // Ensure local port is listening before client (getPostCreateOutput) connects. listen, err := net.Listen("tcp", ":0") // arbitrary port diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 45f92da7d..f48dd1e6d 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -76,7 +76,7 @@ func ports(codespaceName string, asJSON bool) (err error) { if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } - defer codespaces.CloseSession(session, &err) + defer safeClose(session, &err) log.Println("Loading ports...") ports, err := session.GetSharedServers(ctx) @@ -220,7 +220,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } - defer codespaces.CloseSession(session, &err) + defer safeClose(session, &err) port, err := strconv.Atoi(sourcePort) if err != nil { @@ -288,7 +288,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } - defer codespaces.CloseSession(session, &err) + defer safeClose(session, &err) // Run forwarding of all ports concurrently, aborting all of // them at the first failure, including cancellation of the context. diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index a92c99bb3..88117c480 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -58,7 +58,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } - defer codespaces.CloseSession(session, &err) + defer safeClose(session, &err) log.Println("Fetching SSH Details...") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index ae1115905..2933c9d8d 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -75,12 +75,3 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use return lsclient.JoinWorkspace(ctx) } - -// CloseSession closes the Live Share session and assigns the error to the pointer if it is nil. -// It is meant to be called using defer with a named return argument for the error. -func CloseSession(session *liveshare.Session, err *error) { - closeErr := session.Close() - if *err == nil && closeErr != nil { - *err = fmt.Errorf("failed to close Live Share session: %w", closeErr) - } -} diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 7e464d919..31105d576 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -46,7 +46,11 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u if err != nil { return fmt.Errorf("connect to Live Share: %w", err) } - defer CloseSession(session, &err) + defer func() { + if closeErr := session.Close(); err == nil { + err = closeErr + } + }() // Ensure local port is listening before client (getPostCreateOutput) connects. listen, err := net.Listen("tcp", ":0") // arbitrary port From 323462ca5c3ed803da22b47f68e24d7d697c43bf Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 21 Sep 2021 12:37:11 -0400 Subject: [PATCH 251/290] Poll codespace on ErrCreateAsyncRetry error - Introduce tests for the poller - Attempt to fetch codespace for 2 mins --- cmd/ghcs/create.go | 50 +++++++++++++++++++--- cmd/ghcs/create_test.go | 93 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 6 deletions(-) create mode 100644 cmd/ghcs/create_test.go diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index ff3e13962..93016bbf8 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -92,13 +92,19 @@ func create(opts *createOptions) error { ctx, userResult.User, repository, machine, branch, locationResult.Location, ) if err != nil { + // This error is returned by the API when the initial creation fails with a retryable error. + // A retryable error means that GitHub will retry to re-create Codespace and clients should poll + // the API and attempt to fetch the Codespace for the next two minutes. if err == api.ErrCreateAsyncRetry { - createRetryCtx, cancelRetry := context.WithTimeout(ctx, 2*time.Minute) - defer cancelRetry() + log.Print("Switching to async provisioning...") + pollctx, cancel := context.WithTimeout(ctx, 2*time.Minute) + defer cancel() + + codespace, err = pollForCodespace(pollctx, apiClient, log, userResult.User, codespace) + log.Print("\n") - codespace, err = pollForProvisionedCodespace(createRetryCtx, codespace) if err != nil { - return fmt.Errorf("error creating codespace after retry: %w", err) + return fmt.Errorf("error creating codespace with async provisioning: %s: %w", codespace.Name, err) } } @@ -118,8 +124,40 @@ func create(opts *createOptions) error { return nil } -func pollForProvisionedCodespace(ctx context.Context, provisioningCodespace *api.Codespace) (*api.Codespace, error) { - return nil, nil +type apiClient interface { + GetCodespaceToken(context.Context, string, string) (string, error) + GetCodespace(context.Context, string, string, string) (*api.Codespace, error) +} + +// pollForCodespace polls the Codespaces API every second fetching the codespace. +// If it succeeds at fetching the codespace, we consider the codespace provisioned. +// Context should be cancelled to stop polling. +func pollForCodespace( + ctx context.Context, client apiClient, log *output.Logger, user *api.User, provisioningCodespace *api.Codespace, +) (*api.Codespace, error) { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + log.Print(".") + token, err := client.GetCodespaceToken(ctx, user.Login, provisioningCodespace.Name) + if err != nil { + // Do nothing. We expect this to fail until the codespace is provisioned + continue + } + + codespace, err := client.GetCodespace(ctx, token, user.Login, provisioningCodespace.Name) + if err != nil { + return nil, fmt.Errorf("failed to get codespace: %w", err) + } + + return codespace, nil + } + } } // showStatus polls the codespace for a list of post create states and their status. It will keep polling diff --git a/cmd/ghcs/create_test.go b/cmd/ghcs/create_test.go new file mode 100644 index 000000000..36769dc14 --- /dev/null +++ b/cmd/ghcs/create_test.go @@ -0,0 +1,93 @@ +package main + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/github/ghcs/cmd/ghcs/output" + "github.com/github/ghcs/internal/api" +) + +type mockAPIClient struct { + getCodespaceToken func(context.Context, string, string) (string, error) + getCodespace func(context.Context, string, string, string) (*api.Codespace, error) +} + +func (m *mockAPIClient) GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) { + if m.getCodespaceToken == nil { + return "", errors.New("mock api client GetCodespaceToken not implemented") + } + + return m.getCodespaceToken(ctx, userLogin, codespaceName) +} + +func (m *mockAPIClient) GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*api.Codespace, error) { + if m.getCodespace == nil { + return nil, errors.New("mock api client GetCodespace not implemented") + } + + return m.getCodespace(ctx, token, userLogin, codespaceName) +} + +func TestPollForCodespace(t *testing.T) { + logger := output.NewLogger(nil, nil, false) + user := &api.User{Login: "test"} + tmpCodespace := &api.Codespace{Name: "tmp-codespace"} + codespaceToken := "codespace-token" + + ctxTimeout := 1 * time.Second + exceedTime := 2 * time.Second + exceedProvisioningTime := false + + api := &mockAPIClient{ + getCodespaceToken: func(ctx context.Context, userLogin, codespace string) (string, error) { + if exceedProvisioningTime { + ticker := time.NewTicker(exceedTime) + defer ticker.Stop() + <-ticker.C + } + if userLogin != user.Login { + return "", fmt.Errorf("user does not match, got: %s, expected: %s", userLogin, user.Login) + } + if codespace != tmpCodespace.Name { + return "", fmt.Errorf("codespace does not match, got: %s, expected: %s", codespace, tmpCodespace.Name) + } + return codespaceToken, nil + }, + getCodespace: func(ctx context.Context, token, userLogin, codespace string) (*api.Codespace, error) { + if token != codespaceToken { + return nil, fmt.Errorf("token does not match, got: %s, expected: %s", token, codespaceToken) + } + if userLogin != user.Login { + return nil, fmt.Errorf("user does not match, got: %s, expected: %s", userLogin, user.Login) + } + if codespace != tmpCodespace.Name { + return nil, fmt.Errorf("codespace does not match, got: %s, expected: %s", codespace, tmpCodespace.Name) + } + return tmpCodespace, nil + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + codespace, err := pollForCodespace(ctx, api, logger, user, tmpCodespace) + if err != nil { + t.Error(err) + } + if tmpCodespace.Name != codespace.Name { + t.Errorf("returned codespace does not match, got: %s, expected: %s", codespace.Name, tmpCodespace.Name) + } + + exceedProvisioningTime = true + ctx, cancel = context.WithTimeout(ctx, ctxTimeout) + defer cancel() + + _, err = pollForCodespace(ctx, api, logger, user, tmpCodespace) + if err == nil { + t.Error("expected context deadline exceeded error, got nil") + } +} From b3b675d108d02f32b24ad69b33f1dacdd5e85c1d Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 21 Sep 2021 12:44:30 -0400 Subject: [PATCH 252/290] Merge NewClient and JoinWorkspace into Connect --- client.go | 105 +++++++++++++++++++++++++----------------------- client_test.go | 46 ++------------------- connection.go | 2 +- session_test.go | 9 +---- 4 files changed, 61 insertions(+), 101 deletions(-) diff --git a/client.go b/client.go index ba9d2f5e7..3f9345ce4 100644 --- a/client.go +++ b/client.go @@ -1,3 +1,13 @@ +// Package liveshare is a Go client library for the Visual Studio Live Share +// service, which provides collaborative, distibuted editing and debugging. +// See https://docs.microsoft.com/en-us/visualstudio/liveshare for an overview. +// +// It provides the ability for a Go program to connect to a Live Share +// workspace (Connect), to expose a TCP port on a remote host +// (UpdateSharedVisibility), to start an SSH server listening on an +// exposed port (StartSSHServer), and to forward connections between +// the remote port and a local listening TCP port (ForwardToListener) +// or a local Go reader/writer (Forward). package liveshare import ( @@ -9,66 +19,79 @@ import ( "golang.org/x/crypto/ssh" ) -// A Client capable of joining a Live Share workspace. -type Client struct { +// A client capable of joining a Live Share workspace. +type client struct { connection Connection tlsConfig *tls.Config } -// A ClientOption is a function that modifies a client -type ClientOption func(*Client) error +// An Option updates the initial configuration state of a Live Share connection. +type Option func(*client) error -// NewClient accepts a range of options, applies them and returns a client -func NewClient(opts ...ClientOption) (*Client, error) { - client := new(Client) - - for _, o := range opts { - if err := o(client); err != nil { - return nil, err - } - } - - return client, nil -} - -// WithConnection is a ClientOption that accepts a Connection -func WithConnection(connection Connection) ClientOption { - return func(c *Client) error { +// WithConnection is a Option that accepts a Connection. +// +// TODO(adonovan): WithConnection is not optional, so it should not be +// not an Option. We should make Connection a mandatory parameter of +// Connect, at which point, why not just merge +// client+Option+Connection, rename it to Options, do away with the +// function mechanism, and express TLS config (etc) as public fields +// of Options with sensible zero values, like websocket.Dialer, etc? +func WithConnection(connection Connection) Option { + return func(cli *client) error { if err := connection.validate(); err != nil { return err } - c.connection = connection + cli.connection = connection return nil } } -func WithTLSConfig(tlsConfig *tls.Config) ClientOption { - return func(c *Client) error { - c.tlsConfig = tlsConfig +// WithTLSConfig returns a Connect option that sets the TLS configuration. +func WithTLSConfig(tlsConfig *tls.Config) Option { + return func(cli *client) error { + cli.tlsConfig = tlsConfig return nil } } -// JoinWorkspace connects the client to the server's Live Share -// workspace and returns a session representing their connection. -func (c *Client) JoinWorkspace(ctx context.Context) (*Session, error) { - span, ctx := opentracing.StartSpanFromContext(ctx, "Client.JoinWorkspace") +// Connect connects to a Live Share workspace specified by the +// options, and returns a session representing the connection. +// The caller must call the session's Close method to end the session. +func Connect(ctx context.Context, opts ...Option) (*Session, error) { + cli := new(client) + for _, opt := range opts { + if err := opt(cli); err != nil { + return nil, fmt.Errorf("error applying Live Share connect option: %w", err) + } + } + + span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") defer span.Finish() - clientSocket := newSocket(c.connection, c.tlsConfig) - if err := clientSocket.connect(ctx); err != nil { + sock := newSocket(cli.connection, cli.tlsConfig) + if err := sock.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting websocket: %w", err) } - ssh := newSSHSession(c.connection.SessionToken, clientSocket) + ssh := newSSHSession(cli.connection.SessionToken, sock) if err := ssh.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting to ssh session: %w", err) } rpc := newRPCClient(ssh) rpc.connect(ctx) - if _, err := c.joinWorkspace(ctx, rpc); err != nil { + + args := joinWorkspaceArgs{ + ID: cli.connection.SessionID, + ConnectionMode: "local", + JoiningUserSessionToken: cli.connection.SessionToken, + ClientCapabilities: clientCapabilities{ + IsNonInteractive: false, + }, + } + var result joinWorkspaceResult + if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { return nil, fmt.Errorf("error joining Live Share workspace: %w", err) } @@ -96,24 +119,6 @@ type channelID struct { name, condition string } -func (c *Client) joinWorkspace(ctx context.Context, rpc *rpcClient) (*joinWorkspaceResult, error) { - args := joinWorkspaceArgs{ - ID: c.connection.SessionID, - ConnectionMode: "local", - JoiningUserSessionToken: c.connection.SessionToken, - ClientCapabilities: clientCapabilities{ - IsNonInteractive: false, - }, - } - - var result joinWorkspaceResult - if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil { - return nil, fmt.Errorf("error making workspace.joinWorkspace call: %w", err) - } - - return &result, nil -} - func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) { type getStreamArgs struct { StreamName string `json:"streamName"` diff --git a/client_test.go b/client_test.go index c1e61f6e8..369c53b28 100644 --- a/client_test.go +++ b/client_test.go @@ -13,37 +13,7 @@ import ( "github.com/sourcegraph/jsonrpc2" ) -func TestNewClient(t *testing.T) { - client, err := NewClient() - if err != nil { - t.Errorf("error creating new client: %v", err) - } - if client == nil { - t.Error("client is nil") - } -} - -func TestNewClientValidConnection(t *testing.T) { - connection := Connection{"1", "2", "3", "4"} - - client, err := NewClient(WithConnection(connection)) - if err != nil { - t.Errorf("error creating new client: %v", err) - } - if client == nil { - t.Error("client is nil") - } -} - -func TestNewClientWithInvalidConnection(t *testing.T) { - connection := Connection{} - - if _, err := NewClient(WithConnection(connection)); err == nil { - t.Error("err is nil") - } -} - -func TestJoinSession(t *testing.T) { +func TestConnect(t *testing.T) { connection := Connection{ SessionID: "session-id", SessionToken: "session-token", @@ -83,21 +53,11 @@ func TestJoinSession(t *testing.T) { ctx := context.Background() tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) - client, err := NewClient(WithConnection(connection), tlsConfig) - if err != nil { - t.Errorf("error creating new client: %v", err) - } done := make(chan error) go func() { - session, err := client.JoinWorkspace(ctx) - if err != nil { - done <- fmt.Errorf("error joining workspace: %v", err) - return - } - _ = session - - done <- nil + _, err := Connect(ctx, WithConnection(connection), tlsConfig) // ignore session + done <- err }() select { diff --git a/connection.go b/connection.go index c1a4632c8..f402e4bb9 100644 --- a/connection.go +++ b/connection.go @@ -6,7 +6,7 @@ import ( "strings" ) -// A Connection represents a set of values necessary to join a liveshare connection +// A Connection represents a set of values necessary to join a liveshare connection. type Connection struct { SessionID string SessionToken string diff --git a/session_test.go b/session_test.go index 54aab16c8..3be90cb0e 100644 --- a/session_test.go +++ b/session_test.go @@ -32,14 +32,9 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, ) connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https") tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) - client, err := NewClient(WithConnection(connection), tlsConfig) + session, err := Connect(context.Background(), WithConnection(connection), tlsConfig) if err != nil { - return nil, nil, fmt.Errorf("error creating new client: %v", err) - } - ctx := context.Background() - session, err := client.JoinWorkspace(ctx) - if err != nil { - return nil, nil, fmt.Errorf("error joining workspace: %v", err) + return nil, nil, fmt.Errorf("error connecting to Live Share: %v", err) } return testServer, session, nil } From 861811baf03d461e0d89113b07c80ff414c4e146 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 21 Sep 2021 14:02:05 -0400 Subject: [PATCH 253/290] Upgrade pkg name after merge --- cmd/ghcs/create_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/create_test.go b/cmd/ghcs/create_test.go index 36769dc14..3df900afc 100644 --- a/cmd/ghcs/create_test.go +++ b/cmd/ghcs/create_test.go @@ -1,4 +1,4 @@ -package main +package ghcs import ( "context" From 678da44c28506629da1feb53b34efbe59d38b7f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Tue, 21 Sep 2021 21:09:26 +0200 Subject: [PATCH 254/290] Simplify delete further --- cmd/ghcs/common.go | 2 +- cmd/ghcs/delete.go | 63 ++++++++++++++++------------ cmd/ghcs/delete_test.go | 91 ++++------------------------------------ cmd/ghcs/list.go | 2 +- internal/api/api.go | 16 ++++--- internal/api/api_test.go | 6 +-- 6 files changed, 58 insertions(+), 122 deletions(-) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index a3963d22f..4ebc89a2d 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -19,7 +19,7 @@ import ( var errNoCodespaces = errors.New("you have no codespaces") func chooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (*api.Codespace, error) { - codespaces, err := apiClient.ListCodespaces(ctx, user) + codespaces, err := apiClient.ListCodespaces(ctx, user.Login) if err != nil { return nil, fmt.Errorf("error getting codespaces: %w", err) } diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index a8005f1e9..c08344163 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -2,13 +2,13 @@ package ghcs import ( "context" + "errors" "fmt" "os" "strings" "time" "github.com/AlecAivazis/survey/v2" - "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" @@ -17,19 +17,32 @@ import ( type deleteOptions struct { deleteAll bool skipConfirm bool - isInteractive bool codespaceName string repoFilter string keepDays uint16 + + isInteractive bool now func() time.Time - apiClient *api.API + apiClient apiClient + prompter prompter +} + +type prompter interface { + Confirm(message string) (bool, error) +} + +type apiClient interface { + GetUser(ctx context.Context) (*api.User, error) + ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) + DeleteCodespace(ctx context.Context, user, name string) error } func newDeleteCmd() *cobra.Command { opts := deleteOptions{ - apiClient: api.New(os.Getenv("GITHUB_TOKEN")), - now: time.Now, isInteractive: hasTTY, + now: time.Now, + apiClient: api.New(os.Getenv("GITHUB_TOKEN")), + prompter: &surveyPrompter{}, } deleteCmd := &cobra.Command{ @@ -37,15 +50,10 @@ func newDeleteCmd() *cobra.Command { Short: "Delete a codespace", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - // switch { - // case allCodespaces && repo != "": - // return errors.New("both --all and --repo is not supported") - // case allCodespaces: - // return deleteAll(log, force, keepThresholdDays) - // case repo != "": - // return deleteByRepo(log, repo, force, keepThresholdDays) - log := output.NewLogger(os.Stdout, os.Stderr, false) - return delete(context.Background(), log, opts) + if opts.deleteAll && opts.repoFilter != "" { + return errors.New("both --all and --repo is not supported") + } + return delete(context.Background(), opts) }, } @@ -58,13 +66,13 @@ func newDeleteCmd() *cobra.Command { return deleteCmd } -func delete(ctx context.Context, log *output.Logger, opts deleteOptions) error { +func delete(ctx context.Context, opts deleteOptions) error { user, err := opts.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespaces, err := opts.apiClient.ListCodespaces(ctx, user) + codespaces, err := opts.apiClient.ListCodespaces(ctx, user.Login) if err != nil { return fmt.Errorf("error getting codespaces: %w", err) } @@ -78,7 +86,7 @@ func delete(ctx context.Context, log *output.Logger, opts deleteOptions) error { nameFilter = c.Name } - var codespacesToDelete []*api.Codespace + codespacesToDelete := make([]*api.Codespace, 0, len(codespaces)) lastUpdatedCutoffTime := opts.now().AddDate(0, 0, -int(opts.keepDays)) for _, c := range codespaces { if nameFilter != "" && c.Name != nameFilter { @@ -97,9 +105,9 @@ func delete(ctx context.Context, log *output.Logger, opts deleteOptions) error { } } if nameFilter == "" || !opts.skipConfirm { - confirmed, err := confirmDeletion(c) + confirmed, err := confirmDeletion(opts.prompter, c, opts.isInteractive) if err != nil { - return fmt.Errorf("deletion could not be confirmed: %w", err) + return fmt.Errorf("unable to confirm: %w", err) } if !confirmed { continue @@ -112,11 +120,7 @@ func delete(ctx context.Context, log *output.Logger, opts deleteOptions) error { for _, c := range codespacesToDelete { codespaceName := c.Name g.Go(func() error { - token, err := opts.apiClient.GetCodespaceToken(ctx, user.Login, codespaceName) - if err != nil { - return fmt.Errorf("error getting codespace token: %w", err) - } - if err := opts.apiClient.DeleteCodespace(ctx, user, token, codespaceName); err != nil { + if err := opts.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil { return fmt.Errorf("error deleting codespace: %w", err) } return nil @@ -126,16 +130,21 @@ func delete(ctx context.Context, log *output.Logger, opts deleteOptions) error { return g.Wait() } -func confirmDeletion(codespace *api.Codespace) (bool, error) { +func confirmDeletion(p prompter, codespace *api.Codespace, isInteractive bool) (bool, error) { gs := codespace.Environment.GitStatus hasUnsavedChanges := gs.HasUncommitedChanges || gs.HasUnpushedChanges if !hasUnsavedChanges { return true, nil } - if !hasTTY { + if !isInteractive { return false, fmt.Errorf("codespace %s has unsaved changes (use --force to override)", codespace.Name) } + return p.Confirm(fmt.Sprintf("Codespace %s has unsaved changes. OK to delete?", codespace.Name)) +} +type surveyPrompter struct{} + +func (p *surveyPrompter) Confirm(message string) (bool, error) { var confirmed struct { Confirmed bool } @@ -143,7 +152,7 @@ func confirmDeletion(codespace *api.Codespace) (bool, error) { { Name: "confirmed", Prompt: &survey.Confirm{ - Message: fmt.Sprintf("Codespace %s has unsaved changes. OK to delete?", codespace.Name), + Message: message, }, }, } diff --git a/cmd/ghcs/delete_test.go b/cmd/ghcs/delete_test.go index c43331bea..783ad80e6 100644 --- a/cmd/ghcs/delete_test.go +++ b/cmd/ghcs/delete_test.go @@ -1,103 +1,28 @@ package ghcs import ( + "context" "testing" - "time" - - "github.com/github/ghcs/internal/api" ) -func TestFilterCodespacesToDelete(t *testing.T) { - type args struct { - codespaces []*api.Codespace - thresholdDays int - } +func TestDelete(t *testing.T) { tests := []struct { name string - now time.Time - args args + opts deleteOptions wantErr bool - deleted []*api.Codespace }{ { - name: "no codespaces is to be deleted", - - args: args{ - codespaces: []*api.Codespace{ - { - Name: "testcodespace", - CreatedAt: "2021-08-09T10:10:24+02:00", - LastUsedAt: "2021-08-09T13:10:24+02:00", - Environment: api.CodespaceEnvironment{ - State: "Shutdown", - }, - }, - }, - thresholdDays: 1, + name: "by name", + opts: deleteOptions{ + codespaceName: "foo-bar-123", }, - now: time.Date(2021, 8, 9, 20, 10, 24, 0, time.UTC), - deleted: []*api.Codespace{}, - }, - { - name: "one codespace is to be deleted", - - args: args{ - codespaces: []*api.Codespace{ - { - Name: "testcodespace", - CreatedAt: "2021-08-09T10:10:24+02:00", - LastUsedAt: "2021-08-09T13:10:24+02:00", - Environment: api.CodespaceEnvironment{ - State: "Shutdown", - }, - }, - }, - thresholdDays: 1, - }, - now: time.Date(2021, 8, 15, 20, 12, 24, 0, time.UTC), - deleted: []*api.Codespace{ - { - Name: "testcodespace", - CreatedAt: "2021-08-09T10:10:24+02:00", - LastUsedAt: "2021-08-09T13:10:24+02:00", - }, - }, - }, - { - name: "threshold is invalid", - - args: args{ - codespaces: []*api.Codespace{ - { - Name: "testcodespace", - CreatedAt: "2021-08-09T10:10:24+02:00", - LastUsedAt: "2021-08-09T13:10:24+02:00", - Environment: api.CodespaceEnvironment{ - State: "Shutdown", - }, - }, - }, - thresholdDays: -1, - }, - now: time.Date(2021, 8, 15, 20, 12, 24, 0, time.UTC), - wantErr: true, - deleted: []*api.Codespace{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - - now = func() time.Time { - return tt.now - } - - codespaces, err := filterCodespacesToDelete(tt.args.codespaces, tt.args.thresholdDays) + err := delete(context.Background(), tt.opts) if (err != nil) != tt.wantErr { - t.Errorf("API.CleanupUnusedCodespaces() error = %v, wantErr %v", err, tt.wantErr) - } - - if len(codespaces) != len(tt.deleted) { - t.Errorf("expected %d deleted codespaces, got %d", len(tt.deleted), len(codespaces)) + t.Errorf("delete() error = %v, wantErr %v", err, tt.wantErr) } }) } diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 7ee156012..85eabaef5 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -40,7 +40,7 @@ func list(opts *listOptions) error { return fmt.Errorf("error getting user: %w", err) } - codespaces, err := apiClient.ListCodespaces(ctx, user) + codespaces, err := apiClient.ListCodespaces(ctx, user.Login) if err != nil { return fmt.Errorf("error getting codespaces: %w", err) } diff --git a/internal/api/api.go b/internal/api/api.go index 12d5a7263..4d4078c9c 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -32,11 +32,12 @@ import ( "encoding/json" "errors" "fmt" - "github.com/opentracing/opentracing-go" "io/ioutil" "net/http" "strconv" "strings" + + "github.com/opentracing/opentracing-go" ) const githubAPI = "https://api.github.com" @@ -172,9 +173,9 @@ type CodespaceEnvironmentConnection struct { RelaySAS string `json:"relaySas"` } -func (a *API) ListCodespaces(ctx context.Context, user *User) ([]*Codespace, error) { +func (a *API) ListCodespaces(ctx context.Context, user string) ([]*Codespace, error) { req, err := http.NewRequest( - http.MethodGet, a.githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces", nil, + http.MethodGet, a.githubAPI+"/vscs_internal/user/"+user+"/codespaces", nil, ) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) @@ -442,8 +443,13 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos return &response, nil } -func (a *API) DeleteCodespace(ctx context.Context, user *User, token, codespaceName string) error { - req, err := http.NewRequest(http.MethodDelete, a.githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces/"+codespaceName, nil) +func (a *API) DeleteCodespace(ctx context.Context, user string, codespaceName string) error { + token, err := a.GetCodespaceToken(ctx, user, codespaceName) + if err != nil { + return fmt.Errorf("error getting codespace token: %w", err) + } + + req, err := http.NewRequest(http.MethodDelete, a.githubAPI+"/vscs_internal/user/"+user+"/codespaces/"+codespaceName, nil) if err != nil { return fmt.Errorf("error creating request: %w", err) } diff --git a/internal/api/api_test.go b/internal/api/api_test.go index c1f4e5c19..6fb162030 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -10,10 +10,6 @@ import ( ) func TestListCodespaces(t *testing.T) { - user := &User{ - Login: "testuser", - } - codespaces := []*Codespace{ { Name: "testcodespace", @@ -38,7 +34,7 @@ func TestListCodespaces(t *testing.T) { token: "faketoken", } ctx := context.TODO() - codespaces, err := api.ListCodespaces(ctx, user) + codespaces, err := api.ListCodespaces(ctx, "testuser") if err != nil { t.Fatal(err) } From f8a8713520f031758a2b75dc70c5faaea2927ea5 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Tue, 21 Sep 2021 15:23:02 -0400 Subject: [PATCH 255/290] refactor Options API --- client.go | 77 ++++++++++++++++++++++------------------------ client_test.go | 16 +++++----- connection.go | 44 -------------------------- connection_test.go | 41 ------------------------ options_test.go | 56 +++++++++++++++++++++++++++++++++ session_test.go | 22 ++++++------- socket.go | 4 +-- 7 files changed, 113 insertions(+), 147 deletions(-) delete mode 100644 connection.go delete mode 100644 connection_test.go create mode 100644 options_test.go diff --git a/client.go b/client.go index 3f9345ce4..b51e25ea6 100644 --- a/client.go +++ b/client.go @@ -13,68 +13,65 @@ package liveshare import ( "context" "crypto/tls" + "errors" "fmt" + "net/url" + "strings" "github.com/opentracing/opentracing-go" "golang.org/x/crypto/ssh" ) -// A client capable of joining a Live Share workspace. -type client struct { - connection Connection - tlsConfig *tls.Config +// An Options specifies Live Share connection parameters. +type Options struct { + SessionID string + SessionToken string // token for SSH session + RelaySAS string + RelayEndpoint string + TLSConfig *tls.Config // (optional) } -// An Option updates the initial configuration state of a Live Share connection. -type Option func(*client) error - -// WithConnection is a Option that accepts a Connection. -// -// TODO(adonovan): WithConnection is not optional, so it should not be -// not an Option. We should make Connection a mandatory parameter of -// Connect, at which point, why not just merge -// client+Option+Connection, rename it to Options, do away with the -// function mechanism, and express TLS config (etc) as public fields -// of Options with sensible zero values, like websocket.Dialer, etc? -func WithConnection(connection Connection) Option { - return func(cli *client) error { - if err := connection.validate(); err != nil { - return err - } - - cli.connection = connection - return nil +// uri returns a websocket URL for the specified options. +func (opts *Options) uri(action string) (string, error) { + if opts.SessionID == "" { + return "", errors.New("SessionID is required") } -} - -// WithTLSConfig returns a Connect option that sets the TLS configuration. -func WithTLSConfig(tlsConfig *tls.Config) Option { - return func(cli *client) error { - cli.tlsConfig = tlsConfig - return nil + if opts.RelaySAS == "" { + return "", errors.New("RelaySAS is required") } + if opts.RelayEndpoint == "" { + return "", errors.New("RelayEndpoint is required") + } + + sas := url.QueryEscape(opts.RelaySAS) + uri := opts.RelayEndpoint + uri = strings.Replace(uri, "sb:", "wss:", -1) + uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) + uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas + return uri, nil } // Connect connects to a Live Share workspace specified by the // options, and returns a session representing the connection. // The caller must call the session's Close method to end the session. -func Connect(ctx context.Context, opts ...Option) (*Session, error) { - cli := new(client) - for _, opt := range opts { - if err := opt(cli); err != nil { - return nil, fmt.Errorf("error applying Live Share connect option: %w", err) - } +func Connect(ctx context.Context, opts Options) (*Session, error) { + uri, err := opts.uri("connect") + if err != nil { + return nil, err } span, ctx := opentracing.StartSpanFromContext(ctx, "Connect") defer span.Finish() - sock := newSocket(cli.connection, cli.tlsConfig) + sock := newSocket(uri, opts.TLSConfig) if err := sock.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting websocket: %w", err) } - ssh := newSSHSession(cli.connection.SessionToken, sock) + if opts.SessionToken == "" { + return nil, errors.New("SessionToken is required") + } + ssh := newSSHSession(opts.SessionToken, sock) if err := ssh.connect(ctx); err != nil { return nil, fmt.Errorf("error connecting to ssh session: %w", err) } @@ -83,9 +80,9 @@ func Connect(ctx context.Context, opts ...Option) (*Session, error) { rpc.connect(ctx) args := joinWorkspaceArgs{ - ID: cli.connection.SessionID, + ID: opts.SessionID, ConnectionMode: "local", - JoiningUserSessionToken: cli.connection.SessionToken, + JoiningUserSessionToken: opts.SessionToken, ClientCapabilities: clientCapabilities{ IsNonInteractive: false, }, diff --git a/client_test.go b/client_test.go index 369c53b28..2b95f738f 100644 --- a/client_test.go +++ b/client_test.go @@ -14,7 +14,7 @@ import ( ) func TestConnect(t *testing.T) { - connection := Connection{ + opts := Options{ SessionID: "session-id", SessionToken: "session-token", RelaySAS: "relay-sas", @@ -24,13 +24,13 @@ func TestConnect(t *testing.T) { if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { return nil, fmt.Errorf("error unmarshaling req: %v", err) } - if joinWorkspaceReq.ID != connection.SessionID { + if joinWorkspaceReq.ID != opts.SessionID { return nil, errors.New("connection session id does not match") } if joinWorkspaceReq.ConnectionMode != "local" { return nil, errors.New("connection mode is not local") } - if joinWorkspaceReq.JoiningUserSessionToken != connection.SessionToken { + if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken { return nil, errors.New("connection user token does not match") } if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false { @@ -40,23 +40,23 @@ func TestConnect(t *testing.T) { } server, err := livesharetest.NewServer( - livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithPassword(opts.SessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), - livesharetest.WithRelaySAS(connection.RelaySAS), + livesharetest.WithRelaySAS(opts.RelaySAS), ) if err != nil { t.Errorf("error creating Live Share server: %v", err) } defer server.Close() - connection.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") + opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") ctx := context.Background() - tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} done := make(chan error) go func() { - _, err := Connect(ctx, WithConnection(connection), tlsConfig) // ignore session + _, err := Connect(ctx, opts) // ignore session done <- err }() diff --git a/connection.go b/connection.go deleted file mode 100644 index f402e4bb9..000000000 --- a/connection.go +++ /dev/null @@ -1,44 +0,0 @@ -package liveshare - -import ( - "errors" - "net/url" - "strings" -) - -// A Connection represents a set of values necessary to join a liveshare connection. -type Connection struct { - SessionID string - SessionToken string - RelaySAS string - RelayEndpoint string -} - -func (r Connection) validate() error { - if r.SessionID == "" { - return errors.New("connection SessionID is required") - } - - if r.SessionToken == "" { - return errors.New("connection SessionToken is required") - } - - if r.RelaySAS == "" { - return errors.New("connection RelaySAS is required") - } - - if r.RelayEndpoint == "" { - return errors.New("connection RelayEndpoint is required") - } - - return nil -} - -func (r Connection) uri(action string) string { - sas := url.QueryEscape(r.RelaySAS) - uri := r.RelayEndpoint - uri = strings.Replace(uri, "sb:", "wss:", -1) - uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1) - uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas - return uri -} diff --git a/connection_test.go b/connection_test.go deleted file mode 100644 index f42ec4189..000000000 --- a/connection_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package liveshare - -import "testing" - -func TestConnectionValid(t *testing.T) { - conn := Connection{"sess-id", "sess-token", "sas", "endpoint"} - if err := conn.validate(); err != nil { - t.Error(err) - } -} - -func TestConnectionInvalid(t *testing.T) { - conn := Connection{"", "sess-token", "sas", "endpoint"} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"sess-id", "", "sas", "endpoint"} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"sess-id", "sess-token", "", "endpoint"} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"sess-id", "sess-token", "sas", ""} - if err := conn.validate(); err == nil { - t.Error(err) - } - conn = Connection{"", "", "", ""} - if err := conn.validate(); err == nil { - t.Error(err) - } -} - -func TestConnectionURI(t *testing.T) { - conn := Connection{"sess-id", "sess-token", "sas", "sb://endpoint/.net/liveshare"} - uri := conn.uri("connect") - if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" { - t.Errorf("uri is not correct, got: '%v'", uri) - } -} diff --git a/options_test.go b/options_test.go new file mode 100644 index 000000000..830c59104 --- /dev/null +++ b/options_test.go @@ -0,0 +1,56 @@ +package liveshare + +import ( + "context" + "testing" +) + +func TestBadOptions(t *testing.T) { + goodOptions := Options{ + SessionID: "sess-id", + SessionToken: "sess-token", + RelaySAS: "sas", + RelayEndpoint: "endpoint", + } + + opts := goodOptions + opts.SessionID = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.SessionToken = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.RelaySAS = "" + checkBadOptions(t, opts) + + opts = goodOptions + opts.RelayEndpoint = "" + checkBadOptions(t, opts) + + opts = Options{} + checkBadOptions(t, opts) +} + +func checkBadOptions(t *testing.T, opts Options) { + if _, err := Connect(context.Background(), opts); err == nil { + t.Errorf("Connect(%+v): no error", opts) + } +} + +func TestOptionsURI(t *testing.T) { + opts := Options{ + SessionID: "sess-id", + SessionToken: "sess-token", + RelaySAS: "sas", + RelayEndpoint: "sb://endpoint/.net/liveshare", + } + uri, err := opts.uri("connect") + if err != nil { + t.Fatal(err) + } + if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" { + t.Errorf("uri is not correct, got: '%v'", uri) + } +} diff --git a/session_test.go b/session_test.go index 3be90cb0e..cd0a7b474 100644 --- a/session_test.go +++ b/session_test.go @@ -14,25 +14,23 @@ import ( ) func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) { - connection := Connection{ - SessionID: "session-id", - SessionToken: "session-token", - RelaySAS: "relay-sas", - } joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { return joinWorkspaceResult{1}, nil } + const sessionToken = "session-token" opts = append( opts, - livesharetest.WithPassword(connection.SessionToken), + livesharetest.WithPassword(sessionToken), livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), ) - testServer, err := livesharetest.NewServer( - opts..., - ) - connection.RelayEndpoint = "sb" + strings.TrimPrefix(testServer.URL(), "https") - tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) - session, err := Connect(context.Background(), WithConnection(connection), tlsConfig) + testServer, err := livesharetest.NewServer(opts...) + session, err := Connect(context.Background(), Options{ + SessionID: "session-id", + SessionToken: sessionToken, + RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"), + RelaySAS: "relay-sas", + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + }) if err != nil { return nil, nil, fmt.Errorf("error connecting to Live Share: %v", err) } diff --git a/socket.go b/socket.go index 8744eeb96..f66436f65 100644 --- a/socket.go +++ b/socket.go @@ -19,8 +19,8 @@ type socket struct { reader io.Reader } -func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket { - return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig} +func newSocket(uri string, tlsConfig *tls.Config) *socket { + return &socket{addr: uri, tlsConfig: tlsConfig} } func (s *socket) connect(ctx context.Context) error { From 48e3473a953b9502a840bf9ea40fb818dca05f5b Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Tue, 21 Sep 2021 18:18:30 -0400 Subject: [PATCH 256/290] PR Feedback - Bring context.Timeout into the poller - Accept duration and interval - Other tidy up --- cmd/ghcs/create.go | 31 ++++++++++++------------------- cmd/ghcs/create_test.go | 26 ++++++++------------------ 2 files changed, 20 insertions(+), 37 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index fd54a170c..a13807094 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -84,19 +84,17 @@ func create(opts *createOptions) error { log.Println("Creating your codespace...") - codespace, err := apiClient.CreateCodespace( - ctx, userResult.User, repository, machine, branch, locationResult.Location, - ) + codespace, err := apiClient.CreateCodespace(ctx, userResult.User, repository, machine, branch, locationResult.Location) if err != nil { // This error is returned by the API when the initial creation fails with a retryable error. // A retryable error means that GitHub will retry to re-create Codespace and clients should poll // the API and attempt to fetch the Codespace for the next two minutes. if err == api.ErrCreateAsyncRetry { log.Print("Switching to async provisioning...") - pollctx, cancel := context.WithTimeout(ctx, 2*time.Minute) - defer cancel() - codespace, err = pollForCodespace(pollctx, apiClient, log, userResult.User, codespace) + pollTimeout := 2 * time.Minute + pollInterval := 1 * time.Second + codespace, err = pollForCodespace(ctx, apiClient, log, pollTimeout, pollInterval, userResult.User.Login, codespace.Name) log.Print("\n") if err != nil { @@ -125,13 +123,13 @@ type apiClient interface { GetCodespace(context.Context, string, string, string) (*api.Codespace, error) } -// pollForCodespace polls the Codespaces API every second fetching the codespace. +// pollForCodespace polls the Codespaces GET endpoint on a given interval for a specified duration. // If it succeeds at fetching the codespace, we consider the codespace provisioned. -// Context should be cancelled to stop polling. -func pollForCodespace( - ctx context.Context, client apiClient, log *output.Logger, user *api.User, provisioningCodespace *api.Codespace, -) (*api.Codespace, error) { - ticker := time.NewTicker(1 * time.Second) +func pollForCodespace(ctx context.Context, client apiClient, log *output.Logger, duration, interval time.Duration, user, name string) (*api.Codespace, error) { + ctx, cancel := context.WithTimeout(ctx, duration) + defer cancel() + + ticker := time.NewTicker(interval) defer ticker.Stop() for { @@ -140,18 +138,13 @@ func pollForCodespace( return nil, ctx.Err() case <-ticker.C: log.Print(".") - token, err := client.GetCodespaceToken(ctx, user.Login, provisioningCodespace.Name) + token, err := client.GetCodespaceToken(ctx, user, name) if err != nil { // Do nothing. We expect this to fail until the codespace is provisioned continue } - codespace, err := client.GetCodespace(ctx, token, user.Login, provisioningCodespace.Name) - if err != nil { - return nil, fmt.Errorf("failed to get codespace: %w", err) - } - - return codespace, nil + return client.GetCodespace(ctx, token, user, name) } } } diff --git a/cmd/ghcs/create_test.go b/cmd/ghcs/create_test.go index 3df900afc..e86fa00e6 100644 --- a/cmd/ghcs/create_test.go +++ b/cmd/ghcs/create_test.go @@ -37,18 +37,13 @@ func TestPollForCodespace(t *testing.T) { user := &api.User{Login: "test"} tmpCodespace := &api.Codespace{Name: "tmp-codespace"} codespaceToken := "codespace-token" + ctx := context.Background() - ctxTimeout := 1 * time.Second - exceedTime := 2 * time.Second - exceedProvisioningTime := false + pollInterval := 50 * time.Millisecond + pollTimeout := 100 * time.Millisecond api := &mockAPIClient{ getCodespaceToken: func(ctx context.Context, userLogin, codespace string) (string, error) { - if exceedProvisioningTime { - ticker := time.NewTicker(exceedTime) - defer ticker.Stop() - <-ticker.C - } if userLogin != user.Login { return "", fmt.Errorf("user does not match, got: %s, expected: %s", userLogin, user.Login) } @@ -71,10 +66,7 @@ func TestPollForCodespace(t *testing.T) { }, } - ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) - defer cancel() - - codespace, err := pollForCodespace(ctx, api, logger, user, tmpCodespace) + codespace, err := pollForCodespace(ctx, api, logger, pollTimeout, pollInterval, user.Login, tmpCodespace.Name) if err != nil { t.Error(err) } @@ -82,12 +74,10 @@ func TestPollForCodespace(t *testing.T) { t.Errorf("returned codespace does not match, got: %s, expected: %s", codespace.Name, tmpCodespace.Name) } - exceedProvisioningTime = true - ctx, cancel = context.WithTimeout(ctx, ctxTimeout) - defer cancel() - - _, err = pollForCodespace(ctx, api, logger, user, tmpCodespace) - if err == nil { + // swap the durations to trigger a timeout + pollTimeout, pollInterval = pollInterval, pollTimeout + _, err = pollForCodespace(ctx, api, logger, pollTimeout, pollInterval, user.Login, tmpCodespace.Name) + if err != context.DeadlineExceeded { t.Error("expected context deadline exceeded error, got nil") } } From 86717f14a1a6c0ba1f9f45f55367863541c69d53 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 22 Sep 2021 09:09:09 -0400 Subject: [PATCH 257/290] Implement codespaces.Provision - Move polling logic into the Provision function - Document the behavior expected of callers when an ErrCreateAsyncRetry is returned --- cmd/ghcs/create.go | 56 ++------------- internal/api/api.go | 4 ++ internal/codespaces/codespaces.go | 68 +++++++++++++++++++ .../codespaces/codespaces_test.go | 11 ++- 4 files changed, 89 insertions(+), 50 deletions(-) rename cmd/ghcs/create_test.go => internal/codespaces/codespaces_test.go (85%) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index a13807094..91c8bcb8d 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "strings" - "time" "github.com/AlecAivazis/survey/v2" "github.com/fatih/camelcase" @@ -84,24 +83,14 @@ func create(opts *createOptions) error { log.Println("Creating your codespace...") - codespace, err := apiClient.CreateCodespace(ctx, userResult.User, repository, machine, branch, locationResult.Location) + codespace, err := codespaces.Provision(ctx, log, apiClient, &codespaces.ProvisionParams{ + User: userResult.User, + Repository: repository, + Branch: branch, + Machine: machine, + Location: locationResult.Location, + }) if err != nil { - // This error is returned by the API when the initial creation fails with a retryable error. - // A retryable error means that GitHub will retry to re-create Codespace and clients should poll - // the API and attempt to fetch the Codespace for the next two minutes. - if err == api.ErrCreateAsyncRetry { - log.Print("Switching to async provisioning...") - - pollTimeout := 2 * time.Minute - pollInterval := 1 * time.Second - codespace, err = pollForCodespace(ctx, apiClient, log, pollTimeout, pollInterval, userResult.User.Login, codespace.Name) - log.Print("\n") - - if err != nil { - return fmt.Errorf("error creating codespace with async provisioning: %s: %w", codespace.Name, err) - } - } - return fmt.Errorf("error creating codespace: %w", err) } @@ -118,37 +107,6 @@ func create(opts *createOptions) error { return nil } -type apiClient interface { - GetCodespaceToken(context.Context, string, string) (string, error) - GetCodespace(context.Context, string, string, string) (*api.Codespace, error) -} - -// pollForCodespace polls the Codespaces GET endpoint on a given interval for a specified duration. -// If it succeeds at fetching the codespace, we consider the codespace provisioned. -func pollForCodespace(ctx context.Context, client apiClient, log *output.Logger, duration, interval time.Duration, user, name string) (*api.Codespace, error) { - ctx, cancel := context.WithTimeout(ctx, duration) - defer cancel() - - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-ticker.C: - log.Print(".") - token, err := client.GetCodespaceToken(ctx, user, name) - if err != nil { - // Do nothing. We expect this to fail until the codespace is provisioned - continue - } - - return client.GetCodespace(ctx, token, user, name) - } - } -} - // showStatus polls the codespace for a list of post create states and their status. It will keep polling // until all states have finished. Once all states have finished, we poll once more to check if any new // states have been introduced and stop polling otherwise. diff --git a/internal/api/api.go b/internal/api/api.go index 394efc6af..ad9177eff 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -431,6 +431,10 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos case resp.StatusCode > http.StatusAccepted: return nil, jsonErrorResponse(b) case resp.StatusCode == http.StatusAccepted: + // When the API returns a 202, it means that the initial creation failed but it is + // being retried. For clients this means that they must implement a polling strategy + // to check for the codespace existence for the next two minutes. We return an error + // here so callers can detect and handle this condition. return nil, ErrCreateAsyncRetry } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 2933c9d8d..c67f88c3b 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -75,3 +75,71 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use return lsclient.JoinWorkspace(ctx) } + +type apiClient interface { + CreateCodespace(ctx context.Context, user *api.User, repo *api.Repository, machine, branch, location string) (*api.Codespace, error) + GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) + GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*api.Codespace, error) +} + +// ProvisionParams are the required parameters for provisioning a Codespace. +type ProvisionParams struct { + User *api.User + Repository *api.Repository + Branch, Machine, Location string +} + +// Provision creates a codespace with the given parameters and handles polling in the case +// of initial creation failures. +func Provision(ctx context.Context, log logger, client apiClient, params *ProvisionParams) (*api.Codespace, error) { + codespace, err := client.CreateCodespace( + ctx, params.User, params.Repository, params.Machine, params.Branch, params.Location, + ) + if err != nil { + // This error is returned by the API when the initial creation fails with a retryable error. + // A retryable error means that GitHub will retry to re-create Codespace and clients should poll + // the API and attempt to fetch the Codespace for the next two minutes. + if err == api.ErrCreateAsyncRetry { + log.Print("Switching to async provisioning...") + + pollTimeout := 2 * time.Minute + pollInterval := 1 * time.Second + codespace, err = pollForCodespace(ctx, client, log, pollTimeout, pollInterval, params.User.Login, codespace.Name) + log.Print("\n") + + if err != nil { + return nil, fmt.Errorf("error creating codespace with async provisioning: %s: %w", codespace.Name, err) + } + } + + return nil, err + } + + return codespace, nil +} + +// pollForCodespace polls the Codespaces GET endpoint on a given interval for a specified duration. +// If it succeeds at fetching the codespace, we consider the codespace provisioned. +func pollForCodespace(ctx context.Context, client apiClient, log logger, duration, interval time.Duration, user, name string) (*api.Codespace, error) { + ctx, cancel := context.WithTimeout(ctx, duration) + defer cancel() + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + log.Print(".") + token, err := client.GetCodespaceToken(ctx, user, name) + if err != nil { + // Do nothing. We expect this to fail until the codespace is provisioned + continue + } + + return client.GetCodespace(ctx, token, user, name) + } + } +} diff --git a/cmd/ghcs/create_test.go b/internal/codespaces/codespaces_test.go similarity index 85% rename from cmd/ghcs/create_test.go rename to internal/codespaces/codespaces_test.go index e86fa00e6..53aba0557 100644 --- a/cmd/ghcs/create_test.go +++ b/internal/codespaces/codespaces_test.go @@ -1,4 +1,4 @@ -package ghcs +package codespaces import ( "context" @@ -12,10 +12,19 @@ import ( ) type mockAPIClient struct { + createCodespace func(context.Context, *api.User, *api.Repository, string, string, string) (*api.Codespace, error) getCodespaceToken func(context.Context, string, string) (string, error) getCodespace func(context.Context, string, string, string) (*api.Codespace, error) } +func (m *mockAPIClient) CreateCodespace(ctx context.Context, user *api.User, repo *api.Repository, machine, branch, location string) (*api.Codespace, error) { + if m.createCodespace == nil { + return nil, errors.New("mock api client CreateCodespace not implemented") + } + + return m.createCodespace(ctx, user, repo, machine, branch, location) +} + func (m *mockAPIClient) GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) { if m.getCodespaceToken == nil { return "", errors.New("mock api client GetCodespaceToken not implemented") From 2a0ea1617b3fccd06d29c4e4b81fd6d4b815fa15 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 22 Sep 2021 09:40:45 -0400 Subject: [PATCH 258/290] Handle specific error for GetCodespaceToken --- internal/api/api.go | 7 +++++++ internal/codespaces/codespaces.go | 8 ++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index ad9177eff..58d01a428 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -208,6 +208,8 @@ type getCodespaceTokenResponse struct { RepositoryToken string `json:"repository_token"` } +var ErrNotProvisioned = errors.New("codespace not provisioned") + func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName string) (string, error) { reqBody, err := json.Marshal(getCodespaceTokenRequest{true}) if err != nil { @@ -236,6 +238,11 @@ func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName s } if resp.StatusCode != http.StatusOK { + + if resp.StatusCode == http.StatusUnprocessableEntity { + return "", ErrNotProvisioned + } + return "", jsonErrorResponse(b) } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index c67f88c3b..1d60bcd21 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -135,8 +135,12 @@ func pollForCodespace(ctx context.Context, client apiClient, log logger, duratio log.Print(".") token, err := client.GetCodespaceToken(ctx, user, name) if err != nil { - // Do nothing. We expect this to fail until the codespace is provisioned - continue + if err == api.ErrNotProvisioned { + // Do nothing. We expect this to fail until the codespace is provisioned + continue + } + + return nil, fmt.Errorf("failed to get codespace token: %w", err) } return client.GetCodespace(ctx, token, user, name) From 8c5330d9e9691289c29bd6130efbca265985023c Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 22 Sep 2021 10:04:18 -0400 Subject: [PATCH 259/290] Rename error --- internal/api/api.go | 4 ++-- internal/codespaces/codespaces.go | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index 58d01a428..6b19b0703 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -409,7 +409,7 @@ type createCodespaceRequest struct { SkuName string `json:"sku_name"` } -var ErrCreateAsyncRetry = errors.New("initial creation failed, retrying async") +var ErrProvisioningInProgress = errors.New("provisioning in progress") func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repository, sku, branch, location string) (*Codespace, error) { requestBody, err := json.Marshal(createCodespaceRequest{repository.ID, branch, location, sku}) @@ -442,7 +442,7 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos // being retried. For clients this means that they must implement a polling strategy // to check for the codespace existence for the next two minutes. We return an error // here so callers can detect and handle this condition. - return nil, ErrCreateAsyncRetry + return nil, ErrProvisioningInProgress } var response Codespace diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 1d60bcd21..8a0e21b3d 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -99,11 +99,10 @@ func Provision(ctx context.Context, log logger, client apiClient, params *Provis // This error is returned by the API when the initial creation fails with a retryable error. // A retryable error means that GitHub will retry to re-create Codespace and clients should poll // the API and attempt to fetch the Codespace for the next two minutes. - if err == api.ErrCreateAsyncRetry { - log.Print("Switching to async provisioning...") - + if err == api.ErrProvisioningInProgress { pollTimeout := 2 * time.Minute pollInterval := 1 * time.Second + log.Print(".") codespace, err = pollForCodespace(ctx, client, log, pollTimeout, pollInterval, params.User.Login, codespace.Name) log.Print("\n") From cb7b535b917ffddc38f12069b32fbce5a4034eb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 22 Sep 2021 16:11:34 +0200 Subject: [PATCH 260/290] Add tests for delete --- cmd/ghcs/delete.go | 10 ++- cmd/ghcs/delete_test.go | 169 +++++++++++++++++++++++++++++++++-- cmd/ghcs/mock_api.go | 180 ++++++++++++++++++++++++++++++++++++++ cmd/ghcs/mock_prompter.go | 73 ++++++++++++++++ 4 files changed, 425 insertions(+), 7 deletions(-) create mode 100644 cmd/ghcs/mock_api.go create mode 100644 cmd/ghcs/mock_prompter.go diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index c08344163..3408f08d7 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -27,10 +27,12 @@ type deleteOptions struct { prompter prompter } +//go:generate moq -fmt goimports -rm -out mock_prompter.go . prompter type prompter interface { Confirm(message string) (bool, error) } +//go:generate moq -fmt goimports -rm -out mock_api.go . apiClient type apiClient interface { GetUser(ctx context.Context) (*api.User, error) ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) @@ -57,9 +59,9 @@ func newDeleteCmd() *cobra.Command { }, } - deleteCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "Delete codespace by `name`") + deleteCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "The `name` of the codespace to delete") deleteCmd.Flags().BoolVar(&opts.deleteAll, "all", false, "Delete all codespaces") - deleteCmd.Flags().StringVarP(&opts.repoFilter, "repo", "r", "", "Delete codespaces for a repository") + deleteCmd.Flags().StringVarP(&opts.repoFilter, "repo", "r", "", "Delete codespaces for a `repository`") deleteCmd.Flags().BoolVarP(&opts.skipConfirm, "force", "f", false, "Skip confirmation for codespaces that contain unsaved changes") deleteCmd.Flags().Uint16Var(&opts.keepDays, "days", 0, "Delete codespaces older than `N` days") @@ -116,6 +118,10 @@ func delete(ctx context.Context, opts deleteOptions) error { codespacesToDelete = append(codespacesToDelete, c) } + if len(codespacesToDelete) == 0 { + return errors.New("no codespaces to delete") + } + g := errgroup.Group{} for _, c := range codespacesToDelete { codespaceName := c.Name diff --git a/cmd/ghcs/delete_test.go b/cmd/ghcs/delete_test.go index 783ad80e6..754254494 100644 --- a/cmd/ghcs/delete_test.go +++ b/cmd/ghcs/delete_test.go @@ -2,28 +2,187 @@ package ghcs import ( "context" + "fmt" + "sort" "testing" + "time" + + "github.com/github/ghcs/internal/api" ) func TestDelete(t *testing.T) { + user := &api.User{Login: "hubot"} + now, _ := time.Parse(time.RFC3339, "2021-09-22T00:00:00Z") + daysAgo := func(n int) string { + return now.Add(time.Hour * -time.Duration(24*n)).Format(time.RFC3339) + } + tests := []struct { - name string - opts deleteOptions - wantErr bool + name string + opts deleteOptions + codespaces []*api.Codespace + confirms map[string]bool + wantErr bool + wantDeleted []string }{ { name: "by name", opts: deleteOptions{ - codespaceName: "foo-bar-123", + codespaceName: "hubot-robawt-abc", }, + codespaces: []*api.Codespace{ + { + Name: "monalisa-spoonknife-123", + }, + { + Name: "hubot-robawt-abc", + }, + }, + wantDeleted: []string{"hubot-robawt-abc"}, + }, + { + name: "by repo", + opts: deleteOptions{ + repoFilter: "monalisa/spoon-knife", + }, + codespaces: []*api.Codespace{ + { + Name: "monalisa-spoonknife-123", + RepositoryNWO: "monalisa/Spoon-Knife", + }, + { + Name: "hubot-robawt-abc", + RepositoryNWO: "hubot/ROBAWT", + }, + { + Name: "monalisa-spoonknife-c4f3", + RepositoryNWO: "monalisa/Spoon-Knife", + }, + }, + wantDeleted: []string{"monalisa-spoonknife-123", "monalisa-spoonknife-c4f3"}, + }, + { + name: "unused", + opts: deleteOptions{ + deleteAll: true, + keepDays: 3, + }, + codespaces: []*api.Codespace{ + { + Name: "monalisa-spoonknife-123", + LastUsedAt: daysAgo(1), + }, + { + Name: "hubot-robawt-abc", + LastUsedAt: daysAgo(4), + }, + { + Name: "monalisa-spoonknife-c4f3", + LastUsedAt: daysAgo(10), + }, + }, + wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"}, + }, + { + name: "with confirm", + opts: deleteOptions{ + isInteractive: true, + deleteAll: true, + skipConfirm: false, + }, + codespaces: []*api.Codespace{ + { + Name: "monalisa-spoonknife-123", + Environment: api.CodespaceEnvironment{ + GitStatus: api.CodespaceEnvironmentGitStatus{ + HasUnpushedChanges: true, + }, + }, + }, + { + Name: "hubot-robawt-abc", + Environment: api.CodespaceEnvironment{ + GitStatus: api.CodespaceEnvironmentGitStatus{ + HasUncommitedChanges: true, + }, + }, + }, + { + Name: "monalisa-spoonknife-c4f3", + Environment: api.CodespaceEnvironment{ + GitStatus: api.CodespaceEnvironmentGitStatus{ + HasUnpushedChanges: false, + HasUncommitedChanges: false, + }, + }, + }, + }, + confirms: map[string]bool{ + "Codespace monalisa-spoonknife-123 has unsaved changes. OK to delete?": false, + "Codespace hubot-robawt-abc has unsaved changes. OK to delete?": true, + }, + wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := delete(context.Background(), tt.opts) + apiMock := &apiClientMock{ + GetUserFunc: func(_ context.Context) (*api.User, error) { + return user, nil + }, + ListCodespacesFunc: func(_ context.Context, userLogin string) ([]*api.Codespace, error) { + if userLogin != user.Login { + return nil, fmt.Errorf("unexpected user %q", userLogin) + } + return tt.codespaces, nil + }, + DeleteCodespaceFunc: func(_ context.Context, userLogin, name string) error { + if userLogin != user.Login { + return fmt.Errorf("unexpected user %q", userLogin) + } + return nil + }, + } + opts := tt.opts + opts.apiClient = apiMock + opts.now = func() time.Time { return now } + opts.prompter = &prompterMock{ + ConfirmFunc: func(msg string) (bool, error) { + res, found := tt.confirms[msg] + if !found { + return false, fmt.Errorf("unexpected prompt %q", msg) + } + return res, nil + }, + } + + err := delete(context.Background(), opts) if (err != nil) != tt.wantErr { t.Errorf("delete() error = %v, wantErr %v", err, tt.wantErr) } + if n := len(apiMock.GetUserCalls()); n != 1 { + t.Errorf("GetUser invoked %d times, expected %d", n, 1) + } + var gotDeleted []string + for _, delArgs := range apiMock.DeleteCodespaceCalls() { + gotDeleted = append(gotDeleted, delArgs.Name) + } + sort.Strings(gotDeleted) + if !sliceEquals(gotDeleted, tt.wantDeleted) { + t.Errorf("deleted %q, want %q", gotDeleted, tt.wantDeleted) + } }) } } + +func sliceEquals(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/cmd/ghcs/mock_api.go b/cmd/ghcs/mock_api.go new file mode 100644 index 000000000..46edd2835 --- /dev/null +++ b/cmd/ghcs/mock_api.go @@ -0,0 +1,180 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package ghcs + +import ( + "context" + "sync" + + "github.com/github/ghcs/internal/api" +) + +// Ensure, that apiClientMock does implement apiClient. +// If this is not the case, regenerate this file with moq. +var _ apiClient = &apiClientMock{} + +// apiClientMock is a mock implementation of apiClient. +// +// func TestSomethingThatUsesapiClient(t *testing.T) { +// +// // make and configure a mocked apiClient +// mockedapiClient := &apiClientMock{ +// DeleteCodespaceFunc: func(ctx context.Context, user string, name string) error { +// panic("mock out the DeleteCodespace method") +// }, +// GetUserFunc: func(ctx context.Context) (*api.User, error) { +// panic("mock out the GetUser method") +// }, +// ListCodespacesFunc: func(ctx context.Context, user string) ([]*api.Codespace, error) { +// panic("mock out the ListCodespaces method") +// }, +// } +// +// // use mockedapiClient in code that requires apiClient +// // and then make assertions. +// +// } +type apiClientMock struct { + // DeleteCodespaceFunc mocks the DeleteCodespace method. + DeleteCodespaceFunc func(ctx context.Context, user string, name string) error + + // GetUserFunc mocks the GetUser method. + GetUserFunc func(ctx context.Context) (*api.User, error) + + // ListCodespacesFunc mocks the ListCodespaces method. + ListCodespacesFunc func(ctx context.Context, user string) ([]*api.Codespace, error) + + // calls tracks calls to the methods. + calls struct { + // DeleteCodespace holds details about calls to the DeleteCodespace method. + DeleteCodespace []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User string + // Name is the name argument value. + Name string + } + // GetUser holds details about calls to the GetUser method. + GetUser []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // ListCodespaces holds details about calls to the ListCodespaces method. + ListCodespaces []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User string + } + } + lockDeleteCodespace sync.RWMutex + lockGetUser sync.RWMutex + lockListCodespaces sync.RWMutex +} + +// DeleteCodespace calls DeleteCodespaceFunc. +func (mock *apiClientMock) DeleteCodespace(ctx context.Context, user string, name string) error { + if mock.DeleteCodespaceFunc == nil { + panic("apiClientMock.DeleteCodespaceFunc: method is nil but apiClient.DeleteCodespace was just called") + } + callInfo := struct { + Ctx context.Context + User string + Name string + }{ + Ctx: ctx, + User: user, + Name: name, + } + mock.lockDeleteCodespace.Lock() + mock.calls.DeleteCodespace = append(mock.calls.DeleteCodespace, callInfo) + mock.lockDeleteCodespace.Unlock() + return mock.DeleteCodespaceFunc(ctx, user, name) +} + +// DeleteCodespaceCalls gets all the calls that were made to DeleteCodespace. +// Check the length with: +// len(mockedapiClient.DeleteCodespaceCalls()) +func (mock *apiClientMock) DeleteCodespaceCalls() []struct { + Ctx context.Context + User string + Name string +} { + var calls []struct { + Ctx context.Context + User string + Name string + } + mock.lockDeleteCodespace.RLock() + calls = mock.calls.DeleteCodespace + mock.lockDeleteCodespace.RUnlock() + return calls +} + +// GetUser calls GetUserFunc. +func (mock *apiClientMock) GetUser(ctx context.Context) (*api.User, error) { + if mock.GetUserFunc == nil { + panic("apiClientMock.GetUserFunc: method is nil but apiClient.GetUser was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockGetUser.Lock() + mock.calls.GetUser = append(mock.calls.GetUser, callInfo) + mock.lockGetUser.Unlock() + return mock.GetUserFunc(ctx) +} + +// GetUserCalls gets all the calls that were made to GetUser. +// Check the length with: +// len(mockedapiClient.GetUserCalls()) +func (mock *apiClientMock) GetUserCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockGetUser.RLock() + calls = mock.calls.GetUser + mock.lockGetUser.RUnlock() + return calls +} + +// ListCodespaces calls ListCodespacesFunc. +func (mock *apiClientMock) ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) { + if mock.ListCodespacesFunc == nil { + panic("apiClientMock.ListCodespacesFunc: method is nil but apiClient.ListCodespaces was just called") + } + callInfo := struct { + Ctx context.Context + User string + }{ + Ctx: ctx, + User: user, + } + mock.lockListCodespaces.Lock() + mock.calls.ListCodespaces = append(mock.calls.ListCodespaces, callInfo) + mock.lockListCodespaces.Unlock() + return mock.ListCodespacesFunc(ctx, user) +} + +// ListCodespacesCalls gets all the calls that were made to ListCodespaces. +// Check the length with: +// len(mockedapiClient.ListCodespacesCalls()) +func (mock *apiClientMock) ListCodespacesCalls() []struct { + Ctx context.Context + User string +} { + var calls []struct { + Ctx context.Context + User string + } + mock.lockListCodespaces.RLock() + calls = mock.calls.ListCodespaces + mock.lockListCodespaces.RUnlock() + return calls +} diff --git a/cmd/ghcs/mock_prompter.go b/cmd/ghcs/mock_prompter.go new file mode 100644 index 000000000..e15209c03 --- /dev/null +++ b/cmd/ghcs/mock_prompter.go @@ -0,0 +1,73 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package ghcs + +import ( + "sync" +) + +// Ensure, that prompterMock does implement prompter. +// If this is not the case, regenerate this file with moq. +var _ prompter = &prompterMock{} + +// prompterMock is a mock implementation of prompter. +// +// func TestSomethingThatUsesprompter(t *testing.T) { +// +// // make and configure a mocked prompter +// mockedprompter := &prompterMock{ +// ConfirmFunc: func(message string) (bool, error) { +// panic("mock out the Confirm method") +// }, +// } +// +// // use mockedprompter in code that requires prompter +// // and then make assertions. +// +// } +type prompterMock struct { + // ConfirmFunc mocks the Confirm method. + ConfirmFunc func(message string) (bool, error) + + // calls tracks calls to the methods. + calls struct { + // Confirm holds details about calls to the Confirm method. + Confirm []struct { + // Message is the message argument value. + Message string + } + } + lockConfirm sync.RWMutex +} + +// Confirm calls ConfirmFunc. +func (mock *prompterMock) Confirm(message string) (bool, error) { + if mock.ConfirmFunc == nil { + panic("prompterMock.ConfirmFunc: method is nil but prompter.Confirm was just called") + } + callInfo := struct { + Message string + }{ + Message: message, + } + mock.lockConfirm.Lock() + mock.calls.Confirm = append(mock.calls.Confirm, callInfo) + mock.lockConfirm.Unlock() + return mock.ConfirmFunc(message) +} + +// ConfirmCalls gets all the calls that were made to Confirm. +// Check the length with: +// len(mockedprompter.ConfirmCalls()) +func (mock *prompterMock) ConfirmCalls() []struct { + Message string +} { + var calls []struct { + Message string + } + mock.lockConfirm.RLock() + calls = mock.calls.Confirm + mock.lockConfirm.RUnlock() + return calls +} From 32d3a38465ef15e8e7b305dccfef31dbc05c07f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 22 Sep 2021 16:39:50 +0200 Subject: [PATCH 261/290] Name of the codespace --- cmd/ghcs/delete.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 3408f08d7..4f1c71842 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -59,7 +59,7 @@ func newDeleteCmd() *cobra.Command { }, } - deleteCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "The `name` of the codespace to delete") + deleteCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "Name of the codespace") deleteCmd.Flags().BoolVar(&opts.deleteAll, "all", false, "Delete all codespaces") deleteCmd.Flags().StringVarP(&opts.repoFilter, "repo", "r", "", "Delete codespaces for a `repository`") deleteCmd.Flags().BoolVarP(&opts.skipConfirm, "force", "f", false, "Skip confirmation for codespaces that contain unsaved changes") From d2d21996bc1a12a24f3b757e2fbc2ae933aa8a5e Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 22 Sep 2021 11:49:41 -0400 Subject: [PATCH 262/290] Move ProvisionCodespace to API client - Make CreateCodespace private along with its errors --- cmd/ghcs/create.go | 2 +- internal/api/api.go | 82 ++++++++++++++++++- .../codespaces_test.go => api/api_test.go} | 22 ++--- internal/codespaces/codespaces.go | 71 ---------------- 4 files changed, 86 insertions(+), 91 deletions(-) rename internal/{codespaces/codespaces_test.go => api/api_test.go} (77%) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 91c8bcb8d..eb6d1bea6 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -83,7 +83,7 @@ func create(opts *createOptions) error { log.Println("Creating your codespace...") - codespace, err := codespaces.Provision(ctx, log, apiClient, &codespaces.ProvisionParams{ + codespace, err := apiClient.ProvisionCodespace(ctx, log, &api.ProvisionCodespaceParams{ User: userResult.User, Repository: repository, Branch: branch, diff --git a/internal/api/api.go b/internal/api/api.go index 6b19b0703..c3ad0aadc 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -36,6 +36,7 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/opentracing/opentracing-go" ) @@ -402,6 +403,81 @@ func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Rep return response.SKUs, nil } +// ProvisionCodespaceParams are the required parameters for provisioning a Codespace. +type ProvisionCodespaceParams struct { + User *User + Repository *Repository + Branch, Machine, Location string +} + +type logger interface { + Print(v ...interface{}) (int, error) + Println(v ...interface{}) (int, error) +} + +// ProvisionCodespace creates a codespace with the given parameters and handles polling in the case +// of initial creation failures. +func (a *API) ProvisionCodespace(ctx context.Context, log logger, params *ProvisionCodespaceParams) (*Codespace, error) { + codespace, err := a.createCodespace( + ctx, params.User, params.Repository, params.Machine, params.Branch, params.Location, + ) + if err != nil { + // This error is returned by the API when the initial creation fails with a retryable error. + // A retryable error means that GitHub will retry to re-create Codespace and clients should poll + // the API and attempt to fetch the Codespace for the next two minutes. + if err == errProvisioningInProgress { + pollTimeout := 2 * time.Minute + pollInterval := 1 * time.Second + log.Print(".") + codespace, err = pollForCodespace(ctx, a, log, pollTimeout, pollInterval, params.User.Login, codespace.Name) + log.Print("\n") + + if err != nil { + return nil, fmt.Errorf("error creating codespace with async provisioning: %s: %w", codespace.Name, err) + } + } + + return nil, err + } + + return codespace, nil +} + +type apiClient interface { + GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) + GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*Codespace, error) +} + +// pollForCodespace polls the Codespaces GET endpoint on a given interval for a specified duration. +// If it succeeds at fetching the codespace, we consider the codespace provisioned. +func pollForCodespace(ctx context.Context, client apiClient, log logger, duration, interval time.Duration, user, name string) (*Codespace, error) { + ctx, cancel := context.WithTimeout(ctx, duration) + defer cancel() + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-ticker.C: + log.Print(".") + token, err := client.GetCodespaceToken(ctx, user, name) + if err != nil { + if err == ErrNotProvisioned { + // Do nothing. We expect this to fail until the codespace is provisioned + continue + } + + return nil, fmt.Errorf("failed to get codespace token: %w", err) + } + + return client.GetCodespace(ctx, token, user, name) + } + } +} + type createCodespaceRequest struct { RepositoryID int `json:"repository_id"` Ref string `json:"ref"` @@ -409,9 +485,9 @@ type createCodespaceRequest struct { SkuName string `json:"sku_name"` } -var ErrProvisioningInProgress = errors.New("provisioning in progress") +var errProvisioningInProgress = errors.New("provisioning in progress") -func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repository, sku, branch, location string) (*Codespace, error) { +func (a *API) createCodespace(ctx context.Context, user *User, repository *Repository, sku, branch, location string) (*Codespace, error) { requestBody, err := json.Marshal(createCodespaceRequest{repository.ID, branch, location, sku}) if err != nil { return nil, fmt.Errorf("error marshaling request: %w", err) @@ -442,7 +518,7 @@ func (a *API) CreateCodespace(ctx context.Context, user *User, repository *Repos // being retried. For clients this means that they must implement a polling strategy // to check for the codespace existence for the next two minutes. We return an error // here so callers can detect and handle this condition. - return nil, ErrProvisioningInProgress + return nil, errProvisioningInProgress } var response Codespace diff --git a/internal/codespaces/codespaces_test.go b/internal/api/api_test.go similarity index 77% rename from internal/codespaces/codespaces_test.go rename to internal/api/api_test.go index 53aba0557..eb5226a59 100644 --- a/internal/codespaces/codespaces_test.go +++ b/internal/api/api_test.go @@ -1,4 +1,4 @@ -package codespaces +package api import ( "context" @@ -8,21 +8,11 @@ import ( "time" "github.com/github/ghcs/cmd/ghcs/output" - "github.com/github/ghcs/internal/api" ) type mockAPIClient struct { - createCodespace func(context.Context, *api.User, *api.Repository, string, string, string) (*api.Codespace, error) getCodespaceToken func(context.Context, string, string) (string, error) - getCodespace func(context.Context, string, string, string) (*api.Codespace, error) -} - -func (m *mockAPIClient) CreateCodespace(ctx context.Context, user *api.User, repo *api.Repository, machine, branch, location string) (*api.Codespace, error) { - if m.createCodespace == nil { - return nil, errors.New("mock api client CreateCodespace not implemented") - } - - return m.createCodespace(ctx, user, repo, machine, branch, location) + getCodespace func(context.Context, string, string, string) (*Codespace, error) } func (m *mockAPIClient) GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) { @@ -33,7 +23,7 @@ func (m *mockAPIClient) GetCodespaceToken(ctx context.Context, userLogin, codesp return m.getCodespaceToken(ctx, userLogin, codespaceName) } -func (m *mockAPIClient) GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*api.Codespace, error) { +func (m *mockAPIClient) GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*Codespace, error) { if m.getCodespace == nil { return nil, errors.New("mock api client GetCodespace not implemented") } @@ -43,8 +33,8 @@ func (m *mockAPIClient) GetCodespace(ctx context.Context, token, userLogin, code func TestPollForCodespace(t *testing.T) { logger := output.NewLogger(nil, nil, false) - user := &api.User{Login: "test"} - tmpCodespace := &api.Codespace{Name: "tmp-codespace"} + user := &User{Login: "test"} + tmpCodespace := &Codespace{Name: "tmp-codespace"} codespaceToken := "codespace-token" ctx := context.Background() @@ -61,7 +51,7 @@ func TestPollForCodespace(t *testing.T) { } return codespaceToken, nil }, - getCodespace: func(ctx context.Context, token, userLogin, codespace string) (*api.Codespace, error) { + getCodespace: func(ctx context.Context, token, userLogin, codespace string) (*Codespace, error) { if token != codespaceToken { return nil, fmt.Errorf("token does not match, got: %s, expected: %s", token, codespaceToken) } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 8a0e21b3d..2933c9d8d 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -75,74 +75,3 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use return lsclient.JoinWorkspace(ctx) } - -type apiClient interface { - CreateCodespace(ctx context.Context, user *api.User, repo *api.Repository, machine, branch, location string) (*api.Codespace, error) - GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) - GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*api.Codespace, error) -} - -// ProvisionParams are the required parameters for provisioning a Codespace. -type ProvisionParams struct { - User *api.User - Repository *api.Repository - Branch, Machine, Location string -} - -// Provision creates a codespace with the given parameters and handles polling in the case -// of initial creation failures. -func Provision(ctx context.Context, log logger, client apiClient, params *ProvisionParams) (*api.Codespace, error) { - codespace, err := client.CreateCodespace( - ctx, params.User, params.Repository, params.Machine, params.Branch, params.Location, - ) - if err != nil { - // This error is returned by the API when the initial creation fails with a retryable error. - // A retryable error means that GitHub will retry to re-create Codespace and clients should poll - // the API and attempt to fetch the Codespace for the next two minutes. - if err == api.ErrProvisioningInProgress { - pollTimeout := 2 * time.Minute - pollInterval := 1 * time.Second - log.Print(".") - codespace, err = pollForCodespace(ctx, client, log, pollTimeout, pollInterval, params.User.Login, codespace.Name) - log.Print("\n") - - if err != nil { - return nil, fmt.Errorf("error creating codespace with async provisioning: %s: %w", codespace.Name, err) - } - } - - return nil, err - } - - return codespace, nil -} - -// pollForCodespace polls the Codespaces GET endpoint on a given interval for a specified duration. -// If it succeeds at fetching the codespace, we consider the codespace provisioned. -func pollForCodespace(ctx context.Context, client apiClient, log logger, duration, interval time.Duration, user, name string) (*api.Codespace, error) { - ctx, cancel := context.WithTimeout(ctx, duration) - defer cancel() - - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-ticker.C: - log.Print(".") - token, err := client.GetCodespaceToken(ctx, user, name) - if err != nil { - if err == api.ErrNotProvisioned { - // Do nothing. We expect this to fail until the codespace is provisioned - continue - } - - return nil, fmt.Errorf("failed to get codespace token: %w", err) - } - - return client.GetCodespace(ctx, token, user, name) - } - } -} From 70a2ea2e6aaf36cd8e9206adab640892c7892d0d Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 22 Sep 2021 13:19:26 -0400 Subject: [PATCH 263/290] PR Feedback - Rename ProvisionCodespace -> CreateCodespace - Rename createCodespace -> startCreate - Additional docs/comments - Simplify ProvisionCodespaceParams --- cmd/ghcs/create.go | 12 ++++++------ internal/api/api.go | 39 +++++++++++++++++++-------------------- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index eb6d1bea6..c0943549c 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -83,12 +83,12 @@ func create(opts *createOptions) error { log.Println("Creating your codespace...") - codespace, err := apiClient.ProvisionCodespace(ctx, log, &api.ProvisionCodespaceParams{ - User: userResult.User, - Repository: repository, - Branch: branch, - Machine: machine, - Location: locationResult.Location, + codespace, err := apiClient.CreateCodespace(ctx, log, &api.ProvisionCodespaceParams{ + User: userResult.User, + RepositoryID: repository, + Branch: branch, + Machine: machine, + Location: locationResult.Location, }) if err != nil { return fmt.Errorf("error creating codespace: %w", err) diff --git a/internal/api/api.go b/internal/api/api.go index c3ad0aadc..a1e580e0f 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -239,7 +239,6 @@ func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName s } if resp.StatusCode != http.StatusOK { - if resp.StatusCode == http.StatusUnprocessableEntity { return "", ErrNotProvisioned } @@ -405,8 +404,8 @@ func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Rep // ProvisionCodespaceParams are the required parameters for provisioning a Codespace. type ProvisionCodespaceParams struct { - User *User - Repository *Repository + User string + RepositoryID int Branch, Machine, Location string } @@ -415,21 +414,21 @@ type logger interface { Println(v ...interface{}) (int, error) } -// ProvisionCodespace creates a codespace with the given parameters and handles polling in the case -// of initial creation failures. -func (a *API) ProvisionCodespace(ctx context.Context, log logger, params *ProvisionCodespaceParams) (*Codespace, error) { - codespace, err := a.createCodespace( - ctx, params.User, params.Repository, params.Machine, params.Branch, params.Location, +// CreateCodespace creates a codespace with the given parameters and returns a non-nil error if it +// fails to create. +func (a *API) CreateCodespace(ctx context.Context, log logger, params *ProvisionCodespaceParams) (*Codespace, error) { + codespace, err := a.startCreate( + ctx, params.User, params.RepositoryID, params.Machine, params.Branch, params.Location, ) if err != nil { - // This error is returned by the API when the initial creation fails with a retryable error. - // A retryable error means that GitHub will retry to re-create Codespace and clients should poll - // the API and attempt to fetch the Codespace for the next two minutes. + // errProvisioningInProgress indicates that codespace creation did not complete + // within the GitHub API RPC time limit (10s), so it continues asynchronously. + // We must poll the server to discover the outcome. if err == errProvisioningInProgress { pollTimeout := 2 * time.Minute pollInterval := 1 * time.Second log.Print(".") - codespace, err = pollForCodespace(ctx, a, log, pollTimeout, pollInterval, params.User.Login, codespace.Name) + codespace, err = pollForCodespace(ctx, a, log, pollTimeout, pollInterval, params.User, codespace.Name) log.Print("\n") if err != nil { @@ -487,13 +486,17 @@ type createCodespaceRequest struct { var errProvisioningInProgress = errors.New("provisioning in progress") -func (a *API) createCodespace(ctx context.Context, user *User, repository *Repository, sku, branch, location string) (*Codespace, error) { - requestBody, err := json.Marshal(createCodespaceRequest{repository.ID, branch, location, sku}) +// startCreate starts the creation of a codespace. +// It may return success or an error, or errProvisioningInProgress indicating that the operation +// did not complete before the GitHub API's time limit for RPCs (10s), in which case the caller +// must poll the server to learn the outcome. +func (a *API) startCreate(ctx context.Context, user string, repository int, sku, branch, location string) (*Codespace, error) { + requestBody, err := json.Marshal(createCodespaceRequest{repository, branch, location, sku}) if err != nil { return nil, fmt.Errorf("error marshaling request: %w", err) } - req, err := http.NewRequest(http.MethodPost, githubAPI+"/vscs_internal/user/"+user.Login+"/codespaces", bytes.NewBuffer(requestBody)) + req, err := http.NewRequest(http.MethodPost, githubAPI+"/vscs_internal/user/"+user+"/codespaces", bytes.NewBuffer(requestBody)) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) } @@ -514,11 +517,7 @@ func (a *API) createCodespace(ctx context.Context, user *User, repository *Repos case resp.StatusCode > http.StatusAccepted: return nil, jsonErrorResponse(b) case resp.StatusCode == http.StatusAccepted: - // When the API returns a 202, it means that the initial creation failed but it is - // being retried. For clients this means that they must implement a polling strategy - // to check for the codespace existence for the next two minutes. We return an error - // here so callers can detect and handle this condition. - return nil, errProvisioningInProgress + return nil, errProvisioningInProgress // RPC finished before result of creation known } var response Codespace From 208f1721b5a29834c9f6420b765b95dd41ce7020 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 22 Sep 2021 13:21:02 -0400 Subject: [PATCH 264/290] Rename ProvisionCodespaceParams --- cmd/ghcs/create.go | 6 +++--- internal/api/api.go | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index c0943549c..e37c1a200 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -83,9 +83,9 @@ func create(opts *createOptions) error { log.Println("Creating your codespace...") - codespace, err := apiClient.CreateCodespace(ctx, log, &api.ProvisionCodespaceParams{ - User: userResult.User, - RepositoryID: repository, + codespace, err := apiClient.CreateCodespace(ctx, log, &api.CreateCodespaceParams{ + User: userResult.User.Login, + RepositoryID: repository.ID, Branch: branch, Machine: machine, Location: locationResult.Location, diff --git a/internal/api/api.go b/internal/api/api.go index a1e580e0f..273c64435 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -402,8 +402,8 @@ func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Rep return response.SKUs, nil } -// ProvisionCodespaceParams are the required parameters for provisioning a Codespace. -type ProvisionCodespaceParams struct { +// CreateCodespaceParams are the required parameters for provisioning a Codespace. +type CreateCodespaceParams struct { User string RepositoryID int Branch, Machine, Location string @@ -416,7 +416,7 @@ type logger interface { // CreateCodespace creates a codespace with the given parameters and returns a non-nil error if it // fails to create. -func (a *API) CreateCodespace(ctx context.Context, log logger, params *ProvisionCodespaceParams) (*Codespace, error) { +func (a *API) CreateCodespace(ctx context.Context, log logger, params *CreateCodespaceParams) (*Codespace, error) { codespace, err := a.startCreate( ctx, params.User, params.RepositoryID, params.Machine, params.Branch, params.Location, ) From a55f7af92c5e35491ed002e26b7105caf6d1fa5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 22 Sep 2021 19:36:25 +0200 Subject: [PATCH 265/290] Correct wrong args constraints --- cmd/ghcs/code.go | 2 +- cmd/ghcs/logs.go | 2 +- cmd/ghcs/ports.go | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 9f09438d5..4a4259e42 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -19,7 +19,7 @@ func newCodeCmd() *cobra.Command { codeCmd := &cobra.Command{ Use: "code", Short: "Open a codespace in VS Code", - Args: cobra.MaximumNArgs(1), + Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { return code(codespace, useInsiders) }, diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 514c36966..74a763a72 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -24,7 +24,7 @@ func newLogsCmd() *cobra.Command { logsCmd := &cobra.Command{ Use: "logs", Short: "Access codespace logs", - Args: cobra.MaximumNArgs(1), + Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { return logs(context.Background(), log, codespace, follow) }, diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index aeecf0a07..2e93d78ae 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -158,7 +158,7 @@ func newPortsPublicCmd() *cobra.Command { return &cobra.Command{ Use: "public ", Short: "Mark port as public", - Args: cobra.MinimumNArgs(1), + Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { codespace, err := cmd.Flags().GetString("codespace") if err != nil { @@ -179,7 +179,7 @@ func newPortsPrivateCmd() *cobra.Command { return &cobra.Command{ Use: "private ", Short: "Mark port as private", - Args: cobra.MinimumNArgs(1), + Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { codespace, err := cmd.Flags().GetString("codespace") if err != nil { From 7a91ba5942f6535ce840312594ec5fcc630be5d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 22 Sep 2021 19:51:12 +0200 Subject: [PATCH 266/290] Print usage help when args given to "NoArgs" commands --- cmd/ghcs/code.go | 2 +- cmd/ghcs/common.go | 10 ++++++++++ cmd/ghcs/create.go | 2 +- cmd/ghcs/delete.go | 2 +- cmd/ghcs/list.go | 2 +- cmd/ghcs/logs.go | 2 +- cmd/ghcs/main/main.go | 11 ++++++++--- cmd/ghcs/ports.go | 2 +- 8 files changed, 24 insertions(+), 9 deletions(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 4a4259e42..08d2cff1a 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -19,7 +19,7 @@ func newCodeCmd() *cobra.Command { codeCmd := &cobra.Command{ Use: "code", Short: "Open a codespace in VS Code", - Args: cobra.NoArgs, + Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { return code(codespace, useInsiders) }, diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index 4ebc89a2d..371ca30b8 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -13,6 +13,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/AlecAivazis/survey/v2/terminal" "github.com/github/ghcs/internal/api" + "github.com/spf13/cobra" "golang.org/x/term" ) @@ -144,3 +145,12 @@ func checkAuthorizedKeys(ctx context.Context, client *api.API, user string) erro } return nil // success } + +var ErrTooManyArgs = errors.New("the command accepts no arguments") + +func noArgsConstraint(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + return ErrTooManyArgs + } + return nil +} diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 45aa794e6..fbc5e099a 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -28,7 +28,7 @@ func newCreateCmd() *cobra.Command { createCmd := &cobra.Command{ Use: "create", Short: "Create a codespace", - Args: cobra.NoArgs, + Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { return create(opts) }, diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 4f1c71842..defc46883 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -50,7 +50,7 @@ func newDeleteCmd() *cobra.Command { deleteCmd := &cobra.Command{ Use: "delete", Short: "Delete a codespace", - Args: cobra.NoArgs, + Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { if opts.deleteAll && opts.repoFilter != "" { return errors.New("both --all and --repo is not supported") diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 85eabaef5..065b7aa6d 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -20,7 +20,7 @@ func newListCmd() *cobra.Command { listCmd := &cobra.Command{ Use: "list", Short: "List your codespaces", - Args: cobra.NoArgs, + Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { return list(opts) }, diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 74a763a72..01f677cf2 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -24,7 +24,7 @@ func newLogsCmd() *cobra.Command { logsCmd := &cobra.Command{ Use: "logs", Short: "Access codespace logs", - Args: cobra.NoArgs, + Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { return logs(context.Background(), log, codespace, follow) }, diff --git a/cmd/ghcs/main/main.go b/cmd/ghcs/main/main.go index 01dde1270..6b890d740 100644 --- a/cmd/ghcs/main/main.go +++ b/cmd/ghcs/main/main.go @@ -7,20 +7,25 @@ import ( "os" "github.com/github/ghcs/cmd/ghcs" + "github.com/spf13/cobra" ) func main() { rootCmd := ghcs.NewRootCmd() - if err := rootCmd.Execute(); err != nil { - explainError(os.Stderr, err) + if cmd, err := rootCmd.ExecuteC(); err != nil { + explainError(os.Stderr, err, cmd) os.Exit(1) } } -func explainError(w io.Writer, err error) { +func explainError(w io.Writer, err error, cmd *cobra.Command) { if errors.Is(err, ghcs.ErrTokenMissing) { fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo") fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") return } + if errors.Is(err, ghcs.ErrTooManyArgs) { + _ = cmd.Usage() + return + } } diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 2e93d78ae..1e6021809 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -31,7 +31,7 @@ func newPortsCmd() *cobra.Command { portsCmd := &cobra.Command{ Use: "ports", Short: "List ports in a codespace", - Args: cobra.NoArgs, + Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { return ports(codespace, asJSON) }, From 9a558bc58c0d6d2c9a50f6123242ba0e9bec1257 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 22 Sep 2021 15:03:12 -0400 Subject: [PATCH 267/290] Early return if polling is not required - Add context to errors in poller --- cmd/ghcs/create.go | 4 ++-- internal/api/api.go | 36 +++++++++++++++++------------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index e37c1a200..0dffd5710 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -81,8 +81,7 @@ func create(opts *createOptions) error { return errors.New("there are no available machine types for this repository") } - log.Println("Creating your codespace...") - + log.Print("Creating your codespace...") codespace, err := apiClient.CreateCodespace(ctx, log, &api.CreateCodespaceParams{ User: userResult.User.Login, RepositoryID: repository.ID, @@ -90,6 +89,7 @@ func create(opts *createOptions) error { Machine: machine, Location: locationResult.Location, }) + log.Print("\n") if err != nil { return fmt.Errorf("error creating codespace: %w", err) } diff --git a/internal/api/api.go b/internal/api/api.go index 273c64435..eac2c3a88 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -209,6 +209,8 @@ type getCodespaceTokenResponse struct { RepositoryToken string `json:"repository_token"` } +// ErrNotProvisioned is returned by GetCodespacesToken to indicate that the +// creation of a codespace is not yet complete and that the caller should try again. var ErrNotProvisioned = errors.New("codespace not provisioned") func (a *API) GetCodespaceToken(ctx context.Context, ownerLogin, codespaceName string) (string, error) { @@ -420,26 +422,17 @@ func (a *API) CreateCodespace(ctx context.Context, log logger, params *CreateCod codespace, err := a.startCreate( ctx, params.User, params.RepositoryID, params.Machine, params.Branch, params.Location, ) - if err != nil { - // errProvisioningInProgress indicates that codespace creation did not complete - // within the GitHub API RPC time limit (10s), so it continues asynchronously. - // We must poll the server to discover the outcome. - if err == errProvisioningInProgress { - pollTimeout := 2 * time.Minute - pollInterval := 1 * time.Second - log.Print(".") - codespace, err = pollForCodespace(ctx, a, log, pollTimeout, pollInterval, params.User, codespace.Name) - log.Print("\n") - - if err != nil { - return nil, fmt.Errorf("error creating codespace with async provisioning: %s: %w", codespace.Name, err) - } - } - - return nil, err + if err != errProvisioningInProgress { + return codespace, err } - return codespace, nil + // errProvisioningInProgress indicates that codespace creation did not complete + // within the GitHub API RPC time limit (10s), so it continues asynchronously. + // We must poll the server to discover the outcome. + pollTimeout := 2 * time.Minute + pollInterval := 1 * time.Second + + return pollForCodespace(ctx, a, log, pollTimeout, pollInterval, params.User, codespace.Name) } type apiClient interface { @@ -472,7 +465,12 @@ func pollForCodespace(ctx context.Context, client apiClient, log logger, duratio return nil, fmt.Errorf("failed to get codespace token: %w", err) } - return client.GetCodespace(ctx, token, user, name) + codespace, err := client.GetCodespace(ctx, token, user, name) + if err != nil { + return nil, fmt.Errorf("failed to get codespace: %w", err) + } + + return codespace, nil } } } From 4e0ac15fe045012a2398690fefefff66c86d43a7 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Wed, 22 Sep 2021 15:10:47 -0400 Subject: [PATCH 268/290] Add buffer to channels to avoid goroutine leak --- cmd/ghcs/create.go | 4 ++-- cmd/ghcs/ports.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 45aa794e6..b52690d1d 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -172,7 +172,7 @@ type getUserResult struct { // getUser fetches the user record associated with the GITHUB_TOKEN func getUser(ctx context.Context, apiClient *api.API) <-chan getUserResult { - ch := make(chan getUserResult) + ch := make(chan getUserResult, 1) go func() { user, err := apiClient.GetUser(ctx) ch <- getUserResult{user, err} @@ -187,7 +187,7 @@ type locationResult struct { // getLocation fetches the closest Codespace datacenter region/location to the user. func getLocation(ctx context.Context, apiClient *api.API) <-chan locationResult { - ch := make(chan locationResult) + ch := make(chan locationResult, 1) go func() { location, err := apiClient.GetCodespaceRegionLocation(ctx) ch <- locationResult{location, err} diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index aeecf0a07..8a4f855fa 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -123,7 +123,7 @@ type portAttribute struct { } func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Codespace) <-chan devContainerResult { - ch := make(chan devContainerResult) + ch := make(chan devContainerResult, 1) go func() { contents, err := apiClient.GetCodespaceRepositoryContents(ctx, codespace, ".devcontainer/devcontainer.json") if err != nil { From 13d7804a359f8062817ec1e1da183da1e08a927a Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 08:26:23 -0400 Subject: [PATCH 269/290] Remove API test, inline poller --- internal/api/api.go | 22 ++--------- internal/api/api_test.go | 82 ---------------------------------------- 2 files changed, 4 insertions(+), 100 deletions(-) delete mode 100644 internal/api/api_test.go diff --git a/internal/api/api.go b/internal/api/api.go index eac2c3a88..50c4a03de 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -429,24 +429,10 @@ func (a *API) CreateCodespace(ctx context.Context, log logger, params *CreateCod // errProvisioningInProgress indicates that codespace creation did not complete // within the GitHub API RPC time limit (10s), so it continues asynchronously. // We must poll the server to discover the outcome. - pollTimeout := 2 * time.Minute - pollInterval := 1 * time.Second - - return pollForCodespace(ctx, a, log, pollTimeout, pollInterval, params.User, codespace.Name) -} - -type apiClient interface { - GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) - GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*Codespace, error) -} - -// pollForCodespace polls the Codespaces GET endpoint on a given interval for a specified duration. -// If it succeeds at fetching the codespace, we consider the codespace provisioned. -func pollForCodespace(ctx context.Context, client apiClient, log logger, duration, interval time.Duration, user, name string) (*Codespace, error) { - ctx, cancel := context.WithTimeout(ctx, duration) + ctx, cancel := context.WithTimeout(ctx, 2*time.Minute) defer cancel() - ticker := time.NewTicker(interval) + ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { @@ -455,7 +441,7 @@ func pollForCodespace(ctx context.Context, client apiClient, log logger, duratio return nil, ctx.Err() case <-ticker.C: log.Print(".") - token, err := client.GetCodespaceToken(ctx, user, name) + token, err := a.GetCodespaceToken(ctx, params.User, codespace.Name) if err != nil { if err == ErrNotProvisioned { // Do nothing. We expect this to fail until the codespace is provisioned @@ -465,7 +451,7 @@ func pollForCodespace(ctx context.Context, client apiClient, log logger, duratio return nil, fmt.Errorf("failed to get codespace token: %w", err) } - codespace, err := client.GetCodespace(ctx, token, user, name) + codespace, err = a.GetCodespace(ctx, token, params.User, codespace.Name) if err != nil { return nil, fmt.Errorf("failed to get codespace: %w", err) } diff --git a/internal/api/api_test.go b/internal/api/api_test.go deleted file mode 100644 index eb5226a59..000000000 --- a/internal/api/api_test.go +++ /dev/null @@ -1,82 +0,0 @@ -package api - -import ( - "context" - "errors" - "fmt" - "testing" - "time" - - "github.com/github/ghcs/cmd/ghcs/output" -) - -type mockAPIClient struct { - getCodespaceToken func(context.Context, string, string) (string, error) - getCodespace func(context.Context, string, string, string) (*Codespace, error) -} - -func (m *mockAPIClient) GetCodespaceToken(ctx context.Context, userLogin, codespaceName string) (string, error) { - if m.getCodespaceToken == nil { - return "", errors.New("mock api client GetCodespaceToken not implemented") - } - - return m.getCodespaceToken(ctx, userLogin, codespaceName) -} - -func (m *mockAPIClient) GetCodespace(ctx context.Context, token, userLogin, codespaceName string) (*Codespace, error) { - if m.getCodespace == nil { - return nil, errors.New("mock api client GetCodespace not implemented") - } - - return m.getCodespace(ctx, token, userLogin, codespaceName) -} - -func TestPollForCodespace(t *testing.T) { - logger := output.NewLogger(nil, nil, false) - user := &User{Login: "test"} - tmpCodespace := &Codespace{Name: "tmp-codespace"} - codespaceToken := "codespace-token" - ctx := context.Background() - - pollInterval := 50 * time.Millisecond - pollTimeout := 100 * time.Millisecond - - api := &mockAPIClient{ - getCodespaceToken: func(ctx context.Context, userLogin, codespace string) (string, error) { - if userLogin != user.Login { - return "", fmt.Errorf("user does not match, got: %s, expected: %s", userLogin, user.Login) - } - if codespace != tmpCodespace.Name { - return "", fmt.Errorf("codespace does not match, got: %s, expected: %s", codespace, tmpCodespace.Name) - } - return codespaceToken, nil - }, - getCodespace: func(ctx context.Context, token, userLogin, codespace string) (*Codespace, error) { - if token != codespaceToken { - return nil, fmt.Errorf("token does not match, got: %s, expected: %s", token, codespaceToken) - } - if userLogin != user.Login { - return nil, fmt.Errorf("user does not match, got: %s, expected: %s", userLogin, user.Login) - } - if codespace != tmpCodespace.Name { - return nil, fmt.Errorf("codespace does not match, got: %s, expected: %s", codespace, tmpCodespace.Name) - } - return tmpCodespace, nil - }, - } - - codespace, err := pollForCodespace(ctx, api, logger, pollTimeout, pollInterval, user.Login, tmpCodespace.Name) - if err != nil { - t.Error(err) - } - if tmpCodespace.Name != codespace.Name { - t.Errorf("returned codespace does not match, got: %s, expected: %s", codespace.Name, tmpCodespace.Name) - } - - // swap the durations to trigger a timeout - pollTimeout, pollInterval = pollInterval, pollTimeout - _, err = pollForCodespace(ctx, api, logger, pollTimeout, pollInterval, user.Login, tmpCodespace.Name) - if err != context.DeadlineExceeded { - t.Error("expected context deadline exceeded error, got nil") - } -} From 186b90b12e4d253d091a265714965bc96284c78f Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 08:29:24 -0400 Subject: [PATCH 270/290] Rename request type --- internal/api/api.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index 50c4a03de..fdf5c5b55 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -461,7 +461,7 @@ func (a *API) CreateCodespace(ctx context.Context, log logger, params *CreateCod } } -type createCodespaceRequest struct { +type startCreateRequest struct { RepositoryID int `json:"repository_id"` Ref string `json:"ref"` Location string `json:"location"` @@ -475,7 +475,7 @@ var errProvisioningInProgress = errors.New("provisioning in progress") // did not complete before the GitHub API's time limit for RPCs (10s), in which case the caller // must poll the server to learn the outcome. func (a *API) startCreate(ctx context.Context, user string, repository int, sku, branch, location string) (*Codespace, error) { - requestBody, err := json.Marshal(createCodespaceRequest{repository, branch, location, sku}) + requestBody, err := json.Marshal(startCreateRequest{repository, branch, location, sku}) if err != nil { return nil, fmt.Errorf("error marshaling request: %w", err) } From 9654dc4bd3711ed6ec00a112355313827cfe95bf Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 10:07:14 -0400 Subject: [PATCH 271/290] Update to go-liveshare v0.20.0 --- internal/codespaces/codespaces.go | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 2933c9d8d..7b27b817e 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -61,17 +61,10 @@ func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, use log.Println("Connecting to your codespace...") - lsclient, err := liveshare.NewClient( - liveshare.WithConnection(liveshare.Connection{ - SessionID: codespace.Environment.Connection.SessionID, - SessionToken: codespace.Environment.Connection.SessionToken, - RelaySAS: codespace.Environment.Connection.RelaySAS, - RelayEndpoint: codespace.Environment.Connection.RelayEndpoint, - }), - ) - if err != nil { - return nil, fmt.Errorf("error creating Live Share client: %w", err) - } - - return lsclient.JoinWorkspace(ctx) + return liveshare.Connect(ctx, liveshare.Options{ + SessionID: codespace.Environment.Connection.SessionID, + SessionToken: codespace.Environment.Connection.SessionToken, + RelaySAS: codespace.Environment.Connection.RelaySAS, + RelayEndpoint: codespace.Environment.Connection.RelayEndpoint, + }) } From f1c35ba9daa06996205082f05a275ce97aa68297 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 10:21:01 -0400 Subject: [PATCH 272/290] Update docs --- internal/codespaces/codespaces.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 7b27b817e..43809bab9 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -23,8 +23,8 @@ func connectionReady(codespace *api.Codespace) bool { codespace.Environment.State == api.CodespaceEnvironmentStateAvailable } -// ConnectToLiveshare creates a Live Share client and joins the Live Share session. -// It will start the Codespace if it is not already running, it will time out after 60 seconds if fails to start. +// ConnectToLiveshare waits for a Codespace to become running, +// and connects to it using a Live Share session. func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { From e8212a80a9dcdbecb698f47bd45176ad1703bff1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 23 Sep 2021 17:14:25 +0200 Subject: [PATCH 273/290] Print `delete` failures as they occur --- cmd/ghcs/delete.go | 18 ++++++++++++++---- cmd/ghcs/delete_test.go | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 4f1c71842..b4c0bcbcf 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -9,6 +9,7 @@ import ( "time" "github.com/AlecAivazis/survey/v2" + "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" @@ -55,7 +56,8 @@ func newDeleteCmd() *cobra.Command { if opts.deleteAll && opts.repoFilter != "" { return errors.New("both --all and --repo is not supported") } - return delete(context.Background(), opts) + log := output.NewLogger(cmd.OutOrStdout(), cmd.ErrOrStderr(), !opts.isInteractive) + return delete(context.Background(), log, opts) }, } @@ -68,7 +70,11 @@ func newDeleteCmd() *cobra.Command { return deleteCmd } -func delete(ctx context.Context, opts deleteOptions) error { +type logger interface { + Errorf(format string, v ...interface{}) (int, error) +} + +func delete(ctx context.Context, log logger, opts deleteOptions) error { user, err := opts.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) @@ -127,13 +133,17 @@ func delete(ctx context.Context, opts deleteOptions) error { codespaceName := c.Name g.Go(func() error { if err := opts.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil { - return fmt.Errorf("error deleting codespace: %w", err) + log.Errorf("error deleting codespace %q: %v", codespaceName, err) + return err } return nil }) } - return g.Wait() + if err := g.Wait(); err != nil { + return errors.New("some codespaces failed to delete") + } + return nil } func confirmDeletion(p prompter, codespace *api.Codespace, isInteractive bool) (bool, error) { diff --git a/cmd/ghcs/delete_test.go b/cmd/ghcs/delete_test.go index 754254494..beb371dd4 100644 --- a/cmd/ghcs/delete_test.go +++ b/cmd/ghcs/delete_test.go @@ -156,7 +156,7 @@ func TestDelete(t *testing.T) { }, } - err := delete(context.Background(), opts) + err := delete(context.Background(), nil, opts) if (err != nil) != tt.wantErr { t.Errorf("delete() error = %v, wantErr %v", err, tt.wantErr) } From 6ca35d0e730d1adaecc1c7c79c9c4892e2138449 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 11:18:49 -0400 Subject: [PATCH 274/290] Moved files to liveshare dir --- client.go => liveshare/client.go | 0 client_test.go => liveshare/client_test.go | 0 options_test.go => liveshare/options_test.go | 0 port_forwarder.go => liveshare/port_forwarder.go | 0 port_forwarder_test.go => liveshare/port_forwarder_test.go | 0 rpc.go => liveshare/rpc.go | 0 session.go => liveshare/session.go | 0 session_test.go => liveshare/session_test.go | 0 socket.go => liveshare/socket.go | 0 ssh.go => liveshare/ssh.go | 0 {test => liveshare/test}/server.go | 0 {test => liveshare/test}/socket.go | 0 12 files changed, 0 insertions(+), 0 deletions(-) rename client.go => liveshare/client.go (100%) rename client_test.go => liveshare/client_test.go (100%) rename options_test.go => liveshare/options_test.go (100%) rename port_forwarder.go => liveshare/port_forwarder.go (100%) rename port_forwarder_test.go => liveshare/port_forwarder_test.go (100%) rename rpc.go => liveshare/rpc.go (100%) rename session.go => liveshare/session.go (100%) rename session_test.go => liveshare/session_test.go (100%) rename socket.go => liveshare/socket.go (100%) rename ssh.go => liveshare/ssh.go (100%) rename {test => liveshare/test}/server.go (100%) rename {test => liveshare/test}/socket.go (100%) diff --git a/client.go b/liveshare/client.go similarity index 100% rename from client.go rename to liveshare/client.go diff --git a/client_test.go b/liveshare/client_test.go similarity index 100% rename from client_test.go rename to liveshare/client_test.go diff --git a/options_test.go b/liveshare/options_test.go similarity index 100% rename from options_test.go rename to liveshare/options_test.go diff --git a/port_forwarder.go b/liveshare/port_forwarder.go similarity index 100% rename from port_forwarder.go rename to liveshare/port_forwarder.go diff --git a/port_forwarder_test.go b/liveshare/port_forwarder_test.go similarity index 100% rename from port_forwarder_test.go rename to liveshare/port_forwarder_test.go diff --git a/rpc.go b/liveshare/rpc.go similarity index 100% rename from rpc.go rename to liveshare/rpc.go diff --git a/session.go b/liveshare/session.go similarity index 100% rename from session.go rename to liveshare/session.go diff --git a/session_test.go b/liveshare/session_test.go similarity index 100% rename from session_test.go rename to liveshare/session_test.go diff --git a/socket.go b/liveshare/socket.go similarity index 100% rename from socket.go rename to liveshare/socket.go diff --git a/ssh.go b/liveshare/ssh.go similarity index 100% rename from ssh.go rename to liveshare/ssh.go diff --git a/test/server.go b/liveshare/test/server.go similarity index 100% rename from test/server.go rename to liveshare/test/server.go diff --git a/test/socket.go b/liveshare/test/socket.go similarity index 100% rename from test/socket.go rename to liveshare/test/socket.go From f4396e8f1a0b79630e81b233b802d49cd0172dad Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 11:28:04 -0400 Subject: [PATCH 275/290] Inline go-liveshare with history --- cmd/ghcs/logs.go | 2 +- cmd/ghcs/ports.go | 2 +- cmd/ghcs/ssh.go | 2 +- internal/codespaces/codespaces.go | 2 +- internal/codespaces/states.go | 2 +- {liveshare => internal/liveshare}/client.go | 0 {liveshare => internal/liveshare}/client_test.go | 2 +- {liveshare => internal/liveshare}/options_test.go | 0 {liveshare => internal/liveshare}/port_forwarder.go | 0 {liveshare => internal/liveshare}/port_forwarder_test.go | 2 +- {liveshare => internal/liveshare}/rpc.go | 0 {liveshare => internal/liveshare}/session.go | 0 {liveshare => internal/liveshare}/session_test.go | 2 +- {liveshare => internal/liveshare}/socket.go | 0 {liveshare => internal/liveshare}/ssh.go | 0 {liveshare => internal/liveshare}/test/server.go | 0 {liveshare => internal/liveshare}/test/socket.go | 0 17 files changed, 8 insertions(+), 8 deletions(-) rename {liveshare => internal/liveshare}/client.go (100%) rename {liveshare => internal/liveshare}/client_test.go (96%) rename {liveshare => internal/liveshare}/options_test.go (100%) rename {liveshare => internal/liveshare}/port_forwarder.go (100%) rename {liveshare => internal/liveshare}/port_forwarder_test.go (97%) rename {liveshare => internal/liveshare}/rpc.go (100%) rename {liveshare => internal/liveshare}/session.go (100%) rename {liveshare => internal/liveshare}/session_test.go (98%) rename {liveshare => internal/liveshare}/socket.go (100%) rename {liveshare => internal/liveshare}/ssh.go (100%) rename {liveshare => internal/liveshare}/test/server.go (100%) rename {liveshare => internal/liveshare}/test/socket.go (100%) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 514c36966..4bc83001b 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -9,7 +9,7 @@ import ( "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" "github.com/spf13/cobra" ) diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index 8a4f855fa..24ec7a6e8 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -14,7 +14,7 @@ import ( "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" "github.com/muhammadmuzzammil1998/jsonc" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 3a49e6ebc..bb771107a 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -9,7 +9,7 @@ import ( "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" "github.com/spf13/cobra" ) diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 43809bab9..1cd605abc 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -7,7 +7,7 @@ import ( "time" "github.com/github/ghcs/internal/api" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" ) type logger interface { diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 31105d576..c7d61b41e 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -10,7 +10,7 @@ import ( "time" "github.com/github/ghcs/internal/api" - "github.com/github/go-liveshare" + "github.com/github/ghcs/internal/liveshare" ) // PostCreateStateStatus is a string value representing the different statuses a state can have. diff --git a/liveshare/client.go b/internal/liveshare/client.go similarity index 100% rename from liveshare/client.go rename to internal/liveshare/client.go diff --git a/liveshare/client_test.go b/internal/liveshare/client_test.go similarity index 96% rename from liveshare/client_test.go rename to internal/liveshare/client_test.go index 2b95f738f..55139d762 100644 --- a/liveshare/client_test.go +++ b/internal/liveshare/client_test.go @@ -9,7 +9,7 @@ import ( "strings" "testing" - livesharetest "github.com/github/go-liveshare/test" + livesharetest "github.com/github/ghcs/internal/liveshare/test" "github.com/sourcegraph/jsonrpc2" ) diff --git a/liveshare/options_test.go b/internal/liveshare/options_test.go similarity index 100% rename from liveshare/options_test.go rename to internal/liveshare/options_test.go diff --git a/liveshare/port_forwarder.go b/internal/liveshare/port_forwarder.go similarity index 100% rename from liveshare/port_forwarder.go rename to internal/liveshare/port_forwarder.go diff --git a/liveshare/port_forwarder_test.go b/internal/liveshare/port_forwarder_test.go similarity index 97% rename from liveshare/port_forwarder_test.go rename to internal/liveshare/port_forwarder_test.go index c4245f513..25b4b2c80 100644 --- a/liveshare/port_forwarder_test.go +++ b/internal/liveshare/port_forwarder_test.go @@ -10,7 +10,7 @@ import ( "testing" "time" - livesharetest "github.com/github/go-liveshare/test" + livesharetest "github.com/github/ghcs/internal/liveshare/test" "github.com/sourcegraph/jsonrpc2" ) diff --git a/liveshare/rpc.go b/internal/liveshare/rpc.go similarity index 100% rename from liveshare/rpc.go rename to internal/liveshare/rpc.go diff --git a/liveshare/session.go b/internal/liveshare/session.go similarity index 100% rename from liveshare/session.go rename to internal/liveshare/session.go diff --git a/liveshare/session_test.go b/internal/liveshare/session_test.go similarity index 98% rename from liveshare/session_test.go rename to internal/liveshare/session_test.go index cd0a7b474..1273c6f2b 100644 --- a/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -9,7 +9,7 @@ import ( "strings" "testing" - livesharetest "github.com/github/go-liveshare/test" + livesharetest "github.com/github/ghcs/internal/liveshare/test" "github.com/sourcegraph/jsonrpc2" ) diff --git a/liveshare/socket.go b/internal/liveshare/socket.go similarity index 100% rename from liveshare/socket.go rename to internal/liveshare/socket.go diff --git a/liveshare/ssh.go b/internal/liveshare/ssh.go similarity index 100% rename from liveshare/ssh.go rename to internal/liveshare/ssh.go diff --git a/liveshare/test/server.go b/internal/liveshare/test/server.go similarity index 100% rename from liveshare/test/server.go rename to internal/liveshare/test/server.go diff --git a/liveshare/test/socket.go b/internal/liveshare/test/socket.go similarity index 100% rename from liveshare/test/socket.go rename to internal/liveshare/test/socket.go From d0c65e549067426f80caf6dc5a99f99ffa4006cd Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 11:36:27 -0400 Subject: [PATCH 276/290] Linter fixes --- internal/liveshare/session_test.go | 11 +++++++++- internal/liveshare/test/server.go | 32 +++++++++++++++--------------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go index 1273c6f2b..0ffdfe136 100644 --- a/internal/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -24,6 +24,10 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), ) testServer, err := livesharetest.NewServer(opts...) + if err != nil { + return nil, nil, fmt.Errorf("error creating server: %w", err) + } + session, err := Connect(context.Background(), Options{ SessionID: "session-id", SessionToken: sessionToken, @@ -67,7 +71,12 @@ func TestServerStartSharing(t *testing.T) { testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", startSharing), ) - defer testServer.Close() + defer func() { + if err := testServer.Close(); err != nil { + t.Errorf("failed to close test server: %w", err) + } + }() + if err != nil { t.Errorf("error creating mock session: %v", err) } diff --git a/internal/liveshare/test/server.go b/internal/liveshare/test/server.go index 159a2a982..8f80d1bce 100644 --- a/internal/liveshare/test/server.go +++ b/internal/liveshare/test/server.go @@ -44,7 +44,7 @@ Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U= type Server struct { password string - services map[string]RpcHandleFunc + services map[string]RPCHandleFunc relaySAS string streams map[string]io.ReadWriter @@ -67,7 +67,7 @@ func NewServer(opts ...ServerOption) (*Server, error) { } privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey)) if err != nil { - return nil, fmt.Errorf("error parsing key: %v", err) + return nil, fmt.Errorf("error parsing key: %w", err) } server.sshConfig.AddHostKey(privateKey) @@ -85,10 +85,10 @@ func WithPassword(password string) ServerOption { } } -func WithService(serviceName string, handler RpcHandleFunc) ServerOption { +func WithService(serviceName string, handler RPCHandleFunc) ServerOption { return func(s *Server) error { if s.services == nil { - s.services = make(map[string]RpcHandleFunc) + s.services = make(map[string]RPCHandleFunc) } s.services[serviceName] = handler @@ -148,7 +148,7 @@ func makeConnection(server *Server) http.HandlerFunc { } c, err := upgrader.Upgrade(w, req, nil) if err != nil { - server.errCh <- fmt.Errorf("error upgrading connection: %v", err) + server.errCh <- fmt.Errorf("error upgrading connection: %w", err) return } defer c.Close() @@ -156,7 +156,7 @@ func makeConnection(server *Server) http.HandlerFunc { socketConn := newSocketConn(c) _, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig) if err != nil { - server.errCh <- fmt.Errorf("error creating new ssh conn: %v", err) + server.errCh <- fmt.Errorf("error creating new ssh conn: %w", err) return } go ssh.DiscardRequests(reqs) @@ -164,7 +164,7 @@ func makeConnection(server *Server) http.HandlerFunc { for newChannel := range chans { ch, reqs, err := newChannel.Accept() if err != nil { - server.errCh <- fmt.Errorf("error accepting new channel: %v", err) + server.errCh <- fmt.Errorf("error accepting new channel: %w", err) return } go handleNewRequests(server, ch, reqs) @@ -177,7 +177,7 @@ func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Req for req := range reqs { if req.WantReply { if err := req.Reply(true, nil); err != nil { - server.errCh <- fmt.Errorf("error replying to channel request: %v", err) + server.errCh <- fmt.Errorf("error replying to channel request: %w", err) } } if strings.HasPrefix(req.Type, "stream-transport") { @@ -190,14 +190,14 @@ func forwardStream(server *Server, streamName string, channel ssh.Channel) { simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") stream, found := server.streams[simpleStreamName] if !found { - server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName) + server.errCh <- fmt.Errorf("stream '%w' not found", simpleStreamName) return } copy := func(dst io.Writer, src io.Reader) { if _, err := io.Copy(dst, src); err != nil { fmt.Println(err) - server.errCh <- fmt.Errorf("io copy: %v", err) + server.errCh <- fmt.Errorf("io copy: %w", err) return } } @@ -211,33 +211,33 @@ func forwardStream(server *Server, streamName string, channel ssh.Channel) { func handleNewChannel(server *Server, channel ssh.Channel) { stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) - jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server)) + jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server)) } -type RpcHandleFunc func(req *jsonrpc2.Request) (interface{}, error) +type RPCHandleFunc func(req *jsonrpc2.Request) (interface{}, error) type rpcHandler struct { server *Server } -func newRpcHandler(server *Server) *rpcHandler { +func newRPCHandler(server *Server) *rpcHandler { return &rpcHandler{server} } func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { handler, found := r.server.services[req.Method] if !found { - r.server.errCh <- fmt.Errorf("RPC Method: '%v' not serviced", req.Method) + r.server.errCh <- fmt.Errorf("RPC Method: '%s' not serviced", req.Method) return } result, err := handler(req) if err != nil { - r.server.errCh <- fmt.Errorf("error handling: '%v': %v", req.Method, err) + r.server.errCh <- fmt.Errorf("error handling: '%s': %w", req.Method, err) return } if err := conn.Reply(ctx, req.ID, result); err != nil { - r.server.errCh <- fmt.Errorf("error replying: %v", err) + r.server.errCh <- fmt.Errorf("error replying: %w", err) } } From 958990cef83defd3c70278d3b4597c9165641128 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 11:47:52 -0400 Subject: [PATCH 277/290] More linter fixes --- internal/liveshare/session_test.go | 6 +----- internal/liveshare/test/server.go | 16 +++++++++------- internal/liveshare/test/socket.go | 6 +++--- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go index 0ffdfe136..c830c33b1 100644 --- a/internal/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -71,11 +71,7 @@ func TestServerStartSharing(t *testing.T) { testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", startSharing), ) - defer func() { - if err := testServer.Close(); err != nil { - t.Errorf("failed to close test server: %w", err) - } - }() + defer testServer.Close() if err != nil { t.Errorf("error creating mock session: %v", err) diff --git a/internal/liveshare/test/server.go b/internal/liveshare/test/server.go index 8f80d1bce..9b898dafb 100644 --- a/internal/liveshare/test/server.go +++ b/internal/liveshare/test/server.go @@ -138,6 +138,9 @@ var upgrader = websocket.Upgrader{} func makeConnection(server *Server) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if server.relaySAS != "" { // validate the sas key sasParam := req.URL.Query().Get("sb-hc-token") @@ -167,13 +170,13 @@ func makeConnection(server *Server) http.HandlerFunc { server.errCh <- fmt.Errorf("error accepting new channel: %w", err) return } - go handleNewRequests(server, ch, reqs) + go handleNewRequests(ctx, server, ch, reqs) go handleNewChannel(server, ch) } } } -func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { +func handleNewRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { for req := range reqs { if req.WantReply { if err := req.Reply(true, nil); err != nil { @@ -181,16 +184,16 @@ func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Req } } if strings.HasPrefix(req.Type, "stream-transport") { - forwardStream(server, req.Type, channel) + forwardStream(ctx, server, req.Type, channel) } } } -func forwardStream(server *Server, streamName string, channel ssh.Channel) { +func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) { simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") stream, found := server.streams[simpleStreamName] if !found { - server.errCh <- fmt.Errorf("stream '%w' not found", simpleStreamName) + server.errCh <- fmt.Errorf("stream '%s' not found", simpleStreamName) return } @@ -205,8 +208,7 @@ func forwardStream(server *Server, streamName string, channel ssh.Channel) { go copy(stream, channel) go copy(channel, stream) - for { - } + <-ctx.Done() // TODO(josebalius): improve this } func handleNewChannel(server *Server, channel ssh.Channel) { diff --git a/internal/liveshare/test/socket.go b/internal/liveshare/test/socket.go index 9a2d92491..0a7a8baf0 100644 --- a/internal/liveshare/test/socket.go +++ b/internal/liveshare/test/socket.go @@ -28,7 +28,7 @@ func (s *socketConn) Read(b []byte) (int, error) { if s.reader == nil { msgType, r, err := s.Conn.NextReader() if err != nil { - return 0, fmt.Errorf("error getting next reader: %v", err) + return 0, fmt.Errorf("error getting next reader: %w", err) } if msgType != websocket.BinaryMessage { return 0, fmt.Errorf("invalid message type") @@ -54,7 +54,7 @@ func (s *socketConn) Write(b []byte) (int, error) { w, err := s.Conn.NextWriter(websocket.BinaryMessage) if err != nil { - return 0, fmt.Errorf("error getting next writer: %v", err) + return 0, fmt.Errorf("error getting next writer: %w", err) } n, err := w.Write(b) @@ -63,7 +63,7 @@ func (s *socketConn) Write(b []byte) (int, error) { } if err := w.Close(); err != nil { - return 0, fmt.Errorf("error closing writer: %v", err) + return 0, fmt.Errorf("error closing writer: %w", err) } return n, nil From fb53ccb06a1b21e06ce2849d98c2c79f14e4ba76 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 11:56:41 -0400 Subject: [PATCH 278/290] Linter fixes --- internal/liveshare/client_test.go | 8 ++++---- internal/liveshare/port_forwarder_test.go | 14 +++++++------- internal/liveshare/test/socket.go | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/internal/liveshare/client_test.go b/internal/liveshare/client_test.go index 55139d762..12ea903b6 100644 --- a/internal/liveshare/client_test.go +++ b/internal/liveshare/client_test.go @@ -22,7 +22,7 @@ func TestConnect(t *testing.T) { joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { var joinWorkspaceReq joinWorkspaceArgs if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil { - return nil, fmt.Errorf("error unmarshaling req: %v", err) + return nil, fmt.Errorf("error unmarshaling req: %w", err) } if joinWorkspaceReq.ID != opts.SessionID { return nil, errors.New("connection session id does not match") @@ -45,7 +45,7 @@ func TestConnect(t *testing.T) { livesharetest.WithRelaySAS(opts.RelaySAS), ) if err != nil { - t.Errorf("error creating Live Share server: %v", err) + t.Errorf("error creating Live Share server: %w", err) } defer server.Close() opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https") @@ -62,10 +62,10 @@ func TestConnect(t *testing.T) { select { case err := <-server.Err(): - t.Errorf("error from server: %v", err) + t.Errorf("error from server: %w", err) case err := <-done: if err != nil { - t.Errorf("error from client: %v", err) + t.Errorf("error from client: %w", err) } } } diff --git a/internal/liveshare/port_forwarder_test.go b/internal/liveshare/port_forwarder_test.go index 25b4b2c80..64dfb5c88 100644 --- a/internal/liveshare/port_forwarder_test.go +++ b/internal/liveshare/port_forwarder_test.go @@ -17,7 +17,7 @@ import ( func TestNewPortForwarder(t *testing.T) { testServer, session, err := makeMockSession() if err != nil { - t.Errorf("create mock client: %v", err) + t.Errorf("create mock client: %w", err) } defer testServer.Close() pf := NewPortForwarder(session, "ssh", 80) @@ -42,7 +42,7 @@ func TestPortForwarderStart(t *testing.T) { livesharetest.WithStream("stream-id", stream), ) if err != nil { - t.Errorf("create mock session: %v", err) + t.Errorf("create mock session: %w", err) } defer testServer.Close() @@ -73,23 +73,23 @@ func TestPortForwarderStart(t *testing.T) { } b := make([]byte, len("stream-data")) if _, err := conn.Read(b); err != nil && err != io.EOF { - done <- fmt.Errorf("reading stream: %v", err) + done <- fmt.Errorf("reading stream: %w", err) } if string(b) != "stream-data" { - done <- fmt.Errorf("stream data is not expected value, got: %v", string(b)) + done <- fmt.Errorf("stream data is not expected value, got: %s", string(b)) } if _, err := conn.Write([]byte("new-data")); err != nil { - done <- fmt.Errorf("writing to stream: %v", err) + done <- fmt.Errorf("writing to stream: %w", err) } done <- nil }() select { case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) + t.Errorf("error from server: %w", err) case err := <-done: if err != nil { - t.Errorf("error from client: %v", err) + t.Errorf("error from client: %w", err) } } } diff --git a/internal/liveshare/test/socket.go b/internal/liveshare/test/socket.go index 0a7a8baf0..00cd64a1b 100644 --- a/internal/liveshare/test/socket.go +++ b/internal/liveshare/test/socket.go @@ -59,7 +59,7 @@ func (s *socketConn) Write(b []byte) (int, error) { n, err := w.Write(b) if err != nil { - return 0, fmt.Errorf("error writing: %v", err) + return 0, fmt.Errorf("error writing: %w", err) } if err := w.Close(); err != nil { From c4114cc972ccbccabbd3b7e1152703ebab12a892 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 11:58:55 -0400 Subject: [PATCH 279/290] Linter fixes --- internal/liveshare/session_test.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go index c830c33b1..47bac3108 100644 --- a/internal/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -36,7 +36,7 @@ func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, TLSConfig: &tls.Config{InsecureSkipVerify: true}, }) if err != nil { - return nil, nil, fmt.Errorf("error connecting to Live Share: %v", err) + return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err) } return testServer, session, nil } @@ -46,7 +46,7 @@ func TestServerStartSharing(t *testing.T) { startSharing := func(req *jsonrpc2.Request) (interface{}, error) { var args []interface{} if err := json.Unmarshal(*req.Params, &args); err != nil { - return nil, fmt.Errorf("error unmarshaling request: %v", err) + return nil, fmt.Errorf("error unmarshaling request: %w", err) } if len(args) < 3 { return nil, errors.New("not enough arguments to start sharing") @@ -63,7 +63,7 @@ func TestServerStartSharing(t *testing.T) { } if browseURL, ok := args[2].(string); !ok { return nil, errors.New("browse url is not a string") - } else if browseURL != fmt.Sprintf("http://localhost:%v", serverPort) { + } else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) { return nil, errors.New("browseURL does not match expected") } return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil @@ -74,7 +74,7 @@ func TestServerStartSharing(t *testing.T) { defer testServer.Close() if err != nil { - t.Errorf("error creating mock session: %v", err) + t.Errorf("error creating mock session: %w", err) } ctx := context.Background() @@ -82,7 +82,7 @@ func TestServerStartSharing(t *testing.T) { go func() { streamID, err := session.startSharing(ctx, serverProtocol, serverPort) if err != nil { - done <- fmt.Errorf("error sharing server: %v", err) + done <- fmt.Errorf("error sharing server: %w", err) } if streamID.name == "" || streamID.condition == "" { done <- errors.New("stream name or condition is blank") @@ -92,10 +92,10 @@ func TestServerStartSharing(t *testing.T) { select { case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) + t.Errorf("error from server: %w", err) case err := <-done: if err != nil { - t.Errorf("error from client: %v", err) + t.Errorf("error from client: %w", err) } } } @@ -113,7 +113,7 @@ func TestServerGetSharedServers(t *testing.T) { livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), ) if err != nil { - t.Errorf("error creating mock session: %v", err) + t.Errorf("error creating mock session: %w", err) } defer testServer.Close() ctx := context.Background() @@ -121,7 +121,7 @@ func TestServerGetSharedServers(t *testing.T) { go func() { ports, err := session.GetSharedServers(ctx) if err != nil { - done <- fmt.Errorf("error getting shared servers: %v", err) + done <- fmt.Errorf("error getting shared servers: %w", err) } if len(ports) < 1 { done <- errors.New("not enough ports returned") @@ -140,10 +140,10 @@ func TestServerGetSharedServers(t *testing.T) { select { case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) + t.Errorf("error from server: %w", err) case err := <-done: if err != nil { - t.Errorf("error from client: %v", err) + t.Errorf("error from client: %w", err) } } } @@ -152,7 +152,7 @@ func TestServerUpdateSharedVisibility(t *testing.T) { updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) { var req []interface{} if err := json.Unmarshal(*rpcReq.Params, &req); err != nil { - return nil, fmt.Errorf("unmarshal req: %v", err) + return nil, fmt.Errorf("unmarshal req: %w", err) } if len(req) < 2 { return nil, errors.New("request arguments is less than 2") @@ -177,7 +177,7 @@ func TestServerUpdateSharedVisibility(t *testing.T) { livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), ) if err != nil { - t.Errorf("creating mock session: %v", err) + t.Errorf("creating mock session: %w", err) } defer testServer.Close() ctx := context.Background() @@ -187,10 +187,10 @@ func TestServerUpdateSharedVisibility(t *testing.T) { }() select { case err := <-testServer.Err(): - t.Errorf("error from server: %v", err) + t.Errorf("error from server: %w", err) case err := <-done: if err != nil { - t.Errorf("error from client: %v", err) + t.Errorf("error from client: %w", err) } } } From 75c1dfdf49e4b43c31704639d5d33b8361fb58e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 23 Sep 2021 18:57:22 +0200 Subject: [PATCH 280/290] Fetch codespace by name directly if name argument given --- cmd/ghcs/delete.go | 44 +++++++++---- cmd/ghcs/delete_test.go | 33 +++++++--- cmd/ghcs/mock_api.go | 126 +++++++++++++++++++++++++++++++++++--- cmd/ghcs/mock_prompter.go | 4 -- 4 files changed, 174 insertions(+), 33 deletions(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index b4c0bcbcf..0c21c1674 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -28,14 +28,16 @@ type deleteOptions struct { prompter prompter } -//go:generate moq -fmt goimports -rm -out mock_prompter.go . prompter +//go:generate moq -fmt goimports -rm -skip-ensure -out mock_prompter.go . prompter type prompter interface { Confirm(message string) (bool, error) } -//go:generate moq -fmt goimports -rm -out mock_api.go . apiClient +//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient type apiClient interface { GetUser(ctx context.Context) (*api.User, error) + GetCodespaceToken(ctx context.Context, user, name string) (string, error) + GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error) ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) DeleteCodespace(ctx context.Context, user, name string) error } @@ -80,18 +82,34 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error { return fmt.Errorf("error getting user: %w", err) } - codespaces, err := opts.apiClient.ListCodespaces(ctx, user.Login) - if err != nil { - return fmt.Errorf("error getting codespaces: %w", err) - } - + var codespaces []*api.Codespace nameFilter := opts.codespaceName - if nameFilter == "" && !opts.deleteAll && opts.repoFilter == "" { - c, err := chooseCodespaceFromList(ctx, codespaces) + if nameFilter == "" { + codespaces, err = opts.apiClient.ListCodespaces(ctx, user.Login) if err != nil { - return fmt.Errorf("error choosing codespace: %w", err) + return fmt.Errorf("error getting codespaces: %w", err) } - nameFilter = c.Name + + if !opts.deleteAll && opts.repoFilter == "" { + c, err := chooseCodespaceFromList(ctx, codespaces) + if err != nil { + return fmt.Errorf("error choosing codespace: %w", err) + } + nameFilter = c.Name + } + } else { + // TODO: this token is discarded and then re-requested later in DeleteCodespace + token, err := opts.apiClient.GetCodespaceToken(ctx, user.Login, nameFilter) + if err != nil { + return fmt.Errorf("error getting codespace token: %w", err) + } + + codespace, err := opts.apiClient.GetCodespace(ctx, token, user.Login, nameFilter) + if err != nil { + return fmt.Errorf("error fetching codespace information: %w", err) + } + + codespaces = []*api.Codespace{codespace} } codespacesToDelete := make([]*api.Codespace, 0, len(codespaces)) @@ -112,7 +130,7 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error { continue } } - if nameFilter == "" || !opts.skipConfirm { + if !opts.skipConfirm { confirmed, err := confirmDeletion(opts.prompter, c, opts.isInteractive) if err != nil { return fmt.Errorf("unable to confirm: %w", err) @@ -133,7 +151,7 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error { codespaceName := c.Name g.Go(func() error { if err := opts.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil { - log.Errorf("error deleting codespace %q: %v", codespaceName, err) + _, _ = log.Errorf("error deleting codespace %q: %v", codespaceName, err) return err } return nil diff --git a/cmd/ghcs/delete_test.go b/cmd/ghcs/delete_test.go index beb371dd4..7d90a3bd2 100644 --- a/cmd/ghcs/delete_test.go +++ b/cmd/ghcs/delete_test.go @@ -31,9 +31,6 @@ func TestDelete(t *testing.T) { codespaceName: "hubot-robawt-abc", }, codespaces: []*api.Codespace{ - { - Name: "monalisa-spoonknife-123", - }, { Name: "hubot-robawt-abc", }, @@ -130,12 +127,6 @@ func TestDelete(t *testing.T) { GetUserFunc: func(_ context.Context) (*api.User, error) { return user, nil }, - ListCodespacesFunc: func(_ context.Context, userLogin string) ([]*api.Codespace, error) { - if userLogin != user.Login { - return nil, fmt.Errorf("unexpected user %q", userLogin) - } - return tt.codespaces, nil - }, DeleteCodespaceFunc: func(_ context.Context, userLogin, name string) error { if userLogin != user.Login { return fmt.Errorf("unexpected user %q", userLogin) @@ -143,6 +134,30 @@ func TestDelete(t *testing.T) { return nil }, } + if tt.opts.codespaceName == "" { + apiMock.ListCodespacesFunc = func(_ context.Context, userLogin string) ([]*api.Codespace, error) { + if userLogin != user.Login { + return nil, fmt.Errorf("unexpected user %q", userLogin) + } + return tt.codespaces, nil + } + } else { + apiMock.GetCodespaceTokenFunc = func(_ context.Context, userLogin, name string) (string, error) { + if userLogin != user.Login { + return "", fmt.Errorf("unexpected user %q", userLogin) + } + return "CS_TOKEN", nil + } + apiMock.GetCodespaceFunc = func(_ context.Context, token, userLogin, name string) (*api.Codespace, error) { + if userLogin != user.Login { + return nil, fmt.Errorf("unexpected user %q", userLogin) + } + if token != "CS_TOKEN" { + return nil, fmt.Errorf("unexpected token %q", token) + } + return tt.codespaces[0], nil + } + } opts := tt.opts opts.apiClient = apiMock opts.now = func() time.Time { return now } diff --git a/cmd/ghcs/mock_api.go b/cmd/ghcs/mock_api.go index 46edd2835..256a30ec3 100644 --- a/cmd/ghcs/mock_api.go +++ b/cmd/ghcs/mock_api.go @@ -10,10 +10,6 @@ import ( "github.com/github/ghcs/internal/api" ) -// Ensure, that apiClientMock does implement apiClient. -// If this is not the case, regenerate this file with moq. -var _ apiClient = &apiClientMock{} - // apiClientMock is a mock implementation of apiClient. // // func TestSomethingThatUsesapiClient(t *testing.T) { @@ -23,6 +19,12 @@ var _ apiClient = &apiClientMock{} // DeleteCodespaceFunc: func(ctx context.Context, user string, name string) error { // panic("mock out the DeleteCodespace method") // }, +// GetCodespaceFunc: func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) { +// panic("mock out the GetCodespace method") +// }, +// GetCodespaceTokenFunc: func(ctx context.Context, user string, name string) (string, error) { +// panic("mock out the GetCodespaceToken method") +// }, // GetUserFunc: func(ctx context.Context) (*api.User, error) { // panic("mock out the GetUser method") // }, @@ -39,6 +41,12 @@ type apiClientMock struct { // DeleteCodespaceFunc mocks the DeleteCodespace method. DeleteCodespaceFunc func(ctx context.Context, user string, name string) error + // GetCodespaceFunc mocks the GetCodespace method. + GetCodespaceFunc func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) + + // GetCodespaceTokenFunc mocks the GetCodespaceToken method. + GetCodespaceTokenFunc func(ctx context.Context, user string, name string) (string, error) + // GetUserFunc mocks the GetUser method. GetUserFunc func(ctx context.Context) (*api.User, error) @@ -56,6 +64,26 @@ type apiClientMock struct { // Name is the name argument value. Name string } + // GetCodespace holds details about calls to the GetCodespace method. + GetCodespace []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Token is the token argument value. + Token string + // User is the user argument value. + User string + // Name is the name argument value. + Name string + } + // GetCodespaceToken holds details about calls to the GetCodespaceToken method. + GetCodespaceToken []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User string + // Name is the name argument value. + Name string + } // GetUser holds details about calls to the GetUser method. GetUser []struct { // Ctx is the ctx argument value. @@ -69,9 +97,11 @@ type apiClientMock struct { User string } } - lockDeleteCodespace sync.RWMutex - lockGetUser sync.RWMutex - lockListCodespaces sync.RWMutex + lockDeleteCodespace sync.RWMutex + lockGetCodespace sync.RWMutex + lockGetCodespaceToken sync.RWMutex + lockGetUser sync.RWMutex + lockListCodespaces sync.RWMutex } // DeleteCodespace calls DeleteCodespaceFunc. @@ -113,6 +143,88 @@ func (mock *apiClientMock) DeleteCodespaceCalls() []struct { return calls } +// GetCodespace calls GetCodespaceFunc. +func (mock *apiClientMock) GetCodespace(ctx context.Context, token string, user string, name string) (*api.Codespace, error) { + if mock.GetCodespaceFunc == nil { + panic("apiClientMock.GetCodespaceFunc: method is nil but apiClient.GetCodespace was just called") + } + callInfo := struct { + Ctx context.Context + Token string + User string + Name string + }{ + Ctx: ctx, + Token: token, + User: user, + Name: name, + } + mock.lockGetCodespace.Lock() + mock.calls.GetCodespace = append(mock.calls.GetCodespace, callInfo) + mock.lockGetCodespace.Unlock() + return mock.GetCodespaceFunc(ctx, token, user, name) +} + +// GetCodespaceCalls gets all the calls that were made to GetCodespace. +// Check the length with: +// len(mockedapiClient.GetCodespaceCalls()) +func (mock *apiClientMock) GetCodespaceCalls() []struct { + Ctx context.Context + Token string + User string + Name string +} { + var calls []struct { + Ctx context.Context + Token string + User string + Name string + } + mock.lockGetCodespace.RLock() + calls = mock.calls.GetCodespace + mock.lockGetCodespace.RUnlock() + return calls +} + +// GetCodespaceToken calls GetCodespaceTokenFunc. +func (mock *apiClientMock) GetCodespaceToken(ctx context.Context, user string, name string) (string, error) { + if mock.GetCodespaceTokenFunc == nil { + panic("apiClientMock.GetCodespaceTokenFunc: method is nil but apiClient.GetCodespaceToken was just called") + } + callInfo := struct { + Ctx context.Context + User string + Name string + }{ + Ctx: ctx, + User: user, + Name: name, + } + mock.lockGetCodespaceToken.Lock() + mock.calls.GetCodespaceToken = append(mock.calls.GetCodespaceToken, callInfo) + mock.lockGetCodespaceToken.Unlock() + return mock.GetCodespaceTokenFunc(ctx, user, name) +} + +// GetCodespaceTokenCalls gets all the calls that were made to GetCodespaceToken. +// Check the length with: +// len(mockedapiClient.GetCodespaceTokenCalls()) +func (mock *apiClientMock) GetCodespaceTokenCalls() []struct { + Ctx context.Context + User string + Name string +} { + var calls []struct { + Ctx context.Context + User string + Name string + } + mock.lockGetCodespaceToken.RLock() + calls = mock.calls.GetCodespaceToken + mock.lockGetCodespaceToken.RUnlock() + return calls +} + // GetUser calls GetUserFunc. func (mock *apiClientMock) GetUser(ctx context.Context) (*api.User, error) { if mock.GetUserFunc == nil { diff --git a/cmd/ghcs/mock_prompter.go b/cmd/ghcs/mock_prompter.go index e15209c03..56581b64d 100644 --- a/cmd/ghcs/mock_prompter.go +++ b/cmd/ghcs/mock_prompter.go @@ -7,10 +7,6 @@ import ( "sync" ) -// Ensure, that prompterMock does implement prompter. -// If this is not the case, regenerate this file with moq. -var _ prompter = &prompterMock{} - // prompterMock is a mock implementation of prompter. // // func TestSomethingThatUsesprompter(t *testing.T) { From b8f35f950ca104c88489a9dd0f4586cd2a47fa36 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 13:14:35 -0400 Subject: [PATCH 281/290] Linter fixes --- internal/liveshare/client.go | 2 +- internal/liveshare/port_forwarder.go | 2 +- internal/liveshare/session_test.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/liveshare/client.go b/internal/liveshare/client.go index b51e25ea6..76f8146de 100644 --- a/internal/liveshare/client.go +++ b/internal/liveshare/client.go @@ -130,7 +130,7 @@ func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.C return nil, fmt.Errorf("error getting stream id: %w", err) } - span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") + span, _ := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") defer span.Finish() channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) diff --git a/internal/liveshare/port_forwarder.go b/internal/liveshare/port_forwarder.go index 56401cc4d..fcc7ba767 100644 --- a/internal/liveshare/port_forwarder.go +++ b/internal/liveshare/port_forwarder.go @@ -94,7 +94,7 @@ func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error if err != nil { err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err) } - return id, nil + return id, err } func awaitError(ctx context.Context, errc <-chan error) error { diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go index 47bac3108..c9a1be567 100644 --- a/internal/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -71,7 +71,7 @@ func TestServerStartSharing(t *testing.T) { testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", startSharing), ) - defer testServer.Close() + defer testServer.Close() //nolint:stylecheck // httptest.Server does not return errors on Close() if err != nil { t.Errorf("error creating mock session: %w", err) From 08bc181d79f30f344f361494ac490762453edb16 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 13:16:20 -0400 Subject: [PATCH 282/290] Linter fixes --- internal/liveshare/session_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go index c9a1be567..bca11d885 100644 --- a/internal/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -71,7 +71,7 @@ func TestServerStartSharing(t *testing.T) { testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", startSharing), ) - defer testServer.Close() //nolint:stylecheck // httptest.Server does not return errors on Close() + defer testServer.Close() //nolint - httptest.Server does not return errors on Close() if err != nil { t.Errorf("error creating mock session: %w", err) From 65dcb0f428ff703135c43db9ce1246860905f469 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 13:22:20 -0400 Subject: [PATCH 283/290] Linter fixes --- internal/liveshare/session_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go index bca11d885..af41dd117 100644 --- a/internal/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -71,7 +71,7 @@ func TestServerStartSharing(t *testing.T) { testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", startSharing), ) - defer testServer.Close() //nolint - httptest.Server does not return errors on Close() + defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close() if err != nil { t.Errorf("error creating mock session: %w", err) From 5d6ea5029ed7cf2ab39abf1cb2fc37da5883b842 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 13:36:04 -0400 Subject: [PATCH 284/290] Linter fixes --- internal/liveshare/client.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/liveshare/client.go b/internal/liveshare/client.go index 76f8146de..2b1f97831 100644 --- a/internal/liveshare/client.go +++ b/internal/liveshare/client.go @@ -130,8 +130,9 @@ func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.C return nil, fmt.Errorf("error getting stream id: %w", err) } - span, _ := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") + span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest") defer span.Finish() + _ = ctx // ctx is not currently used channel, reqs, err := s.ssh.conn.OpenChannel("session", nil) if err != nil { From 3d017b282484617a73ca185d27dfcefedefe2e46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 24 Sep 2021 15:09:41 +0200 Subject: [PATCH 285/290] Fix stderr output on delete errors --- cmd/ghcs/delete.go | 2 +- cmd/ghcs/delete_test.go | 38 +++++++++++++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index 0c21c1674..94aaaf214 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -151,7 +151,7 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error { codespaceName := c.Name g.Go(func() error { if err := opts.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil { - _, _ = log.Errorf("error deleting codespace %q: %v", codespaceName, err) + _, _ = log.Errorf("error deleting codespace %q: %v\n", codespaceName, err) return err } return nil diff --git a/cmd/ghcs/delete_test.go b/cmd/ghcs/delete_test.go index 7d90a3bd2..47e6a4d6c 100644 --- a/cmd/ghcs/delete_test.go +++ b/cmd/ghcs/delete_test.go @@ -1,12 +1,15 @@ package ghcs import ( + "bytes" "context" + "errors" "fmt" "sort" "testing" "time" + "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" ) @@ -22,8 +25,11 @@ func TestDelete(t *testing.T) { opts deleteOptions codespaces []*api.Codespace confirms map[string]bool + deleteErr error wantErr bool wantDeleted []string + wantStdout string + wantStderr string }{ { name: "by name", @@ -80,6 +86,24 @@ func TestDelete(t *testing.T) { }, wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-c4f3"}, }, + { + name: "deletion failed", + opts: deleteOptions{ + deleteAll: true, + }, + codespaces: []*api.Codespace{ + { + Name: "monalisa-spoonknife-123", + }, + { + Name: "hubot-robawt-abc", + }, + }, + deleteErr: errors.New("aborted by test"), + wantErr: true, + wantDeleted: []string{"hubot-robawt-abc", "monalisa-spoonknife-123"}, + wantStderr: "error deleting codespace \"hubot-robawt-abc\": aborted by test\nerror deleting codespace \"monalisa-spoonknife-123\": aborted by test\n", + }, { name: "with confirm", opts: deleteOptions{ @@ -131,6 +155,9 @@ func TestDelete(t *testing.T) { if userLogin != user.Login { return fmt.Errorf("unexpected user %q", userLogin) } + if tt.deleteErr != nil { + return tt.deleteErr + } return nil }, } @@ -171,7 +198,10 @@ func TestDelete(t *testing.T) { }, } - err := delete(context.Background(), nil, opts) + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + log := output.NewLogger(stdout, stderr, false) + err := delete(context.Background(), log, opts) if (err != nil) != tt.wantErr { t.Errorf("delete() error = %v, wantErr %v", err, tt.wantErr) } @@ -186,6 +216,12 @@ func TestDelete(t *testing.T) { if !sliceEquals(gotDeleted, tt.wantDeleted) { t.Errorf("deleted %q, want %q", gotDeleted, tt.wantDeleted) } + if out := stdout.String(); out != tt.wantStdout { + t.Errorf("stdout = %q, want %q", out, tt.wantStdout) + } + if out := stderr.String(); out != tt.wantStderr { + t.Errorf("stderr = %q, want %q", out, tt.wantStderr) + } }) } } From ca0f89d3bc1bbf2292ec4e0e2b3fbf97e1047fd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 24 Sep 2021 16:03:44 +0200 Subject: [PATCH 286/290] Introduce an App struct that executes core business logic The Cobra commands are now a light wrapper around the App struct. Co-authored-by: Jose Garcia --- cmd/ghcs/code.go | 15 +- cmd/ghcs/common.go | 35 ++- cmd/ghcs/create.go | 36 ++- cmd/ghcs/delete.go | 36 +-- cmd/ghcs/delete_test.go | 8 +- cmd/ghcs/list.go | 13 +- cmd/ghcs/logs.go | 23 +- cmd/ghcs/main/main.go | 27 ++- cmd/ghcs/mock_api.go | 383 +++++++++++++++++++++++++++++- cmd/ghcs/ports.go | 72 +++--- cmd/ghcs/root.go | 26 +- cmd/ghcs/ssh.go | 27 +-- internal/api/api.go | 21 +- internal/codespaces/codespaces.go | 8 +- internal/codespaces/states.go | 2 +- 15 files changed, 557 insertions(+), 175 deletions(-) diff --git a/cmd/ghcs/code.go b/cmd/ghcs/code.go index 08d2cff1a..cfcd989e2 100644 --- a/cmd/ghcs/code.go +++ b/cmd/ghcs/code.go @@ -5,12 +5,11 @@ import ( "fmt" "net/url" - "github.com/github/ghcs/internal/api" "github.com/skratchdot/open-golang/open" "github.com/spf13/cobra" ) -func newCodeCmd() *cobra.Command { +func newCodeCmd(app *App) *cobra.Command { var ( codespace string useInsiders bool @@ -21,7 +20,7 @@ func newCodeCmd() *cobra.Command { Short: "Open a codespace in VS Code", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return code(codespace, useInsiders) + return app.VSCode(cmd.Context(), codespace, useInsiders) }, } @@ -31,17 +30,15 @@ func newCodeCmd() *cobra.Command { return codeCmd } -func code(codespaceName string, useInsiders bool) error { - apiClient := api.New(GithubToken) - ctx := context.Background() - - user, err := apiClient.GetUser(ctx) +// VSCode opens a codespace in the local VS VSCode application. +func (a *App) VSCode(ctx context.Context, codespaceName string, useInsiders bool) error { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } if codespaceName == "" { - codespace, err := chooseCodespace(ctx, apiClient, user) + codespace, err := chooseCodespace(ctx, a.apiClient, user) if err != nil { if err == errNoCodespaces { return err diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index 371ca30b8..e60fa7c96 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -12,14 +12,43 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/AlecAivazis/survey/v2/terminal" + "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" "golang.org/x/term" ) +type App struct { + apiClient apiClient + logger *output.Logger +} + +func NewApp(logger *output.Logger, apiClient apiClient) *App { + return &App{ + apiClient: apiClient, + logger: logger, + } +} + +//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient +type apiClient interface { + GetUser(ctx context.Context) (*api.User, error) + GetCodespaceToken(ctx context.Context, user, name string) (string, error) + GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error) + ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) + DeleteCodespace(ctx context.Context, user, name string) error + StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error + CreateCodespace(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) + GetRepository(ctx context.Context, nwo string) (*api.Repository, error) + AuthorizedKeys(ctx context.Context, user string) ([]byte, error) + GetCodespaceRegionLocation(ctx context.Context) (string, error) + GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch, location string) ([]*api.SKU, error) + GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) +} + var errNoCodespaces = errors.New("you have no codespaces") -func chooseCodespace(ctx context.Context, apiClient *api.API, user *api.User) (*api.Codespace, error) { +func chooseCodespace(ctx context.Context, apiClient apiClient, user *api.User) (*api.Codespace, error) { codespaces, err := apiClient.ListCodespaces(ctx, user.Login) if err != nil { return nil, fmt.Errorf("error getting codespaces: %w", err) @@ -68,7 +97,7 @@ func chooseCodespaceFromList(ctx context.Context, codespaces []*api.Codespace) ( // getOrChooseCodespace prompts the user to choose a codespace if the codespaceName is empty. // It then fetches the codespace token and the codespace record. -func getOrChooseCodespace(ctx context.Context, apiClient *api.API, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { +func getOrChooseCodespace(ctx context.Context, apiClient apiClient, user *api.User, codespaceName string) (codespace *api.Codespace, token string, err error) { if codespaceName == "" { codespace, err = chooseCodespace(ctx, apiClient, user) if err != nil { @@ -135,7 +164,7 @@ func ask(qs []*survey.Question, response interface{}) error { // checkAuthorizedKeys reports an error if the user has not registered any SSH keys; // see https://github.com/github/ghcs/issues/166#issuecomment-921769703. // The check is not required for security but it improves the error message. -func checkAuthorizedKeys(ctx context.Context, client *api.API, user string) error { +func checkAuthorizedKeys(ctx context.Context, client apiClient, user string) error { keys, err := client.AuthorizedKeys(ctx, user) if err != nil { return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 345489f6b..c92a6edff 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -22,7 +22,7 @@ type createOptions struct { showStatus bool } -func newCreateCmd() *cobra.Command { +func newCreateCmd(app *App) *cobra.Command { opts := &createOptions{} createCmd := &cobra.Command{ @@ -30,7 +30,7 @@ func newCreateCmd() *cobra.Command { Short: "Create a codespace", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return create(opts) + return app.Create(cmd.Context(), opts) }, } @@ -42,12 +42,10 @@ func newCreateCmd() *cobra.Command { return createCmd } -func create(opts *createOptions) error { - ctx := context.Background() - apiClient := api.New(GithubToken) - locationCh := getLocation(ctx, apiClient) - userCh := getUser(ctx, apiClient) - log := output.NewLogger(os.Stdout, os.Stderr, false) +// Create creates a new Codespace +func (a *App) Create(ctx context.Context, opts *createOptions) error { + locationCh := getLocation(ctx, a.apiClient) + userCh := getUser(ctx, a.apiClient) repo, err := getRepoName(opts.repo) if err != nil { @@ -58,7 +56,7 @@ func create(opts *createOptions) error { return fmt.Errorf("error getting branch name: %w", err) } - repository, err := apiClient.GetRepository(ctx, repo) + repository, err := a.apiClient.GetRepository(ctx, repo) if err != nil { return fmt.Errorf("error getting repository: %w", err) } @@ -73,7 +71,7 @@ func create(opts *createOptions) error { return fmt.Errorf("error getting codespace user: %w", userResult.Err) } - machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, apiClient) + machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, a.apiClient) if err != nil { return fmt.Errorf("error getting machine type: %w", err) } @@ -81,26 +79,26 @@ func create(opts *createOptions) error { return errors.New("there are no available machine types for this repository") } - log.Print("Creating your codespace...") - codespace, err := apiClient.CreateCodespace(ctx, log, &api.CreateCodespaceParams{ + a.logger.Print("Creating your codespace...") + codespace, err := a.apiClient.CreateCodespace(ctx, a.logger, &api.CreateCodespaceParams{ User: userResult.User.Login, RepositoryID: repository.ID, Branch: branch, Machine: machine, Location: locationResult.Location, }) - log.Print("\n") + a.logger.Print("\n") if err != nil { return fmt.Errorf("error creating codespace: %w", err) } if opts.showStatus { - if err := showStatus(ctx, log, apiClient, userResult.User, codespace); err != nil { + if err := showStatus(ctx, a.logger, a.apiClient, userResult.User, codespace); err != nil { return fmt.Errorf("show status: %w", err) } } - log.Printf("Codespace created: ") + a.logger.Printf("Codespace created: ") fmt.Fprintln(os.Stdout, codespace.Name) @@ -110,7 +108,7 @@ func create(opts *createOptions) error { // showStatus polls the codespace for a list of post create states and their status. It will keep polling // until all states have finished. Once all states have finished, we poll once more to check if any new // states have been introduced and stop polling otherwise. -func showStatus(ctx context.Context, log *output.Logger, apiClient *api.API, user *api.User, codespace *api.Codespace) error { +func showStatus(ctx context.Context, log *output.Logger, apiClient apiClient, user *api.User, codespace *api.Codespace) error { var lastState codespaces.PostCreateState var breakNextState bool @@ -177,7 +175,7 @@ type getUserResult struct { } // getUser fetches the user record associated with the GITHUB_TOKEN -func getUser(ctx context.Context, apiClient *api.API) <-chan getUserResult { +func getUser(ctx context.Context, apiClient apiClient) <-chan getUserResult { ch := make(chan getUserResult, 1) go func() { user, err := apiClient.GetUser(ctx) @@ -192,7 +190,7 @@ type locationResult struct { } // getLocation fetches the closest Codespace datacenter region/location to the user. -func getLocation(ctx context.Context, apiClient *api.API) <-chan locationResult { +func getLocation(ctx context.Context, apiClient apiClient) <-chan locationResult { ch := make(chan locationResult, 1) go func() { location, err := apiClient.GetCodespaceRegionLocation(ctx) @@ -236,7 +234,7 @@ func getBranchName(branch string) (string, error) { } // getMachineName prompts the user to select the machine type, or validates the machine if non-empty. -func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient *api.API) (string, error) { +func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient apiClient) (string, error) { skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, branch, location) if err != nil { return "", fmt.Errorf("error requesting machine instance types: %w", err) diff --git a/cmd/ghcs/delete.go b/cmd/ghcs/delete.go index b5d25e7bb..d7fed4e68 100644 --- a/cmd/ghcs/delete.go +++ b/cmd/ghcs/delete.go @@ -4,12 +4,10 @@ import ( "context" "errors" "fmt" - "os" "strings" "time" "github.com/AlecAivazis/survey/v2" - "github.com/github/ghcs/cmd/ghcs/output" "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" @@ -24,7 +22,6 @@ type deleteOptions struct { isInteractive bool now func() time.Time - apiClient apiClient prompter prompter } @@ -33,20 +30,10 @@ type prompter interface { Confirm(message string) (bool, error) } -//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient -type apiClient interface { - GetUser(ctx context.Context) (*api.User, error) - GetCodespaceToken(ctx context.Context, user, name string) (string, error) - GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error) - ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) - DeleteCodespace(ctx context.Context, user, name string) error -} - -func newDeleteCmd() *cobra.Command { +func newDeleteCmd(app *App) *cobra.Command { opts := deleteOptions{ isInteractive: hasTTY, now: time.Now, - apiClient: api.New(os.Getenv("GITHUB_TOKEN")), prompter: &surveyPrompter{}, } @@ -58,8 +45,7 @@ func newDeleteCmd() *cobra.Command { if opts.deleteAll && opts.repoFilter != "" { return errors.New("both --all and --repo is not supported") } - log := output.NewLogger(cmd.OutOrStdout(), cmd.ErrOrStderr(), !opts.isInteractive) - return delete(context.Background(), log, opts) + return app.Delete(cmd.Context(), opts) }, } @@ -72,12 +58,8 @@ func newDeleteCmd() *cobra.Command { return deleteCmd } -type logger interface { - Errorf(format string, v ...interface{}) (int, error) -} - -func delete(ctx context.Context, log logger, opts deleteOptions) error { - user, err := opts.apiClient.GetUser(ctx) +func (a *App) Delete(ctx context.Context, opts deleteOptions) error { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } @@ -85,7 +67,7 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error { var codespaces []*api.Codespace nameFilter := opts.codespaceName if nameFilter == "" { - codespaces, err = opts.apiClient.ListCodespaces(ctx, user.Login) + codespaces, err = a.apiClient.ListCodespaces(ctx, user.Login) if err != nil { return fmt.Errorf("error getting codespaces: %w", err) } @@ -99,12 +81,12 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error { } } else { // TODO: this token is discarded and then re-requested later in DeleteCodespace - token, err := opts.apiClient.GetCodespaceToken(ctx, user.Login, nameFilter) + token, err := a.apiClient.GetCodespaceToken(ctx, user.Login, nameFilter) if err != nil { return fmt.Errorf("error getting codespace token: %w", err) } - codespace, err := opts.apiClient.GetCodespace(ctx, token, user.Login, nameFilter) + codespace, err := a.apiClient.GetCodespace(ctx, token, user.Login, nameFilter) if err != nil { return fmt.Errorf("error fetching codespace information: %w", err) } @@ -150,8 +132,8 @@ func delete(ctx context.Context, log logger, opts deleteOptions) error { for _, c := range codespacesToDelete { codespaceName := c.Name g.Go(func() error { - if err := opts.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil { - _, _ = log.Errorf("error deleting codespace %q: %v\n", codespaceName, err) + if err := a.apiClient.DeleteCodespace(ctx, user.Login, codespaceName); err != nil { + _, _ = a.logger.Errorf("error deleting codespace %q: %v\n", codespaceName, err) return err } return nil diff --git a/cmd/ghcs/delete_test.go b/cmd/ghcs/delete_test.go index 47e6a4d6c..ab7b01d30 100644 --- a/cmd/ghcs/delete_test.go +++ b/cmd/ghcs/delete_test.go @@ -186,7 +186,6 @@ func TestDelete(t *testing.T) { } } opts := tt.opts - opts.apiClient = apiMock opts.now = func() time.Time { return now } opts.prompter = &prompterMock{ ConfirmFunc: func(msg string) (bool, error) { @@ -200,8 +199,11 @@ func TestDelete(t *testing.T) { stdout := &bytes.Buffer{} stderr := &bytes.Buffer{} - log := output.NewLogger(stdout, stderr, false) - err := delete(context.Background(), log, opts) + app := &App{ + apiClient: apiMock, + logger: output.NewLogger(stdout, stderr, false), + } + err := app.Delete(context.Background(), opts) if (err != nil) != tt.wantErr { t.Errorf("delete() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 065b7aa6d..842b9313d 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -14,7 +14,7 @@ type listOptions struct { asJSON bool } -func newListCmd() *cobra.Command { +func newListCmd(app *App) *cobra.Command { opts := &listOptions{} listCmd := &cobra.Command{ @@ -22,7 +22,7 @@ func newListCmd() *cobra.Command { Short: "List your codespaces", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return list(opts) + return app.List(cmd.Context(), opts) }, } @@ -31,16 +31,13 @@ func newListCmd() *cobra.Command { return listCmd } -func list(opts *listOptions) error { - apiClient := api.New(GithubToken) - ctx := context.Background() - - user, err := apiClient.GetUser(ctx) +func (a *App) List(ctx context.Context, opts *listOptions) error { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespaces, err := apiClient.ListCodespaces(ctx, user.Login) + codespaces, err := a.apiClient.ListCodespaces(ctx, user.Login) if err != nil { return fmt.Errorf("error getting codespaces: %w", err) } diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index 0cddc6377..7f73d893c 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -4,29 +4,24 @@ import ( "context" "fmt" "net" - "os" - "github.com/github/ghcs/cmd/ghcs/output" - "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" "github.com/github/ghcs/internal/liveshare" "github.com/spf13/cobra" ) -func newLogsCmd() *cobra.Command { +func newLogsCmd(app *App) *cobra.Command { var ( codespace string follow bool ) - log := output.NewLogger(os.Stdout, os.Stderr, false) - logsCmd := &cobra.Command{ Use: "logs", Short: "Access codespace logs", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return logs(context.Background(), log, codespace, follow) + return app.Logs(cmd.Context(), codespace, follow) }, } @@ -36,29 +31,27 @@ func newLogsCmd() *cobra.Command { return logsCmd } -func logs(ctx context.Context, log *output.Logger, codespaceName string, follow bool) (err error) { +func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err error) { // Ensure all child tasks (port forwarding, remote exec) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() - apiClient := api.New(GithubToken) - - user, err := apiClient.GetUser(ctx) + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("getting user: %w", err) } authkeys := make(chan error, 1) go func() { - authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login) + authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login) }() - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("connecting to Live Share: %w", err) } @@ -76,7 +69,7 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow defer listen.Close() localPort := listen.Addr().(*net.TCPAddr).Port - log.Println("Fetching SSH Details...") + a.logger.Println("Fetching SSH Details...") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) if err != nil { return fmt.Errorf("error getting ssh server details: %w", err) diff --git a/cmd/ghcs/main/main.go b/cmd/ghcs/main/main.go index 6b890d740..7c6b2a175 100644 --- a/cmd/ghcs/main/main.go +++ b/cmd/ghcs/main/main.go @@ -4,22 +4,45 @@ import ( "errors" "fmt" "io" + "net/http" "os" "github.com/github/ghcs/cmd/ghcs" + "github.com/github/ghcs/cmd/ghcs/output" + "github.com/github/ghcs/internal/api" "github.com/spf13/cobra" ) func main() { - rootCmd := ghcs.NewRootCmd() + token := os.Getenv("GITHUB_TOKEN") + rootCmd := ghcs.NewRootCmd(ghcs.NewApp( + output.NewLogger(os.Stdout, os.Stderr, false), + api.New(token, http.DefaultClient), + )) + + // Require GITHUB_TOKEN through a Cobra pre-run hook so that Cobra's help system for commands can still + // function without the token set. + oldPreRun := rootCmd.PersistentPreRunE + rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { + if token == "" { + return errTokenMissing + } + if oldPreRun != nil { + return oldPreRun(cmd, args) + } + return nil + } + if cmd, err := rootCmd.ExecuteC(); err != nil { explainError(os.Stderr, err, cmd) os.Exit(1) } } +var errTokenMissing = errors.New("GITHUB_TOKEN is missing") + func explainError(w io.Writer, err error, cmd *cobra.Command) { - if errors.Is(err, ghcs.ErrTokenMissing) { + if errors.Is(err, errTokenMissing) { fmt.Fprintln(w, "The GITHUB_TOKEN environment variable is required. Create a Personal Access Token at https://github.com/settings/tokens/new?scopes=repo") fmt.Fprintln(w, "Make sure to enable SSO for your organizations after creating the token.") return diff --git a/cmd/ghcs/mock_api.go b/cmd/ghcs/mock_api.go index 256a30ec3..93abe7ed6 100644 --- a/cmd/ghcs/mock_api.go +++ b/cmd/ghcs/mock_api.go @@ -16,21 +16,42 @@ import ( // // // make and configure a mocked apiClient // mockedapiClient := &apiClientMock{ +// AuthorizedKeysFunc: func(ctx context.Context, user string) ([]byte, error) { +// panic("mock out the AuthorizedKeys method") +// }, +// CreateCodespaceFunc: func(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) { +// panic("mock out the CreateCodespace method") +// }, // DeleteCodespaceFunc: func(ctx context.Context, user string, name string) error { // panic("mock out the DeleteCodespace method") // }, // GetCodespaceFunc: func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) { // panic("mock out the GetCodespace method") // }, +// GetCodespaceRegionLocationFunc: func(ctx context.Context) (string, error) { +// panic("mock out the GetCodespaceRegionLocation method") +// }, +// GetCodespaceRepositoryContentsFunc: func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) { +// panic("mock out the GetCodespaceRepositoryContents method") +// }, // GetCodespaceTokenFunc: func(ctx context.Context, user string, name string) (string, error) { // panic("mock out the GetCodespaceToken method") // }, +// GetCodespacesSKUsFunc: func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) { +// panic("mock out the GetCodespacesSKUs method") +// }, +// GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { +// panic("mock out the GetRepository method") +// }, // GetUserFunc: func(ctx context.Context) (*api.User, error) { // panic("mock out the GetUser method") // }, // ListCodespacesFunc: func(ctx context.Context, user string) ([]*api.Codespace, error) { // panic("mock out the ListCodespaces method") // }, +// StartCodespaceFunc: func(ctx context.Context, token string, codespace *api.Codespace) error { +// panic("mock out the StartCodespace method") +// }, // } // // // use mockedapiClient in code that requires apiClient @@ -38,23 +59,60 @@ import ( // // } type apiClientMock struct { + // AuthorizedKeysFunc mocks the AuthorizedKeys method. + AuthorizedKeysFunc func(ctx context.Context, user string) ([]byte, error) + + // CreateCodespaceFunc mocks the CreateCodespace method. + CreateCodespaceFunc func(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) + // DeleteCodespaceFunc mocks the DeleteCodespace method. DeleteCodespaceFunc func(ctx context.Context, user string, name string) error // GetCodespaceFunc mocks the GetCodespace method. GetCodespaceFunc func(ctx context.Context, token string, user string, name string) (*api.Codespace, error) + // GetCodespaceRegionLocationFunc mocks the GetCodespaceRegionLocation method. + GetCodespaceRegionLocationFunc func(ctx context.Context) (string, error) + + // GetCodespaceRepositoryContentsFunc mocks the GetCodespaceRepositoryContents method. + GetCodespaceRepositoryContentsFunc func(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) + // GetCodespaceTokenFunc mocks the GetCodespaceToken method. GetCodespaceTokenFunc func(ctx context.Context, user string, name string) (string, error) + // GetCodespacesSKUsFunc mocks the GetCodespacesSKUs method. + GetCodespacesSKUsFunc func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) + + // GetRepositoryFunc mocks the GetRepository method. + GetRepositoryFunc func(ctx context.Context, nwo string) (*api.Repository, error) + // GetUserFunc mocks the GetUser method. GetUserFunc func(ctx context.Context) (*api.User, error) // ListCodespacesFunc mocks the ListCodespaces method. ListCodespacesFunc func(ctx context.Context, user string) ([]*api.Codespace, error) + // StartCodespaceFunc mocks the StartCodespace method. + StartCodespaceFunc func(ctx context.Context, token string, codespace *api.Codespace) error + // calls tracks calls to the methods. calls struct { + // AuthorizedKeys holds details about calls to the AuthorizedKeys method. + AuthorizedKeys []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User string + } + // CreateCodespace holds details about calls to the CreateCodespace method. + CreateCodespace []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Logger is the logger argument value. + Logger api.Logger + // Params is the params argument value. + Params *api.CreateCodespaceParams + } // DeleteCodespace holds details about calls to the DeleteCodespace method. DeleteCodespace []struct { // Ctx is the ctx argument value. @@ -75,6 +133,20 @@ type apiClientMock struct { // Name is the name argument value. Name string } + // GetCodespaceRegionLocation holds details about calls to the GetCodespaceRegionLocation method. + GetCodespaceRegionLocation []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // GetCodespaceRepositoryContents holds details about calls to the GetCodespaceRepositoryContents method. + GetCodespaceRepositoryContents []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Codespace is the codespace argument value. + Codespace *api.Codespace + // Path is the path argument value. + Path string + } // GetCodespaceToken holds details about calls to the GetCodespaceToken method. GetCodespaceToken []struct { // Ctx is the ctx argument value. @@ -84,6 +156,26 @@ type apiClientMock struct { // Name is the name argument value. Name string } + // GetCodespacesSKUs holds details about calls to the GetCodespacesSKUs method. + GetCodespacesSKUs []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // User is the user argument value. + User *api.User + // Repository is the repository argument value. + Repository *api.Repository + // Branch is the branch argument value. + Branch string + // Location is the location argument value. + Location string + } + // GetRepository holds details about calls to the GetRepository method. + GetRepository []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Nwo is the nwo argument value. + Nwo string + } // GetUser holds details about calls to the GetUser method. GetUser []struct { // Ctx is the ctx argument value. @@ -96,12 +188,102 @@ type apiClientMock struct { // User is the user argument value. User string } + // StartCodespace holds details about calls to the StartCodespace method. + StartCodespace []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Token is the token argument value. + Token string + // Codespace is the codespace argument value. + Codespace *api.Codespace + } } - lockDeleteCodespace sync.RWMutex - lockGetCodespace sync.RWMutex - lockGetCodespaceToken sync.RWMutex - lockGetUser sync.RWMutex - lockListCodespaces sync.RWMutex + lockAuthorizedKeys sync.RWMutex + lockCreateCodespace sync.RWMutex + lockDeleteCodespace sync.RWMutex + lockGetCodespace sync.RWMutex + lockGetCodespaceRegionLocation sync.RWMutex + lockGetCodespaceRepositoryContents sync.RWMutex + lockGetCodespaceToken sync.RWMutex + lockGetCodespacesSKUs sync.RWMutex + lockGetRepository sync.RWMutex + lockGetUser sync.RWMutex + lockListCodespaces sync.RWMutex + lockStartCodespace sync.RWMutex +} + +// AuthorizedKeys calls AuthorizedKeysFunc. +func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) { + if mock.AuthorizedKeysFunc == nil { + panic("apiClientMock.AuthorizedKeysFunc: method is nil but apiClient.AuthorizedKeys was just called") + } + callInfo := struct { + Ctx context.Context + User string + }{ + Ctx: ctx, + User: user, + } + mock.lockAuthorizedKeys.Lock() + mock.calls.AuthorizedKeys = append(mock.calls.AuthorizedKeys, callInfo) + mock.lockAuthorizedKeys.Unlock() + return mock.AuthorizedKeysFunc(ctx, user) +} + +// AuthorizedKeysCalls gets all the calls that were made to AuthorizedKeys. +// Check the length with: +// len(mockedapiClient.AuthorizedKeysCalls()) +func (mock *apiClientMock) AuthorizedKeysCalls() []struct { + Ctx context.Context + User string +} { + var calls []struct { + Ctx context.Context + User string + } + mock.lockAuthorizedKeys.RLock() + calls = mock.calls.AuthorizedKeys + mock.lockAuthorizedKeys.RUnlock() + return calls +} + +// CreateCodespace calls CreateCodespaceFunc. +func (mock *apiClientMock) CreateCodespace(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) { + if mock.CreateCodespaceFunc == nil { + panic("apiClientMock.CreateCodespaceFunc: method is nil but apiClient.CreateCodespace was just called") + } + callInfo := struct { + Ctx context.Context + Logger api.Logger + Params *api.CreateCodespaceParams + }{ + Ctx: ctx, + Logger: logger, + Params: params, + } + mock.lockCreateCodespace.Lock() + mock.calls.CreateCodespace = append(mock.calls.CreateCodespace, callInfo) + mock.lockCreateCodespace.Unlock() + return mock.CreateCodespaceFunc(ctx, logger, params) +} + +// CreateCodespaceCalls gets all the calls that were made to CreateCodespace. +// Check the length with: +// len(mockedapiClient.CreateCodespaceCalls()) +func (mock *apiClientMock) CreateCodespaceCalls() []struct { + Ctx context.Context + Logger api.Logger + Params *api.CreateCodespaceParams +} { + var calls []struct { + Ctx context.Context + Logger api.Logger + Params *api.CreateCodespaceParams + } + mock.lockCreateCodespace.RLock() + calls = mock.calls.CreateCodespace + mock.lockCreateCodespace.RUnlock() + return calls } // DeleteCodespace calls DeleteCodespaceFunc. @@ -186,6 +368,76 @@ func (mock *apiClientMock) GetCodespaceCalls() []struct { return calls } +// GetCodespaceRegionLocation calls GetCodespaceRegionLocationFunc. +func (mock *apiClientMock) GetCodespaceRegionLocation(ctx context.Context) (string, error) { + if mock.GetCodespaceRegionLocationFunc == nil { + panic("apiClientMock.GetCodespaceRegionLocationFunc: method is nil but apiClient.GetCodespaceRegionLocation was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockGetCodespaceRegionLocation.Lock() + mock.calls.GetCodespaceRegionLocation = append(mock.calls.GetCodespaceRegionLocation, callInfo) + mock.lockGetCodespaceRegionLocation.Unlock() + return mock.GetCodespaceRegionLocationFunc(ctx) +} + +// GetCodespaceRegionLocationCalls gets all the calls that were made to GetCodespaceRegionLocation. +// Check the length with: +// len(mockedapiClient.GetCodespaceRegionLocationCalls()) +func (mock *apiClientMock) GetCodespaceRegionLocationCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockGetCodespaceRegionLocation.RLock() + calls = mock.calls.GetCodespaceRegionLocation + mock.lockGetCodespaceRegionLocation.RUnlock() + return calls +} + +// GetCodespaceRepositoryContents calls GetCodespaceRepositoryContentsFunc. +func (mock *apiClientMock) GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) { + if mock.GetCodespaceRepositoryContentsFunc == nil { + panic("apiClientMock.GetCodespaceRepositoryContentsFunc: method is nil but apiClient.GetCodespaceRepositoryContents was just called") + } + callInfo := struct { + Ctx context.Context + Codespace *api.Codespace + Path string + }{ + Ctx: ctx, + Codespace: codespace, + Path: path, + } + mock.lockGetCodespaceRepositoryContents.Lock() + mock.calls.GetCodespaceRepositoryContents = append(mock.calls.GetCodespaceRepositoryContents, callInfo) + mock.lockGetCodespaceRepositoryContents.Unlock() + return mock.GetCodespaceRepositoryContentsFunc(ctx, codespace, path) +} + +// GetCodespaceRepositoryContentsCalls gets all the calls that were made to GetCodespaceRepositoryContents. +// Check the length with: +// len(mockedapiClient.GetCodespaceRepositoryContentsCalls()) +func (mock *apiClientMock) GetCodespaceRepositoryContentsCalls() []struct { + Ctx context.Context + Codespace *api.Codespace + Path string +} { + var calls []struct { + Ctx context.Context + Codespace *api.Codespace + Path string + } + mock.lockGetCodespaceRepositoryContents.RLock() + calls = mock.calls.GetCodespaceRepositoryContents + mock.lockGetCodespaceRepositoryContents.RUnlock() + return calls +} + // GetCodespaceToken calls GetCodespaceTokenFunc. func (mock *apiClientMock) GetCodespaceToken(ctx context.Context, user string, name string) (string, error) { if mock.GetCodespaceTokenFunc == nil { @@ -225,6 +477,88 @@ func (mock *apiClientMock) GetCodespaceTokenCalls() []struct { return calls } +// GetCodespacesSKUs calls GetCodespacesSKUsFunc. +func (mock *apiClientMock) GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) { + if mock.GetCodespacesSKUsFunc == nil { + panic("apiClientMock.GetCodespacesSKUsFunc: method is nil but apiClient.GetCodespacesSKUs was just called") + } + callInfo := struct { + Ctx context.Context + User *api.User + Repository *api.Repository + Branch string + Location string + }{ + Ctx: ctx, + User: user, + Repository: repository, + Branch: branch, + Location: location, + } + mock.lockGetCodespacesSKUs.Lock() + mock.calls.GetCodespacesSKUs = append(mock.calls.GetCodespacesSKUs, callInfo) + mock.lockGetCodespacesSKUs.Unlock() + return mock.GetCodespacesSKUsFunc(ctx, user, repository, branch, location) +} + +// GetCodespacesSKUsCalls gets all the calls that were made to GetCodespacesSKUs. +// Check the length with: +// len(mockedapiClient.GetCodespacesSKUsCalls()) +func (mock *apiClientMock) GetCodespacesSKUsCalls() []struct { + Ctx context.Context + User *api.User + Repository *api.Repository + Branch string + Location string +} { + var calls []struct { + Ctx context.Context + User *api.User + Repository *api.Repository + Branch string + Location string + } + mock.lockGetCodespacesSKUs.RLock() + calls = mock.calls.GetCodespacesSKUs + mock.lockGetCodespacesSKUs.RUnlock() + return calls +} + +// GetRepository calls GetRepositoryFunc. +func (mock *apiClientMock) GetRepository(ctx context.Context, nwo string) (*api.Repository, error) { + if mock.GetRepositoryFunc == nil { + panic("apiClientMock.GetRepositoryFunc: method is nil but apiClient.GetRepository was just called") + } + callInfo := struct { + Ctx context.Context + Nwo string + }{ + Ctx: ctx, + Nwo: nwo, + } + mock.lockGetRepository.Lock() + mock.calls.GetRepository = append(mock.calls.GetRepository, callInfo) + mock.lockGetRepository.Unlock() + return mock.GetRepositoryFunc(ctx, nwo) +} + +// GetRepositoryCalls gets all the calls that were made to GetRepository. +// Check the length with: +// len(mockedapiClient.GetRepositoryCalls()) +func (mock *apiClientMock) GetRepositoryCalls() []struct { + Ctx context.Context + Nwo string +} { + var calls []struct { + Ctx context.Context + Nwo string + } + mock.lockGetRepository.RLock() + calls = mock.calls.GetRepository + mock.lockGetRepository.RUnlock() + return calls +} + // GetUser calls GetUserFunc. func (mock *apiClientMock) GetUser(ctx context.Context) (*api.User, error) { if mock.GetUserFunc == nil { @@ -290,3 +624,42 @@ func (mock *apiClientMock) ListCodespacesCalls() []struct { mock.lockListCodespaces.RUnlock() return calls } + +// StartCodespace calls StartCodespaceFunc. +func (mock *apiClientMock) StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error { + if mock.StartCodespaceFunc == nil { + panic("apiClientMock.StartCodespaceFunc: method is nil but apiClient.StartCodespace was just called") + } + callInfo := struct { + Ctx context.Context + Token string + Codespace *api.Codespace + }{ + Ctx: ctx, + Token: token, + Codespace: codespace, + } + mock.lockStartCodespace.Lock() + mock.calls.StartCodespace = append(mock.calls.StartCodespace, callInfo) + mock.lockStartCodespace.Unlock() + return mock.StartCodespaceFunc(ctx, token, codespace) +} + +// StartCodespaceCalls gets all the calls that were made to StartCodespace. +// Check the length with: +// len(mockedapiClient.StartCodespaceCalls()) +func (mock *apiClientMock) StartCodespaceCalls() []struct { + Ctx context.Context + Token string + Codespace *api.Codespace +} { + var calls []struct { + Ctx context.Context + Token string + Codespace *api.Codespace + } + mock.lockStartCodespace.RLock() + calls = mock.calls.StartCodespace + mock.lockStartCodespace.RUnlock() + return calls +} diff --git a/cmd/ghcs/ports.go b/cmd/ghcs/ports.go index f423245bd..06eabad6d 100644 --- a/cmd/ghcs/ports.go +++ b/cmd/ghcs/ports.go @@ -22,7 +22,7 @@ import ( // newPortsCmd returns a Cobra "ports" command that displays a table of available ports, // according to the specified flags. -func newPortsCmd() *cobra.Command { +func newPortsCmd(app *App) *cobra.Command { var ( codespace string asJSON bool @@ -33,31 +33,28 @@ func newPortsCmd() *cobra.Command { Short: "List ports in a codespace", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return ports(codespace, asJSON) + return app.ListPorts(cmd.Context(), codespace, asJSON) }, } portsCmd.PersistentFlags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") portsCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON") - portsCmd.AddCommand(newPortsPublicCmd()) - portsCmd.AddCommand(newPortsPrivateCmd()) - portsCmd.AddCommand(newPortsForwardCmd()) + portsCmd.AddCommand(newPortsPublicCmd(app)) + portsCmd.AddCommand(newPortsPrivateCmd(app)) + portsCmd.AddCommand(newPortsForwardCmd(app)) return portsCmd } -func ports(codespaceName string, asJSON bool) (err error) { - apiClient := api.New(os.Getenv("GITHUB_TOKEN")) - ctx := context.Background() - log := output.NewLogger(os.Stdout, os.Stderr, asJSON) - - user, err := apiClient.GetUser(ctx) +// ListPorts lists known ports in a codespace. +func (a *App) ListPorts(ctx context.Context, codespaceName string, asJSON bool) (err error) { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { // TODO(josebalius): remove special handling of this error here and it other places if err == errNoCodespaces { @@ -66,15 +63,15 @@ func ports(codespaceName string, asJSON bool) (err error) { return fmt.Errorf("error choosing codespace: %w", err) } - devContainerCh := getDevContainer(ctx, apiClient, codespace) + devContainerCh := getDevContainer(ctx, a.apiClient, codespace) - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } defer safeClose(session, &err) - log.Println("Loading ports...") + a.logger.Println("Loading ports...") ports, err := session.GetSharedServers(ctx) if err != nil { return fmt.Errorf("error getting ports of shared servers: %w", err) @@ -83,7 +80,7 @@ func ports(codespaceName string, asJSON bool) (err error) { devContainerResult := <-devContainerCh if devContainerResult.err != nil { // Warn about failure to read the devcontainer file. Not a ghcs command error. - _, _ = log.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error()) + _, _ = a.logger.Errorf("Failed to get port names: %v\n", devContainerResult.err.Error()) } table := output.NewTable(os.Stdout, asJSON) @@ -122,7 +119,7 @@ type portAttribute struct { Label string `json:"label"` } -func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Codespace) <-chan devContainerResult { +func getDevContainer(ctx context.Context, apiClient apiClient, codespace *api.Codespace) <-chan devContainerResult { ch := make(chan devContainerResult, 1) go func() { contents, err := apiClient.GetCodespaceRepositoryContents(ctx, codespace, ".devcontainer/devcontainer.json") @@ -154,7 +151,7 @@ func getDevContainer(ctx context.Context, apiClient *api.API, codespace *api.Cod } // newPortsPublicCmd returns a Cobra "ports public" subcommand, which makes a given port public. -func newPortsPublicCmd() *cobra.Command { +func newPortsPublicCmd(app *App) *cobra.Command { return &cobra.Command{ Use: "public ", Short: "Mark port as public", @@ -168,14 +165,13 @@ func newPortsPublicCmd() *cobra.Command { return fmt.Errorf("get codespace flag: %w", err) } - log := output.NewLogger(os.Stdout, os.Stderr, false) - return updatePortVisibility(log, codespace, args[0], true) + return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], true) }, } } // newPortsPrivateCmd returns a Cobra "ports private" subcommand, which makes a given port private. -func newPortsPrivateCmd() *cobra.Command { +func newPortsPrivateCmd(app *App) *cobra.Command { return &cobra.Command{ Use: "private ", Short: "Mark port as private", @@ -189,22 +185,18 @@ func newPortsPrivateCmd() *cobra.Command { return fmt.Errorf("get codespace flag: %w", err) } - log := output.NewLogger(os.Stdout, os.Stderr, false) - return updatePortVisibility(log, codespace, args[0], false) + return app.UpdatePortVisibility(cmd.Context(), codespace, args[0], false) }, } } -func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, public bool) (err error) { - ctx := context.Background() - apiClient := api.New(GithubToken) - - user, err := apiClient.GetUser(ctx) +func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName, sourcePort string, public bool) (err error) { + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { if err == errNoCodespaces { return err @@ -212,7 +204,7 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, return fmt.Errorf("error getting codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -231,14 +223,14 @@ func updatePortVisibility(log *output.Logger, codespaceName, sourcePort string, if !public { state = "PRIVATE" } - log.Printf("Port %s is now %s.\n", sourcePort, state) + a.logger.Printf("Port %s is now %s.\n", sourcePort, state) return nil } // NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of // port pairs from the codespace to localhost. -func newPortsForwardCmd() *cobra.Command { +func newPortsForwardCmd(app *App) *cobra.Command { return &cobra.Command{ Use: "forward :...", Short: "Forward ports", @@ -252,27 +244,23 @@ func newPortsForwardCmd() *cobra.Command { return fmt.Errorf("get codespace flag: %w", err) } - log := output.NewLogger(os.Stdout, os.Stderr, false) - return forwardPorts(log, codespace, args) + return app.ForwardPorts(cmd.Context(), codespace, args) }, } } -func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err error) { - ctx := context.Background() - apiClient := api.New(GithubToken) - +func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []string) (err error) { portPairs, err := getPortPairs(ports) if err != nil { return fmt.Errorf("get port pairs: %w", err) } - user, err := apiClient.GetUser(ctx) + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { if err == errNoCodespaces { return err @@ -280,7 +268,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err return fmt.Errorf("error getting codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -297,7 +285,7 @@ func forwardPorts(log *output.Logger, codespaceName string, ports []string) (err return err } defer listen.Close() - log.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) + a.logger.Printf("Forwarding ports: remote %d <=> local %d\n", pair.remote, pair.local) name := fmt.Sprintf("share-%d", pair.remote) fwd := liveshare.NewPortForwarder(session, name, pair.remote) return fwd.ForwardToListener(ctx, listen) // error always non-nil diff --git a/cmd/ghcs/root.go b/cmd/ghcs/root.go index 6db4144a8..b71f4a0ff 100644 --- a/cmd/ghcs/root.go +++ b/cmd/ghcs/root.go @@ -1,10 +1,8 @@ package ghcs import ( - "errors" "fmt" "log" - "os" "strconv" "strings" @@ -15,10 +13,7 @@ import ( var version = "DEV" // Replaced in the release build process (by GoReleaser or Homebrew) by the git tag version number. -// GithubToken is a temporary stopgap to make the token configurable by apps that import this package -var GithubToken = os.Getenv("GITHUB_TOKEN") - -func NewRootCmd() *cobra.Command { +func NewRootCmd(app *App) *cobra.Command { var lightstep string root := &cobra.Command{ @@ -32,28 +27,23 @@ token to access the GitHub API with.`, Version: version, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - if os.Getenv("GITHUB_TOKEN") == "" { - return ErrTokenMissing - } return initLightstep(lightstep) }, } root.PersistentFlags().StringVar(&lightstep, "lightstep", "", "Lightstep tracing endpoint (service:token@host:port)") - root.AddCommand(newCodeCmd()) - root.AddCommand(newCreateCmd()) - root.AddCommand(newDeleteCmd()) - root.AddCommand(newListCmd()) - root.AddCommand(newLogsCmd()) - root.AddCommand(newPortsCmd()) - root.AddCommand(newSSHCmd()) + root.AddCommand(newCodeCmd(app)) + root.AddCommand(newCreateCmd(app)) + root.AddCommand(newDeleteCmd(app)) + root.AddCommand(newListCmd(app)) + root.AddCommand(newLogsCmd(app)) + root.AddCommand(newPortsCmd(app)) + root.AddCommand(newSSHCmd(app)) return root } -var ErrTokenMissing = errors.New("GITHUB_TOKEN is missing") - // initLightstep parses the --lightstep=service:token@host:port flag and // enables tracing if non-empty. func initLightstep(config string) error { diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index bb771107a..bda7c28bb 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -4,16 +4,13 @@ import ( "context" "fmt" "net" - "os" - "github.com/github/ghcs/cmd/ghcs/output" - "github.com/github/ghcs/internal/api" "github.com/github/ghcs/internal/codespaces" "github.com/github/ghcs/internal/liveshare" "github.com/spf13/cobra" ) -func newSSHCmd() *cobra.Command { +func newSSHCmd(app *App) *cobra.Command { var sshProfile, codespaceName string var sshServerPort int @@ -21,7 +18,7 @@ func newSSHCmd() *cobra.Command { Use: "ssh [flags] [--] [ssh-flags] [command]", Short: "SSH into a codespace", RunE: func(cmd *cobra.Command, args []string) error { - return ssh(context.Background(), args, sshProfile, codespaceName, sshServerPort) + return app.SSH(cmd.Context(), args, sshProfile, codespaceName, sshServerPort) }, } @@ -32,30 +29,28 @@ func newSSHCmd() *cobra.Command { return sshCmd } -func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) { +// SSH opens an ssh session or runs an ssh command in a codespace. +func (a *App) SSH(ctx context.Context, sshArgs []string, sshProfile, codespaceName string, localSSHServerPort int) (err error) { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() - apiClient := api.New(GithubToken) - log := output.NewLogger(os.Stdout, os.Stderr, false) - - user, err := apiClient.GetUser(ctx) + user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) } authkeys := make(chan error, 1) go func() { - authkeys <- checkAuthorizedKeys(ctx, apiClient, user.Login) + authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login) }() - codespace, token, err := getOrChooseCodespace(ctx, apiClient, user, codespaceName) + codespace, token, err := getOrChooseCodespace(ctx, a.apiClient, user, codespaceName) if err != nil { return fmt.Errorf("get or choose codespace: %w", err) } - session, err := codespaces.ConnectToLiveshare(ctx, log, apiClient, user.Login, token, codespace) + session, err := codespaces.ConnectToLiveshare(ctx, a.logger, a.apiClient, user.Login, token, codespace) if err != nil { return fmt.Errorf("error connecting to Live Share: %w", err) } @@ -65,7 +60,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string return err } - log.Println("Fetching SSH Details...") + a.logger.Println("Fetching SSH Details...") remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx) if err != nil { return fmt.Errorf("error getting ssh server details: %w", err) @@ -86,7 +81,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string connectDestination = fmt.Sprintf("%s@localhost", sshUser) } - log.Println("Ready...") + a.logger.Println("Ready...") tunnelClosed := make(chan error, 1) go func() { fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort) @@ -95,7 +90,7 @@ func ssh(ctx context.Context, sshArgs []string, sshProfile, codespaceName string shellClosed := make(chan error, 1) go func() { - shellClosed <- codespaces.Shell(ctx, log, sshArgs, localSSHServerPort, connectDestination, usingCustomPort) + shellClosed <- codespaces.Shell(ctx, a.logger, sshArgs, localSSHServerPort, connectDestination, usingCustomPort) }() select { diff --git a/internal/api/api.go b/internal/api/api.go index bfccfc6c9..4cce56894 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -45,14 +45,18 @@ const githubAPI = "https://api.github.com" type API struct { token string - client *http.Client + client httpClient githubAPI string } -func New(token string) *API { +type httpClient interface { + Do(req *http.Request) (*http.Response, error) +} + +func New(token string, httpClient httpClient) *API { return &API{ token: token, - client: &http.Client{}, + client: httpClient, githubAPI: githubAPI, } } @@ -272,6 +276,7 @@ func (a *API) GetCodespace(ctx context.Context, token, owner, codespace string) return nil, fmt.Errorf("error creating request: %w", err) } + // TODO: use a.setHeaders() req.Header.Set("Authorization", "Bearer "+token) resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") if err != nil { @@ -306,6 +311,7 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes return fmt.Errorf("error creating request: %w", err) } + // TODO: use a.setHeaders() req.Header.Set("Authorization", "Bearer "+token) resp, err := a.do(ctx, req, "/vscs_internal/proxy/environments/*/start") if err != nil { @@ -417,14 +423,14 @@ type CreateCodespaceParams struct { Branch, Machine, Location string } -type logger interface { +type Logger interface { Print(v ...interface{}) (int, error) Println(v ...interface{}) (int, error) } // CreateCodespace creates a codespace with the given parameters and returns a non-nil error if it // fails to create. -func (a *API) CreateCodespace(ctx context.Context, log logger, params *CreateCodespaceParams) (*Codespace, error) { +func (a *API) CreateCodespace(ctx context.Context, log Logger, params *CreateCodespaceParams) (*Codespace, error) { codespace, err := a.startCreate( ctx, params.User, params.RepositoryID, params.Machine, params.Branch, params.Location, ) @@ -529,6 +535,7 @@ func (a *API) DeleteCodespace(ctx context.Context, user string, codespaceName st return fmt.Errorf("error creating request: %w", err) } + // TODO: use a.setHeaders() req.Header.Set("Authorization", "Bearer "+token) resp, err := a.do(ctx, req, "/vscs_internal/user/*/codespaces/*") if err != nil { @@ -628,6 +635,8 @@ func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http } func (a *API) setHeaders(req *http.Request) { - req.Header.Set("Authorization", "Bearer "+a.token) + if a.token != "" { + req.Header.Set("Authorization", "Bearer "+a.token) + } req.Header.Set("Accept", "application/vnd.github.v3+json") } diff --git a/internal/codespaces/codespaces.go b/internal/codespaces/codespaces.go index 1cd605abc..f3cf71b51 100644 --- a/internal/codespaces/codespaces.go +++ b/internal/codespaces/codespaces.go @@ -23,9 +23,15 @@ func connectionReady(codespace *api.Codespace) bool { codespace.Environment.State == api.CodespaceEnvironmentStateAvailable } +type apiClient interface { + GetCodespace(ctx context.Context, token, user, name string) (*api.Codespace, error) + GetCodespaceToken(ctx context.Context, user, codespace string) (string, error) + StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error +} + // ConnectToLiveshare waits for a Codespace to become running, // and connects to it using a Live Share session. -func ConnectToLiveshare(ctx context.Context, log logger, apiClient *api.API, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { +func ConnectToLiveshare(ctx context.Context, log logger, apiClient apiClient, userLogin, token string, codespace *api.Codespace) (*liveshare.Session, error) { var startedCodespace bool if codespace.Environment.State != api.CodespaceEnvironmentStateAvailable { startedCodespace = true diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index c7d61b41e..0b395d6e3 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -36,7 +36,7 @@ type PostCreateState struct { // PollPostCreateStates watches for state changes in a codespace, // and calls the supplied poller for each batch of state changes. // It runs until it encounters an error, including cancellation of the context. -func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { +func PollPostCreateStates(ctx context.Context, log logger, apiClient apiClient, user *api.User, codespace *api.Codespace, poller func([]PostCreateState)) (err error) { token, err := apiClient.GetCodespaceToken(ctx, user.Login, codespace.Name) if err != nil { return fmt.Errorf("getting codespace token: %w", err) From dc8f6ef183f6c4d7a0f4135376d54724302abb01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 24 Sep 2021 17:30:31 +0200 Subject: [PATCH 287/290] No longer accept a logger in CreateCodespace The API layer shouldn't concern itself with logging progress to stderr. Instead, we will subsequently add progress indicators in the caller around CreateCodespace and other potentially slow commands as needed. --- cmd/ghcs/common.go | 2 +- cmd/ghcs/create.go | 2 +- cmd/ghcs/mock_api.go | 14 ++++---------- internal/api/api.go | 8 +------- 4 files changed, 7 insertions(+), 19 deletions(-) diff --git a/cmd/ghcs/common.go b/cmd/ghcs/common.go index e60fa7c96..fcdbb9f11 100644 --- a/cmd/ghcs/common.go +++ b/cmd/ghcs/common.go @@ -38,7 +38,7 @@ type apiClient interface { ListCodespaces(ctx context.Context, user string) ([]*api.Codespace, error) DeleteCodespace(ctx context.Context, user, name string) error StartCodespace(ctx context.Context, token string, codespace *api.Codespace) error - CreateCodespace(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) + CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) GetRepository(ctx context.Context, nwo string) (*api.Repository, error) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) GetCodespaceRegionLocation(ctx context.Context) (string, error) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index c92a6edff..7e861e08d 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -80,7 +80,7 @@ func (a *App) Create(ctx context.Context, opts *createOptions) error { } a.logger.Print("Creating your codespace...") - codespace, err := a.apiClient.CreateCodespace(ctx, a.logger, &api.CreateCodespaceParams{ + codespace, err := a.apiClient.CreateCodespace(ctx, &api.CreateCodespaceParams{ User: userResult.User.Login, RepositoryID: repository.ID, Branch: branch, diff --git a/cmd/ghcs/mock_api.go b/cmd/ghcs/mock_api.go index 93abe7ed6..ef08c0a78 100644 --- a/cmd/ghcs/mock_api.go +++ b/cmd/ghcs/mock_api.go @@ -19,7 +19,7 @@ import ( // AuthorizedKeysFunc: func(ctx context.Context, user string) ([]byte, error) { // panic("mock out the AuthorizedKeys method") // }, -// CreateCodespaceFunc: func(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) { +// CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) { // panic("mock out the CreateCodespace method") // }, // DeleteCodespaceFunc: func(ctx context.Context, user string, name string) error { @@ -63,7 +63,7 @@ type apiClientMock struct { AuthorizedKeysFunc func(ctx context.Context, user string) ([]byte, error) // CreateCodespaceFunc mocks the CreateCodespace method. - CreateCodespaceFunc func(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) + CreateCodespaceFunc func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) // DeleteCodespaceFunc mocks the DeleteCodespace method. DeleteCodespaceFunc func(ctx context.Context, user string, name string) error @@ -108,8 +108,6 @@ type apiClientMock struct { CreateCodespace []struct { // Ctx is the ctx argument value. Ctx context.Context - // Logger is the logger argument value. - Logger api.Logger // Params is the params argument value. Params *api.CreateCodespaceParams } @@ -248,23 +246,21 @@ func (mock *apiClientMock) AuthorizedKeysCalls() []struct { } // CreateCodespace calls CreateCodespaceFunc. -func (mock *apiClientMock) CreateCodespace(ctx context.Context, logger api.Logger, params *api.CreateCodespaceParams) (*api.Codespace, error) { +func (mock *apiClientMock) CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) { if mock.CreateCodespaceFunc == nil { panic("apiClientMock.CreateCodespaceFunc: method is nil but apiClient.CreateCodespace was just called") } callInfo := struct { Ctx context.Context - Logger api.Logger Params *api.CreateCodespaceParams }{ Ctx: ctx, - Logger: logger, Params: params, } mock.lockCreateCodespace.Lock() mock.calls.CreateCodespace = append(mock.calls.CreateCodespace, callInfo) mock.lockCreateCodespace.Unlock() - return mock.CreateCodespaceFunc(ctx, logger, params) + return mock.CreateCodespaceFunc(ctx, params) } // CreateCodespaceCalls gets all the calls that were made to CreateCodespace. @@ -272,12 +268,10 @@ func (mock *apiClientMock) CreateCodespace(ctx context.Context, logger api.Logge // len(mockedapiClient.CreateCodespaceCalls()) func (mock *apiClientMock) CreateCodespaceCalls() []struct { Ctx context.Context - Logger api.Logger Params *api.CreateCodespaceParams } { var calls []struct { Ctx context.Context - Logger api.Logger Params *api.CreateCodespaceParams } mock.lockCreateCodespace.RLock() diff --git a/internal/api/api.go b/internal/api/api.go index 4cce56894..efc24bcfb 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -423,14 +423,9 @@ type CreateCodespaceParams struct { Branch, Machine, Location string } -type Logger interface { - Print(v ...interface{}) (int, error) - Println(v ...interface{}) (int, error) -} - // CreateCodespace creates a codespace with the given parameters and returns a non-nil error if it // fails to create. -func (a *API) CreateCodespace(ctx context.Context, log Logger, params *CreateCodespaceParams) (*Codespace, error) { +func (a *API) CreateCodespace(ctx context.Context, params *CreateCodespaceParams) (*Codespace, error) { codespace, err := a.startCreate( ctx, params.User, params.RepositoryID, params.Machine, params.Branch, params.Location, ) @@ -452,7 +447,6 @@ func (a *API) CreateCodespace(ctx context.Context, log Logger, params *CreateCod case <-ctx.Done(): return nil, ctx.Err() case <-ticker.C: - log.Print(".") token, err := a.GetCodespaceToken(ctx, params.User, codespace.Name) if err != nil { if err == ErrNotProvisioned { From c82d4c54724d9d879350052d7f0c993d92ec13c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Fri, 24 Sep 2021 17:36:18 +0200 Subject: [PATCH 288/290] Avoid passing params struct as pointer --- cmd/ghcs/create.go | 4 ++-- cmd/ghcs/list.go | 14 +++++--------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/cmd/ghcs/create.go b/cmd/ghcs/create.go index 7e861e08d..7174e7721 100644 --- a/cmd/ghcs/create.go +++ b/cmd/ghcs/create.go @@ -23,7 +23,7 @@ type createOptions struct { } func newCreateCmd(app *App) *cobra.Command { - opts := &createOptions{} + opts := createOptions{} createCmd := &cobra.Command{ Use: "create", @@ -43,7 +43,7 @@ func newCreateCmd(app *App) *cobra.Command { } // Create creates a new Codespace -func (a *App) Create(ctx context.Context, opts *createOptions) error { +func (a *App) Create(ctx context.Context, opts createOptions) error { locationCh := getLocation(ctx, a.apiClient) userCh := getUser(ctx, a.apiClient) diff --git a/cmd/ghcs/list.go b/cmd/ghcs/list.go index 842b9313d..1fc59cff0 100644 --- a/cmd/ghcs/list.go +++ b/cmd/ghcs/list.go @@ -10,28 +10,24 @@ import ( "github.com/spf13/cobra" ) -type listOptions struct { - asJSON bool -} - func newListCmd(app *App) *cobra.Command { - opts := &listOptions{} + var asJSON bool listCmd := &cobra.Command{ Use: "list", Short: "List your codespaces", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return app.List(cmd.Context(), opts) + return app.List(cmd.Context(), asJSON) }, } - listCmd.Flags().BoolVar(&opts.asJSON, "json", false, "Output as JSON") + listCmd.Flags().BoolVar(&asJSON, "json", false, "Output as JSON") return listCmd } -func (a *App) List(ctx context.Context, opts *listOptions) error { +func (a *App) List(ctx context.Context, asJSON bool) error { user, err := a.apiClient.GetUser(ctx) if err != nil { return fmt.Errorf("error getting user: %w", err) @@ -42,7 +38,7 @@ func (a *App) List(ctx context.Context, opts *listOptions) error { return fmt.Errorf("error getting codespaces: %w", err) } - table := output.NewTable(os.Stdout, opts.asJSON) + table := output.NewTable(os.Stdout, asJSON) table.SetHeader([]string{"Name", "Repository", "Branch", "State", "Created At"}) for _, codespace := range codespaces { table.Append([]string{ From 57d9b1a9e1acae43a14e4e577be5c1f41a8472a0 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 27 Sep 2021 14:51:52 -0400 Subject: [PATCH 289/290] create: decode JSON error heuristically --- internal/api/api.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internal/api/api.go b/internal/api/api.go index bfccfc6c9..84f2c4a5a 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -319,7 +319,10 @@ func (a *API) StartCodespace(ctx context.Context, token string, codespace *Codes } if resp.StatusCode != http.StatusOK { - // Error response is typically a numeric code (not an error message, nor JSON). + // Error response may be a numeric code or a JSON {"message": "..."}. + if bytes.HasPrefix(b, []byte("{")) { + return jsonErrorResponse(b) // probably JSON + } if len(b) > 100 { b = append(b[:97], "..."...) } From f947ef3448e47cd0b1fa38e8f3e8a43d42bbfb52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Tue, 28 Sep 2021 16:42:35 +0200 Subject: [PATCH 290/290] Remove lightstep configuration The `github.com/shirou/gopsutil` dependency of lightstep-tracer is giving us trouble during building. Ref. https://github.com/shirou/gopsutil/issues/976 Another build problem raises its head even after we upgrade gopsutil to a version where the above bug is fixed. --- cmd/ghcs/root.go | 63 ------------------------------------------------ 1 file changed, 63 deletions(-) diff --git a/cmd/ghcs/root.go b/cmd/ghcs/root.go index b71f4a0ff..c9fdd2876 100644 --- a/cmd/ghcs/root.go +++ b/cmd/ghcs/root.go @@ -1,21 +1,12 @@ package ghcs import ( - "fmt" - "log" - "strconv" - "strings" - - "github.com/lightstep/lightstep-tracer-go" - "github.com/opentracing/opentracing-go" "github.com/spf13/cobra" ) var version = "DEV" // Replaced in the release build process (by GoReleaser or Homebrew) by the git tag version number. func NewRootCmd(app *App) *cobra.Command { - var lightstep string - root := &cobra.Command{ Use: "ghcs", SilenceUsage: true, // don't print usage message after each error (see #80) @@ -25,14 +16,8 @@ func NewRootCmd(app *App) *cobra.Command { Running commands requires the GITHUB_TOKEN environment variable to be set to a token to access the GitHub API with.`, Version: version, - - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - return initLightstep(lightstep) - }, } - root.PersistentFlags().StringVar(&lightstep, "lightstep", "", "Lightstep tracing endpoint (service:token@host:port)") - root.AddCommand(newCodeCmd(app)) root.AddCommand(newCreateCmd(app)) root.AddCommand(newDeleteCmd(app)) @@ -43,51 +28,3 @@ token to access the GitHub API with.`, return root } - -// initLightstep parses the --lightstep=service:token@host:port flag and -// enables tracing if non-empty. -func initLightstep(config string) error { - if config == "" { - return nil - } - - cut := func(s, sep string) (pre, post string) { - if i := strings.Index(s, sep); i >= 0 { - return s[:i], s[i+len(sep):] - } - return s, "" - } - - // Parse service:token@host:port. - serviceToken, hostPort := cut(config, "@") - service, token := cut(serviceToken, ":") - host, port := cut(hostPort, ":") - portI, err := strconv.Atoi(port) - if err != nil { - return fmt.Errorf("invalid Lightstep configuration: %s", config) - } - - opentracing.SetGlobalTracer(lightstep.NewTracer(lightstep.Options{ - AccessToken: token, - Collector: lightstep.Endpoint{ - Host: host, - Port: portI, - Plaintext: false, - }, - Tags: opentracing.Tags{ - lightstep.ComponentNameKey: service, - }, - })) - - // Report failure to record traces. - lightstep.SetGlobalEventHandler(func(ev lightstep.Event) { - switch ev := ev.(type) { - case lightstep.EventStatusReport, lightstep.MetricEventStatusReport: - // ignore - default: - log.Printf("[trace] %s", ev) - } - }) - - return nil -}