From 78ffa73f4442e3ebfb1121dc13931e45a14a96e2 Mon Sep 17 00:00:00 2001 From: Sam Coe Date: Mon, 13 Mar 2023 08:36:46 +1100 Subject: [PATCH] Use transform.Transformer interface for ascii sanitization (#7117) --- api/sanitize_ascii.go | 139 +++++++++++++++++-------------------- api/sanitize_ascii_test.go | 12 ++-- 2 files changed, 69 insertions(+), 82 deletions(-) diff --git a/api/sanitize_ascii.go b/api/sanitize_ascii.go index 951cc1bf8..93ecb4751 100644 --- a/api/sanitize_ascii.go +++ b/api/sanitize_ascii.go @@ -2,11 +2,12 @@ package api import ( "bytes" - "errors" "io" "net/http" "regexp" "strings" + + "golang.org/x/text/transform" ) var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`) @@ -26,114 +27,104 @@ func AddASCIISanitizer(rt http.RoundTripper) http.RoundTripper { if err != nil || !jsonTypeRE.MatchString(res.Header.Get("Content-Type")) { return res, err } - res.Body = &sanitizeASCIIReadCloser{ReadCloser: res.Body} + res.Body = sanitizedReadCloser(res.Body) return res, err }} } -// sanitizeASCIIReadCloser implements the ReadCloser interface. -type sanitizeASCIIReadCloser struct { - io.ReadCloser - addEscape bool - remainder []byte +func sanitizedReadCloser(rc io.ReadCloser) io.ReadCloser { + return struct { + io.Reader + io.Closer + }{ + Reader: transform.NewReader(rc, &sanitizer{}), + Closer: rc, + } } -// Read uses a sliding window alogorithm to detect C0 and C1 +// Sanitizer implements transform.Transformer interface. +type sanitizer struct { + addEscape bool +} + +// Transform uses a sliding window alogorithm to detect C0 and C1 // ASCII control sequences as they are read and replaces them // with equivelent inert characters. Characters that are not part -// of a control sequence not modified. -func (s *sanitizeASCIIReadCloser) Read(out []byte) (int, error) { - var bufIndex, outIndex int - outLen := len(out) - buf := make([]byte, outLen) +// of a control sequence are not modified. +func (t *sanitizer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) { + lSrc := len(src) + lDst := len(dst) - bufLen, readErr := s.ReadCloser.Read(buf) - if readErr != nil && !errors.Is(readErr, io.EOF) { - if bufLen > 0 { - // Do not sanitize if there was a read error that is not EOF. - bufLen = copy(out, buf) - } - return bufLen, readErr - } - buf = buf[:bufLen] - - if s.remainder != nil { - buf = append(s.remainder, buf...) - bufLen += len(s.remainder) - s.remainder = s.remainder[:0] - } - - for bufIndex < bufLen-6 && outIndex < outLen { - window := buf[bufIndex : bufIndex+6] + for nSrc < lSrc-6 && nDst < lDst { + window := src[nSrc : nSrc+6] // Replace C1 Control Characters if window[0] == 0xC2 { repl, _ := mapC1ToCaret(window[:2]) - for j := 0; j < len(repl); j++ { - if outIndex < outLen { - out[outIndex] = repl[j] - outIndex++ - } else { - s.remainder = append(s.remainder, repl[j]) - } + if len(repl)+nDst > lDst { + err = transform.ErrShortDst + return } - bufIndex += 2 + for j := 0; j < len(repl); j++ { + dst[nDst] = repl[j] + nDst++ + } + nSrc += 2 continue } // Replace C0 Control Characters if bytes.HasPrefix(window, []byte(`\u00`)) { repl, found := mapC0ToCaret(window) - if s.addEscape && found { + if t.addEscape && found { repl = append([]byte{'\\'}, repl...) } - s.addEscape = false - for j := 0; j < len(repl); j++ { - if outIndex < outLen { - out[outIndex] = repl[j] - outIndex++ - } else { - s.remainder = append(s.remainder, repl[j]) - } + if len(repl)+nDst > lDst { + err = transform.ErrShortDst + return } - bufIndex += 6 + for j := 0; j < len(repl); j++ { + dst[nDst] = repl[j] + nDst++ + } + t.addEscape = false + nSrc += 6 continue } if window[0] == '\\' { - s.addEscape = !s.addEscape + t.addEscape = !t.addEscape } else { - s.addEscape = false + t.addEscape = false } - out[outIndex] = buf[bufIndex] - outIndex++ - bufIndex++ + dst[nDst] = src[nSrc] + nDst++ + nSrc++ } - if readErr != nil && errors.Is(readErr, io.EOF) { - remaining := bufLen - bufIndex - for j := 0; j < remaining; j++ { - if outIndex < outLen { - out[outIndex] = buf[bufIndex] - outIndex++ - bufIndex++ - } else { - s.remainder = append(s.remainder, buf[bufIndex]) - bufIndex++ - } - } - } else { - if bufIndex < bufLen { - s.remainder = append(s.remainder, buf[bufIndex:]...) - } + if !atEOF { + err = transform.ErrShortSrc + return } - if len(s.remainder) != 0 { - readErr = nil + remaining := lSrc - nSrc + if remaining+nDst > lDst { + err = transform.ErrShortDst + return } - return outIndex, readErr + for j := 0; j < remaining; j++ { + dst[nDst] = src[nSrc] + nDst++ + nSrc++ + } + + return +} + +func (t *sanitizer) Reset() { + t.addEscape = false } // mapC0ToCaret maps C0 control sequences to caret notation. @@ -179,7 +170,7 @@ func mapC0ToCaret(b []byte) ([]byte, bool) { } // mapC1ToCaret maps C1 control sequences to caret notation. -// C1 control sequences are two bytes and start with 0xC2. +// C1 control sequences are two bytes long where the first byte is 0xC2. func mapC1ToCaret(b []byte) ([]byte, bool) { if len(b) != 2 { return b, false diff --git a/api/sanitize_ascii_test.go b/api/sanitize_ascii_test.go index 6188e1e0e..9f5c37059 100644 --- a/api/sanitize_ascii_test.go +++ b/api/sanitize_ascii_test.go @@ -86,12 +86,8 @@ func TestHTTPClientSanitizeASCIIControlCharactersC1(t *testing.T) { assert.Equal(t, "monalisa¡", issue.Author.Login) } -func TestSanitizeASCIIReadCloser(t *testing.T) { - data := []byte(`"Assign},"L`) - var r io.Reader = bytes.NewReader(data) - r = &sanitizeASCIIReadCloser{ReadCloser: io.NopCloser(r)} - r = iotest.OneByteReader(r) - out, err := io.ReadAll(r) - require.NoError(t, err) - assert.Equal(t, data, out) +func TestSanitizedReadCloser(t *testing.T) { + data := []byte(`the quick brown fox\njumped over the lazy dog\t`) + rc := sanitizedReadCloser(io.NopCloser(bytes.NewReader(data))) + assert.NoError(t, iotest.TestReader(rc, data)) }