diff --git a/acceptance/testdata/pr/pr-checkout-by-number.txtar b/acceptance/testdata/pr/pr-checkout-by-number.txtar new file mode 100644 index 000000000..374926f1d --- /dev/null +++ b/acceptance/testdata/pr/pr-checkout-by-number.txtar @@ -0,0 +1,33 @@ +# Set up env vars +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} + +# Use gh as a credential helper +exec gh auth setup-git + +# Create a repository with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup +defer gh repo delete --yes ${ORG}/${REPO} + +# Clone the repo +exec gh repo clone ${ORG}/${REPO} + +# Prepare a branch to PR +cd ${REPO} +exec git checkout -b feature-branch +exec git commit --allow-empty -m 'Empty Commit' +exec git push -u origin feature-branch + +# Create the PR +exec gh pr create --title 'Feature Title' --body 'Feature Body' +stdout2env PR_URL + +# Remove the local branch +exec git checkout main +exec git branch -D feature-branch +stdout 'Deleted branch feature-branch' + +# Checkout the PR +exec gh pr checkout 1 +stderr 'Switched to a new branch ''feature-branch''' diff --git a/acceptance/testdata/pr/pr-checkout-with-url-from-fork.txtar b/acceptance/testdata/pr/pr-checkout-with-url-from-fork.txtar new file mode 100644 index 000000000..9a0494f4b --- /dev/null +++ b/acceptance/testdata/pr/pr-checkout-with-url-from-fork.txtar @@ -0,0 +1,37 @@ +# Set up env vars +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} + +# Use gh as a credential helper +exec gh auth setup-git + +# Create a repository with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer upstream cleanup +defer gh repo delete --yes ${ORG}/${REPO} + +# Create a fork +exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${REPO}-fork + +# Defer fork cleanup +defer gh repo delete --yes ${ORG}/${REPO}-fork + +# Clone both repos +exec gh repo clone ${ORG}/${REPO} +exec gh repo clone ${ORG}/${REPO}-fork + +# Prepare a branch to PR in the fork itself +cd ${REPO}-fork +exec git checkout -b feature-branch +exec git commit --allow-empty -m 'Empty Commit' +exec git push -u origin feature-branch + +# Create the PR inside the fork +exec gh repo set-default ${ORG}/${REPO}-fork +exec gh pr create --title 'Feature Title' --body 'Feature Body' +stdout2env PR_URL + +# Checkout the PR by full URL in the upstream repo +cd ${WORK}/${REPO} +exec gh pr checkout ${PR_URL} +stderr 'Switched to branch ''feature-branch''' diff --git a/docs/install_linux.md b/docs/install_linux.md index 5943bba24..9624b4374 100644 --- a/docs/install_linux.md +++ b/docs/install_linux.md @@ -33,31 +33,39 @@ sudo apt install gh > [!NOTE] > If errors regarding GPG signatures occur, see [cli/cli#9569](https://github.com/cli/cli/issues/9569) for steps to fix this. -### Fedora, CentOS, Red Hat Enterprise Linux (dnf5) +### Fedora, CentOS, Red Hat Enterprise Linux (DNF4 & DNF5) -Install from our package repository for immediate access to latest releases: +Install from our package repository for immediate access to latest releases. + +#### DNF5 + +> [!IMPORTANT] +> **These commands apply to DNF5 only**. If you're using DNF4, please use [the DNF4 instructions](#dnf4). ```bash +# DNF5 installation commands sudo dnf install dnf5-plugins sudo dnf config-manager addrepo --from-repofile=https://cli.github.com/packages/rpm/gh-cli.repo sudo dnf install gh --repo gh-cli ``` -These commands apply for `dnf5`. If you're using `dnf4`, commands will vary slightly. +#### DNF4 -
-Show dnf4 commands +> [!IMPORTANT] +> **These commands apply to DNF4 only**. If you're using DNF5, please use [the DNF5 instructions](#dnf5). ```bash -sudo dnf4 install 'dnf-command(config-manager)' -sudo dnf4 config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.repo -sudo dnf4 install gh --repo gh-cli +# DNF4 installation commands +sudo dnf install 'dnf-command(config-manager)' +sudo dnf config-manager --add-repo https://cli.github.com/packages/rpm/gh-cli.repo +sudo dnf install gh --repo gh-cli ``` -
> [!NOTE] > If errors regarding GPG signatures occur, see [cli/cli#9569](https://github.com/cli/cli/issues/9569) for steps to fix this. +### Fedora, CentOS, Red Hat Enterprise Linux - Community repository + Alternatively, install from the [community repository](https://packages.fedoraproject.org/pkgs/gh/gh/): ```bash diff --git a/git/client.go b/git/client.go index b2cfbce45..1dea7a6d6 100644 --- a/git/client.go +++ b/git/client.go @@ -115,23 +115,23 @@ type CredentialPattern struct { var AllMatchingCredentialsPattern = CredentialPattern{allMatching: true, pattern: ""} var disallowedCredentialPattern = CredentialPattern{allMatching: false, pattern: ""} -// WM-TODO: Are there any funny remotes that might not resolve to a URL? -func CredentialPatternFromRemote(ctx context.Context, c *Client, remote string) (CredentialPattern, error) { - gitURL, err := c.GetRemoteURL(ctx, remote) - if err != nil { - return CredentialPattern{}, err - } - return CredentialPatternFromGitURL(gitURL) -} - +// CredentialPatternFromGitURL takes a git remote URL e.g. "https://github.com/cli/cli.git" or +// "git@github.com:cli/cli.git" and returns the credential pattern that should be used for it. func CredentialPatternFromGitURL(gitURL string) (CredentialPattern, error) { normalizedURL, err := ParseURL(gitURL) if err != nil { return CredentialPattern{}, fmt.Errorf("failed to parse remote URL: %w", err) } + return CredentialPatternFromHost(normalizedURL.Host), nil +} + +// CredentialPatternFromHost expects host to be in the form "github.com" and returns +// the credential pattern that should be used for it. +// It does not perform any canonicalisation e.g. "api.github.com" will not work as expected. +func CredentialPatternFromHost(host string) CredentialPattern { return CredentialPattern{ - pattern: strings.TrimSuffix(ghinstance.HostPrefix(normalizedURL.Host), "/"), - }, nil + pattern: strings.TrimSuffix(ghinstance.HostPrefix(host), "/"), + } } // AuthenticatedCommand is a wrapper around Command that included configuration to use gh @@ -202,19 +202,6 @@ func (c *Client) UpdateRemoteURL(ctx context.Context, name, url string) error { return nil } -func (c *Client) GetRemoteURL(ctx context.Context, name string) (string, error) { - args := []string{"remote", "get-url", name} - cmd, err := c.Command(ctx, args...) - if err != nil { - return "", err - } - out, err := cmd.Output() - if err != nil { - return "", err - } - return strings.TrimSpace(string(out)), nil -} - func (c *Client) SetRemoteResolution(ctx context.Context, name, resolution string) error { args := []string{"config", "--add", fmt.Sprintf("remote.%s.gh-resolved", name), resolution} cmd, err := c.Command(ctx, args...) diff --git a/git/client_test.go b/git/client_test.go index 41b651d0a..0fb7953bc 100644 --- a/git/client_test.go +++ b/git/client_test.go @@ -1564,7 +1564,7 @@ func TestCredentialPatternFromGitURL(t *testing.T) { }{ { name: "Given a well formed gitURL, it returns the corresponding CredentialPattern", - gitURL: "https://github.com/OWNER/REPO", + gitURL: "https://github.com/OWNER/REPO.git", wantCredentialPattern: CredentialPattern{ pattern: "https://github.com", allMatching: false, @@ -1591,47 +1591,25 @@ func TestCredentialPatternFromGitURL(t *testing.T) { } } -func TestCredentialPatternFromRemote(t *testing.T) { +func TestCredentialPatternFromHost(t *testing.T) { tests := []struct { name string - remote string + host string wantCredentialPattern CredentialPattern - wantErr bool }{ { - name: "Given a well formed remote, it returns the corresponding CredentialPattern", - remote: "https://github.com/OWNER/REPO", + name: "Given a well formed host, it returns the corresponding CredentialPattern", + host: "github.com", wantCredentialPattern: CredentialPattern{ pattern: "https://github.com", allMatching: false, }, }, - { - name: "Given an error from GetRemoteURL, it returns that error", - remote: "foo remote", - wantErr: true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var cmdCtx func(ctx context.Context, name string, args ...string) *exec.Cmd - if tt.wantErr { - _, cmdCtx = createCommandContext(t, 1, tt.remote, "GetRemoteURL error") - } else { - _, cmdCtx = createCommandContext(t, 0, tt.remote, "") - } - - client := Client{ - GitPath: "path/to/git", - commandContext: cmdCtx, - } - credentialPattern, err := CredentialPatternFromRemote(context.Background(), &client, tt.remote) - if tt.wantErr { - assert.ErrorContains(t, err, "GetRemoteURL error") - } else { - assert.NoError(t, err) - assert.Equal(t, tt.wantCredentialPattern, credentialPattern) - } + credentialPattern := CredentialPatternFromHost(tt.host) + require.Equal(t, tt.wantCredentialPattern, credentialPattern) }) } } diff --git a/go.mod b/go.mod index a95ccacc0..a5737cc44 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/cpuguy83/go-md2man/v2 v2.0.5 github.com/creack/pty v1.1.24 github.com/distribution/reference v0.5.0 - github.com/gabriel-vasile/mimetype v1.4.6 + github.com/gabriel-vasile/mimetype v1.4.7 github.com/gdamore/tcell/v2 v2.5.4 github.com/google/go-cmp v0.6.0 github.com/google/go-containerregistry v0.20.2 @@ -44,10 +44,10 @@ require ( github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 github.com/zalando/go-keyring v0.2.5 - golang.org/x/crypto v0.28.0 - golang.org/x/sync v0.8.0 - golang.org/x/term v0.25.0 - golang.org/x/text v0.19.0 + golang.org/x/crypto v0.29.0 + golang.org/x/sync v0.9.0 + golang.org/x/term v0.26.0 + golang.org/x/text v0.20.0 google.golang.org/grpc v1.64.1 google.golang.org/protobuf v1.34.2 gopkg.in/h2non/gock.v1 v1.1.2 @@ -159,8 +159,8 @@ require ( go.uber.org/zap v1.27.0 // indirect golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 // indirect golang.org/x/mod v0.21.0 // indirect - golang.org/x/net v0.30.0 // indirect - golang.org/x/sys v0.26.0 // indirect + golang.org/x/net v0.31.0 // indirect + golang.org/x/sys v0.27.0 // indirect golang.org/x/tools v0.26.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240520151616-dc85e6b867a5 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240520151616-dc85e6b867a5 // indirect diff --git a/go.sum b/go.sum index a7d87de7d..ae81bcc33 100644 --- a/go.sum +++ b/go.sum @@ -149,8 +149,8 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= -github.com/gabriel-vasile/mimetype v1.4.6 h1:3+PzJTKLkvgjeTbts6msPJt4DixhT4YtFNf1gtGe3zc= -github.com/gabriel-vasile/mimetype v1.4.6/go.mod h1:JX1qVKqZd40hUPpAfiNTe0Sne7hdfKSbOqqmkq8GCXc= +github.com/gabriel-vasile/mimetype v1.4.7 h1:SKFKl7kD0RiPdbht0s7hFtjl489WcQ1VyPW8ZzUMYCA= +github.com/gabriel-vasile/mimetype v1.4.7/go.mod h1:GDlAgAyIRT27BhFl53XNAFtfjzOkLaF35JdEG0P7LtU= github.com/gdamore/encoding v1.0.0 h1:+7OoQ1Bc6eTm5niUzBa0Ctsh6JbMW6Ra+YNuAtDBdko= github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo5dl+VrEg= github.com/gdamore/tcell/v2 v2.5.4 h1:TGU4tSjD3sCL788vFNeJnTdzpNKIw1H5dgLnJRQVv/k= @@ -488,8 +488,8 @@ go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= +golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 h1:hNQpMuAJe5CtcUqCXaWga3FHu+kQvCqcsoVaQgSV60o= golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -498,14 +498,14 @@ golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= -golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= +golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= +golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/oauth2 v0.22.0 h1:BzDx2FehcG7jJwgWLELCdmLuxk2i+x9UDpSiss2u0ZA= golang.org/x/oauth2 v0.22.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= +golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -517,19 +517,19 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220906165534-d0df966e6959/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= -golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= +golang.org/x/term v0.26.0 h1:WEQa6V3Gja/BhNxg540hBip/kkaYtRg3cxg4oXSw4AU= +golang.org/x/term v0.26.0/go.mod h1:Si5m1o57C5nBNQo5z1iq+XDijt21BDBDp2bK0QI8e3E= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/pkg/cmd/pr/checkout/checkout.go b/pkg/cmd/pr/checkout/checkout.go index b9ded029b..8fb591b58 100644 --- a/pkg/cmd/pr/checkout/checkout.go +++ b/pkg/cmd/pr/checkout/checkout.go @@ -157,14 +157,9 @@ func checkoutRun(opts *CheckoutOptions) error { cmdQueue = append(cmdQueue, []string{"submodule", "update", "--init", "--recursive"}) } - // Note that although we will probably be fetching from the headRemote, in practice, PR checkout can only - // ever point to one host, and we know baseRemote must be populated, where headRemote might be nil (e.g. when - // it was deleted). - credentialPattern, err := git.CredentialPatternFromRemote(context.Background(), opts.GitClient, baseRemote.Name) - if err != nil { - return err - } - err = executeCmds(opts.GitClient, credentialPattern, cmdQueue) + // Note that although we will probably be fetching from the head, in practice, PR checkout can only + // ever point to one host, and we know baseRepo must be populated. + err = executeCmds(opts.GitClient, git.CredentialPatternFromHost(baseRepo.RepoHost()), cmdQueue) if err != nil { return err } diff --git a/pkg/cmd/pr/checkout/checkout_test.go b/pkg/cmd/pr/checkout/checkout_test.go index 5c2109f20..428125bbb 100644 --- a/pkg/cmd/pr/checkout/checkout_test.go +++ b/pkg/cmd/pr/checkout/checkout_test.go @@ -101,7 +101,6 @@ func Test_checkoutRun(t *testing.T) { }, runStubs: func(cs *run.CommandStubber) { cs.Register(`git show-ref --verify -- refs/heads/feature`, 1, "") - cs.Register(`git remote get-url origin`, 0, "https://github.com/OWNER/REPO.git") cs.Register(`git fetch origin \+refs/heads/feature:refs/remotes/origin/feature`, 0, "") cs.Register(`git checkout -b feature --track origin/feature`, 0, "") }, @@ -132,7 +131,6 @@ func Test_checkoutRun(t *testing.T) { "origin": "OWNER/REPO", }, runStubs: func(cs *run.CommandStubber) { - cs.Register(`git remote get-url origin`, 0, "https://github.com/OWNER/REPO.git") cs.Register(`git fetch origin refs/pull/123/head:feature`, 0, "") cs.Register(`git config branch\.feature\.merge`, 1, "") cs.Register(`git checkout feature`, 0, "") @@ -167,7 +165,6 @@ func Test_checkoutRun(t *testing.T) { }, runStubs: func(cs *run.CommandStubber) { cs.Register(`git show-ref --verify -- refs/heads/foobar`, 1, "") - cs.Register(`git remote get-url origin`, 0, "https://github.com/OWNER/REPO.git") cs.Register(`git fetch origin \+refs/heads/feature:refs/remotes/origin/feature`, 0, "") cs.Register(`git checkout -b foobar --track origin/feature`, 0, "") }, @@ -199,7 +196,6 @@ func Test_checkoutRun(t *testing.T) { }, runStubs: func(cs *run.CommandStubber) { cs.Register(`git config branch\.foobar\.merge`, 1, "") - cs.Register(`git remote get-url origin`, 0, "https://github.com/hubot/REPO.git") cs.Register(`git fetch origin refs/pull/123/head:foobar`, 0, "") cs.Register(`git checkout foobar`, 0, "") cs.Register(`git config branch\.foobar\.remote https://github.com/hubot/REPO.git`, 0, "") @@ -386,7 +382,6 @@ func TestPRCheckout_sameRepo(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - cs.Register(`git remote get-url origin`, 0, "https://github.com/OWNER/REPO.git") cs.Register(`git fetch origin \+refs/heads/feature:refs/remotes/origin/feature`, 0, "") cs.Register(`git show-ref --verify -- refs/heads/feature`, 1, "") cs.Register(`git checkout -b feature --track origin/feature`, 0, "") @@ -406,8 +401,6 @@ func TestPRCheckout_existingBranch(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - - cs.Register(`git remote get-url origin`, 0, "https://github.com/OWNER/REPO.git") cs.Register(`git fetch origin \+refs/heads/feature:refs/remotes/origin/feature`, 0, "") cs.Register(`git show-ref --verify -- refs/heads/feature`, 0, "") cs.Register(`git checkout feature`, 0, "") @@ -440,8 +433,6 @@ func TestPRCheckout_differentRepo_remoteExists(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - - cs.Register(`git remote get-url origin`, 0, "https://github.com/OWNER/REPO.git") cs.Register(`git fetch robot-fork \+refs/heads/feature:refs/remotes/robot-fork/feature`, 0, "") cs.Register(`git show-ref --verify -- refs/heads/feature`, 1, "") cs.Register(`git checkout -b feature --track robot-fork/feature`, 0, "") @@ -462,8 +453,6 @@ func TestPRCheckout_differentRepo(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - - cs.Register(`git remote get-url origin`, 0, "https://github.com/OWNER/REPO.git") cs.Register(`git fetch origin refs/pull/123/head:feature`, 0, "") cs.Register(`git config branch\.feature\.merge`, 1, "") cs.Register(`git checkout feature`, 0, "") @@ -486,8 +475,6 @@ func TestPRCheckout_differentRepo_existingBranch(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - - cs.Register(`git remote get-url origin`, 0, "https://github.com/OWNER/REPO.git") cs.Register(`git fetch origin refs/pull/123/head:feature`, 0, "") cs.Register(`git config branch\.feature\.merge`, 0, "refs/heads/feature\n") cs.Register(`git checkout feature`, 0, "") @@ -507,8 +494,6 @@ func TestPRCheckout_detachedHead(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - - cs.Register(`git remote get-url origin`, 0, "https://github.com/OWNER/REPO.git") cs.Register(`git fetch origin refs/pull/123/head:feature`, 0, "") cs.Register(`git config branch\.feature\.merge`, 0, "refs/heads/feature\n") cs.Register(`git checkout feature`, 0, "") @@ -528,8 +513,6 @@ func TestPRCheckout_differentRepo_currentBranch(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - - cs.Register(`git remote get-url origin`, 0, "https://github.com/OWNER/REPO.git") cs.Register(`git fetch origin refs/pull/123/head`, 0, "") cs.Register(`git config branch\.feature\.merge`, 0, "refs/heads/feature\n") cs.Register(`git merge --ff-only FETCH_HEAD`, 0, "") @@ -549,8 +532,9 @@ func TestPRCheckout_differentRepo_invalidBranchName(t *testing.T) { _, cmdTeardown := run.Stub() defer cmdTeardown(t) - + output, err := runCommand(http, nil, "master", `123`, baseRepo) + assert.EqualError(t, err, `invalid branch name: "-foo"`) assert.Equal(t, "", output.Stderr()) assert.Equal(t, "", output.Stderr()) @@ -566,8 +550,6 @@ func TestPRCheckout_maintainerCanModify(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - - cs.Register(`git remote get-url origin`, 0, "https://github.com/OWNER/REPO.git") cs.Register(`git fetch origin refs/pull/123/head:feature`, 0, "") cs.Register(`git config branch\.feature\.merge`, 1, "") cs.Register(`git checkout feature`, 0, "") @@ -589,8 +571,6 @@ func TestPRCheckout_recurseSubmodules(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - - cs.Register(`git remote get-url origin`, 0, "https://github.com/OWNER/REPO.git") cs.Register(`git fetch origin \+refs/heads/feature:refs/remotes/origin/feature`, 0, "") cs.Register(`git show-ref --verify -- refs/heads/feature`, 0, "") cs.Register(`git checkout feature`, 0, "") @@ -612,8 +592,6 @@ func TestPRCheckout_force(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - - cs.Register(`git remote get-url origin`, 0, "https://github.com/OWNER/REPO.git") cs.Register(`git fetch origin \+refs/heads/feature:refs/remotes/origin/feature`, 0, "") cs.Register(`git show-ref --verify -- refs/heads/feature`, 0, "") cs.Register(`git checkout feature`, 0, "") @@ -635,9 +613,7 @@ func TestPRCheckout_detach(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - cs.Register(`git checkout --detach FETCH_HEAD`, 0, "") - cs.Register(`git remote get-url origin`, 0, "https://github.com/hubot/REPO.git") cs.Register(`git fetch origin refs/pull/123/head`, 0, "") output, err := runCommand(http, nil, "", `123 --detach`, baseRepo) diff --git a/pkg/cmd/repo/fork/fork.go b/pkg/cmd/repo/fork/fork.go index a49f5d567..73e269500 100644 --- a/pkg/cmd/repo/fork/fork.go +++ b/pkg/cmd/repo/fork/fork.go @@ -306,6 +306,10 @@ func forkRun(opts *ForkOptions) error { if err != nil { return err } + + if connectedToTerminal { + fmt.Fprintf(stderr, "%s Renamed remote %s to %s\n", cs.SuccessIcon(), cs.Bold(remoteName), cs.Bold(renameTarget)) + } } else { return fmt.Errorf("a git remote named '%s' already exists", remoteName) } diff --git a/pkg/cmd/repo/fork/fork_test.go b/pkg/cmd/repo/fork/fork_test.go index 0f94496f0..0252602fe 100644 --- a/pkg/cmd/repo/fork/fork_test.go +++ b/pkg/cmd/repo/fork/fork_test.go @@ -298,7 +298,7 @@ func TestRepoFork(t *testing.T) { return true, nil }) }, - wantErrOut: "✓ Created fork someone/REPO\n✓ Added remote origin\n", + wantErrOut: "✓ Created fork someone/REPO\n✓ Renamed remote origin to upstream\n✓ Added remote origin\n", }, { name: "implicit tty reuse existing remote", @@ -370,7 +370,7 @@ func TestRepoFork(t *testing.T) { cs.Register("git remote rename origin upstream", 0, "") cs.Register(`git remote add origin https://github.com/someone/REPO.git`, 0, "") }, - wantErrOut: "✓ Created fork someone/REPO\n✓ Added remote origin\n", + wantErrOut: "✓ Created fork someone/REPO\n✓ Renamed remote origin to upstream\n✓ Added remote origin\n", }, { name: "implicit nontty reuse existing remote", diff --git a/pkg/cmd/run/download/download.go b/pkg/cmd/run/download/download.go index 99ec45bbe..8f25e84a2 100644 --- a/pkg/cmd/run/download/download.go +++ b/pkg/cmd/run/download/download.go @@ -151,8 +151,10 @@ func runDownload(opts *DownloadOptions) error { opts.IO.StartProgressIndicator() defer opts.IO.StopProgressIndicator() - // track downloaded artifacts and avoid re-downloading any of the same name + // track downloaded artifacts and avoid re-downloading any of the same name, isolate if multiple artifacts downloaded := set.NewStringSet() + isolateArtifacts := isolateArtifacts(wantNames, wantPatterns) + for _, a := range artifacts { if a.Expired { continue @@ -165,10 +167,16 @@ func runDownload(opts *DownloadOptions) error { continue } } + destDir := opts.DestinationDir - if len(wantPatterns) != 0 || len(wantNames) != 1 { + if isolateArtifacts { destDir = filepath.Join(destDir, a.Name) } + + if !filepathDescendsFrom(destDir, opts.DestinationDir) { + return fmt.Errorf("error downloading %s: would result in path traversal", a.Name) + } + err := opts.Platform.Download(a.DownloadURL, destDir) if err != nil { return fmt.Errorf("error downloading %s: %w", a.Name, err) @@ -183,6 +191,25 @@ func runDownload(opts *DownloadOptions) error { return nil } +func isolateArtifacts(wantNames []string, wantPatterns []string) bool { + if len(wantPatterns) > 0 { + // Patterns can match multiple artifacts + return true + } + + if len(wantNames) == 0 { + // All artifacts wanted regardless what they are named + return true + } + + if len(wantNames) > 1 { + // Multiple, specific artifacts wanted + return true + } + + return false +} + func matchAnyName(names []string, name string) bool { for _, n := range names { if name == n { diff --git a/pkg/cmd/run/download/download_test.go b/pkg/cmd/run/download/download_test.go index 3c1c8f2d8..867661232 100644 --- a/pkg/cmd/run/download/download_test.go +++ b/pkg/cmd/run/download/download_test.go @@ -2,8 +2,11 @@ package download import ( "bytes" + "errors" + "fmt" "io" "net/http" + "os" "path/filepath" "testing" @@ -14,7 +17,6 @@ import ( "github.com/cli/cli/v2/pkg/iostreams" "github.com/google/shlex" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -143,159 +145,584 @@ func Test_NewCmdDownload(t *testing.T) { } } +type run struct { + id string + testArtifacts []testArtifact +} + +type testArtifact struct { + artifact shared.Artifact + files []string +} + +type fakePlatform struct { + runs []run +} + +func (f *fakePlatform) List(runID string) ([]shared.Artifact, error) { + runIds := map[string]struct{}{} + if runID != "" { + runIds[runID] = struct{}{} + } else { + for _, run := range f.runs { + runIds[run.id] = struct{}{} + } + } + + var artifacts []shared.Artifact + for _, run := range f.runs { + // Skip over any runs that we aren't looking for + if _, ok := runIds[run.id]; !ok { + continue + } + + // Grab the artifacts of everything else + for _, testArtifact := range run.testArtifacts { + artifacts = append(artifacts, testArtifact.artifact) + } + } + + return artifacts, nil +} + +func (f *fakePlatform) Download(url string, dir string) error { + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + // Now to be consistent, we find the artifact with the provided URL. + // It's a bit janky to iterate the runs, to find the right artifact + // rather than keying directly to it, but it allows the setup of the + // fake platform to be declarative rather than imperative. + // Think fakePlatform { artifacts: ... } rather than fakePlatform.makeArtifactAvailable() + for _, run := range f.runs { + for _, testArtifact := range run.testArtifacts { + if testArtifact.artifact.DownloadURL == url { + for _, file := range testArtifact.files { + path := filepath.Join(dir, file) + return os.WriteFile(path, []byte{}, 0600) + } + } + } + } + + return errors.New("no artifact matches the provided URL") +} + func Test_runDownload(t *testing.T) { tests := []struct { - name string - opts DownloadOptions - mockAPI func(*mockPlatform) - promptStubs func(*prompter.MockPrompter) - wantErr string + name string + opts DownloadOptions + platform *fakePlatform + promptStubs func(*prompter.MockPrompter) + expectedFiles []string + wantErr string }{ { - name: "download non-expired", + name: "download non-expired to relative directory", opts: DownloadOptions{ RunID: "2345", DestinationDir: "./tmp", - Names: []string(nil), }, - mockAPI: func(p *mockPlatform) { - p.On("List", "2345").Return([]shared.Artifact{ + platform: &fakePlatform{ + runs: []run{ { - Name: "artifact-1", - DownloadURL: "http://download.com/artifact1.zip", - Expired: false, + id: "2345", + testArtifacts: []testArtifact{ + { + artifact: shared.Artifact{ + Name: "artifact-1", + DownloadURL: "http://download.com/artifact1.zip", + Expired: false, + }, + files: []string{ + "artifact-1-file", + }, + }, + { + artifact: shared.Artifact{ + Name: "expired-artifact", + DownloadURL: "http://download.com/expired.zip", + Expired: true, + }, + files: []string{ + "expired", + }, + }, + { + artifact: shared.Artifact{ + Name: "artifact-2", + DownloadURL: "http://download.com/artifact2.zip", + Expired: false, + }, + files: []string{ + "artifact-2-file", + }, + }, + }, }, - { - Name: "expired-artifact", - DownloadURL: "http://download.com/expired.zip", - Expired: true, - }, - { - Name: "artifact-2", - DownloadURL: "http://download.com/artifact2.zip", - Expired: false, - }, - }, nil) - p.On("Download", "http://download.com/artifact1.zip", filepath.FromSlash("tmp/artifact-1")).Return(nil) - p.On("Download", "http://download.com/artifact2.zip", filepath.FromSlash("tmp/artifact-2")).Return(nil) + }, + }, + expectedFiles: []string{ + filepath.Join("artifact-1", "artifact-1-file"), + filepath.Join("artifact-2", "artifact-2-file"), }, }, { - name: "no valid artifacts", + name: "download non-expired to absolute directory", opts: DownloadOptions{ RunID: "2345", - DestinationDir: ".", - Names: []string(nil), + DestinationDir: "/tmp", }, - mockAPI: func(p *mockPlatform) { - p.On("List", "2345").Return([]shared.Artifact{ + platform: &fakePlatform{ + runs: []run{ { - Name: "artifact-1", - DownloadURL: "http://download.com/artifact1.zip", - Expired: true, + id: "2345", + testArtifacts: []testArtifact{ + { + artifact: shared.Artifact{ + Name: "artifact-1", + DownloadURL: "http://download.com/artifact1.zip", + Expired: false, + }, + files: []string{ + "artifact-1-file", + }, + }, + { + artifact: shared.Artifact{ + Name: "expired-artifact", + DownloadURL: "http://download.com/expired.zip", + Expired: true, + }, + files: []string{ + "expired", + }, + }, + { + artifact: shared.Artifact{ + Name: "artifact-2", + DownloadURL: "http://download.com/artifact2.zip", + Expired: false, + }, + files: []string{ + "artifact-2-file", + }, + }, + }, }, - { - Name: "artifact-2", - DownloadURL: "http://download.com/artifact2.zip", - Expired: true, - }, - }, nil) + }, }, - wantErr: "no valid artifacts found to download", + expectedFiles: []string{ + filepath.Join("artifact-1", "artifact-1-file"), + filepath.Join("artifact-2", "artifact-2-file"), + }, + }, + { + name: "all artifacts are expired", + opts: DownloadOptions{ + RunID: "2345", + }, + platform: &fakePlatform{ + runs: []run{ + { + id: "2345", + testArtifacts: []testArtifact{ + { + artifact: shared.Artifact{ + Name: "artifact-1", + DownloadURL: "http://download.com/artifact1.zip", + Expired: true, + }, + files: []string{ + "artifact-1-file", + }, + }, + { + artifact: shared.Artifact{ + Name: "artifact-2", + DownloadURL: "http://download.com/artifact2.zip", + Expired: true, + }, + files: []string{ + "artifact-2-file", + }, + }, + }, + }, + }, + }, + expectedFiles: []string{}, + wantErr: "no valid artifacts found to download", }, { name: "no name matches", opts: DownloadOptions{ - RunID: "2345", - DestinationDir: ".", - Names: []string{"artifact-3"}, + RunID: "2345", + Names: []string{"artifact-3"}, }, - mockAPI: func(p *mockPlatform) { - p.On("List", "2345").Return([]shared.Artifact{ + platform: &fakePlatform{ + runs: []run{ { - Name: "artifact-1", - DownloadURL: "http://download.com/artifact1.zip", - Expired: false, + id: "2345", + testArtifacts: []testArtifact{ + { + artifact: shared.Artifact{ + Name: "artifact-1", + DownloadURL: "http://download.com/artifact1.zip", + Expired: false, + }, + files: []string{ + "artifact-1-file", + }, + }, + { + artifact: shared.Artifact{ + Name: "artifact-2", + DownloadURL: "http://download.com/artifact2.zip", + Expired: false, + }, + files: []string{ + "artifact-2-file", + }, + }, + }, }, - { - Name: "artifact-2", - DownloadURL: "http://download.com/artifact2.zip", - Expired: false, - }, - }, nil) + }, + }, + expectedFiles: []string{}, + wantErr: "no artifact matches any of the names or patterns provided", + }, + { + name: "pattern matches", + opts: DownloadOptions{ + RunID: "2345", + FilePatterns: []string{"artifact-*"}, + }, + platform: &fakePlatform{ + runs: []run{ + { + id: "2345", + testArtifacts: []testArtifact{ + { + artifact: shared.Artifact{ + Name: "artifact-1", + DownloadURL: "http://download.com/artifact1.zip", + Expired: false, + }, + files: []string{ + "artifact-1-file", + }, + }, + { + artifact: shared.Artifact{ + Name: "non-artifact-2", + DownloadURL: "http://download.com/non-artifact-2.zip", + Expired: false, + }, + files: []string{ + "non-artifact-2-file", + }, + }, + { + artifact: shared.Artifact{ + Name: "artifact-3", + DownloadURL: "http://download.com/artifact3.zip", + Expired: false, + }, + files: []string{ + "artifact-3-file", + }, + }, + }, + }, + }, + }, + expectedFiles: []string{ + filepath.Join("artifact-1", "artifact-1-file"), + filepath.Join("artifact-3", "artifact-3-file"), }, - wantErr: "no artifact matches any of the names or patterns provided", }, { name: "no pattern matches", opts: DownloadOptions{ - RunID: "2345", - DestinationDir: ".", - FilePatterns: []string{"artifiction-*"}, + RunID: "2345", + FilePatterns: []string{"artifiction-*"}, }, - mockAPI: func(p *mockPlatform) { - p.On("List", "2345").Return([]shared.Artifact{ + platform: &fakePlatform{ + runs: []run{ { - Name: "artifact-1", - DownloadURL: "http://download.com/artifact1.zip", - Expired: false, + id: "2345", + testArtifacts: []testArtifact{ + { + artifact: shared.Artifact{ + Name: "artifact-1", + DownloadURL: "http://download.com/artifact1.zip", + Expired: false, + }, + files: []string{ + "artifact-1-file", + }, + }, + { + artifact: shared.Artifact{ + Name: "artifact-2", + DownloadURL: "http://download.com/artifact2.zip", + Expired: false, + }, + files: []string{ + "artifact-2-file", + }, + }, + }, }, - { - Name: "artifact-2", - DownloadURL: "http://download.com/artifact2.zip", - Expired: false, - }, - }, nil) + }, + }, + expectedFiles: []string{}, + wantErr: "no artifact matches any of the names or patterns provided", + }, + { + name: "want specific single artifact", + opts: DownloadOptions{ + RunID: "2345", + Names: []string{"non-artifact-2"}, + }, + platform: &fakePlatform{ + runs: []run{ + { + id: "2345", + testArtifacts: []testArtifact{ + { + artifact: shared.Artifact{ + Name: "artifact-1", + DownloadURL: "http://download.com/artifact1.zip", + Expired: false, + }, + files: []string{ + "artifact-1-file", + }, + }, + { + artifact: shared.Artifact{ + Name: "non-artifact-2", + DownloadURL: "http://download.com/non-artifact-2.zip", + Expired: false, + }, + files: []string{ + "non-artifact-2-file", + }, + }, + { + artifact: shared.Artifact{ + Name: "artifact-3", + DownloadURL: "http://download.com/artifact3.zip", + Expired: false, + }, + files: []string{ + "artifact-3-file", + }, + }, + }, + }, + }, + }, + expectedFiles: []string{ + filepath.Join("non-artifact-2-file"), + }, + }, + { + name: "want specific multiple artifacts", + opts: DownloadOptions{ + RunID: "2345", + Names: []string{"artifact-1", "artifact-3"}, + }, + platform: &fakePlatform{ + runs: []run{ + { + id: "2345", + testArtifacts: []testArtifact{ + { + artifact: shared.Artifact{ + Name: "artifact-1", + DownloadURL: "http://download.com/artifact1.zip", + Expired: false, + }, + files: []string{ + "artifact-1-file", + }, + }, + { + artifact: shared.Artifact{ + Name: "non-artifact-2", + DownloadURL: "http://download.com/non-artifact-2.zip", + Expired: false, + }, + files: []string{ + "non-artifact-2-file", + }, + }, + { + artifact: shared.Artifact{ + Name: "artifact-3", + DownloadURL: "http://download.com/artifact3.zip", + Expired: false, + }, + files: []string{ + "artifact-3-file", + }, + }, + }, + }, + }, + }, + expectedFiles: []string{ + filepath.Join("artifact-1", "artifact-1-file"), + filepath.Join("artifact-3", "artifact-3-file"), + }, + }, + { + name: "avoid redownloading files of the same name", + opts: DownloadOptions{ + RunID: "2345", + }, + platform: &fakePlatform{ + runs: []run{ + { + id: "2345", + testArtifacts: []testArtifact{ + { + artifact: shared.Artifact{ + Name: "artifact-1", + DownloadURL: "http://download.com/artifact1.zip", + Expired: false, + }, + files: []string{ + "artifact-1-file", + }, + }, + { + artifact: shared.Artifact{ + Name: "artifact-1", + DownloadURL: "http://download.com/artifact2.zip", + Expired: false, + }, + files: []string{ + "artifact-2-file", + }, + }, + }, + }, + }, + }, + expectedFiles: []string{ + filepath.Join("artifact-1", "artifact-1-file"), }, - wantErr: "no artifact matches any of the names or patterns provided", }, { name: "prompt to select artifact", opts: DownloadOptions{ - RunID: "", - DoPrompt: true, - DestinationDir: ".", - Names: []string(nil), + RunID: "", + DoPrompt: true, + Names: []string(nil), }, - mockAPI: func(p *mockPlatform) { - p.On("List", "").Return([]shared.Artifact{ + platform: &fakePlatform{ + runs: []run{ { - Name: "artifact-1", - DownloadURL: "http://download.com/artifact1.zip", - Expired: false, + id: "2345", + testArtifacts: []testArtifact{ + { + artifact: shared.Artifact{ + Name: "artifact-1", + DownloadURL: "http://download.com/artifact1.zip", + Expired: false, + }, + files: []string{ + "artifact-1-file", + }, + }, + { + artifact: shared.Artifact{ + Name: "expired-artifact", + DownloadURL: "http://download.com/expired.zip", + Expired: true, + }, + files: []string{ + "expired", + }, + }, + }, }, { - Name: "expired-artifact", - DownloadURL: "http://download.com/expired.zip", - Expired: true, + id: "6789", + testArtifacts: []testArtifact{ + { + artifact: shared.Artifact{ + Name: "artifact-2", + DownloadURL: "http://download.com/artifact2.zip", + Expired: false, + }, + files: []string{ + "artifact-2-file", + }, + }, + }, }, - { - Name: "artifact-2", - DownloadURL: "http://download.com/artifact2.zip", - Expired: false, - }, - { - Name: "artifact-2", - DownloadURL: "http://download.com/artifact2.also.zip", - Expired: false, - }, - }, nil) - p.On("Download", "http://download.com/artifact2.zip", ".").Return(nil) + }, }, promptStubs: func(pm *prompter.MockPrompter) { pm.RegisterMultiSelect("Select artifacts to download:", nil, []string{"artifact-1", "artifact-2"}, func(_ string, _, opts []string) ([]int, error) { - return []int{1}, nil + for i, o := range opts { + if o == "artifact-2" { + return []int{i}, nil + } + } + return nil, fmt.Errorf("no artifact-2 found in %v", opts) }) }, + expectedFiles: []string{ + filepath.Join("artifact-2-file"), + }, + }, + { + name: "handling artifact name with path traversal exploit", + opts: DownloadOptions{ + RunID: "2345", + }, + platform: &fakePlatform{ + runs: []run{ + { + id: "2345", + testArtifacts: []testArtifact{ + { + artifact: shared.Artifact{ + Name: "..", + DownloadURL: "http://download.com/artifact1.zip", + Expired: false, + }, + files: []string{ + "etc/passwd", + }, + }, + }, + }, + }, + }, + expectedFiles: []string{}, + wantErr: "error downloading ..: would result in path traversal", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { opts := &tt.opts + if opts.DestinationDir == "" { + opts.DestinationDir = t.TempDir() + } else { + opts.DestinationDir = filepath.Join(t.TempDir(), opts.DestinationDir) + } + ios, _, stdout, stderr := iostreams.Test() opts.IO = ios - opts.Platform = newMockPlatform(t, tt.mockAPI) + opts.Platform = tt.platform pm := prompter.NewMockPrompter(t) opts.Prompter = pm @@ -310,34 +737,31 @@ func Test_runDownload(t *testing.T) { require.NoError(t, err) } + // Check that the exact number of files exist + require.Equal(t, len(tt.expectedFiles), countFilesInDirRecursively(t, opts.DestinationDir)) + + // Then check that the exact files are correct + for _, name := range tt.expectedFiles { + require.FileExists(t, filepath.Join(opts.DestinationDir, name)) + } + assert.Equal(t, "", stdout.String()) assert.Equal(t, "", stderr.String()) }) } } -type mockPlatform struct { - mock.Mock -} +func countFilesInDirRecursively(t *testing.T, dir string) int { + t.Helper() -func newMockPlatform(t *testing.T, config func(*mockPlatform)) *mockPlatform { - m := &mockPlatform{} - m.Test(t) - t.Cleanup(func() { - m.AssertExpectations(t) - }) - if config != nil { - config(m) - } - return m -} + count := 0 + require.NoError(t, filepath.Walk(dir, func(_ string, info os.FileInfo, err error) error { + require.NoError(t, err) + if !info.IsDir() { + count++ + } + return nil + })) -func (p *mockPlatform) List(runID string) ([]shared.Artifact, error) { - args := p.Called(runID) - return args.Get(0).([]shared.Artifact), args.Error(1) -} - -func (p *mockPlatform) Download(url string, dir string) error { - args := p.Called(url, dir) - return args.Error(0) + return count } diff --git a/pkg/cmd/run/download/zip.go b/pkg/cmd/run/download/zip.go index ab5723e94..a68b75fd6 100644 --- a/pkg/cmd/run/download/zip.go +++ b/pkg/cmd/run/download/zip.go @@ -71,13 +71,25 @@ func getPerm(m os.FileMode) os.FileMode { } func filepathDescendsFrom(p, dir string) bool { + // Regardless of the logic below, `p` is never allowed to be current directory `.` or parent directory `..` + // however we check explicitly here before filepath.Rel() which doesn't cover all cases. p = filepath.Clean(p) - dir = filepath.Clean(dir) - if dir == "." && !filepath.IsAbs(p) { - return !strings.HasPrefix(p, ".."+string(filepath.Separator)) + + if p == "." || p == ".." { + return false } - if !strings.HasSuffix(dir, string(filepath.Separator)) { - dir += string(filepath.Separator) + + // filepathDescendsFrom() takes advantage of filepath.Rel() to determine if `p` is descended from `dir`: + // + // 1. filepath.Rel() calculates a path to traversal from fictious `dir` to `p`. + // 2. filepath.Rel() errors in a handful of cases where absolute and relative paths are compared as well as certain traversal edge cases + // For more information, https://github.com/golang/go/blob/00709919d09904b17cfe3bfeb35521cbd3fb04f8/src/path/filepath/path_test.go#L1510-L1515 + // 3. If the path to traverse `dir` to `p` requires `..`, then we know it is not descend from / contained in `dir` + // + // As-is, this function requires the caller to ensure `p` and `dir` are either 1) both relative or 2) both absolute. + relativePath, err := filepath.Rel(dir, p) + if err != nil { + return false } - return strings.HasPrefix(p, dir) + return !strings.HasPrefix(relativePath, "..") } diff --git a/pkg/cmd/run/download/zip_test.go b/pkg/cmd/run/download/zip_test.go index ca401cdb9..b85122ec5 100644 --- a/pkg/cmd/run/download/zip_test.go +++ b/pkg/cmd/run/download/zip_test.go @@ -130,6 +130,86 @@ func Test_filepathDescendsFrom(t *testing.T) { }, want: false, }, + { + name: "deny parent directory filename (`..`) escaping absolute directory", + args: args{ + p: filepath.FromSlash(".."), + dir: filepath.FromSlash("/var/logs/"), + }, + want: false, + }, + { + name: "deny parent directory filename (`..`) escaping current directory", + args: args{ + p: filepath.FromSlash(".."), + dir: filepath.FromSlash("."), + }, + want: false, + }, + { + name: "deny parent directory filename (`..`) escaping parent directory", + args: args{ + p: filepath.FromSlash(".."), + dir: filepath.FromSlash(".."), + }, + want: false, + }, + { + name: "deny parent directory filename (`..`) escaping relative directory", + args: args{ + p: filepath.FromSlash(".."), + dir: filepath.FromSlash("relative-dir"), + }, + want: false, + }, + { + name: "deny current directory filename (`.`) in absolute directory", + args: args{ + p: filepath.FromSlash("."), + dir: filepath.FromSlash("/var/logs/"), + }, + want: false, + }, + { + name: "deny current directory filename (`.`) in current directory", + args: args{ + p: filepath.FromSlash("."), + dir: filepath.FromSlash("."), + }, + want: false, + }, + { + name: "deny current directory filename (`.`) in parent directory", + args: args{ + p: filepath.FromSlash("."), + dir: filepath.FromSlash(".."), + }, + want: false, + }, + { + name: "deny current directory filename (`.`) in relative directory", + args: args{ + p: filepath.FromSlash("."), + dir: filepath.FromSlash("relative-dir"), + }, + want: false, + }, + { + name: "relative path, absolute dir", + args: args{ + p: filepath.FromSlash("whatever"), + dir: filepath.FromSlash("/a/b/c"), + }, + want: false, + }, + { + name: "absolute path, relative dir", + args: args{ + p: filepath.FromSlash("/a/b/c"), + dir: filepath.FromSlash("whatever"), + }, + want: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {