From 07ec8c629d13e5f5c8a43d727fad4fdae2c987f6 Mon Sep 17 00:00:00 2001 From: "Babak K. Shandiz" Date: Thu, 4 Sep 2025 21:01:56 +0100 Subject: [PATCH] test(agent-task create): use `CapiClientMock` Signed-off-by: Babak K. Shandiz --- pkg/cmd/agent-task/create/create_test.go | 300 +++++++++++------------ 1 file changed, 150 insertions(+), 150 deletions(-) diff --git a/pkg/cmd/agent-task/create/create_test.go b/pkg/cmd/agent-task/create/create_test.go index 8cd82ac8c..3a244c0e7 100644 --- a/pkg/cmd/agent-task/create/create_test.go +++ b/pkg/cmd/agent-task/create/create_test.go @@ -1,20 +1,19 @@ package create import ( + "context" + "errors" "fmt" "io" - "net/http" "os" "path/filepath" "testing" + "time" - "github.com/MakeNowJust/heredoc" "github.com/cenkalti/backoff/v4" - "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" "github.com/cli/cli/v2/pkg/cmdutil" - "github.com/cli/cli/v2/pkg/httpmock" "github.com/cli/cli/v2/pkg/iostreams" "github.com/google/shlex" "github.com/stretchr/testify/require" @@ -128,177 +127,182 @@ func TestNewCmdCreate(t *testing.T) { } func Test_createRun(t *testing.T) { - createdJobSuccessResponse := heredoc.Doc(`{ - "job_id":"job123", - "session_id":"sess1", - "actor":{"id":1,"login":"octocat"}, - "created_at":"2025-08-29T00:00:00Z", - "updated_at":"2025-08-29T00:00:00Z" - }`) - createdJobSuccessWithPRResponse := heredoc.Doc(`{ - "job_id":"job123", - "session_id":"sess1", - "actor":{"id":1,"login":"octocat"}, - "created_at":"2025-08-29T00:00:00Z", - "updated_at":"2025-08-29T00:00:00Z", - "pull_request":{"id":101,"number":42} - }`) - createdJobTimeoutResponse := heredoc.Doc(`{ - "job_id":"jobABC", - "session_id":"sess1", - "actor":{"id":1,"login":"octocat"}, - "created_at":"2025-08-29T00:00:00Z", - "updated_at":"2025-08-29T00:00:00Z" - }`) + sampleDateString := "2025-08-29T00:00:00Z" + sampleDate, err := time.Parse(time.RFC3339, sampleDateString) + require.NoError(t, err) + + createdJobSuccess := capi.Job{ + ID: "job123", + SessionID: "sess1", + Actor: &capi.JobActor{ + ID: 1, + Login: "octocat", + }, + CreatedAt: sampleDate, + UpdatedAt: sampleDate, + } + createdJobSuccessWithPR := capi.Job{ + ID: "job123", + SessionID: "sess1", + Actor: &capi.JobActor{ + ID: 1, + Login: "octocat", + }, + CreatedAt: sampleDate, + UpdatedAt: sampleDate, + PullRequest: &capi.JobPullRequest{ + ID: 101, + Number: 42, + }, + } tests := []struct { - name string - stubs func(*httpmock.Registry) - baseRepoFunc func() (ghrepo.Interface, error) - problemStatement string - baseBranch string - wantStdout string - wantStdErr string - wantErr string + name string + capiStubs func(*testing.T, *capi.CapiClientMock) + baseRepoFunc func() (ghrepo.Interface, error) + baseBranch string + wantStdout string + wantStdErr string + wantErr string }{ { - name: "base branch included in create payload", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - problemStatement: "Do the thing", - baseBranch: "feature", - stubs: func(reg *httpmock.Registry) { - reg.Register( - httpmock.WithHost(httpmock.REST("POST", "agents/swe/v1/jobs/OWNER/REPO"), "api.githubcopilot.com"), - httpmock.RESTPayload(201, createdJobSuccessWithPRResponse, func(payload map[string]interface{}) { - prRaw, ok := payload["pull_request"].(map[string]interface{}) - if !ok { - require.FailNow(t, "expected pull_request object in payload") - } - if prRaw["base_ref"] != "refs/heads/feature" { - require.FailNow(t, "expected pull_request.base_ref to be 'refs/heads/feature'") - } - if payload["problem_statement"] != "Do the thing" { - require.FailNow(t, "unexpected problem_statement value") - } - }), - ) - }, - wantStdout: "https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1\n", + name: "missing repo returns error", + baseRepoFunc: func() (ghrepo.Interface, error) { return nil, nil }, + wantErr: "a repository is required; re-run in a repository or supply one with --repo owner/name", }, { - name: "get job API failure surfaces error", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - problemStatement: "Do the thing", - stubs: func(reg *httpmock.Registry) { - reg.Register( - httpmock.WithHost(httpmock.REST("POST", "agents/swe/v1/jobs/OWNER/REPO"), "api.githubcopilot.com"), - httpmock.StatusStringResponse(201, createdJobTimeoutResponse), - ) - reg.Register( - httpmock.WithHost(httpmock.REST("GET", "agents/swe/v1/jobs/OWNER/REPO/jobABC"), "api.githubcopilot.com"), - httpmock.StatusStringResponse(500, `{"error":{"message":"internal server error"}}`), - ) - }, - wantStdErr: "failed to get job: 500 Internal Server Error\n", - wantStdout: "job jobABC queued. View progress: https://github.com/copilot/agents\n", - }, - { - name: "success with immediate PR", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - problemStatement: "Do the thing", - stubs: func(reg *httpmock.Registry) { - reg.Register( - httpmock.WithHost(httpmock.REST("POST", "agents/swe/v1/jobs/OWNER/REPO"), "api.githubcopilot.com"), - httpmock.StatusStringResponse(201, createdJobSuccessWithPRResponse), - ) - }, - wantStdout: "https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1\n", - }, - { - name: "success with delayed PR after polling", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - problemStatement: "Do the thing", - stubs: func(reg *httpmock.Registry) { - reg.Register( - httpmock.WithHost(httpmock.REST("POST", "agents/swe/v1/jobs/OWNER/REPO"), "api.githubcopilot.com"), - httpmock.StatusStringResponse(201, createdJobSuccessResponse), - ) - reg.Register( - httpmock.WithHost(httpmock.REST("GET", "agents/swe/v1/jobs/OWNER/REPO/job123"), "api.githubcopilot.com"), - httpmock.StringResponse(`{"job_id":"job123","pull_request":{"id":101,"number":42}}`), - ) - }, - wantStdout: "https://github.com/OWNER/REPO/pull/42\n", - }, - { - name: "fallback after timeout returns link to global agents page", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - problemStatement: "Do the thing", - stubs: func(reg *httpmock.Registry) { - reg.Register( - httpmock.WithHost(httpmock.REST("POST", "agents/swe/v1/jobs/OWNER/REPO"), "api.githubcopilot.com"), - httpmock.StatusStringResponse(201, createdJobTimeoutResponse), - ) - // 4 attempts: initial + 3 retries - for range 4 { - reg.Register( - httpmock.WithHost(httpmock.REST("GET", "agents/swe/v1/jobs/OWNER/REPO/jobABC"), "api.githubcopilot.com"), - httpmock.StringResponse(`{"job_id":"jobABC"}`), - ) + name: "base branch included in create payload", + baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, + baseBranch: "feature", + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + require.Equal(t, "OWNER", owner) + require.Equal(t, "REPO", repo) + require.Equal(t, "Do the thing", problemStatement) + require.Equal(t, "feature", baseBranch) + return &createdJobSuccess, nil + } + m.GetJobFunc = func(ctx context.Context, owner, repo, jobID string) (*capi.Job, error) { + require.Equal(t, "OWNER", owner) + require.Equal(t, "REPO", repo) + require.Equal(t, "job123", jobID) + return &createdJobSuccessWithPR, nil } }, - wantStdout: "job jobABC queued. View progress: https://github.com/copilot/agents\n", + wantStdout: "https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1\n", }, { - name: "missing repo returns error", - problemStatement: "task", - baseRepoFunc: func() (ghrepo.Interface, error) { return nil, nil }, - wantErr: "a repository is required; re-run in a repository or supply one with --repo owner/name", - }, - { - name: "create task API failure returns error", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - problemStatement: "do the thing", - stubs: func(reg *httpmock.Registry) { - reg.Register( - httpmock.WithHost(httpmock.REST("POST", "agents/swe/v1/jobs/OWNER/REPO"), "api.githubcopilot.com"), - httpmock.StatusStringResponse(500, `{"error":{"message":"some API error"}}`), - ) + name: "create task API failure returns error", + baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + require.Equal(t, "OWNER", owner) + require.Equal(t, "REPO", repo) + require.Equal(t, "Do the thing", problemStatement) + require.Equal(t, "", baseBranch) + return nil, errors.New("some error") + } }, - wantErr: "failed to create job: some API error", + wantErr: "some error", }, { - name: "missing task description returns error", - baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - problemStatement: "", - wantErr: "a task description is required", + name: "get job API failure surfaces error", + baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + require.Equal(t, "OWNER", owner) + require.Equal(t, "REPO", repo) + require.Equal(t, "Do the thing", problemStatement) + require.Equal(t, "", baseBranch) + return &createdJobSuccess, nil + } + m.GetJobFunc = func(ctx context.Context, owner, repo, jobID string) (*capi.Job, error) { + return nil, errors.New("some error") + } + }, + wantStdErr: "some error\n", + wantStdout: "job job123 queued. View progress: https://github.com/copilot/agents\n", + }, + { + name: "success with immediate PR", + baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + require.Equal(t, "OWNER", owner) + require.Equal(t, "REPO", repo) + require.Equal(t, "Do the thing", problemStatement) + require.Equal(t, "", baseBranch) + return &createdJobSuccessWithPR, nil + } + }, + wantStdout: "https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1\n", + }, + { + name: "success with delayed PR after polling", + baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + require.Equal(t, "OWNER", owner) + require.Equal(t, "REPO", repo) + require.Equal(t, "Do the thing", problemStatement) + require.Equal(t, "", baseBranch) + return &createdJobSuccess, nil + } + m.GetJobFunc = func(ctx context.Context, owner, repo, jobID string) (*capi.Job, error) { + require.Equal(t, "OWNER", owner) + require.Equal(t, "REPO", repo) + require.Equal(t, "job123", jobID) + return &createdJobSuccessWithPR, nil + } + }, + wantStdout: "https://github.com/OWNER/REPO/pull/42/agent-sessions/sess1\n", + }, + { + name: "fallback after timeout returns link to global agents page", + baseRepoFunc: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.CreateJobFunc = func(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*capi.Job, error) { + require.Equal(t, "OWNER", owner) + require.Equal(t, "REPO", repo) + require.Equal(t, "Do the thing", problemStatement) + require.Equal(t, "", baseBranch) + return &createdJobSuccess, nil + } + + count := 0 + m.GetJobFunc = func(ctx context.Context, owner, repo, jobID string) (*capi.Job, error) { + if count++; count > 4 { + require.FailNow(t, "too many get calls") + } + return &createdJobSuccess, nil + } + }, + wantStdout: "job job123 queued. View progress: https://github.com/copilot/agents\n", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + capiClientMock := &capi.CapiClientMock{} + if tt.capiStubs != nil { + tt.capiStubs(t, capiClientMock) + } + ios, _, stdout, stderr := iostreams.Test() opts := &CreateOptions{ IO: ios, - ProblemStatement: tt.problemStatement, + ProblemStatement: "Do the thing", BaseRepo: tt.baseRepoFunc, BaseBranch: tt.baseBranch, + CapiClient: func() (capi.CapiClient, error) { + return capiClientMock, nil + }, } // A backoff with no internal between retries to keep tests fast, // and also a max number of retries so we don't infinitely poll. opts.BackOff = backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 3) - reg := &httpmock.Registry{} - if tt.stubs != nil { - tt.stubs(reg) - cfg := config.NewBlankConfig() - cfg.Set("github.com", "oauth_token", "OTOKEN") - authCfg := cfg.Authentication() - client := capi.NewCAPIClient(&http.Client{Transport: reg}, authCfg) - opts.CapiClient = func() (capi.CapiClient, error) { return client, nil } - } - err := createRun(opts) if tt.wantErr != "" { @@ -310,10 +314,6 @@ func Test_createRun(t *testing.T) { require.Equal(t, tt.wantStdout, stdout.String()) require.Equal(t, tt.wantStdErr, stderr.String()) - - if tt.stubs != nil { - reg.Verify(t) - } }) } }