diff --git a/pkg/cmd/pr/diff/diff.go b/pkg/cmd/pr/diff/diff.go index c6902e5b5..3c55e7db8 100644 --- a/pkg/cmd/pr/diff/diff.go +++ b/pkg/cmd/pr/diff/diff.go @@ -6,8 +6,6 @@ import ( "fmt" "io" "net/http" - "os" - "os/exec" "strings" "github.com/cli/cli/api" @@ -16,7 +14,6 @@ import ( "github.com/cli/cli/pkg/cmd/pr/shared" "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/pkg/iostreams" - "github.com/google/shlex" "github.com/spf13/cobra" ) @@ -93,17 +90,17 @@ func diffRun(opts *DiffOptions) error { } defer diff.Close() + err = opts.IO.StartPager() + if err != nil { + return err + } + defer opts.IO.StopPager() + if opts.UseColor == "never" { _, err = io.Copy(opts.IO.Out, diff) return err } - if opts.IO.IsStdoutTTY() { - if pager := os.Getenv("PAGER"); pager != "" { - return runPager(pager, diff, opts.IO.Out) - } - } - diffLines := bufio.NewScanner(diff) for diffLines.Scan() { diffLine := diffLines.Text() @@ -148,14 +145,3 @@ func isRemovalLine(dl string) bool { func validColorFlag(c string) bool { return c == "auto" || c == "always" || c == "never" } - -var runPager = func(pager string, diff io.Reader, out io.Writer) error { - args, err := shlex.Split(pager) - if err != nil { - return err - } - pagerCmd := exec.Command(args[0], args[1:]...) - pagerCmd.Stdin = diff - pagerCmd.Stdout = out - return pagerCmd.Run() -} diff --git a/pkg/cmd/pr/diff/diff_test.go b/pkg/cmd/pr/diff/diff_test.go index dc93970b6..5e362ac4d 100644 --- a/pkg/cmd/pr/diff/diff_test.go +++ b/pkg/cmd/pr/diff/diff_test.go @@ -2,10 +2,8 @@ package diff import ( "bytes" - "io" "io/ioutil" "net/http" - "os" "testing" "github.com/cli/cli/context" @@ -214,13 +212,8 @@ func TestPRDiff_notty(t *testing.T) { } func TestPRDiff_tty(t *testing.T) { - pager := os.Getenv("PAGER") http := &httpmock.Registry{} - defer func() { - os.Setenv("PAGER", pager) - http.Verify(t) - }() - os.Setenv("PAGER", "") + defer http.Verify(t) http.StubResponse(200, bytes.NewBufferString(` { "data": { "repository": { "pullRequests": { "nodes": [ { "url": "https://github.com/OWNER/REPO/pull/123", @@ -237,38 +230,6 @@ func TestPRDiff_tty(t *testing.T) { assert.Contains(t, output.String(), "\x1b[32m+site: bin/gh\x1b[m") } -func TestPRDiff_pager(t *testing.T) { - realRunPager := runPager - pager := os.Getenv("PAGER") - http := &httpmock.Registry{} - defer func() { - runPager = realRunPager - os.Setenv("PAGER", pager) - http.Verify(t) - }() - runPager = func(pager string, diff io.Reader, out io.Writer) error { - _, err := io.Copy(out, diff) - return err - } - os.Setenv("PAGER", "fakepager") - http.StubResponse(200, bytes.NewBufferString(` - { "data": { "repository": { "pullRequests": { "nodes": [ - { "url": "https://github.com/OWNER/REPO/pull/123", - "number": 123, - "id": "foobar123", - "headRefName": "feature", - "baseRefName": "master" } - ] } } } }`)) - http.StubResponse(200, bytes.NewBufferString(testDiff)) - output, err := runCommand(http, nil, true, "") - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - if diff := cmp.Diff(testDiff, output.String()); diff != "" { - t.Errorf("command output did not match:\n%s", diff) - } -} - const testDiff = `diff --git a/.github/workflows/releases.yml b/.github/workflows/releases.yml index 73974448..b7fc0154 100644 --- a/.github/workflows/releases.yml diff --git a/pkg/iostreams/iostreams.go b/pkg/iostreams/iostreams.go index e9b9c2220..d3f7e3c11 100644 --- a/pkg/iostreams/iostreams.go +++ b/pkg/iostreams/iostreams.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" + "github.com/google/shlex" "github.com/mattn/go-colorable" "github.com/mattn/go-isatty" "golang.org/x/crypto/ssh/terminal" @@ -28,6 +29,9 @@ type IOStreams struct { stdoutIsTTY bool stderrTTYOverride bool stderrIsTTY bool + + pagerCommand string + pagerProcess *os.Process } func (s *IOStreams) ColorEnabled() bool { @@ -79,6 +83,51 @@ func (s *IOStreams) IsStderrTTY() bool { return false } +func (s *IOStreams) StartPager() error { + if s.pagerCommand == "" || !s.IsStdoutTTY() { + return nil + } + + pagerArgs, err := shlex.Split(s.pagerCommand) + if err != nil { + return err + } + + pagerEnv := os.Environ() + if _, ok := os.LookupEnv("LESS"); !ok { + pagerEnv = append(pagerEnv, "LESS=FRX") + } + if _, ok := os.LookupEnv("LV"); !ok { + pagerEnv = append(pagerEnv, "LV=-c") + } + + pagerCmd := exec.Command(pagerArgs[0], pagerArgs[1:]...) + pagerCmd.Env = pagerEnv + pagerCmd.Stdout = s.Out + pagerCmd.Stderr = s.ErrOut + pagedOut, err := pagerCmd.StdinPipe() + if err != nil { + return err + } + s.Out = pagedOut + err = pagerCmd.Start() + if err != nil { + return err + } + s.pagerProcess = pagerCmd.Process + return nil +} + +func (s *IOStreams) StopPager() { + if s.pagerProcess == nil { + return + } + + s.Out.(io.ReadCloser).Close() + _, _ = s.pagerProcess.Wait() + s.pagerProcess = nil +} + func (s *IOStreams) TerminalWidth() int { defaultWidth := 80 if s.stdoutTTYOverride { @@ -111,6 +160,7 @@ func System() *IOStreams { Out: colorable.NewColorable(os.Stdout), ErrOut: colorable.NewColorable(os.Stderr), colorEnabled: os.Getenv("NO_COLOR") == "" && stdoutIsTTY, + pagerCommand: os.Getenv("PAGER"), } // prevent duplicate isTerminal queries now that we know the answer