Merge pull request #12068 from cli/kw/spike-custom-agents

`gh agent-task create`: support `--custom-agent`/`-a` flag
This commit is contained in:
Kynan Ware 2025-10-31 18:19:44 -06:00 committed by GitHub
commit cd9e5e534f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 115 additions and 37 deletions

View file

@ -16,7 +16,7 @@ const capiHost = "api.githubcopilot.com"
// may be replaced with test doubles in unit tests.
type CapiClient interface {
ListLatestSessionsForViewer(ctx context.Context, limit int) ([]*Session, error)
CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*Job, error)
CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string, customAgent string) (*Job, error)
GetJob(ctx context.Context, owner, repo, jobID string) (*Job, error)
GetSession(ctx context.Context, id string) (*Session, error)
GetSessionLogs(ctx context.Context, id string) ([]byte, error)

View file

@ -18,7 +18,7 @@ var _ CapiClient = &CapiClientMock{}
//
// // make and configure a mocked CapiClient
// mockedCapiClient := &CapiClientMock{
// CreateJobFunc: func(ctx context.Context, owner string, repo string, problemStatement string, baseBranch string) (*Job, error) {
// CreateJobFunc: func(ctx context.Context, owner string, repo string, problemStatement string, baseBranch string, customAgent string) (*Job, error) {
// panic("mock out the CreateJob method")
// },
// GetJobFunc: func(ctx context.Context, owner string, repo string, jobID string) (*Job, error) {
@ -47,7 +47,7 @@ var _ CapiClient = &CapiClientMock{}
// }
type CapiClientMock struct {
// CreateJobFunc mocks the CreateJob method.
CreateJobFunc func(ctx context.Context, owner string, repo string, problemStatement string, baseBranch string) (*Job, error)
CreateJobFunc func(ctx context.Context, owner string, repo string, problemStatement string, baseBranch string, customAgent string) (*Job, error)
// GetJobFunc mocks the GetJob method.
GetJobFunc func(ctx context.Context, owner string, repo string, jobID string) (*Job, error)
@ -81,6 +81,8 @@ type CapiClientMock struct {
ProblemStatement string
// BaseBranch is the baseBranch argument value.
BaseBranch string
// CustomAgent is the customAgent argument value.
CustomAgent string
}
// GetJob holds details about calls to the GetJob method.
GetJob []struct {
@ -149,7 +151,7 @@ type CapiClientMock struct {
}
// CreateJob calls CreateJobFunc.
func (mock *CapiClientMock) CreateJob(ctx context.Context, owner string, repo string, problemStatement string, baseBranch string) (*Job, error) {
func (mock *CapiClientMock) CreateJob(ctx context.Context, owner string, repo string, problemStatement string, baseBranch string, customAgent string) (*Job, error) {
if mock.CreateJobFunc == nil {
panic("CapiClientMock.CreateJobFunc: method is nil but CapiClient.CreateJob was just called")
}
@ -159,17 +161,19 @@ func (mock *CapiClientMock) CreateJob(ctx context.Context, owner string, repo st
Repo string
ProblemStatement string
BaseBranch string
CustomAgent string
}{
Ctx: ctx,
Owner: owner,
Repo: repo,
ProblemStatement: problemStatement,
BaseBranch: baseBranch,
CustomAgent: customAgent,
}
mock.lockCreateJob.Lock()
mock.calls.CreateJob = append(mock.calls.CreateJob, callInfo)
mock.lockCreateJob.Unlock()
return mock.CreateJobFunc(ctx, owner, repo, problemStatement, baseBranch)
return mock.CreateJobFunc(ctx, owner, repo, problemStatement, baseBranch, customAgent)
}
// CreateJobCalls gets all the calls that were made to CreateJob.
@ -182,6 +186,7 @@ func (mock *CapiClientMock) CreateJobCalls() []struct {
Repo string
ProblemStatement string
BaseBranch string
CustomAgent string
} {
var calls []struct {
Ctx context.Context
@ -189,6 +194,7 @@ func (mock *CapiClientMock) CreateJobCalls() []struct {
Repo string
ProblemStatement string
BaseBranch string
CustomAgent string
}
mock.lockCreateJob.RLock()
calls = mock.calls.CreateJob

View file

@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
@ -18,6 +19,7 @@ type Job struct {
ID string `json:"job_id,omitempty"`
SessionID string `json:"session_id,omitempty"`
ProblemStatement string `json:"problem_statement,omitempty"`
CustomAgent string `json:"custom_agent,omitempty"`
EventType string `json:"event_type,omitempty"`
ContentFilterMode string `json:"content_filter_mode,omitempty"`
Status string `json:"status,omitempty"`
@ -54,7 +56,7 @@ const jobsBasePathV1 = baseCAPIURL + "/agents/swe/v1/jobs"
// CreateJob queues a new job using the v1 Jobs API. It may or may not
// return Pull Request information. If Pull Request information is required
// following up by polling GetJob with the job ID is necessary.
func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*Job, error) {
func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*Job, error) {
if owner == "" || repo == "" {
return nil, errors.New("owner and repo are required")
}
@ -71,6 +73,7 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen
payload := &Job{
ProblemStatement: problemStatement,
CustomAgent: customAgent,
EventType: defaultEventType,
PullRequest: &prOpts,
}
@ -88,8 +91,10 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen
}
defer res.Body.Close()
body, _ := io.ReadAll(res.Body)
var j Job
if err := json.NewDecoder(res.Body).Decode(&j); err != nil {
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&j); err != nil {
if res.StatusCode != http.StatusCreated && res.StatusCode != http.StatusOK { // accept 201 or 200
// This happens when there's an error like unauthorized (401).
statusText := fmt.Sprintf("%d %s", res.StatusCode, http.StatusText(res.StatusCode))
@ -99,11 +104,22 @@ func (c *CAPIClient) CreateJob(ctx context.Context, owner, repo, problemStatemen
}
if res.StatusCode != http.StatusCreated && res.StatusCode != http.StatusOK { // accept 201 or 200
if j.ErrorInfo != nil {
return nil, fmt.Errorf("failed to create job: %s", j.ErrorInfo.Message)
}
statusText := fmt.Sprintf("%d %s", res.StatusCode, http.StatusText(res.StatusCode))
return nil, fmt.Errorf("failed to create job: %s", statusText)
// If the response has error embeded, we can use that.
// TODO: Does this really ever happen?
if j.ErrorInfo != nil {
return nil, fmt.Errorf("failed to create job: %s: %s", statusText, j.ErrorInfo.Message)
}
// If the response doesn't have error embedded,
// try to decode the response itself as a jobError.
var errInfo JobError
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&errInfo); err != nil {
return nil, fmt.Errorf("failed to create job: %s", statusText)
}
return nil, fmt.Errorf("failed to create job: %s: %s", statusText, errInfo.Message)
}
return &j, nil

View file

@ -188,14 +188,14 @@ func TestGetJob(t *testing.T) {
func TestCreateJobRequiresRepoAndProblemStatement(t *testing.T) {
client := &CAPIClient{}
_, err := client.CreateJob(context.Background(), "", "only-repo", "", "")
_, err := client.CreateJob(context.Background(), "", "only-repo", "", "", "")
assert.EqualError(t, err, "owner and repo are required")
_, err = client.CreateJob(context.Background(), "only-owner", "", "", "")
_, err = client.CreateJob(context.Background(), "only-owner", "", "", "", "")
assert.EqualError(t, err, "owner and repo are required")
_, err = client.CreateJob(context.Background(), "", "", "", "")
_, err = client.CreateJob(context.Background(), "", "", "", "", "")
assert.EqualError(t, err, "owner and repo are required")
_, err = client.CreateJob(context.Background(), "owner", "repo", "", "")
_, err = client.CreateJob(context.Background(), "owner", "repo", "", "", "")
assert.EqualError(t, err, "problem statement is required")
}
@ -205,11 +205,12 @@ func TestCreateJob(t *testing.T) {
require.NoError(t, err)
tests := []struct {
name string
baseBranch string
httpStubs func(*testing.T, *httpmock.Registry)
wantErr string
wantOut *Job
name string
baseBranch string
customAgent string
httpStubs func(*testing.T, *httpmock.Registry)
wantErr string
wantOut *Job
}{
{
name: "success",
@ -305,6 +306,56 @@ func TestCreateJob(t *testing.T) {
UpdatedAt: sampleDate,
},
},
{
name: "Success with custom agent",
customAgent: "my-custom-agent",
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.WithHost(httpmock.REST("POST", "agents/swe/v1/jobs/OWNER/REPO"), "api.githubcopilot.com"),
httpmock.RESTPayload(201,
heredoc.Docf(`
{
"job_id": "job123",
"session_id": "sess1",
"problem_statement": "Do the thing",
"custom_agent": "my-custom-agent",
"event_type": "foo",
"content_filter_mode": "foo",
"status": "foo",
"result": "foo",
"actor": {
"id": 1,
"login": "octocat"
},
"created_at": "%[1]s",
"updated_at": "%[1]s"
}
`, sampleDateString),
func(payload map[string]interface{}) {
assert.Equal(t, "Do the thing", payload["problem_statement"])
assert.Equal(t, "gh_cli", payload["event_type"])
assert.Equal(t, "my-custom-agent", payload["custom_agent"])
},
),
)
},
wantOut: &Job{
ID: "job123",
SessionID: "sess1",
ProblemStatement: "Do the thing",
CustomAgent: "my-custom-agent",
EventType: "foo",
ContentFilterMode: "foo",
Status: "foo",
Result: "foo",
Actor: &JobActor{
ID: 1,
Login: "octocat",
},
CreatedAt: sampleDate,
UpdatedAt: sampleDate,
},
},
{
name: "API error, included in response body",
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
@ -317,7 +368,7 @@ func TestCreateJob(t *testing.T) {
}`)),
)
},
wantErr: "failed to create job: some error",
wantErr: "failed to create job: 500 Internal Server Error: some error",
},
{
name: "API error",
@ -327,7 +378,7 @@ func TestCreateJob(t *testing.T) {
httpmock.StatusStringResponse(500, `{}`),
)
},
wantErr: "failed to create job: 500 Internal Server Error",
wantErr: "failed to create job: 500 Internal Server Error: ",
},
{
name: "invalid JSON response, non-HTTP 200",
@ -364,7 +415,7 @@ func TestCreateJob(t *testing.T) {
cfg := config.NewBlankConfig()
capiClient := NewCAPIClient(httpClient, cfg.Authentication())
job, err := capiClient.CreateJob(context.Background(), "OWNER", "REPO", "Do the thing", tt.baseBranch)
job, err := capiClient.CreateJob(context.Background(), "OWNER", "REPO", "Do the thing", tt.baseBranch, tt.customAgent)
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)

View file

@ -34,6 +34,7 @@ type CreateOptions struct {
Sleep func(d time.Duration)
ProblemStatement string
CustomAgent string
BackOff backoff.BackOff
BaseBranch string
Prompter prompter.Prompter
@ -103,6 +104,9 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
# Select a different base branch for the PR
$ gh agent-task create "fix errors" --base branch
# Create a task using the custom agent defined in '.github/agents/my-agent.md'
$ gh agent-task create "build me a new app" --custom-agent my-agent
`),
}
@ -111,6 +115,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
cmd.Flags().StringVarP(&opts.ProblemStatementFile, "from-file", "F", "", "Read task description from `file` (use \"-\" to read from standard input)")
cmd.Flags().StringVarP(&opts.BaseBranch, "base", "b", "", "Base branch for the pull request (use default branch if not provided)")
cmd.Flags().BoolVar(&opts.Follow, "follow", false, "Follow agent session logs")
cmd.Flags().StringVarP(&opts.CustomAgent, "custom-agent", "a", "", "Use a custom agent for the task. e.g., use 'my-agent' for the 'my-agent.md' agent")
return cmd
}
@ -160,7 +165,7 @@ func createRun(opts *CreateOptions) error {
opts.IO.StartProgressIndicatorWithLabel(fmt.Sprintf("Creating agent task in %s/%s...", repo.RepoOwner(), repo.RepoName()))
defer opts.IO.StopProgressIndicator()
job, err := client.CreateJob(ctx, repo.RepoOwner(), repo.RepoName(), opts.ProblemStatement, opts.BaseBranch)
job, err := client.CreateJob(ctx, repo.RepoOwner(), repo.RepoName(), opts.ProblemStatement, opts.BaseBranch, opts.CustomAgent)
if err != nil {
return err
}

View file

@ -180,7 +180,7 @@ func Test_createRun(t *testing.T) {
ProblemStatement: "task description from arg",
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "task description from arg", problemStatement)
@ -196,7 +196,7 @@ func Test_createRun(t *testing.T) {
ProblemStatement: "task description from arg",
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "task description from arg", problemStatement)
@ -214,7 +214,7 @@ func Test_createRun(t *testing.T) {
ProblemStatementFile: taskDescFile,
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "task description from file", problemStatement)
@ -231,7 +231,7 @@ func Test_createRun(t *testing.T) {
ProblemStatementFile: taskDescFile,
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "task description from file", problemStatement)
@ -255,7 +255,7 @@ func Test_createRun(t *testing.T) {
},
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) {
require.Equal(t, "From editor", problemStatement)
return &createdJobSuccessWithPR, nil
}
@ -292,7 +292,7 @@ func Test_createRun(t *testing.T) {
ProblemStatement: "task description",
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "task description", problemStatement)
@ -309,7 +309,7 @@ func Test_createRun(t *testing.T) {
BaseBranch: "feature",
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "Do the thing", problemStatement)
@ -334,7 +334,7 @@ func Test_createRun(t *testing.T) {
ProblemStatement: "Do the thing",
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "Do the thing", problemStatement)
@ -353,7 +353,7 @@ func Test_createRun(t *testing.T) {
ProblemStatement: "Do the thing",
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "Do the thing", problemStatement)
@ -376,7 +376,7 @@ func Test_createRun(t *testing.T) {
ProblemStatement: "Do the thing",
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "Do the thing", problemStatement)
@ -395,7 +395,7 @@ func Test_createRun(t *testing.T) {
ProblemStatement: "Do the thing",
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "Do the thing", problemStatement)
@ -420,7 +420,7 @@ func Test_createRun(t *testing.T) {
ProblemStatement: "Do the thing",
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "Do the thing", problemStatement)
@ -449,7 +449,7 @@ func Test_createRun(t *testing.T) {
Sleep: func(d time.Duration) {},
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) {
m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch, customAgent string) (*capi.Job, error) {
require.Equal(t, "OWNER", owner)
require.Equal(t, "REPO", repo)
require.Equal(t, "Do the thing", problemStatement)