diff --git a/.github/workflows/triage.yml b/.github/workflows/triage.yml index 6cd9d981d..849beebad 100644 --- a/.github/workflows/triage.yml +++ b/.github/workflows/triage.yml @@ -35,6 +35,8 @@ jobs: --- + cc: @github/cli + > $BODY EOF @@ -63,5 +65,7 @@ jobs: --- + cc: @github/cli + > $BODY EOF diff --git a/acceptance/testdata/repo/repo-create-bare.txtar b/acceptance/testdata/repo/repo-create-bare.txtar new file mode 100644 index 000000000..b835c420b --- /dev/null +++ b/acceptance/testdata/repo/repo-create-bare.txtar @@ -0,0 +1,35 @@ +# It's unclear what we want to do with these acceptance tests beyond our GHEC discovery, so skip new ones by default +skip + +# Set up env var +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} + +# Use gh as a credential helper +exec gh auth setup-git + +# Initialise a local repository with two branches +# We expect a bare repo to have all refs pushed with --mirror +mkdir ${REPO} +cd ${REPO} +exec git init +exec git checkout -b feature-1 +exec git commit --allow-empty -m 'Empty Commit 1' + +exec git checkout -b feature-2 +exec git commit --allow-empty -m 'Empty Commit 2' + +# Clone a bare repo from that local repo +cd .. +exec git clone --bare ${REPO} ${REPO}-bare +cd ${REPO}-bare + +# Create a GitHub repository from that bare repo +exec gh repo create ${ORG}/${REPO} --private --source . --push --remote bare + +# Defer repo cleanup +defer gh repo delete --yes ${ORG}/${REPO} + +# Check the remote repo has both branches +exec gh api /repos/${ORG}/${REPO}/branches +stdout 'feature-1' +stdout 'feature-2' diff --git a/internal/codespaces/rpc/invoker.go b/internal/codespaces/rpc/invoker.go index b9d321802..6ba8843ac 100644 --- a/internal/codespaces/rpc/invoker.go +++ b/internal/codespaces/rpc/invoker.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "os" + "regexp" "strconv" "strings" "time" @@ -241,6 +242,9 @@ func (i *invoker) StartSSHServerWithOptions(ctx context.Context, options StartSS return 0, "", fmt.Errorf("failed to parse SSH server port: %w", err) } + if !isUsernameValid(response.User) { + return 0, "", fmt.Errorf("invalid username: %s", response.User) + } return port, response.User, nil } @@ -300,3 +304,10 @@ func (i *invoker) notifyCodespaceOfClientActivity(ctx context.Context, activity return nil } + +func isUsernameValid(username string) bool { + // assuming valid usernames are alphanumeric, with these special characters allowed: . _ - + var validUsernamePattern = `^[a-zA-Z0-9_][-.a-zA-Z0-9_]*$` + re := regexp.MustCompile(validUsernamePattern) + return re.MatchString(username) +} diff --git a/pkg/cmd/attestation/download/download.go b/pkg/cmd/attestation/download/download.go index 77c928093..143912308 100644 --- a/pkg/cmd/attestation/download/download.go +++ b/pkg/cmd/attestation/download/download.go @@ -122,14 +122,13 @@ func runDownload(opts *Options) error { opts.Logger.VerbosePrintf("Downloading trusted metadata for artifact %s\n\n", opts.ArtifactPath) - c := verification.FetchAttestationsConfig{ - APIClient: opts.APIClient, - Digest: artifact.DigestWithAlg(), - Limit: opts.Limit, - Owner: opts.Owner, - Repo: opts.Repo, + params := verification.FetchRemoteAttestationsParams{ + Digest: artifact.DigestWithAlg(), + Limit: opts.Limit, + Owner: opts.Owner, + Repo: opts.Repo, } - attestations, err := verification.GetRemoteAttestations(c) + attestations, err := verification.GetRemoteAttestations(opts.APIClient, params) if err != nil { if errors.Is(err, api.ErrNoAttestations{}) { fmt.Fprintf(opts.Logger.IO.Out, "No attestations found for %s\n", opts.ArtifactPath) diff --git a/pkg/cmd/attestation/verification/attestation.go b/pkg/cmd/attestation/verification/attestation.go index 0ea91c2f7..07083a5c0 100644 --- a/pkg/cmd/attestation/verification/attestation.go +++ b/pkg/cmd/attestation/verification/attestation.go @@ -9,8 +9,8 @@ import ( "path/filepath" "github.com/cli/cli/v2/pkg/cmd/attestation/api" + "github.com/cli/cli/v2/pkg/cmd/attestation/artifact" "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" - "github.com/google/go-containerregistry/pkg/name" protobundle "github.com/sigstore/protobuf-specs/gen/pb-go/bundle/v1" "github.com/sigstore/sigstore-go/pkg/bundle" ) @@ -20,32 +20,11 @@ const SLSAPredicateV1 = "https://slsa.dev/provenance/v1" var ErrUnrecognisedBundleExtension = errors.New("bundle file extension not supported, must be json or jsonl") var ErrEmptyBundleFile = errors.New("provided bundle file is empty") -type FetchAttestationsConfig struct { - APIClient api.Client - BundlePath string - Digest string - Limit int - Owner string - Repo string - OCIClient oci.Client - UseBundleFromRegistry bool - NameRef name.Reference -} - -func (c *FetchAttestationsConfig) IsBundleProvided() bool { - return c.BundlePath != "" -} - -func GetAttestations(c FetchAttestationsConfig) ([]*api.Attestation, error) { - if c.IsBundleProvided() { - return GetLocalAttestations(c.BundlePath) - } - - if c.UseBundleFromRegistry { - return GetOCIAttestations(c) - } - - return GetRemoteAttestations(c) +type FetchRemoteAttestationsParams struct { + Digest string + Limit int + Owner string + Repo string } // GetLocalAttestations returns a slice of attestations read from a local bundle file. @@ -116,30 +95,30 @@ func loadBundlesFromJSONLinesFile(path string) ([]*api.Attestation, error) { return attestations, nil } -func GetRemoteAttestations(c FetchAttestationsConfig) ([]*api.Attestation, error) { - if c.APIClient == nil { +func GetRemoteAttestations(client api.Client, params FetchRemoteAttestationsParams) ([]*api.Attestation, error) { + if client == nil { return nil, fmt.Errorf("api client must be provided") } // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. - if c.Repo != "" { - attestations, err := c.APIClient.GetByRepoAndDigest(c.Repo, c.Digest, c.Limit) + if params.Repo != "" { + attestations, err := client.GetByRepoAndDigest(params.Repo, params.Digest, params.Limit) if err != nil { - return nil, fmt.Errorf("failed to fetch attestations from %s: %w", c.Repo, err) + return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Repo, err) } return attestations, nil - } else if c.Owner != "" { - attestations, err := c.APIClient.GetByOwnerAndDigest(c.Owner, c.Digest, c.Limit) + } else if params.Owner != "" { + attestations, err := client.GetByOwnerAndDigest(params.Owner, params.Digest, params.Limit) if err != nil { - return nil, fmt.Errorf("failed to fetch attestations from %s: %w", c.Owner, err) + return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Owner, err) } return attestations, nil } return nil, fmt.Errorf("owner or repo must be provided") } -func GetOCIAttestations(c FetchAttestationsConfig) ([]*api.Attestation, error) { - attestations, err := c.OCIClient.GetAttestations(c.NameRef, c.Digest) +func GetOCIAttestations(client oci.Client, artifact artifact.DigestedArtifact) ([]*api.Attestation, error) { + attestations, err := client.GetAttestations(artifact.NameRef(), artifact.Digest()) if err != nil { return nil, fmt.Errorf("failed to fetch OCI attestations: %w", err) } diff --git a/pkg/cmd/attestation/verification/extensions.go b/pkg/cmd/attestation/verification/extensions.go index e302d89c9..a0827e9ec 100644 --- a/pkg/cmd/attestation/verification/extensions.go +++ b/pkg/cmd/attestation/verification/extensions.go @@ -20,7 +20,7 @@ func VerifyCertExtensions(results []*AttestationProcessingResult, ec Enforcement var lastErr error for _, attestation := range results { - err := verifyCertExtensions(*attestation.VerificationResult.Signature.Certificate, ec) + err := verifyCertExtensions(*attestation.VerificationResult.Signature.Certificate, ec.Certificate) if err == nil { // if at least one attestation is verified, we're good as verification // is defined as successful if at least one attestation is verified @@ -34,28 +34,23 @@ func VerifyCertExtensions(results []*AttestationProcessingResult, ec Enforcement return lastErr } -func verifyCertExtensions(verifiedCert certificate.Summary, criteria EnforcementCriteria) error { - sourceRepositoryOwnerURI := verifiedCert.Extensions.SourceRepositoryOwnerURI - if !strings.EqualFold(criteria.Certificate.SourceRepositoryOwnerURI, sourceRepositoryOwnerURI) { - return fmt.Errorf("expected SourceRepositoryOwnerURI to be %s, got %s", criteria.Certificate.SourceRepositoryOwnerURI, sourceRepositoryOwnerURI) +func verifyCertExtensions(given, expected certificate.Summary) error { + if !strings.EqualFold(expected.SourceRepositoryOwnerURI, given.SourceRepositoryOwnerURI) { + return fmt.Errorf("expected SourceRepositoryOwnerURI to be %s, got %s", expected.SourceRepositoryOwnerURI, given.SourceRepositoryOwnerURI) } - // if repo is set, check the SourceRepositoryURI field - if criteria.Certificate.SourceRepositoryURI != "" { - sourceRepositoryURI := verifiedCert.Extensions.SourceRepositoryURI - if !strings.EqualFold(criteria.Certificate.SourceRepositoryURI, sourceRepositoryURI) { - return fmt.Errorf("expected SourceRepositoryURI to be %s, got %s", criteria.Certificate.SourceRepositoryURI, sourceRepositoryURI) - } + // if repo is set, compare the SourceRepositoryURI fields + if expected.SourceRepositoryURI != "" && !strings.EqualFold(expected.SourceRepositoryURI, given.SourceRepositoryURI) { + return fmt.Errorf("expected SourceRepositoryURI to be %s, got %s", expected.SourceRepositoryURI, given.SourceRepositoryURI) } - // if issuer is anything other than the default, use the user-provided value; - // otherwise, select the appropriate default based on the tenant - certIssuer := verifiedCert.Extensions.Issuer - if !strings.EqualFold(criteria.Certificate.Issuer, certIssuer) { - if strings.Index(certIssuer, criteria.Certificate.Issuer+"/") == 0 { - return fmt.Errorf("expected Issuer to be %s, got %s -- if you have a custom OIDC issuer policy for your enterprise, use the --cert-oidc-issuer flag with your expected issuer", criteria.Certificate.Issuer, certIssuer) + // compare the OIDC issuers. If not equal, return an error depending + // on if there is a partial match + if !strings.EqualFold(expected.Issuer, given.Issuer) { + if strings.Index(given.Issuer, expected.Issuer+"/") == 0 { + return fmt.Errorf("expected Issuer to be %s, got %s -- if you have a custom OIDC issuer policy for your enterprise, use the --cert-oidc-issuer flag with your expected issuer", expected.Issuer, given.Issuer) } - return fmt.Errorf("expected Issuer to be %s, got %s", criteria.Certificate.Issuer, certIssuer) + return fmt.Errorf("expected Issuer to be %s, got %s", expected.Issuer, given.Issuer) } return nil diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go new file mode 100644 index 000000000..f3f2792c4 --- /dev/null +++ b/pkg/cmd/attestation/verify/attestation.go @@ -0,0 +1,50 @@ +package verify + +import ( + "fmt" + + "github.com/cli/cli/v2/internal/text" + "github.com/cli/cli/v2/pkg/cmd/attestation/api" + "github.com/cli/cli/v2/pkg/cmd/attestation/artifact" + "github.com/cli/cli/v2/pkg/cmd/attestation/verification" +) + +func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestation, string, error) { + if o.BundlePath != "" { + attestations, err := verification.GetLocalAttestations(o.BundlePath) + if err != nil { + msg := fmt.Sprintf("✗ Loading attestations from %s failed", a.URL) + return nil, msg, err + } + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath) + return attestations, msg, nil + } + + if o.UseBundleFromRegistry { + attestations, err := verification.GetOCIAttestations(o.OCIClient, a) + if err != nil { + msg := "✗ Loading attestations from OCI registry failed" + return nil, msg, err + } + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.ArtifactPath) + return attestations, msg, nil + } + + params := verification.FetchRemoteAttestationsParams{ + Digest: a.DigestWithAlg(), + Limit: o.Limit, + Owner: o.Owner, + Repo: o.Repo, + } + + attestations, err := verification.GetRemoteAttestations(o.APIClient, params) + if err != nil { + msg := "✗ Loading attestations from GitHub API failed" + return nil, msg, err + } + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation) + return attestations, msg, nil +} diff --git a/pkg/cmd/attestation/verify/verify.go b/pkg/cmd/attestation/verify/verify.go index 82b126dcb..2e057b9f3 100644 --- a/pkg/cmd/attestation/verify/verify.go +++ b/pkg/cmd/attestation/verify/verify.go @@ -6,7 +6,6 @@ import ( "regexp" "github.com/cli/cli/v2/internal/ghinstance" - "github.com/cli/cli/v2/internal/text" "github.com/cli/cli/v2/pkg/cmd/attestation/api" "github.com/cli/cli/v2/pkg/cmd/attestation/artifact" "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" @@ -222,42 +221,18 @@ func runVerify(opts *Options) error { opts.Logger.Printf("Loaded digest %s for %s\n", artifact.DigestWithAlg(), artifact.URL) - c := verification.FetchAttestationsConfig{ - APIClient: opts.APIClient, - BundlePath: opts.BundlePath, - Digest: artifact.DigestWithAlg(), - Limit: opts.Limit, - Owner: opts.Owner, - Repo: opts.Repo, - OCIClient: opts.OCIClient, - UseBundleFromRegistry: opts.UseBundleFromRegistry, - NameRef: artifact.NameRef(), - } - attestations, err := verification.GetAttestations(c) + attestations, logMsg, err := getAttestations(opts, *artifact) if err != nil { if ok := errors.Is(err, api.ErrNoAttestations{}); ok { opts.Logger.Printf(opts.Logger.ColorScheme.Red("✗ No attestations found for subject %s\n"), artifact.DigestWithAlg()) return err } - - if c.IsBundleProvided() { - opts.Logger.Printf(opts.Logger.ColorScheme.Red("✗ Loading attestations from %s failed\n"), artifact.URL) - } else if c.UseBundleFromRegistry { - opts.Logger.Println(opts.Logger.ColorScheme.Red("✗ Loading attestations from OCI registry failed")) - } else { - opts.Logger.Println(opts.Logger.ColorScheme.Red("✗ Loading attestations from GitHub API failed")) - } + // Print the message signifying failure fetching attestations + opts.Logger.Println(opts.Logger.ColorScheme.Red(logMsg)) return err } - - pluralAttestation := text.Pluralize(len(attestations), "attestation") - if c.IsBundleProvided() { - opts.Logger.Printf("Loaded %s from %s\n", pluralAttestation, opts.BundlePath) - } else if c.UseBundleFromRegistry { - opts.Logger.Printf("Loaded %s from %s\n", pluralAttestation, opts.ArtifactPath) - } else { - opts.Logger.Printf("Loaded %s from GitHub API\n", pluralAttestation) - } + // Print the message signifying success fetching attestations + opts.Logger.Println(logMsg) // Apply predicate type filter to returned attestations filteredAttestations := verification.FilterAttestations(ec.PredicateType, attestations) diff --git a/pkg/cmd/repo/create/create.go b/pkg/cmd/repo/create/create.go index 2fc127aba..79c349aa4 100644 --- a/pkg/cmd/repo/create/create.go +++ b/pkg/cmd/repo/create/create.go @@ -94,7 +94,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co To create a remote repository from an existing local repository, specify the source directory with %[1]s--source%[1]s. By default, the remote repository name will be the name of the source directory. - Pass %[1]s--push%[1]s to push any local commits to the new repository. + Pass %[1]s--push%[1]s to push any local commits to the new repository. If the repo is bare, this will mirror all refs. For language or platform .gitignore templates to use with %[1]s--gitignore%[1]s, . @@ -556,11 +556,11 @@ func createFromLocal(opts *CreateOptions) error { return err } - isRepo, err := isLocalRepo(opts.GitClient) + repoType, err := localRepoType(opts.GitClient) if err != nil { return err } - if !isRepo { + if repoType == unknown { if repoPath == "." { return fmt.Errorf("current directory is not a git repository. Run `git init` to initialize it") } @@ -652,22 +652,43 @@ func createFromLocal(opts *CreateOptions) error { // don't prompt for push if there are no commits if opts.Interactive && committed { + msg := fmt.Sprintf("Would you like to push commits from the current branch to %q?", baseRemote) + if repoType == bare { + msg = fmt.Sprintf("Would you like to mirror all refs to %q?", baseRemote) + } + var err error - opts.Push, err = opts.Prompter.Confirm(fmt.Sprintf("Would you like to push commits from the current branch to %q?", baseRemote), true) + opts.Push, err = opts.Prompter.Confirm(msg, true) if err != nil { return err } } - if opts.Push { + if opts.Push && repoType == working { err := opts.GitClient.Push(context.Background(), baseRemote, "HEAD") if err != nil { return err } + if isTTY { fmt.Fprintf(stdout, "%s Pushed commits to %s\n", cs.SuccessIcon(), remoteURL) } } + + if opts.Push && repoType == bare { + cmd, err := opts.GitClient.AuthenticatedCommand(context.Background(), "push", baseRemote, "--mirror") + if err != nil { + return err + } + if err = cmd.Run(); err != nil { + return err + } + + if isTTY { + fmt.Fprintf(stdout, "%s Mirrored all refs to %s\n", cs.SuccessIcon(), remoteURL) + } + } + return nil } @@ -736,22 +757,34 @@ func hasCommits(gitClient *git.Client) (bool, error) { return false, nil } -// check if path is the top level directory of a git repo -func isLocalRepo(gitClient *git.Client) (bool, error) { +type repoType int + +const ( + unknown repoType = iota + working + bare +) + +func localRepoType(gitClient *git.Client) (repoType, error) { projectDir, projectDirErr := gitClient.GitDir(context.Background()) if projectDirErr != nil { - var execError *exec.ExitError + var execError errWithExitCode if errors.As(projectDirErr, &execError) { if exitCode := int(execError.ExitCode()); exitCode == 128 { - return false, nil + return unknown, nil } - return false, projectDirErr + return unknown, projectDirErr } } - if projectDir != ".git" { - return false, nil + + switch projectDir { + case ".": + return bare, nil + case ".git": + return working, nil + default: + return unknown, nil } - return true, nil } // clone the checkout branch to specified path diff --git a/pkg/cmd/repo/create/create_test.go b/pkg/cmd/repo/create/create_test.go index cc0ec602a..c33cfdad6 100644 --- a/pkg/cmd/repo/create/create_test.go +++ b/pkg/cmd/repo/create/create_test.go @@ -443,6 +443,74 @@ func Test_createRun(t *testing.T) { }, wantStdout: "✓ Created repository OWNER/REPO on GitHub\n https://github.com/OWNER/REPO\n", }, + { + name: "interactive with existing bare repository public and push", + opts: &CreateOptions{Interactive: true}, + tty: true, + promptStubs: func(p *prompter.PrompterMock) { + p.ConfirmFunc = func(message string, defaultValue bool) (bool, error) { + switch message { + case "Add a remote?": + return true, nil + case `Would you like to mirror all refs to "origin"?`: + return true, nil + default: + return false, fmt.Errorf("unexpected confirm prompt: %s", message) + } + } + p.InputFunc = func(message, defaultValue string) (string, error) { + switch message { + case "Path to local repository": + return defaultValue, nil + case "Repository name": + return "REPO", nil + case "Description": + return "my new repo", nil + case "What should the new remote be called?": + return defaultValue, nil + default: + return "", fmt.Errorf("unexpected input prompt: %s", message) + } + } + p.SelectFunc = func(message, defaultValue string, options []string) (int, error) { + switch message { + case "What would you like to do?": + return prompter.IndexFor(options, "Push an existing local repository to GitHub") + case "Visibility": + return prompter.IndexFor(options, "Private") + default: + return 0, fmt.Errorf("unexpected select prompt: %s", message) + } + } + }, + httpStubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`query UserCurrent\b`), + httpmock.StringResponse(`{"data":{"viewer":{"login":"someuser","organizations":{"nodes": []}}}}`)) + reg.Register( + httpmock.GraphQL(`mutation RepositoryCreate\b`), + httpmock.StringResponse(` + { + "data": { + "createRepository": { + "repository": { + "id": "REPOID", + "name": "REPO", + "owner": {"login":"OWNER"}, + "url": "https://github.com/OWNER/REPO" + } + } + } + }`)) + }, + execStubs: func(cs *run.CommandStubber) { + cs.Register(`git -C . rev-parse --git-dir`, 0, ".") + cs.Register(`git -C . rev-parse HEAD`, 0, "commithash") + cs.Register(`git -C . remote add origin https://github.com/OWNER/REPO`, 0, "") + cs.Register(`git -C . push origin --mirror`, 0, "") + }, + wantStdout: "✓ Created repository OWNER/REPO on GitHub\n https://github.com/OWNER/REPO\n✓ Added remote https://github.com/OWNER/REPO.git\n✓ Mirrored all refs to https://github.com/OWNER/REPO.git\n", + }, { name: "interactive with existing repository public add remote and push", opts: &CreateOptions{Interactive: true}, @@ -696,6 +764,71 @@ func Test_createRun(t *testing.T) { }, wantStdout: "https://github.com/OWNER/REPO\n", }, + { + name: "noninteractive create bare from source and push", + opts: &CreateOptions{ + Interactive: false, + Source: ".", + Push: true, + Name: "REPO", + Visibility: "PRIVATE", + }, + tty: false, + httpStubs: func(reg *httpmock.Registry) { + reg.Register( + httpmock.GraphQL(`mutation RepositoryCreate\b`), + httpmock.StringResponse(` + { + "data": { + "createRepository": { + "repository": { + "id": "REPOID", + "name": "REPO", + "owner": {"login":"OWNER"}, + "url": "https://github.com/OWNER/REPO" + } + } + } + }`)) + }, + execStubs: func(cs *run.CommandStubber) { + cs.Register(`git -C . rev-parse --git-dir`, 0, ".") + cs.Register(`git -C . rev-parse HEAD`, 0, "commithash") + cs.Register(`git -C . remote add origin https://github.com/OWNER/REPO`, 0, "") + cs.Register(`git -C . push origin --mirror`, 0, "") + }, + wantStdout: "https://github.com/OWNER/REPO\n", + }, + { + name: "noninteractive create from cwd that isn't a git repo", + opts: &CreateOptions{ + Interactive: false, + Source: ".", + Name: "REPO", + Visibility: "PRIVATE", + }, + tty: false, + execStubs: func(cs *run.CommandStubber) { + cs.Register(`git -C . rev-parse --git-dir`, 128, "") + }, + wantErr: true, + errMsg: "current directory is not a git repository. Run `git init` to initialize it", + }, + { + name: "noninteractive create from cwd that isn't a git repo", + opts: &CreateOptions{ + Interactive: false, + Source: "some-dir", + Name: "REPO", + Visibility: "PRIVATE", + }, + tty: false, + execStubs: func(cs *run.CommandStubber) { + cs.Register(`git -C some-dir rev-parse --git-dir`, 128, "") + }, + wantErr: true, + errMsg: "some-dir is not a git repository. Run `git -C \"some-dir\" init` to initialize it", + }, { name: "noninteractive clone from scratch", opts: &CreateOptions{ @@ -856,11 +989,11 @@ func Test_createRun(t *testing.T) { defer reg.Verify(t) err := createRun(tt.opts) if tt.wantErr { - assert.Error(t, err) - assert.Equal(t, tt.errMsg, err.Error()) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) return } - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, tt.wantStdout, stdout.String()) assert.Equal(t, "", stderr.String()) }) diff --git a/pkg/cmd/root/extension.go b/pkg/cmd/root/extension.go index d6d495103..52250a432 100644 --- a/pkg/cmd/root/extension.go +++ b/pkg/cmd/root/extension.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "os/exec" + "strings" + "time" "github.com/cli/cli/v2/pkg/extensions" "github.com/cli/cli/v2/pkg/iostreams" @@ -14,10 +16,33 @@ type ExternalCommandExitError struct { *exec.ExitError } +type extensionReleaseInfo struct { + CurrentVersion string + LatestVersion string + Pinned bool + URL string +} + func NewCmdExtension(io *iostreams.IOStreams, em extensions.ExtensionManager, ext extensions.Extension) *cobra.Command { + updateMessageChan := make(chan *extensionReleaseInfo) + cs := io.ColorScheme() + return &cobra.Command{ Use: ext.Name(), Short: fmt.Sprintf("Extension %s", ext.Name()), + // PreRun handles looking up whether extension has a latest version only when the command is ran. + PreRun: func(c *cobra.Command, args []string) { + go func() { + if ext.UpdateAvailable() { + updateMessageChan <- &extensionReleaseInfo{ + CurrentVersion: ext.CurrentVersion(), + LatestVersion: ext.LatestVersion(), + Pinned: ext.IsPinned(), + URL: ext.URL(), + } + } + }() + }, RunE: func(c *cobra.Command, args []string) error { args = append([]string{ext.Name()}, args...) if _, err := em.Dispatch(args, io.In, io.Out, io.ErrOut); err != nil { @@ -29,6 +54,28 @@ func NewCmdExtension(io *iostreams.IOStreams, em extensions.ExtensionManager, ex } return nil }, + // PostRun handles communicating extension release information if found + PostRun: func(c *cobra.Command, args []string) { + select { + case releaseInfo := <-updateMessageChan: + if releaseInfo != nil { + stderr := io.ErrOut + fmt.Fprintf(stderr, "\n\n%s %s → %s\n", + cs.Yellowf("A new release of %s is available:", ext.Name()), + cs.Cyan(strings.TrimPrefix(releaseInfo.CurrentVersion, "v")), + cs.Cyan(strings.TrimPrefix(releaseInfo.LatestVersion, "v"))) + if releaseInfo.Pinned { + fmt.Fprintf(stderr, "To upgrade, run: gh extension upgrade %s --force\n", ext.Name()) + } else { + fmt.Fprintf(stderr, "To upgrade, run: gh extension upgrade %s\n", ext.Name()) + } + fmt.Fprintf(stderr, "%s\n\n", + cs.Yellow(releaseInfo.URL)) + } + case <-time.After(1 * time.Second): + // Bail on checking for new extension update as its taking too long + } + }, GroupID: "extension", Annotations: map[string]string{ "skipAuthCheck": "true", diff --git a/pkg/cmd/root/extension_test.go b/pkg/cmd/root/extension_test.go new file mode 100644 index 000000000..ef94dcc71 --- /dev/null +++ b/pkg/cmd/root/extension_test.go @@ -0,0 +1,159 @@ +package root_test + +import ( + "io" + "testing" + + "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/v2/pkg/cmd/root" + "github.com/cli/cli/v2/pkg/extensions" + "github.com/cli/cli/v2/pkg/iostreams" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCmdExtension_Updates(t *testing.T) { + tests := []struct { + name string + extCurrentVersion string + extIsPinned bool + extLatestVersion string + extName string + extUpdateAvailable bool + extURL string + wantStderr string + }{ + { + name: "no update available", + extName: "no-update", + extUpdateAvailable: false, + extCurrentVersion: "1.0.0", + extLatestVersion: "1.0.0", + extURL: "https//github.com/dne/no-update", + }, + { + name: "major update", + extName: "major-update", + extUpdateAvailable: true, + extCurrentVersion: "1.0.0", + extLatestVersion: "2.0.0", + extURL: "https//github.com/dne/major-update", + wantStderr: heredoc.Doc(` + A new release of major-update is available: 1.0.0 → 2.0.0 + To upgrade, run: gh extension upgrade major-update + https//github.com/dne/major-update + `), + }, + { + name: "major update, pinned", + extName: "major-update", + extUpdateAvailable: true, + extCurrentVersion: "1.0.0", + extLatestVersion: "2.0.0", + extIsPinned: true, + extURL: "https//github.com/dne/major-update", + wantStderr: heredoc.Doc(` + A new release of major-update is available: 1.0.0 → 2.0.0 + To upgrade, run: gh extension upgrade major-update --force + https//github.com/dne/major-update + `), + }, + { + name: "minor update", + extName: "minor-update", + extUpdateAvailable: true, + extCurrentVersion: "1.0.0", + extLatestVersion: "1.1.0", + extURL: "https//github.com/dne/minor-update", + wantStderr: heredoc.Doc(` + A new release of minor-update is available: 1.0.0 → 1.1.0 + To upgrade, run: gh extension upgrade minor-update + https//github.com/dne/minor-update + `), + }, + { + name: "minor update, pinned", + extName: "minor-update", + extUpdateAvailable: true, + extCurrentVersion: "1.0.0", + extLatestVersion: "1.1.0", + extURL: "https//github.com/dne/minor-update", + extIsPinned: true, + wantStderr: heredoc.Doc(` + A new release of minor-update is available: 1.0.0 → 1.1.0 + To upgrade, run: gh extension upgrade minor-update --force + https//github.com/dne/minor-update + `), + }, + { + name: "patch update", + extName: "patch-update", + extUpdateAvailable: true, + extCurrentVersion: "1.0.0", + extLatestVersion: "1.0.1", + extURL: "https//github.com/dne/patch-update", + wantStderr: heredoc.Doc(` + A new release of patch-update is available: 1.0.0 → 1.0.1 + To upgrade, run: gh extension upgrade patch-update + https//github.com/dne/patch-update + `), + }, + { + name: "patch update, pinned", + extName: "patch-update", + extUpdateAvailable: true, + extCurrentVersion: "1.0.0", + extLatestVersion: "1.0.1", + extURL: "https//github.com/dne/patch-update", + extIsPinned: true, + wantStderr: heredoc.Doc(` + A new release of patch-update is available: 1.0.0 → 1.0.1 + To upgrade, run: gh extension upgrade patch-update --force + https//github.com/dne/patch-update + `), + }, + } + + for _, tt := range tests { + ios, _, _, stderr := iostreams.Test() + + em := &extensions.ExtensionManagerMock{ + DispatchFunc: func(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) (bool, error) { + // Assume extension executed / dispatched without problems as test is focused on upgrade checking. + return true, nil + }, + } + + ext := &extensions.ExtensionMock{ + CurrentVersionFunc: func() string { + return tt.extCurrentVersion + }, + IsPinnedFunc: func() bool { + return tt.extIsPinned + }, + LatestVersionFunc: func() string { + return tt.extLatestVersion + }, + NameFunc: func() string { + return tt.extName + }, + UpdateAvailableFunc: func() bool { + return tt.extUpdateAvailable + }, + URLFunc: func() string { + return tt.extURL + }, + } + + cmd := root.NewCmdExtension(ios, em, ext) + + _, err := cmd.ExecuteC() + require.NoError(t, err) + + if tt.wantStderr == "" { + assert.Emptyf(t, stderr.String(), "executing extension command should output nothing to stderr") + } else { + assert.Containsf(t, stderr.String(), tt.wantStderr, "executing extension command should output message about upgrade to stderr") + } + } +}