Handle SAML enforcement challenge from the server (#5054)
Whenever a SSO challenge gets issued by the server by means of a URL in the `X-GitHub-SSO` response header, gh now additionally prints that URL on the standard error stream. This is achieved by installing a middleware to all HTTP requests and storing the server challenge to a value accessible by `factory.SSOURL()`. Such approach was made necessary mainly because of the `shurcool-graphql` client which doesn't give access to response headers when a GraphQL error case is encountered.
This commit is contained in:
parent
f950637b0b
commit
66c18b40f2
5 changed files with 63 additions and 7 deletions
|
|
@ -98,6 +98,22 @@ func ReplaceTripper(tr http.RoundTripper) ClientOption {
|
|||
}
|
||||
}
|
||||
|
||||
// ExtractHeader extracts a named header from any response received by this client and, if non-blank, saves
|
||||
// it to dest.
|
||||
func ExtractHeader(name string, dest *string) ClientOption {
|
||||
return func(tr http.RoundTripper) http.RoundTripper {
|
||||
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err == nil {
|
||||
if value := res.Header.Get(name); value != "" {
|
||||
*dest = value
|
||||
}
|
||||
}
|
||||
return res, err
|
||||
}}
|
||||
}
|
||||
}
|
||||
|
||||
type funcTripper struct {
|
||||
roundTrip func(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -224,8 +224,9 @@ func mainRun() exitCode {
|
|||
var httpErr api.HTTPError
|
||||
if errors.As(err, &httpErr) && httpErr.StatusCode == 401 {
|
||||
fmt.Fprintln(stderr, "Try authenticating with: gh auth login")
|
||||
} else if strings.Contains(err.Error(), "Resource protected by organization SAML enforcement") {
|
||||
fmt.Fprintln(stderr, "Try re-authenticating with: gh auth refresh")
|
||||
} else if u := factory.SSOURL(); u != "" {
|
||||
// handles organization SAML enforcement error
|
||||
fmt.Fprintf(stderr, "Authorize in your web browser: %s\n", u)
|
||||
} else if msg := httpErr.ScopesSuggestion(); msg != "" {
|
||||
fmt.Fprintln(stderr, msg)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/cli/cli/v2/internal/config"
|
||||
"github.com/cli/cli/v2/internal/ghinstance"
|
||||
"github.com/cli/cli/v2/internal/ghrepo"
|
||||
"github.com/cli/cli/v2/pkg/cmd/factory"
|
||||
"github.com/cli/cli/v2/pkg/cmdutil"
|
||||
"github.com/cli/cli/v2/pkg/export"
|
||||
"github.com/cli/cli/v2/pkg/iostreams"
|
||||
|
|
@ -393,6 +394,9 @@ func processResponse(resp *http.Response, opts *ApiOptions, headersOutputStream
|
|||
if msg := api.ScopesSuggestion(resp); msg != "" {
|
||||
fmt.Fprintf(opts.IO.ErrOut, "gh: %s\n", msg)
|
||||
}
|
||||
if u := factory.SSOURL(); u != "" {
|
||||
fmt.Fprintf(opts.IO.ErrOut, "Authorize in your web browser: %s\n", u)
|
||||
}
|
||||
err = cmdutil.SilentError
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
|
@ -54,7 +55,6 @@ var timezoneNames = map[int]string{
|
|||
}
|
||||
|
||||
type configGetter interface {
|
||||
GetOrDefault(string, string) (string, error)
|
||||
Get(string, string) (string, error)
|
||||
}
|
||||
|
||||
|
|
@ -108,6 +108,7 @@ func NewHTTPClient(io *iostreams.IOStreams, cfg configGetter, appVersion string,
|
|||
}
|
||||
return "", nil
|
||||
}),
|
||||
api.ExtractHeader("X-GitHub-SSO", &ssoHeader),
|
||||
)
|
||||
|
||||
if setAccept {
|
||||
|
|
@ -127,6 +128,22 @@ func NewHTTPClient(io *iostreams.IOStreams, cfg configGetter, appVersion string,
|
|||
return api.NewHTTPClient(opts...), nil
|
||||
}
|
||||
|
||||
var ssoHeader string
|
||||
var ssoURLRE = regexp.MustCompile(`\burl=([^;]+)`)
|
||||
|
||||
// SSOURL returns the URL of a SAML SSO challenge received by the server for clients that use ExtractHeader
|
||||
// to extract the value of the "X-GitHub-SSO" response header.
|
||||
func SSOURL() string {
|
||||
if ssoHeader == "" {
|
||||
return ""
|
||||
}
|
||||
m := ssoURLRE.FindStringSubmatch(ssoHeader)
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return m[1]
|
||||
}
|
||||
|
||||
func getHost(r *http.Request) string {
|
||||
if r.Host != "" {
|
||||
return r.Host
|
||||
|
|
|
|||
|
|
@ -25,8 +25,10 @@ func TestNewHTTPClient(t *testing.T) {
|
|||
args args
|
||||
envDebug string
|
||||
host string
|
||||
sso string
|
||||
wantHeader map[string]string
|
||||
wantStderr string
|
||||
wantSSO string
|
||||
}{
|
||||
{
|
||||
name: "github.com with Accept header",
|
||||
|
|
@ -117,11 +119,25 @@ func TestNewHTTPClient(t *testing.T) {
|
|||
},
|
||||
wantStderr: "",
|
||||
},
|
||||
{
|
||||
name: "SSO challenge in response header",
|
||||
args: args{
|
||||
config: tinyConfig{},
|
||||
appVersion: "v1.2.3",
|
||||
},
|
||||
host: "github.com",
|
||||
sso: "required; url=https://github.com/login/sso?return_to=xyz¶m=123abc; another",
|
||||
wantStderr: "",
|
||||
wantSSO: "https://github.com/login/sso?return_to=xyz¶m=123abc",
|
||||
},
|
||||
}
|
||||
|
||||
var gotReq *http.Request
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotReq = r
|
||||
if sso := r.URL.Query().Get("sso"); sso != "" {
|
||||
w.Header().Set("X-GitHub-SSO", sso)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
|
@ -139,6 +155,11 @@ func TestNewHTTPClient(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
req, err := http.NewRequest("GET", ts.URL, nil)
|
||||
if tt.sso != "" {
|
||||
q := req.URL.Query()
|
||||
q.Set("sso", tt.sso)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
req.Host = tt.host
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -151,16 +172,13 @@ func TestNewHTTPClient(t *testing.T) {
|
|||
|
||||
assert.Equal(t, 204, res.StatusCode)
|
||||
assert.Equal(t, tt.wantStderr, normalizeVerboseLog(stderr.String()))
|
||||
assert.Equal(t, tt.wantSSO, SSOURL())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type tinyConfig map[string]string
|
||||
|
||||
func (c tinyConfig) GetOrDefault(host, key string) (string, error) {
|
||||
return c[fmt.Sprintf("%s:%s", host, key)], nil
|
||||
}
|
||||
|
||||
func (c tinyConfig) Get(host, key string) (string, error) {
|
||||
return c[fmt.Sprintf("%s:%s", host, key)], nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue