From 5f8648159db13109a250a70fe011b5fda4e3b0ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 1 Oct 2020 17:09:14 +0200 Subject: [PATCH 1/3] Fix handling of HTTP 403 in Device Flow detection --- auth/oauth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auth/oauth.go b/auth/oauth.go index c5066d588..14ddc12e2 100644 --- a/auth/oauth.go +++ b/auth/oauth.go @@ -69,7 +69,7 @@ func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) { } } - if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusPaymentRequired || + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusUnprocessableEntity || (resp.StatusCode == http.StatusBadRequest && values != nil && values.Get("error") == "unauthorized_client") { // OAuth Device Flow is not available; continue with OAuth browser flow with a From 93642529dae41d202f1f43c6600f6d04fcc31bf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 1 Oct 2020 17:09:47 +0200 Subject: [PATCH 2/3] Enforce correct content-type in Device Flow detection --- auth/oauth.go | 37 ++++++++----- auth/oauth_test.go | 128 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+), 13 deletions(-) diff --git a/auth/oauth.go b/auth/oauth.go index 14ddc12e2..2c8a78cd4 100644 --- a/auth/oauth.go +++ b/auth/oauth.go @@ -40,6 +40,21 @@ type OAuthFlow struct { TimeSleep func(time.Duration) } +func detectDeviceFlow(statusCode int, values url.Values) (bool, error) { + if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden || + statusCode == http.StatusNotFound || statusCode == http.StatusUnprocessableEntity || + (statusCode == http.StatusOK && values == nil) || + (statusCode == http.StatusBadRequest && values != nil && values.Get("error") == "unauthorized_client") { + return true, nil + } else if statusCode != http.StatusOK { + if values != nil && values.Get("error_description") != "" { + return false, fmt.Errorf("HTTP %d: %s", statusCode, values.Get("error_description")) + } + return false, fmt.Errorf("error: HTTP %d", statusCode) + } + return false, nil +} + // ObtainAccessToken guides the user through the browser OAuth flow on GitHub // and returns the OAuth access token upon completion. func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) { @@ -58,28 +73,24 @@ func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) { defer resp.Body.Close() var values url.Values - bb, err := ioutil.ReadAll(resp.Body) - if err != nil { - return - } - if resp.StatusCode == http.StatusOK || strings.HasPrefix(resp.Header.Get("Content-Type"), "application/x-www-form-urlencoded") { + if strings.Contains(resp.Header.Get("Content-Type"), "application/x-www-form-urlencoded") { + var bb []byte + bb, err = ioutil.ReadAll(resp.Body) + if err != nil { + return + } values, err = url.ParseQuery(string(bb)) if err != nil { return } } - if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden || - resp.StatusCode == http.StatusNotFound || resp.StatusCode == http.StatusUnprocessableEntity || - (resp.StatusCode == http.StatusBadRequest && values != nil && values.Get("error") == "unauthorized_client") { + if doFallback, err := detectDeviceFlow(resp.StatusCode, values); doFallback { // OAuth Device Flow is not available; continue with OAuth browser flow with a // local server endpoint as callback target return oa.localServerFlow() - } else if resp.StatusCode != http.StatusOK { - if values != nil && values.Get("error_description") != "" { - return "", fmt.Errorf("HTTP %d: %s (%s)", resp.StatusCode, values.Get("error_description"), initURL) - } - return "", fmt.Errorf("error: HTTP %d (%s)", resp.StatusCode, initURL) + } else if err != nil { + return "", fmt.Errorf("%v (%s)", err, initURL) } timeNow := oa.TimeNow diff --git a/auth/oauth_test.go b/auth/oauth_test.go index 66a4a8190..3c9e23043 100644 --- a/auth/oauth_test.go +++ b/auth/oauth_test.go @@ -42,6 +42,9 @@ func TestObtainAccessToken_deviceFlow(t *testing.T) { return &http.Response{ StatusCode: 200, Body: ioutil.NopCloser(bytes.NewBufferString(responseData.Encode())), + Header: http.Header{ + "Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"}, + }, }, nil case "POST https://github.com/login/oauth/access_token": if err := req.ParseForm(); err != nil { @@ -119,3 +122,128 @@ func TestObtainAccessToken_deviceFlow(t *testing.T) { t.Errorf("expected to provide user with one-time code %q, got %q", "1234-ABCD", browseCode) } } + +func Test_detectDeviceFlow(t *testing.T) { + type args struct { + statusCode int + values url.Values + } + tests := []struct { + name string + args args + doFallback bool + wantErr string + }{ + { + name: "success", + args: args{ + statusCode: 200, + values: url.Values{}, + }, + doFallback: false, + wantErr: "", + }, + { + name: "wrong response type", + args: args{ + statusCode: 200, + values: nil, + }, + doFallback: true, + wantErr: "", + }, + { + name: "401 unauthorized", + args: args{ + statusCode: 401, + values: nil, + }, + doFallback: true, + wantErr: "", + }, + { + name: "403 forbidden", + args: args{ + statusCode: 403, + values: nil, + }, + doFallback: true, + wantErr: "", + }, + { + name: "404 not found", + args: args{ + statusCode: 404, + values: nil, + }, + doFallback: true, + wantErr: "", + }, + { + name: "422 unprocessable", + args: args{ + statusCode: 422, + values: nil, + }, + doFallback: true, + wantErr: "", + }, + { + name: "400 bad request", + args: args{ + statusCode: 400, + values: nil, + }, + doFallback: false, + wantErr: "error: HTTP 400", + }, + { + name: "400 with values", + args: args{ + statusCode: 400, + values: url.Values{ + "error": []string{"blah"}, + }, + }, + doFallback: false, + wantErr: "error: HTTP 400", + }, + { + name: "400 with unauthorized_client", + args: args{ + statusCode: 400, + values: url.Values{ + "error": []string{"unauthorized_client"}, + }, + }, + doFallback: true, + wantErr: "", + }, + { + name: "400 with error_description", + args: args{ + statusCode: 400, + values: url.Values{ + "error_description": []string{"HI"}, + }, + }, + doFallback: false, + wantErr: "HTTP 400: HI", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := detectDeviceFlow(tt.args.statusCode, tt.args.values) + if (err != nil) != (tt.wantErr != "") { + t.Errorf("detectDeviceFlow() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr != "" && err.Error() != tt.wantErr { + t.Errorf("error = %q, wantErr = %q", err, tt.wantErr) + } + if got != tt.doFallback { + t.Errorf("detectDeviceFlow() = %v, want %v", got, tt.doFallback) + } + }) + } +} From 61609db9ef05859039620c033252ffe2f384e79f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Thu, 1 Oct 2020 17:21:35 +0200 Subject: [PATCH 3/3] Cover HTTP 402 in oauth tests --- auth/oauth_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/auth/oauth_test.go b/auth/oauth_test.go index 3c9e23043..a9070a1b1 100644 --- a/auth/oauth_test.go +++ b/auth/oauth_test.go @@ -188,6 +188,15 @@ func Test_detectDeviceFlow(t *testing.T) { doFallback: true, wantErr: "", }, + { + name: "402 payment required", + args: args{ + statusCode: 402, + values: nil, + }, + doFallback: false, + wantErr: "error: HTTP 402", + }, { name: "400 bad request", args: args{