diff --git a/server/websocket.go b/server/websocket.go index b699631d..ab4ce822 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -75,7 +75,6 @@ const ( wsFirstFrame = true wsContFrame = false wsFinalFrame = true - wsCompressedFrame = true wsUncompressedFrame = false wsSchemePrefix = "ws" @@ -92,6 +91,7 @@ const ( ) var decompressorPool sync.Pool +var compressLastBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff} // From https://tools.ietf.org/html/rfc6455#section-1.3 var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") @@ -144,7 +144,8 @@ type wsReadInfo struct { mask bool // Incoming leafnode connections may not have masking. mkpos byte mkey [4]byte - buf []byte + cbufs [][]byte + coff int } func (r *wsReadInfo) init() { @@ -292,42 +293,118 @@ func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, err b = buf[pos : pos+n] pos += n r.rem -= n - if r.fc { - r.buf = append(r.buf, b...) - b = r.buf + // If needed, unmask the buffer + if r.mask { + r.unmask(b) } - if !r.fc || r.rem == 0 { - if r.mask { - r.unmask(b) - } - if r.fc { - // As per https://tools.ietf.org/html/rfc7692#section-7.2.2 - // add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader - // does not report unexpected EOF. - b = append(b, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff) - br := bytes.NewBuffer(b) - d, _ := decompressorPool.Get().(io.ReadCloser) - if d == nil { - d = flate.NewReader(br) - } else { - d.(flate.Resetter).Reset(br, nil) - } - b, err = ioutil.ReadAll(d) - decompressorPool.Put(d) + addToBufs := true + // Handle compressed message + if r.fc { + // Assume that we may have continuation frames or not the full payload. + addToBufs = false + // Make a copy of the buffer before adding it to the list + // of compressed fragments. + r.cbufs = append(r.cbufs, append([]byte(nil), b...)) + // When we have the final frame and we have read the full payload, + // we can decompress it. + if r.ff && r.rem == 0 { + b, err = r.decompress() if err != nil { return bufs, err } + r.fc = false + // Now we can add to `bufs` + addToBufs = true } + } + // For non compressed frames, or when we have decompressed the + // whole message. + if addToBufs { bufs = append(bufs, b) - if r.rem == 0 { - r.fs, r.fc, r.buf = true, false, nil - } + } + // If payload has been fully read, then indicate that next + // is the start of a frame. + if r.rem == 0 { + r.fs = true } } } return bufs, nil } +func (r *wsReadInfo) Read(dst []byte) (int, error) { + if len(dst) == 0 { + return 0, nil + } + if len(r.cbufs) == 0 { + return 0, io.EOF + } + copied := 0 + rem := len(dst) + for buf := r.cbufs[0]; buf != nil && rem > 0; { + n := len(buf[r.coff:]) + if n > rem { + n = rem + } + copy(dst[copied:], buf[r.coff:r.coff+n]) + copied += n + rem -= n + r.coff += n + buf = r.nextCBuf() + } + return copied, nil +} + +func (r *wsReadInfo) nextCBuf() []byte { + // We still have remaining data in the first buffer + if r.coff != len(r.cbufs[0]) { + return r.cbufs[0] + } + // We read the full first buffer. Reset offset. + r.coff = 0 + // We were at the last buffer, so we are done. + if len(r.cbufs) == 1 { + r.cbufs = nil + return nil + } + // Here we move to the next buffer. + r.cbufs = r.cbufs[1:] + return r.cbufs[0] +} + +func (r *wsReadInfo) ReadByte() (byte, error) { + if len(r.cbufs) == 0 { + return 0, io.EOF + } + b := r.cbufs[0][r.coff] + r.coff++ + r.nextCBuf() + return b, nil +} + +func (r *wsReadInfo) decompress() ([]byte, error) { + r.coff = 0 + // As per https://tools.ietf.org/html/rfc7692#section-7.2.2 + // add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader + // does not report unexpected EOF. + r.cbufs = append(r.cbufs, compressLastBlock) + // Get a decompressor from the pool and bind it to this object (wsReadInfo) + // that provides Read() and ReadByte() APIs that will consume the compressed + // buffers (r.cbufs). + d, _ := decompressorPool.Get().(io.ReadCloser) + if d == nil { + d = flate.NewReader(r) + } else { + d.(flate.Resetter).Reset(r, nil) + } + // This will do the decompression. + b, err := ioutil.ReadAll(d) + decompressorPool.Put(d) + // Now reset the compressed buffers list. + r.cbufs = nil + return b, err +} + // Handles the PING, PONG and CLOSE websocket control frames. // // Client lock MUST NOT be held on entry. @@ -1211,7 +1288,9 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) { final = true } fh := make([]byte, wsMaxFrameHeaderSize) - n, key := wsFillFrameHeader(fh, mask, first, final, wsCompressedFrame, wsBinaryMessage, lp) + // Only the first frame should be marked as compressed, so pass + // `first` for the compressed boolean. + n, key := wsFillFrameHeader(fh, mask, first, final, first, wsBinaryMessage, lp) if mask { wsMaskBuf(key, p[:lp]) } diff --git a/server/websocket_test.go b/server/websocket_test.go index d3a5de57..8c52b560 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -434,7 +434,10 @@ func TestWSReadCompressedFrames(t *testing.T) { // Stress the fact that we use a pool and want to make sure // that if we get a decompressor from the pool, it is properly reset // with the buffer to decompress. - for i := 0; i < 9; i++ { + // Since we unmask the read buffer, reset it now and fill it + // with 10 compressed frames. + rb = nil + for i := 0; i < 10; i++ { rb = append(rb, wsmsg1...) } bufs, err = c.wsRead(ri, tr, rb) @@ -444,6 +447,31 @@ func TestWSReadCompressedFrames(t *testing.T) { if n := len(bufs); n != 10 { t.Fatalf("Unexpected buffer returned: %v", n) } + + // Compress a message and send it in several frames. + buf := &bytes.Buffer{} + compressor, _ := flate.NewWriter(buf, 1) + compressor.Write(uncompressed) + compressor.Flush() + compressed := buf.Bytes() + // The last 4 bytes are dropped + compressed = compressed[:len(compressed)-4] + ncomp := 10 + frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, compressed[:ncomp]) + frag1[0] |= wsRsv1Bit + frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, compressed[ncomp:]) + rb = append([]byte(nil), frag1...) + rb = append(rb, frag2...) + bufs, err = c.wsRead(ri, tr, rb) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if n := len(bufs); n != 1 { + t.Fatalf("Unexpected buffer returned: %v", n) + } + if !bytes.Equal(bufs[0], uncompressed) { + t.Fatalf("Unexpected content: %s", bufs[0]) + } } func TestWSReadCompressedFrameCorrupted(t *testing.T) { @@ -827,15 +855,20 @@ func TestWSReadErrors(t *testing.T) { }, { func() []byte { - frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, true, []byte("frag1")) - frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frag2")) - frag2[0] |= wsRsv1Bit - all := append([]byte(nil), frag1...) - all = append(all, frag2...) + frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("frame")) + frag := testWSCreateClientMsg(wsBinaryMessage, 2, false, false, []byte("continuation")) + all := append([]byte(nil), frame...) + all = append(all, frag...) return all }, "invalid continuation frame", 2, }, + { + func() []byte { + return testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frame")) + }, + "invalid continuation frame", 1, + }, { func() []byte { return testWSCreateClientMsg(99, 1, false, false, []byte("hello")) @@ -2914,7 +2947,17 @@ func TestWSCompressionFrameSizeLimit(t *testing.T) { } } // Check frame headers for the proper formatting. - if i%2 == 1 { + if i%2 == 0 { + // Only the first frame should have the compress bit set. + if b[0]&wsRsv1Bit != 0 { + if i > 0 { + t.Fatalf("Compressed bit should not be in continuation frame") + } + } else if i == 0 { + t.Fatalf("Compressed bit missing") + } + } else { + // Collect the payload bb.Write(b) } }