From de85294c790a54b8acb6606536edf1f1094027ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Wed, 9 Oct 2019 16:48:51 +0200 Subject: [PATCH] Extract OAuth logic into a struct --- auth/oauth.go | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/auth/oauth.go b/auth/oauth.go index 63a9c1d01..a0978aea2 100644 --- a/auth/oauth.go +++ b/auth/oauth.go @@ -11,6 +11,18 @@ import ( "os/exec" ) +func main() { + oa := OAuthFlow{ + ClientID: os.Getenv("GH_OAUTH_CLIENT_ID"), + ClientSecret: os.Getenv("GH_OAUTH_CLIENT_SECRET"), + } + token, err := oa.ObtainAccessToken() + if err != nil { + panic(err) + } + fmt.Printf("OAuth access token: %s\n", token) +} + func randomString(length int) (string, error) { b := make([]byte, length/2) _, err := rand.Read(b) @@ -20,21 +32,23 @@ func randomString(length int) (string, error) { return fmt.Sprintf("%x", b), nil } -func main() { - state, _ := randomString(20) +type OAuthFlow struct { + ClientID string + ClientSecret string +} - clientID := os.Getenv("GH_OAUTH_CLIENT_ID") - clientSecret := os.Getenv("GH_OAUTH_CLIENT_SECRET") +func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) { + state, _ := randomString(20) code := "" listener, err := net.Listen("tcp", "localhost:0") if err != nil { - panic(err) + return } port := listener.Addr().(*net.TCPAddr).Port q := url.Values{} - q.Set("client_id", clientID) + q.Set("client_id", oa.ClientID) q.Set("redirect_uri", fmt.Sprintf("http://localhost:%d", port)) q.Set("scope", "repo") q.Set("state", state) @@ -42,7 +56,7 @@ func main() { cmd := exec.Command("open", fmt.Sprintf("https://github.com/login/oauth/authorize?%s", q.Encode())) err = cmd.Run() if err != nil { - panic(err) + return } http.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -59,23 +73,24 @@ func main() { resp, err := http.PostForm("https://github.com/login/oauth/access_token", url.Values{ - "client_id": {clientID}, - "client_secret": {clientSecret}, + "client_id": {oa.ClientID}, + "client_secret": {oa.ClientSecret}, "code": {code}, "state": {state}, }) if err != nil { - panic(err) + return } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { - panic(err) + return } tokenValues, err := url.ParseQuery(string(body)) if err != nil { - panic(err) + return } - fmt.Printf("OAuth access token: %s\n", tokenValues.Get("access_token")) + accessToken = tokenValues.Get("access_token") + return }