cli/auth/oauth_test.go
2020-10-01 17:21:35 +02:00

258 lines
6.2 KiB
Go

package auth
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"testing"
"time"
)
type roundTripper func(*http.Request) (*http.Response, error)
func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return rt(req)
}
func TestObtainAccessToken_deviceFlow(t *testing.T) {
requestCount := 0
rt := func(req *http.Request) (*http.Response, error) {
route := fmt.Sprintf("%s %s", req.Method, req.URL)
switch route {
case "POST https://github.com/login/device/code":
if err := req.ParseForm(); err != nil {
return nil, err
}
if req.PostForm.Get("client_id") != "CLIENT-ID" {
t.Errorf("expected POST /login/device/code to supply client_id=%q, got %q", "CLIENT-ID", req.PostForm.Get("client_id"))
}
if req.PostForm.Get("scope") != "repo gist" {
t.Errorf("expected POST /login/device/code to supply scope=%q, got %q", "repo gist", req.PostForm.Get("scope"))
}
responseData := url.Values{}
responseData.Set("device_code", "DEVICE-CODE")
responseData.Set("user_code", "1234-ABCD")
responseData.Set("verification_uri", "https://github.com/login/device")
responseData.Set("interval", "5")
responseData.Set("expires_in", "899")
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewBufferString(responseData.Encode())),
Header: http.Header{
"Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"},
},
}, nil
case "POST https://github.com/login/oauth/access_token":
if err := req.ParseForm(); err != nil {
return nil, err
}
if req.PostForm.Get("client_id") != "CLIENT-ID" {
t.Errorf("expected POST /login/oauth/access_token to supply client_id=%q, got %q", "CLIENT-ID", req.PostForm.Get("client_id"))
}
if req.PostForm.Get("device_code") != "DEVICE-CODE" {
t.Errorf("expected POST /login/oauth/access_token to supply device_code=%q, got %q", "DEVICE-CODE", req.PostForm.Get("scope"))
}
if req.PostForm.Get("grant_type") != "urn:ietf:params:oauth:grant-type:device_code" {
t.Errorf("expected POST /login/oauth/access_token to supply grant_type=%q, got %q", "urn:ietf:params:oauth:grant-type:device_code", req.PostForm.Get("grant_type"))
}
responseData := url.Values{}
requestCount++
if requestCount == 1 {
responseData.Set("error", "authorization_pending")
} else {
responseData.Set("access_token", "OTOKEN")
}
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewBufferString(responseData.Encode())),
}, nil
default:
return nil, fmt.Errorf("unstubbed HTTP request: %v", route)
}
}
httpClient := &http.Client{
Transport: roundTripper(rt),
}
slept := time.Duration(0)
var browseURL string
var browseCode string
oa := &OAuthFlow{
Hostname: "github.com",
ClientID: "CLIENT-ID",
ClientSecret: "CLIENT-SEKRIT",
Scopes: []string{"repo", "gist"},
OpenInBrowser: func(url, code string) error {
browseURL = url
browseCode = code
return nil
},
HTTPClient: httpClient,
TimeNow: time.Now,
TimeSleep: func(d time.Duration) {
slept += d
},
}
token, err := oa.ObtainAccessToken()
if err != nil {
t.Fatalf("ObtainAccessToken error: %v", err)
}
if token != "OTOKEN" {
t.Errorf("expected token %q, got %q", "OTOKEN", token)
}
if requestCount != 2 {
t.Errorf("expected 2 HTTP pings for token, got %d", requestCount)
}
if slept.String() != "10s" {
t.Errorf("expected total sleep duration of %s, got %s", "10s", slept.String())
}
if browseURL != "https://github.com/login/device" {
t.Errorf("expected to open browser at %s, got %s", "https://github.com/login/device", browseURL)
}
if browseCode != "1234-ABCD" {
t.Errorf("expected to provide user with one-time code %q, got %q", "1234-ABCD", browseCode)
}
}
func Test_detectDeviceFlow(t *testing.T) {
type args struct {
statusCode int
values url.Values
}
tests := []struct {
name string
args args
doFallback bool
wantErr string
}{
{
name: "success",
args: args{
statusCode: 200,
values: url.Values{},
},
doFallback: false,
wantErr: "",
},
{
name: "wrong response type",
args: args{
statusCode: 200,
values: nil,
},
doFallback: true,
wantErr: "",
},
{
name: "401 unauthorized",
args: args{
statusCode: 401,
values: nil,
},
doFallback: true,
wantErr: "",
},
{
name: "403 forbidden",
args: args{
statusCode: 403,
values: nil,
},
doFallback: true,
wantErr: "",
},
{
name: "404 not found",
args: args{
statusCode: 404,
values: nil,
},
doFallback: true,
wantErr: "",
},
{
name: "422 unprocessable",
args: args{
statusCode: 422,
values: nil,
},
doFallback: true,
wantErr: "",
},
{
name: "402 payment required",
args: args{
statusCode: 402,
values: nil,
},
doFallback: false,
wantErr: "error: HTTP 402",
},
{
name: "400 bad request",
args: args{
statusCode: 400,
values: nil,
},
doFallback: false,
wantErr: "error: HTTP 400",
},
{
name: "400 with values",
args: args{
statusCode: 400,
values: url.Values{
"error": []string{"blah"},
},
},
doFallback: false,
wantErr: "error: HTTP 400",
},
{
name: "400 with unauthorized_client",
args: args{
statusCode: 400,
values: url.Values{
"error": []string{"unauthorized_client"},
},
},
doFallback: true,
wantErr: "",
},
{
name: "400 with error_description",
args: args{
statusCode: 400,
values: url.Values{
"error_description": []string{"HI"},
},
},
doFallback: false,
wantErr: "HTTP 400: HI",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := detectDeviceFlow(tt.args.statusCode, tt.args.values)
if (err != nil) != (tt.wantErr != "") {
t.Errorf("detectDeviceFlow() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr != "" && err.Error() != tt.wantErr {
t.Errorf("error = %q, wantErr = %q", err, tt.wantErr)
}
if got != tt.doFallback {
t.Errorf("detectDeviceFlow() = %v, want %v", got, tt.doFallback)
}
})
}
}