diff --git a/pkg/cmd/copilot/copilot.go b/pkg/cmd/copilot/copilot.go new file mode 100644 index 000000000..1684e507f --- /dev/null +++ b/pkg/cmd/copilot/copilot.go @@ -0,0 +1,451 @@ +package copilot + +import ( + "archive/tar" + "archive/zip" + "bufio" + "compress/gzip" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "slices" + "strings" + + "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/v2/internal/config" + "github.com/cli/cli/v2/internal/prompter" + "github.com/cli/cli/v2/internal/safepaths" + "github.com/cli/cli/v2/internal/update" + ghzip "github.com/cli/cli/v2/internal/zip" + "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/iostreams" + "github.com/spf13/cobra" +) + +type CopilotOptions struct { + IO *iostreams.IOStreams + HttpClient func() (*http.Client, error) + Prompter prompter.Prompter + + CopilotArgs []string + Remove bool +} + +func NewCmdCopilot(f *cmdutil.Factory, runF func(*CopilotOptions) error) *cobra.Command { + opts := &CopilotOptions{ + IO: f.IOStreams, + HttpClient: f.HttpClient, + Prompter: f.Prompter, + } + + cmd := &cobra.Command{ + Use: "copilot [flags] [args]", + Short: "Run the GitHub Copilot CLI (preview)", + Long: heredoc.Docf(` + Runs the GitHub Copilot CLI. + + Executing the Copilot CLI through %[1]sgh%[1]s is currently in preview and subject to change. + + If already installed, %[1]sgh%[1]s will execute the Copilot CLI found in your %[1]sPATH%[1]s. + If the Copilot CLI is not installed, it will be downloaded to %[2]s. + + Use %[1]s--remove%[1]s to remove the downloaded Copilot CLI. + + This command is only supported on Windows, Linux, and Darwin, on amd64/x64 + or arm64 architectures. + + To prevent %[1]sgh%[1]s from interpreting flags intended for Copilot, + use %[1]s--%[1]s before Copilot flags and args. + + Learn more at https://gh.io/copilot-cli + `, "`", copilotInstallDir()), + Example: heredoc.Doc(` + # Download and run the Copilot CLI + $ gh copilot + + # Run the Copilot CLI + $ gh copilot -p "Summarize this week's commits" --allow-tool 'shell(git)' + + # Remove the Copilot CLI (if installed through gh) + $ gh copilot --remove + + # Run the Copilot CLI help command + $ gh copilot -- --help + `), + DisableFlagParsing: true, + RunE: func(cmd *cobra.Command, args []string) error { + stopParsePos := -1 + for i, arg := range args { + if arg == "--" { + stopParsePos = i + break + } + } + + ghArgs := args + opts.CopilotArgs = args + if stopParsePos >= 0 { + ghArgs = args[:stopParsePos] + opts.CopilotArgs = args[stopParsePos+1:] // +1 to skip the "--" itself + } + + if slices.Contains(ghArgs, "--help") || slices.Contains(ghArgs, "-h") { + return cmd.Help() + } + + if slices.Contains(ghArgs, "--remove") { + hasOtherArgs := len(ghArgs) > 1 + if stopParsePos >= 0 { + hasOtherArgs = hasOtherArgs || len(opts.CopilotArgs) > 0 + } + if hasOtherArgs { + return cmdutil.FlagErrorf("cannot use --remove with args") + } + opts.Remove = true + opts.CopilotArgs = nil + } + + if runF != nil { + return runF(opts) + } + + return runCopilot(opts) + }, + } + + cmdutil.DisableAuthCheck(cmd) + + // We add this flag, even though flag parsing is disabled for this command + // so the flag still appears in the help text. + cmd.Flags().Bool("remove", false, "Remove the downloaded Copilot CLI") + return cmd +} + +func runCopilot(opts *CopilotOptions) error { + if opts.Remove { + if err := removeCopilot(copilotInstallDir()); err != nil { + return err + } + + if opts.IO.IsStdoutTTY() { + fmt.Fprintln(opts.IO.ErrOut, "Copilot CLI removed successfully") + } + return nil + } + + copilotPath := findCopilotBinary() + if copilotPath == "" { + if opts.IO.CanPrompt() { + confirmed, err := opts.Prompter.Confirm("GitHub Copilot CLI is not installed. Would you like to install it?", true) + if err != nil { + return err + } + if !confirmed { + fmt.Fprintf(opts.IO.ErrOut, "%s Copilot CLI was not installed", opts.IO.ColorScheme().WarningIcon()) + return cmdutil.SilentError + } + } else if !update.IsCI() { + fmt.Fprintf(opts.IO.ErrOut, "%s Copilot CLI not installed", opts.IO.ColorScheme().WarningIcon()) + return cmdutil.SilentError + } + + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + + copilotPath, err = downloadCopilot(httpClient, opts.IO, copilotInstallDir(), copilotBinaryPath()) + if err != nil { + return err + } + } + + externalCmd := exec.Command(copilotPath, opts.CopilotArgs...) + externalCmd.Stdin = opts.IO.In + externalCmd.Stdout = opts.IO.Out + externalCmd.Stderr = opts.IO.ErrOut + + if err := externalCmd.Run(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + // We terminate with os.Exit here, preserving the exit code from Copilot CLI, + // and also preventing stdio writes by callers up the stack. + os.Exit(exitErr.ExitCode()) + } + return err + } + return nil +} + +const copilotBinaryName = "copilot" + +func copilotInstallDir() string { + return filepath.Join(config.DataDir(), "copilot") +} + +func copilotBinaryPath() string { + binaryName := copilotBinaryName + if runtime.GOOS == "windows" { + binaryName += ".exe" + } + return filepath.Join(copilotInstallDir(), binaryName) +} + +// findCopilotBinary returns the path to the Copilot CLI binary, if installed, +// with the following order of precedence: +// 1. `copilot` in the PATH +// 2. `copilot` in gh's data directory +// +// If not installed, it returns an empty string. +func findCopilotBinary() string { + if path, err := exec.LookPath(copilotBinaryName); err == nil { + return path + } + + localPath := copilotBinaryPath() + if _, err := os.Stat(localPath); err != nil { + return "" + } + return localPath +} + +// downloadCopilot downloads and installs the Copilot CLI to installDir. +// It returns the path to the installed Copilot binary. +func downloadCopilot(httpClient *http.Client, ios *iostreams.IOStreams, installDir, localPath string) (string, error) { + platform := runtime.GOOS + arch := runtime.GOARCH + if arch == "amd64" { + arch = "x64" + } + + if arch != "x64" && arch != "arm64" { + return "", fmt.Errorf("unsupported architecture: %s (supported: x64, arm64)", arch) + } + + var archiveURL string + var archiveName string + var isZip bool + switch platform { + case "windows": + archiveName = fmt.Sprintf("copilot-%s-%s.zip", platform, arch) + archiveURL = fmt.Sprintf("https://github.com/github/copilot-cli/releases/latest/download/%s", archiveName) + isZip = true + case "linux", "darwin": + archiveName = fmt.Sprintf("copilot-%s-%s.tar.gz", platform, arch) + archiveURL = fmt.Sprintf("https://github.com/github/copilot-cli/releases/latest/download/%s", archiveName) + default: + return "", fmt.Errorf("unsupported platform: %s (supported: linux, darwin, windows)", platform) + } + + checksumsURL := "https://github.com/github/copilot-cli/releases/latest/download/SHA256SUMS.txt" + + ios.StartProgressIndicatorWithLabel(fmt.Sprintf("Downloading Copilot CLI from %s", archiveURL)) + defer ios.StopProgressIndicator() + + expectedChecksum, err := fetchExpectedChecksum(httpClient, checksumsURL, archiveName) + if err != nil { + return "", fmt.Errorf("failed to fetch checksums: %w", err) + } + + resp, err := httpClient.Get(archiveURL) + if err != nil { + return "", fmt.Errorf("failed to download: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("download failed with status: %s", resp.Status) + } + + // Download to temp file while calculating checksum + tmpFile, err := os.CreateTemp("", "copilot-download-*") + if err != nil { + return "", fmt.Errorf("failed to create temp file: %w", err) + } + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + + hasher := sha256.New() + if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil { + return "", fmt.Errorf("failed to download: %w", err) + } + + ios.StopProgressIndicator() + + // Validate checksum + actualChecksumHex := hex.EncodeToString(hasher.Sum(nil)) + if actualChecksumHex != expectedChecksum { + return "", fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksumHex) + } + + if _, err := tmpFile.Seek(0, io.SeekStart); err != nil { + return "", fmt.Errorf("failed to seek temp file: %w", err) + } + + if err := os.MkdirAll(installDir, 0755); err != nil { + return "", fmt.Errorf("failed to create install directory: %w", err) + } + + // Extract from the downloaded data + if isZip { + err = extractZip(tmpFile.Name(), installDir) + } else { + err = extractTarGz(tmpFile, installDir) + } + if err != nil { + return "", err + } + + if _, err := os.Stat(localPath); err != nil { + return "", fmt.Errorf("copilot binary unavailable: %w", err) + } + + fmt.Fprintf(ios.ErrOut, "%s Copilot CLI installed successfully\n", ios.ColorScheme().SuccessIcon()) + return localPath, nil +} + +// fetchExpectedChecksum downloads the SHA256SUMS.txt file and returns the expected checksum for the given archive name. +func fetchExpectedChecksum(httpClient *http.Client, checksumsURL, archiveName string) (string, error) { + resp, err := httpClient.Get(checksumsURL) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("failed to download checksums: %s", resp.Status) + } + + // Parse the checksums file. Possible formats are: + // - " " (two whitespaces) + // - " " + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) >= 2 { + checksum := fields[0] + filename := fields[1] + if filename == archiveName { + return checksum, nil + } + } + } + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("failed to read checksums: %w", err) + } + + return "", fmt.Errorf("checksum not found for %s", archiveName) +} + +// extractZip reads a ZIP archive at path and extracts its contents into destDir. +// It returns an error if the archive cannot be read, +// or if any file or directory within the archive cannot be created or written. +func extractZip(path, destDir string) error { + zipReader, err := zip.OpenReader(path) + if err != nil { + return fmt.Errorf("failed to open zip: %w", err) + } + defer zipReader.Close() + + absPath, err := safepaths.ParseAbsolute(destDir) + if err != nil { + return err + } + + // As of the time of writing, ghzip.ExtractZip will safely skip files that + // would result in path traversal. This is an issue for our use-case because + // we want to error out before extracting if there's any such file. + // To avoid breaking the shared ghzip.ExtractZip code that expects unsafe + // paths to be ignored and no error produced, we pre-validate here, + // producing an error if any such file is found. + for _, f := range zipReader.File { + _, err := absPath.Join(f.Name) + if err != nil { + return err + } + } + + if err := ghzip.ExtractZip(&zipReader.Reader, absPath); err != nil { + return err + } + + return nil +} + +// extractTarGz reads a TAR.GZ archive from r and extracts its contents into destDir. +// It returns an error if the archive cannot be read, +// or if any file or directory within the archive cannot be created or written. +func extractTarGz(r io.Reader, destDir string) error { + gzr, err := gzip.NewReader(r) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gzr.Close() + + absDestDirPath, err := safepaths.ParseAbsolute(destDir) + if err != nil { + return err + } + + tr := tar.NewReader(gzr) + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tar: %w", err) + } + + absFilePath, err := absDestDirPath.Join(header.Name) + if err != nil { + return err + } + target := absFilePath.String() + + if header.Typeflag == tar.TypeReg { + if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + if err := extractFile(target, os.FileMode(header.Mode)&0777, tr); err != nil { + return err + } + } + } + return nil +} + +// extractFile creates a file at target with the given mode and copies content from r. +func extractFile(target string, mode os.FileMode, r io.Reader) (err error) { + out, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer func() { + if cerr := out.Close(); err == nil && cerr != nil { + err = fmt.Errorf("failed to close file: %w", cerr) + } + }() + if _, err := io.Copy(out, r); err != nil { + return fmt.Errorf("failed to write file: %w", err) + } + return nil +} + +func removeCopilot(installDir string) error { + if _, err := os.Stat(installDir); os.IsNotExist(err) { + return fmt.Errorf("failed to remove Copilot CLI: Copilot CLI not installed through `gh`") + } + + if err := os.RemoveAll(installDir); err != nil { + return fmt.Errorf("failed to remove Copilot CLI: %w", err) + } + + return nil +} diff --git a/pkg/cmd/copilot/copilot_test.go b/pkg/cmd/copilot/copilot_test.go new file mode 100644 index 000000000..94e5f8b9f --- /dev/null +++ b/pkg/cmd/copilot/copilot_test.go @@ -0,0 +1,588 @@ +package copilot + +import ( + "archive/tar" + "archive/zip" + "bytes" + "compress/gzip" + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/httpmock" + "github.com/cli/cli/v2/pkg/iostreams" + "github.com/google/shlex" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCmdCopilot(t *testing.T) { + tests := []struct { + name string + args string + wantOpts CopilotOptions + wantErrString string + wantHelp bool + }{ + { + name: "no argument", + args: "", + wantOpts: CopilotOptions{ + CopilotArgs: []string{}, + }, + wantErrString: "", + }, + { + name: "with arguments", + args: "some-arg some-other-arg", + wantOpts: CopilotOptions{ + CopilotArgs: []string{"some-arg", "some-other-arg"}, + }, + }, + { + name: "with --remove alone", + args: "--remove", + wantOpts: CopilotOptions{ + Remove: true, + }, + }, + { + name: "with non-gh flags passed to copilot", + args: "-p testing --something-flag", + wantOpts: CopilotOptions{ + CopilotArgs: []string{"-p", "testing", "--something-flag"}, + }, + }, + { + name: "with --remove and arguments", + args: "--remove some-arg", + wantErrString: "cannot use --remove with args", + }, + { + name: "with --remove passed to copilot using --", + args: "-- --remove", + wantOpts: CopilotOptions{ + CopilotArgs: []string{"--remove"}, + }, + }, + { + name: "with --remove and -- alone", + args: "--remove --", + wantOpts: CopilotOptions{ + Remove: true, + }, + }, + { + name: "with --remove, some invalid arg, and --", + args: "--remove invalid-arg --", + wantErrString: "cannot use --remove with args", + }, + { + name: "with --remove and -- and random arguments", + args: "--remove -- some-arg", + wantErrString: "cannot use --remove with args", + }, + { + name: "with --help, shows gh help", + args: "--help", + wantErrString: "", + wantHelp: true, + }, + { + name: "with --help and --, shows copilot help", + args: "-- --help", + wantOpts: CopilotOptions{ + CopilotArgs: []string{"--help"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &cmdutil.Factory{} + + argv, err := shlex.Split(tt.args) + assert.NoError(t, err) + + var gotOpts *CopilotOptions + cmd := NewCmdCopilot(f, func(opts *CopilotOptions) error { + gotOpts = opts + return nil + }) + + cmd.SetArgs(argv) + cmd.SetIn(&bytes.Buffer{}) + cmd.SetOut(&bytes.Buffer{}) + cmd.SetErr(&bytes.Buffer{}) + + _, err = cmd.ExecuteC() + if tt.wantErrString != "" { + require.EqualError(t, err, tt.wantErrString) + return + } + + if tt.wantHelp { + require.NoError(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantOpts.CopilotArgs, gotOpts.CopilotArgs, "opts.CopilotArgs not as expected") + assert.Equal(t, tt.wantOpts.Remove, gotOpts.Remove, "opts.Remove not as expected") + }) + } +} + +func TestRemoveCopilot(t *testing.T) { + t.Run("removes existing install directory", func(t *testing.T) { + // Create a temporary directory to simulate the install directory + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + require.NoError(t, os.MkdirAll(installDir, 0755), "failed to create test directory") + // Create a dummy file in the directory + dummyFile := filepath.Join(installDir, "copilot") + require.NoError(t, os.WriteFile(dummyFile, []byte("test"), 0755), "failed to create test file") + + err := removeCopilot(installDir) + require.NoError(t, err, "unexpected error") + + _, err = os.Stat(installDir) + require.True(t, os.IsNotExist(err), "expected install directory to be removed") + }) + + t.Run("handles non-existent directory", func(t *testing.T) { + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + + require.ErrorContains(t, removeCopilot(installDir), "failed to remove Copilot CLI") + }) +} + +// createTarGzBuffer creates a tar.gz archive in memory with the given files. +func createTarGzBuffer(t *testing.T, files map[string][]byte) []byte { + t.Helper() + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + for name, content := range files { + hdr := &tar.Header{ + Name: name, + Mode: 0755, + Size: int64(len(content)), + } + require.NoError(t, tw.WriteHeader(hdr), "failed to write tar header") + _, err := tw.Write(content) + require.NoError(t, err, "failed to write tar content") + } + + require.NoError(t, tw.Close(), "failed to close tar writer") + require.NoError(t, gw.Close(), "failed to close gzip writer") + return buf.Bytes() +} + +// createZipBuffer creates a zip archive in memory with the given files. +func createZipBuffer(t *testing.T, files map[string][]byte) []byte { + t.Helper() + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + for name, content := range files { + fw, err := zw.Create(name) + require.NoError(t, err, "failed to create zip entry") + _, err = fw.Write(content) + require.NoError(t, err, "failed to write zip content") + } + + require.NoError(t, zw.Close(), "failed to close zip writer") + return buf.Bytes() +} + +func TestExtractTarGz(t *testing.T) { + t.Run("extracts files correctly", func(t *testing.T) { + content := []byte("hello world") + archive := createTarGzBuffer(t, map[string][]byte{ + "copilot": content, + }) + + destDir := t.TempDir() + + err := extractTarGz(bytes.NewReader(archive), destDir) + require.NoError(t, err, "extractTarGz() error") + + extracted, err := os.ReadFile(filepath.Join(destDir, "copilot")) + require.NoError(t, err, "failed to read extracted file") + require.Equal(t, content, extracted, "extracted content mismatch") + }) + + t.Run("extracts nested files", func(t *testing.T) { + content := []byte("nested content") + archive := createTarGzBuffer(t, map[string][]byte{ + "subdir/file.txt": content, + }) + + destDir := t.TempDir() + + err := extractTarGz(bytes.NewReader(archive), destDir) + require.NoError(t, err, "extractTarGz() error") + + extracted, err := os.ReadFile(filepath.Join(destDir, "subdir", "file.txt")) + require.NoError(t, err, "failed to read extracted file") + require.Equal(t, content, extracted, "extracted content mismatch") + }) + + t.Run("rejects path traversal", func(t *testing.T) { + // Manually create a malicious tar.gz with path traversal + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + hdr := &tar.Header{ + Name: "../evil.txt", + Mode: 0755, + Size: 4, + } + _ = tw.WriteHeader(hdr) + _, _ = tw.Write([]byte("evil")) + _ = tw.Close() + _ = gw.Close() + + destDir := t.TempDir() + + err := extractTarGz(bytes.NewReader(buf.Bytes()), destDir) + require.Error(t, err, "expected error for path traversal, got nil") + }) + + t.Run("handles invalid gzip", func(t *testing.T) { + destDir := t.TempDir() + + err := extractTarGz(bytes.NewReader([]byte("not valid gzip")), destDir) + require.Error(t, err, "expected error for invalid gzip, got nil") + }) +} + +func TestExtractZip(t *testing.T) { + t.Run("extracts files correctly", func(t *testing.T) { + zipDir := t.TempDir() + zipPath := filepath.Join(zipDir, "archive.zip") + content := []byte("hello world") + archive := createZipBuffer(t, map[string][]byte{ + "copilot.exe": content, + }) + require.NoError(t, os.WriteFile(zipPath, archive, 0x755)) + + destDir := t.TempDir() + + err := extractZip(zipPath, destDir) + require.NoError(t, err, "extractZip() error") + + extracted, err := os.ReadFile(filepath.Join(destDir, "copilot.exe")) + require.NoError(t, err, "failed to read extracted file") + require.Equal(t, content, extracted, "extracted content mismatch") + }) + + t.Run("extracts nested files", func(t *testing.T) { + zipDir := t.TempDir() + zipPath := filepath.Join(zipDir, "archive.zip") + content := []byte("hello world") + archive := createZipBuffer(t, map[string][]byte{ + "subdir/file.txt": content, + }) + require.NoError(t, os.WriteFile(zipPath, archive, 0x755)) + + destDir := t.TempDir() + + err := extractZip(zipPath, destDir) + require.NoError(t, err, "extractZip() error") + + extracted, err := os.ReadFile(filepath.Join(destDir, "subdir", "file.txt")) + require.NoError(t, err, "failed to read extracted file") + require.Equal(t, content, extracted, "extracted content mismatch") + }) + + t.Run("rejects path traversal", func(t *testing.T) { + zipDir := t.TempDir() + zipPath := filepath.Join(zipDir, "archive.zip") + + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + fh := &zip.FileHeader{ + Name: "../evil.txt", + Method: zip.Store, + } + fw, _ := zw.CreateHeader(fh) + _, _ = fw.Write([]byte("evil")) + _ = zw.Close() + + require.NoError(t, os.WriteFile(zipPath, buf.Bytes(), 0x755)) + destDir := t.TempDir() + + err := extractZip(zipPath, destDir) + require.Error(t, err, "expected error for path traversal, got nil") + }) +} + +func TestFetchExpectedChecksum(t *testing.T) { + t.Run("parses checksums file correctly", func(t *testing.T) { + reg := &httpmock.Registry{} + checksums := "abc123def456 copilot-linux-x64.tar.gz\n789xyz copilot-darwin-arm64.tar.gz\n" + reg.Register( + httpmock.MatchAny, + httpmock.StringResponse(checksums), + ) + + client := &http.Client{Transport: reg} + checksum, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-linux-x64.tar.gz") + require.NoError(t, err, "unexpected error") + require.Equal(t, "abc123def456", checksum, "checksum mismatch") + }) + + t.Run("returns error for missing archive", func(t *testing.T) { + reg := &httpmock.Registry{} + checksums := "abc123 copilot-linux-x64.tar.gz\n" + reg.Register( + httpmock.MatchAny, + httpmock.StringResponse(checksums), + ) + + client := &http.Client{Transport: reg} + _, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-windows-x64.zip") + require.Error(t, err, "expected error for missing archive") + require.Equal(t, "checksum not found for copilot-windows-x64.zip", err.Error(), "unexpected error") + }) + + t.Run("handles single space separator", func(t *testing.T) { + reg := &httpmock.Registry{} + checksums := "abc123 copilot-darwin-x64.tar.gz\n" + reg.Register( + httpmock.MatchAny, + httpmock.StringResponse(checksums), + ) + + client := &http.Client{Transport: reg} + checksum, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-darwin-x64.tar.gz") + require.NoError(t, err, "unexpected error") + require.Equal(t, "abc123", checksum, "checksum mismatch") + }) + + t.Run("handles HTTP error", func(t *testing.T) { + reg := &httpmock.Registry{} + reg.Register( + httpmock.MatchAny, + httpmock.StatusStringResponse(http.StatusNotFound, "not found"), + ) + + client := &http.Client{Transport: reg} + _, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-linux-x64.tar.gz") + require.Error(t, err, "expected error for HTTP 404") + }) +} + +func archString() string { + arch := runtime.GOARCH + if arch == "amd64" { + return "x64" + } + return arch +} + +func TestDownloadCopilot(t *testing.T) { + // Skip on unsupported architectures + if runtime.GOARCH != "amd64" && runtime.GOARCH != "arm64" { + t.Skip("skipping test on unsupported architecture") + } + + t.Run("downloads and extracts tar.gz with valid checksum", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping tar.gz test on windows") + } + + ios, _, _, stderr := iostreams.Test() + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + localPath := filepath.Join(installDir, "copilot") + + // Create mock archive with copilot binary + binaryContent := []byte("#!/bin/sh\necho copilot") + archive := createTarGzBuffer(t, map[string][]byte{ + "copilot": binaryContent, + }) + + // Calculate checksum + checksum := sha256.Sum256(archive) + checksumHex := hex.EncodeToString(checksum[:]) + archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString()) + checksumFile := fmt.Sprintf("%s %s\n", checksumHex, archiveName) + + reg := &httpmock.Registry{} + // Register checksum endpoint + reg.Register( + httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"), + httpmock.StringResponse(checksumFile), + ) + // Register archive endpoint + reg.Register( + httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)), + httpmock.BinaryResponse(archive), + ) + + httpClient := &http.Client{Transport: reg} + + path, err := downloadCopilot(httpClient, ios, installDir, localPath) + require.NoError(t, err, "downloadCopilot() error") + require.Equal(t, localPath, path, "downloadCopilot() path mismatch") + + // Verify binary was extracted + extracted, err := os.ReadFile(localPath) + require.NoError(t, err, "failed to read extracted binary") + require.Equal(t, binaryContent, extracted, "extracted content mismatch") + + // Verify output messages + require.Contains(t, stderr.String(), "installed successfully", "expected success message in stderr") + }) + + t.Run("fails with checksum mismatch", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping tar.gz test on windows") + } + + ios, _, _, _ := iostreams.Test() + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + localPath := filepath.Join(installDir, "copilot") + + binaryContent := []byte("#!/bin/sh\necho copilot") + archive := createTarGzBuffer(t, map[string][]byte{ + "copilot": binaryContent, + }) + + // Use wrong checksum + archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString()) + checksumFile := fmt.Sprintf("%s %s\n", "0000000000000000000000000000000000000000000000000000000000000000", archiveName) + + reg := &httpmock.Registry{} + reg.Register( + httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"), + httpmock.StringResponse(checksumFile), + ) + reg.Register( + httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)), + httpmock.BinaryResponse(archive), + ) + + httpClient := &http.Client{Transport: reg} + + _, err := downloadCopilot(httpClient, ios, installDir, localPath) + require.Error(t, err, "expected error for checksum mismatch, got nil") + require.Contains(t, err.Error(), "checksum mismatch", "expected checksum mismatch error") + }) + + t.Run("handles HTTP error on archive download", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping tar.gz test on windows") + } + + ios, _, _, _ := iostreams.Test() + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + localPath := filepath.Join(installDir, "copilot") + + archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString()) + checksumFile := fmt.Sprintf("%s %s\n", "abc123", archiveName) + + reg := &httpmock.Registry{} + reg.Register( + httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"), + httpmock.StringResponse(checksumFile), + ) + reg.Register( + httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)), + httpmock.StatusStringResponse(http.StatusNotFound, "not found"), + ) + + httpClient := &http.Client{Transport: reg} + + _, err := downloadCopilot(httpClient, ios, installDir, localPath) + require.Error(t, err, "expected error for HTTP 404, got nil") + require.Contains(t, err.Error(), "download failed", "expected error to contain 'download failed'") + }) + + t.Run("handles missing binary after extraction", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping tar.gz test on windows") + } + + ios, _, _, _ := iostreams.Test() + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + localPath := filepath.Join(installDir, "copilot") + + // Create archive without the expected binary name + archive := createTarGzBuffer(t, map[string][]byte{ + "wrong-name": []byte("content"), + }) + + checksum := sha256.Sum256(archive) + checksumHex := hex.EncodeToString(checksum[:]) + archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString()) + checksumFile := fmt.Sprintf("%s %s\n", checksumHex, archiveName) + + reg := &httpmock.Registry{} + reg.Register( + httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"), + httpmock.StringResponse(checksumFile), + ) + reg.Register( + httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)), + httpmock.BinaryResponse(archive), + ) + + httpClient := &http.Client{Transport: reg} + + _, err := downloadCopilot(httpClient, ios, installDir, localPath) + assert.ErrorContains(t, err, "copilot binary unavailable") + }) + + t.Run("downloads and extracts zip on windows", func(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("skipping zip test on non-windows") + } + + ios, _, _, _ := iostreams.Test() + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + localPath := filepath.Join(installDir, "copilot.exe") + + binaryContent := []byte("MZ fake exe content") + archive := createZipBuffer(t, map[string][]byte{ + "copilot.exe": binaryContent, + }) + + checksum := sha256.Sum256(archive) + checksumHex := hex.EncodeToString(checksum[:]) + archiveName := fmt.Sprintf("copilot-%s-%s.zip", runtime.GOOS, archString()) + checksumFile := fmt.Sprintf("%s %s\n", checksumHex, archiveName) + + reg := &httpmock.Registry{} + reg.Register( + httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"), + httpmock.StringResponse(checksumFile), + ) + reg.Register( + httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)), + httpmock.BinaryResponse(archive), + ) + + httpClient := &http.Client{Transport: reg} + + path, err := downloadCopilot(httpClient, ios, installDir, localPath) + require.NoError(t, err, "downloadCopilot() error") + require.Equal(t, localPath, path, "downloadCopilot() path mismatch") + }) +}