From 2ece00b08fbe12f1aba94c6a27d3e060887ae9e4 Mon Sep 17 00:00:00 2001 From: Neil Twigg Date: Fri, 21 Apr 2023 15:15:09 +0100 Subject: [PATCH] Buffer re-use in WebSocket code, fix race conditions Signed-off-by: Neil Twigg --- server/client.go | 38 +++++++++++++++++++++----------------- server/websocket.go | 44 +++++++++++++++++++++++++++++++------------- 2 files changed, 52 insertions(+), 30 deletions(-) diff --git a/server/client.go b/server/client.go index 71892c57..0089ae74 100644 --- a/server/client.go +++ b/server/client.go @@ -327,6 +327,22 @@ var nbPoolLarge = &sync.Pool{ }, } +func nbPoolGet(sz int) []byte { + var new []byte + switch { + case sz <= nbPoolSizeSmall: + ptr := nbPoolSmall.Get().(*[nbPoolSizeSmall]byte) + new = ptr[:0] + case sz <= nbPoolSizeMedium: + ptr := nbPoolMedium.Get().(*[nbPoolSizeMedium]byte) + new = ptr[:0] + default: + ptr := nbPoolLarge.Get().(*[nbPoolSizeLarge]byte) + new = ptr[:0] + } + return new +} + func nbPoolPut(b []byte) { switch cap(b) { case nbPoolSizeSmall: @@ -1520,7 +1536,7 @@ func (c *client) flushOutbound() bool { // Check for partial writes // TODO(dlc) - zero write with no error will cause lost message and the writeloop to spin. if n != attempted && n > 0 { - c.handlePartialWrite(c.out.wnb) + c.handlePartialWrite(c.out.nb) } // Check that if there is still data to send and writeLoop is in wait, @@ -2029,22 +2045,10 @@ func (c *client) queueOutbound(data []byte) { // in fixed size chunks. This ensures we don't go over the capacity of any // of the buffers and end up reallocating. for len(toBuffer) > 0 { - var new []byte - switch { - case len(toBuffer) <= nbPoolSizeSmall: - new = nbPoolSmall.Get().(*[nbPoolSizeSmall]byte)[:0] - case len(toBuffer) <= nbPoolSizeMedium: - new = nbPoolMedium.Get().(*[nbPoolSizeMedium]byte)[:0] - default: - new = nbPoolLarge.Get().(*[nbPoolSizeLarge]byte)[:0] - } - l := len(toBuffer) - if c := cap(new); l > c { - l = c - } - new = append(new[:0], toBuffer[:l]...) - c.out.nb = append(c.out.nb, new) - toBuffer = toBuffer[l:] + new := nbPoolGet(len(toBuffer)) + n := copy(new[:cap(new)], toBuffer) + c.out.nb = append(c.out.nb, new[:n]) + toBuffer = toBuffer[n:] } // Check for slow consumer via pending bytes limit. diff --git a/server/websocket.go b/server/websocket.go index fd06c986..5d19ee6b 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -452,7 +452,9 @@ func (c *client) wsHandleControlFrame(r *wsReadInfo, frameType wsOpCode, nc io.R } } } - c.wsEnqueueControlMessage(wsCloseMessage, wsCreateCloseMessage(status, body)) + clm := wsCreateCloseMessage(status, body) + c.wsEnqueueControlMessage(wsCloseMessage, clm) + nbPoolPut(clm) // wsEnqueueControlMessage has taken a copy. // Return io.EOF so that readLoop will close the connection as ClientClosed // after processing pending buffers. return pos, io.EOF @@ -502,7 +504,7 @@ func wsIsControlFrame(frameType wsOpCode) bool { // Create the frame header. // Encodes the frame type and optional compression flag, and the size of the payload. func wsCreateFrameHeader(useMasking, compressed bool, frameType wsOpCode, l int) ([]byte, []byte) { - fh := make([]byte, wsMaxFrameHeaderSize) + fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize] n, key := wsFillFrameHeader(fh, useMasking, wsFirstFrame, wsFinalFrame, compressed, frameType, l) return fh[:n], key } @@ -596,11 +598,13 @@ func (c *client) wsEnqueueControlMessageLocked(controlMsg wsOpCode, payload []by if useMasking { sz += 4 } - cm := make([]byte, sz+len(payload)) + cm := nbPoolGet(sz + len(payload)) + cm = cm[:cap(cm)] n, key := wsFillFrameHeader(cm, useMasking, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, controlMsg, len(payload)) + cm = cm[:n] // Note that payload is optional. if len(payload) > 0 { - copy(cm[n:], payload) + cm = append(cm, payload...) if useMasking { wsMaskBuf(key, cm[n:]) } @@ -646,6 +650,7 @@ func (c *client) wsEnqueueCloseMessage(reason ClosedState) { } body := wsCreateCloseMessage(status, reason.String()) c.wsEnqueueControlMessageLocked(wsCloseMessage, body) + nbPoolPut(body) // wsEnqueueControlMessageLocked has taken a copy. } // Create and then enqueue a close message with a protocol error and the @@ -655,6 +660,7 @@ func (c *client) wsEnqueueCloseMessage(reason ClosedState) { func (c *client) wsHandleProtocolError(message string) error { buf := wsCreateCloseMessage(wsCloseStatusProtocolError, message) c.wsEnqueueControlMessage(wsCloseMessage, buf) + nbPoolPut(buf) // wsEnqueueControlMessage has taken a copy. return fmt.Errorf(message) } @@ -671,7 +677,7 @@ func wsCreateCloseMessage(status int, body string) []byte { body = body[:wsMaxControlPayloadSize-5] body += "..." } - buf := make([]byte, 2+len(body)) + buf := nbPoolGet(2 + len(body))[:2+len(body)] // We need to have a 2 byte unsigned int that represents the error status code // https://tools.ietf.org/html/rfc6455#section-5.5.1 binary.BigEndian.PutUint16(buf[:2], uint16(status)) @@ -1298,6 +1304,7 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) { var csz int for _, b := range nb { cp.Write(b) + nbPoolPut(b) // No longer needed as contents written to compressor. } if err := cp.Flush(); err != nil { c.Errorf("Error during compression: %v", err) @@ -1314,24 +1321,33 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) { } else { final = true } - fh := make([]byte, wsMaxFrameHeaderSize) // Only the first frame should be marked as compressed, so pass // `first` for the compressed boolean. + fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize] n, key := wsFillFrameHeader(fh, mask, first, final, first, wsBinaryMessage, lp) if mask { wsMaskBuf(key, p[:lp]) } - bufs = append(bufs, fh[:n], p[:lp]) + new := nbPoolGet(wsFrameSizeForBrowsers) + lp = copy(new[:wsFrameSizeForBrowsers], p[:lp]) + bufs = append(bufs, fh[:n], new[:lp]) csz += n + lp p = p[lp:] } } else { - h, key := wsCreateFrameHeader(mask, true, wsBinaryMessage, len(p)) + ol := len(p) + h, key := wsCreateFrameHeader(mask, true, wsBinaryMessage, ol) if mask { wsMaskBuf(key, p) } - bufs = append(bufs, h, p) - csz = len(h) + len(p) + bufs = append(bufs, h) + for len(p) > 0 { + new := nbPoolGet(len(p)) + n := copy(new[:cap(new)], p) + bufs = append(bufs, new[:n]) + p = p[n:] + } + csz = len(h) + ol } // Add to pb the compressed data size (including headers), but // remove the original uncompressed data size that was added @@ -1343,7 +1359,7 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) { if mfs > 0 { // We are limiting the frame size. startFrame := func() int { - bufs = append(bufs, make([]byte, wsMaxFrameHeaderSize)) + bufs = append(bufs, nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize]) return len(bufs) - 1 } endFrame := func(idx, size int) { @@ -1376,8 +1392,10 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) { if endStart { fhIdx = startFrame() } - bufs = append(bufs, b[:total]) - b = b[total:] + new := nbPoolGet(total) + n := copy(new[:cap(new)], b[:total]) + bufs = append(bufs, new[:n]) + b = b[n:] } } if total > 0 {