diff --git a/pkg/cmd/agent-task/capi/client.go b/pkg/cmd/agent-task/capi/client.go index ecec9a024..15765552b 100644 --- a/pkg/cmd/agent-task/capi/client.go +++ b/pkg/cmd/agent-task/capi/client.go @@ -19,6 +19,7 @@ type CapiClient interface { ListSessionsForRepo(ctx context.Context, owner string, repo string, limit int) ([]*Session, error) CreateJob(ctx context.Context, owner, repo, problemStatement, baseBranch string) (*Job, error) GetJob(ctx context.Context, owner, repo, jobID string) (*Job, error) + GetSession(ctx context.Context, id string) (*Session, error) } // CAPIClient is a client for interacting with the Copilot API diff --git a/pkg/cmd/agent-task/capi/client_mock.go b/pkg/cmd/agent-task/capi/client_mock.go index 621585587..ba7c05ab0 100644 --- a/pkg/cmd/agent-task/capi/client_mock.go +++ b/pkg/cmd/agent-task/capi/client_mock.go @@ -24,6 +24,9 @@ var _ CapiClient = &CapiClientMock{} // GetJobFunc: func(ctx context.Context, owner string, repo string, jobID string) (*Job, error) { // panic("mock out the GetJob method") // }, +// GetSessionFunc: func(ctx context.Context, id string) (*Session, error) { +// panic("mock out the GetSession method") +// }, // ListSessionsForRepoFunc: func(ctx context.Context, owner string, repo string, limit int) ([]*Session, error) { // panic("mock out the ListSessionsForRepo method") // }, @@ -43,6 +46,9 @@ type CapiClientMock struct { // GetJobFunc mocks the GetJob method. GetJobFunc func(ctx context.Context, owner string, repo string, jobID string) (*Job, error) + // GetSessionFunc mocks the GetSession method. + GetSessionFunc func(ctx context.Context, id string) (*Session, error) + // ListSessionsForRepoFunc mocks the ListSessionsForRepo method. ListSessionsForRepoFunc func(ctx context.Context, owner string, repo string, limit int) ([]*Session, error) @@ -75,6 +81,13 @@ type CapiClientMock struct { // JobID is the jobID argument value. JobID string } + // GetSession holds details about calls to the GetSession method. + GetSession []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // ID is the id argument value. + ID string + } // ListSessionsForRepo holds details about calls to the ListSessionsForRepo method. ListSessionsForRepo []struct { // Ctx is the ctx argument value. @@ -96,6 +109,7 @@ type CapiClientMock struct { } lockCreateJob sync.RWMutex lockGetJob sync.RWMutex + lockGetSession sync.RWMutex lockListSessionsForRepo sync.RWMutex lockListSessionsForViewer sync.RWMutex } @@ -192,6 +206,42 @@ func (mock *CapiClientMock) GetJobCalls() []struct { return calls } +// GetSession calls GetSessionFunc. +func (mock *CapiClientMock) GetSession(ctx context.Context, id string) (*Session, error) { + if mock.GetSessionFunc == nil { + panic("CapiClientMock.GetSessionFunc: method is nil but CapiClient.GetSession was just called") + } + callInfo := struct { + Ctx context.Context + ID string + }{ + Ctx: ctx, + ID: id, + } + mock.lockGetSession.Lock() + mock.calls.GetSession = append(mock.calls.GetSession, callInfo) + mock.lockGetSession.Unlock() + return mock.GetSessionFunc(ctx, id) +} + +// GetSessionCalls gets all the calls that were made to GetSession. +// Check the length with: +// +// len(mockedCapiClient.GetSessionCalls()) +func (mock *CapiClientMock) GetSessionCalls() []struct { + Ctx context.Context + ID string +} { + var calls []struct { + Ctx context.Context + ID string + } + mock.lockGetSession.RLock() + calls = mock.calls.GetSession + mock.lockGetSession.RUnlock() + return calls +} + // ListSessionsForRepo calls ListSessionsForRepoFunc. func (mock *CapiClientMock) ListSessionsForRepo(ctx context.Context, owner string, repo string, limit int) ([]*Session, error) { if mock.ListSessionsForRepoFunc == nil {