Handle SAML enforcement challenge from the server (#5054)

Whenever a SSO challenge gets issued by the server by means of a URL in
the `X-GitHub-SSO` response header, gh now additionally prints that URL
on the standard error stream.

This is achieved by installing a middleware to all HTTP requests and
storing the server challenge to a value accessible by `factory.SSOURL()`.
Such approach was made necessary mainly because of the
`shurcool-graphql` client which doesn't give access to response headers
when a GraphQL error case is encountered.
This commit is contained in:
Mislav Marohnić 2022-01-19 14:22:22 +01:00 committed by GitHub
parent f950637b0b
commit 66c18b40f2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 7 deletions

View file

@ -98,6 +98,22 @@ func ReplaceTripper(tr http.RoundTripper) ClientOption {
}
}
// ExtractHeader extracts a named header from any response received by this client and, if non-blank, saves
// it to dest.
func ExtractHeader(name string, dest *string) ClientOption {
return func(tr http.RoundTripper) http.RoundTripper {
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
res, err := tr.RoundTrip(req)
if err == nil {
if value := res.Header.Get(name); value != "" {
*dest = value
}
}
return res, err
}}
}
}
type funcTripper struct {
roundTrip func(*http.Request) (*http.Response, error)
}

View file

@ -224,8 +224,9 @@ func mainRun() exitCode {
var httpErr api.HTTPError
if errors.As(err, &httpErr) && httpErr.StatusCode == 401 {
fmt.Fprintln(stderr, "Try authenticating with: gh auth login")
} else if strings.Contains(err.Error(), "Resource protected by organization SAML enforcement") {
fmt.Fprintln(stderr, "Try re-authenticating with: gh auth refresh")
} else if u := factory.SSOURL(); u != "" {
// handles organization SAML enforcement error
fmt.Fprintf(stderr, "Authorize in your web browser: %s\n", u)
} else if msg := httpErr.ScopesSuggestion(); msg != "" {
fmt.Fprintln(stderr, msg)
}

View file

@ -21,6 +21,7 @@ import (
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/cmd/factory"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/export"
"github.com/cli/cli/v2/pkg/iostreams"
@ -393,6 +394,9 @@ func processResponse(resp *http.Response, opts *ApiOptions, headersOutputStream
if msg := api.ScopesSuggestion(resp); msg != "" {
fmt.Fprintf(opts.IO.ErrOut, "gh: %s\n", msg)
}
if u := factory.SSOURL(); u != "" {
fmt.Fprintf(opts.IO.ErrOut, "Authorize in your web browser: %s\n", u)
}
err = cmdutil.SilentError
return
}

View file

@ -4,6 +4,7 @@ import (
"fmt"
"net/http"
"os"
"regexp"
"strings"
"time"
@ -54,7 +55,6 @@ var timezoneNames = map[int]string{
}
type configGetter interface {
GetOrDefault(string, string) (string, error)
Get(string, string) (string, error)
}
@ -108,6 +108,7 @@ func NewHTTPClient(io *iostreams.IOStreams, cfg configGetter, appVersion string,
}
return "", nil
}),
api.ExtractHeader("X-GitHub-SSO", &ssoHeader),
)
if setAccept {
@ -127,6 +128,22 @@ func NewHTTPClient(io *iostreams.IOStreams, cfg configGetter, appVersion string,
return api.NewHTTPClient(opts...), nil
}
var ssoHeader string
var ssoURLRE = regexp.MustCompile(`\burl=([^;]+)`)
// SSOURL returns the URL of a SAML SSO challenge received by the server for clients that use ExtractHeader
// to extract the value of the "X-GitHub-SSO" response header.
func SSOURL() string {
if ssoHeader == "" {
return ""
}
m := ssoURLRE.FindStringSubmatch(ssoHeader)
if m == nil {
return ""
}
return m[1]
}
func getHost(r *http.Request) string {
if r.Host != "" {
return r.Host

View file

@ -25,8 +25,10 @@ func TestNewHTTPClient(t *testing.T) {
args args
envDebug string
host string
sso string
wantHeader map[string]string
wantStderr string
wantSSO string
}{
{
name: "github.com with Accept header",
@ -117,11 +119,25 @@ func TestNewHTTPClient(t *testing.T) {
},
wantStderr: "",
},
{
name: "SSO challenge in response header",
args: args{
config: tinyConfig{},
appVersion: "v1.2.3",
},
host: "github.com",
sso: "required; url=https://github.com/login/sso?return_to=xyz&param=123abc; another",
wantStderr: "",
wantSSO: "https://github.com/login/sso?return_to=xyz&param=123abc",
},
}
var gotReq *http.Request
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotReq = r
if sso := r.URL.Query().Get("sso"); sso != "" {
w.Header().Set("X-GitHub-SSO", sso)
}
w.WriteHeader(http.StatusNoContent)
}))
defer ts.Close()
@ -139,6 +155,11 @@ func TestNewHTTPClient(t *testing.T) {
require.NoError(t, err)
req, err := http.NewRequest("GET", ts.URL, nil)
if tt.sso != "" {
q := req.URL.Query()
q.Set("sso", tt.sso)
req.URL.RawQuery = q.Encode()
}
req.Host = tt.host
require.NoError(t, err)
@ -151,16 +172,13 @@ func TestNewHTTPClient(t *testing.T) {
assert.Equal(t, 204, res.StatusCode)
assert.Equal(t, tt.wantStderr, normalizeVerboseLog(stderr.String()))
assert.Equal(t, tt.wantSSO, SSOURL())
})
}
}
type tinyConfig map[string]string
func (c tinyConfig) GetOrDefault(host, key string) (string, error) {
return c[fmt.Sprintf("%s:%s", host, key)], nil
}
func (c tinyConfig) Get(host, key string) (string, error) {
return c[fmt.Sprintf("%s:%s", host, key)], nil
}