258 lines
6.2 KiB
Go
258 lines
6.2 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type roundTripper func(*http.Request) (*http.Response, error)
|
|
|
|
func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return rt(req)
|
|
}
|
|
|
|
func TestObtainAccessToken_deviceFlow(t *testing.T) {
|
|
requestCount := 0
|
|
rt := func(req *http.Request) (*http.Response, error) {
|
|
route := fmt.Sprintf("%s %s", req.Method, req.URL)
|
|
switch route {
|
|
case "POST https://github.com/login/device/code":
|
|
if err := req.ParseForm(); err != nil {
|
|
return nil, err
|
|
}
|
|
if req.PostForm.Get("client_id") != "CLIENT-ID" {
|
|
t.Errorf("expected POST /login/device/code to supply client_id=%q, got %q", "CLIENT-ID", req.PostForm.Get("client_id"))
|
|
}
|
|
if req.PostForm.Get("scope") != "repo gist" {
|
|
t.Errorf("expected POST /login/device/code to supply scope=%q, got %q", "repo gist", req.PostForm.Get("scope"))
|
|
}
|
|
|
|
responseData := url.Values{}
|
|
responseData.Set("device_code", "DEVICE-CODE")
|
|
responseData.Set("user_code", "1234-ABCD")
|
|
responseData.Set("verification_uri", "https://github.com/login/device")
|
|
responseData.Set("interval", "5")
|
|
responseData.Set("expires_in", "899")
|
|
|
|
return &http.Response{
|
|
StatusCode: 200,
|
|
Body: ioutil.NopCloser(bytes.NewBufferString(responseData.Encode())),
|
|
Header: http.Header{
|
|
"Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"},
|
|
},
|
|
}, nil
|
|
case "POST https://github.com/login/oauth/access_token":
|
|
if err := req.ParseForm(); err != nil {
|
|
return nil, err
|
|
}
|
|
if req.PostForm.Get("client_id") != "CLIENT-ID" {
|
|
t.Errorf("expected POST /login/oauth/access_token to supply client_id=%q, got %q", "CLIENT-ID", req.PostForm.Get("client_id"))
|
|
}
|
|
if req.PostForm.Get("device_code") != "DEVICE-CODE" {
|
|
t.Errorf("expected POST /login/oauth/access_token to supply device_code=%q, got %q", "DEVICE-CODE", req.PostForm.Get("scope"))
|
|
}
|
|
if req.PostForm.Get("grant_type") != "urn:ietf:params:oauth:grant-type:device_code" {
|
|
t.Errorf("expected POST /login/oauth/access_token to supply grant_type=%q, got %q", "urn:ietf:params:oauth:grant-type:device_code", req.PostForm.Get("grant_type"))
|
|
}
|
|
|
|
responseData := url.Values{}
|
|
requestCount++
|
|
if requestCount == 1 {
|
|
responseData.Set("error", "authorization_pending")
|
|
} else {
|
|
responseData.Set("access_token", "OTOKEN")
|
|
}
|
|
|
|
return &http.Response{
|
|
StatusCode: 200,
|
|
Body: ioutil.NopCloser(bytes.NewBufferString(responseData.Encode())),
|
|
}, nil
|
|
default:
|
|
return nil, fmt.Errorf("unstubbed HTTP request: %v", route)
|
|
}
|
|
}
|
|
httpClient := &http.Client{
|
|
Transport: roundTripper(rt),
|
|
}
|
|
|
|
slept := time.Duration(0)
|
|
var browseURL string
|
|
var browseCode string
|
|
|
|
oa := &OAuthFlow{
|
|
Hostname: "github.com",
|
|
ClientID: "CLIENT-ID",
|
|
ClientSecret: "CLIENT-SEKRIT",
|
|
Scopes: []string{"repo", "gist"},
|
|
OpenInBrowser: func(url, code string) error {
|
|
browseURL = url
|
|
browseCode = code
|
|
return nil
|
|
},
|
|
HTTPClient: httpClient,
|
|
TimeNow: time.Now,
|
|
TimeSleep: func(d time.Duration) {
|
|
slept += d
|
|
},
|
|
}
|
|
|
|
token, err := oa.ObtainAccessToken()
|
|
if err != nil {
|
|
t.Fatalf("ObtainAccessToken error: %v", err)
|
|
}
|
|
|
|
if token != "OTOKEN" {
|
|
t.Errorf("expected token %q, got %q", "OTOKEN", token)
|
|
}
|
|
if requestCount != 2 {
|
|
t.Errorf("expected 2 HTTP pings for token, got %d", requestCount)
|
|
}
|
|
if slept.String() != "10s" {
|
|
t.Errorf("expected total sleep duration of %s, got %s", "10s", slept.String())
|
|
}
|
|
if browseURL != "https://github.com/login/device" {
|
|
t.Errorf("expected to open browser at %s, got %s", "https://github.com/login/device", browseURL)
|
|
}
|
|
if browseCode != "1234-ABCD" {
|
|
t.Errorf("expected to provide user with one-time code %q, got %q", "1234-ABCD", browseCode)
|
|
}
|
|
}
|
|
|
|
func Test_detectDeviceFlow(t *testing.T) {
|
|
type args struct {
|
|
statusCode int
|
|
values url.Values
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
doFallback bool
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "success",
|
|
args: args{
|
|
statusCode: 200,
|
|
values: url.Values{},
|
|
},
|
|
doFallback: false,
|
|
wantErr: "",
|
|
},
|
|
{
|
|
name: "wrong response type",
|
|
args: args{
|
|
statusCode: 200,
|
|
values: nil,
|
|
},
|
|
doFallback: true,
|
|
wantErr: "",
|
|
},
|
|
{
|
|
name: "401 unauthorized",
|
|
args: args{
|
|
statusCode: 401,
|
|
values: nil,
|
|
},
|
|
doFallback: true,
|
|
wantErr: "",
|
|
},
|
|
{
|
|
name: "403 forbidden",
|
|
args: args{
|
|
statusCode: 403,
|
|
values: nil,
|
|
},
|
|
doFallback: true,
|
|
wantErr: "",
|
|
},
|
|
{
|
|
name: "404 not found",
|
|
args: args{
|
|
statusCode: 404,
|
|
values: nil,
|
|
},
|
|
doFallback: true,
|
|
wantErr: "",
|
|
},
|
|
{
|
|
name: "422 unprocessable",
|
|
args: args{
|
|
statusCode: 422,
|
|
values: nil,
|
|
},
|
|
doFallback: true,
|
|
wantErr: "",
|
|
},
|
|
{
|
|
name: "402 payment required",
|
|
args: args{
|
|
statusCode: 402,
|
|
values: nil,
|
|
},
|
|
doFallback: false,
|
|
wantErr: "error: HTTP 402",
|
|
},
|
|
{
|
|
name: "400 bad request",
|
|
args: args{
|
|
statusCode: 400,
|
|
values: nil,
|
|
},
|
|
doFallback: false,
|
|
wantErr: "error: HTTP 400",
|
|
},
|
|
{
|
|
name: "400 with values",
|
|
args: args{
|
|
statusCode: 400,
|
|
values: url.Values{
|
|
"error": []string{"blah"},
|
|
},
|
|
},
|
|
doFallback: false,
|
|
wantErr: "error: HTTP 400",
|
|
},
|
|
{
|
|
name: "400 with unauthorized_client",
|
|
args: args{
|
|
statusCode: 400,
|
|
values: url.Values{
|
|
"error": []string{"unauthorized_client"},
|
|
},
|
|
},
|
|
doFallback: true,
|
|
wantErr: "",
|
|
},
|
|
{
|
|
name: "400 with error_description",
|
|
args: args{
|
|
statusCode: 400,
|
|
values: url.Values{
|
|
"error_description": []string{"HI"},
|
|
},
|
|
},
|
|
doFallback: false,
|
|
wantErr: "HTTP 400: HI",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := detectDeviceFlow(tt.args.statusCode, tt.args.values)
|
|
if (err != nil) != (tt.wantErr != "") {
|
|
t.Errorf("detectDeviceFlow() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if tt.wantErr != "" && err.Error() != tt.wantErr {
|
|
t.Errorf("error = %q, wantErr = %q", err, tt.wantErr)
|
|
}
|
|
if got != tt.doFallback {
|
|
t.Errorf("detectDeviceFlow() = %v, want %v", got, tt.doFallback)
|
|
}
|
|
})
|
|
}
|
|
}
|