diff --git a/pkg/cmd/agent-task/capi/client.go b/pkg/cmd/agent-task/capi/client.go index c25549121..aee09bcd9 100644 --- a/pkg/cmd/agent-task/capi/client.go +++ b/pkg/cmd/agent-task/capi/client.go @@ -16,7 +16,7 @@ const capiHost = "api.githubcopilot.com" // may be replaced with test doubles in unit tests. type CapiClient interface { ListLatestSessionsForViewer(ctx context.Context, limit int) ([]*Session, error) - CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*Job, error) + CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string, customAgent string) (*Job, error) GetJob(ctx context.Context, owner, repo, jobID string) (*Job, error) GetSession(ctx context.Context, id string) (*Session, error) GetSessionLogs(ctx context.Context, id string) ([]byte, error) diff --git a/pkg/cmd/agent-task/capi/client_mock.go b/pkg/cmd/agent-task/capi/client_mock.go index 325a8d513..c594a6e23 100644 --- a/pkg/cmd/agent-task/capi/client_mock.go +++ b/pkg/cmd/agent-task/capi/client_mock.go @@ -18,7 +18,7 @@ var _ CapiClient = &CapiClientMock{} // // // make and configure a mocked CapiClient // mockedCapiClient := &CapiClientMock{ -// CreateJobFunc: func(ctx context.Context, owner string, repo string, problemStatement string, baseBranch string) (*Job, error) { +// CreateJobFunc: func(ctx context.Context, owner string, repo string, problemStatement string, baseBranch string, customAgent string) (*Job, error) { // panic("mock out the CreateJob method") // }, // GetJobFunc: func(ctx context.Context, owner string, repo string, jobID string) (*Job, error) { @@ -47,7 +47,7 @@ var _ CapiClient = &CapiClientMock{} // } type CapiClientMock struct { // CreateJobFunc mocks the CreateJob method. - CreateJobFunc func(ctx context.Context, owner string, repo string, problemStatement string, baseBranch string) (*Job, error) + CreateJobFunc func(ctx context.Context, owner string, repo string, problemStatement string, baseBranch string, customAgent string) (*Job, error) // GetJobFunc mocks the GetJob method. GetJobFunc func(ctx context.Context, owner string, repo string, jobID string) (*Job, error) @@ -81,6 +81,8 @@ type CapiClientMock struct { ProblemStatement string // BaseBranch is the baseBranch argument value. BaseBranch string + // CustomAgent is the customAgent argument value. + CustomAgent string } // GetJob holds details about calls to the GetJob method. GetJob []struct { @@ -149,7 +151,7 @@ type CapiClientMock struct { } // CreateJob calls CreateJobFunc. -func (mock *CapiClientMock) CreateJob(ctx context.Context, owner string, repo string, problemStatement string, baseBranch string) (*Job, error) { +func (mock *CapiClientMock) CreateJob(ctx context.Context, owner string, repo string, problemStatement string, baseBranch string, customAgent string) (*Job, error) { if mock.CreateJobFunc == nil { panic("CapiClientMock.CreateJobFunc: method is nil but CapiClient.CreateJob was just called") } @@ -159,17 +161,19 @@ func (mock *CapiClientMock) CreateJob(ctx context.Context, owner string, repo st Repo string ProblemStatement string BaseBranch string + CustomAgent string }{ Ctx: ctx, Owner: owner, Repo: repo, ProblemStatement: problemStatement, BaseBranch: baseBranch, + CustomAgent: customAgent, } mock.lockCreateJob.Lock() mock.calls.CreateJob = append(mock.calls.CreateJob, callInfo) mock.lockCreateJob.Unlock() - return mock.CreateJobFunc(ctx, owner, repo, problemStatement, baseBranch) + return mock.CreateJobFunc(ctx, owner, repo, problemStatement, baseBranch, customAgent) } // CreateJobCalls gets all the calls that were made to CreateJob. @@ -182,6 +186,7 @@ func (mock *CapiClientMock) CreateJobCalls() []struct { Repo string ProblemStatement string BaseBranch string + CustomAgent string } { var calls []struct { Ctx context.Context @@ -189,6 +194,7 @@ func (mock *CapiClientMock) CreateJobCalls() []struct { Repo string ProblemStatement string BaseBranch string + CustomAgent string } mock.lockCreateJob.RLock() calls = mock.calls.CreateJob diff --git a/pkg/cmd/agent-task/capi/job.go b/pkg/cmd/agent-task/capi/job.go index 5a56323cb..2d5c2d264 100644 --- a/pkg/cmd/agent-task/capi/job.go +++ b/pkg/cmd/agent-task/capi/job.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "net/url" "time" @@ -18,6 +19,7 @@ type Job struct { ID string `json:"job_id,omitempty"` SessionID string `json:"session_id,omitempty"` ProblemStatement string `json:"problem_statement,omitempty"` + CustomAgent string `json:"custom_agent,omitempty"` EventType string `json:"event_type,omitempty"` ContentFilterMode string `json:"content_filter_mode,omitempty"` Status string `json:"status,omitempty"` @@ -54,7 +56,7 @@ const jobsBasePathV1 = baseCAPIURL + "/agents/swe/v1/jobs" // CreateJob queues a new job using the v1 Jobs API. It may or may not // return Pull Request information. If Pull Request information is required // following up by polling GetJob with the job ID is necessary. -func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*Job, error) { +func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*Job, error) { if owner == "" || repo == "" { return nil, errors.New("owner and repo are required") } @@ -71,6 +73,7 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen payload := &Job{ ProblemStatement: problemStatement, + CustomAgent: customAgent, EventType: defaultEventType, PullRequest: &prOpts, } @@ -88,8 +91,10 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen } defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + var j Job - if err := json.NewDecoder(res.Body).Decode(&j); err != nil { + if err := json.NewDecoder(bytes.NewReader(body)).Decode(&j); err != nil { if res.StatusCode != http.StatusCreated && res.StatusCode != http.StatusOK { // accept 201 or 200 // This happens when there's an error like unauthorized (401). statusText := fmt.Sprintf("%d %s", res.StatusCode, http.StatusText(res.StatusCode)) @@ -99,11 +104,22 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen } if res.StatusCode != http.StatusCreated && res.StatusCode != http.StatusOK { // accept 201 or 200 - if j.ErrorInfo != nil { - return nil, fmt.Errorf("failed to create job: %s", j.ErrorInfo.Message) - } statusText := fmt.Sprintf("%d %s", res.StatusCode, http.StatusText(res.StatusCode)) - return nil, fmt.Errorf("failed to create job: %s", statusText) + + // If the response has error embeded, we can use that. + // TODO: Does this really ever happen? + if j.ErrorInfo != nil { + return nil, fmt.Errorf("failed to create job: %s: %s", statusText, j.ErrorInfo.Message) + } + + // If the response doesn't have error embedded, + // try to decode the response itself as a jobError. + var errInfo JobError + if err := json.NewDecoder(bytes.NewReader(body)).Decode(&errInfo); err != nil { + return nil, fmt.Errorf("failed to create job: %s", statusText) + } + + return nil, fmt.Errorf("failed to create job: %s: %s", statusText, errInfo.Message) } return &j, nil diff --git a/pkg/cmd/agent-task/capi/job_test.go b/pkg/cmd/agent-task/capi/job_test.go index 573ce3039..4e14a28a6 100644 --- a/pkg/cmd/agent-task/capi/job_test.go +++ b/pkg/cmd/agent-task/capi/job_test.go @@ -188,14 +188,14 @@ func TestGetJob(t *testing.T) { func TestCreateJobRequiresRepoAndProblemStatement(t *testing.T) { client := &CAPIClient{} - _, err := client.CreateJob(context.Background(), "", "only-repo", "", "") + _, err := client.CreateJob(context.Background(), "", "only-repo", "", "", "") assert.EqualError(t, err, "owner and repo are required") - _, err = client.CreateJob(context.Background(), "only-owner", "", "", "") + _, err = client.CreateJob(context.Background(), "only-owner", "", "", "", "") assert.EqualError(t, err, "owner and repo are required") - _, err = client.CreateJob(context.Background(), "", "", "", "") + _, err = client.CreateJob(context.Background(), "", "", "", "", "") assert.EqualError(t, err, "owner and repo are required") - _, err = client.CreateJob(context.Background(), "owner", "repo", "", "") + _, err = client.CreateJob(context.Background(), "owner", "repo", "", "", "") assert.EqualError(t, err, "problem statement is required") } @@ -205,11 +205,12 @@ func TestCreateJob(t *testing.T) { require.NoError(t, err) tests := []struct { - name string - baseBranch string - httpStubs func(*testing.T, *httpmock.Registry) - wantErr string - wantOut *Job + name string + baseBranch string + customAgent string + httpStubs func(*testing.T, *httpmock.Registry) + wantErr string + wantOut *Job }{ { name: "success", @@ -305,6 +306,56 @@ func TestCreateJob(t *testing.T) { UpdatedAt: sampleDate, }, }, + { + name: "Success with custom agent", + customAgent: "my-custom-agent", + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.WithHost(httpmock.REST("POST", "agents/swe/v1/jobs/OWNER/REPO"), "api.githubcopilot.com"), + httpmock.RESTPayload(201, + heredoc.Docf(` + { + "job_id": "job123", + "session_id": "sess1", + "problem_statement": "Do the thing", + "custom_agent": "my-custom-agent", + "event_type": "foo", + "content_filter_mode": "foo", + "status": "foo", + "result": "foo", + "actor": { + "id": 1, + "login": "octocat" + }, + "created_at": "%[1]s", + "updated_at": "%[1]s" + } + `, sampleDateString), + func(payload map[string]interface{}) { + assert.Equal(t, "Do the thing", payload["problem_statement"]) + assert.Equal(t, "gh_cli", payload["event_type"]) + assert.Equal(t, "my-custom-agent", payload["custom_agent"]) + }, + ), + ) + }, + wantOut: &Job{ + ID: "job123", + SessionID: "sess1", + ProblemStatement: "Do the thing", + CustomAgent: "my-custom-agent", + EventType: "foo", + ContentFilterMode: "foo", + Status: "foo", + Result: "foo", + Actor: &JobActor{ + ID: 1, + Login: "octocat", + }, + CreatedAt: sampleDate, + UpdatedAt: sampleDate, + }, + }, { name: "API error, included in response body", httpStubs: func(t *testing.T, reg *httpmock.Registry) { @@ -317,7 +368,7 @@ func TestCreateJob(t *testing.T) { }`)), ) }, - wantErr: "failed to create job: some error", + wantErr: "failed to create job: 500 Internal Server Error: some error", }, { name: "API error", @@ -327,7 +378,7 @@ func TestCreateJob(t *testing.T) { httpmock.StatusStringResponse(500, `{}`), ) }, - wantErr: "failed to create job: 500 Internal Server Error", + wantErr: "failed to create job: 500 Internal Server Error: ", }, { name: "invalid JSON response, non-HTTP 200", @@ -364,7 +415,7 @@ func TestCreateJob(t *testing.T) { cfg := config.NewBlankConfig() capiClient := NewCAPIClient(httpClient, cfg.Authentication()) - job, err := capiClient.CreateJob(context.Background(), "OWNER", "REPO", "Do the thing", tt.baseBranch) + job, err := capiClient.CreateJob(context.Background(), "OWNER", "REPO", "Do the thing", tt.baseBranch, tt.customAgent) if tt.wantErr != "" { require.EqualError(t, err, tt.wantErr) diff --git a/pkg/cmd/agent-task/create/create.go b/pkg/cmd/agent-task/create/create.go index cf5f7fd11..a9176e966 100644 --- a/pkg/cmd/agent-task/create/create.go +++ b/pkg/cmd/agent-task/create/create.go @@ -34,6 +34,7 @@ type CreateOptions struct { Sleep func(d time.Duration) ProblemStatement string + CustomAgent string BackOff backoff.BackOff BaseBranch string Prompter prompter.Prompter @@ -103,6 +104,9 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co # Select a different base branch for the PR $ gh agent-task create "fix errors" --base branch + + # Create a task using the custom agent defined in '.github/agents/my-agent.md' + $ gh agent-task create "build me a new app" --custom-agent my-agent `), } @@ -111,6 +115,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co cmd.Flags().StringVarP(&opts.ProblemStatementFile, "from-file", "F", "", "Read task description from `file` (use \"-\" to read from standard input)") cmd.Flags().StringVarP(&opts.BaseBranch, "base", "b", "", "Base branch for the pull request (use default branch if not provided)") cmd.Flags().BoolVar(&opts.Follow, "follow", false, "Follow agent session logs") + cmd.Flags().StringVarP(&opts.CustomAgent, "custom-agent", "a", "", "Use a custom agent for the task. e.g., use 'my-agent' for the 'my-agent.md' agent") return cmd } @@ -160,7 +165,7 @@ func createRun(opts *CreateOptions) error { opts.IO.StartProgressIndicatorWithLabel(fmt.Sprintf("Creating agent task in %s/%s...", repo.RepoOwner(), repo.RepoName())) defer opts.IO.StopProgressIndicator() - job, err := client.CreateJob(ctx, repo.RepoOwner(), repo.RepoName(), opts.ProblemStatement, opts.BaseBranch) + job, err := client.CreateJob(ctx, repo.RepoOwner(), repo.RepoName(), opts.ProblemStatement, opts.BaseBranch, opts.CustomAgent) if err != nil { return err } diff --git a/pkg/cmd/agent-task/create/create_test.go b/pkg/cmd/agent-task/create/create_test.go index aa02150ba..a041c347b 100644 --- a/pkg/cmd/agent-task/create/create_test.go +++ b/pkg/cmd/agent-task/create/create_test.go @@ -180,7 +180,7 @@ func Test_createRun(t *testing.T) { ProblemStatement: "task description from arg", }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) require.Equal(t, "REPO", repo) require.Equal(t, "task description from arg", problemStatement) @@ -196,7 +196,7 @@ func Test_createRun(t *testing.T) { ProblemStatement: "task description from arg", }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) require.Equal(t, "REPO", repo) require.Equal(t, "task description from arg", problemStatement) @@ -214,7 +214,7 @@ func Test_createRun(t *testing.T) { ProblemStatementFile: taskDescFile, }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) require.Equal(t, "REPO", repo) require.Equal(t, "task description from file", problemStatement) @@ -231,7 +231,7 @@ func Test_createRun(t *testing.T) { ProblemStatementFile: taskDescFile, }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) require.Equal(t, "REPO", repo) require.Equal(t, "task description from file", problemStatement) @@ -255,7 +255,7 @@ func Test_createRun(t *testing.T) { }, }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) { require.Equal(t, "From editor", problemStatement) return &createdJobSuccessWithPR, nil } @@ -292,7 +292,7 @@ func Test_createRun(t *testing.T) { ProblemStatement: "task description", }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) require.Equal(t, "REPO", repo) require.Equal(t, "task description", problemStatement) @@ -309,7 +309,7 @@ func Test_createRun(t *testing.T) { BaseBranch: "feature", }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) require.Equal(t, "REPO", repo) require.Equal(t, "Do the thing", problemStatement) @@ -334,7 +334,7 @@ func Test_createRun(t *testing.T) { ProblemStatement: "Do the thing", }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) require.Equal(t, "REPO", repo) require.Equal(t, "Do the thing", problemStatement) @@ -353,7 +353,7 @@ func Test_createRun(t *testing.T) { ProblemStatement: "Do the thing", }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) require.Equal(t, "REPO", repo) require.Equal(t, "Do the thing", problemStatement) @@ -376,7 +376,7 @@ func Test_createRun(t *testing.T) { ProblemStatement: "Do the thing", }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) require.Equal(t, "REPO", repo) require.Equal(t, "Do the thing", problemStatement) @@ -395,7 +395,7 @@ func Test_createRun(t *testing.T) { ProblemStatement: "Do the thing", }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) require.Equal(t, "REPO", repo) require.Equal(t, "Do the thing", problemStatement) @@ -420,7 +420,7 @@ func Test_createRun(t *testing.T) { ProblemStatement: "Do the thing", }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) require.Equal(t, "REPO", repo) require.Equal(t, "Do the thing", problemStatement) @@ -449,7 +449,7 @@ func Test_createRun(t *testing.T) { Sleep: func(d time.Duration) {}, }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) { require.Equal(t, "OWNER", owner) require.Equal(t, "REPO", repo) require.Equal(t, "Do the thing", problemStatement)