diff --git a/internal/authflow/flow.go b/internal/authflow/flow.go index ebb5b37a0..0a195168f 100644 --- a/internal/authflow/flow.go +++ b/internal/authflow/flow.go @@ -97,7 +97,7 @@ func AuthFlow(httpClient *http.Client, oauthHost string, IO *iostreams.IOStreams return "", "", err } - userLogin, err := getViewer(oauthHost, token.Token, IO.ErrOut) + userLogin, err := getViewer(httpClient, oauthHost, token.Token) if err != nil { return "", "", err } @@ -123,16 +123,10 @@ func (c cfg) ActiveToken(hostname string) (string, string) { return c.token, "oauth_token" } -func getViewer(hostname, token string, logWriter io.Writer) (string, error) { - opts := api.HTTPClientOptions{ - Config: cfg{token: token}, - Log: logWriter, - } - client, err := api.NewHTTPClient(opts) - if err != nil { - return "", err - } - return api.CurrentLoginName(api.NewClientFromHTTP(client), hostname) +func getViewer(httpClient *http.Client, hostname, token string) (string, error) { + authedClient := *httpClient + authedClient.Transport = api.AddAuthTokenHeader(httpClient.Transport, cfg{token: token}) + return api.CurrentLoginName(api.NewClientFromHTTP(&authedClient), hostname) } func waitForEnter(r io.Reader) error { diff --git a/internal/authflow/flow_test.go b/internal/authflow/flow_test.go index b7ba1f64a..9811c3206 100644 --- a/internal/authflow/flow_test.go +++ b/internal/authflow/flow_test.go @@ -1,11 +1,54 @@ package authflow import ( + "bytes" + "io" + "net/http" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func Test_getViewer_preservesUserAgent(t *testing.T) { + var receivedUA string + var receivedAuth string + + // Outer transport sets User-Agent, simulating the factory-built client's header middleware. + // Inner transport captures headers as-received to verify they survived the wrapping. + plainClient := &http.Client{ + Transport: &roundTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", "GitHub CLI 1.2.3 Agent/copilot-cli") + return (&http.Client{ + Transport: &roundTripper{roundTrip: func(req *http.Request) (*http.Response, error) { + receivedUA = req.Header.Get("User-Agent") + receivedAuth = req.Header.Get("Authorization") + return &http.Response{ + StatusCode: 200, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewBufferString(`{"data":{"viewer":{"login":"monalisa"}}}`)), + Request: req, + }, nil + }}, + }).Transport.RoundTrip(req) + }}, + } + + login, err := getViewer(plainClient, "github.com", "test-token") + require.NoError(t, err) + assert.Equal(t, "monalisa", login) + assert.Equal(t, "GitHub CLI 1.2.3 Agent/copilot-cli", receivedUA) + assert.Equal(t, "token test-token", receivedAuth) +} + +type roundTripper struct { + roundTrip func(*http.Request) (*http.Response, error) +} + +func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return t.roundTrip(req) +} + func Test_getCallbackURI(t *testing.T) { tests := []struct { name string