Merge branch 'trunk' into print-policy-info

This commit is contained in:
Meredith Lancaster 2024-11-18 13:57:34 -07:00
commit a55f9a6301
12 changed files with 528 additions and 108 deletions

View file

@ -35,6 +35,8 @@ jobs:
---
cc: @github/cli
> $BODY
EOF
@ -63,5 +65,7 @@ jobs:
---
cc: @github/cli
> $BODY
EOF

View file

@ -0,0 +1,35 @@
# It's unclear what we want to do with these acceptance tests beyond our GHEC discovery, so skip new ones by default
skip
# Set up env var
env REPO=${SCRIPT_NAME}-${RANDOM_STRING}
# Use gh as a credential helper
exec gh auth setup-git
# Initialise a local repository with two branches
# We expect a bare repo to have all refs pushed with --mirror
mkdir ${REPO}
cd ${REPO}
exec git init
exec git checkout -b feature-1
exec git commit --allow-empty -m 'Empty Commit 1'
exec git checkout -b feature-2
exec git commit --allow-empty -m 'Empty Commit 2'
# Clone a bare repo from that local repo
cd ..
exec git clone --bare ${REPO} ${REPO}-bare
cd ${REPO}-bare
# Create a GitHub repository from that bare repo
exec gh repo create ${ORG}/${REPO} --private --source . --push --remote bare
# Defer repo cleanup
defer gh repo delete --yes ${ORG}/${REPO}
# Check the remote repo has both branches
exec gh api /repos/${ORG}/${REPO}/branches
stdout 'feature-1'
stdout 'feature-2'

View file

@ -8,6 +8,7 @@ import (
"fmt"
"net"
"os"
"regexp"
"strconv"
"strings"
"time"
@ -241,6 +242,9 @@ func (i *invoker) StartSSHServerWithOptions(ctx context.Context, options StartSS
return 0, "", fmt.Errorf("failed to parse SSH server port: %w", err)
}
if !isUsernameValid(response.User) {
return 0, "", fmt.Errorf("invalid username: %s", response.User)
}
return port, response.User, nil
}
@ -300,3 +304,10 @@ func (i *invoker) notifyCodespaceOfClientActivity(ctx context.Context, activity
return nil
}
func isUsernameValid(username string) bool {
// assuming valid usernames are alphanumeric, with these special characters allowed: . _ -
var validUsernamePattern = `^[a-zA-Z0-9_][-.a-zA-Z0-9_]*$`
re := regexp.MustCompile(validUsernamePattern)
return re.MatchString(username)
}

View file

@ -122,14 +122,13 @@ func runDownload(opts *Options) error {
opts.Logger.VerbosePrintf("Downloading trusted metadata for artifact %s\n\n", opts.ArtifactPath)
c := verification.FetchAttestationsConfig{
APIClient: opts.APIClient,
Digest: artifact.DigestWithAlg(),
Limit: opts.Limit,
Owner: opts.Owner,
Repo: opts.Repo,
params := verification.FetchRemoteAttestationsParams{
Digest: artifact.DigestWithAlg(),
Limit: opts.Limit,
Owner: opts.Owner,
Repo: opts.Repo,
}
attestations, err := verification.GetRemoteAttestations(c)
attestations, err := verification.GetRemoteAttestations(opts.APIClient, params)
if err != nil {
if errors.Is(err, api.ErrNoAttestations{}) {
fmt.Fprintf(opts.Logger.IO.Out, "No attestations found for %s\n", opts.ArtifactPath)

View file

@ -9,8 +9,8 @@ import (
"path/filepath"
"github.com/cli/cli/v2/pkg/cmd/attestation/api"
"github.com/cli/cli/v2/pkg/cmd/attestation/artifact"
"github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci"
"github.com/google/go-containerregistry/pkg/name"
protobundle "github.com/sigstore/protobuf-specs/gen/pb-go/bundle/v1"
"github.com/sigstore/sigstore-go/pkg/bundle"
)
@ -20,32 +20,11 @@ const SLSAPredicateV1 = "https://slsa.dev/provenance/v1"
var ErrUnrecognisedBundleExtension = errors.New("bundle file extension not supported, must be json or jsonl")
var ErrEmptyBundleFile = errors.New("provided bundle file is empty")
type FetchAttestationsConfig struct {
APIClient api.Client
BundlePath string
Digest string
Limit int
Owner string
Repo string
OCIClient oci.Client
UseBundleFromRegistry bool
NameRef name.Reference
}
func (c *FetchAttestationsConfig) IsBundleProvided() bool {
return c.BundlePath != ""
}
func GetAttestations(c FetchAttestationsConfig) ([]*api.Attestation, error) {
if c.IsBundleProvided() {
return GetLocalAttestations(c.BundlePath)
}
if c.UseBundleFromRegistry {
return GetOCIAttestations(c)
}
return GetRemoteAttestations(c)
type FetchRemoteAttestationsParams struct {
Digest string
Limit int
Owner string
Repo string
}
// GetLocalAttestations returns a slice of attestations read from a local bundle file.
@ -116,30 +95,30 @@ func loadBundlesFromJSONLinesFile(path string) ([]*api.Attestation, error) {
return attestations, nil
}
func GetRemoteAttestations(c FetchAttestationsConfig) ([]*api.Attestation, error) {
if c.APIClient == nil {
func GetRemoteAttestations(client api.Client, params FetchRemoteAttestationsParams) ([]*api.Attestation, error) {
if client == nil {
return nil, fmt.Errorf("api client must be provided")
}
// check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo.
// If Repo is not set, the field will remain empty. It will not be populated using the value of Owner.
if c.Repo != "" {
attestations, err := c.APIClient.GetByRepoAndDigest(c.Repo, c.Digest, c.Limit)
if params.Repo != "" {
attestations, err := client.GetByRepoAndDigest(params.Repo, params.Digest, params.Limit)
if err != nil {
return nil, fmt.Errorf("failed to fetch attestations from %s: %w", c.Repo, err)
return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Repo, err)
}
return attestations, nil
} else if c.Owner != "" {
attestations, err := c.APIClient.GetByOwnerAndDigest(c.Owner, c.Digest, c.Limit)
} else if params.Owner != "" {
attestations, err := client.GetByOwnerAndDigest(params.Owner, params.Digest, params.Limit)
if err != nil {
return nil, fmt.Errorf("failed to fetch attestations from %s: %w", c.Owner, err)
return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Owner, err)
}
return attestations, nil
}
return nil, fmt.Errorf("owner or repo must be provided")
}
func GetOCIAttestations(c FetchAttestationsConfig) ([]*api.Attestation, error) {
attestations, err := c.OCIClient.GetAttestations(c.NameRef, c.Digest)
func GetOCIAttestations(client oci.Client, artifact artifact.DigestedArtifact) ([]*api.Attestation, error) {
attestations, err := client.GetAttestations(artifact.NameRef(), artifact.Digest())
if err != nil {
return nil, fmt.Errorf("failed to fetch OCI attestations: %w", err)
}

View file

@ -20,7 +20,7 @@ func VerifyCertExtensions(results []*AttestationProcessingResult, ec Enforcement
var lastErr error
for _, attestation := range results {
err := verifyCertExtensions(*attestation.VerificationResult.Signature.Certificate, ec)
err := verifyCertExtensions(*attestation.VerificationResult.Signature.Certificate, ec.Certificate)
if err == nil {
// if at least one attestation is verified, we're good as verification
// is defined as successful if at least one attestation is verified
@ -34,28 +34,23 @@ func VerifyCertExtensions(results []*AttestationProcessingResult, ec Enforcement
return lastErr
}
func verifyCertExtensions(verifiedCert certificate.Summary, criteria EnforcementCriteria) error {
sourceRepositoryOwnerURI := verifiedCert.Extensions.SourceRepositoryOwnerURI
if !strings.EqualFold(criteria.Certificate.SourceRepositoryOwnerURI, sourceRepositoryOwnerURI) {
return fmt.Errorf("expected SourceRepositoryOwnerURI to be %s, got %s", criteria.Certificate.SourceRepositoryOwnerURI, sourceRepositoryOwnerURI)
func verifyCertExtensions(given, expected certificate.Summary) error {
if !strings.EqualFold(expected.SourceRepositoryOwnerURI, given.SourceRepositoryOwnerURI) {
return fmt.Errorf("expected SourceRepositoryOwnerURI to be %s, got %s", expected.SourceRepositoryOwnerURI, given.SourceRepositoryOwnerURI)
}
// if repo is set, check the SourceRepositoryURI field
if criteria.Certificate.SourceRepositoryURI != "" {
sourceRepositoryURI := verifiedCert.Extensions.SourceRepositoryURI
if !strings.EqualFold(criteria.Certificate.SourceRepositoryURI, sourceRepositoryURI) {
return fmt.Errorf("expected SourceRepositoryURI to be %s, got %s", criteria.Certificate.SourceRepositoryURI, sourceRepositoryURI)
}
// if repo is set, compare the SourceRepositoryURI fields
if expected.SourceRepositoryURI != "" && !strings.EqualFold(expected.SourceRepositoryURI, given.SourceRepositoryURI) {
return fmt.Errorf("expected SourceRepositoryURI to be %s, got %s", expected.SourceRepositoryURI, given.SourceRepositoryURI)
}
// if issuer is anything other than the default, use the user-provided value;
// otherwise, select the appropriate default based on the tenant
certIssuer := verifiedCert.Extensions.Issuer
if !strings.EqualFold(criteria.Certificate.Issuer, certIssuer) {
if strings.Index(certIssuer, criteria.Certificate.Issuer+"/") == 0 {
return fmt.Errorf("expected Issuer to be %s, got %s -- if you have a custom OIDC issuer policy for your enterprise, use the --cert-oidc-issuer flag with your expected issuer", criteria.Certificate.Issuer, certIssuer)
// compare the OIDC issuers. If not equal, return an error depending
// on if there is a partial match
if !strings.EqualFold(expected.Issuer, given.Issuer) {
if strings.Index(given.Issuer, expected.Issuer+"/") == 0 {
return fmt.Errorf("expected Issuer to be %s, got %s -- if you have a custom OIDC issuer policy for your enterprise, use the --cert-oidc-issuer flag with your expected issuer", expected.Issuer, given.Issuer)
}
return fmt.Errorf("expected Issuer to be %s, got %s", criteria.Certificate.Issuer, certIssuer)
return fmt.Errorf("expected Issuer to be %s, got %s", expected.Issuer, given.Issuer)
}
return nil

View file

@ -0,0 +1,50 @@
package verify
import (
"fmt"
"github.com/cli/cli/v2/internal/text"
"github.com/cli/cli/v2/pkg/cmd/attestation/api"
"github.com/cli/cli/v2/pkg/cmd/attestation/artifact"
"github.com/cli/cli/v2/pkg/cmd/attestation/verification"
)
func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestation, string, error) {
if o.BundlePath != "" {
attestations, err := verification.GetLocalAttestations(o.BundlePath)
if err != nil {
msg := fmt.Sprintf("✗ Loading attestations from %s failed", a.URL)
return nil, msg, err
}
pluralAttestation := text.Pluralize(len(attestations), "attestation")
msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath)
return attestations, msg, nil
}
if o.UseBundleFromRegistry {
attestations, err := verification.GetOCIAttestations(o.OCIClient, a)
if err != nil {
msg := "✗ Loading attestations from OCI registry failed"
return nil, msg, err
}
pluralAttestation := text.Pluralize(len(attestations), "attestation")
msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.ArtifactPath)
return attestations, msg, nil
}
params := verification.FetchRemoteAttestationsParams{
Digest: a.DigestWithAlg(),
Limit: o.Limit,
Owner: o.Owner,
Repo: o.Repo,
}
attestations, err := verification.GetRemoteAttestations(o.APIClient, params)
if err != nil {
msg := "✗ Loading attestations from GitHub API failed"
return nil, msg, err
}
pluralAttestation := text.Pluralize(len(attestations), "attestation")
msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation)
return attestations, msg, nil
}

View file

@ -6,7 +6,6 @@ import (
"regexp"
"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/internal/text"
"github.com/cli/cli/v2/pkg/cmd/attestation/api"
"github.com/cli/cli/v2/pkg/cmd/attestation/artifact"
"github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci"
@ -222,42 +221,18 @@ func runVerify(opts *Options) error {
opts.Logger.Printf("Loaded digest %s for %s\n", artifact.DigestWithAlg(), artifact.URL)
c := verification.FetchAttestationsConfig{
APIClient: opts.APIClient,
BundlePath: opts.BundlePath,
Digest: artifact.DigestWithAlg(),
Limit: opts.Limit,
Owner: opts.Owner,
Repo: opts.Repo,
OCIClient: opts.OCIClient,
UseBundleFromRegistry: opts.UseBundleFromRegistry,
NameRef: artifact.NameRef(),
}
attestations, err := verification.GetAttestations(c)
attestations, logMsg, err := getAttestations(opts, *artifact)
if err != nil {
if ok := errors.Is(err, api.ErrNoAttestations{}); ok {
opts.Logger.Printf(opts.Logger.ColorScheme.Red("✗ No attestations found for subject %s\n"), artifact.DigestWithAlg())
return err
}
if c.IsBundleProvided() {
opts.Logger.Printf(opts.Logger.ColorScheme.Red("✗ Loading attestations from %s failed\n"), artifact.URL)
} else if c.UseBundleFromRegistry {
opts.Logger.Println(opts.Logger.ColorScheme.Red("✗ Loading attestations from OCI registry failed"))
} else {
opts.Logger.Println(opts.Logger.ColorScheme.Red("✗ Loading attestations from GitHub API failed"))
}
// Print the message signifying failure fetching attestations
opts.Logger.Println(opts.Logger.ColorScheme.Red(logMsg))
return err
}
pluralAttestation := text.Pluralize(len(attestations), "attestation")
if c.IsBundleProvided() {
opts.Logger.Printf("Loaded %s from %s\n", pluralAttestation, opts.BundlePath)
} else if c.UseBundleFromRegistry {
opts.Logger.Printf("Loaded %s from %s\n", pluralAttestation, opts.ArtifactPath)
} else {
opts.Logger.Printf("Loaded %s from GitHub API\n", pluralAttestation)
}
// Print the message signifying success fetching attestations
opts.Logger.Println(logMsg)
// Apply predicate type filter to returned attestations
filteredAttestations := verification.FilterAttestations(ec.PredicateType, attestations)

View file

@ -94,7 +94,7 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co
To create a remote repository from an existing local repository, specify the source directory with %[1]s--source%[1]s.
By default, the remote repository name will be the name of the source directory.
Pass %[1]s--push%[1]s to push any local commits to the new repository.
Pass %[1]s--push%[1]s to push any local commits to the new repository. If the repo is bare, this will mirror all refs.
For language or platform .gitignore templates to use with %[1]s--gitignore%[1]s, <https://github.com/github/gitignore>.
@ -556,11 +556,11 @@ func createFromLocal(opts *CreateOptions) error {
return err
}
isRepo, err := isLocalRepo(opts.GitClient)
repoType, err := localRepoType(opts.GitClient)
if err != nil {
return err
}
if !isRepo {
if repoType == unknown {
if repoPath == "." {
return fmt.Errorf("current directory is not a git repository. Run `git init` to initialize it")
}
@ -652,22 +652,43 @@ func createFromLocal(opts *CreateOptions) error {
// don't prompt for push if there are no commits
if opts.Interactive && committed {
msg := fmt.Sprintf("Would you like to push commits from the current branch to %q?", baseRemote)
if repoType == bare {
msg = fmt.Sprintf("Would you like to mirror all refs to %q?", baseRemote)
}
var err error
opts.Push, err = opts.Prompter.Confirm(fmt.Sprintf("Would you like to push commits from the current branch to %q?", baseRemote), true)
opts.Push, err = opts.Prompter.Confirm(msg, true)
if err != nil {
return err
}
}
if opts.Push {
if opts.Push && repoType == working {
err := opts.GitClient.Push(context.Background(), baseRemote, "HEAD")
if err != nil {
return err
}
if isTTY {
fmt.Fprintf(stdout, "%s Pushed commits to %s\n", cs.SuccessIcon(), remoteURL)
}
}
if opts.Push && repoType == bare {
cmd, err := opts.GitClient.AuthenticatedCommand(context.Background(), "push", baseRemote, "--mirror")
if err != nil {
return err
}
if err = cmd.Run(); err != nil {
return err
}
if isTTY {
fmt.Fprintf(stdout, "%s Mirrored all refs to %s\n", cs.SuccessIcon(), remoteURL)
}
}
return nil
}
@ -736,22 +757,34 @@ func hasCommits(gitClient *git.Client) (bool, error) {
return false, nil
}
// check if path is the top level directory of a git repo
func isLocalRepo(gitClient *git.Client) (bool, error) {
type repoType int
const (
unknown repoType = iota
working
bare
)
func localRepoType(gitClient *git.Client) (repoType, error) {
projectDir, projectDirErr := gitClient.GitDir(context.Background())
if projectDirErr != nil {
var execError *exec.ExitError
var execError errWithExitCode
if errors.As(projectDirErr, &execError) {
if exitCode := int(execError.ExitCode()); exitCode == 128 {
return false, nil
return unknown, nil
}
return false, projectDirErr
return unknown, projectDirErr
}
}
if projectDir != ".git" {
return false, nil
switch projectDir {
case ".":
return bare, nil
case ".git":
return working, nil
default:
return unknown, nil
}
return true, nil
}
// clone the checkout branch to specified path

View file

@ -443,6 +443,74 @@ func Test_createRun(t *testing.T) {
},
wantStdout: "✓ Created repository OWNER/REPO on GitHub\n https://github.com/OWNER/REPO\n",
},
{
name: "interactive with existing bare repository public and push",
opts: &CreateOptions{Interactive: true},
tty: true,
promptStubs: func(p *prompter.PrompterMock) {
p.ConfirmFunc = func(message string, defaultValue bool) (bool, error) {
switch message {
case "Add a remote?":
return true, nil
case `Would you like to mirror all refs to "origin"?`:
return true, nil
default:
return false, fmt.Errorf("unexpected confirm prompt: %s", message)
}
}
p.InputFunc = func(message, defaultValue string) (string, error) {
switch message {
case "Path to local repository":
return defaultValue, nil
case "Repository name":
return "REPO", nil
case "Description":
return "my new repo", nil
case "What should the new remote be called?":
return defaultValue, nil
default:
return "", fmt.Errorf("unexpected input prompt: %s", message)
}
}
p.SelectFunc = func(message, defaultValue string, options []string) (int, error) {
switch message {
case "What would you like to do?":
return prompter.IndexFor(options, "Push an existing local repository to GitHub")
case "Visibility":
return prompter.IndexFor(options, "Private")
default:
return 0, fmt.Errorf("unexpected select prompt: %s", message)
}
}
},
httpStubs: func(reg *httpmock.Registry) {
reg.Register(
httpmock.GraphQL(`query UserCurrent\b`),
httpmock.StringResponse(`{"data":{"viewer":{"login":"someuser","organizations":{"nodes": []}}}}`))
reg.Register(
httpmock.GraphQL(`mutation RepositoryCreate\b`),
httpmock.StringResponse(`
{
"data": {
"createRepository": {
"repository": {
"id": "REPOID",
"name": "REPO",
"owner": {"login":"OWNER"},
"url": "https://github.com/OWNER/REPO"
}
}
}
}`))
},
execStubs: func(cs *run.CommandStubber) {
cs.Register(`git -C . rev-parse --git-dir`, 0, ".")
cs.Register(`git -C . rev-parse HEAD`, 0, "commithash")
cs.Register(`git -C . remote add origin https://github.com/OWNER/REPO`, 0, "")
cs.Register(`git -C . push origin --mirror`, 0, "")
},
wantStdout: "✓ Created repository OWNER/REPO on GitHub\n https://github.com/OWNER/REPO\n✓ Added remote https://github.com/OWNER/REPO.git\n✓ Mirrored all refs to https://github.com/OWNER/REPO.git\n",
},
{
name: "interactive with existing repository public add remote and push",
opts: &CreateOptions{Interactive: true},
@ -696,6 +764,71 @@ func Test_createRun(t *testing.T) {
},
wantStdout: "https://github.com/OWNER/REPO\n",
},
{
name: "noninteractive create bare from source and push",
opts: &CreateOptions{
Interactive: false,
Source: ".",
Push: true,
Name: "REPO",
Visibility: "PRIVATE",
},
tty: false,
httpStubs: func(reg *httpmock.Registry) {
reg.Register(
httpmock.GraphQL(`mutation RepositoryCreate\b`),
httpmock.StringResponse(`
{
"data": {
"createRepository": {
"repository": {
"id": "REPOID",
"name": "REPO",
"owner": {"login":"OWNER"},
"url": "https://github.com/OWNER/REPO"
}
}
}
}`))
},
execStubs: func(cs *run.CommandStubber) {
cs.Register(`git -C . rev-parse --git-dir`, 0, ".")
cs.Register(`git -C . rev-parse HEAD`, 0, "commithash")
cs.Register(`git -C . remote add origin https://github.com/OWNER/REPO`, 0, "")
cs.Register(`git -C . push origin --mirror`, 0, "")
},
wantStdout: "https://github.com/OWNER/REPO\n",
},
{
name: "noninteractive create from cwd that isn't a git repo",
opts: &CreateOptions{
Interactive: false,
Source: ".",
Name: "REPO",
Visibility: "PRIVATE",
},
tty: false,
execStubs: func(cs *run.CommandStubber) {
cs.Register(`git -C . rev-parse --git-dir`, 128, "")
},
wantErr: true,
errMsg: "current directory is not a git repository. Run `git init` to initialize it",
},
{
name: "noninteractive create from cwd that isn't a git repo",
opts: &CreateOptions{
Interactive: false,
Source: "some-dir",
Name: "REPO",
Visibility: "PRIVATE",
},
tty: false,
execStubs: func(cs *run.CommandStubber) {
cs.Register(`git -C some-dir rev-parse --git-dir`, 128, "")
},
wantErr: true,
errMsg: "some-dir is not a git repository. Run `git -C \"some-dir\" init` to initialize it",
},
{
name: "noninteractive clone from scratch",
opts: &CreateOptions{
@ -856,11 +989,11 @@ func Test_createRun(t *testing.T) {
defer reg.Verify(t)
err := createRun(tt.opts)
if tt.wantErr {
assert.Error(t, err)
assert.Equal(t, tt.errMsg, err.Error())
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMsg)
return
}
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, tt.wantStdout, stdout.String())
assert.Equal(t, "", stderr.String())
})

View file

@ -4,6 +4,8 @@ import (
"errors"
"fmt"
"os/exec"
"strings"
"time"
"github.com/cli/cli/v2/pkg/extensions"
"github.com/cli/cli/v2/pkg/iostreams"
@ -14,10 +16,33 @@ type ExternalCommandExitError struct {
*exec.ExitError
}
type extensionReleaseInfo struct {
CurrentVersion string
LatestVersion string
Pinned bool
URL string
}
func NewCmdExtension(io *iostreams.IOStreams, em extensions.ExtensionManager, ext extensions.Extension) *cobra.Command {
updateMessageChan := make(chan *extensionReleaseInfo)
cs := io.ColorScheme()
return &cobra.Command{
Use: ext.Name(),
Short: fmt.Sprintf("Extension %s", ext.Name()),
// PreRun handles looking up whether extension has a latest version only when the command is ran.
PreRun: func(c *cobra.Command, args []string) {
go func() {
if ext.UpdateAvailable() {
updateMessageChan <- &extensionReleaseInfo{
CurrentVersion: ext.CurrentVersion(),
LatestVersion: ext.LatestVersion(),
Pinned: ext.IsPinned(),
URL: ext.URL(),
}
}
}()
},
RunE: func(c *cobra.Command, args []string) error {
args = append([]string{ext.Name()}, args...)
if _, err := em.Dispatch(args, io.In, io.Out, io.ErrOut); err != nil {
@ -29,6 +54,28 @@ func NewCmdExtension(io *iostreams.IOStreams, em extensions.ExtensionManager, ex
}
return nil
},
// PostRun handles communicating extension release information if found
PostRun: func(c *cobra.Command, args []string) {
select {
case releaseInfo := <-updateMessageChan:
if releaseInfo != nil {
stderr := io.ErrOut
fmt.Fprintf(stderr, "\n\n%s %s → %s\n",
cs.Yellowf("A new release of %s is available:", ext.Name()),
cs.Cyan(strings.TrimPrefix(releaseInfo.CurrentVersion, "v")),
cs.Cyan(strings.TrimPrefix(releaseInfo.LatestVersion, "v")))
if releaseInfo.Pinned {
fmt.Fprintf(stderr, "To upgrade, run: gh extension upgrade %s --force\n", ext.Name())
} else {
fmt.Fprintf(stderr, "To upgrade, run: gh extension upgrade %s\n", ext.Name())
}
fmt.Fprintf(stderr, "%s\n\n",
cs.Yellow(releaseInfo.URL))
}
case <-time.After(1 * time.Second):
// Bail on checking for new extension update as its taking too long
}
},
GroupID: "extension",
Annotations: map[string]string{
"skipAuthCheck": "true",

View file

@ -0,0 +1,159 @@
package root_test
import (
"io"
"testing"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/pkg/cmd/root"
"github.com/cli/cli/v2/pkg/extensions"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewCmdExtension_Updates(t *testing.T) {
tests := []struct {
name string
extCurrentVersion string
extIsPinned bool
extLatestVersion string
extName string
extUpdateAvailable bool
extURL string
wantStderr string
}{
{
name: "no update available",
extName: "no-update",
extUpdateAvailable: false,
extCurrentVersion: "1.0.0",
extLatestVersion: "1.0.0",
extURL: "https//github.com/dne/no-update",
},
{
name: "major update",
extName: "major-update",
extUpdateAvailable: true,
extCurrentVersion: "1.0.0",
extLatestVersion: "2.0.0",
extURL: "https//github.com/dne/major-update",
wantStderr: heredoc.Doc(`
A new release of major-update is available: 1.0.0 2.0.0
To upgrade, run: gh extension upgrade major-update
https//github.com/dne/major-update
`),
},
{
name: "major update, pinned",
extName: "major-update",
extUpdateAvailable: true,
extCurrentVersion: "1.0.0",
extLatestVersion: "2.0.0",
extIsPinned: true,
extURL: "https//github.com/dne/major-update",
wantStderr: heredoc.Doc(`
A new release of major-update is available: 1.0.0 2.0.0
To upgrade, run: gh extension upgrade major-update --force
https//github.com/dne/major-update
`),
},
{
name: "minor update",
extName: "minor-update",
extUpdateAvailable: true,
extCurrentVersion: "1.0.0",
extLatestVersion: "1.1.0",
extURL: "https//github.com/dne/minor-update",
wantStderr: heredoc.Doc(`
A new release of minor-update is available: 1.0.0 1.1.0
To upgrade, run: gh extension upgrade minor-update
https//github.com/dne/minor-update
`),
},
{
name: "minor update, pinned",
extName: "minor-update",
extUpdateAvailable: true,
extCurrentVersion: "1.0.0",
extLatestVersion: "1.1.0",
extURL: "https//github.com/dne/minor-update",
extIsPinned: true,
wantStderr: heredoc.Doc(`
A new release of minor-update is available: 1.0.0 1.1.0
To upgrade, run: gh extension upgrade minor-update --force
https//github.com/dne/minor-update
`),
},
{
name: "patch update",
extName: "patch-update",
extUpdateAvailable: true,
extCurrentVersion: "1.0.0",
extLatestVersion: "1.0.1",
extURL: "https//github.com/dne/patch-update",
wantStderr: heredoc.Doc(`
A new release of patch-update is available: 1.0.0 1.0.1
To upgrade, run: gh extension upgrade patch-update
https//github.com/dne/patch-update
`),
},
{
name: "patch update, pinned",
extName: "patch-update",
extUpdateAvailable: true,
extCurrentVersion: "1.0.0",
extLatestVersion: "1.0.1",
extURL: "https//github.com/dne/patch-update",
extIsPinned: true,
wantStderr: heredoc.Doc(`
A new release of patch-update is available: 1.0.0 1.0.1
To upgrade, run: gh extension upgrade patch-update --force
https//github.com/dne/patch-update
`),
},
}
for _, tt := range tests {
ios, _, _, stderr := iostreams.Test()
em := &extensions.ExtensionManagerMock{
DispatchFunc: func(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) (bool, error) {
// Assume extension executed / dispatched without problems as test is focused on upgrade checking.
return true, nil
},
}
ext := &extensions.ExtensionMock{
CurrentVersionFunc: func() string {
return tt.extCurrentVersion
},
IsPinnedFunc: func() bool {
return tt.extIsPinned
},
LatestVersionFunc: func() string {
return tt.extLatestVersion
},
NameFunc: func() string {
return tt.extName
},
UpdateAvailableFunc: func() bool {
return tt.extUpdateAvailable
},
URLFunc: func() string {
return tt.extURL
},
}
cmd := root.NewCmdExtension(ios, em, ext)
_, err := cmd.ExecuteC()
require.NoError(t, err)
if tt.wantStderr == "" {
assert.Emptyf(t, stderr.String(), "executing extension command should output nothing to stderr")
} else {
assert.Containsf(t, stderr.String(), tt.wantStderr, "executing extension command should output message about upgrade to stderr")
}
}
}