diff --git a/cmd/gh/main.go b/cmd/gh/main.go index 077e91995..eb0ad0672 100644 --- a/cmd/gh/main.go +++ b/cmd/gh/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "io" @@ -53,17 +54,24 @@ func main() { func mainRun() exitCode { buildDate := build.Date buildVersion := build.Version - - updateMessageChan := make(chan *update.ReleaseInfo) - go func() { - rel, _ := checkForUpdate(buildVersion) - updateMessageChan <- rel - }() - hasDebug, _ := utils.IsDebugEnabled() cmdFactory := factory.New(buildVersion) stderr := cmdFactory.IOStreams.ErrOut + + ctx := context.Background() + + updateCtx, updateCancel := context.WithCancel(ctx) + defer updateCancel() + updateMessageChan := make(chan *update.ReleaseInfo) + go func() { + rel, err := checkForUpdate(updateCtx, cmdFactory, buildVersion) + if err != nil && hasDebug { + fmt.Fprintf(stderr, "warning: checking for update failed: %v", err) + } + updateMessageChan <- rel + }() + if !cmdFactory.IOStreams.ColorEnabled() { surveyCore.DisableColor = true ansi.DisableColors(true) @@ -209,7 +217,7 @@ func mainRun() exitCode { rootCmd.SetArgs(expandedArgs) - if cmd, err := rootCmd.ExecuteC(); err != nil { + if cmd, err := rootCmd.ExecuteContextC(ctx); err != nil { var pagerPipeError *iostreams.ErrClosedPagerPipe var noResultsError cmdutil.NoResultsError if err == cmdutil.SilentError { @@ -257,6 +265,7 @@ func mainRun() exitCode { return exitError } + updateCancel() // if the update checker hasn't completed by now, abort it newRelease := <-updateMessageChan if newRelease != nil { isHomebrew := isUnderHomebrew(cmdFactory.Executable()) @@ -348,21 +357,17 @@ func isCI() bool { os.Getenv("RUN_ID") != "" // TaskCluster, dsari } -func checkForUpdate(currentVersion string) (*update.ReleaseInfo, error) { +func checkForUpdate(ctx context.Context, f *cmdutil.Factory, currentVersion string) (*update.ReleaseInfo, error) { if !shouldCheckForUpdate() { return nil, nil } - httpClient, err := api.NewHTTPClient(api.HTTPClientOptions{ - AppVersion: currentVersion, - Log: os.Stderr, - }) + httpClient, err := f.HttpClient() if err != nil { return nil, err } - client := api.NewClientFromHTTP(httpClient) repo := updaterEnabled stateFilePath := filepath.Join(config.StateDir(), "state.yml") - return update.CheckForUpdate(client, stateFilePath, repo, currentVersion) + return update.CheckForUpdate(ctx, httpClient, stateFilePath, repo, currentVersion) } func isRecentRelease(publishedAt time.Time) bool { diff --git a/internal/update/update.go b/internal/update/update.go index e9ada22f6..6d69eeada 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -1,7 +1,11 @@ package update import ( + "context" + "encoding/json" "fmt" + "io" + "net/http" "os" "path/filepath" "regexp" @@ -9,8 +13,6 @@ import ( "strings" "time" - "github.com/cli/cli/v2/api" - "github.com/cli/cli/v2/internal/ghinstance" "github.com/hashicorp/go-version" "gopkg.in/yaml.v3" ) @@ -30,13 +32,13 @@ type StateEntry struct { } // CheckForUpdate checks whether this software has had a newer release on GitHub -func CheckForUpdate(client *api.Client, stateFilePath, repo, currentVersion string) (*ReleaseInfo, error) { +func CheckForUpdate(ctx context.Context, client *http.Client, stateFilePath, repo, currentVersion string) (*ReleaseInfo, error) { stateEntry, _ := getStateEntry(stateFilePath) if stateEntry != nil && time.Since(stateEntry.CheckedForUpdateAt).Hours() < 24 { return nil, nil } - releaseInfo, err := getLatestReleaseInfo(client, repo) + releaseInfo, err := getLatestReleaseInfo(ctx, client, repo) if err != nil { return nil, err } @@ -53,13 +55,27 @@ func CheckForUpdate(client *api.Client, stateFilePath, repo, currentVersion stri return nil, nil } -func getLatestReleaseInfo(client *api.Client, repo string) (*ReleaseInfo, error) { - var latestRelease ReleaseInfo - err := client.REST(ghinstance.Default(), "GET", fmt.Sprintf("repos/%s/releases/latest", repo), nil, &latestRelease) +func getLatestReleaseInfo(ctx context.Context, client *http.Client, repo string) (*ReleaseInfo, error) { + req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo), nil) if err != nil { return nil, err } - + res, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { + _, _ = io.Copy(io.Discard, res.Body) + res.Body.Close() + }() + if res.StatusCode != 200 { + return nil, fmt.Errorf("unexpected HTTP %d", res.StatusCode) + } + dec := json.NewDecoder(res.Body) + var latestRelease ReleaseInfo + if err := dec.Decode(&latestRelease); err != nil { + return nil, err + } return &latestRelease, nil } diff --git a/internal/update/update_test.go b/internal/update/update_test.go index 96dce4f2a..bb514adfc 100644 --- a/internal/update/update_test.go +++ b/internal/update/update_test.go @@ -1,13 +1,13 @@ package update import ( + "context" "fmt" "log" "net/http" "os" "testing" - "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/pkg/httpmock" ) @@ -75,7 +75,6 @@ func TestCheckForUpdate(t *testing.T) { reg := &httpmock.Registry{} httpClient := &http.Client{} httpmock.ReplaceTripper(httpClient, reg) - client := api.NewClientFromHTTP(httpClient) reg.Register( httpmock.REST("GET", "repos/OWNER/REPO/releases/latest"), @@ -85,7 +84,7 @@ func TestCheckForUpdate(t *testing.T) { }`, s.LatestVersion, s.LatestURL)), ) - rel, err := CheckForUpdate(client, tempFilePath(), "OWNER/REPO", s.CurrentVersion) + rel, err := CheckForUpdate(context.TODO(), httpClient, tempFilePath(), "OWNER/REPO", s.CurrentVersion) if err != nil { t.Fatal(err) }