fix(agent-task): resolve Copilot API URL dynamically (#12956)
* fix(agent-task): resolve Copilot API URL dynamically Query viewer.copilotEndpoints.api to get the correct Copilot API URL for the user's host instead of hardcoding api.githubcopilot.com. This fixes 401 errors for ghe.com tenancy users whose Copilot API lives at a different endpoint. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
37800dd56a
commit
78b958f9ae
7 changed files with 166 additions and 33 deletions
|
|
@ -3,13 +3,11 @@ package capi
|
|||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
//go:generate moq -rm -out client_mock.go . CapiClient
|
||||
|
||||
const baseCAPIURL = "https://api.githubcopilot.com"
|
||||
const capiHost = "api.githubcopilot.com"
|
||||
|
||||
// CapiClient defines the methods used by the caller. Implementations
|
||||
// may be replaced with test doubles in unit tests.
|
||||
type CapiClient interface {
|
||||
|
|
@ -24,33 +22,42 @@ type CapiClient interface {
|
|||
|
||||
// CAPIClient is a client for interacting with the Copilot API
|
||||
type CAPIClient struct {
|
||||
httpClient *http.Client
|
||||
host string
|
||||
httpClient *http.Client
|
||||
host string
|
||||
capiBaseURL string
|
||||
}
|
||||
|
||||
// NewCAPIClient creates a new CAPI client. Provide a token, host, and an HTTP client which
|
||||
// will be used as the base transport for CAPI requests.
|
||||
// NewCAPIClient creates a new CAPI client. Provide a token, the user's GitHub
|
||||
// host, the resolved Copilot API URL, and an HTTP client which will be used as
|
||||
// the base transport for CAPI requests.
|
||||
//
|
||||
// The provided HTTP client will be mutated for use with CAPI, so it should not
|
||||
// be reused elsewhere.
|
||||
func NewCAPIClient(httpClient *http.Client, token string, host string) *CAPIClient {
|
||||
httpClient.Transport = newCAPITransport(token, httpClient.Transport)
|
||||
func NewCAPIClient(httpClient *http.Client, token string, host string, capiBaseURL string) *CAPIClient {
|
||||
httpClient.Transport = newCAPITransport(token, capiBaseURL, httpClient.Transport)
|
||||
return &CAPIClient{
|
||||
httpClient: httpClient,
|
||||
host: host,
|
||||
httpClient: httpClient,
|
||||
host: host,
|
||||
capiBaseURL: capiBaseURL,
|
||||
}
|
||||
}
|
||||
|
||||
// capiTransport adds the Copilot auth headers
|
||||
type capiTransport struct {
|
||||
rp http.RoundTripper
|
||||
token string
|
||||
rp http.RoundTripper
|
||||
token string
|
||||
capiHost string
|
||||
}
|
||||
|
||||
func newCAPITransport(token string, rp http.RoundTripper) *capiTransport {
|
||||
func newCAPITransport(token string, capiBaseURL string, rp http.RoundTripper) *capiTransport {
|
||||
capiHost := ""
|
||||
if u, err := url.Parse(capiBaseURL); err == nil {
|
||||
capiHost = u.Host
|
||||
}
|
||||
return &capiTransport{
|
||||
rp: rp,
|
||||
token: token,
|
||||
rp: rp,
|
||||
token: token,
|
||||
capiHost: capiHost,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -60,10 +67,10 @@ func (ct *capiTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
// Since this RoundTrip is reused for both Copilot API and
|
||||
// GitHub API requests, we conditionally add the integration
|
||||
// ID only when performing requests to the Copilot API.
|
||||
if req.URL.Host == capiHost {
|
||||
if req.URL.Host == ct.capiHost {
|
||||
req.Header.Add("Copilot-Integration-Id", "copilot-4-cli")
|
||||
|
||||
// This is quick fix to ensure that we are not using GitHub API versions while targeting CAPI.
|
||||
// Ensure we are not using GitHub API versions while targeting CAPI.
|
||||
req.Header.Set("X-GitHub-Api-Version", "2026-01-09")
|
||||
}
|
||||
return ct.rp.RoundTrip(req)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,9 @@ type JobError struct {
|
|||
Service string `json:"service"`
|
||||
}
|
||||
|
||||
const jobsBasePathV1 = baseCAPIURL + "/agents/swe/v1/jobs"
|
||||
func (c *CAPIClient) jobsBasePathV1() string {
|
||||
return c.capiBaseURL + "/agents/swe/v1/jobs"
|
||||
}
|
||||
|
||||
// CreateJob queues a new job using the v1 Jobs API. It may or may not
|
||||
// return Pull Request information. If Pull Request information is required
|
||||
|
|
@ -64,7 +66,7 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen
|
|||
return nil, errors.New("problem statement is required")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/%s/%s", jobsBasePathV1, url.PathEscape(owner), url.PathEscape(repo))
|
||||
url := fmt.Sprintf("%s/%s/%s", c.jobsBasePathV1(), url.PathEscape(owner), url.PathEscape(repo))
|
||||
|
||||
prOpts := JobPullRequest{}
|
||||
if baseBranch != "" {
|
||||
|
|
@ -130,7 +132,7 @@ func (c *CAPIClient) GetJob(ctx context.Context, owner, repo, jobID string) (*Jo
|
|||
if owner == "" || repo == "" || jobID == "" {
|
||||
return nil, errors.New("owner, repo, and jobID are required")
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/%s/%s", jobsBasePathV1, url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(jobID))
|
||||
url := fmt.Sprintf("%s/%s/%s/%s", c.jobsBasePathV1(), url.PathEscape(owner), url.PathEscape(repo), url.PathEscape(jobID))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ func TestGetJob(t *testing.T) {
|
|||
|
||||
httpClient := &http.Client{Transport: reg}
|
||||
|
||||
capiClient := NewCAPIClient(httpClient, "", "github.com")
|
||||
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")
|
||||
|
||||
job, err := capiClient.GetJob(context.Background(), "OWNER", "REPO", "job123")
|
||||
|
||||
|
|
@ -410,7 +410,7 @@ func TestCreateJob(t *testing.T) {
|
|||
|
||||
httpClient := &http.Client{Transport: reg}
|
||||
|
||||
capiClient := NewCAPIClient(httpClient, "", "github.com")
|
||||
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")
|
||||
|
||||
job, err := capiClient.CreateJob(context.Background(), "OWNER", "REPO", "Do the thing", tt.baseBranch, tt.customAgent)
|
||||
|
||||
|
|
|
|||
|
|
@ -217,13 +217,16 @@ func (c *CAPIClient) ListLatestSessionsForViewer(ctx context.Context, limit int)
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
url := baseCAPIURL + "/agents/sessions"
|
||||
sessionsURL, err := url.JoinPath(c.capiBaseURL, "agents", "sessions")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build sessions URL: %w", err)
|
||||
}
|
||||
pageSize := defaultSessionsPerPage
|
||||
|
||||
seenResources := make(map[int64]struct{})
|
||||
latestSessions := make([]session, 0, limit)
|
||||
for page := 1; ; page++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sessionsURL, http.NoBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -296,7 +299,7 @@ func (c *CAPIClient) GetSession(ctx context.Context, id string) (*Session, error
|
|||
return nil, fmt.Errorf("missing session ID")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/agents/sessions/%s", baseCAPIURL, url.PathEscape(id))
|
||||
url := fmt.Sprintf("%s/agents/sessions/%s", c.capiBaseURL, url.PathEscape(id))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
|
||||
if err != nil {
|
||||
|
|
@ -335,7 +338,7 @@ func (c *CAPIClient) GetSessionLogs(ctx context.Context, id string) ([]byte, err
|
|||
return nil, fmt.Errorf("missing session ID")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/agents/sessions/%s/logs", baseCAPIURL, url.PathEscape(id))
|
||||
url := fmt.Sprintf("%s/agents/sessions/%s/logs", c.capiBaseURL, url.PathEscape(id))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
|
||||
if err != nil {
|
||||
|
|
@ -368,7 +371,7 @@ func (c *CAPIClient) ListSessionsByResourceID(ctx context.Context, resourceType
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/agents/resource/%s/%d", baseCAPIURL, url.PathEscape(resourceType), resourceID)
|
||||
url := fmt.Sprintf("%s/agents/resource/%s/%d", c.capiBaseURL, url.PathEscape(resourceType), resourceID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -1161,7 +1161,7 @@ func TestListLatestSessionsForViewer(t *testing.T) {
|
|||
|
||||
httpClient := &http.Client{Transport: reg}
|
||||
|
||||
capiClient := NewCAPIClient(httpClient, "", "github.com")
|
||||
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")
|
||||
|
||||
if tt.perPage != 0 {
|
||||
last := defaultSessionsPerPage
|
||||
|
|
@ -1540,7 +1540,7 @@ func TestListSessionsByResourceID(t *testing.T) {
|
|||
|
||||
httpClient := &http.Client{Transport: reg}
|
||||
|
||||
capiClient := NewCAPIClient(httpClient, "", "github.com")
|
||||
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")
|
||||
|
||||
if tt.perPage != 0 {
|
||||
last := defaultSessionsPerPage
|
||||
|
|
@ -1819,7 +1819,7 @@ func TestGetSession(t *testing.T) {
|
|||
|
||||
httpClient := &http.Client{Transport: reg}
|
||||
|
||||
capiClient := NewCAPIClient(httpClient, "", "github.com")
|
||||
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")
|
||||
|
||||
session, err := capiClient.GetSession(context.Background(), "some-uuid")
|
||||
|
||||
|
|
@ -1895,7 +1895,7 @@ func TestGetPullRequestDatabaseID(t *testing.T) {
|
|||
|
||||
httpClient := &http.Client{Transport: reg}
|
||||
|
||||
capiClient := NewCAPIClient(httpClient, "", "github.com")
|
||||
capiClient := NewCAPIClient(httpClient, "", "github.com", "https://api.githubcopilot.com")
|
||||
|
||||
databaseID, url, err := capiClient.GetPullRequestDatabaseID(context.Background(), "github.com", "OWNER", "REPO", 42)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,11 @@ package shared
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/cli/cli/v2/api"
|
||||
"github.com/cli/cli/v2/pkg/cmd/agent-task/capi"
|
||||
prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared"
|
||||
"github.com/cli/cli/v2/pkg/cmdutil"
|
||||
|
|
@ -30,10 +33,40 @@ func CapiClientFunc(f *cmdutil.Factory) func() (capi.CapiClient, error) {
|
|||
authCfg := cfg.Authentication()
|
||||
host, _ := authCfg.DefaultHost()
|
||||
token, _ := authCfg.ActiveToken(host)
|
||||
return capi.NewCAPIClient(httpClient, token, host), nil
|
||||
|
||||
cachedClient := api.NewCachedHTTPClient(httpClient, time.Minute*10)
|
||||
capiBaseURL, err := resolveCapiURL(cachedClient, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve Copilot API URL: %w", err)
|
||||
}
|
||||
|
||||
return capi.NewCAPIClient(httpClient, token, host, capiBaseURL), nil
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCapiURL queries the GitHub API for the Copilot API endpoint URL.
|
||||
func resolveCapiURL(httpClient *http.Client, host string) (string, error) {
|
||||
apiClient := api.NewClientFromHTTP(httpClient)
|
||||
|
||||
var resp struct {
|
||||
Viewer struct {
|
||||
CopilotEndpoints struct {
|
||||
Api string `graphql:"api"`
|
||||
} `graphql:"copilotEndpoints"`
|
||||
} `graphql:"viewer"`
|
||||
}
|
||||
|
||||
if err := apiClient.Query(host, "CopilotEndpoints", &resp, nil); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if resp.Viewer.CopilotEndpoints.Api == "" {
|
||||
return "", errors.New("empty Copilot API URL returned")
|
||||
}
|
||||
|
||||
return resp.Viewer.CopilotEndpoints.Api, nil
|
||||
}
|
||||
|
||||
func IsSessionID(s string) bool {
|
||||
return sessionIDRegexp.MatchString(s)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,100 @@
|
|||
package shared
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/cli/cli/v2/internal/config"
|
||||
"github.com/cli/cli/v2/internal/gh"
|
||||
ghmock "github.com/cli/cli/v2/internal/gh/mock"
|
||||
"github.com/cli/cli/v2/pkg/cmdutil"
|
||||
"github.com/cli/cli/v2/pkg/httpmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestResolveCapiURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
resp string
|
||||
wantURL string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "returns resolved URL",
|
||||
resp: `{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.example.com"}}}}`,
|
||||
wantURL: "https://test-copilot-api.example.com",
|
||||
},
|
||||
{
|
||||
name: "ghe.com tenant URL",
|
||||
resp: `{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.tenant.example.com"}}}}`,
|
||||
wantURL: "https://test-copilot-api.tenant.example.com",
|
||||
},
|
||||
{
|
||||
name: "empty URL returns error",
|
||||
resp: `{"data":{"viewer":{"copilotEndpoints":{"api":""}}}}`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reg := &httpmock.Registry{}
|
||||
defer reg.Verify(t)
|
||||
|
||||
reg.Register(
|
||||
httpmock.GraphQL(`query CopilotEndpoints\b`),
|
||||
httpmock.StringResponse(tt.resp),
|
||||
)
|
||||
|
||||
httpClient := &http.Client{Transport: reg}
|
||||
url, err := resolveCapiURL(httpClient, "github.com")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantURL, url)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCapiClientFuncResolvesURL(t *testing.T) {
|
||||
reg := &httpmock.Registry{}
|
||||
defer reg.Verify(t)
|
||||
|
||||
reg.Register(
|
||||
httpmock.GraphQL(`query CopilotEndpoints\b`),
|
||||
httpmock.StringResponse(`{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.example.com"}}}}`),
|
||||
)
|
||||
|
||||
f := &cmdutil.Factory{
|
||||
Config: func() (gh.Config, error) {
|
||||
return &ghmock.ConfigMock{
|
||||
AuthenticationFunc: func() gh.AuthConfig {
|
||||
c := &config.AuthConfig{}
|
||||
c.SetDefaultHost("github.com", "hosts")
|
||||
c.SetActiveToken("gho_TOKEN", "oauth_token")
|
||||
return c
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
HttpClient: func() (*http.Client, error) {
|
||||
return &http.Client{Transport: reg}, nil
|
||||
},
|
||||
}
|
||||
|
||||
clientFunc := CapiClientFunc(f)
|
||||
client, err := clientFunc()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
|
||||
// Verify the GraphQL resolution was called
|
||||
require.Len(t, reg.Requests, 1)
|
||||
}
|
||||
|
||||
func TestIsSession(t *testing.T) {
|
||||
assert.True(t, IsSessionID("00000000-0000-0000-0000-000000000000"))
|
||||
assert.True(t, IsSessionID("e2fa49d2-f164-4a56-ab99-498090b8fcdf"))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue