From fac2575c7346cd34b6e2a17f26309e4437d89ece Mon Sep 17 00:00:00 2001 From: Marwan Sulaiman Date: Tue, 25 Jan 2022 12:54:20 -0500 Subject: [PATCH] Add retries to Codespaces API client (#5064) --- internal/codespaces/api/api.go | 77 ++++++++++++++++++----------- internal/codespaces/api/api_test.go | 54 ++++++++++++++++++++ 2 files changed, 102 insertions(+), 29 deletions(-) diff --git a/internal/codespaces/api/api.go b/internal/codespaces/api/api.go index 8de16bbb8..2df2005ac 100644 --- a/internal/codespaces/api/api.go +++ b/internal/codespaces/api/api.go @@ -57,6 +57,7 @@ type API struct { vscsAPI string githubAPI string githubServer string + retryBackoff time.Duration } type httpClient interface { @@ -79,6 +80,7 @@ func New(serverURL, apiURL, vscsURL string, httpClient httpClient) *API { vscsAPI: strings.TrimSuffix(vscsURL, "/"), githubAPI: strings.TrimSuffix(apiURL, "/"), githubServer: strings.TrimSuffix(serverURL, "/"), + retryBackoff: 100 * time.Millisecond, } } @@ -301,24 +303,24 @@ func findNextPage(linkValue string) string { // If the codespace is not found, an error is returned. // If includeConnection is true, it will return the connection information for the codespace. func (a *API) GetCodespace(ctx context.Context, codespaceName string, includeConnection bool) (*Codespace, error) { - req, err := http.NewRequest( - http.MethodGet, - a.githubAPI+"/user/codespaces/"+codespaceName, - nil, - ) - if err != nil { - return nil, fmt.Errorf("error creating request: %w", err) - } - - if includeConnection { - q := req.URL.Query() - q.Add("internal", "true") - q.Add("refresh", "true") - req.URL.RawQuery = q.Encode() - } - - a.setHeaders(req) - resp, err := a.do(ctx, req, "/user/codespaces/*") + resp, err := a.withRetry(func() (*http.Response, error) { + req, err := http.NewRequest( + http.MethodGet, + a.githubAPI+"/user/codespaces/"+codespaceName, + nil, + ) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + if includeConnection { + q := req.URL.Query() + q.Add("internal", "true") + q.Add("refresh", "true") + req.URL.RawQuery = q.Encode() + } + a.setHeaders(req) + return a.do(ctx, req, "/user/codespaces/*") + }) if err != nil { return nil, fmt.Errorf("error making request: %w", err) } @@ -344,17 +346,18 @@ func (a *API) GetCodespace(ctx context.Context, codespaceName string, includeCon // StartCodespace starts a codespace for the user. // If the codespace is already running, the returned error from the API is ignored. func (a *API) StartCodespace(ctx context.Context, codespaceName string) error { - req, err := http.NewRequest( - http.MethodPost, - a.githubAPI+"/user/codespaces/"+codespaceName+"/start", - nil, - ) - if err != nil { - return fmt.Errorf("error creating request: %w", err) - } - - a.setHeaders(req) - resp, err := a.do(ctx, req, "/user/codespaces/*/start") + resp, err := a.withRetry(func() (*http.Response, error) { + req, err := http.NewRequest( + http.MethodPost, + a.githubAPI+"/user/codespaces/"+codespaceName+"/start", + nil, + ) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + a.setHeaders(req) + return a.do(ctx, req, "/user/codespaces/*/start") + }) if err != nil { return fmt.Errorf("error making request: %w", err) } @@ -686,3 +689,19 @@ func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http func (a *API) setHeaders(req *http.Request) { req.Header.Set("Accept", "application/vnd.github.v3+json") } + +// withRetry takes a generic function that sends an http request and retries +// only when the returned response has a >=500 status code. +func (a *API) withRetry(f func() (*http.Response, error)) (resp *http.Response, err error) { + for i := 0; i < 5; i++ { + resp, err = f() + if err != nil { + return nil, err + } + if resp.StatusCode < 500 { + break + } + time.Sleep(a.retryBackoff * (time.Duration(i) + 1)) + } + return resp, err +} diff --git a/internal/codespaces/api/api_test.go b/internal/codespaces/api/api_test.go index 6dcd06a04..bb379c2fe 100644 --- a/internal/codespaces/api/api_test.go +++ b/internal/codespaces/api/api_test.go @@ -114,3 +114,57 @@ func TestListCodespaces_unlimited(t *testing.T) { t.Fatalf("expected codespace-249, got %s", codespaces[0].Name) } } + +func TestRetries(t *testing.T) { + var callCount int + csName := "test_codespace" + handler := func(w http.ResponseWriter, r *http.Request) { + if callCount == 3 { + err := json.NewEncoder(w).Encode(Codespace{ + Name: csName, + }) + if err != nil { + t.Fatal(err) + } + return + } + callCount++ + w.WriteHeader(502) + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler(w, r) })) + t.Cleanup(srv.Close) + a := &API{ + githubAPI: srv.URL, + client: &http.Client{}, + } + cs, err := a.GetCodespace(context.Background(), "test", false) + if err != nil { + t.Fatal(err) + } + if callCount != 3 { + t.Fatalf("expected at least 2 retries but got %d", callCount) + } + if cs.Name != csName { + t.Fatalf("expected codespace name to be %q but got %q", csName, cs.Name) + } + callCount = 0 + handler = func(w http.ResponseWriter, r *http.Request) { + callCount++ + err := json.NewEncoder(w).Encode(Codespace{ + Name: csName, + }) + if err != nil { + t.Fatal(err) + } + } + cs, err = a.GetCodespace(context.Background(), "test", false) + if err != nil { + t.Fatal(err) + } + if callCount != 1 { + t.Fatalf("expected no retries but got %d calls", callCount) + } + if cs.Name != csName { + t.Fatalf("expected codespace name to be %q but got %q", csName, cs.Name) + } +}