diff --git a/api/client.go b/api/client.go index f307b2bf5..3d2e77dc3 100644 --- a/api/client.go +++ b/api/client.go @@ -91,6 +91,11 @@ var issuedScopesWarning bool // CheckScopes checks whether an OAuth scope is present in a response func CheckScopes(wantedScope string, cb func(string) error) ClientOption { + wantedCandidates := []string{wantedScope} + if strings.HasPrefix(wantedScope, "read:") { + wantedCandidates = append(wantedCandidates, "admin:"+strings.TrimPrefix(wantedScope, "read:")) + } + return func(tr http.RoundTripper) http.RoundTripper { return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { res, err := tr.RoundTrip(req) @@ -102,10 +107,13 @@ func CheckScopes(wantedScope string, cb func(string) error) ClientOption { hasScopes := strings.Split(res.Header.Get("X-Oauth-Scopes"), ",") hasWanted := false + outer: for _, s := range hasScopes { - if wantedScope == strings.TrimSpace(s) { - hasWanted = true - break + for _, w := range wantedCandidates { + if w == strings.TrimSpace(s) { + hasWanted = true + break outer + } } } diff --git a/api/client_test.go b/api/client_test.go index b7c226c8f..4a56fc277 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "io/ioutil" + "net/http" "reflect" "testing" @@ -91,5 +92,79 @@ func TestRESTError(t *testing.T) { } if httpErr.Error() != "HTTP 422: OH NO (https://api.github.com/repos/branch)" { t.Errorf("got %q", httpErr.Error()) + + } +} + +func Test_CheckScopes(t *testing.T) { + tests := []struct { + name string + wantScope string + responseApp string + responseScopes string + expectCallback bool + }{ + { + name: "missing read:org", + wantScope: "read:org", + responseApp: "APPID", + responseScopes: "repo, gist", + expectCallback: true, + }, + { + name: "has read:org", + wantScope: "read:org", + responseApp: "APPID", + responseScopes: "repo, read:org, gist", + expectCallback: false, + }, + { + name: "has admin:org", + wantScope: "read:org", + responseApp: "APPID", + responseScopes: "repo, admin:org, gist", + expectCallback: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := &httpmock.Registry{} + tr.Register(httpmock.MatchAny, func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Header: http.Header{ + "X-Oauth-Client-Id": []string{tt.responseApp}, + "X-Oauth-Scopes": []string{tt.responseScopes}, + }, + }, nil + }) + + callbackInvoked := false + var gotAppID string + fn := CheckScopes(tt.wantScope, func(appID string) error { + callbackInvoked = true + gotAppID = appID + return nil + }) + + rt := fn(tr) + req, err := http.NewRequest("GET", "https://api.github.com/hello", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + issuedScopesWarning = false + _, err = rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if tt.expectCallback != callbackInvoked { + t.Fatalf("expected CheckScopes callback: %v", tt.expectCallback) + } + if tt.expectCallback && gotAppID != tt.responseApp { + t.Errorf("unexpected app ID: %q", gotAppID) + } + }) } }