Merge pull request #11746 from cli/babakks/add-follow-option-to-agent-task-create

`gh agent-task create`: add `--follow` flag
This commit is contained in:
Kynan Ware 2025-09-16 18:33:26 -06:00 committed by GitHub
commit c55003bba2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 174 additions and 49 deletions

View file

@ -21,25 +21,38 @@ import (
"github.com/spf13/cobra"
)
const defaultLogPollInterval = 5 * time.Second
// CreateOptions holds options for create command
type CreateOptions struct {
IO *iostreams.IOStreams
BaseRepo func() (ghrepo.Interface, error)
CapiClient func() (capi.CapiClient, error)
Config func() (gh.Config, error)
IO *iostreams.IOStreams
BaseRepo func() (ghrepo.Interface, error)
CapiClient func() (capi.CapiClient, error)
Config func() (gh.Config, error)
LogRenderer func() shared.LogRenderer
Sleep func(d time.Duration)
ProblemStatement string
BackOff backoff.BackOff
BaseBranch string
Prompter prompter.Prompter
ProblemStatementFile string
Follow bool
}
func defaultLogRenderer() shared.LogRenderer {
return shared.NewLogRenderer()
}
func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Command {
opts := &CreateOptions{
IO: f.IOStreams,
CapiClient: shared.CapiClientFunc(f),
Config: f.Config,
Prompter: f.Prompter,
IO: f.IOStreams,
CapiClient: shared.CapiClientFunc(f),
Config: f.Config,
Prompter: f.Prompter,
LogRenderer: defaultLogRenderer,
Sleep: time.Sleep,
}
cmd := &cobra.Command{
@ -70,6 +83,9 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
# Create a task from an inline description
$ gh agent-task create "build me a new app"
# Create a task from an inline description and follow logs
$ gh agent-task create "build me a new app" --follow
# Create a task from a file
$ gh agent-task create -F task-desc.md
@ -91,6 +107,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
cmd.Flags().StringVarP(&opts.ProblemStatementFile, "from-file", "F", "", "Read task description from `file` (use \"-\" to read from standard input)")
cmd.Flags().StringVarP(&opts.BaseBranch, "base", "b", "", "Base branch for the pull request (use default branch if not provided)")
cmd.Flags().BoolVar(&opts.Follow, "follow", false, "Follow agent session logs")
return cmd
}
@ -151,40 +168,24 @@ func createRun(opts *CreateOptions) error {
return err
}
// Print this agent session URL and exit if we happen to get it.
// Right now, this never happens.
if job.PullRequest != nil && job.PullRequest.Number > 0 {
fmt.Fprintf(opts.IO.Out, "%s\n", agentSessionWebURL(repo, job))
return nil
}
// Otherwise, poll using exponential backoff until we either observe a PR or hit the overall timeout.
if opts.BackOff == nil {
opts.BackOff = backoff.NewExponentialBackOff(
backoff.WithMaxElapsedTime(10*time.Second),
backoff.WithInitialInterval(300*time.Millisecond),
backoff.WithMaxInterval(10*time.Second),
backoff.WithMultiplier(1.5),
)
}
jobWithPR, err := fetchJobWithBackoff(ctx, client, repo, job.ID, opts.BackOff)
if err != nil {
// If this does happen ever, we still want the user to get the
// fallback message and URL. So, we don't return with this error,
// but we do still want to print it.
fmt.Fprintf(opts.IO.ErrOut, "%v\n", err)
}
if jobWithPR != nil {
opts.IO.StopProgressIndicator()
fmt.Fprintln(opts.IO.Out, agentSessionWebURL(repo, jobWithPR))
return nil
}
// Fallback if PR not yet ready
sessionURL, err := fetchJobSessionURL(ctx, client, repo, job, opts.BackOff)
opts.IO.StopProgressIndicator()
fmt.Fprintf(opts.IO.Out, "job %s queued. View progress: https://github.com/copilot/agents\n", job.ID)
if sessionURL != "" {
fmt.Fprintln(opts.IO.Out, sessionURL)
} else {
if err != nil {
// If this does happen ever, we still want the user to get the fallback
// message and URL. So, we don't return with this error, but we do still
// want to print it.
fmt.Fprintf(opts.IO.ErrOut, "%v\n", err)
}
fmt.Fprintf(opts.IO.Out, "job %s queued. View progress: %s\n", job.ID, capi.AgentsHomeURL)
}
if opts.Follow {
return followLogs(opts, client, job.SessionID)
}
return nil
}
@ -198,6 +199,31 @@ func agentSessionWebURL(repo ghrepo.Interface, j *capi.Job) string {
return fmt.Sprintf("https://github.com/%s/%s/pull/%d/agent-sessions/%s", url.PathEscape(repo.RepoOwner()), url.PathEscape(repo.RepoName()), j.PullRequest.Number, url.PathEscape(j.SessionID))
}
// fetchJobSessionURL tries to return the agent session URL for a job. If the pull
// request is not yet available, ("", nil) is returned.
func fetchJobSessionURL(ctx context.Context, client capi.CapiClient, repo ghrepo.Interface, job *capi.Job, bo backoff.BackOff) (string, error) {
if job.PullRequest != nil && job.PullRequest.Number > 0 {
// Return the agent session URL if we happen to get it.
// Right now, this never happens.
return agentSessionWebURL(repo, job), nil
}
if bo == nil {
bo = backoff.NewExponentialBackOff(
backoff.WithMaxElapsedTime(10*time.Second),
backoff.WithInitialInterval(300*time.Millisecond),
backoff.WithMaxInterval(10*time.Second),
backoff.WithMultiplier(1.5),
)
}
jobWithPR, err := fetchJobWithBackoff(ctx, client, repo, job.ID, bo)
if jobWithPR != nil {
return agentSessionWebURL(repo, jobWithPR), nil
}
return "", err
}
// fetchJobWithBackoff polls the job resource until a PR number is present or the overall
// timeout elapses. It returns the updated Job on success, (nil, nil) on timeout,
// and (nil, error) only for non-retryable failures.
@ -228,3 +254,25 @@ func fetchJobWithBackoff(ctx context.Context, client capi.CapiClient, repo ghrep
}
return result, nil
}
func followLogs(opts *CreateOptions, capiClient capi.CapiClient, sessionID string) error {
ctx := context.Background()
renderer := opts.LogRenderer()
var called bool
fetcher := func() ([]byte, error) {
if called {
opts.Sleep(defaultLogPollInterval)
}
called = true
raw, err := capiClient.GetSessionLogs(ctx, sessionID)
if err != nil {
return nil, err
}
return raw, nil
}
fmt.Fprintln(opts.IO.Out, "")
return renderer.Follow(fetcher, opts.IO.Out, opts.IO)
}

View file

@ -9,13 +9,16 @@ import (
"testing"
"time"
"github.com/MakeNowJust/heredoc"
"github.com/cenkalti/backoff/v4"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/internal/prompter"
"github.com/cli/cli/v2/pkg/cmd/agent-task/capi"
"github.com/cli/cli/v2/pkg/cmd/agent-task/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/google/shlex"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -57,6 +60,15 @@ func TestNewCmdCreate(t *testing.T) {
BaseBranch: "feature",
},
},
{
name: "with --follow",
args: "'task description from args' --follow",
wantOpts: &CreateOptions{
ProblemStatement: "task description from args",
ProblemStatementFile: "",
Follow: true,
},
},
}
for _, tt := range tests {
@ -135,14 +147,15 @@ func Test_createRun(t *testing.T) {
}
tests := []struct {
name string
isTTY bool
capiStubs func(*testing.T, *capi.CapiClientMock)
opts *CreateOptions // input options (IO & BackOff set later)
wantStdout string
wantStdErr string
wantErr string
wantErrIs error
name string
isTTY bool
opts *CreateOptions // input options (IO & BackOff set later)
capiStubs func(*testing.T, *capi.CapiClientMock)
logRendererStubs func(*testing.T, *shared.LogRendererMock)
wantStdout string
wantStdErr string
wantErr string
wantErrIs error
}{
{
name: "interactive with file prompts to edit with file contents",
@ -428,6 +441,62 @@ func Test_createRun(t *testing.T) {
},
wantStdout: "job job123 queued. View progress: https://github.com/copilot/agents\n",
},
{
name: "success with follow logs and delayed PR after polling",
opts: &CreateOptions{
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.New("OWNER", "REPO"), nil
},
ProblemStatement: "Do the thing",
Follow: true,
Sleep: func(d time.Duration) {},
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "Do the thing", problemStatement)
require.Equal(t, "", baseBranch)
return &createdJobSuccess, nil
}
m.GetJobFunc = func(ctx context.Context, owner, repo, jobID string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "job123", jobID)
return &createdJobSuccessWithPR, nil
}
var count int
m.GetSessionLogsFunc = func(_ context.Context, id string) ([]byte, error) {
assert.Equal(t, "sess1", id)
count++
require.Less(t, count, 3, "too many calls to fetch logs")
if count == 1 {
return []byte("<raw-logs-one>"), nil
}
return []byte("<raw-logs-two>"), nil
}
},
logRendererStubs: func(t *testing.T, m *shared.LogRendererMock) {
m.FollowFunc = func(fetcher func() ([]byte, error), w io.Writer, ios *iostreams.IOStreams) error {
raw, err := fetcher()
require.NoError(t, err)
w.Write([]byte("(rendered:) " + string(raw) + "\n"))
raw, err = fetcher()
require.NoError(t, err)
w.Write([]byte("(rendered:) " + string(raw) + "\n"))
return nil
}
},
wantStdout: heredoc.Doc(`
https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1
(rendered:) <raw-logs-one>
(rendered:) <raw-logs-two>
`),
},
}
for _, tt := range tests {
@ -452,6 +521,14 @@ func Test_createRun(t *testing.T) {
// fast backoff
tt.opts.BackOff = backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 3)
logRenderer := &shared.LogRendererMock{}
if tt.logRendererStubs != nil {
tt.logRendererStubs(t, logRenderer)
}
tt.opts.LogRenderer = func() shared.LogRenderer {
return logRenderer
}
err := createRun(tt.opts)
if tt.wantErrIs != nil {
require.ErrorIs(t, err, tt.wantErrIs)