Use transform.Transformer interface for ascii sanitization (#7117)

This commit is contained in:
Sam Coe 2023-03-13 08:36:46 +11:00 committed by GitHub
parent 661d962112
commit 78ffa73f44
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 82 deletions

View file

@ -2,11 +2,12 @@ package api
import (
"bytes"
"errors"
"io"
"net/http"
"regexp"
"strings"
"golang.org/x/text/transform"
)
var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`)
@ -26,114 +27,104 @@ func AddASCIISanitizer(rt http.RoundTripper) http.RoundTripper {
if err != nil || !jsonTypeRE.MatchString(res.Header.Get("Content-Type")) {
return res, err
}
res.Body = &sanitizeASCIIReadCloser{ReadCloser: res.Body}
res.Body = sanitizedReadCloser(res.Body)
return res, err
}}
}
// sanitizeASCIIReadCloser implements the ReadCloser interface.
type sanitizeASCIIReadCloser struct {
io.ReadCloser
addEscape bool
remainder []byte
func sanitizedReadCloser(rc io.ReadCloser) io.ReadCloser {
return struct {
io.Reader
io.Closer
}{
Reader: transform.NewReader(rc, &sanitizer{}),
Closer: rc,
}
}
// Read uses a sliding window alogorithm to detect C0 and C1
// Sanitizer implements transform.Transformer interface.
type sanitizer struct {
addEscape bool
}
// Transform uses a sliding window alogorithm to detect C0 and C1
// ASCII control sequences as they are read and replaces them
// with equivelent inert characters. Characters that are not part
// of a control sequence not modified.
func (s *sanitizeASCIIReadCloser) Read(out []byte) (int, error) {
var bufIndex, outIndex int
outLen := len(out)
buf := make([]byte, outLen)
// of a control sequence are not modified.
func (t *sanitizer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
lSrc := len(src)
lDst := len(dst)
bufLen, readErr := s.ReadCloser.Read(buf)
if readErr != nil && !errors.Is(readErr, io.EOF) {
if bufLen > 0 {
// Do not sanitize if there was a read error that is not EOF.
bufLen = copy(out, buf)
}
return bufLen, readErr
}
buf = buf[:bufLen]
if s.remainder != nil {
buf = append(s.remainder, buf...)
bufLen += len(s.remainder)
s.remainder = s.remainder[:0]
}
for bufIndex < bufLen-6 && outIndex < outLen {
window := buf[bufIndex : bufIndex+6]
for nSrc < lSrc-6 && nDst < lDst {
window := src[nSrc : nSrc+6]
// Replace C1 Control Characters
if window[0] == 0xC2 {
repl, _ := mapC1ToCaret(window[:2])
for j := 0; j < len(repl); j++ {
if outIndex < outLen {
out[outIndex] = repl[j]
outIndex++
} else {
s.remainder = append(s.remainder, repl[j])
}
if len(repl)+nDst > lDst {
err = transform.ErrShortDst
return
}
bufIndex += 2
for j := 0; j < len(repl); j++ {
dst[nDst] = repl[j]
nDst++
}
nSrc += 2
continue
}
// Replace C0 Control Characters
if bytes.HasPrefix(window, []byte(`\u00`)) {
repl, found := mapC0ToCaret(window)
if s.addEscape && found {
if t.addEscape && found {
repl = append([]byte{'\\'}, repl...)
}
s.addEscape = false
for j := 0; j < len(repl); j++ {
if outIndex < outLen {
out[outIndex] = repl[j]
outIndex++
} else {
s.remainder = append(s.remainder, repl[j])
}
if len(repl)+nDst > lDst {
err = transform.ErrShortDst
return
}
bufIndex += 6
for j := 0; j < len(repl); j++ {
dst[nDst] = repl[j]
nDst++
}
t.addEscape = false
nSrc += 6
continue
}
if window[0] == '\\' {
s.addEscape = !s.addEscape
t.addEscape = !t.addEscape
} else {
s.addEscape = false
t.addEscape = false
}
out[outIndex] = buf[bufIndex]
outIndex++
bufIndex++
dst[nDst] = src[nSrc]
nDst++
nSrc++
}
if readErr != nil && errors.Is(readErr, io.EOF) {
remaining := bufLen - bufIndex
for j := 0; j < remaining; j++ {
if outIndex < outLen {
out[outIndex] = buf[bufIndex]
outIndex++
bufIndex++
} else {
s.remainder = append(s.remainder, buf[bufIndex])
bufIndex++
}
}
} else {
if bufIndex < bufLen {
s.remainder = append(s.remainder, buf[bufIndex:]...)
}
if !atEOF {
err = transform.ErrShortSrc
return
}
if len(s.remainder) != 0 {
readErr = nil
remaining := lSrc - nSrc
if remaining+nDst > lDst {
err = transform.ErrShortDst
return
}
return outIndex, readErr
for j := 0; j < remaining; j++ {
dst[nDst] = src[nSrc]
nDst++
nSrc++
}
return
}
func (t *sanitizer) Reset() {
t.addEscape = false
}
// mapC0ToCaret maps C0 control sequences to caret notation.
@ -179,7 +170,7 @@ func mapC0ToCaret(b []byte) ([]byte, bool) {
}
// mapC1ToCaret maps C1 control sequences to caret notation.
// C1 control sequences are two bytes and start with 0xC2.
// C1 control sequences are two bytes long where the first byte is 0xC2.
func mapC1ToCaret(b []byte) ([]byte, bool) {
if len(b) != 2 {
return b, false

View file

@ -86,12 +86,8 @@ func TestHTTPClientSanitizeASCIIControlCharactersC1(t *testing.T) {
assert.Equal(t, "monalisa¡", issue.Author.Login)
}
func TestSanitizeASCIIReadCloser(t *testing.T) {
data := []byte(`"Assign},"L`)
var r io.Reader = bytes.NewReader(data)
r = &sanitizeASCIIReadCloser{ReadCloser: io.NopCloser(r)}
r = iotest.OneByteReader(r)
out, err := io.ReadAll(r)
require.NoError(t, err)
assert.Equal(t, data, out)
func TestSanitizedReadCloser(t *testing.T) {
data := []byte(`the quick brown fox\njumped over the lazy dog\t`)
rc := sanitizedReadCloser(io.NopCloser(bytes.NewReader(data)))
assert.NoError(t, iotest.TestReader(rc, data))
}