diff --git a/pkg/cmd/api/api.go b/pkg/cmd/api/api.go index b828c607c..b4a8dbdff 100644 --- a/pkg/cmd/api/api.go +++ b/pkg/cmd/api/api.go @@ -103,13 +103,16 @@ func apiRun(opts *ApiOptions) error { } if opts.RequestInputFile != "" { - file, err := openUserFile(opts.RequestInputFile, opts.IO.In) + file, size, err := openUserFile(opts.RequestInputFile, opts.IO.In) if err != nil { return err } defer file.Close() requestPath = addQuery(requestPath, params) requestBody = file + if size >= 0 { + requestHeaders = append([]string{fmt.Sprintf("Content-Length: %d", size)}, requestHeaders...) + } } httpClient, err := opts.HttpClient() @@ -240,19 +243,36 @@ func magicFieldValue(v string, stdin io.ReadCloser) (interface{}, error) { } func readUserFile(fn string, stdin io.ReadCloser) ([]byte, error) { - r, err := openUserFile(fn, stdin) - if err != nil { - return nil, err + var r io.ReadCloser + if fn == "-" { + r = stdin + } else { + var err error + r, err = os.Open(fn) + if err != nil { + return nil, err + } } defer r.Close() return ioutil.ReadAll(r) } -func openUserFile(fn string, stdin io.ReadCloser) (io.ReadCloser, error) { +func openUserFile(fn string, stdin io.ReadCloser) (io.ReadCloser, int64, error) { if fn == "-" { - return stdin, nil + return stdin, -1, nil } - return os.Open(fn) + + r, err := os.Open(fn) + if err != nil { + return r, -1, err + } + + s, err := os.Stat(fn) + if err != nil { + return r, -1, err + } + + return r, s.Size(), nil } func parseErrorResponse(r io.Reader, statusCode int) (io.Reader, string, error) { diff --git a/pkg/cmd/api/api_test.go b/pkg/cmd/api/api_test.go index c121455db..100c1257e 100644 --- a/pkg/cmd/api/api_test.go +++ b/pkg/cmd/api/api_test.go @@ -247,41 +247,78 @@ func Test_apiRun(t *testing.T) { } func Test_apiRun_inputFile(t *testing.T) { - io, stdin, _, _ := iostreams.Test() - resp := &http.Response{StatusCode: 204} + tests := []struct { + name string + inputFile string + inputContents []byte - options := ApiOptions{ - RequestPath: "hello", - RequestInputFile: "-", - RawFields: []string{"a=b", "c=d"}, - - IO: io, - HttpClient: func() (*http.Client, error) { - var tr roundTripper = func(req *http.Request) (*http.Response, error) { - resp.Request = req - return resp, nil - } - return &http.Client{Transport: tr}, nil + contentLength int64 + expectedContents []byte + }{ + { + name: "stdin", + inputFile: "-", + inputContents: []byte("I WORK OUT"), + contentLength: 0, + }, + { + name: "from file", + inputFile: "gh-test-file", + inputContents: []byte("I WORK OUT"), + contentLength: 10, }, } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + io, stdin, _, _ := iostreams.Test() + resp := &http.Response{StatusCode: 204} - fmt.Fprintln(stdin, "I WORK OUT") + inputFile := tt.inputFile + if tt.inputFile == "-" { + _, _ = stdin.Write(tt.inputContents) + } else { + f, err := ioutil.TempFile("", tt.inputFile) + if err != nil { + t.Fatal(err) + } + _, _ = f.Write(tt.inputContents) + f.Close() + t.Cleanup(func() { os.Remove(f.Name()) }) + inputFile = f.Name() + } - err := apiRun(&options) - if err != nil { - t.Errorf("got error %v", err) + var bodyBytes []byte + options := ApiOptions{ + RequestPath: "hello", + RequestInputFile: inputFile, + RawFields: []string{"a=b", "c=d"}, + + IO: io, + HttpClient: func() (*http.Client, error) { + var tr roundTripper = func(req *http.Request) (*http.Response, error) { + var err error + if bodyBytes, err = ioutil.ReadAll(req.Body); err != nil { + return nil, err + } + resp.Request = req + return resp, nil + } + return &http.Client{Transport: tr}, nil + }, + } + + err := apiRun(&options) + if err != nil { + t.Errorf("got error %v", err) + } + + assert.Equal(t, "POST", resp.Request.Method) + assert.Equal(t, "/hello?a=b&c=d", resp.Request.URL.RequestURI()) + assert.Equal(t, tt.contentLength, resp.Request.ContentLength) + assert.Equal(t, "", resp.Request.Header.Get("Content-Type")) + assert.Equal(t, tt.inputContents, bodyBytes) + }) } - - assert.Equal(t, "POST", resp.Request.Method) - assert.Equal(t, "/hello?a=b&c=d", resp.Request.URL.RequestURI()) - assert.Equal(t, "", resp.Request.Header.Get("Content-Length")) - assert.Equal(t, "", resp.Request.Header.Get("Content-Type")) - - bb, err := ioutil.ReadAll(resp.Request.Body) - if err != nil { - t.Errorf("got error %v", err) - } - assert.Equal(t, "I WORK OUT\n", string(bb)) } func Test_parseFields(t *testing.T) { @@ -400,7 +437,7 @@ func Test_openUserFile(t *testing.T) { f.Close() t.Cleanup(func() { os.Remove(f.Name()) }) - file, err := openUserFile(f.Name(), nil) + file, length, err := openUserFile(f.Name(), nil) if err != nil { t.Fatal(err) } @@ -411,5 +448,6 @@ func Test_openUserFile(t *testing.T) { t.Fatal(err) } + assert.Equal(t, int64(13), length) assert.Equal(t, "file contents", string(fb)) } diff --git a/pkg/cmd/api/http.go b/pkg/cmd/api/http.go index 4db21b286..e393bb593 100644 --- a/pkg/cmd/api/http.go +++ b/pkg/cmd/api/http.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/url" + "strconv" "strings" ) @@ -62,7 +63,16 @@ func httpRequest(client *http.Client, method string, p string, params interface{ if idx == -1 { return nil, fmt.Errorf("header %q requires a value separated by ':'", h) } - req.Header.Add(h[0:idx], strings.TrimSpace(h[idx+1:])) + name, value := h[0:idx], strings.TrimSpace(h[idx+1:]) + if strings.EqualFold(name, "Content-Length") { + length, err := strconv.ParseInt(value, 10, 0) + if err != nil { + return nil, err + } + req.ContentLength = length + } else { + req.Header.Add(name, value) + } } if bodyIsJSON && req.Header.Get("Content-Type") == "" { req.Header.Set("Content-Type", "application/json; charset=utf-8")