Before, we implemented the OAuth app authorization flow which requires a callback URL. To provide such a URL, we had to spin up a local HTTP server, which was brittle and did not cover cases where a person might want to authenticate with a browser that runs on a different machine than the GitHub CLI process. This implements the OAuth Device Authorization flow where the user is given a one-time code and asked to paste it in the browser flow. There is no callback URL, so we can avoid spinning up a local server, and the user may open a browser on any of their devices, as long as they provide the correct one-time code. If the Device Authorization flow is detected to be unavailable for the OAuth app (right now, it's specifically enabled for GitHub CLI) or for an older GitHub Enterprise instance, this falls back to the old app authentication flow.
121 lines
3.6 KiB
Go
121 lines
3.6 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/url"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type roundTripper func(*http.Request) (*http.Response, error)
|
|
|
|
func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return rt(req)
|
|
}
|
|
|
|
func TestObtainAccessToken_deviceFlow(t *testing.T) {
|
|
requestCount := 0
|
|
rt := func(req *http.Request) (*http.Response, error) {
|
|
route := fmt.Sprintf("%s %s", req.Method, req.URL)
|
|
switch route {
|
|
case "POST https://github.com/login/device/code":
|
|
if err := req.ParseForm(); err != nil {
|
|
return nil, err
|
|
}
|
|
if req.PostForm.Get("client_id") != "CLIENT-ID" {
|
|
t.Errorf("expected POST /login/device/code to supply client_id=%q, got %q", "CLIENT-ID", req.PostForm.Get("client_id"))
|
|
}
|
|
if req.PostForm.Get("scope") != "repo gist" {
|
|
t.Errorf("expected POST /login/device/code to supply scope=%q, got %q", "repo gist", req.PostForm.Get("scope"))
|
|
}
|
|
|
|
responseData := url.Values{}
|
|
responseData.Set("device_code", "DEVICE-CODE")
|
|
responseData.Set("user_code", "1234-ABCD")
|
|
responseData.Set("verification_uri", "https://github.com/login/device")
|
|
responseData.Set("interval", "5")
|
|
responseData.Set("expires_in", "899")
|
|
|
|
return &http.Response{
|
|
StatusCode: 200,
|
|
Body: ioutil.NopCloser(bytes.NewBufferString(responseData.Encode())),
|
|
}, nil
|
|
case "POST https://github.com/login/oauth/access_token":
|
|
if err := req.ParseForm(); err != nil {
|
|
return nil, err
|
|
}
|
|
if req.PostForm.Get("client_id") != "CLIENT-ID" {
|
|
t.Errorf("expected POST /login/oauth/access_token to supply client_id=%q, got %q", "CLIENT-ID", req.PostForm.Get("client_id"))
|
|
}
|
|
if req.PostForm.Get("device_code") != "DEVICE-CODE" {
|
|
t.Errorf("expected POST /login/oauth/access_token to supply device_code=%q, got %q", "DEVICE-CODE", req.PostForm.Get("scope"))
|
|
}
|
|
if req.PostForm.Get("grant_type") != "urn:ietf:params:oauth:grant-type:device_code" {
|
|
t.Errorf("expected POST /login/oauth/access_token to supply grant_type=%q, got %q", "urn:ietf:params:oauth:grant-type:device_code", req.PostForm.Get("grant_type"))
|
|
}
|
|
|
|
responseData := url.Values{}
|
|
requestCount++
|
|
if requestCount == 1 {
|
|
responseData.Set("error", "authorization_pending")
|
|
} else {
|
|
responseData.Set("access_token", "OTOKEN")
|
|
}
|
|
|
|
return &http.Response{
|
|
StatusCode: 200,
|
|
Body: ioutil.NopCloser(bytes.NewBufferString(responseData.Encode())),
|
|
}, nil
|
|
default:
|
|
return nil, fmt.Errorf("unstubbed HTTP request: %v", route)
|
|
}
|
|
}
|
|
httpClient := &http.Client{
|
|
Transport: roundTripper(rt),
|
|
}
|
|
|
|
slept := time.Duration(0)
|
|
var browseURL string
|
|
var browseCode string
|
|
|
|
oa := &OAuthFlow{
|
|
Hostname: "github.com",
|
|
ClientID: "CLIENT-ID",
|
|
ClientSecret: "CLIENT-SEKRIT",
|
|
Scopes: []string{"repo", "gist"},
|
|
OpenInBrowser: func(url, code string) error {
|
|
browseURL = url
|
|
browseCode = code
|
|
return nil
|
|
},
|
|
HTTPClient: httpClient,
|
|
TimeNow: time.Now,
|
|
TimeSleep: func(d time.Duration) {
|
|
slept += d
|
|
},
|
|
}
|
|
|
|
token, err := oa.ObtainAccessToken()
|
|
if err != nil {
|
|
t.Fatalf("ObtainAccessToken error: %v", err)
|
|
}
|
|
|
|
if token != "OTOKEN" {
|
|
t.Errorf("expected token %q, got %q", "OTOKEN", token)
|
|
}
|
|
if requestCount != 2 {
|
|
t.Errorf("expected 2 HTTP pings for token, got %d", requestCount)
|
|
}
|
|
if slept.String() != "10s" {
|
|
t.Errorf("expected total sleep duration of %s, got %s", "10s", slept.String())
|
|
}
|
|
if browseURL != "https://github.com/login/device" {
|
|
t.Errorf("expected to open browser at %s, got %s", "https://github.com/login/device", browseURL)
|
|
}
|
|
if browseCode != "1234-ABCD" {
|
|
t.Errorf("expected to provide user with one-time code %q, got %q", "1234-ABCD", browseCode)
|
|
}
|
|
}
|