Add test
This commit is contained in:
parent
de98dbd378
commit
1231ddd01c
7 changed files with 149 additions and 91 deletions
|
|
@ -7,6 +7,7 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"path"
|
||||
)
|
||||
|
||||
// ClientOption represents an argument to NewClient
|
||||
|
|
@ -98,6 +99,39 @@ func (c Client) GraphQL(query string, variables map[string]interface{}, data int
|
|||
return handleResponse(resp, data)
|
||||
}
|
||||
|
||||
// REST performs a REST request and parses the response.
|
||||
func (c Client) REST(method string, p string, data interface{}) error {
|
||||
url := path.Join("https://api.github.com/", p)
|
||||
req, err := http.NewRequest(method, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
success := resp.StatusCode >= 200 && resp.StatusCode < 300
|
||||
if !success {
|
||||
return handleHTTPError(resp)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(body, &data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleResponse(resp *http.Response, data interface{}) error {
|
||||
success := resp.StatusCode >= 200 && resp.StatusCode < 300
|
||||
|
||||
|
|
|
|||
|
|
@ -21,9 +21,9 @@ var BuildDate = "YYYY-MM-DD"
|
|||
|
||||
func init() {
|
||||
RootCmd.Version = fmt.Sprintf("%s (%s)", strings.TrimPrefix(Version, "v"), BuildDate)
|
||||
RootCmd.AddCommand(versionCmd)
|
||||
RootCmd.AddCommand(versionCmd)
|
||||
|
||||
RootCmd.PersistentFlags().StringP("repo", "R", "", "Current GitHub repository")
|
||||
RootCmd.PersistentFlags().StringP("repo", "R", "", "Current GitHub repository")
|
||||
RootCmd.PersistentFlags().Bool("help", false, "Show help for command")
|
||||
RootCmd.Flags().Bool("version", false, "Print gh version")
|
||||
// TODO:
|
||||
|
|
@ -66,6 +66,10 @@ var initContext = func() context.Context {
|
|||
return ctx
|
||||
}
|
||||
|
||||
func BasicClient() (*api.Client, error) {
|
||||
return apiClientForContext(initContext())
|
||||
}
|
||||
|
||||
func contextForCommand(cmd *cobra.Command) context.Context {
|
||||
ctx := initContext()
|
||||
if repo, err := cmd.Flags().GetString("repo"); err == nil && repo != "" {
|
||||
|
|
|
|||
36
main.go
36
main.go
|
|
@ -10,15 +10,31 @@ import (
|
|||
)
|
||||
|
||||
func main() {
|
||||
isProduction := os.Getenv("APP_ENV") != "production"
|
||||
update.RunWhileCheckingForUpdate(isProduction, func() {
|
||||
if cmd, err := command.RootCmd.ExecuteC(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
_, isFlagError := err.(command.FlagError)
|
||||
if isFlagError || strings.HasPrefix(err.Error(), "unknown command ") {
|
||||
fmt.Fprintln(os.Stderr, cmd.UsageString())
|
||||
}
|
||||
os.Exit(1)
|
||||
alertMsgChan := make(chan *string)
|
||||
go updateInBackground(alertMsgChan)
|
||||
|
||||
if cmd, err := command.RootCmd.ExecuteC(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
_, isFlagError := err.(command.FlagError)
|
||||
if isFlagError || strings.HasPrefix(err.Error(), "unknown command ") {
|
||||
fmt.Fprintln(os.Stderr, cmd.UsageString())
|
||||
}
|
||||
})
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
alertMsg := <-alertMsgChan
|
||||
if alertMsg != nil {
|
||||
fmt.Fprintf(os.Stderr, *alertMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func updateInBackground(alertMsgChan chan *string) {
|
||||
client, err := command.BasicClient()
|
||||
if err != nil {
|
||||
alertMsgChan <- nil
|
||||
return
|
||||
}
|
||||
|
||||
alertMsg := update.UpdateMessage(client)
|
||||
alertMsgChan <- alertMsg
|
||||
}
|
||||
|
|
|
|||
4
test/fixtures/latestRelease.json
vendored
Normal file
4
test/fixtures/latestRelease.json
vendored
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
{
|
||||
"tag_name": "v1.0.0",
|
||||
"html_url": "https://www.spacejam.com/archive/spacejam/movie/jam.htm"
|
||||
}
|
||||
|
|
@ -1,79 +0,0 @@
|
|||
package update
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/github/gh-cli/command"
|
||||
"github.com/github/gh-cli/utils"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
)
|
||||
|
||||
const nwo = "github/homebrew-gh"
|
||||
|
||||
type releaseInfo struct {
|
||||
Version string `json:"tag_name"`
|
||||
URL string `json:"html_url"`
|
||||
}
|
||||
|
||||
func RunWhileCheckingForUpdate(isProduction bool, f func()) {
|
||||
if isProduction {
|
||||
f()
|
||||
return
|
||||
}
|
||||
|
||||
newReleaseChan := make(chan *releaseInfo)
|
||||
go checkForUpdate(newReleaseChan)
|
||||
f()
|
||||
|
||||
newRelease := <-newReleaseChan
|
||||
if newRelease != nil {
|
||||
fmt.Printf(utils.Cyan(`
|
||||
A new version of gh is available! %s → %s
|
||||
Changelog: %s
|
||||
Run 'brew upgrade gh' to update!`)+"\n\n", command.Version, newRelease.Version, newRelease.URL)
|
||||
}
|
||||
}
|
||||
|
||||
func checkForUpdate(newReleaseChan chan *releaseInfo) {
|
||||
// Ignore if this stdout is not a tty
|
||||
if !terminal.IsTerminal(int(os.Stdout.Fd())) {
|
||||
newReleaseChan <- nil
|
||||
return
|
||||
}
|
||||
|
||||
latestRelease, err := getLatestRelease()
|
||||
if err != nil {
|
||||
newReleaseChan <- nil
|
||||
return
|
||||
}
|
||||
|
||||
updateAvailable := latestRelease.Version != command.Version
|
||||
|
||||
if updateAvailable {
|
||||
newReleaseChan <- latestRelease
|
||||
} else {
|
||||
newReleaseChan <- nil
|
||||
}
|
||||
}
|
||||
|
||||
func getLatestRelease() (*releaseInfo, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", nwo)
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var r releaseInfo
|
||||
json.Unmarshal(data, &r)
|
||||
return &r, nil
|
||||
}
|
||||
43
update/update.go
Normal file
43
update/update.go
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
package update
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/github/gh-cli/api"
|
||||
"github.com/github/gh-cli/command"
|
||||
"github.com/github/gh-cli/utils"
|
||||
)
|
||||
|
||||
const nwo = "github/homebrew-gh"
|
||||
|
||||
type releaseInfo struct {
|
||||
Version string `json:"tag_name"`
|
||||
URL string `json:"html_url"`
|
||||
}
|
||||
|
||||
func UpdateMessage(client *api.Client) *string {
|
||||
latestRelease, err := getLatestRelease(client)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%+v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAvailable := latestRelease.Version != command.Version
|
||||
if updateAvailable {
|
||||
alertMsg := fmt.Sprintf(utils.Cyan(`
|
||||
A new version of gh is available! %s → %s
|
||||
Changelog: %s
|
||||
Run 'brew upgrade gh' to update!`)+"\n\n", command.Version, latestRelease.Version, latestRelease.URL)
|
||||
return &alertMsg
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getLatestRelease(client *api.Client) (*releaseInfo, error) {
|
||||
path := fmt.Sprintf("repos/%s/releases/latest", nwo)
|
||||
var r releaseInfo
|
||||
err := client.REST("GET", path, &r)
|
||||
return &r, err
|
||||
}
|
||||
36
update/update_test.go
Normal file
36
update/update_test.go
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
package update
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/github/gh-cli/api"
|
||||
"github.com/github/gh-cli/command"
|
||||
)
|
||||
|
||||
func TestRunWhileCheckingForUpdate(t *testing.T) {
|
||||
originalVersion := command.Version
|
||||
command.Version = "v0.0.0"
|
||||
defer func() {
|
||||
command.Version = originalVersion
|
||||
}()
|
||||
|
||||
http := &api.FakeHTTP{}
|
||||
jsonFile, _ := os.Open("../test/fixtures/latestRelease.json")
|
||||
defer jsonFile.Close()
|
||||
http.StubResponse(200, jsonFile)
|
||||
|
||||
client := api.NewClient(api.ReplaceTripper(http))
|
||||
alertMsg := *UpdateMessage(client)
|
||||
fmt.Printf("🌭 %+v\n", alertMsg)
|
||||
|
||||
if !strings.Contains(alertMsg, command.Version) {
|
||||
t.Errorf("expected: \"%v\" to contain \"%v\"", alertMsg, command.Version)
|
||||
}
|
||||
|
||||
if !strings.Contains(alertMsg, "v1.0.0") {
|
||||
t.Errorf("expected: \"%v\" to contain \"%v\"", alertMsg, "v1.0.0")
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue