diff --git a/api/queries_pr.go b/api/queries_pr.go index 525418a11..b3373a903 100644 --- a/api/queries_pr.go +++ b/api/queries_pr.go @@ -62,6 +62,7 @@ type PullRequest struct { MergedBy *Author HeadRepositoryOwner Owner HeadRepository *PRRepository + Repository *PRRepository IsCrossRepository bool IsDraft bool MaintainerCanModify bool @@ -251,8 +252,9 @@ type Workflow struct { } type PRRepository struct { - ID string `json:"id"` - Name string `json:"name"` + ID string `json:"id"` + Name string `json:"name"` + NameWithOwner string `json:"nameWithOwner"` } type AutoMergeRequest struct { diff --git a/go.mod b/go.mod index 18deafe15..7f099d7ea 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,7 @@ require ( github.com/spf13/pflag v1.0.7 github.com/stretchr/testify v1.10.0 github.com/theupdateframework/go-tuf/v2 v2.1.1 + github.com/vmihailenco/msgpack/v5 v5.4.1 github.com/yuin/goldmark v1.7.13 github.com/zalando/go-keyring v0.2.6 golang.org/x/crypto v0.41.0 @@ -205,6 +206,7 @@ require ( github.com/transparency-dev/merkle v0.0.2 // indirect github.com/transparency-dev/tessera v0.2.1-0.20250610150926-8ee4e93b2823 // indirect github.com/vbatts/tar-split v0.12.1 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yuin/goldmark-emoji v1.0.6 // indirect github.com/zeebo/errs v1.4.0 // indirect diff --git a/go.sum b/go.sum index eeebd6973..f28dc1cf4 100644 --- a/go.sum +++ b/go.sum @@ -1415,6 +1415,10 @@ github.com/transparency-dev/tessera v0.2.1-0.20250610150926-8ee4e93b2823 h1:s3p7 github.com/transparency-dev/tessera v0.2.1-0.20250610150926-8ee4e93b2823/go.mod h1:Jv2IDwG1q8QNXZTaI1X6QX8s96WlJn73ka2hT1n4N5c= github.com/vbatts/tar-split v0.12.1 h1:CqKoORW7BUWBe7UL/iqTVvkTBOF8UvOMKOIZykxnnbo= github.com/vbatts/tar-split v0.12.1/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/pkg/cmd/agent-task/agent_task.go b/pkg/cmd/agent-task/agent_task.go index b53c6786d..cbbd3e278 100644 --- a/pkg/cmd/agent-task/agent_task.go +++ b/pkg/cmd/agent-task/agent_task.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" + cmdList "github.com/cli/cli/v2/pkg/cmd/agent-task/list" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/go-gh/v2/pkg/auth" "github.com/spf13/cobra" @@ -25,6 +26,10 @@ func NewCmdAgentTask(f *cmdutil.Factory) *cobra.Command { return cmd.Help() }, } + + // register subcommands + cmd.AddCommand(cmdList.NewCmdList(f, nil)) + return cmd } diff --git a/pkg/cmd/agent-task/capi/client.go b/pkg/cmd/agent-task/capi/client.go new file mode 100644 index 000000000..b5b4ea9e0 --- /dev/null +++ b/pkg/cmd/agent-task/capi/client.go @@ -0,0 +1,64 @@ +package capi + +import ( + "context" + "net/http" + + "github.com/cli/cli/v2/internal/gh" +) + +const baseCAPIURL = "https://api.githubcopilot.com" +const capiHost = "api.githubcopilot.com" + +// CapiClient defines the methods used by the caller. Implementations +// may be replaced with test doubles in unit tests. +type CapiClient interface { + ListSessionsForViewer(ctx context.Context, limit int) ([]*Session, error) +} + +// CAPIClient is a client for interacting with the Copilot API +type CAPIClient struct { + httpClient *http.Client + authCfg gh.AuthConfig +} + +// NewCAPIClient creates a new CAPI client. Provide a token and an HTTP client which +// will be used as the base transport for CAPI requests. +// +// The provided HTTP client will be mutated for use with CAPI, so it should not +// be reused elsewhere. +func NewCAPIClient(httpClient *http.Client, authCfg gh.AuthConfig) *CAPIClient { + host, _ := authCfg.DefaultHost() + token, _ := authCfg.ActiveToken(host) + + httpClient.Transport = newCAPITransport(token, httpClient.Transport) + return &CAPIClient{ + httpClient: httpClient, + authCfg: authCfg, + } +} + +// capiTransport adds the Copilot auth headers +type capiTransport struct { + rp http.RoundTripper + token string +} + +func newCAPITransport(token string, rp http.RoundTripper) *capiTransport { + return &capiTransport{ + rp: rp, + token: token, + } +} + +func (ct *capiTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("Authorization", "Bearer "+ct.token) + + // Since this RoundTrip is reused for both Copilot API and + // GitHub API requests, we conditionally add the integration + // ID only when performing requests to the Copilot API. + if req.URL.Host == capiHost { + req.Header.Add("Copilot-Integration-Id", "copilot-4-cli") + } + return ct.rp.RoundTrip(req) +} diff --git a/pkg/cmd/agent-task/capi/sessions.go b/pkg/cmd/agent-task/capi/sessions.go new file mode 100644 index 000000000..2693af57d --- /dev/null +++ b/pkg/cmd/agent-task/capi/sessions.go @@ -0,0 +1,208 @@ +package capi + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "slices" + "strconv" + "time" + + "github.com/cli/cli/v2/api" + "github.com/vmihailenco/msgpack/v5" +) + +// session is an in-flight agent task +type session struct { + ID string `json:"id"` + Name string `json:"name"` + UserID uint64 `json:"user_id"` + AgentID int64 `json:"agent_id"` + Logs string `json:"logs"` + State string `json:"state"` + OwnerID uint64 `json:"owner_id"` + RepoID uint64 `json:"repo_id"` + ResourceType string `json:"resource_type"` + ResourceID int64 `json:"resource_id"` + LastUpdatedAt time.Time `json:"last_updated_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` + CompletedAt time.Time `json:"completed_at,omitempty"` + EventURL string `json:"event_url"` + EventType string `json:"event_type"` +} + +// A shim of a full pull request because looking up by node ID +// using the full api.PullRequest type fails on unions (actors) +type sessionPullRequest struct { + ID string + FullDatabaseID string + Number int + Title string + State string + URL string + Body string + + CreatedAt time.Time + UpdatedAt time.Time + ClosedAt *time.Time + MergedAt *time.Time + + // Uncomment one of these to see error + // Author api.Author + // MergedBy *api.Author + Repository *api.PRRepository +} + +// Session is a hydrated in-flight agent task +type Session struct { + session + PullRequest *api.PullRequest `json:"-"` +} + +// ListSessionsForViewer lists all agent sessions for the +// authenticated user up to limit. +func (c *CAPIClient) ListSessionsForViewer(ctx context.Context, limit int) ([]*Session, error) { + url := baseCAPIURL + "/agents/sessions" + + var sessions []session + page := 1 + perPage := 50 + + for { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return nil, err + } + + q := req.URL.Query() + q.Set("page_size", strconv.Itoa(perPage)) + q.Set("page_number", strconv.Itoa(page)) + req.URL.RawQuery = q.Encode() + + res, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to list sessions: %s", res.Status) + } + var response struct { + Sessions []session `json:"sessions"` + } + if err := json.NewDecoder(res.Body).Decode(&response); err != nil { + return nil, fmt.Errorf("failed to decode sessions response: %w", err) + } + if len(response.Sessions) == 0 || len(sessions) >= limit { + break + } + sessions = append(sessions, response.Sessions...) + page++ + } + + // Drop any above the limit + if len(sessions) > limit { + sessions = sessions[:limit] + } + + // Hydrate the Sessions with pull request data. + Sessions, err := c.hydrateSessionPullRequests(sessions) + if err != nil { + return nil, err + } + + return Sessions, nil +} + +// hydrateSessionPullRequests hydrates pull request information in sessions +func (c *CAPIClient) hydrateSessionPullRequests(sessions []session) ([]*Session, error) { + if len(sessions) == 0 { + return nil, nil + } + + prNodeIds := make([]string, 0, len(sessions)) + + for _, session := range sessions { + prNodeID := generatePullRequestNodeID(int64(session.RepoID), session.ResourceID) + if slices.Contains(prNodeIds, prNodeID) { + continue + } + prNodeIds = append(prNodeIds, prNodeID) + } + + apiClient := api.NewClientFromHTTP(c.httpClient) + + var resp struct { + Nodes []struct { + PullRequest sessionPullRequest `graphql:"... on PullRequest"` + } `graphql:"nodes(ids: $ids)"` + } + + host, _ := c.authCfg.DefaultHost() + err := apiClient.Query(host, "FetchPRs", &resp, map[string]any{ + "ids": prNodeIds, + }) + + if err != nil { + return nil, err + } + + prs := make([]*api.PullRequest, 0, len(prNodeIds)) + for _, node := range resp.Nodes { + prs = append(prs, &api.PullRequest{ + ID: node.PullRequest.ID, + FullDatabaseID: node.PullRequest.FullDatabaseID, + Number: node.PullRequest.Number, + Title: node.PullRequest.Title, + State: node.PullRequest.State, + URL: node.PullRequest.URL, + Body: node.PullRequest.Body, + CreatedAt: node.PullRequest.CreatedAt, + UpdatedAt: node.PullRequest.UpdatedAt, + ClosedAt: node.PullRequest.ClosedAt, + MergedAt: node.PullRequest.MergedAt, + Repository: node.PullRequest.Repository, + }) + } + + newSessions := make([]*Session, 0, len(sessions)) + // For each session, we need to attach the Pull Request + for _, s := range sessions { + // For each Pull Request, check if it matches the session + for _, pr := range prs { + if strconv.FormatInt(s.ResourceID, 10) == pr.FullDatabaseID { + newSessions = append(newSessions, &Session{ + session: s, + PullRequest: pr, + }) + } + } + } + + return newSessions, nil +} + +// generatePullRequestNodeID converts an int64 databaseID and repoID to a GraphQL Node ID format +// with the "PR_" prefix for pull requests +func generatePullRequestNodeID(repoID, pullRequestID int64) string { + buf := bytes.Buffer{} + parts := []int64{0, repoID, pullRequestID} + + encoder := msgpack.NewEncoder(&buf) + encoder.UseCompactInts(true) + + // Encode the parts + err := encoder.Encode(parts) + if err != nil { + panic(err) + } + + // Use URL-safe Base64 encoding without padding + encoded := base64.RawURLEncoding.EncodeToString(buf.Bytes()) + + // Return with the PR_ prefix + return "PR_" + encoded +} diff --git a/pkg/cmd/agent-task/list/list.go b/pkg/cmd/agent-task/list/list.go new file mode 100644 index 000000000..3b2bff267 --- /dev/null +++ b/pkg/cmd/agent-task/list/list.go @@ -0,0 +1,150 @@ +package list + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/cli/cli/v2/internal/gh" + "github.com/cli/cli/v2/internal/tableprinter" + "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" + "github.com/cli/cli/v2/pkg/cmd/pr/shared" + "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/iostreams" + "github.com/spf13/cobra" +) + +const defaultLimit = 30 + +// ListOptions are the options for the list command +type ListOptions struct { + IO *iostreams.IOStreams + Config func() (gh.Config, error) + Limit int + CapiClient capi.CapiClient + HttpClient func() (*http.Client, error) +} + +// NewCmdList creates the list command +func NewCmdList(f *cmdutil.Factory, runF func(*ListOptions) error) *cobra.Command { + opts := &ListOptions{ + IO: f.IOStreams, + Config: f.Config, + Limit: defaultLimit, + HttpClient: f.HttpClient, + } + + cmd := &cobra.Command{ + Use: "list", + Short: "List agent tasks", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := f.Config() + if err != nil { + return err + } + + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + + authCfg := cfg.Authentication() + opts.CapiClient = capi.NewCAPIClient(httpClient, authCfg) + + if runF != nil { + return runF(opts) + } + return listRun(opts) + }, + } + + return cmd +} + +func listRun(opts *ListOptions) error { + if opts.Limit <= 0 { + opts.Limit = defaultLimit + } + + capiClient := opts.CapiClient + + opts.IO.StartProgressIndicatorWithLabel("Fetching agent tasks...") + defer opts.IO.StopProgressIndicator() + sessions, err := capiClient.ListSessionsForViewer(context.Background(), opts.Limit) + if err != nil { + return err + } + opts.IO.StopProgressIndicator() + + if len(sessions) == 0 { + fmt.Fprintln(opts.IO.Out, "no agent tasks found") + return nil + } + + cs := opts.IO.ColorScheme() + tp := tableprinter.New(opts.IO, tableprinter.WithHeader("Session ID", "Pull Request", "Repo", "Session State", "Created")) + for _, s := range sessions { + pr := "" + if s.ResourceType == "pull" && s.PullRequest.Number != 0 { + pr = fmt.Sprintf("#%d", s.PullRequest.Number) + } else { + // Skip these sessions in case they happen, for now. + continue + } + repo := "" + if s.PullRequest.Repository != nil && s.PullRequest.Repository.NameWithOwner != "" { + repo = s.PullRequest.Repository.NameWithOwner + } else { + // Skip these sessions in case they happen, for now. + continue + } + + // ID + tp.AddField(s.ID) + if tp.IsTTY() { + tp.AddField(pr, tableprinter.WithColor(cs.ColorFromString(shared.ColorForPRState(*s.PullRequest)))) + } else { + tp.AddField(pr) + } + + // Repo + tp.AddField(repo, tableprinter.WithColor(cs.Muted)) + + // State + if tp.IsTTY() { + var stateColor func(string) string + switch s.State { + case "completed": + stateColor = cs.Green + case "canceled": + stateColor = cs.Muted + case "in_progress", "queued": + stateColor = cs.Yellow + case "failed": + stateColor = cs.Red + default: + stateColor = cs.Muted + } + tp.AddField(s.State, tableprinter.WithColor(stateColor)) + } else { + tp.AddField(s.State) + } + + // Created + if tp.IsTTY() { + tp.AddTimeField(time.Now(), s.CreatedAt, cs.Muted) + } else { + tp.AddField(s.CreatedAt.Format(time.RFC3339)) + } + + tp.EndRow() + } + + if err := tp.Render(); err != nil { + return err + } + + return nil +} diff --git a/pkg/cmd/agent-task/list/list_test.go b/pkg/cmd/agent-task/list/list_test.go new file mode 100644 index 000000000..a02c77bab --- /dev/null +++ b/pkg/cmd/agent-task/list/list_test.go @@ -0,0 +1,91 @@ +package list + +import ( + "bytes" + "context" + "net/http" + "testing" + "time" + + "github.com/cli/cli/v2/api" + "github.com/cli/cli/v2/internal/config" + "github.com/cli/cli/v2/internal/gh" + capi "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" + "github.com/cli/cli/v2/pkg/httpmock" + "github.com/cli/cli/v2/pkg/iostreams" + "github.com/stretchr/testify/require" +) + +// testListOptionsWithRegistry constructs ListOptions and returns the stdout buffer for assertions +func testListOptionsWithRegistry(reg *httpmock.Registry) (*ListOptions, *bytes.Buffer) { + ios, _, stdout, _ := iostreams.Test() + ios.SetStdoutTTY(true) + + opts := &ListOptions{ + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + Config: func() (gh.Config, error) { + c := config.NewBlankConfig() + c.Set("github.com", "oauth_token", "gho_OAUTH123") + return c, nil + }, + Limit: defaultLimit, + } + + return opts, stdout +} + +// mockCAPIClient is a small test double for the CAPI client. +type mockCAPIClient struct { + sessions []*capi.Session +} + +// Updated to match production interface which now includes a limit parameter. +func (m *mockCAPIClient) ListSessionsForViewer(ctx context.Context, limit int) ([]*capi.Session, error) { + return m.sessions, nil +} + +func TestListRun_WithSessions(t *testing.T) { + reg := httpmock.Registry{} + defer reg.Verify(t) + + opts, stdout := testListOptionsWithRegistry(®) + + createdAt := time.Date(2025, time.August, 25, 12, 0, 0, 0, time.UTC) + s := &capi.Session{} + s.ID = "s1" + s.RepoID = 123 + s.ResourceType = "pull" + s.ResourceID = 456 + s.State = "completed" + s.CreatedAt = createdAt + s.PullRequest = &api.PullRequest{ + Number: 456, + State: "OPEN", + Repository: &api.PRRepository{NameWithOwner: "owner/repo"}, + } + opts.CapiClient = &mockCAPIClient{sessions: []*capi.Session{s}} + + err := listRun(opts) + require.NoError(t, err) + out := stdout.String() + require.Contains(t, out, "SESSION ID") + require.Contains(t, out, "s1") + require.Contains(t, out, "#456") + require.Contains(t, out, "owner/repo") +} + +func TestListRun_NoSessions(t *testing.T) { + reg := httpmock.Registry{} + defer reg.Verify(t) + + opts, stdout := testListOptionsWithRegistry(®) + opts.CapiClient = &mockCAPIClient{sessions: []*capi.Session{}} + + err := listRun(opts) + require.NoError(t, err) + out := stdout.String() + require.Contains(t, out, "no agent tasks found") +}