diff --git a/command/gist.go b/command/gist.go index 227f8bd92..477702ab7 100644 --- a/command/gist.go +++ b/command/gist.go @@ -10,6 +10,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/api" + "github.com/cli/cli/pkg/cmdutil" "github.com/cli/cli/utils" "github.com/spf13/cobra" ) @@ -52,6 +53,22 @@ By default, gists are private; use '--public' to make publicly listed ones.`, # create a gist from output piped from another command $ cat cool.txt | gh gist create `), + Args: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + return nil + } + + info, err := os.Stdin.Stat() + if err != nil { + return fmt.Errorf("failed to check STDIN: %w", err) + } + + stdinIsTTY := (info.Mode() & os.ModeCharDevice) == os.ModeCharDevice + if stdinIsTTY { + return cmdutil.FlagError{Err: errors.New("no filenames passed and nothing on STDIN")} + } + return nil + }, RunE: gistCreate, } @@ -83,15 +100,12 @@ func gistCreate(cmd *cobra.Command, args []string) error { return fmt.Errorf("did not understand arguments: %w", err) } - info, err := os.Stdin.Stat() - if err != nil { - return fmt.Errorf("failed to check STDIN: %w", err) - } - if (info.Mode() & os.ModeCharDevice) != os.ModeCharDevice { - args = append(args, "-") + fileArgs := args + if len(args) == 0 { + fileArgs = []string{"-"} } - files, err := processFiles(os.Stdin, args) + files, err := processFiles(os.Stdin, fileArgs) if err != nil { return fmt.Errorf("failed to collect files for posting: %w", err) } @@ -128,11 +142,11 @@ func processOpts(cmd *cobra.Command) (*Opts, error) { }, err } -func processFiles(stdin io.Reader, filenames []string) (map[string]string, error) { +func processFiles(stdin io.ReadCloser, filenames []string) (map[string]string, error) { fs := map[string]string{} if len(filenames) == 0 { - return fs, errors.New("no filenames passed and nothing on STDIN") + return nil, errors.New("no files passed") } for i, f := range filenames { @@ -145,6 +159,7 @@ func processFiles(stdin io.Reader, filenames []string) (map[string]string, error if err != nil { return fs, fmt.Errorf("failed to read from stdin: %w", err) } + stdin.Close() } else { content, err = ioutil.ReadFile(f) if err != nil { diff --git a/command/gist_test.go b/command/gist_test.go index 7eb530e76..4a1d5d032 100644 --- a/command/gist_test.go +++ b/command/gist_test.go @@ -6,6 +6,8 @@ import ( "io/ioutil" "strings" "testing" + + "github.com/stretchr/testify/assert" ) func TestGistCreate(t *testing.T) { @@ -43,10 +45,11 @@ func TestGistCreate(t *testing.T) { func TestGistCreate_stdin(t *testing.T) { fakeStdin := strings.NewReader("hey cool how is it going") - files, err := processFiles(fakeStdin, []string{"-"}) + files, err := processFiles(ioutil.NopCloser(fakeStdin), []string{"-"}) if err != nil { t.Fatalf("unexpected error processing files: %s", err) } - eq(t, files["gistfile0.txt"], "hey cool how is it going") + assert.Equal(t, 1, len(files)) + assert.Equal(t, "hey cool how is it going", files["gistfile0.txt"]) }