From 8482e3d2a45df18e79e148acf98bb1a51997b529 Mon Sep 17 00:00:00 2001 From: "Babak K. Shandiz" Date: Fri, 5 Sep 2025 17:03:59 +0100 Subject: [PATCH] feat(agent-task view): support PR arg Signed-off-by: Babak K. Shandiz --- pkg/cmd/agent-task/view/view.go | 176 +++++++++++++++++++++++++-- pkg/cmd/agent-task/view/view_test.go | 52 ++++++-- 2 files changed, 209 insertions(+), 19 deletions(-) diff --git a/pkg/cmd/agent-task/view/view.go b/pkg/cmd/agent-task/view/view.go index f6ce4d468..687ed86cb 100644 --- a/pkg/cmd/agent-task/view/view.go +++ b/pkg/cmd/agent-task/view/view.go @@ -4,41 +4,74 @@ import ( "context" "errors" "fmt" + "net/http" "net/url" + "strconv" "time" "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/v2/api" + "github.com/cli/cli/v2/internal/ghinstance" + "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/internal/prompter" "github.com/cli/cli/v2/internal/text" "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" "github.com/cli/cli/v2/pkg/cmd/agent-task/shared" prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" + "github.com/shurcooL/githubv4" "github.com/spf13/cobra" ) +const defaultLimit = 60 + type ViewOptions struct { IO *iostreams.IOStreams + BaseRepo func() (ghrepo.Interface, error) CapiClient func() (capi.CapiClient, error) + HttpClient func() (*http.Client, error) + Finder prShared.PRFinder + Prompter prompter.Prompter SelectorArg string + PRNumber int + SessionID string } func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Command { opts := &ViewOptions{ IO: f.IOStreams, + HttpClient: f.HttpClient, CapiClient: shared.CapiClientFunc(f), + Prompter: f.Prompter, } cmd := &cobra.Command{ - Use: "view ", + Use: "view [ | | | ]", Short: "View an agent task session", Long: heredoc.Doc(` View an agent task session. `), - Args: cmdutil.ExactArgs(1, "a session ID is required"), + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - opts.SelectorArg = args[0] + // Support -R/--repo override + opts.BaseRepo = f.BaseRepo + + if len(args) > 0 { + opts.SelectorArg = args[0] + if shared.IsSessionID(opts.SelectorArg) { + opts.SessionID = opts.SelectorArg + } + } + + if opts.SessionID == "" && !opts.IO.CanPrompt() { + return fmt.Errorf("session ID is required when not running interactively") + } + + if opts.Finder == nil { + opts.Finder = prShared.NewFinder(f) + } if runF != nil { return runF(opts) @@ -47,6 +80,8 @@ func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Comman }, } + cmdutil.EnableRepoOverride(cmd, f) + return cmd } @@ -57,23 +92,117 @@ func viewRun(opts *ViewOptions) error { } ctx := context.Background() + cs := opts.IO.ColorScheme() opts.IO.StartProgressIndicatorWithLabel("Fetching agent session...") defer opts.IO.StopProgressIndicator() - session, err := capiClient.GetSession(ctx, opts.SelectorArg) - opts.IO.StopProgressIndicator() + var session *capi.Session - if err != nil { - if errors.Is(err, capi.ErrSessionNotFound) { - fmt.Fprintln(opts.IO.ErrOut, "session not found") + if opts.SessionID != "" { + if sess, err := capiClient.GetSession(ctx, opts.SessionID); err != nil { + if errors.Is(err, capi.ErrSessionNotFound) { + fmt.Fprintln(opts.IO.ErrOut, "session not found") + return cmdutil.SilentError + } + return err + } else { + session = sess + } + } else { + var resourceID int64 + + if opts.SelectorArg != "" { + // Finder does not support the PR/issue reference format (e.g. owner/repo#123) + // so we need to check if the selector arg is a reference and fetch the PR + // directly. + if repo, num, err := prShared.ParseFullReference(opts.SelectorArg); err == nil { + // We need to check the base repo to get the hostname. + baseRepo, err := opts.BaseRepo() + if err != nil { + return err + } + + hostname := baseRepo.RepoHost() + if repo.RepoHost() != hostname { + return fmt.Errorf("agent tasks are not supported on this host: %s", repo.RepoHost()) + } + + client, err := opts.HttpClient() + if err != nil { + return err + } + + resourceID, err = getPullRequestDatabaseID(ctx, client, hostname, repo, num) + if err != nil { + return fmt.Errorf("failed to get pull request: %w", err) + } + } + } + + if resourceID == 0 { + findOptions := prShared.FindOptions{ + Selector: opts.SelectorArg, + Fields: []string{"id", "url", "fullDatabaseId"}, + } + + pr, repo, err := opts.Finder.Find(findOptions) + if err != nil { + return err + } + + if repo.RepoHost() != ghinstance.Default() { + return fmt.Errorf("agent tasks are not supported on this host: %s", repo.RepoHost()) + } + + databaseID, err := strconv.ParseInt(pr.FullDatabaseID, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse pull request: %w", err) + } + + resourceID = databaseID + } + + // TODO(babakks): currently we just fetch a pre-defined number of + // matching sessions to avoid hitting the API too many times, but it's + // technically possible for a PR to be associated with lots of sessions + // (i.e. above our selected limit). + sessions, err := capiClient.ListSessionsByResourceID(ctx, "pull", resourceID, defaultLimit) + if err != nil { + return fmt.Errorf("failed to list sessions for pull request: %w", err) + } + + if len(sessions) == 0 { + fmt.Fprintln(opts.IO.ErrOut, "no session found for pull request") return cmdutil.SilentError } - return err + + session = sessions[0] + if len(sessions) > 1 { + now := time.Now() + options := make([]string, 0, len(sessions)) + for _, session := range sessions { + options = append(options, fmt.Sprintf( + "%s %s • %s", + shared.SessionSymbol(cs, session.State), + session.Name, + text.FuzzyAgo(now, session.CreatedAt), + )) + } + + opts.IO.StopProgressIndicator() + selected, err := opts.Prompter.Select("Select a session", options[0], options) + if err != nil { + return err + } + + session = sessions[selected] + } } + opts.IO.StopProgressIndicator() + out := opts.IO.Out - cs := opts.IO.ColorScheme() if session.PullRequest != nil { fmt.Fprintf(out, "%s • %s • %s%s\n", @@ -106,3 +235,30 @@ func viewRun(opts *ViewOptions) error { return nil } + +func getPullRequestDatabaseID(ctx context.Context, httpClient *http.Client, hostname string, repo ghrepo.Interface, number int) (int64, error) { + var resp struct { + Repository struct { + PullRequest struct { + FullDatabaseID string `graphql:"fullDatabaseId"` + } `graphql:"pullRequest(number: $number)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + + variables := map[string]interface{}{ + "owner": githubv4.String(repo.RepoOwner()), + "repo": githubv4.String(repo.RepoName()), + "number": githubv4.Int(number), + } + + apiClient := api.NewClientFromHTTP(httpClient) + if err := apiClient.Query(hostname, "GetPullRequestFullDatabaseID", &resp, variables); err != nil { + return 0, err + } + + databaseID, err := strconv.ParseInt(resp.Repository.PullRequest.FullDatabaseID, 10, 64) + if err != nil { + return 0, err + } + return databaseID, nil +} diff --git a/pkg/cmd/agent-task/view/view_test.go b/pkg/cmd/agent-task/view/view_test.go index 97304c399..5710f3ec3 100644 --- a/pkg/cmd/agent-task/view/view_test.go +++ b/pkg/cmd/agent-task/view/view_test.go @@ -10,6 +10,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" + "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" @@ -20,20 +21,49 @@ import ( func TestNewCmdList(t *testing.T) { tests := []struct { - name string - args string - wantOpts ViewOptions - wantErr string + name string + tty bool + args string + wantOpts ViewOptions + wantBaseRepo ghrepo.Interface + wantErr string }{ { - name: "no arguments", - wantErr: "a session ID is required", + name: "no arg tty", + tty: true, + args: "", + wantOpts: ViewOptions{}, }, { - name: "session ID arg", - args: "some-uuid", + name: "session ID arg tty", + tty: true, + args: "00000000-0000-0000-0000-000000000000", wantOpts: ViewOptions{ - SelectorArg: "some-uuid", + SelectorArg: "00000000-0000-0000-0000-000000000000", + SessionID: "00000000-0000-0000-0000-000000000000", + }, + }, + { + name: "non-session ID arg tty", + tty: true, + args: "some-arg", + wantOpts: ViewOptions{ + SelectorArg: "some-arg", + }, + }, + { + name: "session ID required if non-tty", + tty: false, + args: "some-arg", + wantErr: "session ID is required when not running interactively", + }, + { + name: "repo override", + tty: true, + args: "some-arg -R OWNER/REPO", + wantBaseRepo: ghrepo.New("OWNER", "REPO"), + wantOpts: ViewOptions{ + SelectorArg: "some-arg", }, }, } @@ -41,6 +71,10 @@ func TestNewCmdList(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ios, _, _, _ := iostreams.Test() + ios.SetStdinTTY(tt.tty) + ios.SetStdoutTTY(tt.tty) + ios.SetStderrTTY(tt.tty) + f := &cmdutil.Factory{ IOStreams: ios, }