cli/pkg/cmd/agent-task/shared/capi_test.go
Kynan Ware 78b958f9ae
fix(agent-task): resolve Copilot API URL dynamically (#12956)
* fix(agent-task): resolve Copilot API URL dynamically

Query viewer.copilotEndpoints.api to get the correct Copilot API URL
for the user's host instead of hardcoding api.githubcopilot.com. This
fixes 401 errors for ghe.com tenancy users whose Copilot API lives at
a different endpoint.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-03-18 18:14:02 +00:00

164 lines
4.1 KiB
Go

package shared
import (
"net/http"
"testing"
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/internal/gh"
ghmock "github.com/cli/cli/v2/internal/gh/mock"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestResolveCapiURL(t *testing.T) {
tests := []struct {
name string
resp string
wantURL string
wantErr bool
}{
{
name: "returns resolved URL",
resp: `{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.example.com"}}}}`,
wantURL: "https://test-copilot-api.example.com",
},
{
name: "ghe.com tenant URL",
resp: `{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.tenant.example.com"}}}}`,
wantURL: "https://test-copilot-api.tenant.example.com",
},
{
name: "empty URL returns error",
resp: `{"data":{"viewer":{"copilotEndpoints":{"api":""}}}}`,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reg := &httpmock.Registry{}
defer reg.Verify(t)
reg.Register(
httpmock.GraphQL(`query CopilotEndpoints\b`),
httpmock.StringResponse(tt.resp),
)
httpClient := &http.Client{Transport: reg}
url, err := resolveCapiURL(httpClient, "github.com")
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantURL, url)
})
}
}
func TestCapiClientFuncResolvesURL(t *testing.T) {
reg := &httpmock.Registry{}
defer reg.Verify(t)
reg.Register(
httpmock.GraphQL(`query CopilotEndpoints\b`),
httpmock.StringResponse(`{"data":{"viewer":{"copilotEndpoints":{"api":"https://test-copilot-api.example.com"}}}}`),
)
f := &cmdutil.Factory{
Config: func() (gh.Config, error) {
return &ghmock.ConfigMock{
AuthenticationFunc: func() gh.AuthConfig {
c := &config.AuthConfig{}
c.SetDefaultHost("github.com", "hosts")
c.SetActiveToken("gho_TOKEN", "oauth_token")
return c
},
}, nil
},
HttpClient: func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
},
}
clientFunc := CapiClientFunc(f)
client, err := clientFunc()
require.NoError(t, err)
require.NotNil(t, client)
// Verify the GraphQL resolution was called
require.Len(t, reg.Requests, 1)
}
func TestIsSession(t *testing.T) {
assert.True(t, IsSessionID("00000000-0000-0000-0000-000000000000"))
assert.True(t, IsSessionID("e2fa49d2-f164-4a56-ab99-498090b8fcdf"))
assert.True(t, IsSessionID("E2FA49D2-F164-4A56-AB99-498090B8FCDF"))
assert.False(t, IsSessionID(""))
assert.False(t, IsSessionID(" "))
assert.False(t, IsSessionID("\n"))
assert.False(t, IsSessionID("not-a-uuid"))
assert.False(t, IsSessionID("000000000000000000000000000000000000"))
assert.False(t, IsSessionID("00000000-0000-0000-0000-000000000000-extra"))
}
func TestParsePullRequestAgentSessionURL(t *testing.T) {
tests := []struct {
name string
url string
wantSessionID string
wantErr bool
}{
{
name: "valid",
url: "https://github.com/OWNER/REPO/pull/123/agent-sessions/e2fa49d2-f164-4a56-ab99-498090b8fcdf",
wantSessionID: "e2fa49d2-f164-4a56-ab99-498090b8fcdf",
},
{
name: "invalid session id",
url: "https://github.com/OWNER/REPO/pull/123/agent-sessions/fff",
wantErr: true,
},
{
name: "no session id, trailing slash",
url: "https://github.com/OWNER/REPO/pull/123/agent-sessions/",
wantErr: true,
},
{
name: "no session id",
url: "https://github.com/OWNER/REPO/pull/123/agent-sessions",
wantErr: true,
},
{
name: "invalid pr url",
url: "https://github.com/OWNER/REPO/issues/123",
wantErr: true,
},
{
name: "empty",
url: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sessionID, err := ParseSessionIDFromURL(tt.url)
if tt.wantErr {
require.Error(t, err)
assert.Zero(t, sessionID)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantSessionID, sessionID)
})
}
}