diff --git a/pkg/cmd/agent-task/capi/client.go b/pkg/cmd/agent-task/capi/client.go index 9021d6086..1e9cad3c8 100644 --- a/pkg/cmd/agent-task/capi/client.go +++ b/pkg/cmd/agent-task/capi/client.go @@ -15,7 +15,7 @@ const capiHost = "api.githubcopilot.com" type CapiClient interface { ListSessionsForViewer(ctx context.Context, limit int) ([]*Session, error) ListSessionsForRepo(ctx context.Context, owner string, repo string, limit int) ([]*Session, error) - CreateJob(ctx context.Context, owner, repo, problemStatement string) (*Job, error) + CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*Job, error) GetJob(ctx context.Context, owner, repo, jobID string) (*Job, error) } diff --git a/pkg/cmd/agent-task/capi/job.go b/pkg/cmd/agent-task/capi/job.go index 03eaa376d..26bd3cf51 100644 --- a/pkg/cmd/agent-task/capi/job.go +++ b/pkg/cmd/agent-task/capi/job.go @@ -38,8 +38,9 @@ type JobActor struct { } type JobPullRequest struct { - ID int `json:"id"` - Number int `json:"number"` + ID int `json:"id"` + Number int `json:"number"` + BaseRef string `json:"base_ref,omitempty"` } type JobError struct { @@ -53,7 +54,7 @@ const jobsBasePathV1 = baseCAPIURL + "/agents/swe/v1/jobs" // CreateJob queues a new job using the v1 Jobs API. It may or may not // return Pull Request information. If Pull Request information is required // following up by polling GetJob with the job ID is necessary. -func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatement string) (*Job, error) { +func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*Job, error) { if owner == "" || repo == "" { return nil, errors.New("owner and repo are required") } @@ -63,10 +64,17 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen url := fmt.Sprintf("%s/%s/%s", jobsBasePathV1, url.PathEscape(owner), url.PathEscape(repo)) + prOpts := JobPullRequest{} + if baseBranch != "" { + prOpts.BaseRef = "refs/heads/" + baseBranch + } + payload := &Job{ ProblemStatement: problemStatement, EventType: defaultEventType, + PullRequest: &prOpts, } + b, _ := json.Marshal(payload) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) diff --git a/pkg/cmd/agent-task/create/create.go b/pkg/cmd/agent-task/create/create.go index 41f615c3f..89e09561c 100644 --- a/pkg/cmd/agent-task/create/create.go +++ b/pkg/cmd/agent-task/create/create.go @@ -5,10 +5,12 @@ import ( "errors" "fmt" "net/url" + "strings" "time" "github.com/cenkalti/backoff/v4" + "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" @@ -25,24 +27,42 @@ type CreateOptions struct { Config func() (gh.Config, error) ProblemStatement string BackOff backoff.BackOff + BaseBranch string } func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Command { opts := &CreateOptions{ IO: f.IOStreams, } + + var fromFileName string + cmd := &cobra.Command{ - Use: "create \"\"", + Use: "create [] [flags]", Short: "Create an agent task (preview)", Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // TODO: We'll support prompting for the problem statement if not provided - // and from file flags, later. - if len(args) == 0 { - return cmdutil.FlagErrorf("a task description is required") + if err := cmdutil.MutuallyExclusive("only one of -F or arg can be provided", len(args) > 0, fromFileName != ""); err != nil { + return err } - opts.ProblemStatement = args[0] + // Populate ProblemStatement from either arg or file + if len(args) > 0 { + opts.ProblemStatement = args[0] + } else if fromFileName != "" { + fileContent, err := cmdutil.ReadFile(fromFileName, opts.IO.In) + if err != nil { + return cmdutil.FlagErrorf("could not read task description file: %v", err) + } + trimmed := strings.TrimSpace(string(fileContent)) + if trimmed == "" { + return cmdutil.FlagErrorf("task description file is empty") + } + opts.ProblemStatement = trimmed + } + if opts.ProblemStatement == "" { + return cmdutil.FlagErrorf("a task description is required") + } // Support -R/--repo override if f != nil { opts.BaseRepo = f.BaseRepo @@ -52,11 +72,21 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co } return createRun(opts) }, + Example: heredoc.Doc(` + # Create a task from an inline description + $ gh agent-task create "build me a new app" + + # Create a task from a file + $ gh agent-task create -F task-desc.md + `), } if f != nil { cmdutil.EnableRepoOverride(cmd, f) } + cmd.Flags().StringVarP(&fromFileName, "from-file", "F", "", "Read task description from file") + cmd.Flags().StringVarP(&opts.BaseBranch, "base", "b", "", "Base branch for the task") + opts.CapiClient = func() (capi.CapiClient, error) { cfg, err := f.Config() if err != nil { @@ -96,7 +126,7 @@ func createRun(opts *CreateOptions) error { opts.IO.StartProgressIndicatorWithLabel(fmt.Sprintf("Creating agent task in %s/%s...", repo.RepoOwner(), repo.RepoName())) defer opts.IO.StopProgressIndicator() - job, err := client.CreateJob(ctx, repo.RepoOwner(), repo.RepoName(), opts.ProblemStatement) + job, err := client.CreateJob(ctx, repo.RepoOwner(), repo.RepoName(), opts.ProblemStatement, opts.BaseBranch) if err != nil { return err } @@ -112,9 +142,9 @@ func createRun(opts *CreateOptions) error { // Ensure we have a backoff strategy. if opts.BackOff == nil { opts.BackOff = backoff.NewExponentialBackOff( - backoff.WithMaxElapsedTime(4*time.Second), + backoff.WithMaxElapsedTime(10*time.Second), backoff.WithInitialInterval(300*time.Millisecond), - backoff.WithMaxInterval(2*time.Second), + backoff.WithMaxInterval(10*time.Second), backoff.WithMultiplier(1.5), ) } diff --git a/pkg/cmd/agent-task/create/create_test.go b/pkg/cmd/agent-task/create/create_test.go index 977d32dfb..a7bb8a166 100644 --- a/pkg/cmd/agent-task/create/create_test.go +++ b/pkg/cmd/agent-task/create/create_test.go @@ -2,6 +2,9 @@ package create import ( "net/http" + "os" + "path/filepath" + "slices" "testing" "github.com/MakeNowJust/heredoc" @@ -17,13 +20,104 @@ import ( // Test basic option parsing & repository requirement func TestNewCmdCreate_Args(t *testing.T) { - f := &cmdutil.Factory{} - cmd := NewCmdCreate(f, func(o *CreateOptions) error { return nil }) - // no args should error via cobra MinimumNArgs before our runF - // TODO once we support more sources of problem statement input, - // this will change. - _, err := cmd.ExecuteC() - require.Error(t, err) + tests := []struct { + name string + args []string + fileContent string // if non-empty, create temp file and substitute {{FILE}} token in args + wantOpts *CreateOptions // nil when expecting error + expectedErr string + }{ + { + name: "no args nor file", + args: []string{}, + expectedErr: "a task description is required", + }, + { + name: "arg only success", + args: []string{"task description from args"}, + wantOpts: &CreateOptions{ + ProblemStatement: "task description from args", + }, + }, + { + name: "from-file success", + args: []string{"-F", "{{FILE}}"}, + fileContent: "task description from file", + wantOpts: &CreateOptions{ + ProblemStatement: "task description from file", + }, + }, + { + name: "file content from stdin success", + args: []string{"-F", "-"}, + fileContent: "task from stdin", + wantOpts: &CreateOptions{ProblemStatement: "task from stdin"}, + }, + { + name: "mutually exclusive arg and file", + args: []string{"Some task inline", "-F", "{{FILE}}"}, + fileContent: "Some task", + expectedErr: "only one of -F or arg can be provided", + }, + { + name: "missing file path", + args: []string{"-F", "does-not-exist.md"}, + expectedErr: "could not read task description file: open does-not-exist.md: no such file or directory", + }, + { + name: "empty file", + args: []string{"-F", "{{FILE}}"}, + fileContent: " \n\n", + expectedErr: "task description file is empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ios, stdinBuf, _, _ := iostreams.Test() + + // Provide file content either via stdin ( -F - ) or by creating a temp file + if tt.fileContent != "" { + isStdin := len(tt.args) == 2 && tt.args[0] == "-F" && tt.args[1] == "-" + hasFileToken := slices.Contains(tt.args, "{{FILE}}") + + switch { + case isStdin: + stdinBuf.WriteString(tt.fileContent) + case hasFileToken: + dir := t.TempDir() + path := filepath.Join(dir, "task.md") + if err := os.WriteFile(path, []byte(tt.fileContent), 0o600); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + for i, a := range tt.args { + if a == "{{FILE}}" { + tt.args[i] = path + } + } + } + } + + f := &cmdutil.Factory{IOStreams: ios} + var gotOpts *CreateOptions + cmd := NewCmdCreate(f, func(o *CreateOptions) error { + gotOpts = o + return nil + }) + cmd.SetArgs(tt.args) + _, err := cmd.ExecuteC() + + if tt.expectedErr != "" { + require.Error(t, err) + require.Equal(t, tt.expectedErr, err.Error()) + return + } + require.NoError(t, err) + if tt.wantOpts != nil { + require.Equal(t, tt.wantOpts.ProblemStatement, gotOpts.ProblemStatement) + } + }) + } } func Test_createRun(t *testing.T) { @@ -55,10 +149,35 @@ func Test_createRun(t *testing.T) { stubs func(*httpmock.Registry) baseRepoFunc func() (ghrepo.Interface, error) problemStatement string + baseBranch string wantStdout string wantStdErr string wantErr string }{ + { + name: "base branch included in create payload", + baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, + problemStatement: "Do the thing", + baseBranch: "feature", + stubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.WithHost(httpmock.REST("POST", "agents/swe/v1/jobs/OWNER/REPO"), "api.githubcopilot.com"), + httpmock.RESTPayload(201, createdJobSuccessWithPRResponse, func(payload map[string]interface{}) { + prRaw, ok := payload["pull_request"].(map[string]interface{}) + if !ok { + require.FailNow(t, "expected pull_request object in payload") + } + if prRaw["base_ref"] != "refs/heads/feature" { + require.FailNow(t, "expected pull_request.base_ref to be 'refs/heads/feature'") + } + if payload["problem_statement"] != "Do the thing" { + require.FailNow(t, "unexpected problem_statement value") + } + }), + ) + }, + wantStdout: "https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1\n", + }, { name: "get job API failure surfaces error", baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, @@ -156,6 +275,7 @@ func Test_createRun(t *testing.T) { IO: ios, ProblemStatement: tt.problemStatement, BaseRepo: tt.baseRepoFunc, + BaseBranch: tt.baseBranch, } // A backoff with no internal between retries to keep tests fast,