diff --git a/api/sanitize_ascii.go b/api/sanitize_ascii.go index 92741a147..6033a07a6 100644 --- a/api/sanitize_ascii.go +++ b/api/sanitize_ascii.go @@ -32,8 +32,8 @@ func AddASCIISanitizer(rt http.RoundTripper) http.RoundTripper { // sanitizeASCIIReadCloser implements the ReadCloser interface. type sanitizeASCIIReadCloser struct { io.ReadCloser - addBackslash bool - previousWindow []byte + addEscape bool + remainder []byte } // Read uses a sliding window alogorithm to detect C0 and C1 @@ -41,14 +41,11 @@ type sanitizeASCIIReadCloser struct { // with equivelent inert characters. Characters that are not part // of a control sequence not modified. func (s *sanitizeASCIIReadCloser) Read(out []byte) (int, error) { - var readErr error - var outIndex int - var bufIndex int - var bufLen int - var window []byte - buf := make([]byte, len(out)) + var bufIndex, outIndex int + outLen := len(out) + buf := make([]byte, outLen) - bufLen, readErr = s.ReadCloser.Read(buf) + bufLen, readErr := s.ReadCloser.Read(buf) if readErr != nil && !errors.Is(readErr, io.EOF) { if bufLen > 0 { // Do not sanitize if there was a read error that is not EOF. @@ -56,38 +53,39 @@ func (s *sanitizeASCIIReadCloser) Read(out []byte) (int, error) { } return bufLen, readErr } + buf = buf[:bufLen] - if s.previousWindow != nil { - buf = append(s.previousWindow, buf...) - bufLen += len(s.previousWindow) + if s.remainder != nil { + buf = append(s.remainder, buf...) + bufLen += len(s.remainder) + s.remainder = s.remainder[:0] } - for { - remaining := min(6, (bufLen - bufIndex)) - window = buf[bufIndex : bufIndex+remaining] - if remaining < 6 { - break - } + for bufIndex < bufLen-6 && outIndex < outLen { + window := buf[bufIndex : bufIndex+6] if bytes.HasPrefix(window, []byte(`\u00`)) { repl, _ := mapControlCharacterToCaret(window) - if s.addBackslash { - repl = append([]byte{92}, repl...) + if s.addEscape { + repl = append([]byte{'\\'}, repl...) + s.addEscape = false } - l := len(repl) - for j := 0; j < l; j++ { - out[outIndex] = repl[j] - outIndex++ + for j := 0; j < len(repl); j++ { + if outIndex < outLen { + out[outIndex] = repl[j] + outIndex++ + } else { + s.remainder = append(s.remainder, repl[j]) + } } bufIndex += 6 - s.addBackslash = false continue } if window[0] == '\\' { - s.addBackslash = !s.addBackslash + s.addEscape = !s.addEscape } else { - s.addBackslash = false + s.addEscape = false } out[outIndex] = buf[bufIndex] @@ -98,12 +96,23 @@ func (s *sanitizeASCIIReadCloser) Read(out []byte) (int, error) { if readErr != nil && errors.Is(readErr, io.EOF) { remaining := bufLen - bufIndex for j := 0; j < remaining; j++ { - out[outIndex] = window[j] - outIndex++ - bufIndex++ + if outIndex < outLen { + out[outIndex] = buf[bufIndex] + outIndex++ + bufIndex++ + } else { + s.remainder = append(s.remainder, buf[bufIndex]) + bufIndex++ + } } } else { - s.previousWindow = window + if bufIndex < bufLen { + s.remainder = append(s.remainder, buf[bufIndex:]...) + } + } + + if len(s.remainder) != 0 { + readErr = nil } return outIndex, readErr @@ -184,10 +193,3 @@ func mapControlCharacterToCaret(b []byte) ([]byte, bool) { } return b, false } - -func min(a, b int) int { - if a < b { - return a - } - return b -} diff --git a/api/sanitize_ascii_test.go b/api/sanitize_ascii_test.go index 9b405edc8..ff43f9287 100644 --- a/api/sanitize_ascii_test.go +++ b/api/sanitize_ascii_test.go @@ -1,12 +1,14 @@ package api import ( + "bytes" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "testing" + "testing/iotest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -48,3 +50,13 @@ func TestHTTPClient_SanitizeASCIIControlCharacters(t *testing.T) { assert.Equal(t, "monalisa", issue.Author.Login) assert.Equal(t, "Escaped ^[ \\^[ \\^[ \\\\^[", issue.ActiveLockReason) } + +func TestSanitizeASCIIReadCloser(t *testing.T) { + data := []byte(`"Assign},"L`) + var r io.Reader = bytes.NewReader(data) + r = &sanitizeASCIIReadCloser{ReadCloser: io.NopCloser(r)} + r = iotest.OneByteReader(r) + out, err := io.ReadAll(r) + require.NoError(t, err) + assert.Equal(t, data, out) +}