fix(agent-task): resolve Copilot API URL dynamically (#12956)

* fix(agent-task): resolve Copilot API URL dynamically

Query viewer.copilotEndpoints.api to get the correct Copilot API URL
for the user's host instead of hardcoding api.githubcopilot.com. This
fixes 401 errors for ghe.com tenancy users whose Copilot API lives at
a different endpoint.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Kynan Ware 2026-03-18 12:14:02 -06:00 committed by GitHub
parent 37800dd56a
commit 78b958f9ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 166 additions and 33 deletions

View file

@ -3,13 +3,11 @@ package capi
import (
"context"
"net/http"
"net/url"
)
//go:generate moq -rm -out client_mock.go . CapiClient
const baseCAPIURL = "https://api.githubcopilot.com"
const capiHost = "api.githubcopilot.com"
// CapiClient defines the methods used by the caller. Implementations
// may be replaced with test doubles in unit tests.
type CapiClient interface {
@ -24,33 +22,42 @@ type CapiClient interface {
// CAPIClient is a client for interacting with the Copilot API
type CAPIClient struct {
httpClient *http.Client
host string
httpClient *http.Client
host string
capiBaseURL string
}
// NewCAPIClient creates a new CAPI client. Provide a token, host, and an HTTP client which
// will be used as the base transport for CAPI requests.
// NewCAPIClient creates a new CAPI client. Provide a token, the user's GitHub
// host, the resolved Copilot API URL, and an HTTP client which will be used as
// the base transport for CAPI requests.
//
// The provided HTTP client will be mutated for use with CAPI, so it should not
// be reused elsewhere.
func NewCAPIClient(httpClient *http.Client, token string, host string) *CAPIClient {
httpClient.Transport = newCAPITransport(token, httpClient.Transport)
func NewCAPIClient(httpClient *http.Client, token string, host string, capiBaseURL string) *CAPIClient {
httpClient.Transport = newCAPITransport(token, capiBaseURL, httpClient.Transport)
return &CAPIClient{
httpClient: httpClient,
host: host,
httpClient: httpClient,
host: host,
capiBaseURL: capiBaseURL,
}
}
// capiTransport adds the Copilot auth headers
type capiTransport struct {
rp http.RoundTripper
token string
rp http.RoundTripper
token string
capiHost string
}
func newCAPITransport(token string, rp http.RoundTripper) *capiTransport {
func newCAPITransport(token string, capiBaseURL string, rp http.RoundTripper) *capiTransport {
capiHost := ""
if u, err := url.Parse(capiBaseURL); err == nil {
capiHost = u.Host
}
return &capiTransport{
rp: rp,
token: token,
rp: rp,
token: token,
capiHost: capiHost,
}
}
@ -60,10 +67,10 @@ func (ct *capiTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Since this RoundTrip is reused for both Copilot API and
// GitHub API requests, we conditionally add the integration
// ID only when performing requests to the Copilot API.
if req.URL.Host == capiHost {
if req.URL.Host == ct.capiHost {
req.Header.Add("Copilot-Integration-Id", "copilot-4-cli")
// This is quick fix to ensure that we are not using GitHub API versions while targeting CAPI.
// Ensure we are not using GitHub API versions while targeting CAPI.
req.Header.Set("X-GitHub-Api-Version", "2026-01-09")
}
return ct.rp.RoundTrip(req)

View file

@ -51,7 +51,9 @@ type JobError struct {
Service string `json:"service"`
}
const jobsBasePathV1 = baseCAPIURL + "/agents/swe/v1/jobs"
func (c *CAPIClient) jobsBasePathV1() string {
return c.capiBaseURL + "/agents/swe/v1/jobs"
}
// CreateJob queues a new job using the v1 Jobs API. It may or may not
// return Pull Request information. If Pull Request information is required
@ -64,7 +66,7 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen
return nil, errors.New("problem statement is required")
}
url := fmt.Sprintf("%s/%s/%s", jobsBasePathV1, url.PathEscape(owner), url.PathEscape(repo))
url := fmt.Sprintf("%s/%s/%s", c.jobsBasePathV1(), url.PathEscape(owner), url.PathEscape(repo))
prOpts := JobPullRequest{}
if baseBranch != "" {
@ -130,7 +132,7 @@ func (c *CAPIClient) GetJob(ctx context.Context, owner, repo, jobID string) (*Jo
if owner == "" || repo == "" || jobID == "" {
return nil, errors.New("owner, repo, and jobID are required")
}
url := fmt.Sprintf("%s/%s/%s/%s", jobsBasePathV1, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(jobID))
url := fmt.Sprintf("%s/%s/%s/%s", c.jobsBasePathV1(), url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(jobID))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
return nil, err

View file

@ -167,7 +167,7 @@ func TestGetJob(t *testing.T) {
httpClient := &http.Client{Transport: reg}
capiClient := NewCAPIClient(httpClient, "", "github.com")
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")
job, err := capiClient.GetJob(context.Background(), "OWNER", "REPO", "job123")
@ -410,7 +410,7 @@ func TestCreateJob(t *testing.T) {
httpClient := &http.Client{Transport: reg}
capiClient := NewCAPIClient(httpClient, "", "github.com")
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")
job, err := capiClient.CreateJob(context.Background(), "OWNER", "REPO", "Do the thing", tt.baseBranch, tt.customAgent)

View file

@ -217,13 +217,16 @@ func (c *CAPIClient) ListLatestSessionsForViewer(ctx context.Context, limit int)
return nil, nil
}
url := baseCAPIURL + "/agents/sessions"
sessionsURL, err := url.JoinPath(c.capiBaseURL, "agents", "sessions")
if err != nil {
return nil, fmt.Errorf("failed to build sessions URL: %w", err)
}
pageSize := defaultSessionsPerPage
seenResources := make(map[int64]struct{})
latestSessions := make([]session, 0, limit)
for page := 1; ; page++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sessionsURL, http.NoBody)
if err != nil {
return nil, err
}
@ -296,7 +299,7 @@ func (c *CAPIClient) GetSession(ctx context.Context, id string) (*Session, error
return nil, fmt.Errorf("missing session ID")
}
url := fmt.Sprintf("%s/agents/sessions/%s", baseCAPIURL, url.PathEscape(id))
url := fmt.Sprintf("%s/agents/sessions/%s", c.capiBaseURL, url.PathEscape(id))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
@ -335,7 +338,7 @@ func (c *CAPIClient) GetSessionLogs(ctx context.Context, id string) ([]byte, err
return nil, fmt.Errorf("missing session ID")
}
url := fmt.Sprintf("%s/agents/sessions/%s/logs", baseCAPIURL, url.PathEscape(id))
url := fmt.Sprintf("%s/agents/sessions/%s/logs", c.capiBaseURL, url.PathEscape(id))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
@ -368,7 +371,7 @@ func (c *CAPIClient) ListSessionsByResourceID(ctx context.Context, resourceType
return nil, nil
}
url := fmt.Sprintf("%s/agents/resource/%s/%d", baseCAPIURL, url.PathEscape(resourceType), resourceID)
url := fmt.Sprintf("%s/agents/resource/%s/%d", c.capiBaseURL, url.PathEscape(resourceType), resourceID)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {

View file

@ -1161,7 +1161,7 @@ func TestListLatestSessionsForViewer(t *testing.T) {
httpClient := &http.Client{Transport: reg}
capiClient := NewCAPIClient(httpClient, "", "github.com")
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")
if tt.perPage != 0 {
last := defaultSessionsPerPage
@ -1540,7 +1540,7 @@ func TestListSessionsByResourceID(t *testing.T) {
httpClient := &http.Client{Transport: reg}
capiClient := NewCAPIClient(httpClient, "", "github.com")
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")
if tt.perPage != 0 {
last := defaultSessionsPerPage
@ -1819,7 +1819,7 @@ func TestGetSession(t *testing.T) {
httpClient := &http.Client{Transport: reg}
capiClient := NewCAPIClient(httpClient, "", "github.com")
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")
session, err := capiClient.GetSession(context.Background(), "some-uuid")
@ -1895,7 +1895,7 @@ func TestGetPullRequestDatabaseID(t *testing.T) {
httpClient := &http.Client{Transport: reg}
capiClient := NewCAPIClient(httpClient, "", "github.com")
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")
databaseID, url, err := capiClient.GetPullRequestDatabaseID(context.Background(), "github.com", "OWNER", "REPO", 42)

View file

@ -3,8 +3,11 @@ package shared
import (
"errors"
"fmt"
"net/http"
"regexp"
"time"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/pkg/cmd/agent-task/capi"
prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
@ -30,10 +33,40 @@ func CapiClientFunc(f *cmdutil.Factory) func() (capi.CapiClient, error) {
authCfg := cfg.Authentication()
host, _ := authCfg.DefaultHost()
token, _ := authCfg.ActiveToken(host)
return capi.NewCAPIClient(httpClient, token, host), nil
cachedClient := api.NewCachedHTTPClient(httpClient, time.Minute*10)
capiBaseURL, err := resolveCapiURL(cachedClient, host)
if err != nil {
return nil, fmt.Errorf("failed to resolve Copilot API URL: %w", err)
}
return capi.NewCAPIClient(httpClient, token, host, capiBaseURL), nil
}
}
// resolveCapiURL queries the GitHub API for the Copilot API endpoint URL.
func resolveCapiURL(httpClient *http.Client, host string) (string, error) {
apiClient := api.NewClientFromHTTP(httpClient)
var resp struct {
Viewer struct {
CopilotEndpoints struct {
Api string `graphql:"api"`
} `graphql:"copilotEndpoints"`
} `graphql:"viewer"`
}
if err := apiClient.Query(host, "CopilotEndpoints", &resp, nil); err != nil {
return "", err
}
if resp.Viewer.CopilotEndpoints.Api == "" {
return "", errors.New("empty Copilot API URL returned")
}
return resp.Viewer.CopilotEndpoints.Api, nil
}
func IsSessionID(s string) bool {
return sessionIDRegexp.MatchString(s)
}

View file

@ -1,12 +1,100 @@
package shared
import (
"net/http"
"testing"
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/internal/gh"
ghmock "github.com/cli/cli/v2/internal/gh/mock"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestResolveCapiURL(t *testing.T) {
tests := []struct {
name string
resp string
wantURL string
wantErr bool
}{
{
name: "returns resolved URL",
resp: `{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.example.com"}}}}`,
wantURL: "https://test-copilot-api.example.com",
},
{
name: "ghe.com tenant URL",
resp: `{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.tenant.example.com"}}}}`,
wantURL: "https://test-copilot-api.tenant.example.com",
},
{
name: "empty URL returns error",
resp: `{"data":{"viewer":{"copilotEndpoints":{"api":""}}}}`,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reg := &httpmock.Registry{}
defer reg.Verify(t)
reg.Register(
httpmock.GraphQL(`query CopilotEndpoints\b`),
httpmock.StringResponse(tt.resp),
)
httpClient := &http.Client{Transport: reg}
url, err := resolveCapiURL(httpClient, "github.com")
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantURL, url)
})
}
}
func TestCapiClientFuncResolvesURL(t *testing.T) {
reg := &httpmock.Registry{}
defer reg.Verify(t)
reg.Register(
httpmock.GraphQL(`query CopilotEndpoints\b`),
httpmock.StringResponse(`{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.example.com"}}}}`),
)
f := &cmdutil.Factory{
Config: func() (gh.Config, error) {
return &ghmock.ConfigMock{
AuthenticationFunc: func() gh.AuthConfig {
c := &config.AuthConfig{}
c.SetDefaultHost("github.com", "hosts")
c.SetActiveToken("gho_TOKEN", "oauth_token")
return c
},
}, nil
},
HttpClient: func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
},
}
clientFunc := CapiClientFunc(f)
client, err := clientFunc()
require.NoError(t, err)
require.NotNil(t, client)
// Verify the GraphQL resolution was called
require.Len(t, reg.Requests, 1)
}
func TestIsSession(t *testing.T) {
assert.True(t, IsSessionID("00000000-0000-0000-0000-000000000000"))
assert.True(t, IsSessionID("e2fa49d2-f164-4a56-ab99-498090b8fcdf"))