diff --git a/pkg/cmd/agent-task/create/create_test.go b/pkg/cmd/agent-task/create/create_test.go index b855ce687..d1298e606 100644 --- a/pkg/cmd/agent-task/create/create_test.go +++ b/pkg/cmd/agent-task/create/create_test.go @@ -9,13 +9,16 @@ import ( "testing" "time" + "github.com/MakeNowJust/heredoc" "github.com/cenkalti/backoff/v4" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/prompter" "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" + "github.com/cli/cli/v2/pkg/cmd/agent-task/shared" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" "github.com/google/shlex" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -57,6 +60,15 @@ func TestNewCmdCreate(t *testing.T) { BaseBranch: "feature", }, }, + { + name: "with --follow", + args: "'task description from args' --follow", + wantOpts: &CreateOptions{ + ProblemStatement: "task description from args", + ProblemStatementFile: "", + Follow: true, + }, + }, } for _, tt := range tests { @@ -135,14 +147,15 @@ func Test_createRun(t *testing.T) { } tests := []struct { - name string - isTTY bool - capiStubs func(*testing.T, *capi.CapiClientMock) - opts *CreateOptions // input options (IO & BackOff set later) - wantStdout string - wantStdErr string - wantErr string - wantErrIs error + name string + isTTY bool + opts *CreateOptions // input options (IO & BackOff set later) + capiStubs func(*testing.T, *capi.CapiClientMock) + logRendererStubs func(*testing.T, *shared.LogRendererMock) + wantStdout string + wantStdErr string + wantErr string + wantErrIs error }{ { name: "interactive with file prompts to edit with file contents", @@ -428,6 +441,62 @@ func Test_createRun(t *testing.T) { }, wantStdout: "job job123 queued. View progress: https://github.com/copilot/agents\n", }, + { + name: "success with follow logs and delayed PR after polling", + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + ProblemStatement: "Do the thing", + Follow: true, + Sleep: func(d time.Duration) {}, + }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + require.Equal(t, "OWNER", owner) + require.Equal(t, "REPO", repo) + require.Equal(t, "Do the thing", problemStatement) + require.Equal(t, "", baseBranch) + return &createdJobSuccess, nil + } + m.GetJobFunc = func(ctx context.Context, owner, repo, jobID string) (*capi.Job, error) { + require.Equal(t, "OWNER", owner) + require.Equal(t, "REPO", repo) + require.Equal(t, "job123", jobID) + return &createdJobSuccessWithPR, nil + } + + var count int + m.GetSessionLogsFunc = func(_ context.Context, id string) ([]byte, error) { + assert.Equal(t, "sess1", id) + + count++ + require.Less(t, count, 3, "too many calls to fetch logs") + if count == 1 { + return []byte(""), nil + } + return []byte(""), nil + } + }, + logRendererStubs: func(t *testing.T, m *shared.LogRendererMock) { + m.FollowFunc = func(fetcher func() ([]byte, error), w io.Writer, ios *iostreams.IOStreams) error { + raw, err := fetcher() + require.NoError(t, err) + w.Write([]byte("(rendered:) " + string(raw) + "\n")) + + raw, err = fetcher() + require.NoError(t, err) + w.Write([]byte("(rendered:) " + string(raw) + "\n")) + return nil + } + }, + wantStdout: heredoc.Doc(` + https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1 + + (rendered:) + (rendered:) + `), + }, } for _, tt := range tests { @@ -452,6 +521,14 @@ func Test_createRun(t *testing.T) { // fast backoff tt.opts.BackOff = backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 3) + logRenderer := &shared.LogRendererMock{} + if tt.logRendererStubs != nil { + tt.logRendererStubs(t, logRenderer) + } + tt.opts.LogRenderer = func() shared.LogRenderer { + return logRenderer + } + err := createRun(tt.opts) if tt.wantErrIs != nil { require.ErrorIs(t, err, tt.wantErrIs)