Divide shared repo package and add queries tests

This commit is contained in:
bagtoad 2024-09-29 10:40:22 -06:00
parent fd8c4633e3
commit 21f0d9466e
8 changed files with 182 additions and 40 deletions

View file

@ -19,7 +19,8 @@ import (
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/cmd/repo/shared"
"github.com/cli/cli/v2/pkg/cmd/repo/shared/format"
"github.com/cli/cli/v2/pkg/cmd/repo/shared/queries"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/spf13/cobra"
@ -244,7 +245,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
return nil, cobra.ShellCompDirectiveError
}
hostname, _ := cfg.Authentication().DefaultHost()
licenses, err := shared.ListLicenseTemplates(httpClient, hostname)
licenses, err := queries.ListLicenseTemplates(httpClient, hostname)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}
@ -317,9 +318,9 @@ func createFromScratch(opts *CreateOptions) error {
return err
}
targetRepo := shared.NormalizeRepoName(opts.Name)
targetRepo := format.NormalizeRepoName(opts.Name)
if idx := strings.IndexRune(opts.Name, '/'); idx > 0 {
targetRepo = opts.Name[0:idx+1] + shared.NormalizeRepoName(opts.Name[idx+1:])
targetRepo = opts.Name[0:idx+1] + format.NormalizeRepoName(opts.Name[idx+1:])
}
confirmed, err := opts.Prompter.Confirm(fmt.Sprintf(`This will create "%s" as a %s repository on GitHub. Continue?`, targetRepo, strings.ToLower(opts.Visibility)), true)
if err != nil {
@ -476,9 +477,9 @@ func createFromTemplate(opts *CreateOptions) error {
}
templateRepoMainBranch := templateRepo.DefaultBranchRef.Name
targetRepo := shared.NormalizeRepoName(opts.Name)
targetRepo := format.NormalizeRepoName(opts.Name)
if idx := strings.IndexRune(opts.Name, '/'); idx > 0 {
targetRepo = opts.Name[0:idx+1] + shared.NormalizeRepoName(opts.Name[idx+1:])
targetRepo = opts.Name[0:idx+1] + format.NormalizeRepoName(opts.Name[idx+1:])
}
confirmed, err := opts.Prompter.Confirm(fmt.Sprintf(`This will create "%s" as a %s repository on GitHub. Continue?`, targetRepo, strings.ToLower(opts.Visibility)), true)
if err != nil {
@ -830,7 +831,7 @@ func interactiveLicense(client *http.Client, hostname string, prompter iprompter
return "", nil
}
licenses, err := shared.ListLicenseTemplates(client, hostname)
licenses, err := queries.ListLicenseTemplates(client, hostname)
if err != nil {
return "", err
}

View file

@ -16,7 +16,7 @@ import (
"github.com/cli/cli/v2/git"
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/ghrepo"
"github.com/cli/cli/v2/pkg/cmd/repo/shared"
"github.com/cli/cli/v2/pkg/cmd/repo/shared/format"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/spf13/cobra"
@ -225,7 +225,7 @@ func forkRun(opts *ForkOptions) error {
}
// Rename the new repo if necessary
if opts.ForkName != "" && !strings.EqualFold(forkedRepo.RepoName(), shared.NormalizeRepoName(opts.ForkName)) {
if opts.ForkName != "" && !strings.EqualFold(forkedRepo.RepoName(), format.NormalizeRepoName(opts.ForkName)) {
forkedRepo, err = api.RenameRepo(apiClient, forkedRepo, opts.ForkName)
if err != nil {
return fmt.Errorf("could not rename fork: %w", err)

View file

@ -6,7 +6,7 @@ import (
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/tableprinter"
"github.com/cli/cli/v2/pkg/cmd/repo/shared"
"github.com/cli/cli/v2/pkg/cmd/repo/shared/queries"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/spf13/cobra"
@ -67,7 +67,7 @@ func listRun(opts *ListOptions) error {
}
hostname, _ := cfg.Authentication().DefaultHost()
licenses, err := shared.ListLicenseTemplates(client, hostname)
licenses, err := queries.ListLicenseTemplates(client, hostname)
if err != nil {
return err
}

View file

@ -0,0 +1,14 @@
package format
import (
"regexp"
"strings"
)
var invalidCharactersRE = regexp.MustCompile(`[^\w._-]+`)
// NormalizeRepoName takes in the repo name the user inputted and normalizes it using the same logic as GitHub (GitHub.com/new)
func NormalizeRepoName(repoName string) string {
newName := invalidCharactersRE.ReplaceAllString(repoName, "-")
return strings.TrimSuffix(newName, ".git")
}

View file

@ -1,4 +1,4 @@
package shared
package format
import (
"testing"

View file

@ -0,0 +1,19 @@
package queries
import (
"net/http"
"github.com/cli/cli/v2/api"
)
// ListLicenseTemplates fetches available repository templates.
// It uses API v3 because license template isn't supported by GraphQL.
func ListLicenseTemplates(httpClient *http.Client, hostname string) ([]api.License, error) {
var licenseTemplates []api.License
client := api.NewClientFromHTTP(httpClient)
err := client.REST(hostname, "GET", "licenses", nil, &licenseTemplates)
if err != nil {
return nil, err
}
return licenseTemplates, nil
}

View file

@ -0,0 +1,136 @@
package queries
import (
"net/http"
"testing"
"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/stretchr/testify/assert"
)
func TestListLicenseTemplates(t *testing.T) {
tests := []struct {
name string
httpStubs func(t *testing.T, reg *httpmock.Registry)
hostname string
wantLicenses []api.License
wantErr bool
wantErrMsg string
httpClient func() (*http.Client, error)
}{
{
name: "happy path",
httpStubs: func(t *testing.T, reg *httpmock.Registry) {
reg.Register(
httpmock.REST("GET", "licenses"),
httpmock.StringResponse(`[
{
"key": "mit",
"name": "MIT License",
"spdx_id": "MIT",
"url": "https://api.github.com/licenses/mit",
"node_id": "MDc6TGljZW5zZW1pdA=="
},
{
"key": "lgpl-3.0",
"name": "GNU Lesser General Public License v3.0",
"spdx_id": "LGPL-3.0",
"url": "https://api.github.com/licenses/lgpl-3.0",
"node_id": "MDc6TGljZW5zZW1pdA=="
},
{
"key": "mpl-2.0",
"name": "Mozilla Public License 2.0",
"spdx_id": "MPL-2.0",
"url": "https://api.github.com/licenses/mpl-2.0",
"node_id": "MDc6TGljZW5zZW1pdA=="
},
{
"key": "agpl-3.0",
"name": "GNU Affero General Public License v3.0",
"spdx_id": "AGPL-3.0",
"url": "https://api.github.com/licenses/agpl-3.0",
"node_id": "MDc6TGljZW5zZW1pdA=="
},
{
"key": "unlicense",
"name": "The Unlicense",
"spdx_id": "Unlicense",
"url": "https://api.github.com/licenses/unlicense",
"node_id": "MDc6TGljZW5zZW1pdA=="
},
{
"key": "apache-2.0",
"name": "Apache License 2.0",
"spdx_id": "Apache-2.0",
"url": "https://api.github.com/licenses/apache-2.0",
"node_id": "MDc6TGljZW5zZW1pdA=="
},
{
"key": "gpl-3.0",
"name": "GNU General Public License v3.0",
"spdx_id": "GPL-3.0",
"url": "https://api.github.com/licenses/gpl-3.0",
"node_id": "MDc6TGljZW5zZW1pdA=="
}
]`,
))
},
hostname: "api.github.com",
wantLicenses: []api.License{
{
Key: "mit",
Name: "MIT License",
},
{
Key: "lgpl-3.0",
Name: "GNU Lesser General Public License v3.0",
},
{
Key: "mpl-2.0",
Name: "Mozilla Public License 2.0",
},
{
Key: "agpl-3.0",
Name: "GNU Affero General Public License v3.0",
},
{
Key: "unlicense",
Name: "The Unlicense",
},
{
Key: "apache-2.0",
Name: "Apache License 2.0",
},
{
Key: "gpl-3.0",
Name: "GNU General Public License v3.0",
},
},
wantErr: false,
},
}
for _, tt := range tests {
reg := &httpmock.Registry{}
if tt.httpStubs != nil {
tt.httpStubs(t, reg)
}
tt.httpClient = func() (*http.Client, error) {
return &http.Client{Transport: reg}, nil
}
client, _ := tt.httpClient()
t.Run(tt.name, func(t *testing.T) {
defer reg.Verify(t)
gotLicenses, err := ListLicenseTemplates(client, tt.hostname)
if !tt.wantErr {
assert.NoError(t, err, "Expected no error while fetching /licenses")
}
if tt.wantErr {
assert.Error(t, err, "Expected error while fetching /licenses")
}
assert.Equal(t, tt.wantLicenses, gotLicenses, "Licenses fetched is not as expected")
})
}
}

View file

@ -1,28 +0,0 @@
package shared
import (
"net/http"
"regexp"
"strings"
"github.com/cli/cli/v2/api"
)
var invalidCharactersRE = regexp.MustCompile(`[^\w._-]+`)
// NormalizeRepoName takes in the repo name the user inputted and normalizes it using the same logic as GitHub (GitHub.com/new)
func NormalizeRepoName(repoName string) string {
newName := invalidCharactersRE.ReplaceAllString(repoName, "-")
return strings.TrimSuffix(newName, ".git")
}
// ListLicenseTemplates uses API v3 here because license template isn't supported by GraphQL yet.
func ListLicenseTemplates(httpClient *http.Client, hostname string) ([]api.License, error) {
var licenseTemplates []api.License
client := api.NewClientFromHTTP(httpClient)
err := client.REST(hostname, "GET", "licenses", nil, &licenseTemplates)
if err != nil {
return nil, err
}
return licenseTemplates, nil
}