diff --git a/internal/codespaces/api/api.go b/internal/codespaces/api/api.go index 0f114cdbe..fd1bba99e 100644 --- a/internal/codespaces/api/api.go +++ b/internal/codespaces/api/api.go @@ -34,7 +34,6 @@ import ( "fmt" "io/ioutil" "net/http" - "strconv" "strings" "time" @@ -410,14 +409,15 @@ func (a *API) GetCodespaceRegionLocation(ctx context.Context) (string, error) { return response.Current, nil } -type SKU struct { +type Machine struct { Name string `json:"name"` DisplayName string `json:"display_name"` } -// GetCodespacesSKUs returns the available SKUs for the user for a given repo, branch and location. -func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Repository, branch, location string) ([]*SKU, error) { - req, err := http.NewRequest(http.MethodGet, a.githubAPI+"/vscs_internal/user/"+user.Login+"/skus", nil) +// GetCodespacesMachines returns the codespaces machines for the given repo, branch and location. +func (a *API) GetCodespacesMachines(ctx context.Context, repoID int, branch, location string) ([]*Machine, error) { + reqURL := fmt.Sprintf("%s/repositories/%d/codespaces/machines", a.githubAPI, repoID) + req, err := http.NewRequest(http.MethodGet, reqURL, nil) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) } @@ -425,11 +425,10 @@ func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Rep q := req.URL.Query() q.Add("location", location) q.Add("ref", branch) - q.Add("repository_id", strconv.Itoa(repository.ID)) req.URL.RawQuery = q.Encode() a.setHeaders(req) - resp, err := a.do(ctx, req, "/vscs_internal/user/*/skus") + resp, err := a.do(ctx, req, "/repositories/*/codespaces/machines") if err != nil { return nil, fmt.Errorf("error making request: %w", err) } @@ -445,13 +444,13 @@ func (a *API) GetCodespacesSKUs(ctx context.Context, user *User, repository *Rep } var response struct { - SKUs []*SKU `json:"skus"` + Machines []*Machine `json:"machines"` } if err := json.Unmarshal(b, &response); err != nil { return nil, fmt.Errorf("error unmarshaling response: %w", err) } - return response.SKUs, nil + return response.Machines, nil } // CreateCodespaceParams are the required parameters for provisioning a Codespace. diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index f0ad80c35..831db7366 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -42,7 +42,7 @@ type apiClient interface { GetRepository(ctx context.Context, nwo string) (*api.Repository, error) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) GetCodespaceRegionLocation(ctx context.Context) (string, error) - GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch, location string) ([]*api.SKU, error) + GetCodespacesMachines(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error) GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) } diff --git a/pkg/cmd/codespace/create.go b/pkg/cmd/codespace/create.go index 2f0f628f5..96342f27d 100644 --- a/pkg/cmd/codespace/create.go +++ b/pkg/cmd/codespace/create.go @@ -71,7 +71,7 @@ func (a *App) Create(ctx context.Context, opts createOptions) error { return fmt.Errorf("error getting codespace user: %w", userResult.Err) } - machine, err := getMachineName(ctx, opts.machine, userResult.User, repository, branch, locationResult.Location, a.apiClient) + machine, err := getMachineName(ctx, a.apiClient, repository.ID, opts.machine, branch, locationResult.Location) if err != nil { return fmt.Errorf("error getting machine type: %w", err) } @@ -234,8 +234,8 @@ func getBranchName(branch string) (string, error) { } // getMachineName prompts the user to select the machine type, or validates the machine if non-empty. -func getMachineName(ctx context.Context, machine string, user *api.User, repo *api.Repository, branch, location string, apiClient apiClient) (string, error) { - skus, err := apiClient.GetCodespacesSKUs(ctx, user, repo, branch, location) +func getMachineName(ctx context.Context, apiClient apiClient, repoID int, machine, branch, location string) (string, error) { + machines, err := apiClient.GetCodespacesMachines(ctx, repoID, branch, location) if err != nil { return "", fmt.Errorf("error requesting machine instance types: %w", err) } @@ -243,55 +243,55 @@ func getMachineName(ctx context.Context, machine string, user *api.User, repo *a // if user supplied a machine type, it must be valid // if no machine type was supplied, we don't error if there are no machine types for the current repo if machine != "" { - for _, sku := range skus { - if machine == sku.Name { + for _, m := range machines { + if machine == m.Name { return machine, nil } } - availableSKUs := make([]string, len(skus)) - for i := 0; i < len(skus); i++ { - availableSKUs[i] = skus[i].Name + availableMachines := make([]string, len(machines)) + for i := 0; i < len(machines); i++ { + availableMachines[i] = machines[i].Name } - return "", fmt.Errorf("there is no such machine for the repository: %s\nAvailable machines: %v", machine, availableSKUs) - } else if len(skus) == 0 { + return "", fmt.Errorf("there is no such machine for the repository: %s\nAvailable machines: %v", machine, availableMachines) + } else if len(machines) == 0 { return "", nil } - if len(skus) == 1 { - return skus[0].Name, nil // VS Code does not prompt for SKU if there is only one, this makes us consistent with that behavior + if len(machines) == 1 { + // VS Code does not prompt for machine if there is only one, this makes us consistent with that behavior + return machines[0].Name, nil } - skuNames := make([]string, 0, len(skus)) - skuByName := make(map[string]*api.SKU) - for _, sku := range skus { - nameParts := camelcase.Split(sku.Name) + machineNames := make([]string, 0, len(machines)) + machineByName := make(map[string]*api.Machine) + for _, m := range machines { + nameParts := camelcase.Split(m.Name) machineName := strings.Title(strings.ToLower(nameParts[0])) - skuName := fmt.Sprintf("%s - %s", machineName, sku.DisplayName) - skuNames = append(skuNames, skuName) - skuByName[skuName] = sku + machineName = fmt.Sprintf("%s - %s", machineName, m.DisplayName) + machineNames = append(machineNames, machineName) + machineByName[machineName] = m } - skuSurvey := []*survey.Question{ + machineSurvey := []*survey.Question{ { - Name: "sku", + Name: "machine", Prompt: &survey.Select{ Message: "Choose Machine Type:", - Options: skuNames, - Default: skuNames[0], + Options: machineNames, + Default: machineNames[0], }, Validate: survey.Required, }, } - var skuAnswers struct{ SKU string } - if err := ask(skuSurvey, &skuAnswers); err != nil { - return "", fmt.Errorf("error getting SKU: %w", err) + var machineAnswers struct{ Machine string } + if err := ask(machineSurvey, &machineAnswers); err != nil { + return "", fmt.Errorf("error getting machine: %w", err) } - sku := skuByName[skuAnswers.SKU] - machine = sku.Name + selectedMachine := machineByName[machineAnswers.Machine] - return machine, nil + return selectedMachine.Name, nil } diff --git a/pkg/cmd/codespace/mock_api.go b/pkg/cmd/codespace/mock_api.go index a66deeefc..05ddd7612 100644 --- a/pkg/cmd/codespace/mock_api.go +++ b/pkg/cmd/codespace/mock_api.go @@ -37,8 +37,8 @@ import ( // GetCodespaceTokenFunc: func(ctx context.Context, user string, name string) (string, error) { // panic("mock out the GetCodespaceToken method") // }, -// GetCodespacesSKUsFunc: func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) { -// panic("mock out the GetCodespacesSKUs method") +// GetCodespacesMachinesFunc: func(ctx context.Context, repoID int, branch string, location string) ([]*api.Machine, error) { +// panic("mock out the GetCodespacesMachines method") // }, // GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { // panic("mock out the GetRepository method") @@ -80,8 +80,8 @@ type apiClientMock struct { // GetCodespaceTokenFunc mocks the GetCodespaceToken method. GetCodespaceTokenFunc func(ctx context.Context, user string, name string) (string, error) - // GetCodespacesSKUsFunc mocks the GetCodespacesSKUs method. - GetCodespacesSKUsFunc func(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) + // GetCodespacesMachinesFunc mocks the GetCodespacesMachines method. + GetCodespacesMachinesFunc func(ctx context.Context, repoID int, branch string, location string) ([]*api.Machine, error) // GetRepositoryFunc mocks the GetRepository method. GetRepositoryFunc func(ctx context.Context, nwo string) (*api.Repository, error) @@ -152,14 +152,12 @@ type apiClientMock struct { // Name is the name argument value. Name string } - // GetCodespacesSKUs holds details about calls to the GetCodespacesSKUs method. - GetCodespacesSKUs []struct { + // GetCodespacesMachines holds details about calls to the GetCodespacesMachines method. + GetCodespacesMachines []struct { // Ctx is the ctx argument value. Ctx context.Context - // User is the user argument value. - User *api.User - // Repository is the repository argument value. - Repository *api.Repository + // RepoID is the repoID argument value. + RepoID int // Branch is the branch argument value. Branch string // Location is the location argument value. @@ -197,7 +195,7 @@ type apiClientMock struct { lockGetCodespaceRegionLocation sync.RWMutex lockGetCodespaceRepositoryContents sync.RWMutex lockGetCodespaceToken sync.RWMutex - lockGetCodespacesSKUs sync.RWMutex + lockGetCodespacesMachines sync.RWMutex lockGetRepository sync.RWMutex lockGetUser sync.RWMutex lockListCodespaces sync.RWMutex @@ -461,50 +459,46 @@ func (mock *apiClientMock) GetCodespaceTokenCalls() []struct { return calls } -// GetCodespacesSKUs calls GetCodespacesSKUsFunc. -func (mock *apiClientMock) GetCodespacesSKUs(ctx context.Context, user *api.User, repository *api.Repository, branch string, location string) ([]*api.SKU, error) { - if mock.GetCodespacesSKUsFunc == nil { - panic("apiClientMock.GetCodespacesSKUsFunc: method is nil but apiClient.GetCodespacesSKUs was just called") +// GetCodespacesMachines calls GetCodespacesMachinesFunc. +func (mock *apiClientMock) GetCodespacesMachines(ctx context.Context, repoID int, branch string, location string) ([]*api.Machine, error) { + if mock.GetCodespacesMachinesFunc == nil { + panic("apiClientMock.GetCodespacesMachinesFunc: method is nil but apiClient.GetCodespacesMachines was just called") } callInfo := struct { - Ctx context.Context - User *api.User - Repository *api.Repository - Branch string - Location string + Ctx context.Context + RepoID int + Branch string + Location string }{ - Ctx: ctx, - User: user, - Repository: repository, - Branch: branch, - Location: location, + Ctx: ctx, + RepoID: repoID, + Branch: branch, + Location: location, } - mock.lockGetCodespacesSKUs.Lock() - mock.calls.GetCodespacesSKUs = append(mock.calls.GetCodespacesSKUs, callInfo) - mock.lockGetCodespacesSKUs.Unlock() - return mock.GetCodespacesSKUsFunc(ctx, user, repository, branch, location) + mock.lockGetCodespacesMachines.Lock() + mock.calls.GetCodespacesMachines = append(mock.calls.GetCodespacesMachines, callInfo) + mock.lockGetCodespacesMachines.Unlock() + return mock.GetCodespacesMachinesFunc(ctx, repoID, branch, location) } -// GetCodespacesSKUsCalls gets all the calls that were made to GetCodespacesSKUs. +// GetCodespacesMachinesCalls gets all the calls that were made to GetCodespacesMachines. // Check the length with: -// len(mockedapiClient.GetCodespacesSKUsCalls()) -func (mock *apiClientMock) GetCodespacesSKUsCalls() []struct { - Ctx context.Context - User *api.User - Repository *api.Repository - Branch string - Location string +// len(mockedapiClient.GetCodespacesMachinesCalls()) +func (mock *apiClientMock) GetCodespacesMachinesCalls() []struct { + Ctx context.Context + RepoID int + Branch string + Location string } { var calls []struct { - Ctx context.Context - User *api.User - Repository *api.Repository - Branch string - Location string + Ctx context.Context + RepoID int + Branch string + Location string } - mock.lockGetCodespacesSKUs.RLock() - calls = mock.calls.GetCodespacesSKUs - mock.lockGetCodespacesSKUs.RUnlock() + mock.lockGetCodespacesMachines.RLock() + calls = mock.calls.GetCodespacesMachines + mock.lockGetCodespacesMachines.RUnlock() return calls }