diff --git a/pkg/cmd/agent-task/create/create.go b/pkg/cmd/agent-task/create/create.go index c0d1c292e..08c1f8c02 100644 --- a/pkg/cmd/agent-task/create/create.go +++ b/pkg/cmd/agent-task/create/create.go @@ -23,14 +23,15 @@ import ( // CreateOptions holds options for create command type CreateOptions struct { - IO *iostreams.IOStreams - BaseRepo func() (ghrepo.Interface, error) - CapiClient func() (capi.CapiClient, error) - Config func() (gh.Config, error) - ProblemStatement string - BackOff backoff.BackOff - BaseBranch string - Prompter prompter.Prompter + IO *iostreams.IOStreams + BaseRepo func() (ghrepo.Interface, error) + CapiClient func() (capi.CapiClient, error) + Config func() (gh.Config, error) + ProblemStatement string + BackOff backoff.BackOff + BaseBranch string + Prompter prompter.Prompter + ProblemStatementFile string } func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Command { @@ -41,8 +42,6 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co Prompter: f.Prompter, } - var fromFileName string - cmd := &cobra.Command{ Use: "create [] [flags]", Short: "Create an agent task (preview)", @@ -51,23 +50,15 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co // Support -R/--repo override opts.BaseRepo = f.BaseRepo - if err := cmdutil.MutuallyExclusive("only one of -F or arg can be provided", len(args) > 0, fromFileName != ""); err != nil { + if err := cmdutil.MutuallyExclusive("only one of -F or arg can be provided", len(args) > 0, opts.ProblemStatementFile != ""); err != nil { return err } - // Gather arg inputs for ProblemStatement + // Populate ProblemStatement from arg if len(args) > 0 { opts.ProblemStatement = args[0] - } else if fromFileName != "" { - fileContent, err := cmdutil.ReadFile(fromFileName, opts.IO.In) - if err != nil { - return cmdutil.FlagErrorf("could not read task description file: %v", err) - } - trimmed := strings.TrimSpace(string(fileContent)) - if trimmed == "" { - return cmdutil.FlagErrorf("task description file is empty") - } - opts.ProblemStatement = trimmed + } else if opts.ProblemStatementFile == "" && !opts.IO.CanPrompt() { + return cmdutil.FlagErrorf("a task description or -F is required when running non-interactively") } if runF != nil { @@ -85,9 +76,12 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co # Create a task with problem statement from stdin $ echo "build me a new app" | gh agent-task create -F - - # Create a task with an editor prompt (interactive) + # Create a task with an editor $ gh agent-task create + # Create a task with an editor and a file as a template + $ gh agent-task create -F task-desc.md + # Select a different base branch for the PR $ gh agent-task create "fix errors" --base branch `), @@ -95,7 +89,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co cmdutil.EnableRepoOverride(cmd, f) - cmd.Flags().StringVarP(&fromFileName, "from-file", "F", "", "Read task description from `file` (use \"-\" to read from standard input)") + cmd.Flags().StringVarP(&opts.ProblemStatementFile, "from-file", "F", "", "Read task description from `file` (use \"-\" to read from standard input)") cmd.Flags().StringVarP(&opts.BaseBranch, "base", "b", "", "Base branch for the pull request (use default branch if not provided)") return cmd @@ -110,20 +104,37 @@ func createRun(opts *CreateOptions) error { } if opts.ProblemStatement == "" { - if !opts.IO.CanPrompt() { - return cmdutil.FlagErrorf("a task description or -F is required when running non-interactively") + // Load initial problem statement from file, if provided + if opts.ProblemStatementFile != "" { + fileContent, err := cmdutil.ReadFile(opts.ProblemStatementFile, opts.IO.In) + if err != nil { + return cmdutil.FlagErrorf("could not read task description file: %v", err) + } + opts.ProblemStatement = strings.TrimSpace(string(fileContent)) } - desc, err := opts.Prompter.MarkdownEditor("Enter the task description", opts.ProblemStatement, false) + if opts.IO.CanPrompt() { + desc, err := opts.Prompter.MarkdownEditor("Enter the task description", opts.ProblemStatement, false) + if err != nil { + return err + } + opts.ProblemStatement = strings.TrimSpace(desc) + } + } + + if opts.ProblemStatement == "" { + fmt.Fprintf(opts.IO.ErrOut, "a task description is required.\n") + return cmdutil.SilentError + } + + if opts.IO.CanPrompt() { + confirm, err := opts.Prompter.Confirm("Submit agent task", true) if err != nil { return err } - - trimmed := strings.TrimSpace(desc) - if trimmed == "" { - return cmdutil.FlagErrorf("a task description is required") + if !confirm { + return cmdutil.SilentError } - opts.ProblemStatement = trimmed } client, err := opts.CapiClient() diff --git a/pkg/cmd/agent-task/create/create_test.go b/pkg/cmd/agent-task/create/create_test.go index edf03f5c9..b855ce687 100644 --- a/pkg/cmd/agent-task/create/create_test.go +++ b/pkg/cmd/agent-task/create/create_test.go @@ -3,7 +3,6 @@ package create import ( "context" "errors" - "fmt" "io" "os" "path/filepath" @@ -21,46 +20,27 @@ import ( ) func TestNewCmdCreate(t *testing.T) { - tmpDir := t.TempDir() - - tmpEmptyFile := filepath.Join(tmpDir, "empty-task-description.md") - err := os.WriteFile(tmpEmptyFile, []byte(" \n\n"), 0600) - require.NoError(t, err) - - tmpFile := filepath.Join(tmpDir, "task-description.md") - err = os.WriteFile(tmpFile, []byte("task description from file"), 0600) - require.NoError(t, err) - tests := []struct { name string args string - stdin string - wantOpts *CreateOptions // nil when expecting error + tty bool + wantOpts *CreateOptions wantErr string }{ { name: "no args nor file returns no error (prompting path)", + tty: true, + wantOpts: &CreateOptions{ + ProblemStatement: "", + ProblemStatementFile: "", + }, }, { name: "arg only success", args: "'task description from args'", wantOpts: &CreateOptions{ - ProblemStatement: "task description from args", - }, - }, - { - name: "from-file success", - args: fmt.Sprintf("-F '%s'", tmpFile), - wantOpts: &CreateOptions{ - ProblemStatement: "task description from file", - }, - }, - { - name: "file content from stdin success", - args: "-F -", - stdin: "task description from stdin", - wantOpts: &CreateOptions{ - ProblemStatement: "task description from stdin", + ProblemStatement: "task description from args", + ProblemStatementFile: "", }, }, { @@ -69,29 +49,25 @@ func TestNewCmdCreate(t *testing.T) { wantErr: "only one of -F or arg can be provided", }, { - name: "missing file path", - args: "-F does-not-exist.md", - wantErr: "could not read task description file: open does-not-exist.md:", - }, - { - name: "empty file", - args: fmt.Sprintf("-F '%s'", tmpEmptyFile), - wantErr: "task description file is empty", - }, - { - name: "empty from stdin", - args: "-F -", - stdin: " \n\n", - wantErr: "task description file is empty", + name: "base branch sets baseBranch field", + args: "'task description' -b feature", + wantOpts: &CreateOptions{ + ProblemStatement: "task description", + ProblemStatementFile: "", + BaseBranch: "feature", + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ios, stdin, _, _ := iostreams.Test() - f := &cmdutil.Factory{ - IOStreams: ios, + if tt.tty { + ios.SetStdinTTY(true) + ios.SetStdoutTTY(true) + ios.SetStderrTTY(true) } + f := &cmdutil.Factory{IOStreams: ios} var gotOpts *CreateOptions cmd := NewCmdCreate(f, func(o *CreateOptions) error { @@ -102,31 +78,33 @@ func TestNewCmdCreate(t *testing.T) { argv, err := shlex.Split(tt.args) require.NoError(t, err) cmd.SetArgs(argv) - cmd.SetIn(stdin) cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) - if tt.stdin != "" { - stdin.WriteString(tt.stdin) - } - _, err = cmd.ExecuteC() - if tt.wantErr != "" { - require.ErrorContains(t, err, tt.wantErr) - return + require.Error(t, err, tt.wantErr) + } else { + require.NoError(t, err) } - require.NoError(t, err) if tt.wantOpts != nil { require.Equal(t, tt.wantOpts.ProblemStatement, gotOpts.ProblemStatement) + require.Equal(t, tt.wantOpts.ProblemStatementFile, gotOpts.ProblemStatementFile) + require.Equal(t, tt.wantOpts.BaseBranch, gotOpts.BaseBranch) } }) } } func Test_createRun(t *testing.T) { + tmpDir := t.TempDir() + taskDescFile := filepath.Join(tmpDir, "task-description.md") + emptyTaskDescFile := filepath.Join(tmpDir, "empty-task-description.md") + require.NoError(t, os.WriteFile(taskDescFile, []byte("task description from file"), 0600)) + require.NoError(t, os.WriteFile(emptyTaskDescFile, []byte(" \n\n"), 0600)) + sampleDateString := "2025-08-29T00:00:00Z" sampleDate, err := time.Parse(time.RFC3339, sampleDateString) require.NoError(t, err) @@ -157,64 +135,169 @@ func Test_createRun(t *testing.T) { } tests := []struct { - name string - capiStubs func(*testing.T, *capi.CapiClientMock) - baseRepoFunc func() (ghrepo.Interface, error) - baseBranch string - isTTY bool - prompterMock *prompter.PrompterMock - problemStatement string - wantStdout string - wantStdErr string - wantErr string + name string + isTTY bool + capiStubs func(*testing.T, *capi.CapiClientMock) + opts *CreateOptions // input options (IO & BackOff set later) + wantStdout string + wantStdErr string + wantErr string + wantErrIs error }{ { - name: "missing repo returns error", - baseRepoFunc: func() (ghrepo.Interface, error) { return nil, nil }, - wantErr: "a repository is required; re-run in a repository or supply one with --repo owner/name", + name: "interactive with file prompts to edit with file contents", + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, + ProblemStatement: "", + ProblemStatementFile: taskDescFile, + Prompter: &prompter.PrompterMock{ + MarkdownEditorFunc: func(prompt, defaultValue string, blankAllowed bool) (string, error) { + require.Equal(t, "Enter the task description", prompt) + require.Equal(t, "task description from file", defaultValue) + return "edited task description", nil + }, + ConfirmFunc: func(message string, defaultValue bool) (bool, error) { + require.Equal(t, "Submit agent task", message) + return true, nil + }, + }, + }, + isTTY: true, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + require.Equal(t, "OWNER", owner) + require.Equal(t, "REPO", repo) + require.Equal(t, "edited task description", problemStatement) + return &createdJobSuccessWithPR, nil + } + }, + wantStdout: "https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1\n", }, { - name: "non-interactive empty description returns error", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - problemStatement: "", - wantErr: "a task description or -F is required when running non-interactively", + name: "interactively rejecting confirmation prompt aborts task creation", + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, + ProblemStatement: "", + Prompter: &prompter.PrompterMock{ + MarkdownEditorFunc: func(prompt, defaultValue string, blankAllowed bool) (string, error) { + require.Equal(t, "Enter the task description", prompt) + return "From editor", nil + }, + ConfirmFunc: func(message string, defaultValue bool) (bool, error) { + require.Equal(t, "Submit agent task", message) + return false, nil + }, + }, + }, + isTTY: true, + wantErr: "SilentError", + wantErrIs: cmdutil.SilentError, + wantStdErr: "", }, { - name: "interactive prompt success", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - isTTY: true, - problemStatement: "", + name: "interactively entering task description with editor, no file", + isTTY: true, + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + ProblemStatement: "", + Prompter: &prompter.PrompterMock{ + MarkdownEditorFunc: func(prompt, defaultValue string, blankAllowed bool) (string, error) { + require.Equal(t, "Enter the task description", prompt) + return "From editor", nil + }, + ConfirmFunc: func(message string, defaultValue bool) (bool, error) { + require.Equal(t, "Submit agent task", message) + return true, nil + }, + }, + }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { require.Equal(t, "From editor", problemStatement) return &createdJobSuccessWithPR, nil } }, - prompterMock: &prompter.PrompterMock{ - MarkdownEditorFunc: func(prompt, defaultValue string, blankAllowed bool) (string, error) { - require.Equal(t, "Enter the task description", prompt) - return "From editor", nil + wantStdout: "https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1\n", + }, + { + name: "empty task description from interactive prompt returns error", + isTTY: true, + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil }, + Prompter: &prompter.PrompterMock{ + MarkdownEditorFunc: func(prompt, defaultValue string, blankAllowed bool) (string, error) { + return " ", nil + }, + }, + }, + wantErr: "SilentError", + wantErrIs: cmdutil.SilentError, + wantStdErr: "a task description is required.\n", + }, + { + name: "problem statement loaded from file non-interactively doesn't prompt or return error", + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, + ProblemStatement: "", + ProblemStatementFile: taskDescFile, + }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + require.Equal(t, "OWNER", owner) + require.Equal(t, "REPO", repo) + require.Equal(t, "task description from file", problemStatement) + return &createdJobSuccessWithPR, nil + } }, wantStdout: "https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1\n", }, { - name: "interactive prompt empty returns error", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - isTTY: true, - problemStatement: "", - prompterMock: &prompter.PrompterMock{ - MarkdownEditorFunc: func(prompt, defaultValue string, blankAllowed bool) (string, error) { - return " ", nil - }, - }, - wantErr: "a task description is required", + name: "missing repo returns error", + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { + return nil, nil + }}, + wantErr: "a repository is required; re-run in a repository or supply one with --repo owner/name", }, { - name: "base branch included in create payload", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - baseBranch: "feature", - problemStatement: "Do the thing", + name: "non-interactive empty description returns error", + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + ProblemStatement: "", + }, + wantErr: "SilentError", + wantErrIs: cmdutil.SilentError, + wantStdErr: "a task description is required.\n", + }, + { + name: "problem statement loaded from arg non-interactively doesn't prompt or return error", + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, + ProblemStatement: "task description", + }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + require.Equal(t, "OWNER", owner) + require.Equal(t, "REPO", repo) + require.Equal(t, "task description", problemStatement) + return &createdJobSuccessWithPR, nil + } + }, + wantStdout: "https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1\n", + }, + { + name: "base branch included in create payload", + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, + ProblemStatement: "Do the thing", + BaseBranch: "feature", + }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) @@ -233,24 +316,32 @@ func Test_createRun(t *testing.T) { wantStdout: "https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1\n", }, { - name: "create task API failure returns error", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - problemStatement: "Do the thing", + name: "create task API failure returns error", + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + ProblemStatement: "Do the thing", + }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) require.Equal(t, "REPO", repo) require.Equal(t, "Do the thing", problemStatement) require.Equal(t, "", baseBranch) - return nil, errors.New("some error") + return nil, errors.New("some API error") } }, - wantErr: "some error", + wantErr: "some API error", }, { - name: "get job API failure surfaces error", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - problemStatement: "Do the thing", + name: "get job API failure surfaces error", + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + ProblemStatement: "Do the thing", + }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) @@ -267,9 +358,13 @@ func Test_createRun(t *testing.T) { wantStdout: "job job123 queued. View progress: https://github.com/copilot/agents\n", }, { - name: "success with immediate PR", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - problemStatement: "Do the thing", + name: "success with immediate PR", + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + ProblemStatement: "Do the thing", + }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) @@ -282,9 +377,13 @@ func Test_createRun(t *testing.T) { wantStdout: "https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1\n", }, { - name: "success with delayed PR after polling", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - problemStatement: "Do the thing", + name: "success with delayed PR after polling", + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + ProblemStatement: "Do the thing", + }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) @@ -303,9 +402,13 @@ func Test_createRun(t *testing.T) { wantStdout: "https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1\n", }, { - name: "fallback after timeout returns link to global agents page", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - problemStatement: "Do the thing", + name: "fallback after polling timeout returns link to global agents page", + opts: &CreateOptions{ + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + ProblemStatement: "Do the thing", + }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) @@ -341,27 +444,22 @@ func Test_createRun(t *testing.T) { ios.SetStdoutTTY(true) } - opts := &CreateOptions{ - IO: ios, - ProblemStatement: tt.problemStatement, - BaseRepo: tt.baseRepoFunc, - BaseBranch: tt.baseBranch, - Prompter: tt.prompterMock, - CapiClient: func() (capi.CapiClient, error) { - return capiClientMock, nil - }, + tt.opts.IO = ios + tt.opts.CapiClient = func() (capi.CapiClient, error) { + return capiClientMock, nil } - // A backoff with no interval between retries to keep tests fast, - // and also a max number of retries so we don't infinitely poll. - opts.BackOff = backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 3) - - err := createRun(opts) + // fast backoff + tt.opts.BackOff = backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 3) + err := createRun(tt.opts) + if tt.wantErrIs != nil { + require.ErrorIs(t, err, tt.wantErrIs) + } if tt.wantErr != "" { require.Error(t, err) require.Equal(t, tt.wantErr, err.Error()) - } else { + } else if tt.wantErrIs == nil { require.NoError(t, err) }