feat(agent-task view): support PR arg

Signed-off-by: Babak K. Shandiz <babakks@github.com>
This commit is contained in:
Babak K. Shandiz 2025-09-05 17:03:59 +01:00
parent e68e28ddf3
commit 8482e3d2a4
No known key found for this signature in database
GPG key ID: 9472CAEFF56C742E
2 changed files with 209 additions and 19 deletions

View file

@ -4,41 +4,74 @@ import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"time"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/internal/prompter"
"github.com/cli/cli/v2/internal/text"
"github.com/cli/cli/v2/pkg/cmd/agent-task/capi"
"github.com/cli/cli/v2/pkg/cmd/agent-task/shared"
prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/shurcooL/githubv4"
"github.com/spf13/cobra"
)
const defaultLimit = 60
type ViewOptions struct {
IO *iostreams.IOStreams
BaseRepo func() (ghrepo.Interface, error)
CapiClient func() (capi.CapiClient, error)
HttpClient func() (*http.Client, error)
Finder prShared.PRFinder
Prompter prompter.Prompter
SelectorArg string
PRNumber int
SessionID string
}
func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Command {
opts := &ViewOptions{
IO: f.IOStreams,
HttpClient: f.HttpClient,
CapiClient: shared.CapiClientFunc(f),
Prompter: f.Prompter,
}
cmd := &cobra.Command{
Use: "view <session-id>",
Use: "view [<session-id> | <pr-number> | <pr-url> | <pr-branch>]",
Short: "View an agent task session",
Long: heredoc.Doc(`
View an agent task session.
`),
Args: cmdutil.ExactArgs(1, "a session ID is required"),
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
opts.SelectorArg = args[0]
// Support -R/--repo override
opts.BaseRepo = f.BaseRepo
if len(args) > 0 {
opts.SelectorArg = args[0]
if shared.IsSessionID(opts.SelectorArg) {
opts.SessionID = opts.SelectorArg
}
}
if opts.SessionID == "" && !opts.IO.CanPrompt() {
return fmt.Errorf("session ID is required when not running interactively")
}
if opts.Finder == nil {
opts.Finder = prShared.NewFinder(f)
}
if runF != nil {
return runF(opts)
@ -47,6 +80,8 @@ func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Comman
},
}
cmdutil.EnableRepoOverride(cmd, f)
return cmd
}
@ -57,23 +92,117 @@ func viewRun(opts *ViewOptions) error {
}
ctx := context.Background()
cs := opts.IO.ColorScheme()
opts.IO.StartProgressIndicatorWithLabel("Fetching agent session...")
defer opts.IO.StopProgressIndicator()
session, err := capiClient.GetSession(ctx, opts.SelectorArg)
opts.IO.StopProgressIndicator()
var session *capi.Session
if err != nil {
if errors.Is(err, capi.ErrSessionNotFound) {
fmt.Fprintln(opts.IO.ErrOut, "session not found")
if opts.SessionID != "" {
if sess, err := capiClient.GetSession(ctx, opts.SessionID); err != nil {
if errors.Is(err, capi.ErrSessionNotFound) {
fmt.Fprintln(opts.IO.ErrOut, "session not found")
return cmdutil.SilentError
}
return err
} else {
session = sess
}
} else {
var resourceID int64
if opts.SelectorArg != "" {
// Finder does not support the PR/issue reference format (e.g. owner/repo#123)
// so we need to check if the selector arg is a reference and fetch the PR
// directly.
if repo, num, err := prShared.ParseFullReference(opts.SelectorArg); err == nil {
// We need to check the base repo to get the hostname.
baseRepo, err := opts.BaseRepo()
if err != nil {
return err
}
hostname := baseRepo.RepoHost()
if repo.RepoHost() != hostname {
return fmt.Errorf("agent tasks are not supported on this host: %s", repo.RepoHost())
}
client, err := opts.HttpClient()
if err != nil {
return err
}
resourceID, err = getPullRequestDatabaseID(ctx, client, hostname, repo, num)
if err != nil {
return fmt.Errorf("failed to get pull request: %w", err)
}
}
}
if resourceID == 0 {
findOptions := prShared.FindOptions{
Selector: opts.SelectorArg,
Fields: []string{"id", "url", "fullDatabaseId"},
}
pr, repo, err := opts.Finder.Find(findOptions)
if err != nil {
return err
}
if repo.RepoHost() != ghinstance.Default() {
return fmt.Errorf("agent tasks are not supported on this host: %s", repo.RepoHost())
}
databaseID, err := strconv.ParseInt(pr.FullDatabaseID, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse pull request: %w", err)
}
resourceID = databaseID
}
// TODO(babakks): currently we just fetch a pre-defined number of
// matching sessions to avoid hitting the API too many times, but it's
// technically possible for a PR to be associated with lots of sessions
// (i.e. above our selected limit).
sessions, err := capiClient.ListSessionsByResourceID(ctx, "pull", resourceID, defaultLimit)
if err != nil {
return fmt.Errorf("failed to list sessions for pull request: %w", err)
}
if len(sessions) == 0 {
fmt.Fprintln(opts.IO.ErrOut, "no session found for pull request")
return cmdutil.SilentError
}
return err
session = sessions[0]
if len(sessions) > 1 {
now := time.Now()
options := make([]string, 0, len(sessions))
for _, session := range sessions {
options = append(options, fmt.Sprintf(
"%s %s • %s",
shared.SessionSymbol(cs, session.State),
session.Name,
text.FuzzyAgo(now, session.CreatedAt),
))
}
opts.IO.StopProgressIndicator()
selected, err := opts.Prompter.Select("Select a session", options[0], options)
if err != nil {
return err
}
session = sessions[selected]
}
}
opts.IO.StopProgressIndicator()
out := opts.IO.Out
cs := opts.IO.ColorScheme()
if session.PullRequest != nil {
fmt.Fprintf(out, "%s • %s • %s%s\n",
@ -106,3 +235,30 @@ func viewRun(opts *ViewOptions) error {
return nil
}
func getPullRequestDatabaseID(ctx context.Context, httpClient *http.Client, hostname string, repo ghrepo.Interface, number int) (int64, error) {
var resp struct {
Repository struct {
PullRequest struct {
FullDatabaseID string `graphql:"fullDatabaseId"`
} `graphql:"pullRequest(number: $number)"`
} `graphql:"repository(owner: $owner, name: $repo)"`
}
variables := map[string]interface{}{
"owner": githubv4.String(repo.RepoOwner()),
"repo": githubv4.String(repo.RepoName()),
"number": githubv4.Int(number),
}
apiClient := api.NewClientFromHTTP(httpClient)
if err := apiClient.Query(hostname, "GetPullRequestFullDatabaseID", &resp, variables); err != nil {
return 0, err
}
databaseID, err := strconv.ParseInt(resp.Repository.PullRequest.FullDatabaseID, 10, 64)
if err != nil {
return 0, err
}
return databaseID, nil
}

View file

@ -10,6 +10,7 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/cmd/agent-task/capi"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
@ -20,20 +21,49 @@ import (
func TestNewCmdList(t *testing.T) {
tests := []struct {
name string
args string
wantOpts ViewOptions
wantErr string
name string
tty bool
args string
wantOpts ViewOptions
wantBaseRepo ghrepo.Interface
wantErr string
}{
{
name: "no arguments",
wantErr: "a session ID is required",
name: "no arg tty",
tty: true,
args: "",
wantOpts: ViewOptions{},
},
{
name: "session ID arg",
args: "some-uuid",
name: "session ID arg tty",
tty: true,
args: "00000000-0000-0000-0000-000000000000",
wantOpts: ViewOptions{
SelectorArg: "some-uuid",
SelectorArg: "00000000-0000-0000-0000-000000000000",
SessionID: "00000000-0000-0000-0000-000000000000",
},
},
{
name: "non-session ID arg tty",
tty: true,
args: "some-arg",
wantOpts: ViewOptions{
SelectorArg: "some-arg",
},
},
{
name: "session ID required if non-tty",
tty: false,
args: "some-arg",
wantErr: "session ID is required when not running interactively",
},
{
name: "repo override",
tty: true,
args: "some-arg -R OWNER/REPO",
wantBaseRepo: ghrepo.New("OWNER", "REPO"),
wantOpts: ViewOptions{
SelectorArg: "some-arg",
},
},
}
@ -41,6 +71,10 @@ func TestNewCmdList(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ios, _, _, _ := iostreams.Test()
ios.SetStdinTTY(tt.tty)
ios.SetStdoutTTY(tt.tty)
ios.SetStderrTTY(tt.tty)
f := &cmdutil.Factory{
IOStreams: ios,
}