diff --git a/api/http_client.go b/api/http_client.go index 83f228409..3fb83974e 100644 --- a/api/http_client.go +++ b/api/http_client.go @@ -14,7 +14,7 @@ import ( ) type tokenGetter interface { - AuthToken(string) (string, string) + Token(string) (string, string) } type HTTPClientOptions struct { @@ -93,7 +93,7 @@ func AddAuthTokenHeader(rt http.RoundTripper, cfg tokenGetter) http.RoundTripper // If the header is already set in the request, don't overwrite it. if req.Header.Get(authorization) == "" { hostname := ghinstance.NormalizeHostname(getHost(req)) - if token, _ := cfg.AuthToken(hostname); token != "" { + if token, _ := cfg.Token(hostname); token != "" { req.Header.Set(authorization, fmt.Sprintf("token %s", token)) } } diff --git a/api/http_client_test.go b/api/http_client_test.go index d1dce0da6..ee59dbfc1 100644 --- a/api/http_client_test.go +++ b/api/http_client_test.go @@ -202,7 +202,7 @@ func TestNewHTTPClient(t *testing.T) { type tinyConfig map[string]string -func (c tinyConfig) AuthToken(host string) (string, string) { +func (c tinyConfig) Token(host string) (string, string) { return c[fmt.Sprintf("%s:%s", host, "oauth_token")], "oauth_token" } diff --git a/go.mod b/go.mod index fcba58a19..4c6e68bf6 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,10 @@ require ( github.com/AlecAivazis/survey/v2 v2.3.6 github.com/MakeNowJust/heredoc v1.0.0 github.com/briandowns/spinner v1.18.1 - github.com/cenkalti/backoff/v4 v4.1.3 + github.com/cenkalti/backoff/v4 v4.2.0 github.com/charmbracelet/glamour v0.5.1-0.20220727184942-e70ff2d969da github.com/charmbracelet/lipgloss v0.5.0 - github.com/cli/go-gh v1.0.0 + github.com/cli/go-gh v1.2.1 github.com/cli/oauth v1.0.1 github.com/cli/safeexec v1.0.1 github.com/cpuguy83/go-md2man/v2 v2.0.2 @@ -35,10 +35,11 @@ require ( github.com/spf13/cobra v1.6.1 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.7.5 + github.com/zalando/go-keyring v0.2.2 golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 golang.org/x/sync v0.1.0 - golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 - golang.org/x/text v0.5.0 + golang.org/x/term v0.5.0 + golang.org/x/text v0.7.0 google.golang.org/grpc v1.49.0 google.golang.org/protobuf v1.27.1 gopkg.in/yaml.v3 v3.0.1 @@ -46,13 +47,16 @@ require ( require ( github.com/alecthomas/chroma v0.10.0 // indirect + github.com/alessio/shellescape v1.4.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/cli/browser v1.1.0 // indirect github.com/cli/shurcooL-graphql v0.0.2 // indirect + github.com/danieljoos/wincred v1.1.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.4.0 // indirect github.com/fatih/color v1.7.0 // indirect github.com/gdamore/encoding v1.0.0 // indirect + github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/gorilla/css v1.0.0 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect @@ -73,9 +77,9 @@ require ( github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e // indirect github.com/yuin/goldmark v1.4.13 // indirect github.com/yuin/goldmark-emoji v1.0.1 // indirect - golang.org/x/net v0.0.0-20220923203811-8be639271d50 // indirect + golang.org/x/net v0.7.0 // indirect golang.org/x/oauth2 v0.0.0-20220309155454-6242fa91716a // indirect - golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab // indirect + golang.org/x/sys v0.5.0 // indirect google.golang.org/genproto v0.0.0-20200825200019-8632dd797987 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect ) diff --git a/go.sum b/go.sum index 03be7b86e..05979d813 100644 --- a/go.sum +++ b/go.sum @@ -41,12 +41,14 @@ github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63n github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w= github.com/alecthomas/chroma v0.10.0 h1:7XDcGkCQopCNKjZHfYrNLraA+M7e0fMiJ/Mfikbfjek= github.com/alecthomas/chroma v0.10.0/go.mod h1:jtJATyUxlIORhUOFNA9NZDWGAQ8wpxQQqNSB4rjA/1s= +github.com/alessio/shellescape v1.4.1 h1:V7yhSDDn8LP4lc4jS8pFkt0zCnzVJlG5JXy9BVKJUX0= +github.com/alessio/shellescape v1.4.1/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/briandowns/spinner v1.18.1 h1:yhQmQtM1zsqFsouh09Bk/jCjd50pC3EOGsh28gLVvwY= github.com/briandowns/spinner v1.18.1/go.mod h1:mQak9GHqbspjC/5iUx3qMlIho8xBS/ppAL/hX5SmPJU= -github.com/cenkalti/backoff/v4 v4.1.3 h1:cFAlzYUlVYDysBEH2T5hyJZMh3+5+WCBvSnK6Q8UtC4= -github.com/cenkalti/backoff/v4 v4.1.3/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= +github.com/cenkalti/backoff/v4 v4.2.0 h1:HN5dHm3WBOgndBH6E8V0q2jIYIR3s9yglV8k/+MN3u4= +github.com/cenkalti/backoff/v4 v4.2.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/charmbracelet/glamour v0.5.1-0.20220727184942-e70ff2d969da h1:FGz53GWQRiKQ/5xUsoCCkewSQIC7u81Scaxx2nUy3nM= github.com/charmbracelet/glamour v0.5.1-0.20220727184942-e70ff2d969da/go.mod h1:HXz79SMFnF9arKxqeoHWxmo1BhplAH7wehlRhKQIL94= @@ -60,8 +62,8 @@ github.com/cli/browser v1.1.0 h1:xOZBfkfY9L9vMBgqb1YwRirGu6QFaQ5dP/vXt5ENSOY= github.com/cli/browser v1.1.0/go.mod h1:HKMQAt9t12kov91Mn7RfZxyJQQgWgyS/3SZswlZ5iTI= github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03 h1:3f4uHLfWx4/WlnMPXGai03eoWAI+oGHJwr+5OXfxCr8= github.com/cli/crypto v0.0.0-20210929142629-6be313f59b03/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -github.com/cli/go-gh v1.0.0 h1:zE1YUAUYqGXNZuICEBeOkIMJ5F50BS0ftvtoWGlsEFI= -github.com/cli/go-gh v1.0.0/go.mod h1:bqxLdCoTZ73BuiPEJx4olcO/XKhVZaFDchFagYRBweE= +github.com/cli/go-gh v1.2.1 h1:xFrjejSsgPiwXFP6VYynKWwxLQcNJy3Twbu82ZDlR/o= +github.com/cli/go-gh v1.2.1/go.mod h1:Jxk8X+TCO4Ui/GarwY9tByWm/8zp4jJktzVZNlTW5VM= github.com/cli/oauth v1.0.1 h1:pXnTFl/qUegXHK531Dv0LNjW4mLx626eS42gnzfXJPA= github.com/cli/oauth v1.0.1/go.mod h1:qd/FX8ZBD6n1sVNQO3aIdRxeu5LGw9WhKnYhIIoC2A4= github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q= @@ -76,6 +78,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/danieljoos/wincred v1.1.2 h1:QLdCxFs1/Yl4zduvBdcHB8goaYk9RARS2SgLLRuAyr0= +github.com/danieljoos/wincred v1.1.2/go.mod h1:GijpziifJoIBfYh+S7BbkdUTU4LfM+QnGqR5Vl2tAx0= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -96,6 +100,8 @@ github.com/gdamore/tcell/v2 v2.5.4/go.mod h1:dZgRy5v4iMobMEcWNYBtREnDZAT9DYmfqIk github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= +github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -265,6 +271,8 @@ github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark-emoji v1.0.1 h1:ctuWEyzGBwiucEqxzwe0SOYDXPAucOrE9NQC18Wa1os= github.com/yuin/goldmark-emoji v1.0.1/go.mod h1:2w1E6FEWLcDQkoTE+7HU6QF1F6SLlNGjRIBbIZQFqkQ= +github.com/zalando/go-keyring v0.2.2 h1:f0xmpYiSrHtSNAVgwip93Cg8tuF45HJM6rHq/A5RI/4= +github.com/zalando/go-keyring v0.2.2/go.mod h1:sI3evg9Wvpw3+n4SqplGSJUMwtDeROfD4nsFz4z9PG0= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= @@ -333,8 +341,9 @@ golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.0.0-20220923203811-8be639271d50 h1:vKyz8L3zkd+xrMeIaBsQ/MNVPVFSffdaU3ZyYlBGFnI= golang.org/x/net v0.0.0-20220923203811-8be639271d50/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -385,6 +394,7 @@ golang.org/x/sys v0.0.0-20210319071255-635bc2c9138d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -393,12 +403,14 @@ golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab h1:2QkjZIsXupsJbJIdSjjUOgWK3aEtzyuh2mPt3l/CkeU= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210503060354-a79de5458b56/go.mod h1:tfny5GFUkzUvx4ps4ajbZsCe5lw1metzhBm9T3x7oIY= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -406,8 +418,9 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/authflow/flow.go b/internal/authflow/flow.go index 53b65af9e..ddac948cd 100644 --- a/internal/authflow/flow.go +++ b/internal/authflow/flow.go @@ -6,7 +6,6 @@ import ( "io" "net/http" "net/url" - "os" "regexp" "strings" @@ -28,37 +27,7 @@ var ( jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) ) -type iconfig interface { - Get(string, string) (string, error) - Set(string, string, string) - Write() error -} - -func AuthFlowWithConfig(cfg iconfig, IO *iostreams.IOStreams, hostname, notice string, additionalScopes []string, isInteractive bool) (string, error) { - // TODO this probably shouldn't live in this package. It should probably be in a new package that - // depends on both iostreams and config. - - // FIXME: this duplicates `factory.browserLauncher()` - browserLauncher := os.Getenv("GH_BROWSER") - if browserLauncher == "" { - browserLauncher, _ = cfg.Get("", "browser") - } - if browserLauncher == "" { - browserLauncher = os.Getenv("BROWSER") - } - - token, userLogin, err := authFlow(hostname, IO, notice, additionalScopes, isInteractive, browserLauncher) - if err != nil { - return "", err - } - - cfg.Set(hostname, "user", userLogin) - cfg.Set(hostname, "oauth_token", token) - - return token, cfg.Write() -} - -func authFlow(oauthHost string, IO *iostreams.IOStreams, notice string, additionalScopes []string, isInteractive bool, browserLauncher string) (string, string, error) { +func AuthFlow(oauthHost string, IO *iostreams.IOStreams, notice string, additionalScopes []string, isInteractive bool, b browser.Browser) (string, string, error) { w := IO.ErrOut cs := IO.ColorScheme() @@ -106,7 +75,6 @@ func authFlow(oauthHost string, IO *iostreams.IOStreams, notice string, addition fmt.Fprintf(w, "%s to open %s in your browser... ", cs.Bold("Press Enter"), oauthHost) _ = waitForEnter(IO.In) - b := browser.New(browserLauncher, IO.Out, IO.ErrOut) if err := b.Browse(authURL); err != nil { fmt.Fprintf(w, "%s Failed opening a web browser at %s\n", cs.Red("!"), authURL) fmt.Fprintf(w, " %s\n", err) @@ -138,16 +106,16 @@ func authFlow(oauthHost string, IO *iostreams.IOStreams, notice string, addition } type cfg struct { - authToken string + token string } -func (c cfg) AuthToken(hostname string) (string, string) { - return c.authToken, "oauth_token" +func (c cfg) Token(hostname string) (string, string) { + return c.token, "oauth_token" } func getViewer(hostname, token string, logWriter io.Writer) (string, error) { opts := api.HTTPClientOptions{ - Config: cfg{authToken: token}, + Config: cfg{token: token}, Log: logWriter, } client, err := api.NewHTTPClient(opts) diff --git a/internal/config/config.go b/internal/config/config.go index 8fbd9e758..ce5796c9e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,26 +6,26 @@ import ( ghAuth "github.com/cli/go-gh/pkg/auth" ghConfig "github.com/cli/go-gh/pkg/config" + "github.com/zalando/go-keyring" ) const ( - hosts = "hosts" - aliases = "aliases" + aliases = "aliases" + hosts = "hosts" + oauthToken = "oauth_token" ) // This interface describes interacting with some persistent configuration for gh. // //go:generate moq -rm -out config_mock.go . Config type Config interface { - AuthToken(string) (string, string) Get(string, string) (string, error) GetOrDefault(string, string) (string, error) Set(string, string, string) - UnsetHost(string) - Hosts() []string - DefaultHost() (string, string) - Aliases() *AliasConfig Write() error + + Aliases() *AliasConfig + Authentication() *AuthConfig } func NewConfig() (Config, error) { @@ -41,10 +41,6 @@ type cfg struct { cfg *ghConfig.Config } -func (c *cfg) AuthToken(hostname string) (string, string) { - return ghAuth.TokenForHost(hostname) -} - func (c *cfg) Get(hostname, key string) (string, error) { if hostname != "" { val, err := c.cfg.Get([]string{hosts, hostname, key}) @@ -86,27 +82,16 @@ func (c *cfg) Set(hostname, key, value string) { c.cfg.Set([]string{hosts, hostname, key}, value) } -func (c *cfg) UnsetHost(hostname string) { - if hostname == "" { - return - } - _ = c.cfg.Remove([]string{hosts, hostname}) -} - -func (c *cfg) Hosts() []string { - return ghAuth.KnownHosts() -} - -func (c *cfg) DefaultHost() (string, string) { - return ghAuth.DefaultHost() +func (c *cfg) Write() error { + return ghConfig.Write(c.cfg) } func (c *cfg) Aliases() *AliasConfig { return &AliasConfig{cfg: c.cfg} } -func (c *cfg) Write() error { - return ghConfig.Write(c.cfg) +func (c *cfg) Authentication() *AuthConfig { + return &AuthConfig{cfg: c.cfg} } func defaultFor(key string) string { @@ -127,6 +112,132 @@ func defaultExists(key string) bool { return false } +// AuthConfig is used for interacting with some persistent configuration for gh, +// with knowledge on how to access encrypted storage when neccesarry. +// Behavior is scoped to authentication specific tasks. +type AuthConfig struct { + cfg *ghConfig.Config + defaultHostOverride func() (string, string) + hostsOverride func() []string + tokenOverride func(string) (string, string) +} + +// Token will retrieve the auth token for the given hostname, +// searching environment variables, plain text config, and +// lastly encypted storage. +func (c *AuthConfig) Token(hostname string) (string, string) { + if c.tokenOverride != nil { + return c.tokenOverride(hostname) + } + token, source := ghAuth.TokenFromEnvOrConfig(hostname) + if token == "" { + var err error + token, err = c.TokenFromKeyring(hostname) + if err == nil { + source = "keyring" + } + } + return token, source +} + +// SetToken will override any token resolution and return the given +// token and source for all calls to Token. Use for testing purposes only. +func (c *AuthConfig) SetToken(token, source string) { + c.tokenOverride = func(_ string) (string, string) { + return token, source + } +} + +// TokenFromKeyring will retrieve the auth token for the given hostname, +// only searching in encrypted storage. +func (c *AuthConfig) TokenFromKeyring(hostname string) (string, error) { + return keyring.Get(keyringServiceName(hostname), "") +} + +// User will retrieve the username for the logged in user at the given hostname. +func (c *AuthConfig) User(hostname string) (string, error) { + return c.cfg.Get([]string{hosts, hostname, "user"}) +} + +// GitProtocol will retrieve the git protocol for the logged in user at the given hostname. +// If none is set it will return the default value. +func (c *AuthConfig) GitProtocol(hostname string) (string, error) { + key := "git_protocol" + val, err := c.cfg.Get([]string{hosts, hostname, key}) + if err == nil { + return val, err + } + return defaultFor(key), nil +} + +func (c *AuthConfig) Hosts() []string { + if c.hostsOverride != nil { + return c.hostsOverride() + } + return ghAuth.KnownHosts() +} + +// SetHosts will override any hosts resolution and return the given +// hosts for all calls to Hosts. Use for testing purposes only. +func (c *AuthConfig) SetHosts(hosts []string) { + c.hostsOverride = func() []string { + return hosts + } +} + +func (c *AuthConfig) DefaultHost() (string, string) { + if c.defaultHostOverride != nil { + return c.defaultHostOverride() + } + return ghAuth.DefaultHost() +} + +// SetDefaultHost will override any host resolution and return the given +// host and source for all calls to DefaultHost. Use for testing purposes only. +func (c *AuthConfig) SetDefaultHost(host, source string) { + c.defaultHostOverride = func() (string, string) { + return host, source + } +} + +// Login will set user, git protocol, and auth token for the given hostname. +// If the encrypt option is specified it will first try to store the auth token +// in encrypted storage and will fall back to the plain text config file. +func (c *AuthConfig) Login(hostname, username, token, gitProtocol string, secureStorage bool) error { + var setErr error + if secureStorage { + if setErr = keyring.Set(keyringServiceName(hostname), "", token); setErr == nil { + // Clean up the previous oauth_token from the config file. + _ = c.cfg.Remove([]string{hosts, hostname, oauthToken}) + } + } + if !secureStorage || setErr != nil { + c.cfg.Set([]string{hosts, hostname, oauthToken}, token) + } + if username != "" { + c.cfg.Set([]string{hosts, hostname, "user"}, username) + } + if gitProtocol != "" { + c.cfg.Set([]string{hosts, hostname, "git_protocol"}, gitProtocol) + } + return ghConfig.Write(c.cfg) +} + +// Logout will remove user, git protocol, and auth token for the given hostname. +// It will remove the auth token from the encrypted storage if it exists there. +func (c *AuthConfig) Logout(hostname string) error { + if hostname == "" { + return nil + } + _ = c.cfg.Remove([]string{hosts, hostname}) + _ = keyring.Delete(keyringServiceName(hostname), "") + return ghConfig.Write(c.cfg) +} + +func keyringServiceName(hostname string) string { + return "gh:" + hostname +} + type AliasConfig struct { cfg *ghConfig.Config } diff --git a/internal/config/config_mock.go b/internal/config/config_mock.go index 86f45c1c2..6fd76dc55 100644 --- a/internal/config/config_mock.go +++ b/internal/config/config_mock.go @@ -20,11 +20,8 @@ var _ Config = &ConfigMock{} // AliasesFunc: func() *AliasConfig { // panic("mock out the Aliases method") // }, -// AuthTokenFunc: func(s string) (string, string) { -// panic("mock out the AuthToken method") -// }, -// DefaultHostFunc: func() (string, string) { -// panic("mock out the DefaultHost method") +// AuthenticationFunc: func() *AuthConfig { +// panic("mock out the Authentication method") // }, // GetFunc: func(s1 string, s2 string) (string, error) { // panic("mock out the Get method") @@ -32,15 +29,9 @@ var _ Config = &ConfigMock{} // GetOrDefaultFunc: func(s1 string, s2 string) (string, error) { // panic("mock out the GetOrDefault method") // }, -// HostsFunc: func() []string { -// panic("mock out the Hosts method") -// }, // SetFunc: func(s1 string, s2 string, s3 string) { // panic("mock out the Set method") // }, -// UnsetHostFunc: func(s string) { -// panic("mock out the UnsetHost method") -// }, // WriteFunc: func() error { // panic("mock out the Write method") // }, @@ -54,11 +45,8 @@ type ConfigMock struct { // AliasesFunc mocks the Aliases method. AliasesFunc func() *AliasConfig - // AuthTokenFunc mocks the AuthToken method. - AuthTokenFunc func(s string) (string, string) - - // DefaultHostFunc mocks the DefaultHost method. - DefaultHostFunc func() (string, string) + // AuthenticationFunc mocks the Authentication method. + AuthenticationFunc func() *AuthConfig // GetFunc mocks the Get method. GetFunc func(s1 string, s2 string) (string, error) @@ -66,15 +54,9 @@ type ConfigMock struct { // GetOrDefaultFunc mocks the GetOrDefault method. GetOrDefaultFunc func(s1 string, s2 string) (string, error) - // HostsFunc mocks the Hosts method. - HostsFunc func() []string - // SetFunc mocks the Set method. SetFunc func(s1 string, s2 string, s3 string) - // UnsetHostFunc mocks the UnsetHost method. - UnsetHostFunc func(s string) - // WriteFunc mocks the Write method. WriteFunc func() error @@ -83,13 +65,8 @@ type ConfigMock struct { // Aliases holds details about calls to the Aliases method. Aliases []struct { } - // AuthToken holds details about calls to the AuthToken method. - AuthToken []struct { - // S is the s argument value. - S string - } - // DefaultHost holds details about calls to the DefaultHost method. - DefaultHost []struct { + // Authentication holds details about calls to the Authentication method. + Authentication []struct { } // Get holds details about calls to the Get method. Get []struct { @@ -105,9 +82,6 @@ type ConfigMock struct { // S2 is the s2 argument value. S2 string } - // Hosts holds details about calls to the Hosts method. - Hosts []struct { - } // Set holds details about calls to the Set method. Set []struct { // S1 is the s1 argument value. @@ -117,24 +91,16 @@ type ConfigMock struct { // S3 is the s3 argument value. S3 string } - // UnsetHost holds details about calls to the UnsetHost method. - UnsetHost []struct { - // S is the s argument value. - S string - } // Write holds details about calls to the Write method. Write []struct { } } - lockAliases sync.RWMutex - lockAuthToken sync.RWMutex - lockDefaultHost sync.RWMutex - lockGet sync.RWMutex - lockGetOrDefault sync.RWMutex - lockHosts sync.RWMutex - lockSet sync.RWMutex - lockUnsetHost sync.RWMutex - lockWrite sync.RWMutex + lockAliases sync.RWMutex + lockAuthentication sync.RWMutex + lockGet sync.RWMutex + lockGetOrDefault sync.RWMutex + lockSet sync.RWMutex + lockWrite sync.RWMutex } // Aliases calls AliasesFunc. @@ -164,62 +130,30 @@ func (mock *ConfigMock) AliasesCalls() []struct { return calls } -// AuthToken calls AuthTokenFunc. -func (mock *ConfigMock) AuthToken(s string) (string, string) { - if mock.AuthTokenFunc == nil { - panic("ConfigMock.AuthTokenFunc: method is nil but Config.AuthToken was just called") - } - callInfo := struct { - S string - }{ - S: s, - } - mock.lockAuthToken.Lock() - mock.calls.AuthToken = append(mock.calls.AuthToken, callInfo) - mock.lockAuthToken.Unlock() - return mock.AuthTokenFunc(s) -} - -// AuthTokenCalls gets all the calls that were made to AuthToken. -// Check the length with: -// -// len(mockedConfig.AuthTokenCalls()) -func (mock *ConfigMock) AuthTokenCalls() []struct { - S string -} { - var calls []struct { - S string - } - mock.lockAuthToken.RLock() - calls = mock.calls.AuthToken - mock.lockAuthToken.RUnlock() - return calls -} - -// DefaultHost calls DefaultHostFunc. -func (mock *ConfigMock) DefaultHost() (string, string) { - if mock.DefaultHostFunc == nil { - panic("ConfigMock.DefaultHostFunc: method is nil but Config.DefaultHost was just called") +// Authentication calls AuthenticationFunc. +func (mock *ConfigMock) Authentication() *AuthConfig { + if mock.AuthenticationFunc == nil { + panic("ConfigMock.AuthenticationFunc: method is nil but Config.Authentication was just called") } callInfo := struct { }{} - mock.lockDefaultHost.Lock() - mock.calls.DefaultHost = append(mock.calls.DefaultHost, callInfo) - mock.lockDefaultHost.Unlock() - return mock.DefaultHostFunc() + mock.lockAuthentication.Lock() + mock.calls.Authentication = append(mock.calls.Authentication, callInfo) + mock.lockAuthentication.Unlock() + return mock.AuthenticationFunc() } -// DefaultHostCalls gets all the calls that were made to DefaultHost. +// AuthenticationCalls gets all the calls that were made to Authentication. // Check the length with: // -// len(mockedConfig.DefaultHostCalls()) -func (mock *ConfigMock) DefaultHostCalls() []struct { +// len(mockedConfig.AuthenticationCalls()) +func (mock *ConfigMock) AuthenticationCalls() []struct { } { var calls []struct { } - mock.lockDefaultHost.RLock() - calls = mock.calls.DefaultHost - mock.lockDefaultHost.RUnlock() + mock.lockAuthentication.RLock() + calls = mock.calls.Authentication + mock.lockAuthentication.RUnlock() return calls } @@ -295,33 +229,6 @@ func (mock *ConfigMock) GetOrDefaultCalls() []struct { return calls } -// Hosts calls HostsFunc. -func (mock *ConfigMock) Hosts() []string { - if mock.HostsFunc == nil { - panic("ConfigMock.HostsFunc: method is nil but Config.Hosts was just called") - } - callInfo := struct { - }{} - mock.lockHosts.Lock() - mock.calls.Hosts = append(mock.calls.Hosts, callInfo) - mock.lockHosts.Unlock() - return mock.HostsFunc() -} - -// HostsCalls gets all the calls that were made to Hosts. -// Check the length with: -// -// len(mockedConfig.HostsCalls()) -func (mock *ConfigMock) HostsCalls() []struct { -} { - var calls []struct { - } - mock.lockHosts.RLock() - calls = mock.calls.Hosts - mock.lockHosts.RUnlock() - return calls -} - // Set calls SetFunc. func (mock *ConfigMock) Set(s1 string, s2 string, s3 string) { if mock.SetFunc == nil { @@ -362,38 +269,6 @@ func (mock *ConfigMock) SetCalls() []struct { return calls } -// UnsetHost calls UnsetHostFunc. -func (mock *ConfigMock) UnsetHost(s string) { - if mock.UnsetHostFunc == nil { - panic("ConfigMock.UnsetHostFunc: method is nil but Config.UnsetHost was just called") - } - callInfo := struct { - S string - }{ - S: s, - } - mock.lockUnsetHost.Lock() - mock.calls.UnsetHost = append(mock.calls.UnsetHost, callInfo) - mock.lockUnsetHost.Unlock() - mock.UnsetHostFunc(s) -} - -// UnsetHostCalls gets all the calls that were made to UnsetHost. -// Check the length with: -// -// len(mockedConfig.UnsetHostCalls()) -func (mock *ConfigMock) UnsetHostCalls() []struct { - S string -} { - var calls []struct { - S string - } - mock.lockUnsetHost.RLock() - calls = mock.calls.UnsetHost - mock.lockUnsetHost.RUnlock() - return calls -} - // Write calls WriteFunc. func (mock *ConfigMock) Write() error { if mock.WriteFunc == nil { diff --git a/internal/config/stub.go b/internal/config/stub.go index 0fed79cf5..f8e56ee4e 100644 --- a/internal/config/stub.go +++ b/internal/config/stub.go @@ -34,10 +34,6 @@ func NewFromString(cfgStr string) *ConfigMock { c := ghConfig.ReadFromString(cfgStr) cfg := cfg{c} mock := &ConfigMock{} - mock.AuthTokenFunc = func(host string) (string, string) { - token, _ := c.Get([]string{"hosts", host, "oauth_token"}) - return token, "oauth_token" - } mock.GetFunc = func(host, key string) (string, error) { return cfg.Get(host, key) } @@ -47,21 +43,27 @@ func NewFromString(cfgStr string) *ConfigMock { mock.SetFunc = func(host, key, value string) { cfg.Set(host, key, value) } - mock.UnsetHostFunc = func(host string) { - cfg.UnsetHost(host) - } - mock.HostsFunc = func() []string { - keys, _ := c.Keys([]string{"hosts"}) - return keys - } - mock.DefaultHostFunc = func() (string, string) { - return "github.com", "default" + mock.WriteFunc = func() error { + return cfg.Write() } mock.AliasesFunc = func() *AliasConfig { return &AliasConfig{cfg: c} } - mock.WriteFunc = func() error { - return cfg.Write() + mock.AuthenticationFunc = func() *AuthConfig { + return &AuthConfig{ + cfg: c, + defaultHostOverride: func() (string, string) { + return "github.com", "default" + }, + hostsOverride: func() []string { + keys, _ := c.Keys([]string{"hosts"}) + return keys + }, + tokenOverride: func(hostname string) (string, string) { + token, _ := c.Get([]string{hosts, hostname, oauthToken}) + return token, "oauth_token" + }, + } } return mock } diff --git a/pkg/cmd/api/api.go b/pkg/cmd/api/api.go index 643f14094..1a6f84d0b 100644 --- a/pkg/cmd/api/api.go +++ b/pkg/cmd/api/api.go @@ -309,7 +309,7 @@ func apiRun(opts *ApiOptions) error { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() if opts.Hostname != "" { host = opts.Hostname diff --git a/pkg/cmd/auth/gitcredential/helper.go b/pkg/cmd/auth/gitcredential/helper.go index 759025f16..fcceaf186 100644 --- a/pkg/cmd/auth/gitcredential/helper.go +++ b/pkg/cmd/auth/gitcredential/helper.go @@ -14,8 +14,8 @@ import ( const tokenUser = "x-access-token" type config interface { - AuthToken(string) (string, string) - Get(string, string) (string, error) + Token(string) (string, string) + User(string) (string, error) } type CredentialOptions struct { @@ -29,7 +29,11 @@ func NewCmdCredential(f *cmdutil.Factory, runF func(*CredentialOptions) error) * opts := &CredentialOptions{ IO: f.IOStreams, Config: func() (config, error) { - return f.Config() + cfg, err := f.Config() + if err != nil { + return nil, err + } + return cfg.Authentication(), nil }, } @@ -108,16 +112,16 @@ func helperRun(opts *CredentialOptions) error { lookupHost := wants["host"] var gotUser string - gotToken, source := cfg.AuthToken(lookupHost) + gotToken, source := cfg.Token(lookupHost) if gotToken == "" && strings.HasPrefix(lookupHost, "gist.") { lookupHost = strings.TrimPrefix(lookupHost, "gist.") - gotToken, source = cfg.AuthToken(lookupHost) + gotToken, source = cfg.Token(lookupHost) } if strings.HasSuffix(source, "_TOKEN") { gotUser = tokenUser } else { - gotUser, _ = cfg.Get(lookupHost, "user") + gotUser, _ = cfg.User(lookupHost) if gotUser == "" { gotUser = tokenUser } diff --git a/pkg/cmd/auth/gitcredential/helper_test.go b/pkg/cmd/auth/gitcredential/helper_test.go index b80e77b86..f66df1d16 100644 --- a/pkg/cmd/auth/gitcredential/helper_test.go +++ b/pkg/cmd/auth/gitcredential/helper_test.go @@ -8,15 +8,14 @@ import ( "github.com/cli/cli/v2/pkg/iostreams" ) -// why not just use the config stub argh type tinyConfig map[string]string -func (c tinyConfig) AuthToken(host string) (string, string) { +func (c tinyConfig) Token(host string) (string, string) { return c[fmt.Sprintf("%s:%s", host, "oauth_token")], c["_source"] } -func (c tinyConfig) Get(host, key string) (string, error) { - return c[fmt.Sprintf("%s:%s", host, key)], nil +func (c tinyConfig) User(host string) (string, error) { + return c[fmt.Sprintf("%s:%s", host, "user")], nil } func Test_helperRun(t *testing.T) { diff --git a/pkg/cmd/auth/login/login.go b/pkg/cmd/auth/login/login.go index 50cf111f2..3eda42ae0 100644 --- a/pkg/cmd/auth/login/login.go +++ b/pkg/cmd/auth/login/login.go @@ -8,6 +8,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/git" + "github.com/cli/cli/v2/internal/browser" "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/ghinstance" "github.com/cli/cli/v2/pkg/cmd/auth/shared" @@ -23,16 +24,18 @@ type LoginOptions struct { HttpClient func() (*http.Client, error) GitClient *git.Client Prompter shared.Prompt + Browser browser.Browser MainExecutable string Interactive bool - Hostname string - Scopes []string - Token string - Web bool - GitProtocol string + Hostname string + Scopes []string + Token string + Web bool + GitProtocol string + SecureStorage bool } func NewCmdLogin(f *cmdutil.Factory, runF func(*LoginOptions) error) *cobra.Command { @@ -42,6 +45,7 @@ func NewCmdLogin(f *cmdutil.Factory, runF func(*LoginOptions) error) *cobra.Comm HttpClient: f.HttpClient, GitClient: f.GitClient, Prompter: f.Prompter, + Browser: f.Browser, } var tokenStdin bool @@ -120,6 +124,7 @@ func NewCmdLogin(f *cmdutil.Factory, runF func(*LoginOptions) error) *cobra.Comm cmd.Flags().BoolVar(&tokenStdin, "with-token", false, "Read token from standard input") cmd.Flags().BoolVarP(&opts.Web, "web", "w", false, "Open a browser to authenticate") cmdutil.StringEnumFlag(cmd, &opts.GitProtocol, "git-protocol", "p", "", []string{"ssh", "https"}, "The protocol to use for git operations") + cmd.Flags().BoolVarP(&opts.SecureStorage, "secure-storage", "", false, "Save authentication credentials in secure credential store") return cmd } @@ -129,6 +134,12 @@ func loginRun(opts *LoginOptions) error { if err != nil { return err } + authCfg := cfg.Authentication() + + if opts.SecureStorage { + cs := opts.IO.ColorScheme() + fmt.Fprintf(opts.IO.ErrOut, "%s Using secure storage could break installed extensions\n", cs.WarningIcon()) + } hostname := opts.Hostname if opts.Interactive && hostname == "" { @@ -139,7 +150,7 @@ func loginRun(opts *LoginOptions) error { } } - if src, writeable := shared.AuthTokenWriteable(cfg, hostname); !writeable { + if src, writeable := shared.AuthTokenWriteable(authCfg, hostname); !writeable { fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", src) fmt.Fprint(opts.IO.ErrOut, "To have GitHub CLI store credentials instead, first clear the value from the environment.\n") return cmdutil.SilentError @@ -151,18 +162,14 @@ func loginRun(opts *LoginOptions) error { } if opts.Token != "" { - cfg.Set(hostname, "oauth_token", opts.Token) - if err := shared.HasMinimumScopes(httpClient, hostname, opts.Token); err != nil { return fmt.Errorf("error validating token: %w", err) } - if opts.GitProtocol != "" { - cfg.Set(hostname, "git_protocol", opts.GitProtocol) - } - return cfg.Write() + // Adding a user key ensures that a nonempty host section gets written to the config file. + return authCfg.Login(hostname, "x-access-token", opts.Token, opts.GitProtocol, opts.SecureStorage) } - existingToken, _ := cfg.AuthToken(hostname) + existingToken, _ := authCfg.Token(hostname) if existingToken != "" && opts.Interactive { if err := shared.HasMinimumScopes(httpClient, hostname, existingToken); err == nil { keepGoing, err := opts.Prompter.Confirm(fmt.Sprintf("You're already logged into %s. Do you want to re-authenticate?", hostname), false) @@ -176,17 +183,19 @@ func loginRun(opts *LoginOptions) error { } return shared.Login(&shared.LoginOptions{ - IO: opts.IO, - Config: cfg, - HTTPClient: httpClient, - Hostname: hostname, - Interactive: opts.Interactive, - Web: opts.Web, - Scopes: opts.Scopes, - Executable: opts.MainExecutable, - GitProtocol: opts.GitProtocol, - Prompter: opts.Prompter, - GitClient: opts.GitClient, + IO: opts.IO, + Config: authCfg, + HTTPClient: httpClient, + Hostname: hostname, + Interactive: opts.Interactive, + Web: opts.Web, + Scopes: opts.Scopes, + Executable: opts.MainExecutable, + GitProtocol: opts.GitProtocol, + Prompter: opts.Prompter, + GitClient: opts.GitClient, + Browser: opts.Browser, + SecureStorage: opts.SecureStorage, }) } diff --git a/pkg/cmd/auth/login/login_test.go b/pkg/cmd/auth/login/login_test.go index 85e133b2b..9f95cc9ee 100644 --- a/pkg/cmd/auth/login/login_test.go +++ b/pkg/cmd/auth/login/login_test.go @@ -17,6 +17,7 @@ import ( "github.com/cli/cli/v2/pkg/iostreams" "github.com/google/shlex" "github.com/stretchr/testify/assert" + "github.com/zalando/go-keyring" ) func stubHomeDir(t *testing.T, dir string) { @@ -172,6 +173,23 @@ func Test_NewCmdLogin(t *testing.T) { Interactive: true, }, }, + { + name: "tty secure-storage", + stdinTTY: true, + cli: "--secure-storage", + wants: LoginOptions{ + Interactive: true, + SecureStorage: true, + }, + }, + { + name: "nontty secure-storage", + cli: "--secure-storage", + wants: LoginOptions{ + Hostname: "github.com", + SecureStorage: true, + }, + }, } for _, tt := range tests { @@ -223,13 +241,14 @@ func Test_NewCmdLogin(t *testing.T) { func Test_loginRun_nontty(t *testing.T) { tests := []struct { - name string - opts *LoginOptions - httpStubs func(*httpmock.Registry) - cfgStubs func(*config.ConfigMock) - wantHosts string - wantErr string - wantStderr string + name string + opts *LoginOptions + httpStubs func(*httpmock.Registry) + cfgStubs func(*config.ConfigMock) + wantHosts string + wantErr string + wantStderr string + wantSecureToken string }{ { name: "with token", @@ -240,7 +259,7 @@ func Test_loginRun_nontty(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo,read:org")) }, - wantHosts: "github.com:\n oauth_token: abc123\n", + wantHosts: "github.com:\n oauth_token: abc123\n user: x-access-token\n", }, { name: "with token and https git-protocol", @@ -252,7 +271,7 @@ func Test_loginRun_nontty(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo,read:org")) }, - wantHosts: "github.com:\n oauth_token: abc123\n git_protocol: https\n", + wantHosts: "github.com:\n oauth_token: abc123\n user: x-access-token\n git_protocol: https\n", }, { name: "with token and non-default host", @@ -263,7 +282,7 @@ func Test_loginRun_nontty(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org")) }, - wantHosts: "albert.wesker:\n oauth_token: abc123\n", + wantHosts: "albert.wesker:\n oauth_token: abc123\n user: x-access-token\n", }, { name: "missing repo scope", @@ -296,7 +315,7 @@ func Test_loginRun_nontty(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo,admin:org")) }, - wantHosts: "github.com:\n oauth_token: abc456\n", + wantHosts: "github.com:\n oauth_token: abc456\n user: x-access-token\n", }, { name: "github.com token from environment", @@ -305,8 +324,10 @@ func Test_loginRun_nontty(t *testing.T) { Token: "abc456", }, cfgStubs: func(c *config.ConfigMock) { - c.AuthTokenFunc = func(string) (string, string) { - return "value_from_env", "GH_TOKEN" + authCfg := c.Authentication() + authCfg.SetToken("value_from_env", "GH_TOKEN") + c.AuthenticationFunc = func() *config.AuthConfig { + return authCfg } }, wantErr: "SilentError", @@ -322,8 +343,10 @@ func Test_loginRun_nontty(t *testing.T) { Token: "abc456", }, cfgStubs: func(c *config.ConfigMock) { - c.AuthTokenFunc = func(string) (string, string) { - return "value_from_env", "GH_ENTERPRISE_TOKEN" + authCfg := c.Authentication() + authCfg.SetToken("value_from_env", "GH_ENTERPRISE_TOKEN") + c.AuthenticationFunc = func() *config.AuthConfig { + return authCfg } }, wantErr: "SilentError", @@ -332,15 +355,30 @@ func Test_loginRun_nontty(t *testing.T) { To have GitHub CLI store credentials instead, first clear the value from the environment. `), }, + { + name: "with token and secure storage", + opts: &LoginOptions{ + Hostname: "github.com", + Token: "abc123", + SecureStorage: true, + }, + httpStubs: func(reg *httpmock.Registry) { + reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo,read:org")) + }, + wantHosts: "github.com:\n user: x-access-token\n", + wantSecureToken: "abc123", + wantStderr: "! Using secure storage could break installed extensions\n", + }, } for _, tt := range tests { - ios, _, stdout, stderr := iostreams.Test() - ios.SetStdinTTY(false) - ios.SetStdoutTTY(false) - tt.opts.IO = ios - t.Run(tt.name, func(t *testing.T) { + ios, _, stdout, stderr := iostreams.Test() + ios.SetStdinTTY(false) + ios.SetStdoutTTY(false) + tt.opts.IO = ios + + keyring.MockInit() readConfigs := config.StubWriteConfig(t) cfg := config.NewBlankConfig() if tt.cfgStubs != nil { @@ -351,6 +389,7 @@ func Test_loginRun_nontty(t *testing.T) { } reg := &httpmock.Registry{} + defer reg.Verify(t) tt.opts.HttpClient = func() (*http.Client, error) { return &http.Client{Transport: reg}, nil } @@ -371,11 +410,12 @@ func Test_loginRun_nontty(t *testing.T) { mainBuf := bytes.Buffer{} hostsBuf := bytes.Buffer{} readConfigs(&mainBuf, &hostsBuf) + secureToken, _ := cfg.Authentication().TokenFromKeyring(tt.opts.Hostname) assert.Equal(t, "", stdout.String()) assert.Equal(t, tt.wantStderr, stderr.String()) assert.Equal(t, tt.wantHosts, hostsBuf.String()) - reg.Verify(t) + assert.Equal(t, tt.wantSecureToken, secureToken) }) } } @@ -384,14 +424,15 @@ func Test_loginRun_Survey(t *testing.T) { stubHomeDir(t, t.TempDir()) tests := []struct { - name string - opts *LoginOptions - httpStubs func(*httpmock.Registry) - prompterStubs func(*prompter.PrompterMock) - runStubs func(*run.CommandStubber) - wantHosts string - wantErrOut *regexp.Regexp - cfgStubs func(*config.ConfigMock) + name string + opts *LoginOptions + httpStubs func(*httpmock.Registry) + prompterStubs func(*prompter.PrompterMock) + runStubs func(*run.CommandStubber) + cfgStubs func(*config.ConfigMock) + wantHosts string + wantErrOut *regexp.Regexp + wantSecureToken string }{ { name: "already authenticated", @@ -399,8 +440,10 @@ func Test_loginRun_Survey(t *testing.T) { Interactive: true, }, cfgStubs: func(c *config.ConfigMock) { - c.AuthTokenFunc = func(h string) (string, string) { - return "ghi789", "oauth_token" + authCfg := c.Authentication() + authCfg.SetToken("ghi789", "oauth_token") + c.AuthenticationFunc = func() *config.AuthConfig { + return authCfg } }, httpStubs: func(reg *httpmock.Registry) { @@ -547,32 +590,62 @@ func Test_loginRun_Survey(t *testing.T) { }, wantErrOut: regexp.MustCompile("Tip: you can generate a Personal Access Token here https://github.com/settings/tokens"), }, - // TODO how to test browser auth? + { + name: "secure storage", + opts: &LoginOptions{ + Hostname: "github.com", + Interactive: true, + SecureStorage: true, + }, + prompterStubs: func(pm *prompter.PrompterMock) { + pm.SelectFunc = func(prompt, _ string, opts []string) (int, error) { + switch prompt { + case "What is your preferred protocol for Git operations?": + return prompter.IndexFor(opts, "HTTPS") + case "How would you like to authenticate GitHub CLI?": + return prompter.IndexFor(opts, "Paste an authentication token") + } + return -1, prompter.NoSuchPromptErr(prompt) + } + }, + runStubs: func(rs *run.CommandStubber) { + rs.Register(`git config credential\.https:/`, 1, "") + rs.Register(`git config credential\.helper`, 1, "") + }, + wantHosts: heredoc.Doc(` + github.com: + user: jillv + git_protocol: https + `), + wantErrOut: regexp.MustCompile("! Using secure storage could break installed extensions"), + wantSecureToken: "def456", + }, } for _, tt := range tests { - if tt.opts == nil { - tt.opts = &LoginOptions{} - } - ios, _, _, stderr := iostreams.Test() - - ios.SetStdinTTY(true) - ios.SetStderrTTY(true) - ios.SetStdoutTTY(true) - - tt.opts.IO = ios - - readConfigs := config.StubWriteConfig(t) - - cfg := config.NewBlankConfig() - if tt.cfgStubs != nil { - tt.cfgStubs(cfg) - } - tt.opts.Config = func() (config.Config, error) { - return cfg, nil - } - t.Run(tt.name, func(t *testing.T) { + if tt.opts == nil { + tt.opts = &LoginOptions{} + } + ios, _, _, stderr := iostreams.Test() + + ios.SetStdinTTY(true) + ios.SetStderrTTY(true) + ios.SetStdoutTTY(true) + + tt.opts.IO = ios + + keyring.MockInit() + readConfigs := config.StubWriteConfig(t) + + cfg := config.NewBlankConfig() + if tt.cfgStubs != nil { + tt.cfgStubs(cfg) + } + tt.opts.Config = func() (config.Config, error) { + return cfg, nil + } + reg := &httpmock.Registry{} tt.opts.HttpClient = func() (*http.Client, error) { return &http.Client{Transport: reg}, nil @@ -614,8 +687,10 @@ func Test_loginRun_Survey(t *testing.T) { mainBuf := bytes.Buffer{} hostsBuf := bytes.Buffer{} readConfigs(&mainBuf, &hostsBuf) + secureToken, _ := cfg.Authentication().TokenFromKeyring(tt.opts.Hostname) assert.Equal(t, tt.wantHosts, hostsBuf.String()) + assert.Equal(t, tt.wantSecureToken, secureToken) if tt.wantErrOut == nil { assert.Equal(t, "", stderr.String()) } else { diff --git a/pkg/cmd/auth/logout/logout.go b/pkg/cmd/auth/logout/logout.go index 4495cd5f6..c871a5333 100644 --- a/pkg/cmd/auth/logout/logout.go +++ b/pkg/cmd/auth/logout/logout.go @@ -69,8 +69,9 @@ func logoutRun(opts *LogoutOptions) error { if err != nil { return err } + authCfg := cfg.Authentication() - candidates := cfg.Hosts() + candidates := authCfg.Hosts() if len(candidates) == 0 { return fmt.Errorf("not logged in to any hosts") } @@ -100,7 +101,7 @@ func logoutRun(opts *LogoutOptions) error { } } - if src, writeable := shared.AuthTokenWriteable(cfg, hostname); !writeable { + if src, writeable := shared.AuthTokenWriteable(authCfg, hostname); !writeable { fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", src) fmt.Fprint(opts.IO.ErrOut, "To erase credentials stored in GitHub CLI, first clear the value from the environment.\n") return cmdutil.SilentError @@ -116,7 +117,7 @@ func logoutRun(opts *LogoutOptions) error { if err != nil { // suppressing; the user is trying to delete this token and it might be bad. // we'll see if the username is in the config and fall back to that. - username, _ = cfg.Get(hostname, "user") + username, _ = authCfg.User(hostname) } usernameStr := "" @@ -124,9 +125,7 @@ func logoutRun(opts *LogoutOptions) error { usernameStr = fmt.Sprintf(" account '%s'", username) } - cfg.UnsetHost(hostname) - err = cfg.Write() - if err != nil { + if err := authCfg.Logout(hostname); err != nil { return fmt.Errorf("failed to write config, authentication configuration not updated: %w", err) } diff --git a/pkg/cmd/auth/logout/logout_test.go b/pkg/cmd/auth/logout/logout_test.go index 0a5e32f7c..75c215aaa 100644 --- a/pkg/cmd/auth/logout/logout_test.go +++ b/pkg/cmd/auth/logout/logout_test.go @@ -2,6 +2,7 @@ package logout import ( "bytes" + "fmt" "net/http" "regexp" "testing" @@ -13,6 +14,7 @@ import ( "github.com/cli/cli/v2/pkg/iostreams" "github.com/google/shlex" "github.com/stretchr/testify/assert" + "github.com/zalando/go-keyring" ) func Test_NewCmdLogout(t *testing.T) { @@ -96,6 +98,7 @@ func Test_logoutRun_tty(t *testing.T) { opts *LogoutOptions prompterStubs func(*prompter.PrompterMock) cfgHosts []string + secureStorage bool wantHosts string wantErrOut *regexp.Regexp wantErr string @@ -133,14 +136,31 @@ func Test_logoutRun_tty(t *testing.T) { wantHosts: "github.com:\n oauth_token: abc123\n", wantErrOut: regexp.MustCompile(`Logged out of cheryl.mason account 'cybilb'`), }, + { + name: "secure storage", + secureStorage: true, + opts: &LogoutOptions{ + Hostname: "github.com", + }, + cfgHosts: []string{"github.com"}, + wantHosts: "{}\n", + wantErrOut: regexp.MustCompile(`Logged out of github.com account 'cybilb'`), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + keyring.MockInit() readConfigs := config.StubWriteConfig(t) cfg := config.NewFromString("") for _, hostname := range tt.cfgHosts { - cfg.Set(hostname, "oauth_token", "abc123") + if tt.secureStorage { + cfg.Set(hostname, "user", "monalisa") + _ = keyring.Set(fmt.Sprintf("gh:%s", hostname), "", "abc123") + cfg.Authentication().SetToken("abc123", "keyring") + } else { + cfg.Set(hostname, "oauth_token", "abc123") + } } tt.opts.Config = func() (config.Config, error) { return cfg, nil @@ -183,8 +203,10 @@ func Test_logoutRun_tty(t *testing.T) { mainBuf := bytes.Buffer{} hostsBuf := bytes.Buffer{} readConfigs(&mainBuf, &hostsBuf) + secureToken, _ := cfg.Authentication().TokenFromKeyring(tt.opts.Hostname) assert.Equal(t, tt.wantHosts, hostsBuf.String()) + assert.Equal(t, "", secureToken) reg.Verify(t) }) } @@ -192,12 +214,13 @@ func Test_logoutRun_tty(t *testing.T) { func Test_logoutRun_nontty(t *testing.T) { tests := []struct { - name string - opts *LogoutOptions - cfgHosts []string - wantHosts string - wantErr string - ghtoken string + name string + opts *LogoutOptions + cfgHosts []string + secureStorage bool + ghtoken string + wantHosts string + wantErr string }{ { name: "hostname, one host", @@ -222,14 +245,30 @@ func Test_logoutRun_nontty(t *testing.T) { }, wantErr: `not logged in to any hosts`, }, + { + name: "secure storage", + secureStorage: true, + opts: &LogoutOptions{ + Hostname: "harry.mason", + }, + cfgHosts: []string{"harry.mason"}, + wantHosts: "{}\n", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + keyring.MockInit() readConfigs := config.StubWriteConfig(t) cfg := config.NewFromString("") for _, hostname := range tt.cfgHosts { - cfg.Set(hostname, "oauth_token", "abc123") + if tt.secureStorage { + cfg.Set(hostname, "user", "monalisa") + _ = keyring.Set(fmt.Sprintf("gh:%s", hostname), "", "abc123") + cfg.Authentication().SetToken("abc123", "keyring") + } else { + cfg.Set(hostname, "oauth_token", "abc123") + } } tt.opts.Config = func() (config.Config, error) { return cfg, nil @@ -257,8 +296,10 @@ func Test_logoutRun_nontty(t *testing.T) { mainBuf := bytes.Buffer{} hostsBuf := bytes.Buffer{} readConfigs(&mainBuf, &hostsBuf) + secureToken, _ := cfg.Authentication().TokenFromKeyring(tt.opts.Hostname) assert.Equal(t, tt.wantHosts, hostsBuf.String()) + assert.Equal(t, "", secureToken) reg.Verify(t) }) } diff --git a/pkg/cmd/auth/refresh/refresh.go b/pkg/cmd/auth/refresh/refresh.go index c66fb5237..ce4c5029f 100644 --- a/pkg/cmd/auth/refresh/refresh.go +++ b/pkg/cmd/auth/refresh/refresh.go @@ -26,18 +26,26 @@ type RefreshOptions struct { Hostname string Scopes []string - AuthFlow func(config.Config, *iostreams.IOStreams, string, []string, bool) error + AuthFlow func(*config.AuthConfig, *iostreams.IOStreams, string, []string, bool, bool) error - Interactive bool + Interactive bool + SecureStorage bool } func NewCmdRefresh(f *cmdutil.Factory, runF func(*RefreshOptions) error) *cobra.Command { opts := &RefreshOptions{ IO: f.IOStreams, Config: f.Config, - AuthFlow: func(cfg config.Config, io *iostreams.IOStreams, hostname string, scopes []string, interactive bool) error { - _, err := authflow.AuthFlowWithConfig(cfg, io, hostname, "", scopes, interactive) - return err + AuthFlow: func(authCfg *config.AuthConfig, io *iostreams.IOStreams, hostname string, scopes []string, interactive, secureStorage bool) error { + if secureStorage { + cs := io.ColorScheme() + fmt.Fprintf(io.ErrOut, "%s Using secure storage could break installed extensions", cs.WarningIcon()) + } + token, username, err := authflow.AuthFlow(hostname, io, "", scopes, interactive, f.Browser) + if err != nil { + return err + } + return authCfg.Login(hostname, username, token, "", secureStorage) }, HttpClient: &http.Client{}, GitClient: f.GitClient, @@ -77,6 +85,7 @@ func NewCmdRefresh(f *cmdutil.Factory, runF func(*RefreshOptions) error) *cobra. cmd.Flags().StringVarP(&opts.Hostname, "hostname", "h", "", "The GitHub host to use for authentication") cmd.Flags().StringSliceVarP(&opts.Scopes, "scopes", "s", nil, "Additional authentication scopes for gh to have") + cmd.Flags().BoolVarP(&opts.SecureStorage, "secure-storage", "", false, "Save authentication credentials in secure credential store") return cmd } @@ -86,8 +95,9 @@ func refreshRun(opts *RefreshOptions) error { if err != nil { return err } + authCfg := cfg.Authentication() - candidates := cfg.Hosts() + candidates := authCfg.Hosts() if len(candidates) == 0 { return fmt.Errorf("not logged in to any hosts. Use 'gh auth login' to authenticate with a host") } @@ -117,14 +127,14 @@ func refreshRun(opts *RefreshOptions) error { } } - if src, writeable := shared.AuthTokenWriteable(cfg, hostname); !writeable { + if src, writeable := shared.AuthTokenWriteable(authCfg, hostname); !writeable { fmt.Fprintf(opts.IO.ErrOut, "The value of the %s environment variable is being used for authentication.\n", src) fmt.Fprint(opts.IO.ErrOut, "To refresh credentials stored in GitHub CLI, first clear the value from the environment.\n") return cmdutil.SilentError } var additionalScopes []string - if oldToken, _ := cfg.AuthToken(hostname); oldToken != "" { + if oldToken, _ := authCfg.Token(hostname); oldToken != "" { if oldScopes, err := shared.GetScopes(opts.HttpClient, hostname, oldToken); err == nil { for _, s := range strings.Split(oldScopes, ",") { s = strings.TrimSpace(s) @@ -140,7 +150,7 @@ func refreshRun(opts *RefreshOptions) error { Prompter: opts.Prompter, GitClient: opts.GitClient, } - gitProtocol, _ := cfg.GetOrDefault(hostname, "git_protocol") + gitProtocol, _ := authCfg.GitProtocol(hostname) if opts.Interactive && gitProtocol == "https" { if err := credentialFlow.Prompt(hostname); err != nil { return err @@ -148,7 +158,7 @@ func refreshRun(opts *RefreshOptions) error { additionalScopes = append(additionalScopes, credentialFlow.Scopes()...) } - if err := opts.AuthFlow(cfg, opts.IO, hostname, append(opts.Scopes, additionalScopes...), opts.Interactive); err != nil { + if err := opts.AuthFlow(authCfg, opts.IO, hostname, append(opts.Scopes, additionalScopes...), opts.Interactive, opts.SecureStorage); err != nil { return err } @@ -156,8 +166,8 @@ func refreshRun(opts *RefreshOptions) error { fmt.Fprintf(opts.IO.ErrOut, "%s Authentication complete.\n", cs.SuccessIcon()) if credentialFlow.ShouldSetup() { - username, _ := cfg.Get(hostname, "user") - password, _ := cfg.AuthToken(hostname) + username, _ := authCfg.User(hostname) + password, _ := authCfg.Token(hostname) if err := credentialFlow.Setup(hostname, username, password); err != nil { return err } diff --git a/pkg/cmd/auth/refresh/refresh_test.go b/pkg/cmd/auth/refresh/refresh_test.go index d1d6042c0..863bc0b7c 100644 --- a/pkg/cmd/auth/refresh/refresh_test.go +++ b/pkg/cmd/auth/refresh/refresh_test.go @@ -85,6 +85,14 @@ func Test_NewCmdRefresh(t *testing.T) { Scopes: []string{"repo:invite", "read:public_key"}, }, }, + { + name: "secure storage", + tty: true, + cli: "--secure-storage", + wants: RefreshOptions{ + SecureStorage: true, + }, + }, } for _, tt := range tests { @@ -126,8 +134,10 @@ func Test_NewCmdRefresh(t *testing.T) { } type authArgs struct { - hostname string - scopes []string + hostname string + scopes []string + interactive bool + secureStorage bool } func Test_refreshRun(t *testing.T) { @@ -230,17 +240,33 @@ func Test_refreshRun(t *testing.T) { scopes: []string{"repo:invite", "public_key:read", "delete_repo", "codespace"}, }, }, + { + name: "secure storage", + cfgHosts: []string{ + "obed.morton", + }, + opts: &RefreshOptions{ + Hostname: "obed.morton", + SecureStorage: true, + }, + wantAuthArgs: authArgs{ + hostname: "obed.morton", + scopes: nil, + secureStorage: true, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { aa := authArgs{} - tt.opts.AuthFlow = func(_ config.Config, _ *iostreams.IOStreams, hostname string, scopes []string, interactive bool) error { + tt.opts.AuthFlow = func(_ *config.AuthConfig, _ *iostreams.IOStreams, hostname string, scopes []string, interactive, secureStorage bool) error { aa.hostname = hostname aa.scopes = scopes + aa.interactive = interactive + aa.secureStorage = secureStorage return nil } - _ = config.StubWriteConfig(t) cfg := config.NewFromString("") for _, hostname := range tt.cfgHosts { cfg.Set(hostname, "oauth_token", "abc123") @@ -291,6 +317,8 @@ func Test_refreshRun(t *testing.T) { assert.Equal(t, tt.wantAuthArgs.hostname, aa.hostname) assert.Equal(t, tt.wantAuthArgs.scopes, aa.scopes) + assert.Equal(t, tt.wantAuthArgs.interactive, aa.interactive) + assert.Equal(t, tt.wantAuthArgs.secureStorage, aa.secureStorage) }) } } diff --git a/pkg/cmd/auth/setupgit/setupgit.go b/pkg/cmd/auth/setupgit/setupgit.go index edc531235..a041ebfc1 100644 --- a/pkg/cmd/auth/setupgit/setupgit.go +++ b/pkg/cmd/auth/setupgit/setupgit.go @@ -54,8 +54,9 @@ func setupGitRun(opts *SetupGitOptions) error { if err != nil { return err } + authCfg := cfg.Authentication() - hostnames := cfg.Hosts() + hostnames := authCfg.Hosts() stderr := opts.IO.ErrOut cs := opts.IO.ColorScheme() diff --git a/pkg/cmd/auth/setupgit/setupgit_test.go b/pkg/cmd/auth/setupgit/setupgit_test.go index 1f3f0be92..635cece8c 100644 --- a/pkg/cmd/auth/setupgit/setupgit_test.go +++ b/pkg/cmd/auth/setupgit/setupgit_test.go @@ -38,8 +38,10 @@ func Test_setupGitRun(t *testing.T) { opts: &SetupGitOptions{ Config: func() (config.Config, error) { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{} + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{}) + return authCfg } return cfg, nil }, @@ -53,8 +55,10 @@ func Test_setupGitRun(t *testing.T) { Hostname: "foo", Config: func() (config.Config, error) { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"bar"} + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"bar"}) + return authCfg } return cfg, nil }, @@ -70,8 +74,10 @@ func Test_setupGitRun(t *testing.T) { }, Config: func() (config.Config, error) { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"bar"} + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"bar"}) + return authCfg } return cfg, nil }, @@ -85,8 +91,10 @@ func Test_setupGitRun(t *testing.T) { gitConfigure: &mockGitConfigurer{}, Config: func() (config.Config, error) { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"bar"} + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"bar"}) + return authCfg } return cfg, nil }, @@ -99,8 +107,10 @@ func Test_setupGitRun(t *testing.T) { gitConfigure: &mockGitConfigurer{}, Config: func() (config.Config, error) { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"bar", "yes"} + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"bar", "yes"}) + return authCfg } return cfg, nil }, diff --git a/pkg/cmd/auth/shared/login_flow.go b/pkg/cmd/auth/shared/login_flow.go index 99d82b5f7..0201af155 100644 --- a/pkg/cmd/auth/shared/login_flow.go +++ b/pkg/cmd/auth/shared/login_flow.go @@ -10,6 +10,7 @@ import ( "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/authflow" + "github.com/cli/cli/v2/internal/browser" "github.com/cli/cli/v2/internal/ghinstance" "github.com/cli/cli/v2/pkg/cmd/ssh-key/add" "github.com/cli/cli/v2/pkg/iostreams" @@ -19,23 +20,23 @@ import ( const defaultSSHKeyTitle = "GitHub CLI" type iconfig interface { - Get(string, string) (string, error) - Set(string, string, string) - Write() error + Login(string, string, string, string, bool) error } type LoginOptions struct { - IO *iostreams.IOStreams - Config iconfig - HTTPClient *http.Client - GitClient *git.Client - Hostname string - Interactive bool - Web bool - Scopes []string - Executable string - GitProtocol string - Prompter Prompt + IO *iostreams.IOStreams + Config iconfig + HTTPClient *http.Client + GitClient *git.Client + Hostname string + Interactive bool + Web bool + Scopes []string + Executable string + GitProtocol string + Prompter Prompt + Browser browser.Browser + SecureStorage bool sshContext ssh.Context } @@ -145,16 +146,15 @@ func Login(opts *LoginOptions) error { } var authToken string - userValidated := false + var username string if authMode == 0 { var err error - authToken, err = authflow.AuthFlowWithConfig(cfg, opts.IO, hostname, "", append(opts.Scopes, additionalScopes...), opts.Interactive) + authToken, username, err = authflow.AuthFlow(hostname, opts.IO, "", append(opts.Scopes, additionalScopes...), opts.Interactive, opts.Browser) if err != nil { return fmt.Errorf("failed to authenticate via web browser: %w", err) } fmt.Fprintf(opts.IO.ErrOut, "%s Authentication complete.\n", cs.SuccessIcon()) - userValidated = true } else { minimumScopes := append([]string{"repo", "read:org"}, additionalScopes...) fmt.Fprint(opts.IO.ErrOut, heredoc.Docf(` @@ -162,7 +162,8 @@ func Login(opts *LoginOptions) error { The minimum required scopes are %s. `, hostname, scopesSentence(minimumScopes, ghinstance.IsEnterprise(hostname)))) - authToken, err := opts.Prompter.AuthToken() + var err error + authToken, err = opts.Prompter.AuthToken() if err != nil { return err } @@ -170,32 +171,23 @@ func Login(opts *LoginOptions) error { if err := HasMinimumScopes(httpClient, hostname, authToken); err != nil { return fmt.Errorf("error validating token: %w", err) } - - cfg.Set(hostname, "oauth_token", authToken) } - var username string - if userValidated { - username, _ = cfg.Get(hostname, "user") - } else { + if username == "" { apiClient := api.NewClientFromHTTP(httpClient) var err error username, err = api.CurrentLoginName(apiClient, hostname) if err != nil { return fmt.Errorf("error using api: %w", err) } - - cfg.Set(hostname, "user", username) } if gitProtocol != "" { fmt.Fprintf(opts.IO.ErrOut, "- gh config set -h %s git_protocol %s\n", hostname, gitProtocol) - cfg.Set(hostname, "git_protocol", gitProtocol) fmt.Fprintf(opts.IO.ErrOut, "%s Configured git protocol\n", cs.SuccessIcon()) } - err := cfg.Write() - if err != nil { + if err := cfg.Login(hostname, username, authToken, gitProtocol, opts.SecureStorage); err != nil { return err } diff --git a/pkg/cmd/auth/shared/login_flow_test.go b/pkg/cmd/auth/shared/login_flow_test.go index 40d045661..f01ba46ff 100644 --- a/pkg/cmd/auth/shared/login_flow_test.go +++ b/pkg/cmd/auth/shared/login_flow_test.go @@ -18,15 +18,10 @@ import ( type tinyConfig map[string]string -func (c tinyConfig) Get(host, key string) (string, error) { - return c[fmt.Sprintf("%s:%s", host, key)], nil -} - -func (c tinyConfig) Set(host string, key string, value string) { - c[fmt.Sprintf("%s:%s", host, key)] = value -} - -func (c tinyConfig) Write() error { +func (c tinyConfig) Login(host, username, token, gitProtocol string, encrypt bool) error { + c[fmt.Sprintf("%s:%s", host, "user")] = username + c[fmt.Sprintf("%s:%s", host, "oauth_token")] = token + c[fmt.Sprintf("%s:%s", host, "git_protocol")] = gitProtocol return nil } diff --git a/pkg/cmd/auth/shared/writeable.go b/pkg/cmd/auth/shared/writeable.go index ef117f32d..cf8e678e4 100644 --- a/pkg/cmd/auth/shared/writeable.go +++ b/pkg/cmd/auth/shared/writeable.go @@ -1,14 +1,12 @@ package shared import ( + "strings" + "github.com/cli/cli/v2/internal/config" ) -const ( - oauthToken = "oauth_token" -) - -func AuthTokenWriteable(cfg config.Config, hostname string) (string, bool) { - token, src := cfg.AuthToken(hostname) - return src, (token == "" || src == oauthToken) +func AuthTokenWriteable(authCfg *config.AuthConfig, hostname string) (string, bool) { + token, src := authCfg.Token(hostname) + return src, (token == "" || !strings.HasSuffix(src, "_TOKEN")) } diff --git a/pkg/cmd/auth/status/status.go b/pkg/cmd/auth/status/status.go index c8c4f0ee6..8070fe9ae 100644 --- a/pkg/cmd/auth/status/status.go +++ b/pkg/cmd/auth/status/status.go @@ -61,6 +61,7 @@ func statusRun(opts *StatusOptions) error { if err != nil { return err } + authCfg := cfg.Authentication() // TODO check tty @@ -70,7 +71,7 @@ func statusRun(opts *StatusOptions) error { statusInfo := map[string][]string{} - hostnames := cfg.Hosts() + hostnames := authCfg.Hosts() if len(hostnames) == 0 { fmt.Fprintf(stderr, "You are not logged into any GitHub hosts. Run %s to authenticate.\n", cs.Bold("gh auth login")) @@ -91,13 +92,13 @@ func statusRun(opts *StatusOptions) error { } isHostnameFound = true - token, tokenSource := cfg.AuthToken(hostname) + token, tokenSource := authCfg.Token(hostname) if tokenSource == "oauth_token" { // The go-gh function TokenForHost returns this value as source for tokens read from the // config file, but we want the file path instead. This attempts to reconstruct it. tokenSource = filepath.Join(config.ConfigDir(), "hosts.yml") } - _, tokenIsWriteable := shared.AuthTokenWriteable(cfg, hostname) + _, tokenIsWriteable := shared.AuthTokenWriteable(authCfg, hostname) statusInfo[hostname] = []string{} addMsg := func(x string, ys ...interface{}) { @@ -138,7 +139,7 @@ func statusRun(opts *StatusOptions) error { } addMsg("%s Logged in to %s as %s (%s)", cs.SuccessIcon(), hostname, cs.Bold(username), tokenSource) - proto, _ := cfg.GetOrDefault(hostname, "git_protocol") + proto, _ := authCfg.GitProtocol(hostname) if proto != "" { addMsg("%s Git operations for %s configured to use %s protocol.", cs.SuccessIcon(), hostname, cs.Bold(proto)) diff --git a/pkg/cmd/auth/token/token.go b/pkg/cmd/auth/token/token.go index c28d42d26..093c6b4f3 100644 --- a/pkg/cmd/auth/token/token.go +++ b/pkg/cmd/auth/token/token.go @@ -14,7 +14,8 @@ type TokenOptions struct { IO *iostreams.IOStreams Config func() (config.Config, error) - Hostname string + Hostname string + SecureStorage bool } func NewCmdToken(f *cmdutil.Factory, runF func(*TokenOptions) error) *cobra.Command { @@ -37,6 +38,8 @@ func NewCmdToken(f *cmdutil.Factory, runF func(*TokenOptions) error) *cobra.Comm } cmd.Flags().StringVarP(&opts.Hostname, "hostname", "h", "", "The hostname of the GitHub instance authenticated with") + cmd.Flags().BoolVarP(&opts.SecureStorage, "secure-storage", "", false, "Search only secure credential store for authentication token") + _ = cmd.Flags().MarkHidden("secure-storeage") return cmd } @@ -51,8 +54,14 @@ func tokenRun(opts *TokenOptions) error { if err != nil { return err } + authCfg := cfg.Authentication() - val, _ := cfg.AuthToken(hostname) + var val string + if opts.SecureStorage { + val, _ = authCfg.TokenFromKeyring(hostname) + } else { + val, _ = authCfg.Token(hostname) + } if val == "" { return fmt.Errorf("no oauth token") } diff --git a/pkg/cmd/auth/token/token_test.go b/pkg/cmd/auth/token/token_test.go index 22d13a9a5..463c073b2 100644 --- a/pkg/cmd/auth/token/token_test.go +++ b/pkg/cmd/auth/token/token_test.go @@ -9,6 +9,7 @@ import ( "github.com/cli/cli/v2/pkg/iostreams" "github.com/google/shlex" "github.com/stretchr/testify/assert" + "github.com/zalando/go-keyring" ) func TestNewCmdToken(t *testing.T) { @@ -34,6 +35,11 @@ func TestNewCmdToken(t *testing.T) { input: "-h github.mycompany.com", output: TokenOptions{Hostname: "github.mycompany.com"}, }, + { + name: "with secure-storage", + input: "--secure-storage", + output: TokenOptions{SecureStorage: true}, + }, } for _, tt := range tests { @@ -71,11 +77,12 @@ func TestNewCmdToken(t *testing.T) { assert.NoError(t, err) assert.Equal(t, tt.output.Hostname, cmdOpts.Hostname) + assert.Equal(t, tt.output.SecureStorage, cmdOpts.SecureStorage) }) } } -func Test_tokenRun(t *testing.T) { +func TestTokenRun(t *testing.T) { tests := []struct { name string opts TokenOptions @@ -121,17 +128,77 @@ func Test_tokenRun(t *testing.T) { } for _, tt := range tests { - ios, _, stdout, _ := iostreams.Test() - tt.opts.IO = ios - t.Run(tt.name, func(t *testing.T) { + ios, _, stdout, _ := iostreams.Test() + tt.opts.IO = ios + err := tokenRun(&tt.opts) + if tt.wantErr { + assert.Error(t, err) + assert.EqualError(t, err, tt.wantErrMsg) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.wantStdout, stdout.String()) + }) + } +} + +func TestTokenRunSecureStorage(t *testing.T) { + tests := []struct { + name string + opts TokenOptions + wantStdout string + wantErr bool + wantErrMsg string + }{ + { + name: "token", + opts: TokenOptions{ + Config: func() (config.Config, error) { + cfg := config.NewBlankConfig() + _ = keyring.Set("gh:github.com", "", "gho_ABCDEFG") + return cfg, nil + }, + }, + wantStdout: "gho_ABCDEFG\n", + }, + { + name: "token by hostname", + opts: TokenOptions{ + Config: func() (config.Config, error) { + cfg := config.NewBlankConfig() + _ = keyring.Set("gh:mycompany.com", "", "gho_1234567") + return cfg, nil + }, + Hostname: "mycompany.com", + }, + wantStdout: "gho_1234567\n", + }, + { + name: "no token", + opts: TokenOptions{ + Config: func() (config.Config, error) { + cfg := config.NewBlankConfig() + return cfg, nil + }, + }, + wantErr: true, + wantErrMsg: "no oauth token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + keyring.MockInit() + ios, _, stdout, _ := iostreams.Test() + tt.opts.IO = ios + tt.opts.SecureStorage = true err := tokenRun(&tt.opts) if tt.wantErr { assert.Error(t, err) assert.EqualError(t, err, tt.wantErrMsg) return } - assert.NoError(t, err) assert.Equal(t, tt.wantStdout, stdout.String()) }) diff --git a/pkg/cmd/browse/browse.go b/pkg/cmd/browse/browse.go index 524f9f20e..619c7f0aa 100644 --- a/pkg/cmd/browse/browse.go +++ b/pkg/cmd/browse/browse.go @@ -174,7 +174,7 @@ func parseSection(baseRepo ghrepo.Interface, opts *BrowseOptions) (string, error } } - if isNumber(opts.SelectorArg) { + if !opts.CommitFlag && isNumber(opts.SelectorArg) { return fmt.Sprintf("issues/%s", strings.TrimPrefix(opts.SelectorArg, "#")), nil } diff --git a/pkg/cmd/browse/browse_test.go b/pkg/cmd/browse/browse_test.go index 3a283f8f4..57b6ad99b 100644 --- a/pkg/cmd/browse/browse_test.go +++ b/pkg/cmd/browse/browse_test.go @@ -123,6 +123,15 @@ func TestNewCmdBrowse(t *testing.T) { }, wantsErr: false, }, + { + name: "commit hash flag", + cli: "-c 123", + wants: BrowseOptions{ + CommitFlag: true, + SelectorArg: "123", + }, + wantsErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -395,6 +404,17 @@ func Test_runBrowse(t *testing.T) { wantsErr: false, expectedURL: "https://github.com/vilmibm/gh-user-status/tree/6f1a2405cace1633d89a79c74c65f22fe78f9659/main.go", }, + { + name: "open number only commit hash", + opts: BrowseOptions{ + CommitFlag: true, + SelectorArg: "1234567890", + GitClient: &testGitClient{}, + }, + baseRepo: ghrepo.New("yanskun", "ILoveGitHub"), + wantsErr: false, + expectedURL: "https://github.com/yanskun/ILoveGitHub/commit/1234567890", + }, { name: "relative path from browse_test.go", opts: BrowseOptions{ diff --git a/pkg/cmd/codespace/code.go b/pkg/cmd/codespace/code.go index 429b5ce71..77153ccd7 100644 --- a/pkg/cmd/codespace/code.go +++ b/pkg/cmd/codespace/code.go @@ -10,7 +10,7 @@ import ( func newCodeCmd(app *App) *cobra.Command { var ( - codespace string + selector *CodespaceSelector useInsiders bool useWeb bool ) @@ -20,11 +20,12 @@ func newCodeCmd(app *App) *cobra.Command { Short: "Open a codespace in Visual Studio Code", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return app.VSCode(cmd.Context(), codespace, useInsiders, useWeb) + return app.VSCode(cmd.Context(), selector, useInsiders, useWeb) }, } - codeCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + selector = AddCodespaceSelector(codeCmd, app.apiClient) + codeCmd.Flags().BoolVar(&useInsiders, "insiders", false, "Use the insiders version of Visual Studio Code") codeCmd.Flags().BoolVarP(&useWeb, "web", "w", false, "Use the web version of Visual Studio Code") @@ -32,8 +33,8 @@ func newCodeCmd(app *App) *cobra.Command { } // VSCode opens a codespace in the local VS VSCode application. -func (a *App) VSCode(ctx context.Context, codespaceName string, useInsiders bool, useWeb bool) error { - codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName) +func (a *App) VSCode(ctx context.Context, selector *CodespaceSelector, useInsiders bool, useWeb bool) error { + codespace, err := selector.Select(ctx) if err != nil { return err } diff --git a/pkg/cmd/codespace/code_test.go b/pkg/cmd/codespace/code_test.go index f43d8a20c..be2743c5d 100644 --- a/pkg/cmd/codespace/code_test.go +++ b/pkg/cmd/codespace/code_test.go @@ -69,7 +69,9 @@ func TestApp_VSCode(t *testing.T) { apiClient: testCodeApiMock(), io: ios, } - if err := a.VSCode(context.Background(), tt.args.codespaceName, tt.args.useInsiders, tt.args.useWeb); (err != nil) != tt.wantErr { + selector := &CodespaceSelector{api: a.apiClient, codespaceName: tt.args.codespaceName} + + if err := a.VSCode(context.Background(), selector, tt.args.useInsiders, tt.args.useWeb); (err != nil) != tt.wantErr { t.Errorf("App.VSCode() error = %v, wantErr %v", err, tt.wantErr) } b.Verify(t, tt.wantURL) @@ -85,8 +87,9 @@ func TestApp_VSCode(t *testing.T) { func TestPendingOperationDisallowsCode(t *testing.T) { app := testingCodeApp() + selector := &CodespaceSelector{api: app.apiClient, codespaceName: "disabledCodespace"} - if err := app.VSCode(context.Background(), "disabledCodespace", false, false); err != nil { + if err := app.VSCode(context.Background(), selector, false, false); err != nil { if err.Error() != "codespace is disabled while it has a pending operation: Some pending operation" { t.Errorf("expected pending operation error, but got: %v", err) } diff --git a/pkg/cmd/codespace/codespace_selector.go b/pkg/cmd/codespace/codespace_selector.go new file mode 100644 index 000000000..69b9ed06b --- /dev/null +++ b/pkg/cmd/codespace/codespace_selector.go @@ -0,0 +1,123 @@ +package codespace + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/cli/cli/v2/internal/codespaces/api" + "github.com/spf13/cobra" +) + +type CodespaceSelector struct { + api apiClient + + repoName string + codespaceName string +} + +var errNoFilteredCodespaces = errors.New("you have no codespaces meeting the filter criteria") + +// AddCodespaceSelector adds persistent flags for selecting a codespace to the given command and returns a CodespaceSelector which applies them +func AddCodespaceSelector(cmd *cobra.Command, api apiClient) *CodespaceSelector { + cs := &CodespaceSelector{api: api} + + cmd.PersistentFlags().StringVarP(&cs.codespaceName, "codespace", "c", "", "Name of the codespace") + cmd.PersistentFlags().StringVarP(&cs.repoName, "repo", "R", "", "Filter codespace selection by repository name (user/repo)") + + cmd.MarkFlagsMutuallyExclusive("codespace", "repo") + + return cs +} + +func (cs *CodespaceSelector) Select(ctx context.Context) (codespace *api.Codespace, err error) { + if cs.codespaceName != "" { + codespace, err = cs.api.GetCodespace(ctx, cs.codespaceName, true) + if err != nil { + return nil, fmt.Errorf("getting full codespace details: %w", err) + } + } else { + codespaces, err := cs.fetchCodespaces(ctx) + if err != nil { + return nil, err + } + + codespace, err = cs.chooseCodespace(ctx, codespaces) + if err != nil { + return nil, err + } + } + + if codespace.PendingOperation { + return nil, fmt.Errorf( + "codespace is disabled while it has a pending operation: %s", + codespace.PendingOperationDisabledReason, + ) + } + + return codespace, nil +} + +func (cs *CodespaceSelector) SelectName(ctx context.Context) (string, error) { + if cs.codespaceName != "" { + return cs.codespaceName, nil + } + + codespaces, err := cs.fetchCodespaces(ctx) + if err != nil { + return "", err + } + + codespace, err := cs.chooseCodespace(ctx, codespaces) + if err != nil { + return "", err + } + + return codespace.Name, nil +} + +func (cs *CodespaceSelector) fetchCodespaces(ctx context.Context) (codespaces []*api.Codespace, err error) { + codespaces, err = cs.api.ListCodespaces(ctx, api.ListCodespacesOptions{}) + if err != nil { + return nil, fmt.Errorf("error getting codespaces: %w", err) + } + + if len(codespaces) == 0 { + return nil, errNoCodespaces + } + + // Note that repo filtering done here can also be done in api.ListCodespaces. + // We do it here instead so that we can differentiate no codespaces in general vs. none after filtering. + if cs.repoName != "" { + var filteredCodespaces []*api.Codespace + for _, c := range codespaces { + if !strings.EqualFold(c.Repository.FullName, cs.repoName) { + continue + } + + filteredCodespaces = append(filteredCodespaces, c) + } + + codespaces = filteredCodespaces + } + + if len(codespaces) == 0 { + return nil, errNoFilteredCodespaces + } + + return codespaces, err +} + +func (cs *CodespaceSelector) chooseCodespace(ctx context.Context, codespaces []*api.Codespace) (codespace *api.Codespace, err error) { + skipPromptForSingleOption := cs.repoName != "" + codespace, err = chooseCodespaceFromList(ctx, codespaces, false, skipPromptForSingleOption) + if err != nil { + if err == errNoCodespaces { + return nil, err + } + return nil, fmt.Errorf("choosing codespace: %w", err) + } + + return codespace, nil +} diff --git a/pkg/cmd/codespace/codespace_selector_test.go b/pkg/cmd/codespace/codespace_selector_test.go new file mode 100644 index 000000000..30ba3c588 --- /dev/null +++ b/pkg/cmd/codespace/codespace_selector_test.go @@ -0,0 +1,137 @@ +package codespace + +import ( + "context" + "fmt" + "testing" + + "github.com/cli/cli/v2/internal/codespaces/api" +) + +func TestSelectWithCodespaceName(t *testing.T) { + wantName := "mock-codespace" + + api := &apiClientMock{ + GetCodespaceFunc: func(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error) { + if name != wantName { + t.Errorf("incorrect name: want %s, got %s", wantName, name) + } + + return &api.Codespace{}, nil + }, + } + + cs := &CodespaceSelector{api: api, codespaceName: wantName} + + _, err := cs.Select(context.Background()) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestSelectNameWithCodespaceName(t *testing.T) { + wantName := "mock-codespace" + + cs := &CodespaceSelector{codespaceName: wantName} + + name, err := cs.SelectName(context.Background()) + + if name != wantName { + t.Errorf("incorrect name: want %s, got %s", wantName, name) + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestFetchCodespaces(t *testing.T) { + var ( + repoA1 = &api.Codespace{Name: "1", Repository: api.Repository{FullName: "mock/A"}} + repoA2 = &api.Codespace{Name: "2", Repository: api.Repository{FullName: "mock/A"}} + + repoB1 = &api.Codespace{Name: "1", Repository: api.Repository{FullName: "mock/B"}} + ) + + tests := []struct { + tName string + apiCodespaces []*api.Codespace + repoName string + wantCodespaces []*api.Codespace + wantErr error + }{ + // Empty case + { + "empty", nil, "", nil, errNoCodespaces, + }, + + // Tests with no filtering + { + "no filtering, single codespace", + []*api.Codespace{repoA1}, + "", + []*api.Codespace{repoA1}, + nil, + }, + { + "no filtering, multiple codespaces", + []*api.Codespace{repoA1, repoA2, repoB1}, + "", + []*api.Codespace{repoA1, repoA2, repoB1}, + nil, + }, + + // Test repo filtering + { + "repo filtering, single codespace", + []*api.Codespace{repoA1}, + "mock/A", + []*api.Codespace{repoA1}, + nil, + }, + { + "repo filtering, multiple codespaces", + []*api.Codespace{repoA1, repoA2, repoB1}, + "mock/A", + []*api.Codespace{repoA1, repoA2}, + nil, + }, + { + "repo filtering, multiple codespaces 2", + []*api.Codespace{repoA1, repoA2, repoB1}, + "mock/B", + []*api.Codespace{repoB1}, + nil, + }, + { + "repo filtering, no matches", + []*api.Codespace{repoA1, repoA2, repoB1}, + "mock/C", + nil, + errNoFilteredCodespaces, + }, + } + + for _, tt := range tests { + t.Run(tt.tName, func(t *testing.T) { + api := &apiClientMock{ + ListCodespacesFunc: func(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) { + return tt.apiCodespaces, nil + }, + } + + cs := &CodespaceSelector{api: api, repoName: tt.repoName} + + codespaces, err := cs.fetchCodespaces(context.Background()) + + if err != tt.wantErr { + t.Errorf("expected error to be %v, got %v", tt.wantErr, err) + } + + if fmt.Sprintf("%v", tt.wantCodespaces) != fmt.Sprintf("%v", codespaces) { + t.Errorf("expected codespaces to be %v, got %v", tt.wantCodespaces, codespaces) + } + }) + } +} diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index 24235ed27..3fb7cac10 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -102,21 +102,17 @@ type apiClient interface { var errNoCodespaces = errors.New("you have no codespaces") -func chooseCodespace(ctx context.Context, apiClient apiClient) (*api.Codespace, error) { - codespaces, err := apiClient.ListCodespaces(ctx, api.ListCodespacesOptions{}) - if err != nil { - return nil, fmt.Errorf("error getting codespaces: %w", err) - } - return chooseCodespaceFromList(ctx, codespaces, false) -} - // chooseCodespaceFromList returns the codespace that the user has interactively selected from the list, or // an error if there are no codespaces. -func chooseCodespaceFromList(ctx context.Context, codespaces []*api.Codespace, includeOwner bool) (*api.Codespace, error) { +func chooseCodespaceFromList(ctx context.Context, codespaces []*api.Codespace, includeOwner bool, skipPromptForSingleOption bool) (*api.Codespace, error) { if len(codespaces) == 0 { return nil, errNoCodespaces } + if skipPromptForSingleOption && len(codespaces) == 1 { + return codespaces[0], nil + } + sortedCodespaces := codespaces sort.Slice(sortedCodespaces, func(i, j int) bool { return sortedCodespaces[i].CreatedAt > sortedCodespaces[j].CreatedAt @@ -154,35 +150,6 @@ func formatCodespacesForSelect(codespaces []*api.Codespace, includeOwner bool) [ return names } -// getOrChooseCodespace prompts the user to choose a codespace if the codespaceName is empty. -// It then fetches the codespace record with full connection details. -// TODO(josebalius): accept a progress indicator or *App and show progress when fetching. -func getOrChooseCodespace(ctx context.Context, apiClient apiClient, codespaceName string) (codespace *api.Codespace, err error) { - if codespaceName == "" { - codespace, err = chooseCodespace(ctx, apiClient) - if err != nil { - if err == errNoCodespaces { - return nil, err - } - return nil, fmt.Errorf("choosing codespace: %w", err) - } - } else { - codespace, err = apiClient.GetCodespace(ctx, codespaceName, true) - if err != nil { - return nil, fmt.Errorf("getting full codespace details: %w", err) - } - } - - if codespace.PendingOperation { - return nil, fmt.Errorf( - "codespace is disabled while it has a pending operation: %s", - codespace.PendingOperationDisabledReason, - ) - } - - return codespace, nil -} - func safeClose(closer io.Closer, err *error) { if closeErr := closer.Close(); *err == nil { *err = closeErr @@ -289,5 +256,9 @@ func addDeprecatedRepoShorthand(cmd *cobra.Command, target *string) error { return fmt.Errorf("error marking `-r` shorthand as deprecated: %w", err) } + if cmd.Flag("codespace") != nil { + cmd.MarkFlagsMutuallyExclusive("codespace", "repo-deprecated") + } + return nil } diff --git a/pkg/cmd/codespace/delete.go b/pkg/cmd/codespace/delete.go index 3af44395e..d01c3c642 100644 --- a/pkg/cmd/codespace/delete.go +++ b/pkg/cmd/codespace/delete.go @@ -41,6 +41,8 @@ func newDeleteCmd(app *App) *cobra.Command { prompter: &surveyPrompter{}, } + var selector *CodespaceSelector + deleteCmd := &cobra.Command{ Use: "delete", Short: "Delete codespaces", @@ -54,6 +56,11 @@ func newDeleteCmd(app *App) *cobra.Command { `), Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { + // TODO: ideally we would use the selector directly, but the logic here is too intertwined with other flags to do so elegantly + // After the admin subcommand is added (see https://github.com/cli/cli/pull/6944#issuecomment-1419553639) we can revisit this. + opts.codespaceName = selector.codespaceName + opts.repoFilter = selector.repoName + if opts.deleteAll && opts.repoFilter != "" { return cmdutil.FlagErrorf("both `--all` and `--repo` is not supported") } @@ -64,13 +71,12 @@ func newDeleteCmd(app *App) *cobra.Command { }, } - deleteCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "Name of the codespace") - deleteCmd.Flags().BoolVar(&opts.deleteAll, "all", false, "Delete all codespaces") - deleteCmd.Flags().StringVarP(&opts.repoFilter, "repo", "R", "", "Delete codespaces for a `repository`") - if err := addDeprecatedRepoShorthand(deleteCmd, &opts.repoFilter); err != nil { + selector = AddCodespaceSelector(deleteCmd, app.apiClient) + if err := addDeprecatedRepoShorthand(deleteCmd, &selector.repoName); err != nil { fmt.Fprintf(app.io.ErrOut, "%v\n", err) } + deleteCmd.Flags().BoolVar(&opts.deleteAll, "all", false, "Delete all codespaces") deleteCmd.Flags().BoolVarP(&opts.skipConfirm, "force", "f", false, "Skip confirmation for codespaces that contain unsaved changes") deleteCmd.Flags().Uint16Var(&opts.keepDays, "days", 0, "Delete codespaces older than `N` days") deleteCmd.Flags().StringVarP(&opts.orgName, "org", "o", "", "The `login` handle of the organization (admin-only)") @@ -100,7 +106,7 @@ func (a *App) Delete(ctx context.Context, opts deleteOptions) (err error) { if !opts.deleteAll && opts.repoFilter == "" { includeUsername := opts.orgName != "" - c, err := chooseCodespaceFromList(ctx, codespaces, includeUsername) + c, err := chooseCodespaceFromList(ctx, codespaces, includeUsername, false) if err != nil { return fmt.Errorf("error choosing codespace: %w", err) } diff --git a/pkg/cmd/codespace/edit.go b/pkg/cmd/codespace/edit.go index d07173623..0eecbc257 100644 --- a/pkg/cmd/codespace/edit.go +++ b/pkg/cmd/codespace/edit.go @@ -2,6 +2,7 @@ package codespace import ( "context" + "errors" "fmt" "github.com/cli/cli/v2/internal/codespaces/api" @@ -10,9 +11,9 @@ import ( ) type editOptions struct { - codespaceName string - displayName string - machine string + selector *CodespaceSelector + displayName string + machine string } func newEditCmd(app *App) *cobra.Command { @@ -31,7 +32,7 @@ func newEditCmd(app *App) *cobra.Command { }, } - editCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "Name of the codespace") + opts.selector = AddCodespaceSelector(editCmd, app.apiClient) editCmd.Flags().StringVarP(&opts.displayName, "display-name", "d", "", "Set the display name") editCmd.Flags().StringVar(&opts.displayName, "displayName", "", "display name") if err := editCmd.Flags().MarkDeprecated("displayName", "use `--display-name` instead"); err != nil { @@ -44,21 +45,17 @@ func newEditCmd(app *App) *cobra.Command { // Edits a codespace func (a *App) Edit(ctx context.Context, opts editOptions) error { - codespaceName := opts.codespaceName - - if codespaceName == "" { - selectedCodespace, err := chooseCodespace(ctx, a.apiClient) - if err != nil { - if err == errNoCodespaces { - return err - } - return fmt.Errorf("error choosing codespace: %w", err) + codespaceName, err := opts.selector.SelectName(ctx) + if err != nil { + // TODO: is there a cleaner way to do this? + if errors.Is(err, errNoCodespaces) || errors.Is(err, errNoFilteredCodespaces) { + return err } - codespaceName = selectedCodespace.Name + return fmt.Errorf("error choosing codespace: %w", err) } a.StartProgressIndicatorWithLabel("Editing codespace") - _, err := a.apiClient.EditCodespace(ctx, codespaceName, &api.EditCodespaceParams{ + _, err = a.apiClient.EditCodespace(ctx, codespaceName, &api.EditCodespaceParams{ DisplayName: opts.displayName, Machine: opts.machine, }) diff --git a/pkg/cmd/codespace/edit_test.go b/pkg/cmd/codespace/edit_test.go index 886d9e455..01fb4f4ef 100644 --- a/pkg/cmd/codespace/edit_test.go +++ b/pkg/cmd/codespace/edit_test.go @@ -23,9 +23,9 @@ func TestEdit(t *testing.T) { { name: "edit codespace display name", opts: editOptions{ - codespaceName: "hubot", - displayName: "hubot-changed", - machine: "", + selector: &CodespaceSelector{codespaceName: "hubot"}, + displayName: "hubot-changed", + machine: "", }, wantEdits: &api.EditCodespaceParams{ DisplayName: "hubot-changed", @@ -54,9 +54,9 @@ func TestEdit(t *testing.T) { { name: "edit codespace machine", opts: editOptions{ - codespaceName: "hubot", - displayName: "", - machine: "machine", + selector: &CodespaceSelector{codespaceName: "hubot"}, + displayName: "", + machine: "machine", }, wantEdits: &api.EditCodespaceParams{ Machine: "machine", @@ -92,6 +92,11 @@ func TestEdit(t *testing.T) { var err error if tt.cliArgs == nil { + if tt.opts.selector == nil { + t.Fatalf("selector must be set in opts if cliArgs are not provided") + } + + tt.opts.selector.api = apiMock err = a.Edit(context.Background(), tt.opts) } else { cmd := newEditCmd(a) diff --git a/pkg/cmd/codespace/jupyter.go b/pkg/cmd/codespace/jupyter.go index 0e3e0dee0..1374916bb 100644 --- a/pkg/cmd/codespace/jupyter.go +++ b/pkg/cmd/codespace/jupyter.go @@ -13,28 +13,28 @@ import ( ) func newJupyterCmd(app *App) *cobra.Command { - var codespace string + var selector *CodespaceSelector jupyterCmd := &cobra.Command{ Use: "jupyter", Short: "Open a codespace in JupyterLab", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return app.Jupyter(cmd.Context(), codespace) + return app.Jupyter(cmd.Context(), selector) }, } - jupyterCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + selector = AddCodespaceSelector(jupyterCmd, app.apiClient) return jupyterCmd } -func (a *App) Jupyter(ctx context.Context, codespaceName string) (err error) { +func (a *App) Jupyter(ctx context.Context, selector *CodespaceSelector) (err error) { // Ensure all child tasks (e.g. port forwarding) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() - codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName) + codespace, err := selector.Select(ctx) if err != nil { return err } diff --git a/pkg/cmd/codespace/logs.go b/pkg/cmd/codespace/logs.go index 9a42b9866..6f0e88402 100644 --- a/pkg/cmd/codespace/logs.go +++ b/pkg/cmd/codespace/logs.go @@ -12,8 +12,8 @@ import ( func newLogsCmd(app *App) *cobra.Command { var ( - codespace string - follow bool + selector *CodespaceSelector + follow bool ) logsCmd := &cobra.Command{ @@ -21,22 +21,23 @@ func newLogsCmd(app *App) *cobra.Command { Short: "Access codespace logs", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return app.Logs(cmd.Context(), codespace, follow) + return app.Logs(cmd.Context(), selector, follow) }, } - logsCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + selector = AddCodespaceSelector(logsCmd, app.apiClient) + logsCmd.Flags().BoolVarP(&follow, "follow", "f", false, "Tail and follow the logs") return logsCmd } -func (a *App) Logs(ctx context.Context, codespaceName string, follow bool) (err error) { +func (a *App) Logs(ctx context.Context, selector *CodespaceSelector, follow bool) (err error) { // Ensure all child tasks (port forwarding, remote exec) terminate before return. ctx, cancel := context.WithCancel(ctx) defer cancel() - codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName) + codespace, err := selector.Select(ctx) if err != nil { return err } diff --git a/pkg/cmd/codespace/logs_test.go b/pkg/cmd/codespace/logs_test.go index 161657b4d..c4ba1ef59 100644 --- a/pkg/cmd/codespace/logs_test.go +++ b/pkg/cmd/codespace/logs_test.go @@ -10,8 +10,9 @@ import ( func TestPendingOperationDisallowsLogs(t *testing.T) { app := testingLogsApp() + selector := &CodespaceSelector{api: app.apiClient, codespaceName: "disabledCodespace"} - if err := app.Logs(context.Background(), "disabledCodespace", false); err != nil { + if err := app.Logs(context.Background(), selector, false); err != nil { if err.Error() != "codespace is disabled while it has a pending operation: Some pending operation" { t.Errorf("expected pending operation error, but got: %v", err) } diff --git a/pkg/cmd/codespace/ports.go b/pkg/cmd/codespace/ports.go index d36fa54a1..2f84d1678 100644 --- a/pkg/cmd/codespace/ports.go +++ b/pkg/cmd/codespace/ports.go @@ -29,30 +29,33 @@ const ( // newPortsCmd returns a Cobra "ports" command that displays a table of available ports, // according to the specified flags. func newPortsCmd(app *App) *cobra.Command { - var codespace string - var exporter cmdutil.Exporter + var ( + selector *CodespaceSelector + exporter cmdutil.Exporter + ) portsCmd := &cobra.Command{ Use: "ports", Short: "List ports in a codespace", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return app.ListPorts(cmd.Context(), codespace, exporter) + return app.ListPorts(cmd.Context(), selector, exporter) }, } - portsCmd.PersistentFlags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace") + selector = AddCodespaceSelector(portsCmd, app.apiClient) + cmdutil.AddJSONFlags(portsCmd, &exporter, portFields) - portsCmd.AddCommand(newPortsForwardCmd(app)) - portsCmd.AddCommand(newPortsVisibilityCmd(app)) + portsCmd.AddCommand(newPortsForwardCmd(app, selector)) + portsCmd.AddCommand(newPortsVisibilityCmd(app, selector)) return portsCmd } // ListPorts lists known ports in a codespace. -func (a *App) ListPorts(ctx context.Context, codespaceName string, exporter cmdutil.Exporter) (err error) { - codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName) +func (a *App) ListPorts(ctx context.Context, selector *CodespaceSelector, exporter cmdutil.Exporter) (err error) { + codespace, err := selector.Select(ctx) if err != nil { return err } @@ -218,21 +221,14 @@ func getDevContainer(ctx context.Context, apiClient apiClient, codespace *api.Co return ch } -func newPortsVisibilityCmd(app *App) *cobra.Command { +func newPortsVisibilityCmd(app *App, selector *CodespaceSelector) *cobra.Command { return &cobra.Command{ Use: "visibility :{public|private|org}...", Short: "Change the visibility of the forwarded port", Example: "gh codespace ports visibility 80:org 3000:private 8000:public", Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - codespace, err := cmd.Flags().GetString("codespace") - if err != nil { - // should only happen if flag is not defined - // or if the flag is not of string type - // since it's a persistent flag that we control it should never happen - return fmt.Errorf("get codespace flag: %w", err) - } - return app.UpdatePortVisibility(cmd.Context(), codespace, args) + return app.UpdatePortVisibility(cmd.Context(), selector, args) }, } } @@ -261,13 +257,13 @@ func (e *ErrUpdatingPortVisibility) Unwrap() error { var errUpdatePortVisibilityForbidden = errors.New("organization admin has forbidden this privacy setting") -func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName string, args []string) (err error) { +func (a *App) UpdatePortVisibility(ctx context.Context, selector *CodespaceSelector, args []string) (err error) { ports, err := a.parsePortVisibilities(args) if err != nil { return fmt.Errorf("error parsing port arguments: %w", err) } - codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName) + codespace, err := selector.Select(ctx) if err != nil { return err } @@ -347,32 +343,24 @@ func (a *App) parsePortVisibilities(args []string) ([]portVisibility, error) { // NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of // port pairs from the codespace to localhost. -func newPortsForwardCmd(app *App) *cobra.Command { +func newPortsForwardCmd(app *App, selector *CodespaceSelector) *cobra.Command { return &cobra.Command{ Use: "forward :...", Short: "Forward ports", Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - codespace, err := cmd.Flags().GetString("codespace") - if err != nil { - // should only happen if flag is not defined - // or if the flag is not of string type - // since it's a persistent flag that we control it should never happen - return fmt.Errorf("get codespace flag: %w", err) - } - - return app.ForwardPorts(cmd.Context(), codespace, args) + return app.ForwardPorts(cmd.Context(), selector, args) }, } } -func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []string) (err error) { +func (a *App) ForwardPorts(ctx context.Context, selector *CodespaceSelector, ports []string) (err error) { portPairs, err := getPortPairs(ports) if err != nil { return fmt.Errorf("get port pairs: %w", err) } - codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName) + codespace, err := selector.Select(ctx) if err != nil { return err } diff --git a/pkg/cmd/codespace/ports_test.go b/pkg/cmd/codespace/ports_test.go index ea61b11a5..9979345b2 100644 --- a/pkg/cmd/codespace/ports_test.go +++ b/pkg/cmd/codespace/ports_test.go @@ -207,13 +207,16 @@ func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, ev portArgs = append(portArgs, fmt.Sprintf("%d:%s", pv.number, pv.visibility)) } - return a.UpdatePortVisibility(ctx, "codespace-name", portArgs) + selector := &CodespaceSelector{api: a.apiClient, codespaceName: "codespace-name"} + + return a.UpdatePortVisibility(ctx, selector, portArgs) } func TestPendingOperationDisallowsListPorts(t *testing.T) { app := testingPortsApp() + selector := &CodespaceSelector{api: app.apiClient, codespaceName: "disabledCodespace"} - if err := app.ListPorts(context.Background(), "disabledCodespace", nil); err != nil { + if err := app.ListPorts(context.Background(), selector, nil); err != nil { if err.Error() != "codespace is disabled while it has a pending operation: Some pending operation" { t.Errorf("expected pending operation error, but got: %v", err) } @@ -224,8 +227,9 @@ func TestPendingOperationDisallowsListPorts(t *testing.T) { func TestPendingOperationDisallowsUpdatePortVisability(t *testing.T) { app := testingPortsApp() + selector := &CodespaceSelector{api: app.apiClient, codespaceName: "disabledCodespace"} - if err := app.UpdatePortVisibility(context.Background(), "disabledCodespace", nil); err != nil { + if err := app.UpdatePortVisibility(context.Background(), selector, nil); err != nil { if err.Error() != "codespace is disabled while it has a pending operation: Some pending operation" { t.Errorf("expected pending operation error, but got: %v", err) } @@ -236,8 +240,9 @@ func TestPendingOperationDisallowsUpdatePortVisability(t *testing.T) { func TestPendingOperationDisallowsForwardPorts(t *testing.T) { app := testingPortsApp() + selector := &CodespaceSelector{api: app.apiClient, codespaceName: "disabledCodespace"} - if err := app.ForwardPorts(context.Background(), "disabledCodespace", nil); err != nil { + if err := app.ForwardPorts(context.Background(), selector, nil); err != nil { if err.Error() != "codespace is disabled while it has a pending operation: Some pending operation" { t.Errorf("expected pending operation error, but got: %v", err) } diff --git a/pkg/cmd/codespace/rebuild.go b/pkg/cmd/codespace/rebuild.go index 923f471c4..17e00670b 100644 --- a/pkg/cmd/codespace/rebuild.go +++ b/pkg/cmd/codespace/rebuild.go @@ -10,8 +10,10 @@ import ( ) func newRebuildCmd(app *App) *cobra.Command { - var codespace string - var fullRebuild bool + var ( + selector *CodespaceSelector + fullRebuild bool + ) rebuildCmd := &cobra.Command{ Use: "rebuild", @@ -21,21 +23,22 @@ preserved. Your codespace will be rebuilt using your working directory's dev container. A full rebuild also removes cached Docker images.`, Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - return app.Rebuild(cmd.Context(), codespace, fullRebuild) + return app.Rebuild(cmd.Context(), selector, fullRebuild) }, } - rebuildCmd.Flags().StringVarP(&codespace, "codespace", "c", "", "name of the codespace") + selector = AddCodespaceSelector(rebuildCmd, app.apiClient) + rebuildCmd.Flags().BoolVar(&fullRebuild, "full", false, "perform a full rebuild") return rebuildCmd } -func (a *App) Rebuild(ctx context.Context, codespaceName string, full bool) (err error) { +func (a *App) Rebuild(ctx context.Context, selector *CodespaceSelector, full bool) (err error) { ctx, cancel := context.WithCancel(ctx) defer cancel() - codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName) + codespace, err := selector.Select(ctx) if err != nil { return err } diff --git a/pkg/cmd/codespace/rebuild_test.go b/pkg/cmd/codespace/rebuild_test.go index f2496d089..b38bababe 100644 --- a/pkg/cmd/codespace/rebuild_test.go +++ b/pkg/cmd/codespace/rebuild_test.go @@ -14,8 +14,9 @@ func TestAlreadyRebuildingCodespace(t *testing.T) { State: api.CodespaceStateRebuilding, } app := testingRebuildApp(*rebuildingCodespace) + selector := &CodespaceSelector{api: app.apiClient, codespaceName: "rebuildingCodespace"} - err := app.Rebuild(context.Background(), "rebuildingCodespace", false) + err := app.Rebuild(context.Background(), selector, false) if err != nil { t.Errorf("rebuilding a codespace that was already rebuilding: %v", err) } diff --git a/pkg/cmd/codespace/select.go b/pkg/cmd/codespace/select.go index 32cc277b2..cb6a4128e 100644 --- a/pkg/cmd/codespace/select.go +++ b/pkg/cmd/codespace/select.go @@ -10,10 +10,13 @@ import ( type selectOptions struct { filePath string + selector *CodespaceSelector } func newSelectCmd(app *App) *cobra.Command { - opts := selectOptions{} + var ( + opts selectOptions + ) selectCmd := &cobra.Command{ Use: "select", @@ -21,19 +24,39 @@ func newSelectCmd(app *App) *cobra.Command { Hidden: true, Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - return app.Select(cmd.Context(), "", opts) + return app.Select(cmd.Context(), opts) }, } + opts.selector = AddCodespaceSelector(selectCmd, app.apiClient) selectCmd.Flags().StringVarP(&opts.filePath, "file", "f", "", "Output file path") return selectCmd } -// Hidden codespace select command allows to reuse existing codespace selection -// dialog by external GH CLI extensions. By default, print selected codespace name -// to stdout. Pass file argument to save result into a file instead. -func (a *App) Select(ctx context.Context, name string, opts selectOptions) (err error) { - codespace, err := getOrChooseCodespace(ctx, a.apiClient, name) +// Hidden codespace `select` command allows to reuse existing codespace selection +// dialog by external GH CLI extensions. By default output selected codespace name +// into `stdout`. Pass `--file`(`-f`) flag along with a file path to output selected +// codespace name into a file instead. +// +// ## Examples +// +// With `stdout` output: +// +// ```shell +// +// gh codespace select +// +// ``` +// +// With `into-a-file` output: +// +// ```shell +// +// gh codespace select --file /tmp/selected_codespace.txt +// +// ``` +func (a *App) Select(ctx context.Context, opts selectOptions) (err error) { + codespace, err := opts.selector.Select(ctx) if err != nil { return err } diff --git a/pkg/cmd/codespace/select_test.go b/pkg/cmd/codespace/select_test.go index e97a7720e..02ea6f967 100644 --- a/pkg/cmd/codespace/select_test.go +++ b/pkg/cmd/codespace/select_test.go @@ -50,6 +50,7 @@ func TestApp_Select(t *testing.T) { a := NewApp(ios, nil, testSelectApiMock(), nil, nil) opts := selectOptions{} + if tt.outputToFile { file, err := os.CreateTemp("", "codespace-selection-test") if err != nil { @@ -61,7 +62,9 @@ func TestApp_Select(t *testing.T) { opts = selectOptions{filePath: file.Name()} } - if err := a.Select(context.Background(), tt.arg, opts); (err != nil) != tt.wantErr { + opts.selector = &CodespaceSelector{api: a.apiClient, codespaceName: tt.arg} + + if err := a.Select(context.Background(), opts); (err != nil) != tt.wantErr { t.Errorf("App.Select() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index d8788b45a..3295956f7 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -36,7 +36,7 @@ const automaticPrivateKeyName = "codespaces.auto" var errKeyFileNotFound = errors.New("SSH key file does not exist") type sshOptions struct { - codespace string + selector *CodespaceSelector profile string serverPort int debug bool @@ -87,7 +87,7 @@ func newSSHCmd(app *App) *cobra.Command { `), PreRunE: func(c *cobra.Command, args []string) error { if opts.stdio { - if opts.codespace == "" { + if opts.selector.codespaceName == "" { return errors.New("`--stdio` requires explicit `--codespace`") } if opts.config { @@ -122,7 +122,7 @@ func newSSHCmd(app *App) *cobra.Command { sshCmd.Flags().StringVarP(&opts.profile, "profile", "", "", "Name of the SSH profile to use") sshCmd.Flags().IntVarP(&opts.serverPort, "server-port", "", 0, "SSH server port number (0 => pick unused)") - sshCmd.Flags().StringVarP(&opts.codespace, "codespace", "c", "", "Name of the codespace") + opts.selector = AddCodespaceSelector(sshCmd, app.apiClient) sshCmd.Flags().BoolVarP(&opts.debug, "debug", "d", false, "Log debug data to a file") sshCmd.Flags().StringVarP(&opts.debugFile, "debug-file", "", "", "Path of the file log to") sshCmd.Flags().BoolVarP(&opts.config, "config", "", false, "Write OpenSSH configuration to stdout") @@ -160,7 +160,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e args = append([]string{"-i", keyPair.PrivateKeyPath}, args...) } - codespace, err := getOrChooseCodespace(ctx, a.apiClient, opts.codespace) + codespace, err := opts.selector.Select(ctx) if err != nil { return err } @@ -471,13 +471,13 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro defer cancel() var csList []*api.Codespace - if opts.codespace == "" { + if opts.selector.codespaceName == "" { a.StartProgressIndicatorWithLabel("Fetching codespaces") csList, err = a.apiClient.ListCodespaces(ctx, api.ListCodespacesOptions{}) a.StopProgressIndicator() } else { var codespace *api.Codespace - codespace, err = getOrChooseCodespace(ctx, a.apiClient, opts.codespace) + codespace, err = opts.selector.Select(ctx) csList = []*api.Codespace{codespace} } if err != nil { @@ -494,7 +494,7 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro var wg sync.WaitGroup var status error for _, cs := range csList { - if cs.State != "Available" && opts.codespace == "" { + if cs.State != "Available" && opts.selector.codespaceName == "" { fmt.Fprintf(os.Stderr, "skipping unavailable codespace %s: %s\n", cs.Name, cs.State) status = cmdutil.SilentError continue @@ -656,7 +656,7 @@ func newCpCmd(app *App) *cobra.Command { // We don't expose all sshOptions. cpCmd.Flags().BoolVarP(&opts.recursive, "recursive", "r", false, "Recursively copy directories") cpCmd.Flags().BoolVarP(&opts.expand, "expand", "e", false, "Expand remote file names on remote shell") - cpCmd.Flags().StringVarP(&opts.codespace, "codespace", "c", "", "Name of the codespace") + opts.selector = AddCodespaceSelector(cpCmd, app.apiClient) cpCmd.Flags().StringVarP(&opts.profile, "profile", "p", "", "Name of the SSH profile to use") return cpCmd } diff --git a/pkg/cmd/codespace/ssh_test.go b/pkg/cmd/codespace/ssh_test.go index 1740b00f7..b76d304d7 100644 --- a/pkg/cmd/codespace/ssh_test.go +++ b/pkg/cmd/codespace/ssh_test.go @@ -15,8 +15,9 @@ import ( func TestPendingOperationDisallowsSSH(t *testing.T) { app := testingSSHApp() + selector := &CodespaceSelector{api: app.apiClient, codespaceName: "disabledCodespace"} - if err := app.SSH(context.Background(), []string{}, sshOptions{codespace: "disabledCodespace"}); err != nil { + if err := app.SSH(context.Background(), []string{}, sshOptions{selector: selector}); err != nil { if err.Error() != "codespace is disabled while it has a pending operation: Some pending operation" { t.Errorf("expected pending operation error, but got: %v", err) } diff --git a/pkg/cmd/codespace/stop.go b/pkg/cmd/codespace/stop.go index bc60e9dc1..1deb27bfa 100644 --- a/pkg/cmd/codespace/stop.go +++ b/pkg/cmd/codespace/stop.go @@ -11,9 +11,9 @@ import ( ) type stopOptions struct { - codespaceName string - orgName string - userName string + selector *CodespaceSelector + orgName string + userName string } func newStopCmd(app *App) *cobra.Command { @@ -24,13 +24,13 @@ func newStopCmd(app *App) *cobra.Command { Short: "Stop a running codespace", Args: noArgsConstraint, RunE: func(cmd *cobra.Command, args []string) error { - if opts.orgName != "" && opts.codespaceName != "" && opts.userName == "" { + if opts.orgName != "" && opts.selector.codespaceName != "" && opts.userName == "" { return cmdutil.FlagErrorf("using `--org` with `--codespace` requires `--user`") } return app.StopCodespace(cmd.Context(), opts) }, } - stopCmd.Flags().StringVarP(&opts.codespaceName, "codespace", "c", "", "Name of the codespace") + opts.selector = AddCodespaceSelector(stopCmd, app.apiClient) stopCmd.Flags().StringVarP(&opts.orgName, "org", "o", "", "The `login` handle of the organization (admin-only)") stopCmd.Flags().StringVarP(&opts.userName, "user", "u", "", "The `username` to stop codespace for (used with --org)") @@ -38,12 +38,19 @@ func newStopCmd(app *App) *cobra.Command { } func (a *App) StopCodespace(ctx context.Context, opts *stopOptions) error { - codespaceName := opts.codespaceName - ownerName := opts.userName + var ( + codespaceName = opts.selector.codespaceName + repoName = opts.selector.repoName + ownerName = opts.userName + ) if codespaceName == "" { a.StartProgressIndicatorWithLabel("Fetching codespaces") - codespaces, err := a.apiClient.ListCodespaces(ctx, api.ListCodespacesOptions{OrgName: opts.orgName, UserName: ownerName}) + codespaces, err := a.apiClient.ListCodespaces(ctx, api.ListCodespacesOptions{ + RepoName: repoName, + OrgName: opts.orgName, + UserName: ownerName, + }) a.StopProgressIndicator() if err != nil { return fmt.Errorf("failed to list codespaces: %w", err) @@ -61,7 +68,8 @@ func (a *App) StopCodespace(ctx context.Context, opts *stopOptions) error { } includeOwner := opts.orgName != "" - codespace, err := chooseCodespaceFromList(ctx, runningCodespaces, includeOwner) + skipPromptForSingleOption := repoName != "" + codespace, err := chooseCodespaceFromList(ctx, runningCodespaces, includeOwner, skipPromptForSingleOption) if err != nil { return fmt.Errorf("failed to choose codespace: %w", err) } diff --git a/pkg/cmd/codespace/stop_test.go b/pkg/cmd/codespace/stop_test.go index bdd3721c0..78e07fcec 100644 --- a/pkg/cmd/codespace/stop_test.go +++ b/pkg/cmd/codespace/stop_test.go @@ -22,7 +22,7 @@ func TestApp_StopCodespace(t *testing.T) { { name: "Stop a codespace I own", opts: &stopOptions{ - codespaceName: "test-codespace", + selector: &CodespaceSelector{codespaceName: "test-codespace"}, }, fields: fields{ apiClient: &apiClientMock{ @@ -52,9 +52,9 @@ func TestApp_StopCodespace(t *testing.T) { { name: "Stop a codespace as an org admin", opts: &stopOptions{ - codespaceName: "test-codespace", - orgName: "test-org", - userName: "test-user", + selector: &CodespaceSelector{codespaceName: "test-codespace"}, + orgName: "test-org", + userName: "test-user", }, fields: fields{ apiClient: &apiClientMock{ diff --git a/pkg/cmd/config/list/list.go b/pkg/cmd/config/list/list.go index f9da66d3f..2faed3f15 100644 --- a/pkg/cmd/config/list/list.go +++ b/pkg/cmd/config/list/list.go @@ -51,7 +51,7 @@ func listRun(opts *ListOptions) error { if opts.Hostname != "" { host = opts.Hostname } else { - host, _ = cfg.DefaultHost() + host, _ = cfg.Authentication().DefaultHost() } configOptions := config.ConfigOptions() diff --git a/pkg/cmd/extension/browse/browse.go b/pkg/cmd/extension/browse/browse.go index 247d43c5b..47b3ea8b2 100644 --- a/pkg/cmd/extension/browse/browse.go +++ b/pkg/cmd/extension/browse/browse.go @@ -343,7 +343,7 @@ func getExtensions(opts ExtBrowseOpts) ([]extEntry, error) { return extEntries, fmt.Errorf("failed to search for extensions: %w", err) } - host, _ := opts.Cfg.DefaultHost() + host, _ := opts.Cfg.Authentication().DefaultHost() for _, repo := range result.Items { if !strings.HasPrefix(repo.Name, "gh-") { diff --git a/pkg/cmd/extension/browse/browse_test.go b/pkg/cmd/extension/browse/browse_test.go index aacfb37b5..5ad117356 100644 --- a/pkg/cmd/extension/browse/browse_test.go +++ b/pkg/cmd/extension/browse/browse_test.go @@ -76,7 +76,11 @@ func Test_getExtensionRepos(t *testing.T) { } cfg := config.NewBlankConfig() - cfg.DefaultHostFunc = func() (string, string) { return "github.com", "" } + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetDefaultHost("github.com", "") + return authCfg + } reg.Register( httpmock.QueryMatcher("GET", "search/repositories", values), diff --git a/pkg/cmd/extension/command.go b/pkg/cmd/extension/command.go index 722163dd2..bfb87188e 100644 --- a/pkg/cmd/extension/command.go +++ b/pkg/cmd/extension/command.go @@ -137,7 +137,7 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { query.Keywords = args query.Qualifiers = qualifiers - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() searcher := search.NewSearcher(client, host) if webMode { @@ -445,7 +445,7 @@ func NewCmdExtension(f *cmdutil.Factory) *cobra.Command { if err != nil { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() client, err := f.HttpClient() if err != nil { return err diff --git a/pkg/cmd/extension/manager.go b/pkg/cmd/extension/manager.go index 4f0e6a9fb..af4dbe9a4 100644 --- a/pkg/cmd/extension/manager.go +++ b/pkg/cmd/extension/manager.go @@ -708,7 +708,7 @@ func (m *Manager) goBinScaffolding(name string) error { return err } - host, _ := m.config.DefaultHost() + host, _ := m.config.Authentication().DefaultHost() currentUser, err := api.CurrentLoginName(api.NewClientFromHTTP(m.client), host) if err != nil { diff --git a/pkg/cmd/factory/default.go b/pkg/cmd/factory/default.go index c74815907..4ff7a77f7 100644 --- a/pkg/cmd/factory/default.go +++ b/pkg/cmd/factory/default.go @@ -96,7 +96,7 @@ func httpClientFunc(f *cmdutil.Factory, appVersion string) func() (*http.Client, return nil, err } opts := api.HTTPClientOptions{ - Config: cfg, + Config: cfg.Authentication(), Log: io.ErrOut, LogColorize: io.ColorEnabled(), AppVersion: appVersion, diff --git a/pkg/cmd/factory/default_test.go b/pkg/cmd/factory/default_test.go index f9b316bce..b58544b44 100644 --- a/pkg/cmd/factory/default_test.go +++ b/pkg/cmd/factory/default_test.go @@ -71,21 +71,19 @@ func Test_BaseRepo(t *testing.T) { }, getConfig: func() (config.Config, error) { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} hosts := []string{"nonsense.com"} if tt.override != "" { hosts = append([]string{tt.override}, hosts...) } - return hosts - } - cfg.DefaultHostFunc = func() (string, string) { + authCfg.SetHosts(hosts) + authCfg.SetToken("", "") + authCfg.SetDefaultHost("nonsense.com", "hosts") if tt.override != "" { - return tt.override, "GH_HOST" + authCfg.SetDefaultHost(tt.override, "GH_HOST") } - return "nonsense.com", "hosts" - } - cfg.AuthTokenFunc = func(string) (string, string) { - return "", "" + return authCfg } return cfg, nil }, @@ -211,21 +209,19 @@ func Test_SmartBaseRepo(t *testing.T) { }, getConfig: func() (config.Config, error) { cfg := &config.ConfigMock{} - cfg.AuthTokenFunc = func(_ string) (string, string) { - return "", "" - } - cfg.HostsFunc = func() []string { + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} hosts := []string{"nonsense.com"} if tt.override != "" { hosts = append([]string{tt.override}, hosts...) } - return hosts - } - cfg.DefaultHostFunc = func() (string, string) { + authCfg.SetHosts(hosts) + authCfg.SetToken("", "") + authCfg.SetDefaultHost("nonsense.com", "hosts") if tt.override != "" { - return tt.override, "GH_HOST" + authCfg.SetDefaultHost(tt.override, "GH_HOST") } - return "nonsense.com", "hosts" + return authCfg } return cfg, nil }, diff --git a/pkg/cmd/factory/remote_resolver.go b/pkg/cmd/factory/remote_resolver.go index 86a0aa324..a672ba5e6 100644 --- a/pkg/cmd/factory/remote_resolver.go +++ b/pkg/cmd/factory/remote_resolver.go @@ -53,11 +53,11 @@ func (rr *remoteResolver) Resolver() func() (context.Remotes, error) { return nil, err } - authedHosts := cfg.Hosts() + authedHosts := cfg.Authentication().Hosts() if len(authedHosts) == 0 { return nil, errors.New("could not find any host configurations") } - defaultHost, src := cfg.DefaultHost() + defaultHost, src := cfg.Authentication().DefaultHost() // Use set to dedupe list of hosts hostsSet := set.NewStringSet() @@ -86,7 +86,7 @@ func (rr *remoteResolver) Resolver() func() (context.Remotes, error) { dummyHostname := "example.com" if isHostEnv(src) { return nil, fmt.Errorf("none of the git remotes configured for this repository correspond to the %s environment variable. Try adding a matching remote or unsetting the variable.", src) - } else if v, _ := cfg.AuthToken(dummyHostname); v != "" { + } else if v, _ := cfg.Authentication().Token(dummyHostname); v != "" { return nil, errors.New("set the GH_HOST environment variable to specify which GitHub host to use") } return nil, errors.New("none of the git remotes configured for this repository point to a known GitHub host. To tell gh about a new GitHub host, please use `gh auth login`") diff --git a/pkg/cmd/factory/remote_resolver_test.go b/pkg/cmd/factory/remote_resolver_test.go index bc20686a9..a38470595 100644 --- a/pkg/cmd/factory/remote_resolver_test.go +++ b/pkg/cmd/factory/remote_resolver_test.go @@ -32,11 +32,11 @@ func Test_remoteResolver(t *testing.T) { }, config: func() config.Config { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{} - } - cfg.DefaultHostFunc = func() (string, string) { - return "github.com", "default" + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{}) + authCfg.SetDefaultHost("github.com", "default") + return authCfg } return cfg }(), @@ -49,11 +49,11 @@ func Test_remoteResolver(t *testing.T) { }, config: func() config.Config { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"example.com"} - } - cfg.DefaultHostFunc = func() (string, string) { - return "example.com", "hosts" + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"example.com"}) + authCfg.SetDefaultHost("example.com", "hosts") + return authCfg } return cfg }(), @@ -68,14 +68,12 @@ func Test_remoteResolver(t *testing.T) { }, config: func() config.Config { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"example.com"} - } - cfg.DefaultHostFunc = func() (string, string) { - return "example.com", "hosts" - } - cfg.AuthTokenFunc = func(string) (string, string) { - return "", "" + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"example.com"}) + authCfg.SetToken("", "") + authCfg.SetDefaultHost("example.com", "hosts") + return authCfg } return cfg }(), @@ -90,11 +88,11 @@ func Test_remoteResolver(t *testing.T) { }, config: func() config.Config { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"example.com"} - } - cfg.DefaultHostFunc = func() (string, string) { - return "example.com", "hosts" + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"example.com"}) + authCfg.SetDefaultHost("example.com", "hosts") + return authCfg } return cfg }(), @@ -109,11 +107,11 @@ func Test_remoteResolver(t *testing.T) { }, config: func() config.Config { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"example.com"} - } - cfg.DefaultHostFunc = func() (string, string) { - return "example.com", "default" + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"example.com"}) + authCfg.SetDefaultHost("example.com", "default") + return authCfg } return cfg }(), @@ -131,11 +129,11 @@ func Test_remoteResolver(t *testing.T) { }, config: func() config.Config { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"example.com"} - } - cfg.DefaultHostFunc = func() (string, string) { - return "example.com", "default" + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"example.com"}) + authCfg.SetDefaultHost("example.com", "default") + return authCfg } return cfg }(), @@ -150,14 +148,12 @@ func Test_remoteResolver(t *testing.T) { }, config: func() config.Config { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"example.com", "github.com"} - } - cfg.DefaultHostFunc = func() (string, string) { - return "github.com", "default" - } - cfg.AuthTokenFunc = func(string) (string, string) { - return "", "" + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"example.com", "github.com"}) + authCfg.SetToken("", "") + authCfg.SetDefaultHost("example.com", "default") + return authCfg } return cfg }(), @@ -173,11 +169,11 @@ func Test_remoteResolver(t *testing.T) { }, config: func() config.Config { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"example.com", "github.com"} - } - cfg.DefaultHostFunc = func() (string, string) { - return "github.com", "default" + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"example.com", "github.com"}) + authCfg.SetDefaultHost("github.com", "default") + return authCfg } return cfg }(), @@ -196,11 +192,11 @@ func Test_remoteResolver(t *testing.T) { }, config: func() config.Config { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"example.com", "github.com"} - } - cfg.DefaultHostFunc = func() (string, string) { - return "github.com", "default" + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"example.com", "github.com"}) + authCfg.SetDefaultHost("github.com", "default") + return authCfg } return cfg }(), @@ -215,11 +211,11 @@ func Test_remoteResolver(t *testing.T) { }, config: func() config.Config { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"example.com"} - } - cfg.DefaultHostFunc = func() (string, string) { - return "test.com", "GH_HOST" + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"example.com"}) + authCfg.SetDefaultHost("test.com", "GH_HOST") + return authCfg } return cfg }(), @@ -235,11 +231,11 @@ func Test_remoteResolver(t *testing.T) { }, config: func() config.Config { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"example.com"} - } - cfg.DefaultHostFunc = func() (string, string) { - return "test.com", "GH_HOST" + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"example.com"}) + authCfg.SetDefaultHost("test.com", "GH_HOST") + return authCfg } return cfg }(), @@ -256,11 +252,11 @@ func Test_remoteResolver(t *testing.T) { }, config: func() config.Config { cfg := &config.ConfigMock{} - cfg.HostsFunc = func() []string { - return []string{"example.com", "test.com"} - } - cfg.DefaultHostFunc = func() (string, string) { - return "test.com", "GH_HOST" + cfg.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetHosts([]string{"example.com", "test.com"}) + authCfg.SetDefaultHost("test.com", "GH_HOST") + return authCfg } return cfg }(), diff --git a/pkg/cmd/gist/clone/clone.go b/pkg/cmd/gist/clone/clone.go index ed074e700..9b75a309c 100644 --- a/pkg/cmd/gist/clone/clone.go +++ b/pkg/cmd/gist/clone/clone.go @@ -79,7 +79,7 @@ func cloneRun(opts *CloneOptions) error { if err != nil { return err } - hostname, _ := cfg.DefaultHost() + hostname, _ := cfg.Authentication().DefaultHost() protocol, err := cfg.GetOrDefault(hostname, "git_protocol") if err != nil { return err diff --git a/pkg/cmd/gist/create/create.go b/pkg/cmd/gist/create/create.go index 8dfd19664..34476afef 100644 --- a/pkg/cmd/gist/create/create.go +++ b/pkg/cmd/gist/create/create.go @@ -143,7 +143,7 @@ func createRun(opts *CreateOptions) error { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() opts.IO.StartProgressIndicator() gist, err := createGist(httpClient, host, opts.Description, opts.Public, files) diff --git a/pkg/cmd/gist/delete/delete.go b/pkg/cmd/gist/delete/delete.go index b4e76a69a..22d27b65d 100644 --- a/pkg/cmd/gist/delete/delete.go +++ b/pkg/cmd/gist/delete/delete.go @@ -64,7 +64,7 @@ func deleteRun(opts *DeleteOptions) error { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() apiClient := api.NewClientFromHTTP(client) if err := deleteGist(apiClient, host, gistID); err != nil { diff --git a/pkg/cmd/gist/edit/edit.go b/pkg/cmd/gist/edit/edit.go index a96b3feb7..773ea41df 100644 --- a/pkg/cmd/gist/edit/edit.go +++ b/pkg/cmd/gist/edit/edit.go @@ -107,7 +107,7 @@ func editRun(opts *EditOptions) error { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() gist, err := shared.GetGist(client, host, gistID) if err != nil { diff --git a/pkg/cmd/gist/list/list.go b/pkg/cmd/gist/list/list.go index 733b0d8f5..c49f1fffd 100644 --- a/pkg/cmd/gist/list/list.go +++ b/pkg/cmd/gist/list/list.go @@ -76,7 +76,7 @@ func listRun(opts *ListOptions) error { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() gists, err := shared.ListGists(client, host, opts.Limit, opts.Visibility) if err != nil { diff --git a/pkg/cmd/gist/view/view.go b/pkg/cmd/gist/view/view.go index 489994f54..545681522 100644 --- a/pkg/cmd/gist/view/view.go +++ b/pkg/cmd/gist/view/view.go @@ -85,7 +85,7 @@ func viewRun(opts *ViewOptions) error { return err } - hostname, _ := cfg.DefaultHost() + hostname, _ := cfg.Authentication().DefaultHost() cs := opts.IO.ColorScheme() if gistID == "" { diff --git a/pkg/cmd/gpg-key/add/add.go b/pkg/cmd/gpg-key/add/add.go index 37ba07678..54482d029 100644 --- a/pkg/cmd/gpg-key/add/add.go +++ b/pkg/cmd/gpg-key/add/add.go @@ -79,7 +79,7 @@ func runAdd(opts *AddOptions) error { return err } - hostname, _ := cfg.DefaultHost() + hostname, _ := cfg.Authentication().DefaultHost() err = gpgKeyUpload(httpClient, hostname, keyReader, opts.Title) if err != nil { diff --git a/pkg/cmd/gpg-key/delete/delete.go b/pkg/cmd/gpg-key/delete/delete.go index b8569f72d..0f152297a 100644 --- a/pkg/cmd/gpg-key/delete/delete.go +++ b/pkg/cmd/gpg-key/delete/delete.go @@ -65,7 +65,7 @@ func deleteRun(opts *DeleteOptions) error { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() gpgKeys, err := getGPGKeys(httpClient, host) if err != nil { return err diff --git a/pkg/cmd/gpg-key/list/list.go b/pkg/cmd/gpg-key/list/list.go index 3a52c64fe..623cc63e6 100644 --- a/pkg/cmd/gpg-key/list/list.go +++ b/pkg/cmd/gpg-key/list/list.go @@ -54,7 +54,7 @@ func listRun(opts *ListOptions) error { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() gpgKeys, err := userKeys(apiClient, host, "") if err != nil { diff --git a/pkg/cmd/repo/archive/archive.go b/pkg/cmd/repo/archive/archive.go index 50b1e7908..fd2716111 100644 --- a/pkg/cmd/repo/archive/archive.go +++ b/pkg/cmd/repo/archive/archive.go @@ -87,7 +87,7 @@ func archiveRun(opts *ArchiveOptions) error { return err } - hostname, _ := cfg.DefaultHost() + hostname, _ := cfg.Authentication().DefaultHost() currentUser, err := api.CurrentLoginName(apiClient, hostname) if err != nil { diff --git a/pkg/cmd/repo/clone/clone.go b/pkg/cmd/repo/clone/clone.go index 4a5ce3e82..fc39013c0 100644 --- a/pkg/cmd/repo/clone/clone.go +++ b/pkg/cmd/repo/clone/clone.go @@ -114,7 +114,7 @@ func cloneRun(opts *CloneOptions) error { if repositoryIsFullName { fullName = opts.Repository } else { - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() currentUser, err := api.CurrentLoginName(apiClient, host) if err != nil { return err diff --git a/pkg/cmd/repo/create/create.go b/pkg/cmd/repo/create/create.go index 085857dc1..6d6eeef30 100644 --- a/pkg/cmd/repo/create/create.go +++ b/pkg/cmd/repo/create/create.go @@ -206,7 +206,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co if err != nil { return nil, cobra.ShellCompDirectiveError } - hostname, _ := cfg.DefaultHost() + hostname, _ := cfg.Authentication().DefaultHost() results, err := listGitIgnoreTemplates(httpClient, hostname) if err != nil { return nil, cobra.ShellCompDirectiveError @@ -223,7 +223,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co if err != nil { return nil, cobra.ShellCompDirectiveError } - hostname, _ := cfg.DefaultHost() + hostname, _ := cfg.Authentication().DefaultHost() licenses, err := listLicenseTemplates(httpClient, hostname) if err != nil { return nil, cobra.ShellCompDirectiveError @@ -271,7 +271,7 @@ func createFromScratch(opts *CreateOptions) error { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() if opts.Interactive { opts.Name, opts.Description, opts.Visibility, err = interactiveRepoInfo(httpClient, host, opts.Prompter, "") @@ -420,7 +420,7 @@ func createFromLocal(opts *CreateOptions) error { if err != nil { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() if opts.Interactive { var err error diff --git a/pkg/cmd/repo/fork/fork.go b/pkg/cmd/repo/fork/fork.go index 0721485a8..8334cf0e5 100644 --- a/pkg/cmd/repo/fork/fork.go +++ b/pkg/cmd/repo/fork/fork.go @@ -2,6 +2,7 @@ package fork import ( "context" + "errors" "fmt" "net/http" "net/url" @@ -9,6 +10,7 @@ import ( "time" "github.com/MakeNowJust/heredoc" + "github.com/cenkalti/backoff/v4" "github.com/cli/cli/v2/api" ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/git" @@ -32,6 +34,7 @@ type ForkOptions struct { BaseRepo func() (ghrepo.Interface, error) Remotes func() (ghContext.Remotes, error) Since func(time.Time) time.Duration + BackOff backoff.BackOff GitArgs []string Repository string @@ -46,6 +49,10 @@ type ForkOptions struct { DefaultBranchOnly bool } +type errWithExitCode interface { + ExitCode() int +} + // TODO warn about useless flags (--remote, --remote-name) when running from outside a repository // TODO output over STDOUT not STDERR // TODO remote-name has no effect on its own; error that or change behavior @@ -202,7 +209,7 @@ func forkRun(opts *ForkOptions) error { cs.Bold(ghrepo.FullName(forkedRepo)), "already exists") } else { - fmt.Fprintf(stderr, "%s already exists", ghrepo.FullName(forkedRepo)) + fmt.Fprintf(stderr, "%s already exists\n", ghrepo.FullName(forkedRepo)) } } else { if connectedToTerminal { @@ -317,8 +324,25 @@ func forkRun(opts *ForkOptions) error { } } if cloneDesired { - forkedRepoURL := ghrepo.FormatRemoteURL(forkedRepo, protocol) - cloneDir, err := gitClient.Clone(ctx, forkedRepoURL, opts.GitArgs) + // Allow injecting alternative BackOff in tests. + if opts.BackOff == nil { + bo := backoff.NewConstantBackOff(3 * time.Second) + opts.BackOff = bo + } + + cloneDir, err := backoff.RetryWithData(func() (string, error) { + forkedRepoURL := ghrepo.FormatRemoteURL(forkedRepo, protocol) + dir, err := gitClient.Clone(ctx, forkedRepoURL, opts.GitArgs) + if err == nil { + return dir, err + } + var execError errWithExitCode + if errors.As(err, &execError) && execError.ExitCode() == 128 { + return "", err + } + return "", backoff.Permanent(err) + }, backoff.WithContext(backoff.WithMaxRetries(opts.BackOff, 3), ctx)) + if err != nil { return fmt.Errorf("failed to clone fork: %w", err) } diff --git a/pkg/cmd/repo/fork/fork_test.go b/pkg/cmd/repo/fork/fork_test.go index 8022cd2dc..5833acb5d 100644 --- a/pkg/cmd/repo/fork/fork_test.go +++ b/pkg/cmd/repo/fork/fork_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/cenkalti/backoff/v4" "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/config" @@ -406,7 +407,7 @@ func TestRepoFork(t *testing.T) { }, }, httpStubs: forkPost, - wantErrOut: "someone/REPO already exists", + wantErrOut: "someone/REPO already exists\n", }, { name: "implicit nontty --remote", @@ -559,7 +560,7 @@ func TestRepoFork(t *testing.T) { }, }, httpStubs: forkPost, - wantErrOut: "someone/REPO already exists", + wantErrOut: "someone/REPO already exists\n", }, { name: "repo arg nontty clone arg already exists", @@ -576,7 +577,7 @@ func TestRepoFork(t *testing.T) { cs.Register(`git -C REPO remote add upstream https://github\.com/OWNER/REPO\.git`, 0, "") cs.Register(`git -C REPO fetch upstream`, 0, "") }, - wantErrOut: "someone/REPO already exists", + wantErrOut: "someone/REPO already exists\n", }, { name: "repo arg nontty clone arg", @@ -663,78 +664,111 @@ func TestRepoFork(t *testing.T) { }, wantErrOut: "āœ“ Created fork OWNER/REPO\nāœ“ Renamed fork to OWNER/NEW_REPO\n", }, + { + name: "retries clone up to four times if necessary", + opts: &ForkOptions{ + Repository: "OWNER/REPO", + Clone: true, + BackOff: &backoff.ZeroBackOff{}, + }, + httpStubs: forkPost, + execStubs: func(cs *run.CommandStubber) { + cs.Register(`git clone https://github.com/someone/REPO\.git`, 128, "") + cs.Register(`git clone https://github.com/someone/REPO\.git`, 128, "") + cs.Register(`git clone https://github.com/someone/REPO\.git`, 128, "") + cs.Register(`git clone https://github.com/someone/REPO\.git`, 0, "") + cs.Register(`git -C REPO remote add upstream https://github\.com/OWNER/REPO\.git`, 0, "") + cs.Register(`git -C REPO fetch upstream`, 0, "") + }, + }, + { + name: "does not retry clone if error occurs and exit code is not 128", + opts: &ForkOptions{ + Repository: "OWNER/REPO", + Clone: true, + BackOff: &backoff.ZeroBackOff{}, + }, + httpStubs: forkPost, + execStubs: func(cs *run.CommandStubber) { + cs.Register(`git clone https://github.com/someone/REPO\.git`, 128, "") + cs.Register(`git clone https://github.com/someone/REPO\.git`, 65, "") + }, + wantErr: true, + errMsg: `failed to clone fork: failed to run git: git -c credential.helper= -c credential.helper=!"[^"]+" auth git-credential clone https://github.com/someone/REPO\.git exited with status 65`, + }, } for _, tt := range tests { - ios, _, stdout, stderr := iostreams.Test() - ios.SetStdinTTY(tt.tty) - ios.SetStdoutTTY(tt.tty) - ios.SetStderrTTY(tt.tty) - tt.opts.IO = ios - - tt.opts.BaseRepo = func() (ghrepo.Interface, error) { - return ghrepo.New("OWNER", "REPO"), nil - } - - reg := &httpmock.Registry{} - if tt.httpStubs != nil { - tt.httpStubs(reg) - } - tt.opts.HttpClient = func() (*http.Client, error) { - return &http.Client{Transport: reg}, nil - } - - cfg := config.NewBlankConfig() - if tt.cfgStubs != nil { - tt.cfgStubs(cfg) - } - tt.opts.Config = func() (config.Config, error) { - return cfg, nil - } - - tt.opts.Remotes = func() (context.Remotes, error) { - if tt.remotes == nil { - return []*context.Remote{ - { - Remote: &git.Remote{ - Name: "origin", - FetchURL: &url.URL{}, - }, - Repo: ghrepo.New("OWNER", "REPO"), - }, - }, nil - } - return tt.remotes, nil - } - - tt.opts.GitClient = &git.Client{ - GhPath: "some/path/gh", - GitPath: "some/path/git", - } - - //nolint:staticcheck // SA1019: prompt.InitAskStubber is deprecated: use NewAskStubber - as, teardown := prompt.InitAskStubber() - defer teardown() - if tt.askStubs != nil { - tt.askStubs(as) - } - cs, restoreRun := run.Stub() - defer restoreRun(t) - if tt.execStubs != nil { - tt.execStubs(cs) - } - t.Run(tt.name, func(t *testing.T) { + ios, _, stdout, stderr := iostreams.Test() + ios.SetStdinTTY(tt.tty) + ios.SetStdoutTTY(tt.tty) + ios.SetStderrTTY(tt.tty) + tt.opts.IO = ios + + tt.opts.BaseRepo = func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + } + + reg := &httpmock.Registry{} + if tt.httpStubs != nil { + tt.httpStubs(reg) + } + tt.opts.HttpClient = func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + } + + cfg := config.NewBlankConfig() + if tt.cfgStubs != nil { + tt.cfgStubs(cfg) + } + tt.opts.Config = func() (config.Config, error) { + return cfg, nil + } + + tt.opts.Remotes = func() (context.Remotes, error) { + if tt.remotes == nil { + return []*context.Remote{ + { + Remote: &git.Remote{ + Name: "origin", + FetchURL: &url.URL{}, + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + } + return tt.remotes, nil + } + + tt.opts.GitClient = &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + } + + //nolint:staticcheck // SA1019: prompt.InitAskStubber is deprecated: use NewAskStubber + as, teardown := prompt.InitAskStubber() + defer teardown() + if tt.askStubs != nil { + tt.askStubs(as) + } + + cs, restoreRun := run.Stub() + defer restoreRun(t) + if tt.execStubs != nil { + tt.execStubs(cs) + } + if tt.opts.Since == nil { tt.opts.Since = func(t time.Time) time.Duration { return 2 * time.Second } } + defer reg.Verify(t) err := forkRun(tt.opts) if tt.wantErr { - assert.Error(t, err) - assert.Equal(t, tt.errMsg, err.Error()) + assert.Error(t, err, tt.errMsg) return } diff --git a/pkg/cmd/repo/garden/garden.go b/pkg/cmd/repo/garden/garden.go index f8e0722f3..c9c59d420 100644 --- a/pkg/cmd/repo/garden/garden.go +++ b/pkg/cmd/repo/garden/garden.go @@ -155,7 +155,7 @@ func gardenRun(opts *GardenOptions) error { if err != nil { return err } - hostname, _ := cfg.DefaultHost() + hostname, _ := cfg.Authentication().DefaultHost() currentUser, err := api.CurrentLoginName(apiClient, hostname) if err != nil { return err diff --git a/pkg/cmd/repo/list/list.go b/pkg/cmd/repo/list/list.go index d6ebd1b0a..f9c0ca8ac 100644 --- a/pkg/cmd/repo/list/list.go +++ b/pkg/cmd/repo/list/list.go @@ -119,7 +119,7 @@ func listRun(opts *ListOptions) error { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() if opts.Detector == nil { cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) diff --git a/pkg/cmd/repo/view/view.go b/pkg/cmd/repo/view/view.go index 5074d02ee..722301397 100644 --- a/pkg/cmd/repo/view/view.go +++ b/pkg/cmd/repo/view/view.go @@ -94,7 +94,7 @@ func viewRun(opts *ViewOptions) error { if err != nil { return err } - hostname, _ := cfg.DefaultHost() + hostname, _ := cfg.Authentication().DefaultHost() currentUser, err := api.CurrentLoginName(apiClient, hostname) if err != nil { return err diff --git a/pkg/cmd/root/root.go b/pkg/cmd/root/root.go index ce0a09cec..718561df1 100644 --- a/pkg/cmd/root/root.go +++ b/pkg/cmd/root/root.go @@ -144,7 +144,7 @@ func bareHTTPClient(f *cmdutil.Factory, version string) func() (*http.Client, er } opts := api.HTTPClientOptions{ AppVersion: version, - Config: cfg, + Config: cfg.Authentication(), Log: f.IOStreams.ErrOut, LogColorize: f.IOStreams.ColorEnabled(), SkipAcceptHeaders: true, diff --git a/pkg/cmd/search/shared/shared.go b/pkg/cmd/search/shared/shared.go index e751848d4..141ab5421 100644 --- a/pkg/cmd/search/shared/shared.go +++ b/pkg/cmd/search/shared/shared.go @@ -42,7 +42,7 @@ func Searcher(f *cmdutil.Factory) (search.Searcher, error) { if err != nil { return nil, err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() client, err := f.HttpClient() if err != nil { return nil, err diff --git a/pkg/cmd/secret/delete/delete.go b/pkg/cmd/secret/delete/delete.go index fc75a490d..ce5275d87 100644 --- a/pkg/cmd/secret/delete/delete.go +++ b/pkg/cmd/secret/delete/delete.go @@ -122,7 +122,7 @@ func removeRun(opts *DeleteOptions) error { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() err = client.REST(host, "DELETE", path, nil, nil) if err != nil { diff --git a/pkg/cmd/secret/list/list.go b/pkg/cmd/secret/list/list.go index 4844e8cc7..faee8ae3e 100644 --- a/pkg/cmd/secret/list/list.go +++ b/pkg/cmd/secret/list/list.go @@ -123,7 +123,7 @@ func listRun(opts *ListOptions) error { return err } - host, _ = cfg.DefaultHost() + host, _ = cfg.Authentication().DefaultHost() if secretEntity == shared.User { secrets, err = getUserSecrets(client, host, showSelectedRepoInfo) diff --git a/pkg/cmd/secret/set/set.go b/pkg/cmd/secret/set/set.go index 793d14cef..f540f112c 100644 --- a/pkg/cmd/secret/set/set.go +++ b/pkg/cmd/secret/set/set.go @@ -186,7 +186,7 @@ func setRun(opts *SetOptions) error { if err != nil { return err } - host, _ = cfg.DefaultHost() + host, _ = cfg.Authentication().DefaultHost() } secretEntity, err := shared.GetSecretEntity(orgName, envName, opts.UserSecrets) diff --git a/pkg/cmd/ssh-key/add/add.go b/pkg/cmd/ssh-key/add/add.go index 5b0cb7aa5..e3f5088d2 100644 --- a/pkg/cmd/ssh-key/add/add.go +++ b/pkg/cmd/ssh-key/add/add.go @@ -77,7 +77,7 @@ func runAdd(opts *AddOptions) error { return err } - hostname, _ := cfg.DefaultHost() + hostname, _ := cfg.Authentication().DefaultHost() err = SSHKeyUpload(httpClient, hostname, keyReader, opts.Title) if err != nil { diff --git a/pkg/cmd/ssh-key/delete/delete.go b/pkg/cmd/ssh-key/delete/delete.go index de53f1391..b6ed9f014 100644 --- a/pkg/cmd/ssh-key/delete/delete.go +++ b/pkg/cmd/ssh-key/delete/delete.go @@ -66,7 +66,7 @@ func deleteRun(opts *DeleteOptions) error { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() key, err := getSSHKey(httpClient, host, opts.KeyID) if err != nil { return err diff --git a/pkg/cmd/ssh-key/list/list.go b/pkg/cmd/ssh-key/list/list.go index 376404d42..d057b198c 100644 --- a/pkg/cmd/ssh-key/list/list.go +++ b/pkg/cmd/ssh-key/list/list.go @@ -53,7 +53,7 @@ func listRun(opts *ListOptions) error { return err } - host, _ := cfg.DefaultHost() + host, _ := cfg.Authentication().DefaultHost() sshKeys, err := userKeys(apiClient, host, "") if err != nil { diff --git a/pkg/cmd/status/status.go b/pkg/cmd/status/status.go index c73db711a..438175d72 100644 --- a/pkg/cmd/status/status.go +++ b/pkg/cmd/status/status.go @@ -68,7 +68,7 @@ func NewCmdStatus(f *cmdutil.Factory, runF func(*StatusOptions) error) *cobra.Co return err } - opts.HostConfig = cfg + opts.HostConfig = cfg.Authentication() if runF != nil { return runF(opts) diff --git a/pkg/cmdutil/auth_check.go b/pkg/cmdutil/auth_check.go index 03f76de6f..ea18d1502 100644 --- a/pkg/cmdutil/auth_check.go +++ b/pkg/cmdutil/auth_check.go @@ -18,12 +18,12 @@ func CheckAuth(cfg config.Config) bool { // authentication tokens set for enterprise hosts. // Any non-github.com hostname is fine here dummyHostname := "example.com" - token, _ := cfg.AuthToken(dummyHostname) + token, _ := cfg.Authentication().Token(dummyHostname) if token != "" { return true } - if len(cfg.Hosts()) > 0 { + if len(cfg.Authentication().Hosts()) > 0 { return true } diff --git a/pkg/cmdutil/auth_check_test.go b/pkg/cmdutil/auth_check_test.go index 797824f1c..6bef0d413 100644 --- a/pkg/cmdutil/auth_check_test.go +++ b/pkg/cmdutil/auth_check_test.go @@ -21,8 +21,10 @@ func Test_CheckAuth(t *testing.T) { { name: "no known hosts, env auth token", cfgStubs: func(c *config.ConfigMock) { - c.AuthTokenFunc = func(string) (string, string) { - return "token", "GITHUB_TOKEN" + c.AuthenticationFunc = func() *config.AuthConfig { + authCfg := &config.AuthConfig{} + authCfg.SetToken("token", "GITHUB_TOKEN") + return authCfg } }, expected: true, diff --git a/pkg/cmdutil/repo_override.go b/pkg/cmdutil/repo_override.go index 29043d54d..791dd919a 100644 --- a/pkg/cmdutil/repo_override.go +++ b/pkg/cmdutil/repo_override.go @@ -31,7 +31,7 @@ func EnableRepoOverride(cmd *cobra.Command, f *Factory) { if err != nil { return nil, cobra.ShellCompDirectiveError } - defaultHost, _ := config.DefaultHost() + defaultHost, _ := config.Authentication().DefaultHost() var results []string for _, remote := range remotes { diff --git a/pkg/liveshare/test/server.go b/pkg/liveshare/test/server.go index 5dd4e56aa..0b4f6a7ba 100644 --- a/pkg/liveshare/test/server.go +++ b/pkg/liveshare/test/server.go @@ -242,16 +242,17 @@ func handleRequests(ctx context.Context, server *Server, channel ssh.Channel, re errc := make(chan error, 1) go func() { for req := range reqs { - if req.WantReply { - if err := req.Reply(true, nil); err != nil { + r := req + if r.WantReply { + if err := r.Reply(true, nil); err != nil { sendError(errc, fmt.Errorf("error replying to channel request: %w", err)) return } } - if strings.HasPrefix(req.Type, "stream-transport") { + if strings.HasPrefix(r.Type, "stream-transport") { go func() { - if err := forwardStream(ctx, server, req.Type, channel); err != nil { + if err := forwardStream(ctx, server, r.Type, channel); err != nil { sendError(errc, fmt.Errorf("failed to forward stream: %w", err)) } }()