cli/auth/oauth.go
Mislav Marohnić cb4cc72e50 Handle HTTP 422 response to OAuth Device flow detection
If HTTP 422 is encountered, assume that OAuth Device Flow is unavailable
and fall back to OAuth app authorization flow.
2020-08-31 22:26:04 +02:00

263 lines
6.8 KiB
Go

package auth
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/cli/cli/internal/ghinstance"
)
func randomString(length int) (string, error) {
b := make([]byte, length/2)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return hex.EncodeToString(b), nil
}
// OAuthFlow represents the setup for authenticating with GitHub
type OAuthFlow struct {
Hostname string
ClientID string
ClientSecret string
Scopes []string
OpenInBrowser func(string, string) error
WriteSuccessHTML func(io.Writer)
VerboseStream io.Writer
HTTPClient *http.Client
TimeNow func() time.Time
TimeSleep func(time.Duration)
}
// ObtainAccessToken guides the user through the browser OAuth flow on GitHub
// and returns the OAuth access token upon completion.
func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) {
// first, check if OAuth Device Flow is supported
initURL := fmt.Sprintf("https://%s/login/device/code", oa.Hostname)
tokenURL := fmt.Sprintf("https://%s/login/oauth/access_token", oa.Hostname)
oa.logf("POST %s\n", initURL)
resp, err := oa.HTTPClient.PostForm(initURL, url.Values{
"client_id": {oa.ClientID},
"scope": {strings.Join(oa.Scopes, " ")},
})
if err != nil {
return
}
defer resp.Body.Close()
var values url.Values
bb, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}
if resp.StatusCode == 200 || strings.HasPrefix(resp.Header.Get("Content-Type"), "application/x-www-form-urlencoded") {
values, err = url.ParseQuery(string(bb))
if err != nil {
return
}
}
if resp.StatusCode == 401 || resp.StatusCode == 403 || resp.StatusCode == 404 || resp.StatusCode == 422 ||
(resp.StatusCode == 400 && values != nil && values.Get("error") == "unauthorized_client") {
// OAuth Device Flow is not available; continue with OAuth browser flow with a
// local server endpoint as callback target
return oa.localServerFlow()
} else if resp.StatusCode != 200 {
if values != nil && values.Get("error_description") != "" {
return "", fmt.Errorf("HTTP %d: %s (%s)", resp.StatusCode, values.Get("error_description"), initURL)
}
return "", fmt.Errorf("error: HTTP %d (%s)", resp.StatusCode, initURL)
}
timeNow := oa.TimeNow
if timeNow == nil {
timeNow = time.Now
}
timeSleep := oa.TimeSleep
if timeSleep == nil {
timeSleep = time.Sleep
}
intervalSeconds, err := strconv.Atoi(values.Get("interval"))
if err != nil {
return "", fmt.Errorf("could not parse interval=%q as integer: %w", values.Get("interval"), err)
}
checkInterval := time.Duration(intervalSeconds) * time.Second
expiresIn, err := strconv.Atoi(values.Get("expires_in"))
if err != nil {
return "", fmt.Errorf("could not parse expires_in=%q as integer: %w", values.Get("expires_in"), err)
}
expiresAt := timeNow().Add(time.Duration(expiresIn) * time.Second)
err = oa.OpenInBrowser(values.Get("verification_uri"), values.Get("user_code"))
if err != nil {
return
}
for {
timeSleep(checkInterval)
accessToken, err = oa.deviceFlowPing(tokenURL, values.Get("device_code"))
if accessToken == "" && err == nil {
if timeNow().After(expiresAt) {
err = errors.New("authentication timed out")
} else {
continue
}
}
break
}
return
}
func (oa *OAuthFlow) deviceFlowPing(tokenURL, deviceCode string) (accessToken string, err error) {
oa.logf("POST %s\n", tokenURL)
resp, err := oa.HTTPClient.PostForm(tokenURL, url.Values{
"client_id": {oa.ClientID},
"device_code": {deviceCode},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
})
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return "", fmt.Errorf("error: HTTP %d (%s)", resp.StatusCode, tokenURL)
}
bb, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
values, err := url.ParseQuery(string(bb))
if err != nil {
return "", err
}
if accessToken := values.Get("access_token"); accessToken != "" {
return accessToken, nil
}
errorType := values.Get("error")
if errorType == "authorization_pending" {
return "", nil
}
if errorDescription := values.Get("error_description"); errorDescription != "" {
return "", errors.New(errorDescription)
}
return "", errors.New("OAuth device flow error")
}
func (oa *OAuthFlow) localServerFlow() (accessToken string, err error) {
state, _ := randomString(20)
code := ""
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
return
}
port := listener.Addr().(*net.TCPAddr).Port
scopes := "repo"
if oa.Scopes != nil {
scopes = strings.Join(oa.Scopes, " ")
}
localhost := "127.0.0.1"
callbackPath := "/callback"
if ghinstance.IsEnterprise(oa.Hostname) {
// the OAuth app on Enterprise hosts is still registered with a legacy callback URL
// see https://github.com/cli/cli/pull/222, https://github.com/cli/cli/pull/650
localhost = "localhost"
callbackPath = "/"
}
q := url.Values{}
q.Set("client_id", oa.ClientID)
q.Set("redirect_uri", fmt.Sprintf("http://%s:%d%s", localhost, port, callbackPath))
q.Set("scope", scopes)
q.Set("state", state)
startURL := fmt.Sprintf("https://%s/login/oauth/authorize?%s", oa.Hostname, q.Encode())
oa.logf("open %s\n", startURL)
err = oa.OpenInBrowser(startURL, "")
if err != nil {
return
}
_ = http.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
oa.logf("server handler: %s\n", r.URL.Path)
if r.URL.Path != callbackPath {
w.WriteHeader(404)
return
}
defer listener.Close()
rq := r.URL.Query()
if state != rq.Get("state") {
fmt.Fprintf(w, "Error: state mismatch")
return
}
code = rq.Get("code")
oa.logf("server received code %q\n", code)
w.Header().Add("content-type", "text/html")
if oa.WriteSuccessHTML != nil {
oa.WriteSuccessHTML(w)
} else {
fmt.Fprintf(w, "<p>You have successfully authenticated. You may now close this page.</p>")
}
}))
tokenURL := fmt.Sprintf("https://%s/login/oauth/access_token", oa.Hostname)
oa.logf("POST %s\n", tokenURL)
resp, err := oa.HTTPClient.PostForm(tokenURL,
url.Values{
"client_id": {oa.ClientID},
"client_secret": {oa.ClientSecret},
"code": {code},
"state": {state},
})
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
err = fmt.Errorf("HTTP %d error while obtaining OAuth access token", resp.StatusCode)
return
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return
}
tokenValues, err := url.ParseQuery(string(body))
if err != nil {
return
}
accessToken = tokenValues.Get("access_token")
if accessToken == "" {
err = errors.New("the access token could not be read from HTTP response")
}
return
}
func (oa *OAuthFlow) logf(format string, args ...interface{}) {
if oa.VerboseStream == nil {
return
}
fmt.Fprintf(oa.VerboseStream, format, args...)
}