[FIXED] Websocket compression/decompression issue with continuation frames

For compression, continuation frames had the compress bit set, which is
wrong since only the first frame should.
For decompression, continuation frames were decompressed individually
instead of assembling the full payload and then decompressing.

Resolves #2287

Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
This commit is contained in:
Ivan Kozlovic
2021-06-21 11:09:19 -06:00
parent 6129562b63
commit 189336417f
2 changed files with 156 additions and 34 deletions

View File

@@ -75,7 +75,6 @@ const (
wsFirstFrame = true
wsContFrame = false
wsFinalFrame = true
wsCompressedFrame = true
wsUncompressedFrame = false
wsSchemePrefix = "ws"
@@ -92,6 +91,7 @@ const (
)
var decompressorPool sync.Pool
var compressLastBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff}
// From https://tools.ietf.org/html/rfc6455#section-1.3
var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
@@ -144,7 +144,8 @@ type wsReadInfo struct {
mask bool // Incoming leafnode connections may not have masking.
mkpos byte
mkey [4]byte
buf []byte
cbufs [][]byte
coff int
}
func (r *wsReadInfo) init() {
@@ -292,42 +293,118 @@ func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, err
b = buf[pos : pos+n]
pos += n
r.rem -= n
if r.fc {
r.buf = append(r.buf, b...)
b = r.buf
// If needed, unmask the buffer
if r.mask {
r.unmask(b)
}
if !r.fc || r.rem == 0 {
if r.mask {
r.unmask(b)
}
if r.fc {
// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
// does not report unexpected EOF.
b = append(b, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff)
br := bytes.NewBuffer(b)
d, _ := decompressorPool.Get().(io.ReadCloser)
if d == nil {
d = flate.NewReader(br)
} else {
d.(flate.Resetter).Reset(br, nil)
}
b, err = ioutil.ReadAll(d)
decompressorPool.Put(d)
addToBufs := true
// Handle compressed message
if r.fc {
// Assume that we may have continuation frames or not the full payload.
addToBufs = false
// Make a copy of the buffer before adding it to the list
// of compressed fragments.
r.cbufs = append(r.cbufs, append([]byte(nil), b...))
// When we have the final frame and we have read the full payload,
// we can decompress it.
if r.ff && r.rem == 0 {
b, err = r.decompress()
if err != nil {
return bufs, err
}
r.fc = false
// Now we can add to `bufs`
addToBufs = true
}
}
// For non compressed frames, or when we have decompressed the
// whole message.
if addToBufs {
bufs = append(bufs, b)
if r.rem == 0 {
r.fs, r.fc, r.buf = true, false, nil
}
}
// If payload has been fully read, then indicate that next
// is the start of a frame.
if r.rem == 0 {
r.fs = true
}
}
}
return bufs, nil
}
func (r *wsReadInfo) Read(dst []byte) (int, error) {
if len(dst) == 0 {
return 0, nil
}
if len(r.cbufs) == 0 {
return 0, io.EOF
}
copied := 0
rem := len(dst)
for buf := r.cbufs[0]; buf != nil && rem > 0; {
n := len(buf[r.coff:])
if n > rem {
n = rem
}
copy(dst[copied:], buf[r.coff:r.coff+n])
copied += n
rem -= n
r.coff += n
buf = r.nextCBuf()
}
return copied, nil
}
func (r *wsReadInfo) nextCBuf() []byte {
// We still have remaining data in the first buffer
if r.coff != len(r.cbufs[0]) {
return r.cbufs[0]
}
// We read the full first buffer. Reset offset.
r.coff = 0
// We were at the last buffer, so we are done.
if len(r.cbufs) == 1 {
r.cbufs = nil
return nil
}
// Here we move to the next buffer.
r.cbufs = r.cbufs[1:]
return r.cbufs[0]
}
func (r *wsReadInfo) ReadByte() (byte, error) {
if len(r.cbufs) == 0 {
return 0, io.EOF
}
b := r.cbufs[0][r.coff]
r.coff++
r.nextCBuf()
return b, nil
}
func (r *wsReadInfo) decompress() ([]byte, error) {
r.coff = 0
// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
// does not report unexpected EOF.
r.cbufs = append(r.cbufs, compressLastBlock)
// Get a decompressor from the pool and bind it to this object (wsReadInfo)
// that provides Read() and ReadByte() APIs that will consume the compressed
// buffers (r.cbufs).
d, _ := decompressorPool.Get().(io.ReadCloser)
if d == nil {
d = flate.NewReader(r)
} else {
d.(flate.Resetter).Reset(r, nil)
}
// This will do the decompression.
b, err := ioutil.ReadAll(d)
decompressorPool.Put(d)
// Now reset the compressed buffers list.
r.cbufs = nil
return b, err
}
// Handles the PING, PONG and CLOSE websocket control frames.
//
// Client lock MUST NOT be held on entry.
@@ -1211,7 +1288,9 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
final = true
}
fh := make([]byte, wsMaxFrameHeaderSize)
n, key := wsFillFrameHeader(fh, mask, first, final, wsCompressedFrame, wsBinaryMessage, lp)
// Only the first frame should be marked as compressed, so pass
// `first` for the compressed boolean.
n, key := wsFillFrameHeader(fh, mask, first, final, first, wsBinaryMessage, lp)
if mask {
wsMaskBuf(key, p[:lp])
}

View File

@@ -434,7 +434,10 @@ func TestWSReadCompressedFrames(t *testing.T) {
// Stress the fact that we use a pool and want to make sure
// that if we get a decompressor from the pool, it is properly reset
// with the buffer to decompress.
for i := 0; i < 9; i++ {
// Since we unmask the read buffer, reset it now and fill it
// with 10 compressed frames.
rb = nil
for i := 0; i < 10; i++ {
rb = append(rb, wsmsg1...)
}
bufs, err = c.wsRead(ri, tr, rb)
@@ -444,6 +447,31 @@ func TestWSReadCompressedFrames(t *testing.T) {
if n := len(bufs); n != 10 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
// Compress a message and send it in several frames.
buf := &bytes.Buffer{}
compressor, _ := flate.NewWriter(buf, 1)
compressor.Write(uncompressed)
compressor.Flush()
compressed := buf.Bytes()
// The last 4 bytes are dropped
compressed = compressed[:len(compressed)-4]
ncomp := 10
frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, compressed[:ncomp])
frag1[0] |= wsRsv1Bit
frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, compressed[ncomp:])
rb = append([]byte(nil), frag1...)
rb = append(rb, frag2...)
bufs, err = c.wsRead(ri, tr, rb)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n := len(bufs); n != 1 {
t.Fatalf("Unexpected buffer returned: %v", n)
}
if !bytes.Equal(bufs[0], uncompressed) {
t.Fatalf("Unexpected content: %s", bufs[0])
}
}
func TestWSReadCompressedFrameCorrupted(t *testing.T) {
@@ -827,15 +855,20 @@ func TestWSReadErrors(t *testing.T) {
},
{
func() []byte {
frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, true, []byte("frag1"))
frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frag2"))
frag2[0] |= wsRsv1Bit
all := append([]byte(nil), frag1...)
all = append(all, frag2...)
frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("frame"))
frag := testWSCreateClientMsg(wsBinaryMessage, 2, false, false, []byte("continuation"))
all := append([]byte(nil), frame...)
all = append(all, frag...)
return all
},
"invalid continuation frame", 2,
},
{
func() []byte {
return testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frame"))
},
"invalid continuation frame", 1,
},
{
func() []byte {
return testWSCreateClientMsg(99, 1, false, false, []byte("hello"))
@@ -2914,7 +2947,17 @@ func TestWSCompressionFrameSizeLimit(t *testing.T) {
}
}
// Check frame headers for the proper formatting.
if i%2 == 1 {
if i%2 == 0 {
// Only the first frame should have the compress bit set.
if b[0]&wsRsv1Bit != 0 {
if i > 0 {
t.Fatalf("Compressed bit should not be in continuation frame")
}
} else if i == 0 {
t.Fatalf("Compressed bit missing")
}
} else {
// Collect the payload
bb.Write(b)
}
}