diff --git a/pkg/cmd/gist/clone/clone.go b/pkg/cmd/gist/clone/clone.go index c1dfee31b..41fa104fa 100644 --- a/pkg/cmd/gist/clone/clone.go +++ b/pkg/cmd/gist/clone/clone.go @@ -7,6 +7,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/config" + "github.com/cli/cli/v2/internal/ghinstance" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" "github.com/spf13/cobra" @@ -92,9 +93,15 @@ func cloneRun(opts *CloneOptions) error { } func formatRemoteURL(hostname string, gistID string, protocol string) string { + if ghinstance.IsEnterprise(hostname) { + if protocol == "ssh" { + return fmt.Sprintf("git@%s:gist/%s.git", hostname, gistID) + } + return fmt.Sprintf("https://%s/gist/%s.git", hostname, gistID) + } + if protocol == "ssh" { return fmt.Sprintf("git@gist.%s:%s.git", hostname, gistID) } - return fmt.Sprintf("https://gist.%s/%s.git", hostname, gistID) } diff --git a/pkg/cmd/gist/clone/clone_test.go b/pkg/cmd/gist/clone/clone_test.go index ebf9b5c31..ccca76b25 100644 --- a/pkg/cmd/gist/clone/clone_test.go +++ b/pkg/cmd/gist/clone/clone_test.go @@ -116,3 +116,60 @@ func Test_GistClone_flagError(t *testing.T) { t.Errorf("unexpected error %v", err) } } + +func Test_formatRemoteURL(t *testing.T) { + type args struct { + hostname string + gistID string + protocol string + } + tests := []struct { + name string + args args + want string + }{ + { + name: "github.com HTTPS", + args: args{ + hostname: "github.com", + protocol: "https", + gistID: "ID", + }, + want: "https://gist.github.com/ID.git", + }, + { + name: "github.com SSH", + args: args{ + hostname: "github.com", + protocol: "ssh", + gistID: "ID", + }, + want: "git@gist.github.com:ID.git", + }, + { + name: "Enterprise HTTPS", + args: args{ + hostname: "acme.org", + protocol: "https", + gistID: "ID", + }, + want: "https://acme.org/gist/ID.git", + }, + { + name: "Enterprise SSH", + args: args{ + hostname: "acme.org", + protocol: "ssh", + gistID: "ID", + }, + want: "git@acme.org:gist/ID.git", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := formatRemoteURL(tt.args.hostname, tt.args.gistID, tt.args.protocol); got != tt.want { + t.Errorf("formatRemoteURL() = %v, want %v", got, tt.want) + } + }) + } +}