Turns out we do need explicit Content-Length for file uploads

This reverts commit 141388fd23.
This commit is contained in:
Mislav Marohnić 2020-06-10 17:34:01 +02:00
parent 4d95349081
commit 74a39f3ed1
3 changed files with 106 additions and 38 deletions

View file

@ -103,13 +103,16 @@ func apiRun(opts *ApiOptions) error {
}
if opts.RequestInputFile != "" {
file, err := openUserFile(opts.RequestInputFile, opts.IO.In)
file, size, err := openUserFile(opts.RequestInputFile, opts.IO.In)
if err != nil {
return err
}
defer file.Close()
requestPath = addQuery(requestPath, params)
requestBody = file
if size >= 0 {
requestHeaders = append([]string{fmt.Sprintf("Content-Length: %d", size)}, requestHeaders...)
}
}
httpClient, err := opts.HttpClient()
@ -240,19 +243,36 @@ func magicFieldValue(v string, stdin io.ReadCloser) (interface{}, error) {
}
func readUserFile(fn string, stdin io.ReadCloser) ([]byte, error) {
r, err := openUserFile(fn, stdin)
if err != nil {
return nil, err
var r io.ReadCloser
if fn == "-" {
r = stdin
} else {
var err error
r, err = os.Open(fn)
if err != nil {
return nil, err
}
}
defer r.Close()
return ioutil.ReadAll(r)
}
func openUserFile(fn string, stdin io.ReadCloser) (io.ReadCloser, error) {
func openUserFile(fn string, stdin io.ReadCloser) (io.ReadCloser, int64, error) {
if fn == "-" {
return stdin, nil
return stdin, -1, nil
}
return os.Open(fn)
r, err := os.Open(fn)
if err != nil {
return r, -1, err
}
s, err := os.Stat(fn)
if err != nil {
return r, -1, err
}
return r, s.Size(), nil
}
func parseErrorResponse(r io.Reader, statusCode int) (io.Reader, string, error) {

View file

@ -247,41 +247,78 @@ func Test_apiRun(t *testing.T) {
}
func Test_apiRun_inputFile(t *testing.T) {
io, stdin, _, _ := iostreams.Test()
resp := &http.Response{StatusCode: 204}
tests := []struct {
name string
inputFile string
inputContents []byte
options := ApiOptions{
RequestPath: "hello",
RequestInputFile: "-",
RawFields: []string{"a=b", "c=d"},
IO: io,
HttpClient: func() (*http.Client, error) {
var tr roundTripper = func(req *http.Request) (*http.Response, error) {
resp.Request = req
return resp, nil
}
return &http.Client{Transport: tr}, nil
contentLength int64
expectedContents []byte
}{
{
name: "stdin",
inputFile: "-",
inputContents: []byte("I WORK OUT"),
contentLength: 0,
},
{
name: "from file",
inputFile: "gh-test-file",
inputContents: []byte("I WORK OUT"),
contentLength: 10,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
io, stdin, _, _ := iostreams.Test()
resp := &http.Response{StatusCode: 204}
fmt.Fprintln(stdin, "I WORK OUT")
inputFile := tt.inputFile
if tt.inputFile == "-" {
_, _ = stdin.Write(tt.inputContents)
} else {
f, err := ioutil.TempFile("", tt.inputFile)
if err != nil {
t.Fatal(err)
}
_, _ = f.Write(tt.inputContents)
f.Close()
t.Cleanup(func() { os.Remove(f.Name()) })
inputFile = f.Name()
}
err := apiRun(&options)
if err != nil {
t.Errorf("got error %v", err)
var bodyBytes []byte
options := ApiOptions{
RequestPath: "hello",
RequestInputFile: inputFile,
RawFields: []string{"a=b", "c=d"},
IO: io,
HttpClient: func() (*http.Client, error) {
var tr roundTripper = func(req *http.Request) (*http.Response, error) {
var err error
if bodyBytes, err = ioutil.ReadAll(req.Body); err != nil {
return nil, err
}
resp.Request = req
return resp, nil
}
return &http.Client{Transport: tr}, nil
},
}
err := apiRun(&options)
if err != nil {
t.Errorf("got error %v", err)
}
assert.Equal(t, "POST", resp.Request.Method)
assert.Equal(t, "/hello?a=b&c=d", resp.Request.URL.RequestURI())
assert.Equal(t, tt.contentLength, resp.Request.ContentLength)
assert.Equal(t, "", resp.Request.Header.Get("Content-Type"))
assert.Equal(t, tt.inputContents, bodyBytes)
})
}
assert.Equal(t, "POST", resp.Request.Method)
assert.Equal(t, "/hello?a=b&c=d", resp.Request.URL.RequestURI())
assert.Equal(t, "", resp.Request.Header.Get("Content-Length"))
assert.Equal(t, "", resp.Request.Header.Get("Content-Type"))
bb, err := ioutil.ReadAll(resp.Request.Body)
if err != nil {
t.Errorf("got error %v", err)
}
assert.Equal(t, "I WORK OUT\n", string(bb))
}
func Test_parseFields(t *testing.T) {
@ -400,7 +437,7 @@ func Test_openUserFile(t *testing.T) {
f.Close()
t.Cleanup(func() { os.Remove(f.Name()) })
file, err := openUserFile(f.Name(), nil)
file, length, err := openUserFile(f.Name(), nil)
if err != nil {
t.Fatal(err)
}
@ -411,5 +448,6 @@ func Test_openUserFile(t *testing.T) {
t.Fatal(err)
}
assert.Equal(t, int64(13), length)
assert.Equal(t, "file contents", string(fb))
}

View file

@ -7,6 +7,7 @@ import (
"io"
"net/http"
"net/url"
"strconv"
"strings"
)
@ -62,7 +63,16 @@ func httpRequest(client *http.Client, method string, p string, params interface{
if idx == -1 {
return nil, fmt.Errorf("header %q requires a value separated by ':'", h)
}
req.Header.Add(h[0:idx], strings.TrimSpace(h[idx+1:]))
name, value := h[0:idx], strings.TrimSpace(h[idx+1:])
if strings.EqualFold(name, "Content-Length") {
length, err := strconv.ParseInt(value, 10, 0)
if err != nil {
return nil, err
}
req.ContentLength = length
} else {
req.Header.Add(name, value)
}
}
if bodyIsJSON && req.Header.Get("Content-Type") == "" {
req.Header.Set("Content-Type", "application/json; charset=utf-8")