Fix panic in ASCII sanitization (#6956)

This commit is contained in:
Sam Coe 2023-02-07 01:24:57 +11:00 committed by GitHub
parent 9f426bf615
commit 0dd7e9c36f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 38 deletions

View file

@ -32,8 +32,8 @@ func AddASCIISanitizer(rt http.RoundTripper) http.RoundTripper {
// sanitizeASCIIReadCloser implements the ReadCloser interface.
type sanitizeASCIIReadCloser struct {
io.ReadCloser
addBackslash bool
previousWindow []byte
addEscape bool
remainder []byte
}
// Read uses a sliding window alogorithm to detect C0 and C1
@ -41,14 +41,11 @@ type sanitizeASCIIReadCloser struct {
// with equivelent inert characters. Characters that are not part
// of a control sequence not modified.
func (s *sanitizeASCIIReadCloser) Read(out []byte) (int, error) {
var readErr error
var outIndex int
var bufIndex int
var bufLen int
var window []byte
buf := make([]byte, len(out))
var bufIndex, outIndex int
outLen := len(out)
buf := make([]byte, outLen)
bufLen, readErr = s.ReadCloser.Read(buf)
bufLen, readErr := s.ReadCloser.Read(buf)
if readErr != nil && !errors.Is(readErr, io.EOF) {
if bufLen > 0 {
// Do not sanitize if there was a read error that is not EOF.
@ -56,38 +53,39 @@ func (s *sanitizeASCIIReadCloser) Read(out []byte) (int, error) {
}
return bufLen, readErr
}
buf = buf[:bufLen]
if s.previousWindow != nil {
buf = append(s.previousWindow, buf...)
bufLen += len(s.previousWindow)
if s.remainder != nil {
buf = append(s.remainder, buf...)
bufLen += len(s.remainder)
s.remainder = s.remainder[:0]
}
for {
remaining := min(6, (bufLen - bufIndex))
window = buf[bufIndex : bufIndex+remaining]
if remaining < 6 {
break
}
for bufIndex < bufLen-6 && outIndex < outLen {
window := buf[bufIndex : bufIndex+6]
if bytes.HasPrefix(window, []byte(`\u00`)) {
repl, _ := mapControlCharacterToCaret(window)
if s.addBackslash {
repl = append([]byte{92}, repl...)
if s.addEscape {
repl = append([]byte{'\\'}, repl...)
s.addEscape = false
}
l := len(repl)
for j := 0; j < l; j++ {
out[outIndex] = repl[j]
outIndex++
for j := 0; j < len(repl); j++ {
if outIndex < outLen {
out[outIndex] = repl[j]
outIndex++
} else {
s.remainder = append(s.remainder, repl[j])
}
}
bufIndex += 6
s.addBackslash = false
continue
}
if window[0] == '\\' {
s.addBackslash = !s.addBackslash
s.addEscape = !s.addEscape
} else {
s.addBackslash = false
s.addEscape = false
}
out[outIndex] = buf[bufIndex]
@ -98,12 +96,23 @@ func (s *sanitizeASCIIReadCloser) Read(out []byte) (int, error) {
if readErr != nil && errors.Is(readErr, io.EOF) {
remaining := bufLen - bufIndex
for j := 0; j < remaining; j++ {
out[outIndex] = window[j]
outIndex++
bufIndex++
if outIndex < outLen {
out[outIndex] = buf[bufIndex]
outIndex++
bufIndex++
} else {
s.remainder = append(s.remainder, buf[bufIndex])
bufIndex++
}
}
} else {
s.previousWindow = window
if bufIndex < bufLen {
s.remainder = append(s.remainder, buf[bufIndex:]...)
}
}
if len(s.remainder) != 0 {
readErr = nil
}
return outIndex, readErr
@ -184,10 +193,3 @@ func mapControlCharacterToCaret(b []byte) ([]byte, bool) {
}
return b, false
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View file

@ -1,12 +1,14 @@
package api
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"testing/iotest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -48,3 +50,13 @@ func TestHTTPClient_SanitizeASCIIControlCharacters(t *testing.T) {
assert.Equal(t, "monalisa", issue.Author.Login)
assert.Equal(t, "Escaped ^[ \\^[ \\^[ \\\\^[", issue.ActiveLockReason)
}
func TestSanitizeASCIIReadCloser(t *testing.T) {
data := []byte(`"Assign},"L`)
var r io.Reader = bytes.NewReader(data)
r = &sanitizeASCIIReadCloser{ReadCloser: io.NopCloser(r)}
r = iotest.OneByteReader(r)
out, err := io.ReadAll(r)
require.NoError(t, err)
assert.Equal(t, data, out)
}