diff --git a/pkg/cmd/agent-task/capi/client.go b/pkg/cmd/agent-task/capi/client.go index 15765552b..3e6d92736 100644 --- a/pkg/cmd/agent-task/capi/client.go +++ b/pkg/cmd/agent-task/capi/client.go @@ -20,6 +20,8 @@ type CapiClient interface { CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*Job, error) GetJob(ctx context.Context, owner, repo, jobID string) (*Job, error) GetSession(ctx context.Context, id string) (*Session, error) + ListSessionsByResourceID(ctx context.Context, resourceType string, resourceID int64, limit int) ([]*Session, error) + GetPullRequestDatabaseID(ctx context.Context, hostname string, owner string, repo string, number int) (int64, error) } // CAPIClient is a client for interacting with the Copilot API diff --git a/pkg/cmd/agent-task/capi/client_mock.go b/pkg/cmd/agent-task/capi/client_mock.go index ba7c05ab0..7998f94d8 100644 --- a/pkg/cmd/agent-task/capi/client_mock.go +++ b/pkg/cmd/agent-task/capi/client_mock.go @@ -24,9 +24,15 @@ var _ CapiClient = &CapiClientMock{} // GetJobFunc: func(ctx context.Context, owner string, repo string, jobID string) (*Job, error) { // panic("mock out the GetJob method") // }, +// GetPullRequestDatabaseIDFunc: func(ctx context.Context, hostname string, owner string, repo string, number int) (int64, error) { +// panic("mock out the GetPullRequestDatabaseID method") +// }, // GetSessionFunc: func(ctx context.Context, id string) (*Session, error) { // panic("mock out the GetSession method") // }, +// ListSessionsByResourceIDFunc: func(ctx context.Context, resourceType string, resourceID int64, limit int) ([]*Session, error) { +// panic("mock out the ListSessionsByResourceID method") +// }, // ListSessionsForRepoFunc: func(ctx context.Context, owner string, repo string, limit int) ([]*Session, error) { // panic("mock out the ListSessionsForRepo method") // }, @@ -46,9 +52,15 @@ type CapiClientMock struct { // GetJobFunc mocks the GetJob method. GetJobFunc func(ctx context.Context, owner string, repo string, jobID string) (*Job, error) + // GetPullRequestDatabaseIDFunc mocks the GetPullRequestDatabaseID method. + GetPullRequestDatabaseIDFunc func(ctx context.Context, hostname string, owner string, repo string, number int) (int64, error) + // GetSessionFunc mocks the GetSession method. GetSessionFunc func(ctx context.Context, id string) (*Session, error) + // ListSessionsByResourceIDFunc mocks the ListSessionsByResourceID method. + ListSessionsByResourceIDFunc func(ctx context.Context, resourceType string, resourceID int64, limit int) ([]*Session, error) + // ListSessionsForRepoFunc mocks the ListSessionsForRepo method. ListSessionsForRepoFunc func(ctx context.Context, owner string, repo string, limit int) ([]*Session, error) @@ -81,6 +93,19 @@ type CapiClientMock struct { // JobID is the jobID argument value. JobID string } + // GetPullRequestDatabaseID holds details about calls to the GetPullRequestDatabaseID method. + GetPullRequestDatabaseID []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Hostname is the hostname argument value. + Hostname string + // Owner is the owner argument value. + Owner string + // Repo is the repo argument value. + Repo string + // Number is the number argument value. + Number int + } // GetSession holds details about calls to the GetSession method. GetSession []struct { // Ctx is the ctx argument value. @@ -88,6 +113,17 @@ type CapiClientMock struct { // ID is the id argument value. ID string } + // ListSessionsByResourceID holds details about calls to the ListSessionsByResourceID method. + ListSessionsByResourceID []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // ResourceType is the resourceType argument value. + ResourceType string + // ResourceID is the resourceID argument value. + ResourceID int64 + // Limit is the limit argument value. + Limit int + } // ListSessionsForRepo holds details about calls to the ListSessionsForRepo method. ListSessionsForRepo []struct { // Ctx is the ctx argument value. @@ -107,11 +143,13 @@ type CapiClientMock struct { Limit int } } - lockCreateJob sync.RWMutex - lockGetJob sync.RWMutex - lockGetSession sync.RWMutex - lockListSessionsForRepo sync.RWMutex - lockListSessionsForViewer sync.RWMutex + lockCreateJob sync.RWMutex + lockGetJob sync.RWMutex + lockGetPullRequestDatabaseID sync.RWMutex + lockGetSession sync.RWMutex + lockListSessionsByResourceID sync.RWMutex + lockListSessionsForRepo sync.RWMutex + lockListSessionsForViewer sync.RWMutex } // CreateJob calls CreateJobFunc. @@ -206,6 +244,54 @@ func (mock *CapiClientMock) GetJobCalls() []struct { return calls } +// GetPullRequestDatabaseID calls GetPullRequestDatabaseIDFunc. +func (mock *CapiClientMock) GetPullRequestDatabaseID(ctx context.Context, hostname string, owner string, repo string, number int) (int64, error) { + if mock.GetPullRequestDatabaseIDFunc == nil { + panic("CapiClientMock.GetPullRequestDatabaseIDFunc: method is nil but CapiClient.GetPullRequestDatabaseID was just called") + } + callInfo := struct { + Ctx context.Context + Hostname string + Owner string + Repo string + Number int + }{ + Ctx: ctx, + Hostname: hostname, + Owner: owner, + Repo: repo, + Number: number, + } + mock.lockGetPullRequestDatabaseID.Lock() + mock.calls.GetPullRequestDatabaseID = append(mock.calls.GetPullRequestDatabaseID, callInfo) + mock.lockGetPullRequestDatabaseID.Unlock() + return mock.GetPullRequestDatabaseIDFunc(ctx, hostname, owner, repo, number) +} + +// GetPullRequestDatabaseIDCalls gets all the calls that were made to GetPullRequestDatabaseID. +// Check the length with: +// +// len(mockedCapiClient.GetPullRequestDatabaseIDCalls()) +func (mock *CapiClientMock) GetPullRequestDatabaseIDCalls() []struct { + Ctx context.Context + Hostname string + Owner string + Repo string + Number int +} { + var calls []struct { + Ctx context.Context + Hostname string + Owner string + Repo string + Number int + } + mock.lockGetPullRequestDatabaseID.RLock() + calls = mock.calls.GetPullRequestDatabaseID + mock.lockGetPullRequestDatabaseID.RUnlock() + return calls +} + // GetSession calls GetSessionFunc. func (mock *CapiClientMock) GetSession(ctx context.Context, id string) (*Session, error) { if mock.GetSessionFunc == nil { @@ -242,6 +328,50 @@ func (mock *CapiClientMock) GetSessionCalls() []struct { return calls } +// ListSessionsByResourceID calls ListSessionsByResourceIDFunc. +func (mock *CapiClientMock) ListSessionsByResourceID(ctx context.Context, resourceType string, resourceID int64, limit int) ([]*Session, error) { + if mock.ListSessionsByResourceIDFunc == nil { + panic("CapiClientMock.ListSessionsByResourceIDFunc: method is nil but CapiClient.ListSessionsByResourceID was just called") + } + callInfo := struct { + Ctx context.Context + ResourceType string + ResourceID int64 + Limit int + }{ + Ctx: ctx, + ResourceType: resourceType, + ResourceID: resourceID, + Limit: limit, + } + mock.lockListSessionsByResourceID.Lock() + mock.calls.ListSessionsByResourceID = append(mock.calls.ListSessionsByResourceID, callInfo) + mock.lockListSessionsByResourceID.Unlock() + return mock.ListSessionsByResourceIDFunc(ctx, resourceType, resourceID, limit) +} + +// ListSessionsByResourceIDCalls gets all the calls that were made to ListSessionsByResourceID. +// Check the length with: +// +// len(mockedCapiClient.ListSessionsByResourceIDCalls()) +func (mock *CapiClientMock) ListSessionsByResourceIDCalls() []struct { + Ctx context.Context + ResourceType string + ResourceID int64 + Limit int +} { + var calls []struct { + Ctx context.Context + ResourceType string + ResourceID int64 + Limit int + } + mock.lockListSessionsByResourceID.RLock() + calls = mock.calls.ListSessionsByResourceID + mock.lockListSessionsByResourceID.RUnlock() + return calls +} + // ListSessionsForRepo calls ListSessionsForRepoFunc. func (mock *CapiClientMock) ListSessionsForRepo(ctx context.Context, owner string, repo string, limit int) ([]*Session, error) { if mock.ListSessionsForRepoFunc == nil { diff --git a/pkg/cmd/agent-task/capi/sessions.go b/pkg/cmd/agent-task/capi/sessions.go index e0252a8bb..f6c5d7856 100644 --- a/pkg/cmd/agent-task/capi/sessions.go +++ b/pkg/cmd/agent-task/capi/sessions.go @@ -14,6 +14,7 @@ import ( "time" "github.com/cli/cli/v2/api" + "github.com/shurcooL/githubv4" "github.com/vmihailenco/msgpack/v5" ) @@ -131,7 +132,6 @@ func (c *CAPIClient) ListSessionsForViewer(ctx context.Context, limit int) ([]*S sessions = sessions[:limit] } - // Hydrate the result with pull request data. result, err := c.hydrateSessionPullRequestsAndUsers(sessions) if err != nil { return nil, fmt.Errorf("failed to fetch session resources: %w", err) @@ -192,7 +192,6 @@ func (c *CAPIClient) ListSessionsForRepo(ctx context.Context, owner string, repo sessions = sessions[:limit] } - // Hydrate the result with pull request data. result, err := c.hydrateSessionPullRequestsAndUsers(sessions) if err != nil { return nil, fmt.Errorf("failed to fetch session resources: %w", err) @@ -239,6 +238,65 @@ func (c *CAPIClient) GetSession(ctx context.Context, id string) (*Session, error return sessions[0], nil } +// ListSessionsByResourceID retrieves sessions associated with the given resource type and ID. +func (c *CAPIClient) ListSessionsByResourceID(ctx context.Context, resourceType string, resourceID int64, limit int) ([]*Session, error) { + if resourceType == "" || resourceID == 0 { + return nil, fmt.Errorf("missing resource type/ID") + } + + if limit == 0 { + return nil, nil + } + + url := fmt.Sprintf("%s/agents/sessions/resource/%s/%d", baseCAPIURL, url.PathEscape(resourceType), resourceID) + pageSize := defaultSessionsPerPage + + sessions := make([]session, 0, limit+pageSize) + + for page := 1; ; page++ { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return nil, err + } + + q := req.URL.Query() + q.Set("page_size", strconv.Itoa(pageSize)) + q.Set("page_number", strconv.Itoa(page)) + req.URL.RawQuery = q.Encode() + + res, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to list sessions: %s", res.Status) + } + var response struct { + Sessions []session `json:"sessions"` + } + if err := json.NewDecoder(res.Body).Decode(&response); err != nil { + return nil, fmt.Errorf("failed to decode sessions response: %w", err) + } + + sessions = append(sessions, response.Sessions...) + if len(response.Sessions) < pageSize || len(sessions) >= limit { + break + } + } + + // Drop any above the limit + if len(sessions) > limit { + sessions = sessions[:limit] + } + + result, err := c.hydrateSessionPullRequestsAndUsers(sessions) + if err != nil { + return nil, fmt.Errorf("failed to fetch session resources: %w", err) + } + return result, nil +} + // hydrateSessionPullRequestsAndUsers hydrates pull request and user information in sessions func (c *CAPIClient) hydrateSessionPullRequestsAndUsers(sessions []session) ([]*Session, error) { if len(sessions) == 0 { @@ -248,9 +306,11 @@ func (c *CAPIClient) hydrateSessionPullRequestsAndUsers(sessions []session) ([]* prNodeIds := make([]string, 0, len(sessions)) userNodeIds := make([]string, 0, len(sessions)) for _, session := range sessions { - prNodeID := generatePullRequestNodeID(int64(session.RepoID), session.ResourceID) - if !slices.Contains(prNodeIds, prNodeID) { - prNodeIds = append(prNodeIds, prNodeID) + if session.ResourceType == "pull" { + prNodeID := generatePullRequestNodeID(int64(session.RepoID), session.ResourceID) + if !slices.Contains(prNodeIds, prNodeID) { + prNodeIds = append(prNodeIds, prNodeID) + } } userNodeId := generateUserNodeID(session.UserID) @@ -318,6 +378,34 @@ func (c *CAPIClient) hydrateSessionPullRequestsAndUsers(sessions []session) ([]* return newSessions, nil } +// GetPullRequestDatabaseID retrieves the database ID of a pull request given its number in a repository. +func (c *CAPIClient) GetPullRequestDatabaseID(ctx context.Context, hostname string, owner string, repo string, number int) (int64, error) { + var resp struct { + Repository struct { + PullRequest struct { + FullDatabaseID string `graphql:"fullDatabaseId"` + } `graphql:"pullRequest(number: $number)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + + variables := map[string]interface{}{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "number": githubv4.Int(number), + } + + apiClient := api.NewClientFromHTTP(c.httpClient) + if err := apiClient.Query(hostname, "GetPullRequestFullDatabaseID", &resp, variables); err != nil { + return 0, err + } + + databaseID, err := strconv.ParseInt(resp.Repository.PullRequest.FullDatabaseID, 10, 64) + if err != nil { + return 0, err + } + return databaseID, nil +} + // generatePullRequestNodeID converts an int64 databaseID and repoID to a GraphQL Node ID format // with the "PR_" prefix for pull requests func generatePullRequestNodeID(repoID, pullRequestID int64) string { diff --git a/pkg/cmd/agent-task/capi/sessions_test.go b/pkg/cmd/agent-task/capi/sessions_test.go index 115c7ab9e..1b750f56b 100644 --- a/pkg/cmd/agent-task/capi/sessions_test.go +++ b/pkg/cmd/agent-task/capi/sessions_test.go @@ -159,6 +159,84 @@ func TestListSessionsForViewer(t *testing.T) { }, }, }, + { + // This happens at the early moments of a session lifecycle, before a PR is created and associated with it. + name: "single session, no pull request resource", + limit: 10, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.WithHost( + httpmock.QueryMatcher("GET", "agents/sessions", url.Values{ + "page_number": {"1"}, + "page_size": {"50"}, + }), + "api.githubcopilot.com", + ), + httpmock.StringResponse(heredoc.Docf(` + { + "sessions": [ + { + "id": "sess1", + "name": "Build artifacts", + "user_id": 1, + "agent_id": 2, + "logs": "", + "state": "completed", + "owner_id": 10, + "repo_id": 1000, + "resource_type": "", + "resource_id": 0, + "created_at": "%[1]s" + } + ] + }`, + sampleDateString, + )), + ) + // GraphQL hydration + reg.Register( + httpmock.GraphQL(`query FetchPRsAndUsersForAgentTaskSessions\b`), + httpmock.GraphQLQuery(heredoc.Docf(` + { + "data": { + "nodes": [ + { + "__typename": "User", + "login": "octocat", + "name": "Octocat", + "databaseId": 1 + } + ] + } + }`, + sampleDateString, + ), func(q string, vars map[string]interface{}) { + assert.Equal(t, []interface{}{"U_kgAB"}, vars["ids"]) + }), + ) + }, + wantOut: []*Session{ + { + + ID: "sess1", + Name: "Build artifacts", + UserID: 1, + AgentID: 2, + Logs: "", + State: "completed", + OwnerID: 10, + RepoID: 1000, + ResourceType: "", + ResourceID: 0, + CreatedAt: sampleDate, + User: &api.GitHubUser{ + Login: "octocat", + Name: "Octocat", + DatabaseID: 1, + }, + }, + }, + }, { name: "multiple sessions, paginated", perPage: 1, // to enforce pagination @@ -594,6 +672,84 @@ func TestListSessionsForRepo(t *testing.T) { }, }, }, + { + // This happens at the early moments of a session lifecycle, before a PR is created and associated with it. + name: "single session, no pull request resource", + limit: 10, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.WithHost( + httpmock.QueryMatcher("GET", "agents/sessions/nwo/OWNER/REPO", url.Values{ + "page_number": {"1"}, + "page_size": {"50"}, + }), + "api.githubcopilot.com", + ), + httpmock.StringResponse(heredoc.Docf(` + { + "sessions": [ + { + "id": "sess1", + "name": "Build artifacts", + "user_id": 1, + "agent_id": 2, + "logs": "", + "state": "completed", + "owner_id": 10, + "repo_id": 1000, + "resource_type": "", + "resource_id": 0, + "created_at": "%[1]s" + } + ] + }`, + sampleDateString, + )), + ) + // GraphQL hydration + reg.Register( + httpmock.GraphQL(`query FetchPRsAndUsersForAgentTaskSessions\b`), + httpmock.GraphQLQuery(heredoc.Docf(` + { + "data": { + "nodes": [ + { + "__typename": "User", + "login": "octocat", + "name": "Octocat", + "databaseId": 1 + } + ] + } + }`, + sampleDateString, + ), func(q string, vars map[string]interface{}) { + assert.Equal(t, []interface{}{"U_kgAB"}, vars["ids"]) + }), + ) + }, + wantOut: []*Session{ + { + + ID: "sess1", + Name: "Build artifacts", + UserID: 1, + AgentID: 2, + Logs: "", + State: "completed", + OwnerID: 10, + RepoID: 1000, + ResourceType: "", + ResourceID: 0, + CreatedAt: sampleDate, + User: &api.GitHubUser{ + Login: "octocat", + Name: "Octocat", + DatabaseID: 1, + }, + }, + }, + }, { name: "multiple sessions, paginated", perPage: 1, // to enforce pagination @@ -876,6 +1032,445 @@ func TestListSessionsForRepo(t *testing.T) { } } +func TestListSessionsByResourceIDRequiresResource(t *testing.T) { + client := &CAPIClient{} + + _, err := client.ListSessionsByResourceID(context.Background(), "", 999, 0) + assert.EqualError(t, err, "missing resource type/ID") + _, err = client.ListSessionsByResourceID(context.Background(), "only-resource-type", 0, 0) + assert.EqualError(t, err, "missing resource type/ID") + _, err = client.ListSessionsByResourceID(context.Background(), "", 0, 0) + assert.EqualError(t, err, "missing resource type/ID") +} + +func TestListSessionsByResourceID(t *testing.T) { + sampleDateString := "2025-08-29T00:00:00Z" + sampleDate, err := time.Parse(time.RFC3339, sampleDateString) + require.NoError(t, err) + + resourceID := int64(999) + resourceType := "pull" + + tests := []struct { + name string + perPage int + limit int + httpStubs func(*testing.T, *httpmock.Registry) + wantErr string + wantOut []*Session + }{ + { + name: "zero limit", + limit: 0, + wantOut: nil, + }, + { + name: "no sessions", + limit: 10, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.WithHost( + httpmock.QueryMatcher("GET", "agents/sessions/resource/pull/999", url.Values{ + "page_number": {"1"}, + "page_size": {"50"}, + }), + "api.githubcopilot.com", + ), + httpmock.StringResponse(`{"sessions":[]}`), + ) + }, + wantOut: nil, + }, + { + name: "single session", + limit: 10, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.WithHost( + httpmock.QueryMatcher("GET", "agents/sessions/resource/pull/999", url.Values{ + "page_number": {"1"}, + "page_size": {"50"}, + }), + "api.githubcopilot.com", + ), + httpmock.StringResponse(heredoc.Docf(` + { + "sessions": [ + { + "id": "sess1", + "name": "Build artifacts", + "user_id": 1, + "agent_id": 2, + "logs": "", + "state": "completed", + "owner_id": 10, + "repo_id": 1000, + "resource_type": "pull", + "resource_id": 2000, + "created_at": "%[1]s" + } + ] + }`, + sampleDateString, + )), + ) + // GraphQL hydration + reg.Register( + httpmock.GraphQL(`query FetchPRsAndUsersForAgentTaskSessions\b`), + httpmock.GraphQLQuery(heredoc.Docf(` + { + "data": { + "nodes": [ + { + "__typename": "PullRequest", + "id": "PR_node", + "fullDatabaseId": "2000", + "number": 42, + "title": "Improve docs", + "state": "OPEN", + "isDraft": true, + "url": "https://github.com/OWNER/REPO/pull/42", + "body": "", + "createdAt": "%[1]s", + "updatedAt": "%[1]s", + "repository": { + "nameWithOwner": "OWNER/REPO" + } + }, + { + "__typename": "User", + "login": "octocat", + "name": "Octocat", + "databaseId": 1 + } + ] + } + }`, + sampleDateString, + ), func(q string, vars map[string]interface{}) { + assert.Equal(t, []interface{}{"PR_kwDNA-jNB9A", "U_kgAB"}, vars["ids"]) + }), + ) + }, + wantOut: []*Session{ + { + + ID: "sess1", + Name: "Build artifacts", + UserID: 1, + AgentID: 2, + Logs: "", + State: "completed", + OwnerID: 10, + RepoID: 1000, + ResourceType: "pull", + ResourceID: 2000, + CreatedAt: sampleDate, + PullRequest: &api.PullRequest{ + ID: "PR_node", + FullDatabaseID: "2000", + Number: 42, + Title: "Improve docs", + State: "OPEN", + IsDraft: true, + URL: "https://github.com/OWNER/REPO/pull/42", + Body: "", + CreatedAt: sampleDate, + UpdatedAt: sampleDate, + Repository: &api.PRRepository{ + NameWithOwner: "OWNER/REPO", + }, + }, + User: &api.GitHubUser{ + Login: "octocat", + Name: "Octocat", + DatabaseID: 1, + }, + }, + }, + }, + { + name: "multiple sessions, paginated", + perPage: 1, // to enforce pagination + limit: 2, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.WithHost( + httpmock.QueryMatcher("GET", "agents/sessions/resource/pull/999", url.Values{ + "page_number": {"1"}, + "page_size": {"1"}, + }), + "api.githubcopilot.com", + ), + httpmock.StringResponse(heredoc.Docf(` + { + "sessions": [ + { + "id": "sess1", + "name": "Build artifacts", + "user_id": 1, + "agent_id": 2, + "logs": "", + "state": "completed", + "owner_id": 10, + "repo_id": 1000, + "resource_type": "pull", + "resource_id": 2000, + "created_at": "%[1]s" + } + ] + }`, + sampleDateString, + )), + ) + + // Second page + reg.Register( + httpmock.WithHost( + httpmock.QueryMatcher("GET", "agents/sessions/resource/pull/999", url.Values{ + "page_number": {"2"}, + "page_size": {"1"}, + }), + "api.githubcopilot.com", + ), + httpmock.StringResponse(heredoc.Docf(` + { + "sessions": [ + { + "id": "sess2", + "name": "Build artifacts", + "user_id": 1, + "agent_id": 2, + "logs": "", + "state": "completed", + "owner_id": 10, + "repo_id": 1000, + "resource_type": "pull", + "resource_id": 2001, + "created_at": "%[1]s" + } + ] + }`, + sampleDateString, + )), + ) + // GraphQL hydration + reg.Register( + httpmock.GraphQL(`query FetchPRsAndUsersForAgentTaskSessions\b`), + httpmock.GraphQLQuery(heredoc.Docf(` + { + "data": { + "nodes": [ + { + "__typename": "PullRequest", + "id": "PR_node", + "fullDatabaseId": "2000", + "number": 42, + "title": "Improve docs", + "state": "OPEN", + "isDraft": true, + "url": "https://github.com/OWNER/REPO/pull/42", + "body": "", + "createdAt": "%[1]s", + "updatedAt": "%[1]s", + "repository": { + "nameWithOwner": "OWNER/REPO" + } + }, + { + "__typename": "PullRequest", + "id": "PR_node", + "fullDatabaseId": "2001", + "number": 43, + "title": "Improve docs", + "state": "OPEN", + "isDraft": true, + "url": "https://github.com/OWNER/REPO/pull/43", + "body": "", + "createdAt": "%[1]s", + "updatedAt": "%[1]s", + "repository": { + "nameWithOwner": "OWNER/REPO" + } + }, + { + "__typename": "User", + "login": "octocat", + "name": "Octocat", + "databaseId": 1 + } + ] + } + }`, + sampleDateString, + ), func(q string, vars map[string]interface{}) { + assert.Equal(t, []interface{}{"PR_kwDNA-jNB9A", "PR_kwDNA-jNB9E", "U_kgAB"}, vars["ids"]) + }), + ) + }, + wantOut: []*Session{ + { + ID: "sess1", + Name: "Build artifacts", + UserID: 1, + AgentID: 2, + Logs: "", + State: "completed", + OwnerID: 10, + RepoID: 1000, + ResourceType: "pull", + ResourceID: 2000, + CreatedAt: sampleDate, + PullRequest: &api.PullRequest{ + ID: "PR_node", + FullDatabaseID: "2000", + Number: 42, + Title: "Improve docs", + State: "OPEN", + IsDraft: true, + URL: "https://github.com/OWNER/REPO/pull/42", + Body: "", + CreatedAt: sampleDate, + UpdatedAt: sampleDate, + Repository: &api.PRRepository{ + NameWithOwner: "OWNER/REPO", + }, + }, + User: &api.GitHubUser{ + Login: "octocat", + Name: "Octocat", + DatabaseID: 1, + }, + }, + { + ID: "sess2", + Name: "Build artifacts", + UserID: 1, + AgentID: 2, + Logs: "", + State: "completed", + OwnerID: 10, + RepoID: 1000, + ResourceType: "pull", + ResourceID: 2001, + CreatedAt: sampleDate, + PullRequest: &api.PullRequest{ + ID: "PR_node", + FullDatabaseID: "2001", + Number: 43, + Title: "Improve docs", + State: "OPEN", + IsDraft: true, + URL: "https://github.com/OWNER/REPO/pull/43", + Body: "", + CreatedAt: sampleDate, + UpdatedAt: sampleDate, + Repository: &api.PRRepository{ + NameWithOwner: "OWNER/REPO", + }, + }, + User: &api.GitHubUser{ + Login: "octocat", + Name: "Octocat", + DatabaseID: 1, + }, + }, + }, + }, + { + name: "API error", + limit: 10, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.WithHost( + httpmock.QueryMatcher("GET", "agents/sessions/resource/pull/999", url.Values{ + "page_number": {"1"}, + "page_size": {"50"}, + }), + "api.githubcopilot.com", + ), + httpmock.StatusStringResponse(500, "{}"), + ) + }, + wantErr: "failed to list sessions:", + }, { + name: "API error at hydration", + limit: 10, + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.WithHost( + httpmock.QueryMatcher("GET", "agents/sessions/resource/pull/999", url.Values{ + "page_number": {"1"}, + "page_size": {"50"}, + }), + "api.githubcopilot.com", + ), + httpmock.StringResponse(heredoc.Docf(` + { + "sessions": [ + { + "id": "sess1", + "name": "Build artifacts", + "user_id": 1, + "agent_id": 2, + "logs": "", + "state": "completed", + "owner_id": 10, + "repo_id": 1000, + "resource_type": "pull", + "resource_id": 2000, + "created_at": "%[1]s" + } + ] + }`, + sampleDateString, + )), + ) + // GraphQL hydration + reg.Register( + httpmock.GraphQL(`query FetchPRsAndUsersForAgentTaskSessions\b`), + httpmock.StatusStringResponse(500, `{}`), + ) + }, + wantErr: `failed to fetch session resources: non-200 OK status code:`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reg := &httpmock.Registry{} + if tt.httpStubs != nil { + tt.httpStubs(t, reg) + } + defer reg.Verify(t) + + httpClient := &http.Client{Transport: reg} + + cfg := config.NewBlankConfig() + capiClient := NewCAPIClient(httpClient, cfg.Authentication()) + + if tt.perPage != 0 { + last := defaultSessionsPerPage + defaultSessionsPerPage = tt.perPage + defer func() { + defaultSessionsPerPage = last + }() + } + + sessions, err := capiClient.ListSessionsByResourceID(context.Background(), resourceType, resourceID, tt.limit) + + if tt.wantErr != "" { + require.ErrorContains(t, err, tt.wantErr) + require.Nil(t, sessions) + return + } + + require.NoError(t, err) + require.Equal(t, tt.wantOut, sessions) + }) + } +} + func TestGetSessionRequiresID(t *testing.T) { client := &CAPIClient{} @@ -1020,6 +1615,70 @@ func TestGetSession(t *testing.T) { }, }, }, + { + // This happens at the early moments of a session lifecycle, before a PR is created and associated with it. + name: "success, but no pull request resource", + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.WithHost(httpmock.REST("GET", "agents/sessions/some-uuid"), "api.githubcopilot.com"), + httpmock.StringResponse(heredoc.Docf(` + { + "id": "some-uuid", + "name": "Build artifacts", + "user_id": 1, + "agent_id": 2, + "logs": "", + "state": "completed", + "owner_id": 10, + "repo_id": 1000, + "resource_type": "", + "resource_id": 0, + "created_at": "%[1]s" + }`, + sampleDateString, + )), + ) + // GraphQL hydration + reg.Register( + httpmock.GraphQL(`query FetchPRsAndUsersForAgentTaskSessions\b`), + httpmock.GraphQLQuery(heredoc.Docf(` + { + "data": { + "nodes": [ + { + "__typename": "User", + "login": "octocat", + "name": "Octocat", + "databaseId": 1 + } + ] + } + }`, + sampleDateString, + ), func(q string, vars map[string]interface{}) { + assert.Equal(t, []interface{}{"U_kgAB"}, vars["ids"]) + }), + ) + }, + wantOut: &Session{ + ID: "some-uuid", + Name: "Build artifacts", + UserID: 1, + AgentID: 2, + Logs: "", + State: "completed", + OwnerID: 10, + RepoID: 1000, + ResourceType: "", + ResourceID: 0, + CreatedAt: sampleDate, + User: &api.GitHubUser{ + Login: "octocat", + Name: "Octocat", + DatabaseID: 1, + }, + }, + }, { name: "API error at hydration", httpStubs: func(t *testing.T, reg *httpmock.Registry) { @@ -1082,3 +1741,73 @@ func TestGetSession(t *testing.T) { }) } } +func TestGetPullRequestDatabaseID(t *testing.T) { + tests := []struct { + name string + httpStubs func(*testing.T, *httpmock.Registry) + wantErr string + wantOut int64 + }{ + { + name: "graphql error", + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.WithHost(httpmock.GraphQL(`query GetPullRequestFullDatabaseID\b`), "api.github.com"), + httpmock.StringResponse(`{"data":{}, "errors": [{"message": "some gql error"}]}`), + ) + }, + wantErr: "some gql error", + }, + { + // This never happens in practice and it's just to cover more code path + name: "non-int database ID", + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.WithHost(httpmock.GraphQL(`query GetPullRequestFullDatabaseID\b`), "api.github.com"), + httpmock.StringResponse(`{"data": {"repository": {"pullRequest": {"fullDatabaseId": "non-int"}}}}`), + ) + }, + wantErr: `strconv.ParseInt: parsing "non-int": invalid syntax`, + }, + { + name: "success", + httpStubs: func(t *testing.T, reg *httpmock.Registry) { + reg.Register( + httpmock.WithHost(httpmock.GraphQL(`query GetPullRequestFullDatabaseID\b`), "api.github.com"), + httpmock.GraphQLQuery(`{"data": {"repository": {"pullRequest": {"fullDatabaseId": "999"}}}}`, func(s string, m map[string]interface{}) { + assert.Equal(t, "OWNER", m["owner"]) + assert.Equal(t, "REPO", m["repo"]) + assert.Equal(t, float64(42), m["number"]) + }), + ) + }, + wantOut: 999, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reg := &httpmock.Registry{} + if tt.httpStubs != nil { + tt.httpStubs(t, reg) + } + defer reg.Verify(t) + + httpClient := &http.Client{Transport: reg} + + cfg := config.NewBlankConfig() + capiClient := NewCAPIClient(httpClient, cfg.Authentication()) + + databaseID, err := capiClient.GetPullRequestDatabaseID(context.Background(), "github.com", "OWNER", "REPO", 42) + + if tt.wantErr != "" { + require.ErrorContains(t, err, tt.wantErr) + require.Zero(t, databaseID) + return + } + + require.NoError(t, err) + require.Equal(t, tt.wantOut, databaseID) + }) + } +} diff --git a/pkg/cmd/agent-task/shared/capi.go b/pkg/cmd/agent-task/shared/capi.go index f23ee86d2..f064eac2e 100644 --- a/pkg/cmd/agent-task/shared/capi.go +++ b/pkg/cmd/agent-task/shared/capi.go @@ -1,10 +1,14 @@ package shared import ( + "regexp" + "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" "github.com/cli/cli/v2/pkg/cmdutil" ) +var uuidRE = regexp.MustCompile(`^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$`) + func CapiClientFunc(f *cmdutil.Factory) func() (capi.CapiClient, error) { return func() (capi.CapiClient, error) { cfg, err := f.Config() @@ -21,3 +25,7 @@ func CapiClientFunc(f *cmdutil.Factory) func() (capi.CapiClient, error) { return capi.NewCAPIClient(httpClient, authCfg), nil } } + +func IsSessionID(s string) bool { + return uuidRE.MatchString(s) +} diff --git a/pkg/cmd/agent-task/shared/capi_test.go b/pkg/cmd/agent-task/shared/capi_test.go new file mode 100644 index 000000000..d6a106d1b --- /dev/null +++ b/pkg/cmd/agent-task/shared/capi_test.go @@ -0,0 +1,20 @@ +package shared + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsSession(t *testing.T) { + assert.True(t, IsSessionID("00000000-0000-0000-0000-000000000000")) + assert.True(t, IsSessionID("e2fa49d2-f164-4a56-ab99-498090b8fcdf")) + assert.True(t, IsSessionID("E2FA49D2-F164-4A56-AB99-498090B8FCDF")) + + assert.False(t, IsSessionID("")) + assert.False(t, IsSessionID(" ")) + assert.False(t, IsSessionID("\n")) + assert.False(t, IsSessionID("not-a-uuid")) + assert.False(t, IsSessionID("000000000000000000000000000000000000")) + assert.False(t, IsSessionID("00000000-0000-0000-0000-000000000000-extra")) +} diff --git a/pkg/cmd/agent-task/shared/display.go b/pkg/cmd/agent-task/shared/display.go index dd114b049..e841f4c41 100644 --- a/pkg/cmd/agent-task/shared/display.go +++ b/pkg/cmd/agent-task/shared/display.go @@ -24,6 +24,7 @@ func ColorFuncForSessionState(s capi.Session, cs *iostreams.ColorScheme) func(st return stateColor } +// SessionStateString returns the humane/capitalised form of the given session state. func SessionStateString(state string) string { switch state { case "queued": @@ -46,3 +47,17 @@ func SessionStateString(state string) string { return state } } + +type ColorFunc func(string) string + +func SessionSymbol(cs *iostreams.ColorScheme, state string) string { + noColor := func(s string) string { return s } + switch state { + case "completed": + return cs.SuccessIconWithColor(noColor) + case "failed", "timed_out", "cancelled": + return cs.FailureIconWithColor(noColor) + default: + return "-" + } +} diff --git a/pkg/cmd/agent-task/view/view.go b/pkg/cmd/agent-task/view/view.go index f6ce4d468..6dc5b6cf9 100644 --- a/pkg/cmd/agent-task/view/view.go +++ b/pkg/cmd/agent-task/view/view.go @@ -4,10 +4,15 @@ import ( "context" "errors" "fmt" + "net/http" "net/url" + "strconv" "time" "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/v2/internal/ghinstance" + "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/internal/prompter" "github.com/cli/cli/v2/internal/text" "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" "github.com/cli/cli/v2/pkg/cmd/agent-task/shared" @@ -17,28 +22,54 @@ import ( "github.com/spf13/cobra" ) +const defaultLimit = 40 + type ViewOptions struct { IO *iostreams.IOStreams + BaseRepo func() (ghrepo.Interface, error) CapiClient func() (capi.CapiClient, error) + HttpClient func() (*http.Client, error) + Finder prShared.PRFinder + Prompter prompter.Prompter SelectorArg string + PRNumber int + SessionID string } func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Command { opts := &ViewOptions{ IO: f.IOStreams, + HttpClient: f.HttpClient, CapiClient: shared.CapiClientFunc(f), + Prompter: f.Prompter, } cmd := &cobra.Command{ - Use: "view ", + Use: "view [ | | | ]", Short: "View an agent task session", Long: heredoc.Doc(` View an agent task session. `), - Args: cmdutil.ExactArgs(1, "a session ID is required"), + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - opts.SelectorArg = args[0] + // Support -R/--repo override + opts.BaseRepo = f.BaseRepo + + if len(args) > 0 { + opts.SelectorArg = args[0] + if shared.IsSessionID(opts.SelectorArg) { + opts.SessionID = opts.SelectorArg + } + } + + if opts.SessionID == "" && !opts.IO.CanPrompt() { + return fmt.Errorf("session ID is required when not running interactively") + } + + if opts.Finder == nil { + opts.Finder = prShared.NewFinder(f) + } if runF != nil { return runF(opts) @@ -47,6 +78,8 @@ func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Comman }, } + cmdutil.EnableRepoOverride(cmd, f) + return cmd } @@ -57,23 +90,113 @@ func viewRun(opts *ViewOptions) error { } ctx := context.Background() + cs := opts.IO.ColorScheme() opts.IO.StartProgressIndicatorWithLabel("Fetching agent session...") defer opts.IO.StopProgressIndicator() - session, err := capiClient.GetSession(ctx, opts.SelectorArg) - opts.IO.StopProgressIndicator() + var session *capi.Session - if err != nil { - if errors.Is(err, capi.ErrSessionNotFound) { - fmt.Fprintln(opts.IO.ErrOut, "session not found") + if opts.SessionID != "" { + if sess, err := capiClient.GetSession(ctx, opts.SessionID); err != nil { + if errors.Is(err, capi.ErrSessionNotFound) { + fmt.Fprintln(opts.IO.ErrOut, "session not found") + return cmdutil.SilentError + } + return err + } else { + session = sess + } + } else { + var resourceID int64 + + if opts.SelectorArg != "" { + // Finder does not support the PR/issue reference format (e.g. owner/repo#123) + // so we need to check if the selector arg is a reference and fetch the PR + // directly. + if repo, num, err := prShared.ParseFullReference(opts.SelectorArg); err == nil { + // Since the selector was a reference (i.e. without hostname data), we need to + // check the base repo to get the hostname. + baseRepo, err := opts.BaseRepo() + if err != nil { + return err + } + + hostname := baseRepo.RepoHost() + if hostname != ghinstance.Default() { + return fmt.Errorf("agent tasks are not supported on this host: %s", hostname) + } + + resourceID, err = capiClient.GetPullRequestDatabaseID(ctx, hostname, repo.RepoOwner(), repo.RepoName(), num) + if err != nil { + return fmt.Errorf("failed to fetch pull request: %w", err) + } + } + } + + if resourceID == 0 { + findOptions := prShared.FindOptions{ + Selector: opts.SelectorArg, + Fields: []string{"id", "url", "fullDatabaseId"}, + } + + pr, repo, err := opts.Finder.Find(findOptions) + if err != nil { + return err + } + + if repo.RepoHost() != ghinstance.Default() { + return fmt.Errorf("agent tasks are not supported on this host: %s", repo.RepoHost()) + } + + databaseID, err := strconv.ParseInt(pr.FullDatabaseID, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse pull request: %w", err) + } + + resourceID = databaseID + } + + // TODO(babakks): currently we just fetch a pre-defined number of + // matching sessions to avoid hitting the API too many times, but it's + // technically possible for a PR to be associated with lots of sessions + // (i.e. above our selected limit). + sessions, err := capiClient.ListSessionsByResourceID(ctx, "pull", resourceID, defaultLimit) + if err != nil { + return fmt.Errorf("failed to list sessions for pull request: %w", err) + } + + if len(sessions) == 0 { + fmt.Fprintln(opts.IO.ErrOut, "no session found for pull request") return cmdutil.SilentError } - return err + + session = sessions[0] + if len(sessions) > 1 { + now := time.Now() + options := make([]string, 0, len(sessions)) + for _, session := range sessions { + options = append(options, fmt.Sprintf( + "%s %s • %s", + shared.SessionSymbol(cs, session.State), + session.Name, + text.FuzzyAgo(now, session.CreatedAt), + )) + } + + opts.IO.StopProgressIndicator() + selected, err := opts.Prompter.Select("Select a session", options[0], options) + if err != nil { + return err + } + + session = sessions[selected] + } } + opts.IO.StopProgressIndicator() + out := opts.IO.Out - cs := opts.IO.ColorScheme() if session.PullRequest != nil { fmt.Fprintf(out, "%s • %s • %s%s\n", @@ -83,7 +206,7 @@ func viewRun(opts *ViewOptions) error { cs.ColorFromString(prShared.ColorForPRState(*session.PullRequest))(fmt.Sprintf("#%d", session.PullRequest.Number)), ) } else { - // Should never happen, but we need to cover the path + // This can happen when the session is just created and a PR is not yet available for it fmt.Fprintf(out, "%s\n", shared.ColorFuncForSessionState(*session, cs)(shared.SessionStateString(session.State))) } diff --git a/pkg/cmd/agent-task/view/view_test.go b/pkg/cmd/agent-task/view/view_test.go index 97304c399..64b63df5b 100644 --- a/pkg/cmd/agent-task/view/view_test.go +++ b/pkg/cmd/agent-task/view/view_test.go @@ -10,7 +10,10 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" + "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/internal/prompter" "github.com/cli/cli/v2/pkg/cmd/agent-task/capi" + prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" "github.com/google/shlex" @@ -20,20 +23,49 @@ import ( func TestNewCmdList(t *testing.T) { tests := []struct { - name string - args string - wantOpts ViewOptions - wantErr string + name string + tty bool + args string + wantOpts ViewOptions + wantBaseRepo ghrepo.Interface + wantErr string }{ { - name: "no arguments", - wantErr: "a session ID is required", + name: "no arg tty", + tty: true, + args: "", + wantOpts: ViewOptions{}, }, { - name: "session ID arg", - args: "some-uuid", + name: "session ID arg tty", + tty: true, + args: "00000000-0000-0000-0000-000000000000", wantOpts: ViewOptions{ - SelectorArg: "some-uuid", + SelectorArg: "00000000-0000-0000-0000-000000000000", + SessionID: "00000000-0000-0000-0000-000000000000", + }, + }, + { + name: "non-session ID arg tty", + tty: true, + args: "some-arg", + wantOpts: ViewOptions{ + SelectorArg: "some-arg", + }, + }, + { + name: "session ID required if non-tty", + tty: false, + args: "some-arg", + wantErr: "session ID is required when not running interactively", + }, + { + name: "repo override", + tty: true, + args: "some-arg -R OWNER/REPO", + wantBaseRepo: ghrepo.New("OWNER", "REPO"), + wantOpts: ViewOptions{ + SelectorArg: "some-arg", }, }, } @@ -41,6 +73,10 @@ func TestNewCmdList(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ios, _, _, _ := iostreams.Test() + ios.SetStdinTTY(tt.tty) + ios.SetStdoutTTY(tt.tty) + ios.SetStderrTTY(tt.tty) + f := &cmdutil.Factory{ IOStreams: ios, } @@ -65,6 +101,13 @@ func TestNewCmdList(t *testing.T) { require.NoError(t, err) assert.Equal(t, tt.wantOpts.SelectorArg, gotOpts.SelectorArg) + assert.Equal(t, tt.wantOpts.SessionID, gotOpts.SessionID) + + if tt.wantBaseRepo != nil { + baseRepo, err := gotOpts.BaseRepo() + require.NoError(t, err) + assert.True(t, ghrepo.IsSame(tt.wantBaseRepo, baseRepo)) + } }) } } @@ -74,19 +117,23 @@ func Test_viewRun(t *testing.T) { tests := []struct { name string - selectorArg string tty bool + opts ViewOptions + promptStubs func(*testing.T, *prompter.MockPrompter) capiStubs func(*testing.T, *capi.CapiClientMock) wantOut string wantErr error wantStderr string }{ { - name: "not found (tty)", - tty: true, - selectorArg: "some-session-id", + name: "with session id, not found (tty)", + tty: true, + opts: ViewOptions{ + SelectorArg: "some-session-id", + SessionID: "some-session-id", + }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) { + m.GetSessionFunc = func(_ context.Context, _ string) (*capi.Session, error) { return nil, capi.ErrSessionNotFound } }, @@ -94,43 +141,29 @@ func Test_viewRun(t *testing.T) { wantErr: cmdutil.SilentError, }, { - name: "not found (nontty)", - selectorArg: "some-session-id", - capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) { - return nil, capi.ErrSessionNotFound - } + name: "with session id, api error (tty)", + tty: true, + opts: ViewOptions{ + SelectorArg: "some-session-id", + SessionID: "some-session-id", }, - wantStderr: "session not found\n", - wantErr: cmdutil.SilentError, - }, - { - name: "API error (tty)", - tty: true, - selectorArg: "some-session-id", capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) { + m.GetSessionFunc = func(_ context.Context, _ string) (*capi.Session, error) { return nil, errors.New("some error") } }, wantErr: errors.New("some error"), }, { - name: "API error (nontty)", - selectorArg: "some-session-id", - capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) { - return nil, errors.New("some error") - } + name: "with session id, success, with pr and user data (tty)", + tty: true, + opts: ViewOptions{ + SelectorArg: "some-session-id", + SessionID: "some-session-id", }, - wantErr: errors.New("some error"), - }, - { - name: "success, with PR and user data (tty)", - tty: true, - selectorArg: "some-session-id", capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) { + m.GetSessionFunc = func(_ context.Context, id string) (*capi.Session, error) { + assert.Equal(t, "some-session-id", id) return &capi.Session{ ID: "some-session-id", State: "completed", @@ -152,17 +185,22 @@ func Test_viewRun(t *testing.T) { wantOut: heredoc.Doc(` Completed • fix something • OWNER/REPO#101 Started on behalf of octocat about 6 hours ago - + View this session on GitHub: https://github.com/OWNER/REPO/pull/101/agent-sessions/some-session-id `), }, { - name: "success, without user data (tty)", - tty: true, - selectorArg: "some-session-id", + // The user data should always be there, but we need to cover the code path. + name: "with session id, success, without user data (tty)", + tty: true, + opts: ViewOptions{ + SelectorArg: "some-session-id", + SessionID: "some-session-id", + }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) { + m.GetSessionFunc = func(_ context.Context, id string) (*capi.Session, error) { + assert.Equal(t, "some-session-id", id) return &capi.Session{ ID: "some-session-id", State: "completed", @@ -181,17 +219,22 @@ func Test_viewRun(t *testing.T) { wantOut: heredoc.Doc(` Completed • fix something • OWNER/REPO#101 Started about 6 hours ago - + View this session on GitHub: https://github.com/OWNER/REPO/pull/101/agent-sessions/some-session-id `), }, { - name: "success, without PR data (tty)", - tty: true, - selectorArg: "some-session-id", + // This can happen when the session is just created and a PR is not yet available for it. + name: "with session id, success, without pr data (tty)", + tty: true, + opts: ViewOptions{ + SelectorArg: "some-session-id", + SessionID: "some-session-id", + }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) { + m.GetSessionFunc = func(_ context.Context, id string) (*capi.Session, error) { + assert.Equal(t, "some-session-id", id) return &capi.Session{ ID: "some-session-id", State: "completed", @@ -208,11 +251,16 @@ func Test_viewRun(t *testing.T) { `), }, { - name: "success, without PR nor user data (tty)", - tty: true, - selectorArg: "some-session-id", + // The user data should always be there, but we need to cover the code path. + name: "with session id, success, without pr nor user data (tty)", + tty: true, + opts: ViewOptions{ + SelectorArg: "some-session-id", + SessionID: "some-session-id", + }, capiStubs: func(t *testing.T, m *capi.CapiClientMock) { - m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) { + m.GetSessionFunc = func(_ context.Context, id string) (*capi.Session, error) { + assert.Equal(t, "some-session-id", id) return &capi.Session{ ID: "some-session-id", State: "completed", @@ -225,6 +273,267 @@ func Test_viewRun(t *testing.T) { Started about 6 hours ago `), }, + { + name: "with pr number, api error (tty)", + tty: true, + opts: ViewOptions{ + SelectorArg: "pr-number", + Finder: prShared.NewMockFinder( + "pr-number", + &api.PullRequest{FullDatabaseID: "999999"}, + ghrepo.New("OWNER", "REPO"), + ), + }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.ListSessionsByResourceIDFunc = func(_ context.Context, _ string, _ int64, _ int) ([]*capi.Session, error) { + return nil, errors.New("some error") + } + }, + wantErr: errors.New("failed to list sessions for pull request: some error"), + }, + { + name: "with pr reference, unsupported hostname (tty)", + tty: true, + opts: ViewOptions{ + SelectorArg: "OWNER/REPO#999", + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.NewWithHost("OWNER", "REPO", "foo.com"), nil + }, + }, + wantErr: errors.New("agent tasks are not supported on this host: foo.com"), + }, + { + name: "with pr reference, api error when fetching pr database ID (tty)", + tty: true, + opts: ViewOptions{ + SelectorArg: "OWNER/REPO#999", + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.GetPullRequestDatabaseIDFunc = func(_ context.Context, _ string, _ string, _ string, _ int) (int64, error) { + return 0, errors.New("some error") + } + }, + wantErr: errors.New("failed to fetch pull request: some error"), + }, + { + name: "with pr reference, api error when fetching session (tty)", + tty: true, + opts: ViewOptions{ + SelectorArg: "OWNER/REPO#999", + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.GetPullRequestDatabaseIDFunc = func(_ context.Context, _ string, _ string, _ string, _ int) (int64, error) { + return 999999, nil + } + m.ListSessionsByResourceIDFunc = func(_ context.Context, _ string, _ int64, _ int) ([]*capi.Session, error) { + return nil, errors.New("some error") + } + }, + wantErr: errors.New("failed to list sessions for pull request: some error"), + }, + { + name: "with pr number, success, single session with pr and user data (tty)", + tty: true, + opts: ViewOptions{ + SelectorArg: "pr-number", + Finder: prShared.NewMockFinder( + "pr-number", + &api.PullRequest{FullDatabaseID: "999999"}, + ghrepo.New("OWNER", "REPO"), + ), + }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.ListSessionsByResourceIDFunc = func(_ context.Context, resourceType string, resourceID int64, limit int) ([]*capi.Session, error) { + assert.Equal(t, "pull", resourceType) + assert.Equal(t, int64(999999), resourceID) + assert.Equal(t, defaultLimit, limit) + return []*capi.Session{ + { + ID: "some-session-id", + State: "completed", + CreatedAt: sampleDate, + PullRequest: &api.PullRequest{ + Title: "fix something", + Number: 101, + URL: "https://github.com/OWNER/REPO/pull/101", + Repository: &api.PRRepository{ + NameWithOwner: "OWNER/REPO", + }, + }, + User: &api.GitHubUser{ + Login: "octocat", + }, + }, + }, nil + } + }, + wantOut: heredoc.Doc(` + Completed • fix something • OWNER/REPO#101 + Started on behalf of octocat about 6 hours ago + + View this session on GitHub: + https://github.com/OWNER/REPO/pull/101/agent-sessions/some-session-id + `), + }, + { + name: "with pr number, success, multiple sessions with pr and user data (tty)", + tty: true, + opts: ViewOptions{ + SelectorArg: "pr-number", + Finder: prShared.NewMockFinder( + "pr-number", + &api.PullRequest{FullDatabaseID: "999999"}, + ghrepo.New("OWNER", "REPO"), + ), + }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.ListSessionsByResourceIDFunc = func(_ context.Context, resourceType string, resourceID int64, limit int) ([]*capi.Session, error) { + assert.Equal(t, "pull", resourceType) + assert.Equal(t, int64(999999), resourceID) + assert.Equal(t, defaultLimit, limit) + return []*capi.Session{ + { + ID: "some-session-id", + Name: "session one", + State: "completed", + CreatedAt: sampleDate, + PullRequest: &api.PullRequest{ + Title: "fix something", + Number: 101, + URL: "https://github.com/OWNER/REPO/pull/101", + Repository: &api.PRRepository{ + NameWithOwner: "OWNER/REPO", + }, + }, + User: &api.GitHubUser{ + Login: "octocat", + }, + }, + { + ID: "some-other-session-id", + Name: "session two", + State: "completed", + CreatedAt: sampleDate, + PullRequest: &api.PullRequest{ + Title: "fix something else", + Number: 102, + URL: "https://github.com/OWNER/REPO/pull/102", + Repository: &api.PRRepository{ + NameWithOwner: "OWNER/REPO", + }, + }, + User: &api.GitHubUser{ + Login: "octocat", + }, + }, + }, nil + } + }, + promptStubs: func(t *testing.T, pm *prompter.MockPrompter) { + pm.RegisterSelect( + "Select a session", + []string{ + "✓ session one • about 6 hours ago", + "✓ session two • about 6 hours ago", + }, + func(_, _ string, opts []string) (int, error) { + return prompter.IndexFor(opts, "✓ session one • about 6 hours ago") + }, + ) + }, + wantOut: heredoc.Doc(` + Completed • fix something • OWNER/REPO#101 + Started on behalf of octocat about 6 hours ago + + View this session on GitHub: + https://github.com/OWNER/REPO/pull/101/agent-sessions/some-session-id + `), + }, + { + name: "with pr reference, success, multiple sessions with pr and user data (tty)", + tty: true, + opts: ViewOptions{ + SelectorArg: "OWNER/REPO#999", + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + }, + capiStubs: func(t *testing.T, m *capi.CapiClientMock) { + m.GetPullRequestDatabaseIDFunc = func(_ context.Context, hostname string, owner string, repo string, number int) (int64, error) { + assert.Equal(t, "github.com", hostname) + assert.Equal(t, "OWNER", owner) + assert.Equal(t, "REPO", repo) + assert.Equal(t, 999, number) + return 999999, nil + } + m.ListSessionsByResourceIDFunc = func(_ context.Context, resourceType string, resourceID int64, limit int) ([]*capi.Session, error) { + assert.Equal(t, "pull", resourceType) + assert.Equal(t, int64(999999), resourceID) + assert.Equal(t, defaultLimit, limit) + return []*capi.Session{ + { + ID: "some-session-id", + Name: "session one", + State: "completed", + CreatedAt: sampleDate, + PullRequest: &api.PullRequest{ + Title: "fix something", + Number: 101, + URL: "https://github.com/OWNER/REPO/pull/101", + Repository: &api.PRRepository{ + NameWithOwner: "OWNER/REPO", + }, + }, + User: &api.GitHubUser{ + Login: "octocat", + }, + }, + { + ID: "some-other-session-id", + Name: "session two", + State: "completed", + CreatedAt: sampleDate, + PullRequest: &api.PullRequest{ + Title: "fix something else", + Number: 102, + URL: "https://github.com/OWNER/REPO/pull/102", + Repository: &api.PRRepository{ + NameWithOwner: "OWNER/REPO", + }, + }, + User: &api.GitHubUser{ + Login: "octocat", + }, + }, + }, nil + } + }, + promptStubs: func(t *testing.T, pm *prompter.MockPrompter) { + pm.RegisterSelect( + "Select a session", + []string{ + "✓ session one • about 6 hours ago", + "✓ session two • about 6 hours ago", + }, + func(_, _ string, opts []string) (int, error) { + return prompter.IndexFor(opts, "✓ session one • about 6 hours ago") + }, + ) + }, + wantOut: heredoc.Doc(` + Completed • fix something • OWNER/REPO#101 + Started on behalf of octocat about 6 hours ago + + View this session on GitHub: + https://github.com/OWNER/REPO/pull/101/agent-sessions/some-session-id + `), + }, } for _, tt := range tests { @@ -234,18 +543,22 @@ func Test_viewRun(t *testing.T) { tt.capiStubs(t, capiClientMock) } + prompter := prompter.NewMockPrompter(t) + if tt.promptStubs != nil { + tt.promptStubs(t, prompter) + } + ios, _, stdout, stderr := iostreams.Test() ios.SetStdoutTTY(tt.tty) - opts := &ViewOptions{ - IO: ios, - CapiClient: func() (capi.CapiClient, error) { - return capiClientMock, nil - }, - SelectorArg: tt.selectorArg, + opts := tt.opts + opts.IO = ios + opts.Prompter = prompter + opts.CapiClient = func() (capi.CapiClient, error) { + return capiClientMock, nil } - err := viewRun(opts) + err := viewRun(&opts) if tt.wantErr != nil { assert.Error(t, err) require.EqualError(t, err, tt.wantErr.Error()) diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index df8cc0fd4..cb8237d58 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -328,6 +328,31 @@ func ParseURL(prURL string) (ghrepo.Interface, int, error) { return repo, prNumber, nil } +var fullReferenceRE = regexp.MustCompile(`^(?:([^/]+)/([^/]+))#(\d+)$`) + +// ParseFullReference parses a short issue/pull request reference of the form +// "owner/repo#number", where owner, repo and number are all required. +func ParseFullReference(s string) (ghrepo.Interface, int, error) { + if s == "" { + return nil, 0, errors.New("empty reference") + } + + m := fullReferenceRE.FindStringSubmatch(s) + if m == nil { + return nil, 0, fmt.Errorf("invalid reference: %q", s) + } + + number, err := strconv.Atoi(m[3]) + if err != nil { + return nil, 0, fmt.Errorf("invalid reference: %q", number) + } + + owner := m[1] + repo := m[2] + + return ghrepo.New(owner, repo), number, nil +} + func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) { type response struct { Repository struct { diff --git a/pkg/cmd/pr/shared/finder_test.go b/pkg/cmd/pr/shared/finder_test.go index 0f6da5a6e..5e33ee876 100644 --- a/pkg/cmd/pr/shared/finder_test.go +++ b/pkg/cmd/pr/shared/finder_test.go @@ -11,6 +11,7 @@ import ( "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/httpmock" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -78,6 +79,80 @@ func TestParseURL(t *testing.T) { } } +func TestParseFullReference(t *testing.T) { + tests := []struct { + name string + arg string + wantRepo ghrepo.Interface + wantNumber int + wantErr string + }{ + { + name: "number", + arg: "123", + wantErr: `invalid reference: "123"`, + }, + { + name: "number with hash", + arg: "#123", + wantErr: `invalid reference: "#123"`, + }, + { + name: "full form", + arg: "OWNER/REPO#123", + wantNumber: 123, + wantRepo: ghrepo.New("OWNER", "REPO"), + }, + { + name: "empty", + wantErr: "empty reference", + }, + { + name: "invalid full form, without hash", + arg: "OWNER/REPO123", + wantErr: `invalid reference: "OWNER/REPO123"`, + }, + { + name: "invalid full form, empty owner and repo", + arg: "/#123", + wantErr: `invalid reference: "/#123"`, + }, + { + name: "invalid full form, without owner", + arg: "REPO#123", + wantErr: `invalid reference: "REPO#123"`, + }, + { + name: "invalid full form, without repo", + arg: "OWNER/#123", + wantErr: `invalid reference: "OWNER/#123"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo, number, err := ParseFullReference(tt.arg) + + if tt.wantErr != "" { + require.EqualError(t, err, tt.wantErr) + assert.Nil(t, repo) + assert.Zero(t, number) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantNumber, number) + + if tt.wantRepo != nil { + require.NotNil(t, repo) + assert.True(t, ghrepo.IsSame(tt.wantRepo, repo)) + } else { + assert.Nil(t, repo) + } + }) + } +} + type args struct { baseRepoFn func() (ghrepo.Interface, error) branchFn func() (string, error)