cli/github/client.go
2019-10-08 13:25:49 +02:00

585 lines
14 KiB
Go

package github
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"strings"
"time"
"github.com/github/gh-cli/version"
)
const (
GitHubHost string = "github.com"
OAuthAppURL string = "https://github.com/github/gh-cli"
)
var userAgent = "GitHub CLI " + version.Version
func NewClient(h string) *Client {
return NewClientWithHost(&Host{Host: h})
}
func NewClientWithHost(host *Host) *Client {
return &Client{Host: host}
}
type Client struct {
Host *Host
cachedClient *simpleClient
}
func (client *Client) FetchPullRequests(project *Project, filterParams map[string]interface{}, limit int, filter func(*PullRequest) bool) (pulls []PullRequest, err error) {
api, err := client.simpleApi()
if err != nil {
return
}
path := fmt.Sprintf("repos/%s/%s/pulls?per_page=%d", project.Owner, project.Name, perPage(limit, 100))
if filterParams != nil {
path = addQuery(path, filterParams)
}
pulls = []PullRequest{}
var res *simpleResponse
for path != "" {
res, err = api.GetFile(path, draftsType)
if err = checkStatus(200, "fetching pull requests", res, err); err != nil {
return
}
path = res.Link("next")
pullsPage := []PullRequest{}
if err = res.Unmarshal(&pullsPage); err != nil {
return
}
for _, pr := range pullsPage {
if filter == nil || filter(&pr) {
pulls = append(pulls, pr)
if limit > 0 && len(pulls) == limit {
path = ""
break
}
}
}
}
return
}
func (client *Client) PullRequest(project *Project, id string) (pr *PullRequest, err error) {
api, err := client.simpleApi()
if err != nil {
return
}
res, err := api.Get(fmt.Sprintf("repos/%s/%s/pulls/%s", project.Owner, project.Name, id))
if err = checkStatus(200, "getting pull request", res, err); err != nil {
return
}
pr = &PullRequest{}
err = res.Unmarshal(pr)
return
}
func (client *Client) CreatePullRequest(project *Project, params map[string]interface{}) (pr *PullRequest, err error) {
api, err := client.simpleApi()
if err != nil {
return
}
res, err := api.PostJSONPreview(fmt.Sprintf("repos/%s/%s/pulls", project.Owner, project.Name), params, draftsType)
if err = checkStatus(201, "creating pull request", res, err); err != nil {
if res != nil && res.StatusCode == 404 {
projectUrl := strings.SplitN(project.WebURL("", "", ""), "://", 2)[1]
err = fmt.Errorf("%s\nAre you sure that %s exists?", err, projectUrl)
}
return
}
pr = &PullRequest{}
err = res.Unmarshal(pr)
return
}
func (client *Client) UpdatePullRequest(pr *PullRequest, params map[string]interface{}) (updatedPullRequest *PullRequest, err error) {
api, err := client.simpleApi()
if err != nil {
return
}
res, err := api.PatchJSON(pr.ApiUrl, params)
if err = checkStatus(200, "updating pull request", res, err); err != nil {
return
}
updatedPullRequest = &PullRequest{}
err = res.Unmarshal(updatedPullRequest)
return
}
func (client *Client) RequestReview(project *Project, prNumber int, params map[string]interface{}) (err error) {
api, err := client.simpleApi()
if err != nil {
return
}
res, err := api.PostJSON(fmt.Sprintf("repos/%s/%s/pulls/%d/requested_reviewers", project.Owner, project.Name, prNumber), params)
if err = checkStatus(201, "requesting reviewer", res, err); err != nil {
return
}
res.Body.Close()
return
}
func (client *Client) Repository(project *Project) (repo *Repository, err error) {
api, err := client.simpleApi()
if err != nil {
return
}
res, err := api.Get(fmt.Sprintf("repos/%s/%s", project.Owner, project.Name))
if err = checkStatus(200, "getting repository info", res, err); err != nil {
return
}
repo = &Repository{}
err = res.Unmarshal(&repo)
return
}
type Repository struct {
Name string `json:"name"`
FullName string `json:"full_name"`
Parent *Repository `json:"parent"`
Owner *User `json:"owner"`
Private bool `json:"private"`
HasWiki bool `json:"has_wiki"`
Permissions *RepositoryPermissions `json:"permissions"`
HtmlUrl string `json:"html_url"`
DefaultBranch string `json:"default_branch"`
}
type RepositoryPermissions struct {
Admin bool `json:"admin"`
Push bool `json:"push"`
Pull bool `json:"pull"`
}
type Issue struct {
Number int `json:"number"`
State string `json:"state"`
Title string `json:"title"`
Body string `json:"body"`
User *User `json:"user"`
PullRequest *PullRequest `json:"pull_request"`
Head *PullRequestSpec `json:"head"`
Base *PullRequestSpec `json:"base"`
MergeCommitSha string `json:"merge_commit_sha"`
MaintainerCanModify bool `json:"maintainer_can_modify"`
Draft bool `json:"draft"`
Comments int `json:"comments"`
Labels []IssueLabel `json:"labels"`
Assignees []User `json:"assignees"`
Milestone *Milestone `json:"milestone"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
MergedAt time.Time `json:"merged_at"`
RequestedReviewers []User `json:"requested_reviewers"`
RequestedTeams []Team `json:"requested_teams"`
ApiUrl string `json:"url"`
HtmlUrl string `json:"html_url"`
ClosedBy *User `json:"closed_by"`
}
type PullRequest Issue
type PullRequestSpec struct {
Label string `json:"label"`
Ref string `json:"ref"`
Sha string `json:"sha"`
Repo *Repository `json:"repo"`
}
func (pr *PullRequest) IsSameRepo() bool {
return pr.Head != nil && pr.Head.Repo != nil &&
pr.Head.Repo.Name == pr.Base.Repo.Name &&
pr.Head.Repo.Owner.Login == pr.Base.Repo.Owner.Login
}
func (pr *PullRequest) HasRequestedReviewer(name string) bool {
for _, user := range pr.RequestedReviewers {
if strings.EqualFold(user.Login, name) {
return true
}
}
return false
}
func (pr *PullRequest) HasRequestedTeam(name string) bool {
for _, team := range pr.RequestedTeams {
if strings.EqualFold(team.Slug, name) {
return true
}
}
return false
}
type IssueLabel struct {
Name string `json:"name"`
Color string `json:"color"`
}
type User struct {
Login string `json:"login"`
}
type Team struct {
Name string `json:"name"`
Slug string `json:"slug"`
}
type Milestone struct {
Number int `json:"number"`
Title string `json:"title"`
}
func (client *Client) GenericAPIRequest(method, path string, data interface{}, headers map[string]string, ttl int) (*simpleResponse, error) {
api, err := client.simpleApi()
if err != nil {
return nil, err
}
api.CacheTTL = ttl
var body io.Reader
switch d := data.(type) {
case map[string]interface{}:
if method == "GET" {
path = addQuery(path, d)
} else if len(d) > 0 {
json, err := json.Marshal(d)
if err != nil {
return nil, err
}
body = bytes.NewBuffer(json)
}
case io.Reader:
body = d
}
return api.performRequest(method, path, body, func(req *http.Request) {
if body != nil {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}
for key, value := range headers {
req.Header.Set(key, value)
}
})
}
func (client *Client) CurrentUser() (user *User, err error) {
api, err := client.simpleApi()
if err != nil {
return
}
res, err := api.Get("user")
if err = checkStatus(200, "getting current user", res, err); err != nil {
return
}
user = &User{}
err = res.Unmarshal(user)
return
}
type AuthorizationEntry struct {
Token string `json:"token"`
}
func isToken(api *simpleClient, password string) bool {
api.PrepareRequest = func(req *http.Request) {
req.Header.Set("Authorization", "token "+password)
}
res, _ := api.Get("user")
if res != nil && res.StatusCode == 200 {
return true
}
return false
}
func (client *Client) FindOrCreateToken(user, password, twoFactorCode string) (token string, err error) {
api := client.apiClient()
if len(password) >= 40 && isToken(api, password) {
return password, nil
}
params := map[string]interface{}{
"scopes": []string{"repo"},
"note_url": OAuthAppURL,
}
api.PrepareRequest = func(req *http.Request) {
req.SetBasicAuth(user, password)
if twoFactorCode != "" {
req.Header.Set("X-GitHub-OTP", twoFactorCode)
}
}
count := 1
maxTries := 9
for {
params["note"], err = authTokenNote(count)
if err != nil {
return
}
res, postErr := api.PostJSON("authorizations", params)
if postErr != nil {
err = postErr
break
}
if res.StatusCode == 201 {
auth := &AuthorizationEntry{}
if err = res.Unmarshal(auth); err != nil {
return
}
token = auth.Token
break
} else if res.StatusCode == 422 && count < maxTries {
count++
} else {
errInfo, e := res.ErrorInfo()
if e == nil {
err = errInfo
} else {
err = e
}
return
}
}
return
}
func (client *Client) ensureAccessToken() error {
if client.Host.AccessToken == "" {
host, err := CurrentConfig().PromptForHost(client.Host.Host)
if err != nil {
return err
}
client.Host = host
}
return nil
}
func (client *Client) simpleApi() (c *simpleClient, err error) {
err = client.ensureAccessToken()
if err != nil {
return
}
if client.cachedClient != nil {
c = client.cachedClient
return
}
c = client.apiClient()
c.PrepareRequest = func(req *http.Request) {
clientDomain := normalizeHost(client.Host.Host)
if strings.HasPrefix(clientDomain, "api.github.") {
clientDomain = strings.TrimPrefix(clientDomain, "api.")
}
requestHost := strings.ToLower(req.URL.Host)
if requestHost == clientDomain || strings.HasSuffix(requestHost, "."+clientDomain) {
req.Header.Set("Authorization", "token "+client.Host.AccessToken)
}
}
client.cachedClient = c
return
}
func (client *Client) apiClient() *simpleClient {
unixSocket := os.ExpandEnv(client.Host.UnixSocket)
httpClient := newHttpClient(os.Getenv("HUB_TEST_HOST"), os.Getenv("HUB_VERBOSE") != "", unixSocket)
apiRoot := client.absolute(normalizeHost(client.Host.Host))
if !strings.HasPrefix(apiRoot.Host, "api.github.") {
apiRoot.Path = "/api/v3/"
}
return &simpleClient{
httpClient: httpClient,
rootUrl: apiRoot,
}
}
func (client *Client) absolute(host string) *url.URL {
u, err := url.Parse("https://" + host + "/")
if err != nil {
panic(err)
} else if client.Host != nil && client.Host.Protocol != "" {
u.Scheme = client.Host.Protocol
}
return u
}
func normalizeHost(host string) string {
if host == "" {
return GitHubHost
} else if strings.EqualFold(host, GitHubHost) {
return "api.github.com"
} else if strings.EqualFold(host, "github.localhost") {
return "api.github.localhost"
} else {
return strings.ToLower(host)
}
}
func checkStatus(expectedStatus int, action string, response *simpleResponse, err error) error {
if err != nil {
return fmt.Errorf("Error %s: %s", action, err.Error())
} else if response.StatusCode != expectedStatus {
errInfo, err := response.ErrorInfo()
if err == nil {
return FormatError(action, errInfo)
} else {
return fmt.Errorf("Error %s: %s (HTTP %d)", action, err.Error(), response.StatusCode)
}
} else {
return nil
}
}
func FormatError(action string, err error) (ee error) {
switch e := err.(type) {
default:
ee = err
case *errorInfo:
statusCode := e.Response.StatusCode
var reason string
if s := strings.SplitN(e.Response.Status, " ", 2); len(s) >= 2 {
reason = strings.TrimSpace(s[1])
}
errStr := fmt.Sprintf("Error %s: %s (HTTP %d)", action, reason, statusCode)
var errorSentences []string
for _, err := range e.Errors {
switch err.Code {
case "custom":
errorSentences = append(errorSentences, err.Message)
case "missing_field":
errorSentences = append(errorSentences, fmt.Sprintf("Missing field: \"%s\"", err.Field))
case "already_exists":
errorSentences = append(errorSentences, fmt.Sprintf("Duplicate value for \"%s\"", err.Field))
case "invalid":
errorSentences = append(errorSentences, fmt.Sprintf("Invalid value for \"%s\"", err.Field))
case "unauthorized":
errorSentences = append(errorSentences, fmt.Sprintf("Not allowed to change field \"%s\"", err.Field))
}
}
var errorMessage string
if len(errorSentences) > 0 {
errorMessage = strings.Join(errorSentences, "\n")
} else {
errorMessage = e.Message
if action == "getting current user" && e.Message == "Resource not accessible by integration" {
errorMessage = errorMessage + "\nYou must specify GITHUB_USER via environment variable."
}
}
if errorMessage != "" {
errStr = fmt.Sprintf("%s\n%s", errStr, errorMessage)
}
ee = fmt.Errorf(errStr)
}
return
}
func authTokenNote(num int) (string, error) {
n := os.Getenv("USER")
if n == "" {
n = os.Getenv("USERNAME")
}
if n == "" {
whoami := exec.Command("whoami")
whoamiOut, err := whoami.Output()
if err != nil {
return "", err
}
n = strings.TrimSpace(string(whoamiOut))
}
h, err := os.Hostname()
if err != nil {
return "", err
}
if num > 1 {
return fmt.Sprintf("hub for %s@%s %d", n, h, num), nil
}
return fmt.Sprintf("hub for %s@%s", n, h), nil
}
func perPage(limit, max int) int {
if limit > 0 {
limit = limit + (limit / 2)
if limit < max {
return limit
}
}
return max
}
func addQuery(path string, params map[string]interface{}) string {
if len(params) == 0 {
return path
}
query := url.Values{}
for key, value := range params {
switch v := value.(type) {
case string:
query.Add(key, v)
case nil:
query.Add(key, "")
case int:
query.Add(key, fmt.Sprintf("%d", v))
case bool:
query.Add(key, fmt.Sprintf("%v", v))
}
}
sep := "?"
if strings.Contains(path, sep) {
sep = "&"
}
return path + sep + query.Encode()
}