diff --git a/api/client.go b/api/client.go index 3fcf0d930..bf6827d34 100644 --- a/api/client.go +++ b/api/client.go @@ -98,6 +98,22 @@ func ReplaceTripper(tr http.RoundTripper) ClientOption { } } +// ExtractHeader extracts a named header from any response received by this client and, if non-blank, saves +// it to dest. +func ExtractHeader(name string, dest *string) ClientOption { + return func(tr http.RoundTripper) http.RoundTripper { + return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + res, err := tr.RoundTrip(req) + if err == nil { + if value := res.Header.Get(name); value != "" { + *dest = value + } + } + return res, err + }} + } +} + type funcTripper struct { roundTrip func(*http.Request) (*http.Response, error) } diff --git a/cmd/gh/main.go b/cmd/gh/main.go index 50f8335a3..10cd94c2e 100644 --- a/cmd/gh/main.go +++ b/cmd/gh/main.go @@ -224,8 +224,9 @@ func mainRun() exitCode { var httpErr api.HTTPError if errors.As(err, &httpErr) && httpErr.StatusCode == 401 { fmt.Fprintln(stderr, "Try authenticating with: gh auth login") - } else if strings.Contains(err.Error(), "Resource protected by organization SAML enforcement") { - fmt.Fprintln(stderr, "Try re-authenticating with: gh auth refresh") + } else if u := factory.SSOURL(); u != "" { + // handles organization SAML enforcement error + fmt.Fprintf(stderr, "Authorize in your web browser: %s\n", u) } else if msg := httpErr.ScopesSuggestion(); msg != "" { fmt.Fprintln(stderr, msg) } diff --git a/pkg/cmd/api/api.go b/pkg/cmd/api/api.go index 473139faf..08c149360 100644 --- a/pkg/cmd/api/api.go +++ b/pkg/cmd/api/api.go @@ -21,6 +21,7 @@ import ( "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/ghinstance" "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/pkg/cmd/factory" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/export" "github.com/cli/cli/v2/pkg/iostreams" @@ -393,6 +394,9 @@ func processResponse(resp *http.Response, opts *ApiOptions, headersOutputStream if msg := api.ScopesSuggestion(resp); msg != "" { fmt.Fprintf(opts.IO.ErrOut, "gh: %s\n", msg) } + if u := factory.SSOURL(); u != "" { + fmt.Fprintf(opts.IO.ErrOut, "Authorize in your web browser: %s\n", u) + } err = cmdutil.SilentError return } diff --git a/pkg/cmd/factory/http.go b/pkg/cmd/factory/http.go index d1b8b54ed..7037b1558 100644 --- a/pkg/cmd/factory/http.go +++ b/pkg/cmd/factory/http.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "os" + "regexp" "strings" "time" @@ -54,7 +55,6 @@ var timezoneNames = map[int]string{ } type configGetter interface { - GetOrDefault(string, string) (string, error) Get(string, string) (string, error) } @@ -108,6 +108,7 @@ func NewHTTPClient(io *iostreams.IOStreams, cfg configGetter, appVersion string, } return "", nil }), + api.ExtractHeader("X-GitHub-SSO", &ssoHeader), ) if setAccept { @@ -127,6 +128,22 @@ func NewHTTPClient(io *iostreams.IOStreams, cfg configGetter, appVersion string, return api.NewHTTPClient(opts...), nil } +var ssoHeader string +var ssoURLRE = regexp.MustCompile(`\burl=([^;]+)`) + +// SSOURL returns the URL of a SAML SSO challenge received by the server for clients that use ExtractHeader +// to extract the value of the "X-GitHub-SSO" response header. +func SSOURL() string { + if ssoHeader == "" { + return "" + } + m := ssoURLRE.FindStringSubmatch(ssoHeader) + if m == nil { + return "" + } + return m[1] +} + func getHost(r *http.Request) string { if r.Host != "" { return r.Host diff --git a/pkg/cmd/factory/http_test.go b/pkg/cmd/factory/http_test.go index 0039289a7..0cb5ac15c 100644 --- a/pkg/cmd/factory/http_test.go +++ b/pkg/cmd/factory/http_test.go @@ -25,8 +25,10 @@ func TestNewHTTPClient(t *testing.T) { args args envDebug string host string + sso string wantHeader map[string]string wantStderr string + wantSSO string }{ { name: "github.com with Accept header", @@ -117,11 +119,25 @@ func TestNewHTTPClient(t *testing.T) { }, wantStderr: "", }, + { + name: "SSO challenge in response header", + args: args{ + config: tinyConfig{}, + appVersion: "v1.2.3", + }, + host: "github.com", + sso: "required; url=https://github.com/login/sso?return_to=xyz¶m=123abc; another", + wantStderr: "", + wantSSO: "https://github.com/login/sso?return_to=xyz¶m=123abc", + }, } var gotReq *http.Request ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotReq = r + if sso := r.URL.Query().Get("sso"); sso != "" { + w.Header().Set("X-GitHub-SSO", sso) + } w.WriteHeader(http.StatusNoContent) })) defer ts.Close() @@ -139,6 +155,11 @@ func TestNewHTTPClient(t *testing.T) { require.NoError(t, err) req, err := http.NewRequest("GET", ts.URL, nil) + if tt.sso != "" { + q := req.URL.Query() + q.Set("sso", tt.sso) + req.URL.RawQuery = q.Encode() + } req.Host = tt.host require.NoError(t, err) @@ -151,16 +172,13 @@ func TestNewHTTPClient(t *testing.T) { assert.Equal(t, 204, res.StatusCode) assert.Equal(t, tt.wantStderr, normalizeVerboseLog(stderr.String())) + assert.Equal(t, tt.wantSSO, SSOURL()) }) } } type tinyConfig map[string]string -func (c tinyConfig) GetOrDefault(host, key string) (string, error) { - return c[fmt.Sprintf("%s:%s", host, key)], nil -} - func (c tinyConfig) Get(host, key string) (string, error) { return c[fmt.Sprintf("%s:%s", host, key)], nil }