This commit is contained in:
Corey Johnson 2019-12-02 15:08:36 -08:00
parent de98dbd378
commit 1231ddd01c
7 changed files with 149 additions and 91 deletions

View file

@ -7,6 +7,7 @@ import (
"io"
"io/ioutil"
"net/http"
"path"
)
// ClientOption represents an argument to NewClient
@ -98,6 +99,39 @@ func (c Client) GraphQL(query string, variables map[string]interface{}, data int
return handleResponse(resp, data)
}
// REST performs a REST request and parses the response.
func (c Client) REST(method string, p string, data interface{}) error {
url := path.Join("https://api.github.com/", p)
req, err := http.NewRequest(method, url, nil)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
resp, err := c.http.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
success := resp.StatusCode >= 200 && resp.StatusCode < 300
if !success {
return handleHTTPError(resp)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
err = json.Unmarshal(body, &data)
if err != nil {
return err
}
return nil
}
func handleResponse(resp *http.Response, data interface{}) error {
success := resp.StatusCode >= 200 && resp.StatusCode < 300

View file

@ -21,9 +21,9 @@ var BuildDate = "YYYY-MM-DD"
func init() {
RootCmd.Version = fmt.Sprintf("%s (%s)", strings.TrimPrefix(Version, "v"), BuildDate)
RootCmd.AddCommand(versionCmd)
RootCmd.AddCommand(versionCmd)
RootCmd.PersistentFlags().StringP("repo", "R", "", "Current GitHub repository")
RootCmd.PersistentFlags().StringP("repo", "R", "", "Current GitHub repository")
RootCmd.PersistentFlags().Bool("help", false, "Show help for command")
RootCmd.Flags().Bool("version", false, "Print gh version")
// TODO:
@ -66,6 +66,10 @@ var initContext = func() context.Context {
return ctx
}
func BasicClient() (*api.Client, error) {
return apiClientForContext(initContext())
}
func contextForCommand(cmd *cobra.Command) context.Context {
ctx := initContext()
if repo, err := cmd.Flags().GetString("repo"); err == nil && repo != "" {

36
main.go
View file

@ -10,15 +10,31 @@ import (
)
func main() {
isProduction := os.Getenv("APP_ENV") != "production"
update.RunWhileCheckingForUpdate(isProduction, func() {
if cmd, err := command.RootCmd.ExecuteC(); err != nil {
fmt.Fprintln(os.Stderr, err)
_, isFlagError := err.(command.FlagError)
if isFlagError || strings.HasPrefix(err.Error(), "unknown command ") {
fmt.Fprintln(os.Stderr, cmd.UsageString())
}
os.Exit(1)
alertMsgChan := make(chan *string)
go updateInBackground(alertMsgChan)
if cmd, err := command.RootCmd.ExecuteC(); err != nil {
fmt.Fprintln(os.Stderr, err)
_, isFlagError := err.(command.FlagError)
if isFlagError || strings.HasPrefix(err.Error(), "unknown command ") {
fmt.Fprintln(os.Stderr, cmd.UsageString())
}
})
os.Exit(1)
}
alertMsg := <-alertMsgChan
if alertMsg != nil {
fmt.Fprintf(os.Stderr, *alertMsg)
}
}
func updateInBackground(alertMsgChan chan *string) {
client, err := command.BasicClient()
if err != nil {
alertMsgChan <- nil
return
}
alertMsg := update.UpdateMessage(client)
alertMsgChan <- alertMsg
}

4
test/fixtures/latestRelease.json vendored Normal file
View file

@ -0,0 +1,4 @@
{
"tag_name": "v1.0.0",
"html_url": "https://www.spacejam.com/archive/spacejam/movie/jam.htm"
}

View file

@ -1,79 +0,0 @@
package update
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"github.com/github/gh-cli/command"
"github.com/github/gh-cli/utils"
"golang.org/x/crypto/ssh/terminal"
)
const nwo = "github/homebrew-gh"
type releaseInfo struct {
Version string `json:"tag_name"`
URL string `json:"html_url"`
}
func RunWhileCheckingForUpdate(isProduction bool, f func()) {
if isProduction {
f()
return
}
newReleaseChan := make(chan *releaseInfo)
go checkForUpdate(newReleaseChan)
f()
newRelease := <-newReleaseChan
if newRelease != nil {
fmt.Printf(utils.Cyan(`
A new version of gh is available! %s %s
Changelog: %s
Run 'brew upgrade gh' to update!`)+"\n\n", command.Version, newRelease.Version, newRelease.URL)
}
}
func checkForUpdate(newReleaseChan chan *releaseInfo) {
// Ignore if this stdout is not a tty
if !terminal.IsTerminal(int(os.Stdout.Fd())) {
newReleaseChan <- nil
return
}
latestRelease, err := getLatestRelease()
if err != nil {
newReleaseChan <- nil
return
}
updateAvailable := latestRelease.Version != command.Version
if updateAvailable {
newReleaseChan <- latestRelease
} else {
newReleaseChan <- nil
}
}
func getLatestRelease() (*releaseInfo, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", nwo)
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var r releaseInfo
json.Unmarshal(data, &r)
return &r, nil
}

43
update/update.go Normal file
View file

@ -0,0 +1,43 @@
package update
import (
"fmt"
"os"
"github.com/github/gh-cli/api"
"github.com/github/gh-cli/command"
"github.com/github/gh-cli/utils"
)
const nwo = "github/homebrew-gh"
type releaseInfo struct {
Version string `json:"tag_name"`
URL string `json:"html_url"`
}
func UpdateMessage(client *api.Client) *string {
latestRelease, err := getLatestRelease(client)
if err != nil {
fmt.Fprintf(os.Stderr, "%+v", err)
return nil
}
updateAvailable := latestRelease.Version != command.Version
if updateAvailable {
alertMsg := fmt.Sprintf(utils.Cyan(`
A new version of gh is available! %s %s
Changelog: %s
Run 'brew upgrade gh' to update!`)+"\n\n", command.Version, latestRelease.Version, latestRelease.URL)
return &alertMsg
}
return nil
}
func getLatestRelease(client *api.Client) (*releaseInfo, error) {
path := fmt.Sprintf("repos/%s/releases/latest", nwo)
var r releaseInfo
err := client.REST("GET", path, &r)
return &r, err
}

36
update/update_test.go Normal file
View file

@ -0,0 +1,36 @@
package update
import (
"fmt"
"os"
"strings"
"testing"
"github.com/github/gh-cli/api"
"github.com/github/gh-cli/command"
)
func TestRunWhileCheckingForUpdate(t *testing.T) {
originalVersion := command.Version
command.Version = "v0.0.0"
defer func() {
command.Version = originalVersion
}()
http := &api.FakeHTTP{}
jsonFile, _ := os.Open("../test/fixtures/latestRelease.json")
defer jsonFile.Close()
http.StubResponse(200, jsonFile)
client := api.NewClient(api.ReplaceTripper(http))
alertMsg := *UpdateMessage(client)
fmt.Printf("🌭 %+v\n", alertMsg)
if !strings.Contains(alertMsg, command.Version) {
t.Errorf("expected: \"%v\" to contain \"%v\"", alertMsg, command.Version)
}
if !strings.Contains(alertMsg, "v1.0.0") {
t.Errorf("expected: \"%v\" to contain \"%v\"", alertMsg, "v1.0.0")
}
}