feat: add copilot command

Signed-off-by: Babak K. Shandiz <babakks@github.com>

Co-authored-by: Kynan Ware <bagtoad@github.com>
Co-authored-by: Devraj Mehta <devm33@github.com>
This commit is contained in:
Babak K. Shandiz 2026-01-19 10:29:27 +00:00
parent 0f32f2ac46
commit 39880650d5
No known key found for this signature in database
GPG key ID: 9472CAEFF56C742E
2 changed files with 1039 additions and 0 deletions

451
pkg/cmd/copilot/copilot.go Normal file
View file

@ -0,0 +1,451 @@
package copilot
import (
"archive/tar"
"archive/zip"
"bufio"
"compress/gzip"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
"github.com/MakeNowJust/heredoc"
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/internal/prompter"
"github.com/cli/cli/v2/internal/safepaths"
"github.com/cli/cli/v2/internal/update"
ghzip "github.com/cli/cli/v2/internal/zip"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/spf13/cobra"
)
type CopilotOptions struct {
IO *iostreams.IOStreams
HttpClient func() (*http.Client, error)
Prompter prompter.Prompter
CopilotArgs []string
Remove bool
}
func NewCmdCopilot(f *cmdutil.Factory, runF func(*CopilotOptions) error) *cobra.Command {
opts := &CopilotOptions{
IO: f.IOStreams,
HttpClient: f.HttpClient,
Prompter: f.Prompter,
}
cmd := &cobra.Command{
Use: "copilot [flags] [args]",
Short: "Run the GitHub Copilot CLI (preview)",
Long: heredoc.Docf(`
Runs the GitHub Copilot CLI.
Executing the Copilot CLI through %[1]sgh%[1]s is currently in preview and subject to change.
If already installed, %[1]sgh%[1]s will execute the Copilot CLI found in your %[1]sPATH%[1]s.
If the Copilot CLI is not installed, it will be downloaded to %[2]s.
Use %[1]s--remove%[1]s to remove the downloaded Copilot CLI.
This command is only supported on Windows, Linux, and Darwin, on amd64/x64
or arm64 architectures.
To prevent %[1]sgh%[1]s from interpreting flags intended for Copilot,
use %[1]s--%[1]s before Copilot flags and args.
Learn more at https://gh.io/copilot-cli
`, "`", copilotInstallDir()),
Example: heredoc.Doc(`
# Download and run the Copilot CLI
$ gh copilot
# Run the Copilot CLI
$ gh copilot -p "Summarize this week's commits" --allow-tool 'shell(git)'
# Remove the Copilot CLI (if installed through gh)
$ gh copilot --remove
# Run the Copilot CLI help command
$ gh copilot -- --help
`),
DisableFlagParsing: true,
RunE: func(cmd *cobra.Command, args []string) error {
stopParsePos := -1
for i, arg := range args {
if arg == "--" {
stopParsePos = i
break
}
}
ghArgs := args
opts.CopilotArgs = args
if stopParsePos >= 0 {
ghArgs = args[:stopParsePos]
opts.CopilotArgs = args[stopParsePos+1:] // +1 to skip the "--" itself
}
if slices.Contains(ghArgs, "--help") || slices.Contains(ghArgs, "-h") {
return cmd.Help()
}
if slices.Contains(ghArgs, "--remove") {
hasOtherArgs := len(ghArgs) > 1
if stopParsePos >= 0 {
hasOtherArgs = hasOtherArgs || len(opts.CopilotArgs) > 0
}
if hasOtherArgs {
return cmdutil.FlagErrorf("cannot use --remove with args")
}
opts.Remove = true
opts.CopilotArgs = nil
}
if runF != nil {
return runF(opts)
}
return runCopilot(opts)
},
}
cmdutil.DisableAuthCheck(cmd)
// We add this flag, even though flag parsing is disabled for this command
// so the flag still appears in the help text.
cmd.Flags().Bool("remove", false, "Remove the downloaded Copilot CLI")
return cmd
}
func runCopilot(opts *CopilotOptions) error {
if opts.Remove {
if err := removeCopilot(copilotInstallDir()); err != nil {
return err
}
if opts.IO.IsStdoutTTY() {
fmt.Fprintln(opts.IO.ErrOut, "Copilot CLI removed successfully")
}
return nil
}
copilotPath := findCopilotBinary()
if copilotPath == "" {
if opts.IO.CanPrompt() {
confirmed, err := opts.Prompter.Confirm("GitHub Copilot CLI is not installed. Would you like to install it?", true)
if err != nil {
return err
}
if !confirmed {
fmt.Fprintf(opts.IO.ErrOut, "%s Copilot CLI was not installed", opts.IO.ColorScheme().WarningIcon())
return cmdutil.SilentError
}
} else if !update.IsCI() {
fmt.Fprintf(opts.IO.ErrOut, "%s Copilot CLI not installed", opts.IO.ColorScheme().WarningIcon())
return cmdutil.SilentError
}
httpClient, err := opts.HttpClient()
if err != nil {
return err
}
copilotPath, err = downloadCopilot(httpClient, opts.IO, copilotInstallDir(), copilotBinaryPath())
if err != nil {
return err
}
}
externalCmd := exec.Command(copilotPath, opts.CopilotArgs...)
externalCmd.Stdin = opts.IO.In
externalCmd.Stdout = opts.IO.Out
externalCmd.Stderr = opts.IO.ErrOut
if err := externalCmd.Run(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
// We terminate with os.Exit here, preserving the exit code from Copilot CLI,
// and also preventing stdio writes by callers up the stack.
os.Exit(exitErr.ExitCode())
}
return err
}
return nil
}
const copilotBinaryName = "copilot"
func copilotInstallDir() string {
return filepath.Join(config.DataDir(), "copilot")
}
func copilotBinaryPath() string {
binaryName := copilotBinaryName
if runtime.GOOS == "windows" {
binaryName += ".exe"
}
return filepath.Join(copilotInstallDir(), binaryName)
}
// findCopilotBinary returns the path to the Copilot CLI binary, if installed,
// with the following order of precedence:
// 1. `copilot` in the PATH
// 2. `copilot` in gh's data directory
//
// If not installed, it returns an empty string.
func findCopilotBinary() string {
if path, err := exec.LookPath(copilotBinaryName); err == nil {
return path
}
localPath := copilotBinaryPath()
if _, err := os.Stat(localPath); err != nil {
return ""
}
return localPath
}
// downloadCopilot downloads and installs the Copilot CLI to installDir.
// It returns the path to the installed Copilot binary.
func downloadCopilot(httpClient *http.Client, ios *iostreams.IOStreams, installDir, localPath string) (string, error) {
platform := runtime.GOOS
arch := runtime.GOARCH
if arch == "amd64" {
arch = "x64"
}
if arch != "x64" && arch != "arm64" {
return "", fmt.Errorf("unsupported architecture: %s (supported: x64, arm64)", arch)
}
var archiveURL string
var archiveName string
var isZip bool
switch platform {
case "windows":
archiveName = fmt.Sprintf("copilot-%s-%s.zip", platform, arch)
archiveURL = fmt.Sprintf("https://github.com/github/copilot-cli/releases/latest/download/%s", archiveName)
isZip = true
case "linux", "darwin":
archiveName = fmt.Sprintf("copilot-%s-%s.tar.gz", platform, arch)
archiveURL = fmt.Sprintf("https://github.com/github/copilot-cli/releases/latest/download/%s", archiveName)
default:
return "", fmt.Errorf("unsupported platform: %s (supported: linux, darwin, windows)", platform)
}
checksumsURL := "https://github.com/github/copilot-cli/releases/latest/download/SHA256SUMS.txt"
ios.StartProgressIndicatorWithLabel(fmt.Sprintf("Downloading Copilot CLI from %s", archiveURL))
defer ios.StopProgressIndicator()
expectedChecksum, err := fetchExpectedChecksum(httpClient, checksumsURL, archiveName)
if err != nil {
return "", fmt.Errorf("failed to fetch checksums: %w", err)
}
resp, err := httpClient.Get(archiveURL)
if err != nil {
return "", fmt.Errorf("failed to download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("download failed with status: %s", resp.Status)
}
// Download to temp file while calculating checksum
tmpFile, err := os.CreateTemp("", "copilot-download-*")
if err != nil {
return "", fmt.Errorf("failed to create temp file: %w", err)
}
defer os.Remove(tmpFile.Name())
defer tmpFile.Close()
hasher := sha256.New()
if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil {
return "", fmt.Errorf("failed to download: %w", err)
}
ios.StopProgressIndicator()
// Validate checksum
actualChecksumHex := hex.EncodeToString(hasher.Sum(nil))
if actualChecksumHex != expectedChecksum {
return "", fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksumHex)
}
if _, err := tmpFile.Seek(0, io.SeekStart); err != nil {
return "", fmt.Errorf("failed to seek temp file: %w", err)
}
if err := os.MkdirAll(installDir, 0755); err != nil {
return "", fmt.Errorf("failed to create install directory: %w", err)
}
// Extract from the downloaded data
if isZip {
err = extractZip(tmpFile.Name(), installDir)
} else {
err = extractTarGz(tmpFile, installDir)
}
if err != nil {
return "", err
}
if _, err := os.Stat(localPath); err != nil {
return "", fmt.Errorf("copilot binary unavailable: %w", err)
}
fmt.Fprintf(ios.ErrOut, "%s Copilot CLI installed successfully\n", ios.ColorScheme().SuccessIcon())
return localPath, nil
}
// fetchExpectedChecksum downloads the SHA256SUMS.txt file and returns the expected checksum for the given archive name.
func fetchExpectedChecksum(httpClient *http.Client, checksumsURL, archiveName string) (string, error) {
resp, err := httpClient.Get(checksumsURL)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("failed to download checksums: %s", resp.Status)
}
// Parse the checksums file. Possible formats are:
// - "<checksum> <filename>" (two whitespaces)
// - "<checksum> <filename>"
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if len(fields) >= 2 {
checksum := fields[0]
filename := fields[1]
if filename == archiveName {
return checksum, nil
}
}
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("failed to read checksums: %w", err)
}
return "", fmt.Errorf("checksum not found for %s", archiveName)
}
// extractZip reads a ZIP archive at path and extracts its contents into destDir.
// It returns an error if the archive cannot be read,
// or if any file or directory within the archive cannot be created or written.
func extractZip(path, destDir string) error {
zipReader, err := zip.OpenReader(path)
if err != nil {
return fmt.Errorf("failed to open zip: %w", err)
}
defer zipReader.Close()
absPath, err := safepaths.ParseAbsolute(destDir)
if err != nil {
return err
}
// As of the time of writing, ghzip.ExtractZip will safely skip files that
// would result in path traversal. This is an issue for our use-case because
// we want to error out before extracting if there's any such file.
// To avoid breaking the shared ghzip.ExtractZip code that expects unsafe
// paths to be ignored and no error produced, we pre-validate here,
// producing an error if any such file is found.
for _, f := range zipReader.File {
_, err := absPath.Join(f.Name)
if err != nil {
return err
}
}
if err := ghzip.ExtractZip(&zipReader.Reader, absPath); err != nil {
return err
}
return nil
}
// extractTarGz reads a TAR.GZ archive from r and extracts its contents into destDir.
// It returns an error if the archive cannot be read,
// or if any file or directory within the archive cannot be created or written.
func extractTarGz(r io.Reader, destDir string) error {
gzr, err := gzip.NewReader(r)
if err != nil {
return fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gzr.Close()
absDestDirPath, err := safepaths.ParseAbsolute(destDir)
if err != nil {
return err
}
tr := tar.NewReader(gzr)
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read tar: %w", err)
}
absFilePath, err := absDestDirPath.Join(header.Name)
if err != nil {
return err
}
target := absFilePath.String()
if header.Typeflag == tar.TypeReg {
if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil {
return fmt.Errorf("failed to create parent directory: %w", err)
}
if err := extractFile(target, os.FileMode(header.Mode)&0777, tr); err != nil {
return err
}
}
}
return nil
}
// extractFile creates a file at target with the given mode and copies content from r.
func extractFile(target string, mode os.FileMode, r io.Reader) (err error) {
out, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer func() {
if cerr := out.Close(); err == nil && cerr != nil {
err = fmt.Errorf("failed to close file: %w", cerr)
}
}()
if _, err := io.Copy(out, r); err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
return nil
}
func removeCopilot(installDir string) error {
if _, err := os.Stat(installDir); os.IsNotExist(err) {
return fmt.Errorf("failed to remove Copilot CLI: Copilot CLI not installed through `gh`")
}
if err := os.RemoveAll(installDir); err != nil {
return fmt.Errorf("failed to remove Copilot CLI: %w", err)
}
return nil
}

View file

@ -0,0 +1,588 @@
package copilot
import (
"archive/tar"
"archive/zip"
"bytes"
"compress/gzip"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/cli/cli/v2/pkg/cmdutil"
"github.com/cli/cli/v2/pkg/httpmock"
"github.com/cli/cli/v2/pkg/iostreams"
"github.com/google/shlex"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewCmdCopilot(t *testing.T) {
tests := []struct {
name string
args string
wantOpts CopilotOptions
wantErrString string
wantHelp bool
}{
{
name: "no argument",
args: "",
wantOpts: CopilotOptions{
CopilotArgs: []string{},
},
wantErrString: "",
},
{
name: "with arguments",
args: "some-arg some-other-arg",
wantOpts: CopilotOptions{
CopilotArgs: []string{"some-arg", "some-other-arg"},
},
},
{
name: "with --remove alone",
args: "--remove",
wantOpts: CopilotOptions{
Remove: true,
},
},
{
name: "with non-gh flags passed to copilot",
args: "-p testing --something-flag",
wantOpts: CopilotOptions{
CopilotArgs: []string{"-p", "testing", "--something-flag"},
},
},
{
name: "with --remove and arguments",
args: "--remove some-arg",
wantErrString: "cannot use --remove with args",
},
{
name: "with --remove passed to copilot using --",
args: "-- --remove",
wantOpts: CopilotOptions{
CopilotArgs: []string{"--remove"},
},
},
{
name: "with --remove and -- alone",
args: "--remove --",
wantOpts: CopilotOptions{
Remove: true,
},
},
{
name: "with --remove, some invalid arg, and --",
args: "--remove invalid-arg --",
wantErrString: "cannot use --remove with args",
},
{
name: "with --remove and -- and random arguments",
args: "--remove -- some-arg",
wantErrString: "cannot use --remove with args",
},
{
name: "with --help, shows gh help",
args: "--help",
wantErrString: "",
wantHelp: true,
},
{
name: "with --help and --, shows copilot help",
args: "-- --help",
wantOpts: CopilotOptions{
CopilotArgs: []string{"--help"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &cmdutil.Factory{}
argv, err := shlex.Split(tt.args)
assert.NoError(t, err)
var gotOpts *CopilotOptions
cmd := NewCmdCopilot(f, func(opts *CopilotOptions) error {
gotOpts = opts
return nil
})
cmd.SetArgs(argv)
cmd.SetIn(&bytes.Buffer{})
cmd.SetOut(&bytes.Buffer{})
cmd.SetErr(&bytes.Buffer{})
_, err = cmd.ExecuteC()
if tt.wantErrString != "" {
require.EqualError(t, err, tt.wantErrString)
return
}
if tt.wantHelp {
require.NoError(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantOpts.CopilotArgs, gotOpts.CopilotArgs, "opts.CopilotArgs not as expected")
assert.Equal(t, tt.wantOpts.Remove, gotOpts.Remove, "opts.Remove not as expected")
})
}
}
func TestRemoveCopilot(t *testing.T) {
t.Run("removes existing install directory", func(t *testing.T) {
// Create a temporary directory to simulate the install directory
tmpDir := t.TempDir()
installDir := filepath.Join(tmpDir, "copilot")
require.NoError(t, os.MkdirAll(installDir, 0755), "failed to create test directory")
// Create a dummy file in the directory
dummyFile := filepath.Join(installDir, "copilot")
require.NoError(t, os.WriteFile(dummyFile, []byte("test"), 0755), "failed to create test file")
err := removeCopilot(installDir)
require.NoError(t, err, "unexpected error")
_, err = os.Stat(installDir)
require.True(t, os.IsNotExist(err), "expected install directory to be removed")
})
t.Run("handles non-existent directory", func(t *testing.T) {
tmpDir := t.TempDir()
installDir := filepath.Join(tmpDir, "copilot")
require.ErrorContains(t, removeCopilot(installDir), "failed to remove Copilot CLI")
})
}
// createTarGzBuffer creates a tar.gz archive in memory with the given files.
func createTarGzBuffer(t *testing.T, files map[string][]byte) []byte {
t.Helper()
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
tw := tar.NewWriter(gw)
for name, content := range files {
hdr := &tar.Header{
Name: name,
Mode: 0755,
Size: int64(len(content)),
}
require.NoError(t, tw.WriteHeader(hdr), "failed to write tar header")
_, err := tw.Write(content)
require.NoError(t, err, "failed to write tar content")
}
require.NoError(t, tw.Close(), "failed to close tar writer")
require.NoError(t, gw.Close(), "failed to close gzip writer")
return buf.Bytes()
}
// createZipBuffer creates a zip archive in memory with the given files.
func createZipBuffer(t *testing.T, files map[string][]byte) []byte {
t.Helper()
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
for name, content := range files {
fw, err := zw.Create(name)
require.NoError(t, err, "failed to create zip entry")
_, err = fw.Write(content)
require.NoError(t, err, "failed to write zip content")
}
require.NoError(t, zw.Close(), "failed to close zip writer")
return buf.Bytes()
}
func TestExtractTarGz(t *testing.T) {
t.Run("extracts files correctly", func(t *testing.T) {
content := []byte("hello world")
archive := createTarGzBuffer(t, map[string][]byte{
"copilot": content,
})
destDir := t.TempDir()
err := extractTarGz(bytes.NewReader(archive), destDir)
require.NoError(t, err, "extractTarGz() error")
extracted, err := os.ReadFile(filepath.Join(destDir, "copilot"))
require.NoError(t, err, "failed to read extracted file")
require.Equal(t, content, extracted, "extracted content mismatch")
})
t.Run("extracts nested files", func(t *testing.T) {
content := []byte("nested content")
archive := createTarGzBuffer(t, map[string][]byte{
"subdir/file.txt": content,
})
destDir := t.TempDir()
err := extractTarGz(bytes.NewReader(archive), destDir)
require.NoError(t, err, "extractTarGz() error")
extracted, err := os.ReadFile(filepath.Join(destDir, "subdir", "file.txt"))
require.NoError(t, err, "failed to read extracted file")
require.Equal(t, content, extracted, "extracted content mismatch")
})
t.Run("rejects path traversal", func(t *testing.T) {
// Manually create a malicious tar.gz with path traversal
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
tw := tar.NewWriter(gw)
hdr := &tar.Header{
Name: "../evil.txt",
Mode: 0755,
Size: 4,
}
_ = tw.WriteHeader(hdr)
_, _ = tw.Write([]byte("evil"))
_ = tw.Close()
_ = gw.Close()
destDir := t.TempDir()
err := extractTarGz(bytes.NewReader(buf.Bytes()), destDir)
require.Error(t, err, "expected error for path traversal, got nil")
})
t.Run("handles invalid gzip", func(t *testing.T) {
destDir := t.TempDir()
err := extractTarGz(bytes.NewReader([]byte("not valid gzip")), destDir)
require.Error(t, err, "expected error for invalid gzip, got nil")
})
}
func TestExtractZip(t *testing.T) {
t.Run("extracts files correctly", func(t *testing.T) {
zipDir := t.TempDir()
zipPath := filepath.Join(zipDir, "archive.zip")
content := []byte("hello world")
archive := createZipBuffer(t, map[string][]byte{
"copilot.exe": content,
})
require.NoError(t, os.WriteFile(zipPath, archive, 0x755))
destDir := t.TempDir()
err := extractZip(zipPath, destDir)
require.NoError(t, err, "extractZip() error")
extracted, err := os.ReadFile(filepath.Join(destDir, "copilot.exe"))
require.NoError(t, err, "failed to read extracted file")
require.Equal(t, content, extracted, "extracted content mismatch")
})
t.Run("extracts nested files", func(t *testing.T) {
zipDir := t.TempDir()
zipPath := filepath.Join(zipDir, "archive.zip")
content := []byte("hello world")
archive := createZipBuffer(t, map[string][]byte{
"subdir/file.txt": content,
})
require.NoError(t, os.WriteFile(zipPath, archive, 0x755))
destDir := t.TempDir()
err := extractZip(zipPath, destDir)
require.NoError(t, err, "extractZip() error")
extracted, err := os.ReadFile(filepath.Join(destDir, "subdir", "file.txt"))
require.NoError(t, err, "failed to read extracted file")
require.Equal(t, content, extracted, "extracted content mismatch")
})
t.Run("rejects path traversal", func(t *testing.T) {
zipDir := t.TempDir()
zipPath := filepath.Join(zipDir, "archive.zip")
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
fh := &zip.FileHeader{
Name: "../evil.txt",
Method: zip.Store,
}
fw, _ := zw.CreateHeader(fh)
_, _ = fw.Write([]byte("evil"))
_ = zw.Close()
require.NoError(t, os.WriteFile(zipPath, buf.Bytes(), 0x755))
destDir := t.TempDir()
err := extractZip(zipPath, destDir)
require.Error(t, err, "expected error for path traversal, got nil")
})
}
func TestFetchExpectedChecksum(t *testing.T) {
t.Run("parses checksums file correctly", func(t *testing.T) {
reg := &httpmock.Registry{}
checksums := "abc123def456 copilot-linux-x64.tar.gz\n789xyz copilot-darwin-arm64.tar.gz\n"
reg.Register(
httpmock.MatchAny,
httpmock.StringResponse(checksums),
)
client := &http.Client{Transport: reg}
checksum, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-linux-x64.tar.gz")
require.NoError(t, err, "unexpected error")
require.Equal(t, "abc123def456", checksum, "checksum mismatch")
})
t.Run("returns error for missing archive", func(t *testing.T) {
reg := &httpmock.Registry{}
checksums := "abc123 copilot-linux-x64.tar.gz\n"
reg.Register(
httpmock.MatchAny,
httpmock.StringResponse(checksums),
)
client := &http.Client{Transport: reg}
_, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-windows-x64.zip")
require.Error(t, err, "expected error for missing archive")
require.Equal(t, "checksum not found for copilot-windows-x64.zip", err.Error(), "unexpected error")
})
t.Run("handles single space separator", func(t *testing.T) {
reg := &httpmock.Registry{}
checksums := "abc123 copilot-darwin-x64.tar.gz\n"
reg.Register(
httpmock.MatchAny,
httpmock.StringResponse(checksums),
)
client := &http.Client{Transport: reg}
checksum, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-darwin-x64.tar.gz")
require.NoError(t, err, "unexpected error")
require.Equal(t, "abc123", checksum, "checksum mismatch")
})
t.Run("handles HTTP error", func(t *testing.T) {
reg := &httpmock.Registry{}
reg.Register(
httpmock.MatchAny,
httpmock.StatusStringResponse(http.StatusNotFound, "not found"),
)
client := &http.Client{Transport: reg}
_, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-linux-x64.tar.gz")
require.Error(t, err, "expected error for HTTP 404")
})
}
func archString() string {
arch := runtime.GOARCH
if arch == "amd64" {
return "x64"
}
return arch
}
func TestDownloadCopilot(t *testing.T) {
// Skip on unsupported architectures
if runtime.GOARCH != "amd64" && runtime.GOARCH != "arm64" {
t.Skip("skipping test on unsupported architecture")
}
t.Run("downloads and extracts tar.gz with valid checksum", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping tar.gz test on windows")
}
ios, _, _, stderr := iostreams.Test()
tmpDir := t.TempDir()
installDir := filepath.Join(tmpDir, "copilot")
localPath := filepath.Join(installDir, "copilot")
// Create mock archive with copilot binary
binaryContent := []byte("#!/bin/sh\necho copilot")
archive := createTarGzBuffer(t, map[string][]byte{
"copilot": binaryContent,
})
// Calculate checksum
checksum := sha256.Sum256(archive)
checksumHex := hex.EncodeToString(checksum[:])
archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString())
checksumFile := fmt.Sprintf("%s %s\n", checksumHex, archiveName)
reg := &httpmock.Registry{}
// Register checksum endpoint
reg.Register(
httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"),
httpmock.StringResponse(checksumFile),
)
// Register archive endpoint
reg.Register(
httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)),
httpmock.BinaryResponse(archive),
)
httpClient := &http.Client{Transport: reg}
path, err := downloadCopilot(httpClient, ios, installDir, localPath)
require.NoError(t, err, "downloadCopilot() error")
require.Equal(t, localPath, path, "downloadCopilot() path mismatch")
// Verify binary was extracted
extracted, err := os.ReadFile(localPath)
require.NoError(t, err, "failed to read extracted binary")
require.Equal(t, binaryContent, extracted, "extracted content mismatch")
// Verify output messages
require.Contains(t, stderr.String(), "installed successfully", "expected success message in stderr")
})
t.Run("fails with checksum mismatch", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping tar.gz test on windows")
}
ios, _, _, _ := iostreams.Test()
tmpDir := t.TempDir()
installDir := filepath.Join(tmpDir, "copilot")
localPath := filepath.Join(installDir, "copilot")
binaryContent := []byte("#!/bin/sh\necho copilot")
archive := createTarGzBuffer(t, map[string][]byte{
"copilot": binaryContent,
})
// Use wrong checksum
archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString())
checksumFile := fmt.Sprintf("%s %s\n", "0000000000000000000000000000000000000000000000000000000000000000", archiveName)
reg := &httpmock.Registry{}
reg.Register(
httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"),
httpmock.StringResponse(checksumFile),
)
reg.Register(
httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)),
httpmock.BinaryResponse(archive),
)
httpClient := &http.Client{Transport: reg}
_, err := downloadCopilot(httpClient, ios, installDir, localPath)
require.Error(t, err, "expected error for checksum mismatch, got nil")
require.Contains(t, err.Error(), "checksum mismatch", "expected checksum mismatch error")
})
t.Run("handles HTTP error on archive download", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping tar.gz test on windows")
}
ios, _, _, _ := iostreams.Test()
tmpDir := t.TempDir()
installDir := filepath.Join(tmpDir, "copilot")
localPath := filepath.Join(installDir, "copilot")
archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString())
checksumFile := fmt.Sprintf("%s %s\n", "abc123", archiveName)
reg := &httpmock.Registry{}
reg.Register(
httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"),
httpmock.StringResponse(checksumFile),
)
reg.Register(
httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)),
httpmock.StatusStringResponse(http.StatusNotFound, "not found"),
)
httpClient := &http.Client{Transport: reg}
_, err := downloadCopilot(httpClient, ios, installDir, localPath)
require.Error(t, err, "expected error for HTTP 404, got nil")
require.Contains(t, err.Error(), "download failed", "expected error to contain 'download failed'")
})
t.Run("handles missing binary after extraction", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping tar.gz test on windows")
}
ios, _, _, _ := iostreams.Test()
tmpDir := t.TempDir()
installDir := filepath.Join(tmpDir, "copilot")
localPath := filepath.Join(installDir, "copilot")
// Create archive without the expected binary name
archive := createTarGzBuffer(t, map[string][]byte{
"wrong-name": []byte("content"),
})
checksum := sha256.Sum256(archive)
checksumHex := hex.EncodeToString(checksum[:])
archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString())
checksumFile := fmt.Sprintf("%s %s\n", checksumHex, archiveName)
reg := &httpmock.Registry{}
reg.Register(
httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"),
httpmock.StringResponse(checksumFile),
)
reg.Register(
httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)),
httpmock.BinaryResponse(archive),
)
httpClient := &http.Client{Transport: reg}
_, err := downloadCopilot(httpClient, ios, installDir, localPath)
assert.ErrorContains(t, err, "copilot binary unavailable")
})
t.Run("downloads and extracts zip on windows", func(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip("skipping zip test on non-windows")
}
ios, _, _, _ := iostreams.Test()
tmpDir := t.TempDir()
installDir := filepath.Join(tmpDir, "copilot")
localPath := filepath.Join(installDir, "copilot.exe")
binaryContent := []byte("MZ fake exe content")
archive := createZipBuffer(t, map[string][]byte{
"copilot.exe": binaryContent,
})
checksum := sha256.Sum256(archive)
checksumHex := hex.EncodeToString(checksum[:])
archiveName := fmt.Sprintf("copilot-%s-%s.zip", runtime.GOOS, archString())
checksumFile := fmt.Sprintf("%s %s\n", checksumHex, archiveName)
reg := &httpmock.Registry{}
reg.Register(
httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"),
httpmock.StringResponse(checksumFile),
)
reg.Register(
httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)),
httpmock.BinaryResponse(archive),
)
httpClient := &http.Client{Transport: reg}
path, err := downloadCopilot(httpClient, ios, installDir, localPath)
require.NoError(t, err, "downloadCopilot() error")
require.Equal(t, localPath, path, "downloadCopilot() path mismatch")
})
}