Use transform.Transformer interface for ascii sanitization (#7117)
This commit is contained in:
parent
661d962112
commit
78ffa73f44
2 changed files with 69 additions and 82 deletions
|
|
@ -2,11 +2,12 @@ package api
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`)
|
||||
|
|
@ -26,114 +27,104 @@ func AddASCIISanitizer(rt http.RoundTripper) http.RoundTripper {
|
|||
if err != nil || !jsonTypeRE.MatchString(res.Header.Get("Content-Type")) {
|
||||
return res, err
|
||||
}
|
||||
res.Body = &sanitizeASCIIReadCloser{ReadCloser: res.Body}
|
||||
res.Body = sanitizedReadCloser(res.Body)
|
||||
return res, err
|
||||
}}
|
||||
}
|
||||
|
||||
// sanitizeASCIIReadCloser implements the ReadCloser interface.
|
||||
type sanitizeASCIIReadCloser struct {
|
||||
io.ReadCloser
|
||||
addEscape bool
|
||||
remainder []byte
|
||||
func sanitizedReadCloser(rc io.ReadCloser) io.ReadCloser {
|
||||
return struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}{
|
||||
Reader: transform.NewReader(rc, &sanitizer{}),
|
||||
Closer: rc,
|
||||
}
|
||||
}
|
||||
|
||||
// Read uses a sliding window alogorithm to detect C0 and C1
|
||||
// Sanitizer implements transform.Transformer interface.
|
||||
type sanitizer struct {
|
||||
addEscape bool
|
||||
}
|
||||
|
||||
// Transform uses a sliding window alogorithm to detect C0 and C1
|
||||
// ASCII control sequences as they are read and replaces them
|
||||
// with equivelent inert characters. Characters that are not part
|
||||
// of a control sequence not modified.
|
||||
func (s *sanitizeASCIIReadCloser) Read(out []byte) (int, error) {
|
||||
var bufIndex, outIndex int
|
||||
outLen := len(out)
|
||||
buf := make([]byte, outLen)
|
||||
// of a control sequence are not modified.
|
||||
func (t *sanitizer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) {
|
||||
lSrc := len(src)
|
||||
lDst := len(dst)
|
||||
|
||||
bufLen, readErr := s.ReadCloser.Read(buf)
|
||||
if readErr != nil && !errors.Is(readErr, io.EOF) {
|
||||
if bufLen > 0 {
|
||||
// Do not sanitize if there was a read error that is not EOF.
|
||||
bufLen = copy(out, buf)
|
||||
}
|
||||
return bufLen, readErr
|
||||
}
|
||||
buf = buf[:bufLen]
|
||||
|
||||
if s.remainder != nil {
|
||||
buf = append(s.remainder, buf...)
|
||||
bufLen += len(s.remainder)
|
||||
s.remainder = s.remainder[:0]
|
||||
}
|
||||
|
||||
for bufIndex < bufLen-6 && outIndex < outLen {
|
||||
window := buf[bufIndex : bufIndex+6]
|
||||
for nSrc < lSrc-6 && nDst < lDst {
|
||||
window := src[nSrc : nSrc+6]
|
||||
|
||||
// Replace C1 Control Characters
|
||||
if window[0] == 0xC2 {
|
||||
repl, _ := mapC1ToCaret(window[:2])
|
||||
for j := 0; j < len(repl); j++ {
|
||||
if outIndex < outLen {
|
||||
out[outIndex] = repl[j]
|
||||
outIndex++
|
||||
} else {
|
||||
s.remainder = append(s.remainder, repl[j])
|
||||
}
|
||||
if len(repl)+nDst > lDst {
|
||||
err = transform.ErrShortDst
|
||||
return
|
||||
}
|
||||
bufIndex += 2
|
||||
for j := 0; j < len(repl); j++ {
|
||||
dst[nDst] = repl[j]
|
||||
nDst++
|
||||
}
|
||||
nSrc += 2
|
||||
continue
|
||||
}
|
||||
|
||||
// Replace C0 Control Characters
|
||||
if bytes.HasPrefix(window, []byte(`\u00`)) {
|
||||
repl, found := mapC0ToCaret(window)
|
||||
if s.addEscape && found {
|
||||
if t.addEscape && found {
|
||||
repl = append([]byte{'\\'}, repl...)
|
||||
}
|
||||
s.addEscape = false
|
||||
for j := 0; j < len(repl); j++ {
|
||||
if outIndex < outLen {
|
||||
out[outIndex] = repl[j]
|
||||
outIndex++
|
||||
} else {
|
||||
s.remainder = append(s.remainder, repl[j])
|
||||
}
|
||||
if len(repl)+nDst > lDst {
|
||||
err = transform.ErrShortDst
|
||||
return
|
||||
}
|
||||
bufIndex += 6
|
||||
for j := 0; j < len(repl); j++ {
|
||||
dst[nDst] = repl[j]
|
||||
nDst++
|
||||
}
|
||||
t.addEscape = false
|
||||
nSrc += 6
|
||||
continue
|
||||
}
|
||||
|
||||
if window[0] == '\\' {
|
||||
s.addEscape = !s.addEscape
|
||||
t.addEscape = !t.addEscape
|
||||
} else {
|
||||
s.addEscape = false
|
||||
t.addEscape = false
|
||||
}
|
||||
|
||||
out[outIndex] = buf[bufIndex]
|
||||
outIndex++
|
||||
bufIndex++
|
||||
dst[nDst] = src[nSrc]
|
||||
nDst++
|
||||
nSrc++
|
||||
}
|
||||
|
||||
if readErr != nil && errors.Is(readErr, io.EOF) {
|
||||
remaining := bufLen - bufIndex
|
||||
for j := 0; j < remaining; j++ {
|
||||
if outIndex < outLen {
|
||||
out[outIndex] = buf[bufIndex]
|
||||
outIndex++
|
||||
bufIndex++
|
||||
} else {
|
||||
s.remainder = append(s.remainder, buf[bufIndex])
|
||||
bufIndex++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if bufIndex < bufLen {
|
||||
s.remainder = append(s.remainder, buf[bufIndex:]...)
|
||||
}
|
||||
if !atEOF {
|
||||
err = transform.ErrShortSrc
|
||||
return
|
||||
}
|
||||
|
||||
if len(s.remainder) != 0 {
|
||||
readErr = nil
|
||||
remaining := lSrc - nSrc
|
||||
if remaining+nDst > lDst {
|
||||
err = transform.ErrShortDst
|
||||
return
|
||||
}
|
||||
|
||||
return outIndex, readErr
|
||||
for j := 0; j < remaining; j++ {
|
||||
dst[nDst] = src[nSrc]
|
||||
nDst++
|
||||
nSrc++
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (t *sanitizer) Reset() {
|
||||
t.addEscape = false
|
||||
}
|
||||
|
||||
// mapC0ToCaret maps C0 control sequences to caret notation.
|
||||
|
|
@ -179,7 +170,7 @@ func mapC0ToCaret(b []byte) ([]byte, bool) {
|
|||
}
|
||||
|
||||
// mapC1ToCaret maps C1 control sequences to caret notation.
|
||||
// C1 control sequences are two bytes and start with 0xC2.
|
||||
// C1 control sequences are two bytes long where the first byte is 0xC2.
|
||||
func mapC1ToCaret(b []byte) ([]byte, bool) {
|
||||
if len(b) != 2 {
|
||||
return b, false
|
||||
|
|
|
|||
|
|
@ -86,12 +86,8 @@ func TestHTTPClientSanitizeASCIIControlCharactersC1(t *testing.T) {
|
|||
assert.Equal(t, "monalisa¡", issue.Author.Login)
|
||||
}
|
||||
|
||||
func TestSanitizeASCIIReadCloser(t *testing.T) {
|
||||
data := []byte(`"Assign},"L`)
|
||||
var r io.Reader = bytes.NewReader(data)
|
||||
r = &sanitizeASCIIReadCloser{ReadCloser: io.NopCloser(r)}
|
||||
r = iotest.OneByteReader(r)
|
||||
out, err := io.ReadAll(r)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, data, out)
|
||||
func TestSanitizedReadCloser(t *testing.T) {
|
||||
data := []byte(`the quick brown fox\njumped over the lazy dog\t`)
|
||||
rc := sanitizedReadCloser(io.NopCloser(bytes.NewReader(data)))
|
||||
assert.NoError(t, iotest.TestReader(rc, data))
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue