Merge pull request #2054 from cli/ghe-auth
Handle edge cases in GHE auth
This commit is contained in:
commit
54e292703b
2 changed files with 161 additions and 13 deletions
|
|
@ -40,6 +40,21 @@ type OAuthFlow struct {
|
|||
TimeSleep func(time.Duration)
|
||||
}
|
||||
|
||||
func detectDeviceFlow(statusCode int, values url.Values) (bool, error) {
|
||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden ||
|
||||
statusCode == http.StatusNotFound || statusCode == http.StatusUnprocessableEntity ||
|
||||
(statusCode == http.StatusOK && values == nil) ||
|
||||
(statusCode == http.StatusBadRequest && values != nil && values.Get("error") == "unauthorized_client") {
|
||||
return true, nil
|
||||
} else if statusCode != http.StatusOK {
|
||||
if values != nil && values.Get("error_description") != "" {
|
||||
return false, fmt.Errorf("HTTP %d: %s", statusCode, values.Get("error_description"))
|
||||
}
|
||||
return false, fmt.Errorf("error: HTTP %d", statusCode)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// ObtainAccessToken guides the user through the browser OAuth flow on GitHub
|
||||
// and returns the OAuth access token upon completion.
|
||||
func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) {
|
||||
|
|
@ -58,28 +73,24 @@ func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) {
|
|||
defer resp.Body.Close()
|
||||
|
||||
var values url.Values
|
||||
bb, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if resp.StatusCode == http.StatusOK || strings.HasPrefix(resp.Header.Get("Content-Type"), "application/x-www-form-urlencoded") {
|
||||
if strings.Contains(resp.Header.Get("Content-Type"), "application/x-www-form-urlencoded") {
|
||||
var bb []byte
|
||||
bb, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
values, err = url.ParseQuery(string(bb))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusPaymentRequired ||
|
||||
resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusUnprocessableEntity ||
|
||||
(resp.StatusCode == http.StatusBadRequest && values != nil && values.Get("error") == "unauthorized_client") {
|
||||
if doFallback, err := detectDeviceFlow(resp.StatusCode, values); doFallback {
|
||||
// OAuth Device Flow is not available; continue with OAuth browser flow with a
|
||||
// local server endpoint as callback target
|
||||
return oa.localServerFlow()
|
||||
} else if resp.StatusCode != http.StatusOK {
|
||||
if values != nil && values.Get("error_description") != "" {
|
||||
return "", fmt.Errorf("HTTP %d: %s (%s)", resp.StatusCode, values.Get("error_description"), initURL)
|
||||
}
|
||||
return "", fmt.Errorf("error: HTTP %d (%s)", resp.StatusCode, initURL)
|
||||
} else if err != nil {
|
||||
return "", fmt.Errorf("%v (%s)", err, initURL)
|
||||
}
|
||||
|
||||
timeNow := oa.TimeNow
|
||||
|
|
|
|||
|
|
@ -42,6 +42,9 @@ func TestObtainAccessToken_deviceFlow(t *testing.T) {
|
|||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(bytes.NewBufferString(responseData.Encode())),
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"},
|
||||
},
|
||||
}, nil
|
||||
case "POST https://github.com/login/oauth/access_token":
|
||||
if err := req.ParseForm(); err != nil {
|
||||
|
|
@ -119,3 +122,137 @@ func TestObtainAccessToken_deviceFlow(t *testing.T) {
|
|||
t.Errorf("expected to provide user with one-time code %q, got %q", "1234-ABCD", browseCode)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_detectDeviceFlow(t *testing.T) {
|
||||
type args struct {
|
||||
statusCode int
|
||||
values url.Values
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
doFallback bool
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
args: args{
|
||||
statusCode: 200,
|
||||
values: url.Values{},
|
||||
},
|
||||
doFallback: false,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "wrong response type",
|
||||
args: args{
|
||||
statusCode: 200,
|
||||
values: nil,
|
||||
},
|
||||
doFallback: true,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "401 unauthorized",
|
||||
args: args{
|
||||
statusCode: 401,
|
||||
values: nil,
|
||||
},
|
||||
doFallback: true,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "403 forbidden",
|
||||
args: args{
|
||||
statusCode: 403,
|
||||
values: nil,
|
||||
},
|
||||
doFallback: true,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "404 not found",
|
||||
args: args{
|
||||
statusCode: 404,
|
||||
values: nil,
|
||||
},
|
||||
doFallback: true,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "422 unprocessable",
|
||||
args: args{
|
||||
statusCode: 422,
|
||||
values: nil,
|
||||
},
|
||||
doFallback: true,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "402 payment required",
|
||||
args: args{
|
||||
statusCode: 402,
|
||||
values: nil,
|
||||
},
|
||||
doFallback: false,
|
||||
wantErr: "error: HTTP 402",
|
||||
},
|
||||
{
|
||||
name: "400 bad request",
|
||||
args: args{
|
||||
statusCode: 400,
|
||||
values: nil,
|
||||
},
|
||||
doFallback: false,
|
||||
wantErr: "error: HTTP 400",
|
||||
},
|
||||
{
|
||||
name: "400 with values",
|
||||
args: args{
|
||||
statusCode: 400,
|
||||
values: url.Values{
|
||||
"error": []string{"blah"},
|
||||
},
|
||||
},
|
||||
doFallback: false,
|
||||
wantErr: "error: HTTP 400",
|
||||
},
|
||||
{
|
||||
name: "400 with unauthorized_client",
|
||||
args: args{
|
||||
statusCode: 400,
|
||||
values: url.Values{
|
||||
"error": []string{"unauthorized_client"},
|
||||
},
|
||||
},
|
||||
doFallback: true,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "400 with error_description",
|
||||
args: args{
|
||||
statusCode: 400,
|
||||
values: url.Values{
|
||||
"error_description": []string{"HI"},
|
||||
},
|
||||
},
|
||||
doFallback: false,
|
||||
wantErr: "HTTP 400: HI",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := detectDeviceFlow(tt.args.statusCode, tt.args.values)
|
||||
if (err != nil) != (tt.wantErr != "") {
|
||||
t.Errorf("detectDeviceFlow() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr != "" && err.Error() != tt.wantErr {
|
||||
t.Errorf("error = %q, wantErr = %q", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.doFallback {
|
||||
t.Errorf("detectDeviceFlow() = %v, want %v", got, tt.doFallback)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue