feat: add copilot command
Signed-off-by: Babak K. Shandiz <babakks@github.com> Co-authored-by: Kynan Ware <bagtoad@github.com> Co-authored-by: Devraj Mehta <devm33@github.com>
This commit is contained in:
parent
0f32f2ac46
commit
39880650d5
2 changed files with 1039 additions and 0 deletions
451
pkg/cmd/copilot/copilot.go
Normal file
451
pkg/cmd/copilot/copilot.go
Normal file
|
|
@ -0,0 +1,451 @@
|
|||
package copilot
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/MakeNowJust/heredoc"
|
||||
"github.com/cli/cli/v2/internal/config"
|
||||
"github.com/cli/cli/v2/internal/prompter"
|
||||
"github.com/cli/cli/v2/internal/safepaths"
|
||||
"github.com/cli/cli/v2/internal/update"
|
||||
ghzip "github.com/cli/cli/v2/internal/zip"
|
||||
"github.com/cli/cli/v2/pkg/cmdutil"
|
||||
"github.com/cli/cli/v2/pkg/iostreams"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type CopilotOptions struct {
|
||||
IO *iostreams.IOStreams
|
||||
HttpClient func() (*http.Client, error)
|
||||
Prompter prompter.Prompter
|
||||
|
||||
CopilotArgs []string
|
||||
Remove bool
|
||||
}
|
||||
|
||||
func NewCmdCopilot(f *cmdutil.Factory, runF func(*CopilotOptions) error) *cobra.Command {
|
||||
opts := &CopilotOptions{
|
||||
IO: f.IOStreams,
|
||||
HttpClient: f.HttpClient,
|
||||
Prompter: f.Prompter,
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "copilot [flags] [args]",
|
||||
Short: "Run the GitHub Copilot CLI (preview)",
|
||||
Long: heredoc.Docf(`
|
||||
Runs the GitHub Copilot CLI.
|
||||
|
||||
Executing the Copilot CLI through %[1]sgh%[1]s is currently in preview and subject to change.
|
||||
|
||||
If already installed, %[1]sgh%[1]s will execute the Copilot CLI found in your %[1]sPATH%[1]s.
|
||||
If the Copilot CLI is not installed, it will be downloaded to %[2]s.
|
||||
|
||||
Use %[1]s--remove%[1]s to remove the downloaded Copilot CLI.
|
||||
|
||||
This command is only supported on Windows, Linux, and Darwin, on amd64/x64
|
||||
or arm64 architectures.
|
||||
|
||||
To prevent %[1]sgh%[1]s from interpreting flags intended for Copilot,
|
||||
use %[1]s--%[1]s before Copilot flags and args.
|
||||
|
||||
Learn more at https://gh.io/copilot-cli
|
||||
`, "`", copilotInstallDir()),
|
||||
Example: heredoc.Doc(`
|
||||
# Download and run the Copilot CLI
|
||||
$ gh copilot
|
||||
|
||||
# Run the Copilot CLI
|
||||
$ gh copilot -p "Summarize this week's commits" --allow-tool 'shell(git)'
|
||||
|
||||
# Remove the Copilot CLI (if installed through gh)
|
||||
$ gh copilot --remove
|
||||
|
||||
# Run the Copilot CLI help command
|
||||
$ gh copilot -- --help
|
||||
`),
|
||||
DisableFlagParsing: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
stopParsePos := -1
|
||||
for i, arg := range args {
|
||||
if arg == "--" {
|
||||
stopParsePos = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
ghArgs := args
|
||||
opts.CopilotArgs = args
|
||||
if stopParsePos >= 0 {
|
||||
ghArgs = args[:stopParsePos]
|
||||
opts.CopilotArgs = args[stopParsePos+1:] // +1 to skip the "--" itself
|
||||
}
|
||||
|
||||
if slices.Contains(ghArgs, "--help") || slices.Contains(ghArgs, "-h") {
|
||||
return cmd.Help()
|
||||
}
|
||||
|
||||
if slices.Contains(ghArgs, "--remove") {
|
||||
hasOtherArgs := len(ghArgs) > 1
|
||||
if stopParsePos >= 0 {
|
||||
hasOtherArgs = hasOtherArgs || len(opts.CopilotArgs) > 0
|
||||
}
|
||||
if hasOtherArgs {
|
||||
return cmdutil.FlagErrorf("cannot use --remove with args")
|
||||
}
|
||||
opts.Remove = true
|
||||
opts.CopilotArgs = nil
|
||||
}
|
||||
|
||||
if runF != nil {
|
||||
return runF(opts)
|
||||
}
|
||||
|
||||
return runCopilot(opts)
|
||||
},
|
||||
}
|
||||
|
||||
cmdutil.DisableAuthCheck(cmd)
|
||||
|
||||
// We add this flag, even though flag parsing is disabled for this command
|
||||
// so the flag still appears in the help text.
|
||||
cmd.Flags().Bool("remove", false, "Remove the downloaded Copilot CLI")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runCopilot(opts *CopilotOptions) error {
|
||||
if opts.Remove {
|
||||
if err := removeCopilot(copilotInstallDir()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if opts.IO.IsStdoutTTY() {
|
||||
fmt.Fprintln(opts.IO.ErrOut, "Copilot CLI removed successfully")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
copilotPath := findCopilotBinary()
|
||||
if copilotPath == "" {
|
||||
if opts.IO.CanPrompt() {
|
||||
confirmed, err := opts.Prompter.Confirm("GitHub Copilot CLI is not installed. Would you like to install it?", true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !confirmed {
|
||||
fmt.Fprintf(opts.IO.ErrOut, "%s Copilot CLI was not installed", opts.IO.ColorScheme().WarningIcon())
|
||||
return cmdutil.SilentError
|
||||
}
|
||||
} else if !update.IsCI() {
|
||||
fmt.Fprintf(opts.IO.ErrOut, "%s Copilot CLI not installed", opts.IO.ColorScheme().WarningIcon())
|
||||
return cmdutil.SilentError
|
||||
}
|
||||
|
||||
httpClient, err := opts.HttpClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
copilotPath, err = downloadCopilot(httpClient, opts.IO, copilotInstallDir(), copilotBinaryPath())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
externalCmd := exec.Command(copilotPath, opts.CopilotArgs...)
|
||||
externalCmd.Stdin = opts.IO.In
|
||||
externalCmd.Stdout = opts.IO.Out
|
||||
externalCmd.Stderr = opts.IO.ErrOut
|
||||
|
||||
if err := externalCmd.Run(); err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
// We terminate with os.Exit here, preserving the exit code from Copilot CLI,
|
||||
// and also preventing stdio writes by callers up the stack.
|
||||
os.Exit(exitErr.ExitCode())
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const copilotBinaryName = "copilot"
|
||||
|
||||
func copilotInstallDir() string {
|
||||
return filepath.Join(config.DataDir(), "copilot")
|
||||
}
|
||||
|
||||
func copilotBinaryPath() string {
|
||||
binaryName := copilotBinaryName
|
||||
if runtime.GOOS == "windows" {
|
||||
binaryName += ".exe"
|
||||
}
|
||||
return filepath.Join(copilotInstallDir(), binaryName)
|
||||
}
|
||||
|
||||
// findCopilotBinary returns the path to the Copilot CLI binary, if installed,
|
||||
// with the following order of precedence:
|
||||
// 1. `copilot` in the PATH
|
||||
// 2. `copilot` in gh's data directory
|
||||
//
|
||||
// If not installed, it returns an empty string.
|
||||
func findCopilotBinary() string {
|
||||
if path, err := exec.LookPath(copilotBinaryName); err == nil {
|
||||
return path
|
||||
}
|
||||
|
||||
localPath := copilotBinaryPath()
|
||||
if _, err := os.Stat(localPath); err != nil {
|
||||
return ""
|
||||
}
|
||||
return localPath
|
||||
}
|
||||
|
||||
// downloadCopilot downloads and installs the Copilot CLI to installDir.
|
||||
// It returns the path to the installed Copilot binary.
|
||||
func downloadCopilot(httpClient *http.Client, ios *iostreams.IOStreams, installDir, localPath string) (string, error) {
|
||||
platform := runtime.GOOS
|
||||
arch := runtime.GOARCH
|
||||
if arch == "amd64" {
|
||||
arch = "x64"
|
||||
}
|
||||
|
||||
if arch != "x64" && arch != "arm64" {
|
||||
return "", fmt.Errorf("unsupported architecture: %s (supported: x64, arm64)", arch)
|
||||
}
|
||||
|
||||
var archiveURL string
|
||||
var archiveName string
|
||||
var isZip bool
|
||||
switch platform {
|
||||
case "windows":
|
||||
archiveName = fmt.Sprintf("copilot-%s-%s.zip", platform, arch)
|
||||
archiveURL = fmt.Sprintf("https://github.com/github/copilot-cli/releases/latest/download/%s", archiveName)
|
||||
isZip = true
|
||||
case "linux", "darwin":
|
||||
archiveName = fmt.Sprintf("copilot-%s-%s.tar.gz", platform, arch)
|
||||
archiveURL = fmt.Sprintf("https://github.com/github/copilot-cli/releases/latest/download/%s", archiveName)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported platform: %s (supported: linux, darwin, windows)", platform)
|
||||
}
|
||||
|
||||
checksumsURL := "https://github.com/github/copilot-cli/releases/latest/download/SHA256SUMS.txt"
|
||||
|
||||
ios.StartProgressIndicatorWithLabel(fmt.Sprintf("Downloading Copilot CLI from %s", archiveURL))
|
||||
defer ios.StopProgressIndicator()
|
||||
|
||||
expectedChecksum, err := fetchExpectedChecksum(httpClient, checksumsURL, archiveName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch checksums: %w", err)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Get(archiveURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to download: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("download failed with status: %s", resp.Status)
|
||||
}
|
||||
|
||||
// Download to temp file while calculating checksum
|
||||
tmpFile, err := os.CreateTemp("", "copilot-download-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp file: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpFile.Name())
|
||||
defer tmpFile.Close()
|
||||
|
||||
hasher := sha256.New()
|
||||
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
|
||||
return "", fmt.Errorf("failed to download: %w", err)
|
||||
}
|
||||
|
||||
ios.StopProgressIndicator()
|
||||
|
||||
// Validate checksum
|
||||
actualChecksumHex := hex.EncodeToString(hasher.Sum(nil))
|
||||
if actualChecksumHex != expectedChecksum {
|
||||
return "", fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksumHex)
|
||||
}
|
||||
|
||||
if _, err := tmpFile.Seek(0, io.SeekStart); err != nil {
|
||||
return "", fmt.Errorf("failed to seek temp file: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(installDir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create install directory: %w", err)
|
||||
}
|
||||
|
||||
// Extract from the downloaded data
|
||||
if isZip {
|
||||
err = extractZip(tmpFile.Name(), installDir)
|
||||
} else {
|
||||
err = extractTarGz(tmpFile, installDir)
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if _, err := os.Stat(localPath); err != nil {
|
||||
return "", fmt.Errorf("copilot binary unavailable: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintf(ios.ErrOut, "%s Copilot CLI installed successfully\n", ios.ColorScheme().SuccessIcon())
|
||||
return localPath, nil
|
||||
}
|
||||
|
||||
// fetchExpectedChecksum downloads the SHA256SUMS.txt file and returns the expected checksum for the given archive name.
|
||||
func fetchExpectedChecksum(httpClient *http.Client, checksumsURL, archiveName string) (string, error) {
|
||||
resp, err := httpClient.Get(checksumsURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("failed to download checksums: %s", resp.Status)
|
||||
}
|
||||
|
||||
// Parse the checksums file. Possible formats are:
|
||||
// - "<checksum> <filename>" (two whitespaces)
|
||||
// - "<checksum> <filename>"
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) >= 2 {
|
||||
checksum := fields[0]
|
||||
filename := fields[1]
|
||||
if filename == archiveName {
|
||||
return checksum, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", fmt.Errorf("failed to read checksums: %w", err)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("checksum not found for %s", archiveName)
|
||||
}
|
||||
|
||||
// extractZip reads a ZIP archive at path and extracts its contents into destDir.
|
||||
// It returns an error if the archive cannot be read,
|
||||
// or if any file or directory within the archive cannot be created or written.
|
||||
func extractZip(path, destDir string) error {
|
||||
zipReader, err := zip.OpenReader(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open zip: %w", err)
|
||||
}
|
||||
defer zipReader.Close()
|
||||
|
||||
absPath, err := safepaths.ParseAbsolute(destDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// As of the time of writing, ghzip.ExtractZip will safely skip files that
|
||||
// would result in path traversal. This is an issue for our use-case because
|
||||
// we want to error out before extracting if there's any such file.
|
||||
// To avoid breaking the shared ghzip.ExtractZip code that expects unsafe
|
||||
// paths to be ignored and no error produced, we pre-validate here,
|
||||
// producing an error if any such file is found.
|
||||
for _, f := range zipReader.File {
|
||||
_, err := absPath.Join(f.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := ghzip.ExtractZip(&zipReader.Reader, absPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractTarGz reads a TAR.GZ archive from r and extracts its contents into destDir.
|
||||
// It returns an error if the archive cannot be read,
|
||||
// or if any file or directory within the archive cannot be created or written.
|
||||
func extractTarGz(r io.Reader, destDir string) error {
|
||||
gzr, err := gzip.NewReader(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gzr.Close()
|
||||
|
||||
absDestDirPath, err := safepaths.ParseAbsolute(destDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tr := tar.NewReader(gzr)
|
||||
for {
|
||||
header, err := tr.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read tar: %w", err)
|
||||
}
|
||||
|
||||
absFilePath, err := absDestDirPath.Join(header.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target := absFilePath.String()
|
||||
|
||||
if header.Typeflag == tar.TypeReg {
|
||||
if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create parent directory: %w", err)
|
||||
}
|
||||
if err := extractFile(target, os.FileMode(header.Mode)&0777, tr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractFile creates a file at target with the given mode and copies content from r.
|
||||
func extractFile(target string, mode os.FileMode, r io.Reader) (err error) {
|
||||
out, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if cerr := out.Close(); err == nil && cerr != nil {
|
||||
err = fmt.Errorf("failed to close file: %w", cerr)
|
||||
}
|
||||
}()
|
||||
if _, err := io.Copy(out, r); err != nil {
|
||||
return fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeCopilot(installDir string) error {
|
||||
if _, err := os.Stat(installDir); os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove Copilot CLI: Copilot CLI not installed through `gh`")
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(installDir); err != nil {
|
||||
return fmt.Errorf("failed to remove Copilot CLI: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
588
pkg/cmd/copilot/copilot_test.go
Normal file
588
pkg/cmd/copilot/copilot_test.go
Normal file
|
|
@ -0,0 +1,588 @@
|
|||
package copilot
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/cli/cli/v2/pkg/cmdutil"
|
||||
"github.com/cli/cli/v2/pkg/httpmock"
|
||||
"github.com/cli/cli/v2/pkg/iostreams"
|
||||
"github.com/google/shlex"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewCmdCopilot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args string
|
||||
wantOpts CopilotOptions
|
||||
wantErrString string
|
||||
wantHelp bool
|
||||
}{
|
||||
{
|
||||
name: "no argument",
|
||||
args: "",
|
||||
wantOpts: CopilotOptions{
|
||||
CopilotArgs: []string{},
|
||||
},
|
||||
wantErrString: "",
|
||||
},
|
||||
{
|
||||
name: "with arguments",
|
||||
args: "some-arg some-other-arg",
|
||||
wantOpts: CopilotOptions{
|
||||
CopilotArgs: []string{"some-arg", "some-other-arg"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with --remove alone",
|
||||
args: "--remove",
|
||||
wantOpts: CopilotOptions{
|
||||
Remove: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with non-gh flags passed to copilot",
|
||||
args: "-p testing --something-flag",
|
||||
wantOpts: CopilotOptions{
|
||||
CopilotArgs: []string{"-p", "testing", "--something-flag"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with --remove and arguments",
|
||||
args: "--remove some-arg",
|
||||
wantErrString: "cannot use --remove with args",
|
||||
},
|
||||
{
|
||||
name: "with --remove passed to copilot using --",
|
||||
args: "-- --remove",
|
||||
wantOpts: CopilotOptions{
|
||||
CopilotArgs: []string{"--remove"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with --remove and -- alone",
|
||||
args: "--remove --",
|
||||
wantOpts: CopilotOptions{
|
||||
Remove: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with --remove, some invalid arg, and --",
|
||||
args: "--remove invalid-arg --",
|
||||
wantErrString: "cannot use --remove with args",
|
||||
},
|
||||
{
|
||||
name: "with --remove and -- and random arguments",
|
||||
args: "--remove -- some-arg",
|
||||
wantErrString: "cannot use --remove with args",
|
||||
},
|
||||
{
|
||||
name: "with --help, shows gh help",
|
||||
args: "--help",
|
||||
wantErrString: "",
|
||||
wantHelp: true,
|
||||
},
|
||||
{
|
||||
name: "with --help and --, shows copilot help",
|
||||
args: "-- --help",
|
||||
wantOpts: CopilotOptions{
|
||||
CopilotArgs: []string{"--help"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f := &cmdutil.Factory{}
|
||||
|
||||
argv, err := shlex.Split(tt.args)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var gotOpts *CopilotOptions
|
||||
cmd := NewCmdCopilot(f, func(opts *CopilotOptions) error {
|
||||
gotOpts = opts
|
||||
return nil
|
||||
})
|
||||
|
||||
cmd.SetArgs(argv)
|
||||
cmd.SetIn(&bytes.Buffer{})
|
||||
cmd.SetOut(&bytes.Buffer{})
|
||||
cmd.SetErr(&bytes.Buffer{})
|
||||
|
||||
_, err = cmd.ExecuteC()
|
||||
if tt.wantErrString != "" {
|
||||
require.EqualError(t, err, tt.wantErrString)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantHelp {
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantOpts.CopilotArgs, gotOpts.CopilotArgs, "opts.CopilotArgs not as expected")
|
||||
assert.Equal(t, tt.wantOpts.Remove, gotOpts.Remove, "opts.Remove not as expected")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveCopilot(t *testing.T) {
|
||||
t.Run("removes existing install directory", func(t *testing.T) {
|
||||
// Create a temporary directory to simulate the install directory
|
||||
tmpDir := t.TempDir()
|
||||
installDir := filepath.Join(tmpDir, "copilot")
|
||||
require.NoError(t, os.MkdirAll(installDir, 0755), "failed to create test directory")
|
||||
// Create a dummy file in the directory
|
||||
dummyFile := filepath.Join(installDir, "copilot")
|
||||
require.NoError(t, os.WriteFile(dummyFile, []byte("test"), 0755), "failed to create test file")
|
||||
|
||||
err := removeCopilot(installDir)
|
||||
require.NoError(t, err, "unexpected error")
|
||||
|
||||
_, err = os.Stat(installDir)
|
||||
require.True(t, os.IsNotExist(err), "expected install directory to be removed")
|
||||
})
|
||||
|
||||
t.Run("handles non-existent directory", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
installDir := filepath.Join(tmpDir, "copilot")
|
||||
|
||||
require.ErrorContains(t, removeCopilot(installDir), "failed to remove Copilot CLI")
|
||||
})
|
||||
}
|
||||
|
||||
// createTarGzBuffer creates a tar.gz archive in memory with the given files.
|
||||
func createTarGzBuffer(t *testing.T, files map[string][]byte) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
gw := gzip.NewWriter(&buf)
|
||||
tw := tar.NewWriter(gw)
|
||||
|
||||
for name, content := range files {
|
||||
hdr := &tar.Header{
|
||||
Name: name,
|
||||
Mode: 0755,
|
||||
Size: int64(len(content)),
|
||||
}
|
||||
require.NoError(t, tw.WriteHeader(hdr), "failed to write tar header")
|
||||
_, err := tw.Write(content)
|
||||
require.NoError(t, err, "failed to write tar content")
|
||||
}
|
||||
|
||||
require.NoError(t, tw.Close(), "failed to close tar writer")
|
||||
require.NoError(t, gw.Close(), "failed to close gzip writer")
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// createZipBuffer creates a zip archive in memory with the given files.
|
||||
func createZipBuffer(t *testing.T, files map[string][]byte) []byte {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
for name, content := range files {
|
||||
fw, err := zw.Create(name)
|
||||
require.NoError(t, err, "failed to create zip entry")
|
||||
_, err = fw.Write(content)
|
||||
require.NoError(t, err, "failed to write zip content")
|
||||
}
|
||||
|
||||
require.NoError(t, zw.Close(), "failed to close zip writer")
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestExtractTarGz(t *testing.T) {
|
||||
t.Run("extracts files correctly", func(t *testing.T) {
|
||||
content := []byte("hello world")
|
||||
archive := createTarGzBuffer(t, map[string][]byte{
|
||||
"copilot": content,
|
||||
})
|
||||
|
||||
destDir := t.TempDir()
|
||||
|
||||
err := extractTarGz(bytes.NewReader(archive), destDir)
|
||||
require.NoError(t, err, "extractTarGz() error")
|
||||
|
||||
extracted, err := os.ReadFile(filepath.Join(destDir, "copilot"))
|
||||
require.NoError(t, err, "failed to read extracted file")
|
||||
require.Equal(t, content, extracted, "extracted content mismatch")
|
||||
})
|
||||
|
||||
t.Run("extracts nested files", func(t *testing.T) {
|
||||
content := []byte("nested content")
|
||||
archive := createTarGzBuffer(t, map[string][]byte{
|
||||
"subdir/file.txt": content,
|
||||
})
|
||||
|
||||
destDir := t.TempDir()
|
||||
|
||||
err := extractTarGz(bytes.NewReader(archive), destDir)
|
||||
require.NoError(t, err, "extractTarGz() error")
|
||||
|
||||
extracted, err := os.ReadFile(filepath.Join(destDir, "subdir", "file.txt"))
|
||||
require.NoError(t, err, "failed to read extracted file")
|
||||
require.Equal(t, content, extracted, "extracted content mismatch")
|
||||
})
|
||||
|
||||
t.Run("rejects path traversal", func(t *testing.T) {
|
||||
// Manually create a malicious tar.gz with path traversal
|
||||
var buf bytes.Buffer
|
||||
gw := gzip.NewWriter(&buf)
|
||||
tw := tar.NewWriter(gw)
|
||||
|
||||
hdr := &tar.Header{
|
||||
Name: "../evil.txt",
|
||||
Mode: 0755,
|
||||
Size: 4,
|
||||
}
|
||||
_ = tw.WriteHeader(hdr)
|
||||
_, _ = tw.Write([]byte("evil"))
|
||||
_ = tw.Close()
|
||||
_ = gw.Close()
|
||||
|
||||
destDir := t.TempDir()
|
||||
|
||||
err := extractTarGz(bytes.NewReader(buf.Bytes()), destDir)
|
||||
require.Error(t, err, "expected error for path traversal, got nil")
|
||||
})
|
||||
|
||||
t.Run("handles invalid gzip", func(t *testing.T) {
|
||||
destDir := t.TempDir()
|
||||
|
||||
err := extractTarGz(bytes.NewReader([]byte("not valid gzip")), destDir)
|
||||
require.Error(t, err, "expected error for invalid gzip, got nil")
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractZip(t *testing.T) {
|
||||
t.Run("extracts files correctly", func(t *testing.T) {
|
||||
zipDir := t.TempDir()
|
||||
zipPath := filepath.Join(zipDir, "archive.zip")
|
||||
content := []byte("hello world")
|
||||
archive := createZipBuffer(t, map[string][]byte{
|
||||
"copilot.exe": content,
|
||||
})
|
||||
require.NoError(t, os.WriteFile(zipPath, archive, 0x755))
|
||||
|
||||
destDir := t.TempDir()
|
||||
|
||||
err := extractZip(zipPath, destDir)
|
||||
require.NoError(t, err, "extractZip() error")
|
||||
|
||||
extracted, err := os.ReadFile(filepath.Join(destDir, "copilot.exe"))
|
||||
require.NoError(t, err, "failed to read extracted file")
|
||||
require.Equal(t, content, extracted, "extracted content mismatch")
|
||||
})
|
||||
|
||||
t.Run("extracts nested files", func(t *testing.T) {
|
||||
zipDir := t.TempDir()
|
||||
zipPath := filepath.Join(zipDir, "archive.zip")
|
||||
content := []byte("hello world")
|
||||
archive := createZipBuffer(t, map[string][]byte{
|
||||
"subdir/file.txt": content,
|
||||
})
|
||||
require.NoError(t, os.WriteFile(zipPath, archive, 0x755))
|
||||
|
||||
destDir := t.TempDir()
|
||||
|
||||
err := extractZip(zipPath, destDir)
|
||||
require.NoError(t, err, "extractZip() error")
|
||||
|
||||
extracted, err := os.ReadFile(filepath.Join(destDir, "subdir", "file.txt"))
|
||||
require.NoError(t, err, "failed to read extracted file")
|
||||
require.Equal(t, content, extracted, "extracted content mismatch")
|
||||
})
|
||||
|
||||
t.Run("rejects path traversal", func(t *testing.T) {
|
||||
zipDir := t.TempDir()
|
||||
zipPath := filepath.Join(zipDir, "archive.zip")
|
||||
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
|
||||
fh := &zip.FileHeader{
|
||||
Name: "../evil.txt",
|
||||
Method: zip.Store,
|
||||
}
|
||||
fw, _ := zw.CreateHeader(fh)
|
||||
_, _ = fw.Write([]byte("evil"))
|
||||
_ = zw.Close()
|
||||
|
||||
require.NoError(t, os.WriteFile(zipPath, buf.Bytes(), 0x755))
|
||||
destDir := t.TempDir()
|
||||
|
||||
err := extractZip(zipPath, destDir)
|
||||
require.Error(t, err, "expected error for path traversal, got nil")
|
||||
})
|
||||
}
|
||||
|
||||
func TestFetchExpectedChecksum(t *testing.T) {
|
||||
t.Run("parses checksums file correctly", func(t *testing.T) {
|
||||
reg := &httpmock.Registry{}
|
||||
checksums := "abc123def456 copilot-linux-x64.tar.gz\n789xyz copilot-darwin-arm64.tar.gz\n"
|
||||
reg.Register(
|
||||
httpmock.MatchAny,
|
||||
httpmock.StringResponse(checksums),
|
||||
)
|
||||
|
||||
client := &http.Client{Transport: reg}
|
||||
checksum, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-linux-x64.tar.gz")
|
||||
require.NoError(t, err, "unexpected error")
|
||||
require.Equal(t, "abc123def456", checksum, "checksum mismatch")
|
||||
})
|
||||
|
||||
t.Run("returns error for missing archive", func(t *testing.T) {
|
||||
reg := &httpmock.Registry{}
|
||||
checksums := "abc123 copilot-linux-x64.tar.gz\n"
|
||||
reg.Register(
|
||||
httpmock.MatchAny,
|
||||
httpmock.StringResponse(checksums),
|
||||
)
|
||||
|
||||
client := &http.Client{Transport: reg}
|
||||
_, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-windows-x64.zip")
|
||||
require.Error(t, err, "expected error for missing archive")
|
||||
require.Equal(t, "checksum not found for copilot-windows-x64.zip", err.Error(), "unexpected error")
|
||||
})
|
||||
|
||||
t.Run("handles single space separator", func(t *testing.T) {
|
||||
reg := &httpmock.Registry{}
|
||||
checksums := "abc123 copilot-darwin-x64.tar.gz\n"
|
||||
reg.Register(
|
||||
httpmock.MatchAny,
|
||||
httpmock.StringResponse(checksums),
|
||||
)
|
||||
|
||||
client := &http.Client{Transport: reg}
|
||||
checksum, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-darwin-x64.tar.gz")
|
||||
require.NoError(t, err, "unexpected error")
|
||||
require.Equal(t, "abc123", checksum, "checksum mismatch")
|
||||
})
|
||||
|
||||
t.Run("handles HTTP error", func(t *testing.T) {
|
||||
reg := &httpmock.Registry{}
|
||||
reg.Register(
|
||||
httpmock.MatchAny,
|
||||
httpmock.StatusStringResponse(http.StatusNotFound, "not found"),
|
||||
)
|
||||
|
||||
client := &http.Client{Transport: reg}
|
||||
_, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-linux-x64.tar.gz")
|
||||
require.Error(t, err, "expected error for HTTP 404")
|
||||
})
|
||||
}
|
||||
|
||||
func archString() string {
|
||||
arch := runtime.GOARCH
|
||||
if arch == "amd64" {
|
||||
return "x64"
|
||||
}
|
||||
return arch
|
||||
}
|
||||
|
||||
func TestDownloadCopilot(t *testing.T) {
|
||||
// Skip on unsupported architectures
|
||||
if runtime.GOARCH != "amd64" && runtime.GOARCH != "arm64" {
|
||||
t.Skip("skipping test on unsupported architecture")
|
||||
}
|
||||
|
||||
t.Run("downloads and extracts tar.gz with valid checksum", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping tar.gz test on windows")
|
||||
}
|
||||
|
||||
ios, _, _, stderr := iostreams.Test()
|
||||
tmpDir := t.TempDir()
|
||||
installDir := filepath.Join(tmpDir, "copilot")
|
||||
localPath := filepath.Join(installDir, "copilot")
|
||||
|
||||
// Create mock archive with copilot binary
|
||||
binaryContent := []byte("#!/bin/sh\necho copilot")
|
||||
archive := createTarGzBuffer(t, map[string][]byte{
|
||||
"copilot": binaryContent,
|
||||
})
|
||||
|
||||
// Calculate checksum
|
||||
checksum := sha256.Sum256(archive)
|
||||
checksumHex := hex.EncodeToString(checksum[:])
|
||||
archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString())
|
||||
checksumFile := fmt.Sprintf("%s %s\n", checksumHex, archiveName)
|
||||
|
||||
reg := &httpmock.Registry{}
|
||||
// Register checksum endpoint
|
||||
reg.Register(
|
||||
httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"),
|
||||
httpmock.StringResponse(checksumFile),
|
||||
)
|
||||
// Register archive endpoint
|
||||
reg.Register(
|
||||
httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)),
|
||||
httpmock.BinaryResponse(archive),
|
||||
)
|
||||
|
||||
httpClient := &http.Client{Transport: reg}
|
||||
|
||||
path, err := downloadCopilot(httpClient, ios, installDir, localPath)
|
||||
require.NoError(t, err, "downloadCopilot() error")
|
||||
require.Equal(t, localPath, path, "downloadCopilot() path mismatch")
|
||||
|
||||
// Verify binary was extracted
|
||||
extracted, err := os.ReadFile(localPath)
|
||||
require.NoError(t, err, "failed to read extracted binary")
|
||||
require.Equal(t, binaryContent, extracted, "extracted content mismatch")
|
||||
|
||||
// Verify output messages
|
||||
require.Contains(t, stderr.String(), "installed successfully", "expected success message in stderr")
|
||||
})
|
||||
|
||||
t.Run("fails with checksum mismatch", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping tar.gz test on windows")
|
||||
}
|
||||
|
||||
ios, _, _, _ := iostreams.Test()
|
||||
tmpDir := t.TempDir()
|
||||
installDir := filepath.Join(tmpDir, "copilot")
|
||||
localPath := filepath.Join(installDir, "copilot")
|
||||
|
||||
binaryContent := []byte("#!/bin/sh\necho copilot")
|
||||
archive := createTarGzBuffer(t, map[string][]byte{
|
||||
"copilot": binaryContent,
|
||||
})
|
||||
|
||||
// Use wrong checksum
|
||||
archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString())
|
||||
checksumFile := fmt.Sprintf("%s %s\n", "0000000000000000000000000000000000000000000000000000000000000000", archiveName)
|
||||
|
||||
reg := &httpmock.Registry{}
|
||||
reg.Register(
|
||||
httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"),
|
||||
httpmock.StringResponse(checksumFile),
|
||||
)
|
||||
reg.Register(
|
||||
httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)),
|
||||
httpmock.BinaryResponse(archive),
|
||||
)
|
||||
|
||||
httpClient := &http.Client{Transport: reg}
|
||||
|
||||
_, err := downloadCopilot(httpClient, ios, installDir, localPath)
|
||||
require.Error(t, err, "expected error for checksum mismatch, got nil")
|
||||
require.Contains(t, err.Error(), "checksum mismatch", "expected checksum mismatch error")
|
||||
})
|
||||
|
||||
t.Run("handles HTTP error on archive download", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping tar.gz test on windows")
|
||||
}
|
||||
|
||||
ios, _, _, _ := iostreams.Test()
|
||||
tmpDir := t.TempDir()
|
||||
installDir := filepath.Join(tmpDir, "copilot")
|
||||
localPath := filepath.Join(installDir, "copilot")
|
||||
|
||||
archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString())
|
||||
checksumFile := fmt.Sprintf("%s %s\n", "abc123", archiveName)
|
||||
|
||||
reg := &httpmock.Registry{}
|
||||
reg.Register(
|
||||
httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"),
|
||||
httpmock.StringResponse(checksumFile),
|
||||
)
|
||||
reg.Register(
|
||||
httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)),
|
||||
httpmock.StatusStringResponse(http.StatusNotFound, "not found"),
|
||||
)
|
||||
|
||||
httpClient := &http.Client{Transport: reg}
|
||||
|
||||
_, err := downloadCopilot(httpClient, ios, installDir, localPath)
|
||||
require.Error(t, err, "expected error for HTTP 404, got nil")
|
||||
require.Contains(t, err.Error(), "download failed", "expected error to contain 'download failed'")
|
||||
})
|
||||
|
||||
t.Run("handles missing binary after extraction", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping tar.gz test on windows")
|
||||
}
|
||||
|
||||
ios, _, _, _ := iostreams.Test()
|
||||
tmpDir := t.TempDir()
|
||||
installDir := filepath.Join(tmpDir, "copilot")
|
||||
localPath := filepath.Join(installDir, "copilot")
|
||||
|
||||
// Create archive without the expected binary name
|
||||
archive := createTarGzBuffer(t, map[string][]byte{
|
||||
"wrong-name": []byte("content"),
|
||||
})
|
||||
|
||||
checksum := sha256.Sum256(archive)
|
||||
checksumHex := hex.EncodeToString(checksum[:])
|
||||
archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString())
|
||||
checksumFile := fmt.Sprintf("%s %s\n", checksumHex, archiveName)
|
||||
|
||||
reg := &httpmock.Registry{}
|
||||
reg.Register(
|
||||
httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"),
|
||||
httpmock.StringResponse(checksumFile),
|
||||
)
|
||||
reg.Register(
|
||||
httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)),
|
||||
httpmock.BinaryResponse(archive),
|
||||
)
|
||||
|
||||
httpClient := &http.Client{Transport: reg}
|
||||
|
||||
_, err := downloadCopilot(httpClient, ios, installDir, localPath)
|
||||
assert.ErrorContains(t, err, "copilot binary unavailable")
|
||||
})
|
||||
|
||||
t.Run("downloads and extracts zip on windows", func(t *testing.T) {
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("skipping zip test on non-windows")
|
||||
}
|
||||
|
||||
ios, _, _, _ := iostreams.Test()
|
||||
tmpDir := t.TempDir()
|
||||
installDir := filepath.Join(tmpDir, "copilot")
|
||||
localPath := filepath.Join(installDir, "copilot.exe")
|
||||
|
||||
binaryContent := []byte("MZ fake exe content")
|
||||
archive := createZipBuffer(t, map[string][]byte{
|
||||
"copilot.exe": binaryContent,
|
||||
})
|
||||
|
||||
checksum := sha256.Sum256(archive)
|
||||
checksumHex := hex.EncodeToString(checksum[:])
|
||||
archiveName := fmt.Sprintf("copilot-%s-%s.zip", runtime.GOOS, archString())
|
||||
checksumFile := fmt.Sprintf("%s %s\n", checksumHex, archiveName)
|
||||
|
||||
reg := &httpmock.Registry{}
|
||||
reg.Register(
|
||||
httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"),
|
||||
httpmock.StringResponse(checksumFile),
|
||||
)
|
||||
reg.Register(
|
||||
httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)),
|
||||
httpmock.BinaryResponse(archive),
|
||||
)
|
||||
|
||||
httpClient := &http.Client{Transport: reg}
|
||||
|
||||
path, err := downloadCopilot(httpClient, ios, installDir, localPath)
|
||||
require.NoError(t, err, "downloadCopilot() error")
|
||||
require.Equal(t, localPath, path, "downloadCopilot() path mismatch")
|
||||
})
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue