From 626c639df5e9c02a9f02ecd67434182a52673d47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mislav=20Marohni=C4=87?= Date: Tue, 7 Feb 2023 20:52:53 +0100 Subject: [PATCH] Disallow update checker delaying the gh process (#6978) This ensures that checking for newer versions of gh happens in the background of the main operation that the user requested, and that when that operation is completed, the gh process should immediately exit without being delayed by the update checker goroutine. --- cmd/gh/main.go | 35 +++++++++++++++++++--------------- internal/update/update.go | 32 +++++++++++++++++++++++-------- internal/update/update_test.go | 5 ++--- 3 files changed, 46 insertions(+), 26 deletions(-) diff --git a/cmd/gh/main.go b/cmd/gh/main.go index 077e91995..eb0ad0672 100644 --- a/cmd/gh/main.go +++ b/cmd/gh/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "io" @@ -53,17 +54,24 @@ func main() { func mainRun() exitCode { buildDate := build.Date buildVersion := build.Version - - updateMessageChan := make(chan *update.ReleaseInfo) - go func() { - rel, _ := checkForUpdate(buildVersion) - updateMessageChan <- rel - }() - hasDebug, _ := utils.IsDebugEnabled() cmdFactory := factory.New(buildVersion) stderr := cmdFactory.IOStreams.ErrOut + + ctx := context.Background() + + updateCtx, updateCancel := context.WithCancel(ctx) + defer updateCancel() + updateMessageChan := make(chan *update.ReleaseInfo) + go func() { + rel, err := checkForUpdate(updateCtx, cmdFactory, buildVersion) + if err != nil && hasDebug { + fmt.Fprintf(stderr, "warning: checking for update failed: %v", err) + } + updateMessageChan <- rel + }() + if !cmdFactory.IOStreams.ColorEnabled() { surveyCore.DisableColor = true ansi.DisableColors(true) @@ -209,7 +217,7 @@ func mainRun() exitCode { rootCmd.SetArgs(expandedArgs) - if cmd, err := rootCmd.ExecuteC(); err != nil { + if cmd, err := rootCmd.ExecuteContextC(ctx); err != nil { var pagerPipeError *iostreams.ErrClosedPagerPipe var noResultsError cmdutil.NoResultsError if err == cmdutil.SilentError { @@ -257,6 +265,7 @@ func mainRun() exitCode { return exitError } + updateCancel() // if the update checker hasn't completed by now, abort it newRelease := <-updateMessageChan if newRelease != nil { isHomebrew := isUnderHomebrew(cmdFactory.Executable()) @@ -348,21 +357,17 @@ func isCI() bool { os.Getenv("RUN_ID") != "" // TaskCluster, dsari } -func checkForUpdate(currentVersion string) (*update.ReleaseInfo, error) { +func checkForUpdate(ctx context.Context, f *cmdutil.Factory, currentVersion string) (*update.ReleaseInfo, error) { if !shouldCheckForUpdate() { return nil, nil } - httpClient, err := api.NewHTTPClient(api.HTTPClientOptions{ - AppVersion: currentVersion, - Log: os.Stderr, - }) + httpClient, err := f.HttpClient() if err != nil { return nil, err } - client := api.NewClientFromHTTP(httpClient) repo := updaterEnabled stateFilePath := filepath.Join(config.StateDir(), "state.yml") - return update.CheckForUpdate(client, stateFilePath, repo, currentVersion) + return update.CheckForUpdate(ctx, httpClient, stateFilePath, repo, currentVersion) } func isRecentRelease(publishedAt time.Time) bool { diff --git a/internal/update/update.go b/internal/update/update.go index e9ada22f6..6d69eeada 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -1,7 +1,11 @@ package update import ( + "context" + "encoding/json" "fmt" + "io" + "net/http" "os" "path/filepath" "regexp" @@ -9,8 +13,6 @@ import ( "strings" "time" - "github.com/cli/cli/v2/api" - "github.com/cli/cli/v2/internal/ghinstance" "github.com/hashicorp/go-version" "gopkg.in/yaml.v3" ) @@ -30,13 +32,13 @@ type StateEntry struct { } // CheckForUpdate checks whether this software has had a newer release on GitHub -func CheckForUpdate(client *api.Client, stateFilePath, repo, currentVersion string) (*ReleaseInfo, error) { +func CheckForUpdate(ctx context.Context, client *http.Client, stateFilePath, repo, currentVersion string) (*ReleaseInfo, error) { stateEntry, _ := getStateEntry(stateFilePath) if stateEntry != nil && time.Since(stateEntry.CheckedForUpdateAt).Hours() < 24 { return nil, nil } - releaseInfo, err := getLatestReleaseInfo(client, repo) + releaseInfo, err := getLatestReleaseInfo(ctx, client, repo) if err != nil { return nil, err } @@ -53,13 +55,27 @@ func CheckForUpdate(client *api.Client, stateFilePath, repo, currentVersion stri return nil, nil } -func getLatestReleaseInfo(client *api.Client, repo string) (*ReleaseInfo, error) { - var latestRelease ReleaseInfo - err := client.REST(ghinstance.Default(), "GET", fmt.Sprintf("repos/%s/releases/latest", repo), nil, &latestRelease) +func getLatestReleaseInfo(ctx context.Context, client *http.Client, repo string) (*ReleaseInfo, error) { + req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo), nil) if err != nil { return nil, err } - + res, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { + _, _ = io.Copy(io.Discard, res.Body) + res.Body.Close() + }() + if res.StatusCode != 200 { + return nil, fmt.Errorf("unexpected HTTP %d", res.StatusCode) + } + dec := json.NewDecoder(res.Body) + var latestRelease ReleaseInfo + if err := dec.Decode(&latestRelease); err != nil { + return nil, err + } return &latestRelease, nil } diff --git a/internal/update/update_test.go b/internal/update/update_test.go index 96dce4f2a..bb514adfc 100644 --- a/internal/update/update_test.go +++ b/internal/update/update_test.go @@ -1,13 +1,13 @@ package update import ( + "context" "fmt" "log" "net/http" "os" "testing" - "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/pkg/httpmock" ) @@ -75,7 +75,6 @@ func TestCheckForUpdate(t *testing.T) { reg := &httpmock.Registry{} httpClient := &http.Client{} httpmock.ReplaceTripper(httpClient, reg) - client := api.NewClientFromHTTP(httpClient) reg.Register( httpmock.REST("GET", "repos/OWNER/REPO/releases/latest"), @@ -85,7 +84,7 @@ func TestCheckForUpdate(t *testing.T) { }`, s.LatestVersion, s.LatestURL)), ) - rel, err := CheckForUpdate(client, tempFilePath(), "OWNER/REPO", s.CurrentVersion) + rel, err := CheckForUpdate(context.TODO(), httpClient, tempFilePath(), "OWNER/REPO", s.CurrentVersion) if err != nil { t.Fatal(err) }