Merge pull request #2054 from cli/ghe-auth

Handle edge cases in GHE auth
This commit is contained in:
Mislav Marohnić 2020-10-06 14:27:32 +02:00 committed by GitHub
commit 54e292703b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 161 additions and 13 deletions

View file

@ -40,6 +40,21 @@ type OAuthFlow struct {
TimeSleep func(time.Duration)
}
func detectDeviceFlow(statusCode int, values url.Values) (bool, error) {
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden ||
statusCode == http.StatusNotFound || statusCode == http.StatusUnprocessableEntity ||
(statusCode == http.StatusOK && values == nil) ||
(statusCode == http.StatusBadRequest && values != nil && values.Get("error") == "unauthorized_client") {
return true, nil
} else if statusCode != http.StatusOK {
if values != nil && values.Get("error_description") != "" {
return false, fmt.Errorf("HTTP %d: %s", statusCode, values.Get("error_description"))
}
return false, fmt.Errorf("error: HTTP %d", statusCode)
}
return false, nil
}
// ObtainAccessToken guides the user through the browser OAuth flow on GitHub
// and returns the OAuth access token upon completion.
func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) {
@ -58,28 +73,24 @@ func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) {
defer resp.Body.Close()
var values url.Values
bb, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}
if resp.StatusCode == http.StatusOK || strings.HasPrefix(resp.Header.Get("Content-Type"), "application/x-www-form-urlencoded") {
if strings.Contains(resp.Header.Get("Content-Type"), "application/x-www-form-urlencoded") {
var bb []byte
bb, err = ioutil.ReadAll(resp.Body)
if err != nil {
return
}
values, err = url.ParseQuery(string(bb))
if err != nil {
return
}
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusPaymentRequired ||
resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusUnprocessableEntity ||
(resp.StatusCode == http.StatusBadRequest && values != nil && values.Get("error") == "unauthorized_client") {
if doFallback, err := detectDeviceFlow(resp.StatusCode, values); doFallback {
// OAuth Device Flow is not available; continue with OAuth browser flow with a
// local server endpoint as callback target
return oa.localServerFlow()
} else if resp.StatusCode != http.StatusOK {
if values != nil && values.Get("error_description") != "" {
return "", fmt.Errorf("HTTP %d: %s (%s)", resp.StatusCode, values.Get("error_description"), initURL)
}
return "", fmt.Errorf("error: HTTP %d (%s)", resp.StatusCode, initURL)
} else if err != nil {
return "", fmt.Errorf("%v (%s)", err, initURL)
}
timeNow := oa.TimeNow

View file

@ -42,6 +42,9 @@ func TestObtainAccessToken_deviceFlow(t *testing.T) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewBufferString(responseData.Encode())),
Header: http.Header{
"Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"},
},
}, nil
case "POST https://github.com/login/oauth/access_token":
if err := req.ParseForm(); err != nil {
@ -119,3 +122,137 @@ func TestObtainAccessToken_deviceFlow(t *testing.T) {
t.Errorf("expected to provide user with one-time code %q, got %q", "1234-ABCD", browseCode)
}
}
func Test_detectDeviceFlow(t *testing.T) {
type args struct {
statusCode int
values url.Values
}
tests := []struct {
name string
args args
doFallback bool
wantErr string
}{
{
name: "success",
args: args{
statusCode: 200,
values: url.Values{},
},
doFallback: false,
wantErr: "",
},
{
name: "wrong response type",
args: args{
statusCode: 200,
values: nil,
},
doFallback: true,
wantErr: "",
},
{
name: "401 unauthorized",
args: args{
statusCode: 401,
values: nil,
},
doFallback: true,
wantErr: "",
},
{
name: "403 forbidden",
args: args{
statusCode: 403,
values: nil,
},
doFallback: true,
wantErr: "",
},
{
name: "404 not found",
args: args{
statusCode: 404,
values: nil,
},
doFallback: true,
wantErr: "",
},
{
name: "422 unprocessable",
args: args{
statusCode: 422,
values: nil,
},
doFallback: true,
wantErr: "",
},
{
name: "402 payment required",
args: args{
statusCode: 402,
values: nil,
},
doFallback: false,
wantErr: "error: HTTP 402",
},
{
name: "400 bad request",
args: args{
statusCode: 400,
values: nil,
},
doFallback: false,
wantErr: "error: HTTP 400",
},
{
name: "400 with values",
args: args{
statusCode: 400,
values: url.Values{
"error": []string{"blah"},
},
},
doFallback: false,
wantErr: "error: HTTP 400",
},
{
name: "400 with unauthorized_client",
args: args{
statusCode: 400,
values: url.Values{
"error": []string{"unauthorized_client"},
},
},
doFallback: true,
wantErr: "",
},
{
name: "400 with error_description",
args: args{
statusCode: 400,
values: url.Values{
"error_description": []string{"HI"},
},
},
doFallback: false,
wantErr: "HTTP 400: HI",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := detectDeviceFlow(tt.args.statusCode, tt.args.values)
if (err != nil) != (tt.wantErr != "") {
t.Errorf("detectDeviceFlow() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr != "" && err.Error() != tt.wantErr {
t.Errorf("error = %q, wantErr = %q", err, tt.wantErr)
}
if got != tt.doFallback {
t.Errorf("detectDeviceFlow() = %v, want %v", got, tt.doFallback)
}
})
}
}