Merge pull request #11674 from cli/babakks/add-pr-number-arg-support-to-view-cmd

`gh agent-task view`: support PR number arg
This commit is contained in:
Kynan Ware 2025-09-09 10:44:51 -06:00 committed by GitHub
commit f9617d990f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1611 additions and 83 deletions

View file

@ -20,6 +20,8 @@ type CapiClient interface {
CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*Job, error)
GetJob(ctx context.Context, owner, repo, jobID string) (*Job, error)
GetSession(ctx context.Context, id string) (*Session, error)
ListSessionsByResourceID(ctx context.Context, resourceType string, resourceID int64, limit int) ([]*Session, error)
GetPullRequestDatabaseID(ctx context.Context, hostname string, owner string, repo string, number int) (int64, error)
}
// CAPIClient is a client for interacting with the Copilot API

View file

@ -24,9 +24,15 @@ var _ CapiClient = &CapiClientMock{}
// GetJobFunc: func(ctx context.Context, owner string, repo string, jobID string) (*Job, error) {
// panic("mock out the GetJob method")
// },
// GetPullRequestDatabaseIDFunc: func(ctx context.Context, hostname string, owner string, repo string, number int) (int64, error) {
// panic("mock out the GetPullRequestDatabaseID method")
// },
// GetSessionFunc: func(ctx context.Context, id string) (*Session, error) {
// panic("mock out the GetSession method")
// },
// ListSessionsByResourceIDFunc: func(ctx context.Context, resourceType string, resourceID int64, limit int) ([]*Session, error) {
// panic("mock out the ListSessionsByResourceID method")
// },
// ListSessionsForRepoFunc: func(ctx context.Context, owner string, repo string, limit int) ([]*Session, error) {
// panic("mock out the ListSessionsForRepo method")
// },
@ -46,9 +52,15 @@ type CapiClientMock struct {
// GetJobFunc mocks the GetJob method.
GetJobFunc func(ctx context.Context, owner string, repo string, jobID string) (*Job, error)
// GetPullRequestDatabaseIDFunc mocks the GetPullRequestDatabaseID method.
GetPullRequestDatabaseIDFunc func(ctx context.Context, hostname string, owner string, repo string, number int) (int64, error)
// GetSessionFunc mocks the GetSession method.
GetSessionFunc func(ctx context.Context, id string) (*Session, error)
// ListSessionsByResourceIDFunc mocks the ListSessionsByResourceID method.
ListSessionsByResourceIDFunc func(ctx context.Context, resourceType string, resourceID int64, limit int) ([]*Session, error)
// ListSessionsForRepoFunc mocks the ListSessionsForRepo method.
ListSessionsForRepoFunc func(ctx context.Context, owner string, repo string, limit int) ([]*Session, error)
@ -81,6 +93,19 @@ type CapiClientMock struct {
// JobID is the jobID argument value.
JobID string
}
// GetPullRequestDatabaseID holds details about calls to the GetPullRequestDatabaseID method.
GetPullRequestDatabaseID []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// Hostname is the hostname argument value.
Hostname string
// Owner is the owner argument value.
Owner string
// Repo is the repo argument value.
Repo string
// Number is the number argument value.
Number int
}
// GetSession holds details about calls to the GetSession method.
GetSession []struct {
// Ctx is the ctx argument value.
@ -88,6 +113,17 @@ type CapiClientMock struct {
// ID is the id argument value.
ID string
}
// ListSessionsByResourceID holds details about calls to the ListSessionsByResourceID method.
ListSessionsByResourceID []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// ResourceType is the resourceType argument value.
ResourceType string
// ResourceID is the resourceID argument value.
ResourceID int64
// Limit is the limit argument value.
Limit int
}
// ListSessionsForRepo holds details about calls to the ListSessionsForRepo method.
ListSessionsForRepo []struct {
// Ctx is the ctx argument value.
@ -107,11 +143,13 @@ type CapiClientMock struct {
Limit int
}
}
lockCreateJob sync.RWMutex
lockGetJob sync.RWMutex
lockGetSession sync.RWMutex
lockListSessionsForRepo sync.RWMutex
lockListSessionsForViewer sync.RWMutex
lockCreateJob sync.RWMutex
lockGetJob sync.RWMutex
lockGetPullRequestDatabaseID sync.RWMutex
lockGetSession sync.RWMutex
lockListSessionsByResourceID sync.RWMutex
lockListSessionsForRepo sync.RWMutex
lockListSessionsForViewer sync.RWMutex
}
// CreateJob calls CreateJobFunc.
@ -206,6 +244,54 @@ func (mock *CapiClientMock) GetJobCalls() []struct {
return calls
}
// GetPullRequestDatabaseID calls GetPullRequestDatabaseIDFunc.
func (mock *CapiClientMock) GetPullRequestDatabaseID(ctx context.Context, hostname string, owner string, repo string, number int) (int64, error) {
if mock.GetPullRequestDatabaseIDFunc == nil {
panic("CapiClientMock.GetPullRequestDatabaseIDFunc: method is nil but CapiClient.GetPullRequestDatabaseID was just called")
}
callInfo := struct {
Ctx context.Context
Hostname string
Owner string
Repo string
Number int
}{
Ctx: ctx,
Hostname: hostname,
Owner: owner,
Repo: repo,
Number: number,
}
mock.lockGetPullRequestDatabaseID.Lock()
mock.calls.GetPullRequestDatabaseID = append(mock.calls.GetPullRequestDatabaseID, callInfo)
mock.lockGetPullRequestDatabaseID.Unlock()
return mock.GetPullRequestDatabaseIDFunc(ctx, hostname, owner, repo, number)
}
// GetPullRequestDatabaseIDCalls gets all the calls that were made to GetPullRequestDatabaseID.
// Check the length with:
//
// len(mockedCapiClient.GetPullRequestDatabaseIDCalls())
func (mock *CapiClientMock) GetPullRequestDatabaseIDCalls() []struct {
Ctx context.Context
Hostname string
Owner string
Repo string
Number int
} {
var calls []struct {
Ctx context.Context
Hostname string
Owner string
Repo string
Number int
}
mock.lockGetPullRequestDatabaseID.RLock()
calls = mock.calls.GetPullRequestDatabaseID
mock.lockGetPullRequestDatabaseID.RUnlock()
return calls
}
// GetSession calls GetSessionFunc.
func (mock *CapiClientMock) GetSession(ctx context.Context, id string) (*Session, error) {
if mock.GetSessionFunc == nil {
@ -242,6 +328,50 @@ func (mock *CapiClientMock) GetSessionCalls() []struct {
return calls
}
// ListSessionsByResourceID calls ListSessionsByResourceIDFunc.
func (mock *CapiClientMock) ListSessionsByResourceID(ctx context.Context, resourceType string, resourceID int64, limit int) ([]*Session, error) {
if mock.ListSessionsByResourceIDFunc == nil {
panic("CapiClientMock.ListSessionsByResourceIDFunc: method is nil but CapiClient.ListSessionsByResourceID was just called")
}
callInfo := struct {
Ctx context.Context
ResourceType string
ResourceID int64
Limit int
}{
Ctx: ctx,
ResourceType: resourceType,
ResourceID: resourceID,
Limit: limit,
}
mock.lockListSessionsByResourceID.Lock()
mock.calls.ListSessionsByResourceID = append(mock.calls.ListSessionsByResourceID, callInfo)
mock.lockListSessionsByResourceID.Unlock()
return mock.ListSessionsByResourceIDFunc(ctx, resourceType, resourceID, limit)
}
// ListSessionsByResourceIDCalls gets all the calls that were made to ListSessionsByResourceID.
// Check the length with:
//
// len(mockedCapiClient.ListSessionsByResourceIDCalls())
func (mock *CapiClientMock) ListSessionsByResourceIDCalls() []struct {
Ctx context.Context
ResourceType string
ResourceID int64
Limit int
} {
var calls []struct {
Ctx context.Context
ResourceType string
ResourceID int64
Limit int
}
mock.lockListSessionsByResourceID.RLock()
calls = mock.calls.ListSessionsByResourceID
mock.lockListSessionsByResourceID.RUnlock()
return calls
}
// ListSessionsForRepo calls ListSessionsForRepoFunc.
func (mock *CapiClientMock) ListSessionsForRepo(ctx context.Context, owner string, repo string, limit int) ([]*Session, error) {
if mock.ListSessionsForRepoFunc == nil {

View file

@ -14,6 +14,7 @@ import (
"time"
"github.com/cli/cli/v2/api"
"github.com/shurcooL/githubv4"
"github.com/vmihailenco/msgpack/v5"
)
@ -131,7 +132,6 @@ func (c *CAPIClient) ListSessionsForViewer(ctx context.Context, limit int) ([]*S
sessions = sessions[:limit]
}
// Hydrate the result with pull request data.
result, err := c.hydrateSessionPullRequestsAndUsers(sessions)
if err != nil {
return nil, fmt.Errorf("failed to fetch session resources: %w", err)
@ -192,7 +192,6 @@ func (c *CAPIClient) ListSessionsForRepo(ctx context.Context, owner string, repo
sessions = sessions[:limit]
}
// Hydrate the result with pull request data.
result, err := c.hydrateSessionPullRequestsAndUsers(sessions)
if err != nil {
return nil, fmt.Errorf("failed to fetch session resources: %w", err)
@ -239,6 +238,65 @@ func (c *CAPIClient) GetSession(ctx context.Context, id string) (*Session, error
return sessions[0], nil
}
// ListSessionsByResourceID retrieves sessions associated with the given resource type and ID.
func (c *CAPIClient) ListSessionsByResourceID(ctx context.Context, resourceType string, resourceID int64, limit int) ([]*Session, error) {
if resourceType == "" || resourceID == 0 {
return nil, fmt.Errorf("missing resource type/ID")
}
if limit == 0 {
return nil, nil
}
url := fmt.Sprintf("%s/agents/sessions/resource/%s/%d", baseCAPIURL, url.PathEscape(resourceType), resourceID)
pageSize := defaultSessionsPerPage
sessions := make([]session, 0, limit+pageSize)
for page := 1; ; page++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
return nil, err
}
q := req.URL.Query()
q.Set("page_size", strconv.Itoa(pageSize))
q.Set("page_number", strconv.Itoa(page))
req.URL.RawQuery = q.Encode()
res, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to list sessions: %s", res.Status)
}
var response struct {
Sessions []session `json:"sessions"`
}
if err := json.NewDecoder(res.Body).Decode(&response); err != nil {
return nil, fmt.Errorf("failed to decode sessions response: %w", err)
}
sessions = append(sessions, response.Sessions...)
if len(response.Sessions) < pageSize || len(sessions) >= limit {
break
}
}
// Drop any above the limit
if len(sessions) > limit {
sessions = sessions[:limit]
}
result, err := c.hydrateSessionPullRequestsAndUsers(sessions)
if err != nil {
return nil, fmt.Errorf("failed to fetch session resources: %w", err)
}
return result, nil
}
// hydrateSessionPullRequestsAndUsers hydrates pull request and user information in sessions
func (c *CAPIClient) hydrateSessionPullRequestsAndUsers(sessions []session) ([]*Session, error) {
if len(sessions) == 0 {
@ -248,9 +306,11 @@ func (c *CAPIClient) hydrateSessionPullRequestsAndUsers(sessions []session) ([]*
prNodeIds := make([]string, 0, len(sessions))
userNodeIds := make([]string, 0, len(sessions))
for _, session := range sessions {
prNodeID := generatePullRequestNodeID(int64(session.RepoID), session.ResourceID)
if !slices.Contains(prNodeIds, prNodeID) {
prNodeIds = append(prNodeIds, prNodeID)
if session.ResourceType == "pull" {
prNodeID := generatePullRequestNodeID(int64(session.RepoID), session.ResourceID)
if !slices.Contains(prNodeIds, prNodeID) {
prNodeIds = append(prNodeIds, prNodeID)
}
}
userNodeId := generateUserNodeID(session.UserID)
@ -318,6 +378,34 @@ func (c *CAPIClient) hydrateSessionPullRequestsAndUsers(sessions []session) ([]*
return newSessions, nil
}
// GetPullRequestDatabaseID retrieves the database ID of a pull request given its number in a repository.
func (c *CAPIClient) GetPullRequestDatabaseID(ctx context.Context, hostname string, owner string, repo string, number int) (int64, error) {
var resp struct {
Repository struct {
PullRequest struct {
FullDatabaseID string `graphql:"fullDatabaseId"`
} `graphql:"pullRequest(number: $number)"`
} `graphql:"repository(owner: $owner, name: $repo)"`
}
variables := map[string]interface{}{
"owner": githubv4.String(owner),
"repo": githubv4.String(repo),
"number": githubv4.Int(number),
}
apiClient := api.NewClientFromHTTP(c.httpClient)
if err := apiClient.Query(hostname, "GetPullRequestFullDatabaseID", &resp, variables); err != nil {
return 0, err
}
databaseID, err := strconv.ParseInt(resp.Repository.PullRequest.FullDatabaseID, 10, 64)
if err != nil {
return 0, err
}
return databaseID, nil
}
// generatePullRequestNodeID converts an int64 databaseID and repoID to a GraphQL Node ID format
// with the "PR_" prefix for pull requests
func generatePullRequestNodeID(repoID, pullRequestID int64) string {

View file

@ -159,6 +159,84 @@ func TestListSessionsForViewer(t *testing.T) {
},
},
},
{
// This happens at the early moments of a session lifecycle, before a PR is created and associated with it.
name: "single session, no pull request resource",
limit: 10,
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.WithHost(
httpmock.QueryMatcher("GET", "agents/sessions", url.Values{
"page_number": {"1"},
"page_size": {"50"},
}),
"api.githubcopilot.com",
),
httpmock.StringResponse(heredoc.Docf(`
{
"sessions": [
{
"id": "sess1",
"name": "Build artifacts",
"user_id": 1,
"agent_id": 2,
"logs": "",
"state": "completed",
"owner_id": 10,
"repo_id": 1000,
"resource_type": "",
"resource_id": 0,
"created_at": "%[1]s"
}
]
}`,
sampleDateString,
)),
)
// GraphQL hydration
reg.Register(
httpmock.GraphQL(`query FetchPRsAndUsersForAgentTaskSessions\b`),
httpmock.GraphQLQuery(heredoc.Docf(`
{
"data": {
"nodes": [
{
"__typename": "User",
"login": "octocat",
"name": "Octocat",
"databaseId": 1
}
]
}
}`,
sampleDateString,
), func(q string, vars map[string]interface{}) {
assert.Equal(t, []interface{}{"U_kgAB"}, vars["ids"])
}),
)
},
wantOut: []*Session{
{
ID: "sess1",
Name: "Build artifacts",
UserID: 1,
AgentID: 2,
Logs: "",
State: "completed",
OwnerID: 10,
RepoID: 1000,
ResourceType: "",
ResourceID: 0,
CreatedAt: sampleDate,
User: &api.GitHubUser{
Login: "octocat",
Name: "Octocat",
DatabaseID: 1,
},
},
},
},
{
name: "multiple sessions, paginated",
perPage: 1, // to enforce pagination
@ -594,6 +672,84 @@ func TestListSessionsForRepo(t *testing.T) {
},
},
},
{
// This happens at the early moments of a session lifecycle, before a PR is created and associated with it.
name: "single session, no pull request resource",
limit: 10,
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.WithHost(
httpmock.QueryMatcher("GET", "agents/sessions/nwo/OWNER/REPO", url.Values{
"page_number": {"1"},
"page_size": {"50"},
}),
"api.githubcopilot.com",
),
httpmock.StringResponse(heredoc.Docf(`
{
"sessions": [
{
"id": "sess1",
"name": "Build artifacts",
"user_id": 1,
"agent_id": 2,
"logs": "",
"state": "completed",
"owner_id": 10,
"repo_id": 1000,
"resource_type": "",
"resource_id": 0,
"created_at": "%[1]s"
}
]
}`,
sampleDateString,
)),
)
// GraphQL hydration
reg.Register(
httpmock.GraphQL(`query FetchPRsAndUsersForAgentTaskSessions\b`),
httpmock.GraphQLQuery(heredoc.Docf(`
{
"data": {
"nodes": [
{
"__typename": "User",
"login": "octocat",
"name": "Octocat",
"databaseId": 1
}
]
}
}`,
sampleDateString,
), func(q string, vars map[string]interface{}) {
assert.Equal(t, []interface{}{"U_kgAB"}, vars["ids"])
}),
)
},
wantOut: []*Session{
{
ID: "sess1",
Name: "Build artifacts",
UserID: 1,
AgentID: 2,
Logs: "",
State: "completed",
OwnerID: 10,
RepoID: 1000,
ResourceType: "",
ResourceID: 0,
CreatedAt: sampleDate,
User: &api.GitHubUser{
Login: "octocat",
Name: "Octocat",
DatabaseID: 1,
},
},
},
},
{
name: "multiple sessions, paginated",
perPage: 1, // to enforce pagination
@ -876,6 +1032,445 @@ func TestListSessionsForRepo(t *testing.T) {
}
}
func TestListSessionsByResourceIDRequiresResource(t *testing.T) {
client := &CAPIClient{}
_, err := client.ListSessionsByResourceID(context.Background(), "", 999, 0)
assert.EqualError(t, err, "missing resource type/ID")
_, err = client.ListSessionsByResourceID(context.Background(), "only-resource-type", 0, 0)
assert.EqualError(t, err, "missing resource type/ID")
_, err = client.ListSessionsByResourceID(context.Background(), "", 0, 0)
assert.EqualError(t, err, "missing resource type/ID")
}
func TestListSessionsByResourceID(t *testing.T) {
sampleDateString := "2025-08-29T00:00:00Z"
sampleDate, err := time.Parse(time.RFC3339, sampleDateString)
require.NoError(t, err)
resourceID := int64(999)
resourceType := "pull"
tests := []struct {
name string
perPage int
limit int
httpStubs func(*testing.T, *httpmock.Registry)
wantErr string
wantOut []*Session
}{
{
name: "zero limit",
limit: 0,
wantOut: nil,
},
{
name: "no sessions",
limit: 10,
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.WithHost(
httpmock.QueryMatcher("GET", "agents/sessions/resource/pull/999", url.Values{
"page_number": {"1"},
"page_size": {"50"},
}),
"api.githubcopilot.com",
),
httpmock.StringResponse(`{"sessions":[]}`),
)
},
wantOut: nil,
},
{
name: "single session",
limit: 10,
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.WithHost(
httpmock.QueryMatcher("GET", "agents/sessions/resource/pull/999", url.Values{
"page_number": {"1"},
"page_size": {"50"},
}),
"api.githubcopilot.com",
),
httpmock.StringResponse(heredoc.Docf(`
{
"sessions": [
{
"id": "sess1",
"name": "Build artifacts",
"user_id": 1,
"agent_id": 2,
"logs": "",
"state": "completed",
"owner_id": 10,
"repo_id": 1000,
"resource_type": "pull",
"resource_id": 2000,
"created_at": "%[1]s"
}
]
}`,
sampleDateString,
)),
)
// GraphQL hydration
reg.Register(
httpmock.GraphQL(`query FetchPRsAndUsersForAgentTaskSessions\b`),
httpmock.GraphQLQuery(heredoc.Docf(`
{
"data": {
"nodes": [
{
"__typename": "PullRequest",
"id": "PR_node",
"fullDatabaseId": "2000",
"number": 42,
"title": "Improve docs",
"state": "OPEN",
"isDraft": true,
"url": "https://github.com/OWNER/REPO/pull/42",
"body": "",
"createdAt": "%[1]s",
"updatedAt": "%[1]s",
"repository": {
"nameWithOwner": "OWNER/REPO"
}
},
{
"__typename": "User",
"login": "octocat",
"name": "Octocat",
"databaseId": 1
}
]
}
}`,
sampleDateString,
), func(q string, vars map[string]interface{}) {
assert.Equal(t, []interface{}{"PR_kwDNA-jNB9A", "U_kgAB"}, vars["ids"])
}),
)
},
wantOut: []*Session{
{
ID: "sess1",
Name: "Build artifacts",
UserID: 1,
AgentID: 2,
Logs: "",
State: "completed",
OwnerID: 10,
RepoID: 1000,
ResourceType: "pull",
ResourceID: 2000,
CreatedAt: sampleDate,
PullRequest: &api.PullRequest{
ID: "PR_node",
FullDatabaseID: "2000",
Number: 42,
Title: "Improve docs",
State: "OPEN",
IsDraft: true,
URL: "https://github.com/OWNER/REPO/pull/42",
Body: "",
CreatedAt: sampleDate,
UpdatedAt: sampleDate,
Repository: &api.PRRepository{
NameWithOwner: "OWNER/REPO",
},
},
User: &api.GitHubUser{
Login: "octocat",
Name: "Octocat",
DatabaseID: 1,
},
},
},
},
{
name: "multiple sessions, paginated",
perPage: 1, // to enforce pagination
limit: 2,
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.WithHost(
httpmock.QueryMatcher("GET", "agents/sessions/resource/pull/999", url.Values{
"page_number": {"1"},
"page_size": {"1"},
}),
"api.githubcopilot.com",
),
httpmock.StringResponse(heredoc.Docf(`
{
"sessions": [
{
"id": "sess1",
"name": "Build artifacts",
"user_id": 1,
"agent_id": 2,
"logs": "",
"state": "completed",
"owner_id": 10,
"repo_id": 1000,
"resource_type": "pull",
"resource_id": 2000,
"created_at": "%[1]s"
}
]
}`,
sampleDateString,
)),
)
// Second page
reg.Register(
httpmock.WithHost(
httpmock.QueryMatcher("GET", "agents/sessions/resource/pull/999", url.Values{
"page_number": {"2"},
"page_size": {"1"},
}),
"api.githubcopilot.com",
),
httpmock.StringResponse(heredoc.Docf(`
{
"sessions": [
{
"id": "sess2",
"name": "Build artifacts",
"user_id": 1,
"agent_id": 2,
"logs": "",
"state": "completed",
"owner_id": 10,
"repo_id": 1000,
"resource_type": "pull",
"resource_id": 2001,
"created_at": "%[1]s"
}
]
}`,
sampleDateString,
)),
)
// GraphQL hydration
reg.Register(
httpmock.GraphQL(`query FetchPRsAndUsersForAgentTaskSessions\b`),
httpmock.GraphQLQuery(heredoc.Docf(`
{
"data": {
"nodes": [
{
"__typename": "PullRequest",
"id": "PR_node",
"fullDatabaseId": "2000",
"number": 42,
"title": "Improve docs",
"state": "OPEN",
"isDraft": true,
"url": "https://github.com/OWNER/REPO/pull/42",
"body": "",
"createdAt": "%[1]s",
"updatedAt": "%[1]s",
"repository": {
"nameWithOwner": "OWNER/REPO"
}
},
{
"__typename": "PullRequest",
"id": "PR_node",
"fullDatabaseId": "2001",
"number": 43,
"title": "Improve docs",
"state": "OPEN",
"isDraft": true,
"url": "https://github.com/OWNER/REPO/pull/43",
"body": "",
"createdAt": "%[1]s",
"updatedAt": "%[1]s",
"repository": {
"nameWithOwner": "OWNER/REPO"
}
},
{
"__typename": "User",
"login": "octocat",
"name": "Octocat",
"databaseId": 1
}
]
}
}`,
sampleDateString,
), func(q string, vars map[string]interface{}) {
assert.Equal(t, []interface{}{"PR_kwDNA-jNB9A", "PR_kwDNA-jNB9E", "U_kgAB"}, vars["ids"])
}),
)
},
wantOut: []*Session{
{
ID: "sess1",
Name: "Build artifacts",
UserID: 1,
AgentID: 2,
Logs: "",
State: "completed",
OwnerID: 10,
RepoID: 1000,
ResourceType: "pull",
ResourceID: 2000,
CreatedAt: sampleDate,
PullRequest: &api.PullRequest{
ID: "PR_node",
FullDatabaseID: "2000",
Number: 42,
Title: "Improve docs",
State: "OPEN",
IsDraft: true,
URL: "https://github.com/OWNER/REPO/pull/42",
Body: "",
CreatedAt: sampleDate,
UpdatedAt: sampleDate,
Repository: &api.PRRepository{
NameWithOwner: "OWNER/REPO",
},
},
User: &api.GitHubUser{
Login: "octocat",
Name: "Octocat",
DatabaseID: 1,
},
},
{
ID: "sess2",
Name: "Build artifacts",
UserID: 1,
AgentID: 2,
Logs: "",
State: "completed",
OwnerID: 10,
RepoID: 1000,
ResourceType: "pull",
ResourceID: 2001,
CreatedAt: sampleDate,
PullRequest: &api.PullRequest{
ID: "PR_node",
FullDatabaseID: "2001",
Number: 43,
Title: "Improve docs",
State: "OPEN",
IsDraft: true,
URL: "https://github.com/OWNER/REPO/pull/43",
Body: "",
CreatedAt: sampleDate,
UpdatedAt: sampleDate,
Repository: &api.PRRepository{
NameWithOwner: "OWNER/REPO",
},
},
User: &api.GitHubUser{
Login: "octocat",
Name: "Octocat",
DatabaseID: 1,
},
},
},
},
{
name: "API error",
limit: 10,
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.WithHost(
httpmock.QueryMatcher("GET", "agents/sessions/resource/pull/999", url.Values{
"page_number": {"1"},
"page_size": {"50"},
}),
"api.githubcopilot.com",
),
httpmock.StatusStringResponse(500, "{}"),
)
},
wantErr: "failed to list sessions:",
}, {
name: "API error at hydration",
limit: 10,
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.WithHost(
httpmock.QueryMatcher("GET", "agents/sessions/resource/pull/999", url.Values{
"page_number": {"1"},
"page_size": {"50"},
}),
"api.githubcopilot.com",
),
httpmock.StringResponse(heredoc.Docf(`
{
"sessions": [
{
"id": "sess1",
"name": "Build artifacts",
"user_id": 1,
"agent_id": 2,
"logs": "",
"state": "completed",
"owner_id": 10,
"repo_id": 1000,
"resource_type": "pull",
"resource_id": 2000,
"created_at": "%[1]s"
}
]
}`,
sampleDateString,
)),
)
// GraphQL hydration
reg.Register(
httpmock.GraphQL(`query FetchPRsAndUsersForAgentTaskSessions\b`),
httpmock.StatusStringResponse(500, `{}`),
)
},
wantErr: `failed to fetch session resources: non-200 OK status code:`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reg := &httpmock.Registry{}
if tt.httpStubs != nil {
tt.httpStubs(t, reg)
}
defer reg.Verify(t)
httpClient := &http.Client{Transport: reg}
cfg := config.NewBlankConfig()
capiClient := NewCAPIClient(httpClient, cfg.Authentication())
if tt.perPage != 0 {
last := defaultSessionsPerPage
defaultSessionsPerPage = tt.perPage
defer func() {
defaultSessionsPerPage = last
}()
}
sessions, err := capiClient.ListSessionsByResourceID(context.Background(), resourceType, resourceID, tt.limit)
if tt.wantErr != "" {
require.ErrorContains(t, err, tt.wantErr)
require.Nil(t, sessions)
return
}
require.NoError(t, err)
require.Equal(t, tt.wantOut, sessions)
})
}
}
func TestGetSessionRequiresID(t *testing.T) {
client := &CAPIClient{}
@ -1020,6 +1615,70 @@ func TestGetSession(t *testing.T) {
},
},
},
{
// This happens at the early moments of a session lifecycle, before a PR is created and associated with it.
name: "success, but no pull request resource",
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.WithHost(httpmock.REST("GET", "agents/sessions/some-uuid"), "api.githubcopilot.com"),
httpmock.StringResponse(heredoc.Docf(`
{
"id": "some-uuid",
"name": "Build artifacts",
"user_id": 1,
"agent_id": 2,
"logs": "",
"state": "completed",
"owner_id": 10,
"repo_id": 1000,
"resource_type": "",
"resource_id": 0,
"created_at": "%[1]s"
}`,
sampleDateString,
)),
)
// GraphQL hydration
reg.Register(
httpmock.GraphQL(`query FetchPRsAndUsersForAgentTaskSessions\b`),
httpmock.GraphQLQuery(heredoc.Docf(`
{
"data": {
"nodes": [
{
"__typename": "User",
"login": "octocat",
"name": "Octocat",
"databaseId": 1
}
]
}
}`,
sampleDateString,
), func(q string, vars map[string]interface{}) {
assert.Equal(t, []interface{}{"U_kgAB"}, vars["ids"])
}),
)
},
wantOut: &Session{
ID: "some-uuid",
Name: "Build artifacts",
UserID: 1,
AgentID: 2,
Logs: "",
State: "completed",
OwnerID: 10,
RepoID: 1000,
ResourceType: "",
ResourceID: 0,
CreatedAt: sampleDate,
User: &api.GitHubUser{
Login: "octocat",
Name: "Octocat",
DatabaseID: 1,
},
},
},
{
name: "API error at hydration",
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
@ -1082,3 +1741,73 @@ func TestGetSession(t *testing.T) {
})
}
}
func TestGetPullRequestDatabaseID(t *testing.T) {
tests := []struct {
name string
httpStubs func(*testing.T, *httpmock.Registry)
wantErr string
wantOut int64
}{
{
name: "graphql error",
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.WithHost(httpmock.GraphQL(`query GetPullRequestFullDatabaseID\b`), "api.github.com"),
httpmock.StringResponse(`{"data":{}, "errors": [{"message": "some gql error"}]}`),
)
},
wantErr: "some gql error",
},
{
// This never happens in practice and it's just to cover more code path
name: "non-int database ID",
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.WithHost(httpmock.GraphQL(`query GetPullRequestFullDatabaseID\b`), "api.github.com"),
httpmock.StringResponse(`{"data": {"repository": {"pullRequest": {"fullDatabaseId": "non-int"}}}}`),
)
},
wantErr: `strconv.ParseInt: parsing "non-int": invalid syntax`,
},
{
name: "success",
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.WithHost(httpmock.GraphQL(`query GetPullRequestFullDatabaseID\b`), "api.github.com"),
httpmock.GraphQLQuery(`{"data": {"repository": {"pullRequest": {"fullDatabaseId": "999"}}}}`, func(s string, m map[string]interface{}) {
assert.Equal(t, "OWNER", m["owner"])
assert.Equal(t, "REPO", m["repo"])
assert.Equal(t, float64(42), m["number"])
}),
)
},
wantOut: 999,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reg := &httpmock.Registry{}
if tt.httpStubs != nil {
tt.httpStubs(t, reg)
}
defer reg.Verify(t)
httpClient := &http.Client{Transport: reg}
cfg := config.NewBlankConfig()
capiClient := NewCAPIClient(httpClient, cfg.Authentication())
databaseID, err := capiClient.GetPullRequestDatabaseID(context.Background(), "github.com", "OWNER", "REPO", 42)
if tt.wantErr != "" {
require.ErrorContains(t, err, tt.wantErr)
require.Zero(t, databaseID)
return
}
require.NoError(t, err)
require.Equal(t, tt.wantOut, databaseID)
})
}
}

View file

@ -1,10 +1,14 @@
package shared
import (
"regexp"
"github.com/cli/cli/v2/pkg/cmd/agent-task/capi"
"github.com/cli/cli/v2/pkg/cmdutil"
)
var uuidRE = regexp.MustCompile(`^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$`)
func CapiClientFunc(f *cmdutil.Factory) func() (capi.CapiClient, error) {
return func() (capi.CapiClient, error) {
cfg, err := f.Config()
@ -21,3 +25,7 @@ func CapiClientFunc(f *cmdutil.Factory) func() (capi.CapiClient, error) {
return capi.NewCAPIClient(httpClient, authCfg), nil
}
}
func IsSessionID(s string) bool {
return uuidRE.MatchString(s)
}

View file

@ -0,0 +1,20 @@
package shared
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsSession(t *testing.T) {
assert.True(t, IsSessionID("00000000-0000-0000-0000-000000000000"))
assert.True(t, IsSessionID("e2fa49d2-f164-4a56-ab99-498090b8fcdf"))
assert.True(t, IsSessionID("E2FA49D2-F164-4A56-AB99-498090B8FCDF"))
assert.False(t, IsSessionID(""))
assert.False(t, IsSessionID(" "))
assert.False(t, IsSessionID("\n"))
assert.False(t, IsSessionID("not-a-uuid"))
assert.False(t, IsSessionID("000000000000000000000000000000000000"))
assert.False(t, IsSessionID("00000000-0000-0000-0000-000000000000-extra"))
}

View file

@ -24,6 +24,7 @@ func ColorFuncForSessionState(s capi.Session, cs *iostreams.ColorScheme) func(st
return stateColor
}
// SessionStateString returns the humane/capitalised form of the given session state.
func SessionStateString(state string) string {
switch state {
case "queued":
@ -46,3 +47,17 @@ func SessionStateString(state string) string {
return state
}
}
type ColorFunc func(string) string
func SessionSymbol(cs *iostreams.ColorScheme, state string) string {
noColor := func(s string) string { return s }
switch state {
case "completed":
return cs.SuccessIconWithColor(noColor)
case "failed", "timed_out", "cancelled":
return cs.FailureIconWithColor(noColor)
default:
return "-"
}
}

View file

@ -4,10 +4,15 @@ import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"time"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/internal/prompter"
"github.com/cli/cli/v2/internal/text"
"github.com/cli/cli/v2/pkg/cmd/agent-task/capi"
"github.com/cli/cli/v2/pkg/cmd/agent-task/shared"
@ -17,28 +22,54 @@ import (
"github.com/spf13/cobra"
)
const defaultLimit = 40
type ViewOptions struct {
IO *iostreams.IOStreams
BaseRepo func() (ghrepo.Interface, error)
CapiClient func() (capi.CapiClient, error)
HttpClient func() (*http.Client, error)
Finder prShared.PRFinder
Prompter prompter.Prompter
SelectorArg string
PRNumber int
SessionID string
}
func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Command {
opts := &ViewOptions{
IO: f.IOStreams,
HttpClient: f.HttpClient,
CapiClient: shared.CapiClientFunc(f),
Prompter: f.Prompter,
}
cmd := &cobra.Command{
Use: "view <session-id>",
Use: "view [<session-id> | <pr-number> | <pr-url> | <pr-branch>]",
Short: "View an agent task session",
Long: heredoc.Doc(`
View an agent task session.
`),
Args: cmdutil.ExactArgs(1, "a session ID is required"),
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
opts.SelectorArg = args[0]
// Support -R/--repo override
opts.BaseRepo = f.BaseRepo
if len(args) > 0 {
opts.SelectorArg = args[0]
if shared.IsSessionID(opts.SelectorArg) {
opts.SessionID = opts.SelectorArg
}
}
if opts.SessionID == "" && !opts.IO.CanPrompt() {
return fmt.Errorf("session ID is required when not running interactively")
}
if opts.Finder == nil {
opts.Finder = prShared.NewFinder(f)
}
if runF != nil {
return runF(opts)
@ -47,6 +78,8 @@ func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Comman
},
}
cmdutil.EnableRepoOverride(cmd, f)
return cmd
}
@ -57,23 +90,113 @@ func viewRun(opts *ViewOptions) error {
}
ctx := context.Background()
cs := opts.IO.ColorScheme()
opts.IO.StartProgressIndicatorWithLabel("Fetching agent session...")
defer opts.IO.StopProgressIndicator()
session, err := capiClient.GetSession(ctx, opts.SelectorArg)
opts.IO.StopProgressIndicator()
var session *capi.Session
if err != nil {
if errors.Is(err, capi.ErrSessionNotFound) {
fmt.Fprintln(opts.IO.ErrOut, "session not found")
if opts.SessionID != "" {
if sess, err := capiClient.GetSession(ctx, opts.SessionID); err != nil {
if errors.Is(err, capi.ErrSessionNotFound) {
fmt.Fprintln(opts.IO.ErrOut, "session not found")
return cmdutil.SilentError
}
return err
} else {
session = sess
}
} else {
var resourceID int64
if opts.SelectorArg != "" {
// Finder does not support the PR/issue reference format (e.g. owner/repo#123)
// so we need to check if the selector arg is a reference and fetch the PR
// directly.
if repo, num, err := prShared.ParseFullReference(opts.SelectorArg); err == nil {
// Since the selector was a reference (i.e. without hostname data), we need to
// check the base repo to get the hostname.
baseRepo, err := opts.BaseRepo()
if err != nil {
return err
}
hostname := baseRepo.RepoHost()
if hostname != ghinstance.Default() {
return fmt.Errorf("agent tasks are not supported on this host: %s", hostname)
}
resourceID, err = capiClient.GetPullRequestDatabaseID(ctx, hostname, repo.RepoOwner(), repo.RepoName(), num)
if err != nil {
return fmt.Errorf("failed to fetch pull request: %w", err)
}
}
}
if resourceID == 0 {
findOptions := prShared.FindOptions{
Selector: opts.SelectorArg,
Fields: []string{"id", "url", "fullDatabaseId"},
}
pr, repo, err := opts.Finder.Find(findOptions)
if err != nil {
return err
}
if repo.RepoHost() != ghinstance.Default() {
return fmt.Errorf("agent tasks are not supported on this host: %s", repo.RepoHost())
}
databaseID, err := strconv.ParseInt(pr.FullDatabaseID, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse pull request: %w", err)
}
resourceID = databaseID
}
// TODO(babakks): currently we just fetch a pre-defined number of
// matching sessions to avoid hitting the API too many times, but it's
// technically possible for a PR to be associated with lots of sessions
// (i.e. above our selected limit).
sessions, err := capiClient.ListSessionsByResourceID(ctx, "pull", resourceID, defaultLimit)
if err != nil {
return fmt.Errorf("failed to list sessions for pull request: %w", err)
}
if len(sessions) == 0 {
fmt.Fprintln(opts.IO.ErrOut, "no session found for pull request")
return cmdutil.SilentError
}
return err
session = sessions[0]
if len(sessions) > 1 {
now := time.Now()
options := make([]string, 0, len(sessions))
for _, session := range sessions {
options = append(options, fmt.Sprintf(
"%s %s • %s",
shared.SessionSymbol(cs, session.State),
session.Name,
text.FuzzyAgo(now, session.CreatedAt),
))
}
opts.IO.StopProgressIndicator()
selected, err := opts.Prompter.Select("Select a session", options[0], options)
if err != nil {
return err
}
session = sessions[selected]
}
}
opts.IO.StopProgressIndicator()
out := opts.IO.Out
cs := opts.IO.ColorScheme()
if session.PullRequest != nil {
fmt.Fprintf(out, "%s • %s • %s%s\n",
@ -83,7 +206,7 @@ func viewRun(opts *ViewOptions) error {
cs.ColorFromString(prShared.ColorForPRState(*session.PullRequest))(fmt.Sprintf("#%d", session.PullRequest.Number)),
)
} else {
// Should never happen, but we need to cover the path
// This can happen when the session is just created and a PR is not yet available for it
fmt.Fprintf(out, "%s\n", shared.ColorFuncForSessionState(*session, cs)(shared.SessionStateString(session.State)))
}

View file

@ -10,7 +10,10 @@ import (
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/internal/prompter"
"github.com/cli/cli/v2/pkg/cmd/agent-task/capi"
prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/google/shlex"
@ -20,20 +23,49 @@ import (
func TestNewCmdList(t *testing.T) {
tests := []struct {
name string
args string
wantOpts ViewOptions
wantErr string
name string
tty bool
args string
wantOpts ViewOptions
wantBaseRepo ghrepo.Interface
wantErr string
}{
{
name: "no arguments",
wantErr: "a session ID is required",
name: "no arg tty",
tty: true,
args: "",
wantOpts: ViewOptions{},
},
{
name: "session ID arg",
args: "some-uuid",
name: "session ID arg tty",
tty: true,
args: "00000000-0000-0000-0000-000000000000",
wantOpts: ViewOptions{
SelectorArg: "some-uuid",
SelectorArg: "00000000-0000-0000-0000-000000000000",
SessionID: "00000000-0000-0000-0000-000000000000",
},
},
{
name: "non-session ID arg tty",
tty: true,
args: "some-arg",
wantOpts: ViewOptions{
SelectorArg: "some-arg",
},
},
{
name: "session ID required if non-tty",
tty: false,
args: "some-arg",
wantErr: "session ID is required when not running interactively",
},
{
name: "repo override",
tty: true,
args: "some-arg -R OWNER/REPO",
wantBaseRepo: ghrepo.New("OWNER", "REPO"),
wantOpts: ViewOptions{
SelectorArg: "some-arg",
},
},
}
@ -41,6 +73,10 @@ func TestNewCmdList(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ios, _, _, _ := iostreams.Test()
ios.SetStdinTTY(tt.tty)
ios.SetStdoutTTY(tt.tty)
ios.SetStderrTTY(tt.tty)
f := &cmdutil.Factory{
IOStreams: ios,
}
@ -65,6 +101,13 @@ func TestNewCmdList(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, tt.wantOpts.SelectorArg, gotOpts.SelectorArg)
assert.Equal(t, tt.wantOpts.SessionID, gotOpts.SessionID)
if tt.wantBaseRepo != nil {
baseRepo, err := gotOpts.BaseRepo()
require.NoError(t, err)
assert.True(t, ghrepo.IsSame(tt.wantBaseRepo, baseRepo))
}
})
}
}
@ -74,19 +117,23 @@ func Test_viewRun(t *testing.T) {
tests := []struct {
name string
selectorArg string
tty bool
opts ViewOptions
promptStubs func(*testing.T, *prompter.MockPrompter)
capiStubs func(*testing.T, *capi.CapiClientMock)
wantOut string
wantErr error
wantStderr string
}{
{
name: "not found (tty)",
tty: true,
selectorArg: "some-session-id",
name: "with session id, not found (tty)",
tty: true,
opts: ViewOptions{
SelectorArg: "some-session-id",
SessionID: "some-session-id",
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) {
m.GetSessionFunc = func(_ context.Context, _ string) (*capi.Session, error) {
return nil, capi.ErrSessionNotFound
}
},
@ -94,43 +141,29 @@ func Test_viewRun(t *testing.T) {
wantErr: cmdutil.SilentError,
},
{
name: "not found (nontty)",
selectorArg: "some-session-id",
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) {
return nil, capi.ErrSessionNotFound
}
name: "with session id, api error (tty)",
tty: true,
opts: ViewOptions{
SelectorArg: "some-session-id",
SessionID: "some-session-id",
},
wantStderr: "session not found\n",
wantErr: cmdutil.SilentError,
},
{
name: "API error (tty)",
tty: true,
selectorArg: "some-session-id",
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) {
m.GetSessionFunc = func(_ context.Context, _ string) (*capi.Session, error) {
return nil, errors.New("some error")
}
},
wantErr: errors.New("some error"),
},
{
name: "API error (nontty)",
selectorArg: "some-session-id",
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) {
return nil, errors.New("some error")
}
name: "with session id, success, with pr and user data (tty)",
tty: true,
opts: ViewOptions{
SelectorArg: "some-session-id",
SessionID: "some-session-id",
},
wantErr: errors.New("some error"),
},
{
name: "success, with PR and user data (tty)",
tty: true,
selectorArg: "some-session-id",
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) {
m.GetSessionFunc = func(_ context.Context, id string) (*capi.Session, error) {
assert.Equal(t, "some-session-id", id)
return &capi.Session{
ID: "some-session-id",
State: "completed",
@ -152,17 +185,22 @@ func Test_viewRun(t *testing.T) {
wantOut: heredoc.Doc(`
Completed fix something OWNER/REPO#101
Started on behalf of octocat about 6 hours ago
View this session on GitHub:
https://github.com/OWNER/REPO/pull/101/agent-sessions/some-session-id
`),
},
{
name: "success, without user data (tty)",
tty: true,
selectorArg: "some-session-id",
// The user data should always be there, but we need to cover the code path.
name: "with session id, success, without user data (tty)",
tty: true,
opts: ViewOptions{
SelectorArg: "some-session-id",
SessionID: "some-session-id",
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) {
m.GetSessionFunc = func(_ context.Context, id string) (*capi.Session, error) {
assert.Equal(t, "some-session-id", id)
return &capi.Session{
ID: "some-session-id",
State: "completed",
@ -181,17 +219,22 @@ func Test_viewRun(t *testing.T) {
wantOut: heredoc.Doc(`
Completed fix something OWNER/REPO#101
Started about 6 hours ago
View this session on GitHub:
https://github.com/OWNER/REPO/pull/101/agent-sessions/some-session-id
`),
},
{
name: "success, without PR data (tty)",
tty: true,
selectorArg: "some-session-id",
// This can happen when the session is just created and a PR is not yet available for it.
name: "with session id, success, without pr data (tty)",
tty: true,
opts: ViewOptions{
SelectorArg: "some-session-id",
SessionID: "some-session-id",
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) {
m.GetSessionFunc = func(_ context.Context, id string) (*capi.Session, error) {
assert.Equal(t, "some-session-id", id)
return &capi.Session{
ID: "some-session-id",
State: "completed",
@ -208,11 +251,16 @@ func Test_viewRun(t *testing.T) {
`),
},
{
name: "success, without PR nor user data (tty)",
tty: true,
selectorArg: "some-session-id",
// The user data should always be there, but we need to cover the code path.
name: "with session id, success, without pr nor user data (tty)",
tty: true,
opts: ViewOptions{
SelectorArg: "some-session-id",
SessionID: "some-session-id",
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.GetSessionFunc = func(ctx context.Context, selector string) (*capi.Session, error) {
m.GetSessionFunc = func(_ context.Context, id string) (*capi.Session, error) {
assert.Equal(t, "some-session-id", id)
return &capi.Session{
ID: "some-session-id",
State: "completed",
@ -225,6 +273,267 @@ func Test_viewRun(t *testing.T) {
Started about 6 hours ago
`),
},
{
name: "with pr number, api error (tty)",
tty: true,
opts: ViewOptions{
SelectorArg: "pr-number",
Finder: prShared.NewMockFinder(
"pr-number",
&api.PullRequest{FullDatabaseID: "999999"},
ghrepo.New("OWNER", "REPO"),
),
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.ListSessionsByResourceIDFunc = func(_ context.Context, _ string, _ int64, _ int) ([]*capi.Session, error) {
return nil, errors.New("some error")
}
},
wantErr: errors.New("failed to list sessions for pull request: some error"),
},
{
name: "with pr reference, unsupported hostname (tty)",
tty: true,
opts: ViewOptions{
SelectorArg: "OWNER/REPO#999",
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.NewWithHost("OWNER", "REPO", "foo.com"), nil
},
},
wantErr: errors.New("agent tasks are not supported on this host: foo.com"),
},
{
name: "with pr reference, api error when fetching pr database ID (tty)",
tty: true,
opts: ViewOptions{
SelectorArg: "OWNER/REPO#999",
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.New("OWNER", "REPO"), nil
},
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.GetPullRequestDatabaseIDFunc = func(_ context.Context, _ string, _ string, _ string, _ int) (int64, error) {
return 0, errors.New("some error")
}
},
wantErr: errors.New("failed to fetch pull request: some error"),
},
{
name: "with pr reference, api error when fetching session (tty)",
tty: true,
opts: ViewOptions{
SelectorArg: "OWNER/REPO#999",
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.New("OWNER", "REPO"), nil
},
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.GetPullRequestDatabaseIDFunc = func(_ context.Context, _ string, _ string, _ string, _ int) (int64, error) {
return 999999, nil
}
m.ListSessionsByResourceIDFunc = func(_ context.Context, _ string, _ int64, _ int) ([]*capi.Session, error) {
return nil, errors.New("some error")
}
},
wantErr: errors.New("failed to list sessions for pull request: some error"),
},
{
name: "with pr number, success, single session with pr and user data (tty)",
tty: true,
opts: ViewOptions{
SelectorArg: "pr-number",
Finder: prShared.NewMockFinder(
"pr-number",
&api.PullRequest{FullDatabaseID: "999999"},
ghrepo.New("OWNER", "REPO"),
),
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.ListSessionsByResourceIDFunc = func(_ context.Context, resourceType string, resourceID int64, limit int) ([]*capi.Session, error) {
assert.Equal(t, "pull", resourceType)
assert.Equal(t, int64(999999), resourceID)
assert.Equal(t, defaultLimit, limit)
return []*capi.Session{
{
ID: "some-session-id",
State: "completed",
CreatedAt: sampleDate,
PullRequest: &api.PullRequest{
Title: "fix something",
Number: 101,
URL: "https://github.com/OWNER/REPO/pull/101",
Repository: &api.PRRepository{
NameWithOwner: "OWNER/REPO",
},
},
User: &api.GitHubUser{
Login: "octocat",
},
},
}, nil
}
},
wantOut: heredoc.Doc(`
Completed fix something OWNER/REPO#101
Started on behalf of octocat about 6 hours ago
View this session on GitHub:
https://github.com/OWNER/REPO/pull/101/agent-sessions/some-session-id
`),
},
{
name: "with pr number, success, multiple sessions with pr and user data (tty)",
tty: true,
opts: ViewOptions{
SelectorArg: "pr-number",
Finder: prShared.NewMockFinder(
"pr-number",
&api.PullRequest{FullDatabaseID: "999999"},
ghrepo.New("OWNER", "REPO"),
),
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.ListSessionsByResourceIDFunc = func(_ context.Context, resourceType string, resourceID int64, limit int) ([]*capi.Session, error) {
assert.Equal(t, "pull", resourceType)
assert.Equal(t, int64(999999), resourceID)
assert.Equal(t, defaultLimit, limit)
return []*capi.Session{
{
ID: "some-session-id",
Name: "session one",
State: "completed",
CreatedAt: sampleDate,
PullRequest: &api.PullRequest{
Title: "fix something",
Number: 101,
URL: "https://github.com/OWNER/REPO/pull/101",
Repository: &api.PRRepository{
NameWithOwner: "OWNER/REPO",
},
},
User: &api.GitHubUser{
Login: "octocat",
},
},
{
ID: "some-other-session-id",
Name: "session two",
State: "completed",
CreatedAt: sampleDate,
PullRequest: &api.PullRequest{
Title: "fix something else",
Number: 102,
URL: "https://github.com/OWNER/REPO/pull/102",
Repository: &api.PRRepository{
NameWithOwner: "OWNER/REPO",
},
},
User: &api.GitHubUser{
Login: "octocat",
},
},
}, nil
}
},
promptStubs: func(t *testing.T, pm *prompter.MockPrompter) {
pm.RegisterSelect(
"Select a session",
[]string{
"✓ session one • about 6 hours ago",
"✓ session two • about 6 hours ago",
},
func(_, _ string, opts []string) (int, error) {
return prompter.IndexFor(opts, "✓ session one • about 6 hours ago")
},
)
},
wantOut: heredoc.Doc(`
Completed fix something OWNER/REPO#101
Started on behalf of octocat about 6 hours ago
View this session on GitHub:
https://github.com/OWNER/REPO/pull/101/agent-sessions/some-session-id
`),
},
{
name: "with pr reference, success, multiple sessions with pr and user data (tty)",
tty: true,
opts: ViewOptions{
SelectorArg: "OWNER/REPO#999",
BaseRepo: func() (ghrepo.Interface, error) {
return ghrepo.New("OWNER", "REPO"), nil
},
},
capiStubs: func(t *testing.T, m *capi.CapiClientMock) {
m.GetPullRequestDatabaseIDFunc = func(_ context.Context, hostname string, owner string, repo string, number int) (int64, error) {
assert.Equal(t, "github.com", hostname)
assert.Equal(t, "OWNER", owner)
assert.Equal(t, "REPO", repo)
assert.Equal(t, 999, number)
return 999999, nil
}
m.ListSessionsByResourceIDFunc = func(_ context.Context, resourceType string, resourceID int64, limit int) ([]*capi.Session, error) {
assert.Equal(t, "pull", resourceType)
assert.Equal(t, int64(999999), resourceID)
assert.Equal(t, defaultLimit, limit)
return []*capi.Session{
{
ID: "some-session-id",
Name: "session one",
State: "completed",
CreatedAt: sampleDate,
PullRequest: &api.PullRequest{
Title: "fix something",
Number: 101,
URL: "https://github.com/OWNER/REPO/pull/101",
Repository: &api.PRRepository{
NameWithOwner: "OWNER/REPO",
},
},
User: &api.GitHubUser{
Login: "octocat",
},
},
{
ID: "some-other-session-id",
Name: "session two",
State: "completed",
CreatedAt: sampleDate,
PullRequest: &api.PullRequest{
Title: "fix something else",
Number: 102,
URL: "https://github.com/OWNER/REPO/pull/102",
Repository: &api.PRRepository{
NameWithOwner: "OWNER/REPO",
},
},
User: &api.GitHubUser{
Login: "octocat",
},
},
}, nil
}
},
promptStubs: func(t *testing.T, pm *prompter.MockPrompter) {
pm.RegisterSelect(
"Select a session",
[]string{
"✓ session one • about 6 hours ago",
"✓ session two • about 6 hours ago",
},
func(_, _ string, opts []string) (int, error) {
return prompter.IndexFor(opts, "✓ session one • about 6 hours ago")
},
)
},
wantOut: heredoc.Doc(`
Completed fix something OWNER/REPO#101
Started on behalf of octocat about 6 hours ago
View this session on GitHub:
https://github.com/OWNER/REPO/pull/101/agent-sessions/some-session-id
`),
},
}
for _, tt := range tests {
@ -234,18 +543,22 @@ func Test_viewRun(t *testing.T) {
tt.capiStubs(t, capiClientMock)
}
prompter := prompter.NewMockPrompter(t)
if tt.promptStubs != nil {
tt.promptStubs(t, prompter)
}
ios, _, stdout, stderr := iostreams.Test()
ios.SetStdoutTTY(tt.tty)
opts := &ViewOptions{
IO: ios,
CapiClient: func() (capi.CapiClient, error) {
return capiClientMock, nil
},
SelectorArg: tt.selectorArg,
opts := tt.opts
opts.IO = ios
opts.Prompter = prompter
opts.CapiClient = func() (capi.CapiClient, error) {
return capiClientMock, nil
}
err := viewRun(opts)
err := viewRun(&opts)
if tt.wantErr != nil {
assert.Error(t, err)
require.EqualError(t, err, tt.wantErr.Error())

View file

@ -328,6 +328,31 @@ func ParseURL(prURL string) (ghrepo.Interface, int, error) {
return repo, prNumber, nil
}
var fullReferenceRE = regexp.MustCompile(`^(?:([^/]+)/([^/]+))#(\d+)$`)
// ParseFullReference parses a short issue/pull request reference of the form
// "owner/repo#number", where owner, repo and number are all required.
func ParseFullReference(s string) (ghrepo.Interface, int, error) {
if s == "" {
return nil, 0, errors.New("empty reference")
}
m := fullReferenceRE.FindStringSubmatch(s)
if m == nil {
return nil, 0, fmt.Errorf("invalid reference: %q", s)
}
number, err := strconv.Atoi(m[3])
if err != nil {
return nil, 0, fmt.Errorf("invalid reference: %q", number)
}
owner := m[1]
repo := m[2]
return ghrepo.New(owner, repo), number, nil
}
func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) {
type response struct {
Repository struct {

View file

@ -11,6 +11,7 @@ import (
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -78,6 +79,80 @@ func TestParseURL(t *testing.T) {
}
}
func TestParseFullReference(t *testing.T) {
tests := []struct {
name string
arg string
wantRepo ghrepo.Interface
wantNumber int
wantErr string
}{
{
name: "number",
arg: "123",
wantErr: `invalid reference: "123"`,
},
{
name: "number with hash",
arg: "#123",
wantErr: `invalid reference: "#123"`,
},
{
name: "full form",
arg: "OWNER/REPO#123",
wantNumber: 123,
wantRepo: ghrepo.New("OWNER", "REPO"),
},
{
name: "empty",
wantErr: "empty reference",
},
{
name: "invalid full form, without hash",
arg: "OWNER/REPO123",
wantErr: `invalid reference: "OWNER/REPO123"`,
},
{
name: "invalid full form, empty owner and repo",
arg: "/#123",
wantErr: `invalid reference: "/#123"`,
},
{
name: "invalid full form, without owner",
arg: "REPO#123",
wantErr: `invalid reference: "REPO#123"`,
},
{
name: "invalid full form, without repo",
arg: "OWNER/#123",
wantErr: `invalid reference: "OWNER/#123"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo, number, err := ParseFullReference(tt.arg)
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
assert.Nil(t, repo)
assert.Zero(t, number)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantNumber, number)
if tt.wantRepo != nil {
require.NotNil(t, repo)
assert.True(t, ghrepo.IsSame(tt.wantRepo, repo))
} else {
assert.Nil(t, repo)
}
})
}
}
type args struct {
baseRepoFn func() (ghrepo.Interface, error)
branchFn func() (string, error)