Merge pull request #11746 from cli/babakks/add-follow-option-to-agent-task-create
`gh agent-task create`: add `--follow` flag
This commit is contained in:
commit
c55003bba2
2 changed files with 174 additions and 49 deletions
|
|
@ -21,25 +21,38 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const defaultLogPollInterval = 5 * time.Second
|
||||
|
||||
// CreateOptions holds options for create command
|
||||
type CreateOptions struct {
|
||||
IO *iostreams.IOStreams
|
||||
BaseRepo func() (ghrepo.Interface, error)
|
||||
CapiClient func() (capi.CapiClient, error)
|
||||
Config func() (gh.Config, error)
|
||||
IO *iostreams.IOStreams
|
||||
BaseRepo func() (ghrepo.Interface, error)
|
||||
CapiClient func() (capi.CapiClient, error)
|
||||
Config func() (gh.Config, error)
|
||||
|
||||
LogRenderer func() shared.LogRenderer
|
||||
Sleep func(d time.Duration)
|
||||
|
||||
ProblemStatement string
|
||||
BackOff backoff.BackOff
|
||||
BaseBranch string
|
||||
Prompter prompter.Prompter
|
||||
ProblemStatementFile string
|
||||
Follow bool
|
||||
}
|
||||
|
||||
func defaultLogRenderer() shared.LogRenderer {
|
||||
return shared.NewLogRenderer()
|
||||
}
|
||||
|
||||
func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Command {
|
||||
opts := &CreateOptions{
|
||||
IO: f.IOStreams,
|
||||
CapiClient: shared.CapiClientFunc(f),
|
||||
Config: f.Config,
|
||||
Prompter: f.Prompter,
|
||||
IO: f.IOStreams,
|
||||
CapiClient: shared.CapiClientFunc(f),
|
||||
Config: f.Config,
|
||||
Prompter: f.Prompter,
|
||||
LogRenderer: defaultLogRenderer,
|
||||
Sleep: time.Sleep,
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
|
|
@ -70,6 +83,9 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
|
|||
# Create a task from an inline description
|
||||
$ gh agent-task create "build me a new app"
|
||||
|
||||
# Create a task from an inline description and follow logs
|
||||
$ gh agent-task create "build me a new app" --follow
|
||||
|
||||
# Create a task from a file
|
||||
$ gh agent-task create -F task-desc.md
|
||||
|
||||
|
|
@ -91,6 +107,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
|
|||
|
||||
cmd.Flags().StringVarP(&opts.ProblemStatementFile, "from-file", "F", "", "Read task description from `file` (use \"-\" to read from standard input)")
|
||||
cmd.Flags().StringVarP(&opts.BaseBranch, "base", "b", "", "Base branch for the pull request (use default branch if not provided)")
|
||||
cmd.Flags().BoolVar(&opts.Follow, "follow", false, "Follow agent session logs")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
|
@ -151,40 +168,24 @@ func createRun(opts *CreateOptions) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Print this agent session URL and exit if we happen to get it.
|
||||
// Right now, this never happens.
|
||||
if job.PullRequest != nil && job.PullRequest.Number > 0 {
|
||||
fmt.Fprintf(opts.IO.Out, "%s\n", agentSessionWebURL(repo, job))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise, poll using exponential backoff until we either observe a PR or hit the overall timeout.
|
||||
if opts.BackOff == nil {
|
||||
opts.BackOff = backoff.NewExponentialBackOff(
|
||||
backoff.WithMaxElapsedTime(10*time.Second),
|
||||
backoff.WithInitialInterval(300*time.Millisecond),
|
||||
backoff.WithMaxInterval(10*time.Second),
|
||||
backoff.WithMultiplier(1.5),
|
||||
)
|
||||
}
|
||||
|
||||
jobWithPR, err := fetchJobWithBackoff(ctx, client, repo, job.ID, opts.BackOff)
|
||||
if err != nil {
|
||||
// If this does happen ever, we still want the user to get the
|
||||
// fallback message and URL. So, we don't return with this error,
|
||||
// but we do still want to print it.
|
||||
fmt.Fprintf(opts.IO.ErrOut, "%v\n", err)
|
||||
}
|
||||
|
||||
if jobWithPR != nil {
|
||||
opts.IO.StopProgressIndicator()
|
||||
fmt.Fprintln(opts.IO.Out, agentSessionWebURL(repo, jobWithPR))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fallback if PR not yet ready
|
||||
sessionURL, err := fetchJobSessionURL(ctx, client, repo, job, opts.BackOff)
|
||||
opts.IO.StopProgressIndicator()
|
||||
fmt.Fprintf(opts.IO.Out, "job %s queued. View progress: https://github.com/copilot/agents\n", job.ID)
|
||||
|
||||
if sessionURL != "" {
|
||||
fmt.Fprintln(opts.IO.Out, sessionURL)
|
||||
} else {
|
||||
if err != nil {
|
||||
// If this does happen ever, we still want the user to get the fallback
|
||||
// message and URL. So, we don't return with this error, but we do still
|
||||
// want to print it.
|
||||
fmt.Fprintf(opts.IO.ErrOut, "%v\n", err)
|
||||
}
|
||||
fmt.Fprintf(opts.IO.Out, "job %s queued. View progress: %s\n", job.ID, capi.AgentsHomeURL)
|
||||
}
|
||||
|
||||
if opts.Follow {
|
||||
return followLogs(opts, client, job.SessionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -198,6 +199,31 @@ func agentSessionWebURL(repo ghrepo.Interface, j *capi.Job) string {
|
|||
return fmt.Sprintf("https://github.com/%s/%s/pull/%d/agent-sessions/%s", url.PathEscape(repo.RepoOwner()), url.PathEscape(repo.RepoName()), j.PullRequest.Number, url.PathEscape(j.SessionID))
|
||||
}
|
||||
|
||||
// fetchJobSessionURL tries to return the agent session URL for a job. If the pull
|
||||
// request is not yet available, ("", nil) is returned.
|
||||
func fetchJobSessionURL(ctx context.Context, client capi.CapiClient, repo ghrepo.Interface, job *capi.Job, bo backoff.BackOff) (string, error) {
|
||||
if job.PullRequest != nil && job.PullRequest.Number > 0 {
|
||||
// Return the agent session URL if we happen to get it.
|
||||
// Right now, this never happens.
|
||||
return agentSessionWebURL(repo, job), nil
|
||||
}
|
||||
|
||||
if bo == nil {
|
||||
bo = backoff.NewExponentialBackOff(
|
||||
backoff.WithMaxElapsedTime(10*time.Second),
|
||||
backoff.WithInitialInterval(300*time.Millisecond),
|
||||
backoff.WithMaxInterval(10*time.Second),
|
||||
backoff.WithMultiplier(1.5),
|
||||
)
|
||||
}
|
||||
|
||||
jobWithPR, err := fetchJobWithBackoff(ctx, client, repo, job.ID, bo)
|
||||
if jobWithPR != nil {
|
||||
return agentSessionWebURL(repo, jobWithPR), nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
// fetchJobWithBackoff polls the job resource until a PR number is present or the overall
|
||||
// timeout elapses. It returns the updated Job on success, (nil, nil) on timeout,
|
||||
// and (nil, error) only for non-retryable failures.
|
||||
|
|
@ -228,3 +254,25 @@ func fetchJobWithBackoff(ctx context.Context, client capi.CapiClient, repo ghrep
|
|||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func followLogs(opts *CreateOptions, capiClient capi.CapiClient, sessionID string) error {
|
||||
ctx := context.Background()
|
||||
|
||||
renderer := opts.LogRenderer()
|
||||
|
||||
var called bool
|
||||
fetcher := func() ([]byte, error) {
|
||||
if called {
|
||||
opts.Sleep(defaultLogPollInterval)
|
||||
}
|
||||
called = true
|
||||
raw, err := capiClient.GetSessionLogs(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
fmt.Fprintln(opts.IO.Out, "")
|
||||
return renderer.Follow(fetcher, opts.IO.Out, opts.IO)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,13 +9,16 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/MakeNowJust/heredoc"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/cli/cli/v2/internal/ghrepo"
|
||||
"github.com/cli/cli/v2/internal/prompter"
|
||||
"github.com/cli/cli/v2/pkg/cmd/agent-task/capi"
|
||||
"github.com/cli/cli/v2/pkg/cmd/agent-task/shared"
|
||||
"github.com/cli/cli/v2/pkg/cmdutil"
|
||||
"github.com/cli/cli/v2/pkg/iostreams"
|
||||
"github.com/google/shlex"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
|
@ -57,6 +60,15 @@ func TestNewCmdCreate(t *testing.T) {
|
|||
BaseBranch: "feature",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with --follow",
|
||||
args: "'task description from args' --follow",
|
||||
wantOpts: &CreateOptions{
|
||||
ProblemStatement: "task description from args",
|
||||
ProblemStatementFile: "",
|
||||
Follow: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
@ -135,14 +147,15 @@ func Test_createRun(t *testing.T) {
|
|||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
isTTY bool
|
||||
capiStubs func(*testing.T, *capi.CapiClientMock)
|
||||
opts *CreateOptions // input options (IO & BackOff set later)
|
||||
wantStdout string
|
||||
wantStdErr string
|
||||
wantErr string
|
||||
wantErrIs error
|
||||
name string
|
||||
isTTY bool
|
||||
opts *CreateOptions // input options (IO & BackOff set later)
|
||||
capiStubs func(*testing.T, *capi.CapiClientMock)
|
||||
logRendererStubs func(*testing.T, *shared.LogRendererMock)
|
||||
wantStdout string
|
||||
wantStdErr string
|
||||
wantErr string
|
||||
wantErrIs error
|
||||
}{
|
||||
{
|
||||
name: "interactive with file prompts to edit with file contents",
|
||||
|
|
@ -428,6 +441,62 @@ func Test_createRun(t *testing.T) {
|
|||
},
|
||||
wantStdout: "job job123 queued. View progress: https://github.com/copilot/agents\n",
|
||||
},
|
||||
{
|
||||
name: "success with follow logs and delayed PR after polling",
|
||||
opts: &CreateOptions{
|
||||
BaseRepo: func() (ghrepo.Interface, error) {
|
||||
return ghrepo.New("OWNER", "REPO"), nil
|
||||
},
|
||||
ProblemStatement: "Do the thing",
|
||||
Follow: true,
|
||||
Sleep: func(d time.Duration) {},
|
||||
},
|
||||
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
|
||||
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
|
||||
require.Equal(t, "OWNER", owner)
|
||||
require.Equal(t, "REPO", repo)
|
||||
require.Equal(t, "Do the thing", problemStatement)
|
||||
require.Equal(t, "", baseBranch)
|
||||
return &createdJobSuccess, nil
|
||||
}
|
||||
m.GetJobFunc = func(ctx context.Context, owner, repo, jobID string) (*capi.Job, error) {
|
||||
require.Equal(t, "OWNER", owner)
|
||||
require.Equal(t, "REPO", repo)
|
||||
require.Equal(t, "job123", jobID)
|
||||
return &createdJobSuccessWithPR, nil
|
||||
}
|
||||
|
||||
var count int
|
||||
m.GetSessionLogsFunc = func(_ context.Context, id string) ([]byte, error) {
|
||||
assert.Equal(t, "sess1", id)
|
||||
|
||||
count++
|
||||
require.Less(t, count, 3, "too many calls to fetch logs")
|
||||
if count == 1 {
|
||||
return []byte("<raw-logs-one>"), nil
|
||||
}
|
||||
return []byte("<raw-logs-two>"), nil
|
||||
}
|
||||
},
|
||||
logRendererStubs: func(t *testing.T, m *shared.LogRendererMock) {
|
||||
m.FollowFunc = func(fetcher func() ([]byte, error), w io.Writer, ios *iostreams.IOStreams) error {
|
||||
raw, err := fetcher()
|
||||
require.NoError(t, err)
|
||||
w.Write([]byte("(rendered:) " + string(raw) + "\n"))
|
||||
|
||||
raw, err = fetcher()
|
||||
require.NoError(t, err)
|
||||
w.Write([]byte("(rendered:) " + string(raw) + "\n"))
|
||||
return nil
|
||||
}
|
||||
},
|
||||
wantStdout: heredoc.Doc(`
|
||||
https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1
|
||||
|
||||
(rendered:) <raw-logs-one>
|
||||
(rendered:) <raw-logs-two>
|
||||
`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
@ -452,6 +521,14 @@ func Test_createRun(t *testing.T) {
|
|||
// fast backoff
|
||||
tt.opts.BackOff = backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 3)
|
||||
|
||||
logRenderer := &shared.LogRendererMock{}
|
||||
if tt.logRendererStubs != nil {
|
||||
tt.logRendererStubs(t, logRenderer)
|
||||
}
|
||||
tt.opts.LogRenderer = func() shared.LogRenderer {
|
||||
return logRenderer
|
||||
}
|
||||
|
||||
err := createRun(tt.opts)
|
||||
if tt.wantErrIs != nil {
|
||||
require.ErrorIs(t, err, tt.wantErrIs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue