mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-17 11:24:44 -07:00
Merge branch 'main' into dev
This commit is contained in:
@@ -303,7 +303,8 @@ type outbound struct {
|
||||
stc chan struct{} // Stall chan we create to slow down producers on overrun, e.g. fan-in.
|
||||
}
|
||||
|
||||
const nbPoolSizeSmall = 4096 // Underlying array size of small buffer
|
||||
const nbPoolSizeSmall = 512 // Underlying array size of small buffer
|
||||
const nbPoolSizeMedium = 4096 // Underlying array size of medium buffer
|
||||
const nbPoolSizeLarge = 65536 // Underlying array size of large buffer
|
||||
|
||||
var nbPoolSmall = &sync.Pool{
|
||||
@@ -313,6 +314,13 @@ var nbPoolSmall = &sync.Pool{
|
||||
},
|
||||
}
|
||||
|
||||
var nbPoolMedium = &sync.Pool{
|
||||
New: func() any {
|
||||
b := [nbPoolSizeMedium]byte{}
|
||||
return &b
|
||||
},
|
||||
}
|
||||
|
||||
var nbPoolLarge = &sync.Pool{
|
||||
New: func() any {
|
||||
b := [nbPoolSizeLarge]byte{}
|
||||
@@ -320,11 +328,30 @@ var nbPoolLarge = &sync.Pool{
|
||||
},
|
||||
}
|
||||
|
||||
func nbPoolGet(sz int) []byte {
|
||||
var new []byte
|
||||
switch {
|
||||
case sz <= nbPoolSizeSmall:
|
||||
ptr := nbPoolSmall.Get().(*[nbPoolSizeSmall]byte)
|
||||
new = ptr[:0]
|
||||
case sz <= nbPoolSizeMedium:
|
||||
ptr := nbPoolMedium.Get().(*[nbPoolSizeMedium]byte)
|
||||
new = ptr[:0]
|
||||
default:
|
||||
ptr := nbPoolLarge.Get().(*[nbPoolSizeLarge]byte)
|
||||
new = ptr[:0]
|
||||
}
|
||||
return new
|
||||
}
|
||||
|
||||
func nbPoolPut(b []byte) {
|
||||
switch cap(b) {
|
||||
case nbPoolSizeSmall:
|
||||
b := (*[nbPoolSizeSmall]byte)(b[0:nbPoolSizeSmall])
|
||||
nbPoolSmall.Put(b)
|
||||
case nbPoolSizeMedium:
|
||||
b := (*[nbPoolSizeMedium]byte)(b[0:nbPoolSizeMedium])
|
||||
nbPoolMedium.Put(b)
|
||||
case nbPoolSizeLarge:
|
||||
b := (*[nbPoolSizeLarge]byte)(b[0:nbPoolSizeLarge])
|
||||
nbPoolLarge.Put(b)
|
||||
@@ -1548,7 +1575,7 @@ func (c *client) flushOutbound() bool {
|
||||
if err != nil && err != io.ErrShortWrite {
|
||||
// Handle timeout error (slow consumer) differently
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
if closed := c.handleWriteTimeout(n, attempted, len(c.out.nb)); closed {
|
||||
if closed := c.handleWriteTimeout(n, attempted, len(orig)); closed {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
@@ -2081,43 +2108,14 @@ func (c *client) queueOutbound(data []byte) {
|
||||
// without affecting the original "data" slice.
|
||||
toBuffer := data
|
||||
|
||||
// All of the queued []byte have a fixed capacity, so if there's a []byte
|
||||
// at the tail of the buffer list that isn't full yet, we should top that
|
||||
// up first. This helps to ensure we aren't pulling more []bytes from the
|
||||
// pool than we need to.
|
||||
if len(c.out.nb) > 0 {
|
||||
last := &c.out.nb[len(c.out.nb)-1]
|
||||
if free := cap(*last) - len(*last); free > 0 {
|
||||
if l := len(toBuffer); l < free {
|
||||
free = l
|
||||
}
|
||||
*last = append(*last, toBuffer[:free]...)
|
||||
toBuffer = toBuffer[free:]
|
||||
}
|
||||
}
|
||||
|
||||
// Now we can push the rest of the data into new []bytes from the pool
|
||||
// in fixed size chunks. This ensures we don't go over the capacity of any
|
||||
// of the buffers and end up reallocating.
|
||||
for len(toBuffer) > 0 {
|
||||
var new []byte
|
||||
if len(c.out.nb) == 0 && len(toBuffer) <= nbPoolSizeSmall {
|
||||
// If the buffer is empty, try to allocate a small buffer if the
|
||||
// message will fit in it. This will help for cases like pings.
|
||||
new = nbPoolSmall.Get().(*[nbPoolSizeSmall]byte)[:0]
|
||||
} else {
|
||||
// If "nb" isn't empty, default to large buffers in all cases as
|
||||
// this means we are always coalescing future messages into
|
||||
// larger buffers. Reduces the number of buffers into writev.
|
||||
new = nbPoolLarge.Get().(*[nbPoolSizeLarge]byte)[:0]
|
||||
}
|
||||
l := len(toBuffer)
|
||||
if c := cap(new); l > c {
|
||||
l = c
|
||||
}
|
||||
new = append(new, toBuffer[:l]...)
|
||||
c.out.nb = append(c.out.nb, new)
|
||||
toBuffer = toBuffer[l:]
|
||||
new := nbPoolGet(len(toBuffer))
|
||||
n := copy(new[:cap(new)], toBuffer)
|
||||
c.out.nb = append(c.out.nb, new[:n])
|
||||
toBuffer = toBuffer[n:]
|
||||
}
|
||||
|
||||
// Check for slow consumer via pending bytes limit.
|
||||
|
||||
@@ -1483,64 +1483,69 @@ func TestWildcardCharsInLiteralSubjectWorks(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// This test ensures that coalescing into the fixed-size output
|
||||
// queues works as expected. When bytes are queued up, they should
|
||||
// not overflow a buffer until the capacity is exceeded, at which
|
||||
// point a new buffer should be added.
|
||||
func TestClientOutboundQueueCoalesce(t *testing.T) {
|
||||
// This test ensures that outbound queues don't cause a run on
|
||||
// memory when sending something to lots of clients.
|
||||
func TestClientOutboundQueueMemory(t *testing.T) {
|
||||
opts := DefaultOptions()
|
||||
s := RunServer(opts)
|
||||
defer s.Shutdown()
|
||||
|
||||
nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port))
|
||||
var before runtime.MemStats
|
||||
var after runtime.MemStats
|
||||
|
||||
var err error
|
||||
clients := make([]*nats.Conn, 50000)
|
||||
wait := &sync.WaitGroup{}
|
||||
wait.Add(len(clients))
|
||||
|
||||
for i := 0; i < len(clients); i++ {
|
||||
clients[i], err = nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port), nats.InProcessServer(s))
|
||||
if err != nil {
|
||||
t.Fatalf("Error on connect: %v", err)
|
||||
}
|
||||
defer clients[i].Close()
|
||||
|
||||
clients[i].Subscribe("test", func(m *nats.Msg) {
|
||||
wait.Done()
|
||||
})
|
||||
}
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&before)
|
||||
|
||||
nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port), nats.InProcessServer(s))
|
||||
if err != nil {
|
||||
t.Fatalf("Error on connect: %v", err)
|
||||
}
|
||||
defer nc.Close()
|
||||
|
||||
clients := s.GlobalAccount().getClients()
|
||||
if len(clients) != 1 {
|
||||
t.Fatal("Expecting a client to exist")
|
||||
}
|
||||
client := clients[0]
|
||||
client.mu.Lock()
|
||||
defer client.mu.Unlock()
|
||||
|
||||
// First up, queue something small into the queue.
|
||||
client.queueOutbound([]byte{1, 2, 3, 4, 5})
|
||||
|
||||
if len(client.out.nb) != 1 {
|
||||
t.Fatal("Expecting a single queued buffer")
|
||||
}
|
||||
if l := len(client.out.nb[0]); l != 5 {
|
||||
t.Fatalf("Expecting only 5 bytes in the first queued buffer, found %d instead", l)
|
||||
var m [48000]byte
|
||||
if err = nc.Publish("test", m[:]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Then queue up a few more bytes, but not enough
|
||||
// to overflow into the next buffer.
|
||||
client.queueOutbound([]byte{6, 7, 8, 9, 10})
|
||||
wait.Wait()
|
||||
|
||||
if len(client.out.nb) != 1 {
|
||||
t.Fatal("Expecting a single queued buffer")
|
||||
}
|
||||
if l := len(client.out.nb[0]); l != 10 {
|
||||
t.Fatalf("Expecting 10 bytes in the first queued buffer, found %d instead", l)
|
||||
}
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&after)
|
||||
|
||||
// Finally, queue up something that is guaranteed
|
||||
// to overflow.
|
||||
b := nbPoolSmall.Get().(*[nbPoolSizeSmall]byte)[:]
|
||||
b = b[:cap(b)]
|
||||
client.queueOutbound(b)
|
||||
if len(client.out.nb) != 2 {
|
||||
t.Fatal("Expecting buffer to have overflowed")
|
||||
}
|
||||
if l := len(client.out.nb[0]); l != cap(b) {
|
||||
t.Fatalf("Expecting %d bytes in the first queued buffer, found %d instead", cap(b), l)
|
||||
}
|
||||
if l := len(client.out.nb[1]); l != 10 {
|
||||
t.Fatalf("Expecting 10 bytes in the second queued buffer, found %d instead", l)
|
||||
}
|
||||
hb, ha := float64(before.HeapAlloc), float64(after.HeapAlloc)
|
||||
ms := float64(len(m))
|
||||
diff := float64(ha) - float64(hb)
|
||||
inc := (diff / float64(hb)) * 100
|
||||
|
||||
fmt.Printf("Message size: %.1fKB\n", ms/1024)
|
||||
fmt.Printf("Subscribed clients: %d\n", len(clients))
|
||||
fmt.Printf("Heap allocs before: %.1fMB\n", hb/1024/1024)
|
||||
fmt.Printf("Heap allocs after: %.1fMB\n", ha/1024/1024)
|
||||
fmt.Printf("Heap allocs delta: %.1f%%\n", inc)
|
||||
|
||||
// TODO: What threshold makes sense here for a failure?
|
||||
/*
|
||||
if inc > 10 {
|
||||
t.Fatalf("memory increase was %.1f%% (should be <= 10%%)", inc)
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
func TestClientTraceRace(t *testing.T) {
|
||||
|
||||
@@ -3420,6 +3420,69 @@ func (o *consumer) hbTimer() (time.Duration, *time.Timer) {
|
||||
return o.cfg.Heartbeat, time.NewTimer(o.cfg.Heartbeat)
|
||||
}
|
||||
|
||||
// Check here for conditions when our ack floor may have drifted below the streams first sequence.
|
||||
// In general this is accounted for in normal operations, but if the consumer misses the signal from
|
||||
// the stream it will not clear the message and move the ack state.
|
||||
// Should only be called from consumer leader.
|
||||
func (o *consumer) checkAckFloor() {
|
||||
o.mu.RLock()
|
||||
mset, closed, asflr := o.mset, o.closed, o.asflr
|
||||
o.mu.RUnlock()
|
||||
|
||||
if closed || mset == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var ss StreamState
|
||||
mset.store.FastState(&ss)
|
||||
|
||||
// If our floor is equal or greater that is normal and nothing for us to do.
|
||||
if ss.FirstSeq == 0 || asflr >= ss.FirstSeq-1 {
|
||||
return
|
||||
}
|
||||
|
||||
// Process all messages that no longer exist.
|
||||
for seq := asflr + 1; seq < ss.FirstSeq; seq++ {
|
||||
// Check if this message was pending.
|
||||
o.mu.RLock()
|
||||
p, isPending := o.pending[seq]
|
||||
var rdc uint64 = 1
|
||||
if o.rdc != nil {
|
||||
rdc = o.rdc[seq]
|
||||
}
|
||||
o.mu.RUnlock()
|
||||
// If it was pending for us, get rid of it.
|
||||
if isPending {
|
||||
o.processTerm(seq, p.Sequence, rdc)
|
||||
}
|
||||
}
|
||||
|
||||
// Do one final check here.
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
// If we are here, and this should be rare, we still are off with our ack floor.
|
||||
// We will set it explicitly to 1 behind our current lowest in pending, or if
|
||||
// pending is empty, to our current delivered -1.
|
||||
if o.asflr < ss.FirstSeq-1 {
|
||||
var psseq, pdseq uint64
|
||||
for seq, p := range o.pending {
|
||||
if psseq == 0 || seq < psseq {
|
||||
psseq, pdseq = seq, p.Sequence
|
||||
}
|
||||
}
|
||||
// If we still have none, set to current delivered -1.
|
||||
if psseq == 0 {
|
||||
psseq, pdseq = o.sseq-1, o.dseq-1
|
||||
// If still not adjusted.
|
||||
if psseq < ss.FirstSeq-1 {
|
||||
psseq, pdseq = ss.FirstSeq-1, ss.FirstSeq-1
|
||||
}
|
||||
}
|
||||
o.asflr, o.adflr = psseq, pdseq
|
||||
}
|
||||
}
|
||||
|
||||
func (o *consumer) processInboundAcks(qch chan struct{}) {
|
||||
// Grab the server lock to watch for server quit.
|
||||
o.mu.RLock()
|
||||
@@ -3427,6 +3490,12 @@ func (o *consumer) processInboundAcks(qch chan struct{}) {
|
||||
hasInactiveThresh := o.cfg.InactiveThreshold > 0
|
||||
o.mu.RUnlock()
|
||||
|
||||
// We will check this on entry and periodically.
|
||||
o.checkAckFloor()
|
||||
|
||||
// How often we will check for ack floor drift.
|
||||
var ackFloorCheck = 30 * time.Second
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-o.ackMsgs.ch:
|
||||
@@ -3440,6 +3509,8 @@ func (o *consumer) processInboundAcks(qch chan struct{}) {
|
||||
if hasInactiveThresh {
|
||||
o.suppressDeletion()
|
||||
}
|
||||
case <-time.After(ackFloorCheck):
|
||||
o.checkAckFloor()
|
||||
case <-qch:
|
||||
return
|
||||
case <-s.quitCh:
|
||||
|
||||
@@ -3511,3 +3511,145 @@ func TestJetStreamNoLeadersDuringLameDuck(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a consumer has not been registered (possible in heavily loaded systems with lots of assets)
|
||||
// it could miss the signal of a message going away. If that message was pending and expires the
|
||||
// ack floor could fall below the stream first sequence. This test will force that condition and
|
||||
// make sure the system resolves itself.
|
||||
func TestJetStreamConsumerAckFloorDrift(t *testing.T) {
|
||||
c := createJetStreamClusterExplicit(t, "R3S", 3)
|
||||
defer c.shutdown()
|
||||
|
||||
nc, js := jsClientConnect(t, c.randomServer())
|
||||
defer nc.Close()
|
||||
|
||||
_, err := js.AddStream(&nats.StreamConfig{
|
||||
Name: "TEST",
|
||||
Subjects: []string{"*"},
|
||||
Replicas: 3,
|
||||
MaxAge: 200 * time.Millisecond,
|
||||
MaxMsgs: 10,
|
||||
})
|
||||
require_NoError(t, err)
|
||||
|
||||
sub, err := js.PullSubscribe("foo", "C")
|
||||
require_NoError(t, err)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
sendStreamMsg(t, nc, "foo", "HELLO")
|
||||
}
|
||||
|
||||
// No-op but will surface as delivered.
|
||||
_, err = sub.Fetch(10)
|
||||
require_NoError(t, err)
|
||||
|
||||
// We will grab the state with delivered being 10 and ackfloor being 0 directly.
|
||||
cl := c.consumerLeader(globalAccountName, "TEST", "C")
|
||||
require_NotNil(t, cl)
|
||||
|
||||
mset, err := cl.GlobalAccount().lookupStream("TEST")
|
||||
require_NoError(t, err)
|
||||
o := mset.lookupConsumer("C")
|
||||
require_NotNil(t, o)
|
||||
o.mu.RLock()
|
||||
state, err := o.store.State()
|
||||
o.mu.RUnlock()
|
||||
require_NoError(t, err)
|
||||
require_NotNil(t, state)
|
||||
|
||||
// Now let messages expire.
|
||||
checkFor(t, time.Second, 100*time.Millisecond, func() error {
|
||||
si, err := js.StreamInfo("TEST")
|
||||
require_NoError(t, err)
|
||||
if si.State.Msgs == 0 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("stream still has msgs")
|
||||
})
|
||||
|
||||
// Set state to ackfloor of 5 and no pending.
|
||||
state.AckFloor.Consumer = 5
|
||||
state.AckFloor.Stream = 5
|
||||
state.Pending = nil
|
||||
|
||||
// Now put back the state underneath of the consumers.
|
||||
for _, s := range c.servers {
|
||||
mset, err := s.GlobalAccount().lookupStream("TEST")
|
||||
require_NoError(t, err)
|
||||
o := mset.lookupConsumer("C")
|
||||
require_NotNil(t, o)
|
||||
o.mu.Lock()
|
||||
err = o.setStoreState(state)
|
||||
cfs := o.store.(*consumerFileStore)
|
||||
o.mu.Unlock()
|
||||
require_NoError(t, err)
|
||||
// The lower layer will ignore, so set more directly.
|
||||
cfs.mu.Lock()
|
||||
cfs.state = *state
|
||||
cfs.mu.Unlock()
|
||||
// Also snapshot to remove any raft entries that could affect it.
|
||||
snap, err := o.store.EncodedState()
|
||||
require_NoError(t, err)
|
||||
require_NoError(t, o.raftNode().InstallSnapshot(snap))
|
||||
}
|
||||
|
||||
cl.JetStreamStepdownConsumer(globalAccountName, "TEST", "C")
|
||||
c.waitOnConsumerLeader(globalAccountName, "TEST", "C")
|
||||
|
||||
checkFor(t, 5*time.Second, 100*time.Millisecond, func() error {
|
||||
ci, err := js.ConsumerInfo("TEST", "C")
|
||||
require_NoError(t, err)
|
||||
// Make sure we catch this and adjust.
|
||||
if ci.AckFloor.Stream == 10 && ci.AckFloor.Consumer == 10 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("AckFloor not correct, expected 10, got %+v", ci.AckFloor)
|
||||
})
|
||||
}
|
||||
|
||||
func TestJetStreamClusterInterestStreamFilteredConsumersWithNoInterest(t *testing.T) {
|
||||
c := createJetStreamClusterExplicit(t, "R5S", 5)
|
||||
defer c.shutdown()
|
||||
|
||||
nc, js := jsClientConnect(t, c.randomServer())
|
||||
defer nc.Close()
|
||||
|
||||
_, err := js.AddStream(&nats.StreamConfig{
|
||||
Name: "TEST",
|
||||
Subjects: []string{"*"},
|
||||
Retention: nats.InterestPolicy,
|
||||
Replicas: 3,
|
||||
})
|
||||
require_NoError(t, err)
|
||||
|
||||
// Create three subscribers.
|
||||
ackCb := func(m *nats.Msg) { m.Ack() }
|
||||
|
||||
_, err = js.Subscribe("foo", ackCb, nats.BindStream("TEST"), nats.ManualAck())
|
||||
require_NoError(t, err)
|
||||
|
||||
_, err = js.Subscribe("bar", ackCb, nats.BindStream("TEST"), nats.ManualAck())
|
||||
require_NoError(t, err)
|
||||
|
||||
_, err = js.Subscribe("baz", ackCb, nats.BindStream("TEST"), nats.ManualAck())
|
||||
require_NoError(t, err)
|
||||
|
||||
// Now send 100 messages, randomly picking foo or bar, but never baz.
|
||||
for i := 0; i < 100; i++ {
|
||||
if rand.Intn(2) > 0 {
|
||||
sendStreamMsg(t, nc, "foo", "HELLO")
|
||||
} else {
|
||||
sendStreamMsg(t, nc, "bar", "WORLD")
|
||||
}
|
||||
}
|
||||
|
||||
// Messages are expected to go to 0.
|
||||
checkFor(t, time.Second, 100*time.Millisecond, func() error {
|
||||
si, err := js.StreamInfo("TEST")
|
||||
require_NoError(t, err)
|
||||
if si.State.Msgs == 0 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("stream still has msgs")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2020 The NATS Authors
|
||||
// Copyright 2020-2023 The NATS Authors
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
@@ -15,7 +15,6 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"crypto/tls"
|
||||
@@ -34,6 +33,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/klauspost/compress/flate"
|
||||
)
|
||||
|
||||
type wsOpCode int
|
||||
@@ -452,7 +453,9 @@ func (c *client) wsHandleControlFrame(r *wsReadInfo, frameType wsOpCode, nc io.R
|
||||
}
|
||||
}
|
||||
}
|
||||
c.wsEnqueueControlMessage(wsCloseMessage, wsCreateCloseMessage(status, body))
|
||||
clm := wsCreateCloseMessage(status, body)
|
||||
c.wsEnqueueControlMessage(wsCloseMessage, clm)
|
||||
nbPoolPut(clm) // wsEnqueueControlMessage has taken a copy.
|
||||
// Return io.EOF so that readLoop will close the connection as ClientClosed
|
||||
// after processing pending buffers.
|
||||
return pos, io.EOF
|
||||
@@ -502,7 +505,7 @@ func wsIsControlFrame(frameType wsOpCode) bool {
|
||||
// Create the frame header.
|
||||
// Encodes the frame type and optional compression flag, and the size of the payload.
|
||||
func wsCreateFrameHeader(useMasking, compressed bool, frameType wsOpCode, l int) ([]byte, []byte) {
|
||||
fh := make([]byte, wsMaxFrameHeaderSize)
|
||||
fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize]
|
||||
n, key := wsFillFrameHeader(fh, useMasking, wsFirstFrame, wsFinalFrame, compressed, frameType, l)
|
||||
return fh[:n], key
|
||||
}
|
||||
@@ -596,11 +599,13 @@ func (c *client) wsEnqueueControlMessageLocked(controlMsg wsOpCode, payload []by
|
||||
if useMasking {
|
||||
sz += 4
|
||||
}
|
||||
cm := make([]byte, sz+len(payload))
|
||||
cm := nbPoolGet(sz + len(payload))
|
||||
cm = cm[:cap(cm)]
|
||||
n, key := wsFillFrameHeader(cm, useMasking, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, controlMsg, len(payload))
|
||||
cm = cm[:n]
|
||||
// Note that payload is optional.
|
||||
if len(payload) > 0 {
|
||||
copy(cm[n:], payload)
|
||||
cm = append(cm, payload...)
|
||||
if useMasking {
|
||||
wsMaskBuf(key, cm[n:])
|
||||
}
|
||||
@@ -646,6 +651,7 @@ func (c *client) wsEnqueueCloseMessage(reason ClosedState) {
|
||||
}
|
||||
body := wsCreateCloseMessage(status, reason.String())
|
||||
c.wsEnqueueControlMessageLocked(wsCloseMessage, body)
|
||||
nbPoolPut(body) // wsEnqueueControlMessageLocked has taken a copy.
|
||||
}
|
||||
|
||||
// Create and then enqueue a close message with a protocol error and the
|
||||
@@ -655,6 +661,7 @@ func (c *client) wsEnqueueCloseMessage(reason ClosedState) {
|
||||
func (c *client) wsHandleProtocolError(message string) error {
|
||||
buf := wsCreateCloseMessage(wsCloseStatusProtocolError, message)
|
||||
c.wsEnqueueControlMessage(wsCloseMessage, buf)
|
||||
nbPoolPut(buf) // wsEnqueueControlMessage has taken a copy.
|
||||
return fmt.Errorf(message)
|
||||
}
|
||||
|
||||
@@ -671,7 +678,7 @@ func wsCreateCloseMessage(status int, body string) []byte {
|
||||
body = body[:wsMaxControlPayloadSize-5]
|
||||
body += "..."
|
||||
}
|
||||
buf := make([]byte, 2+len(body))
|
||||
buf := nbPoolGet(2 + len(body))[:2+len(body)]
|
||||
// We need to have a 2 byte unsigned int that represents the error status code
|
||||
// https://tools.ietf.org/html/rfc6455#section-5.5.1
|
||||
binary.BigEndian.PutUint16(buf[:2], uint16(status))
|
||||
@@ -1298,6 +1305,7 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
|
||||
var csz int
|
||||
for _, b := range nb {
|
||||
cp.Write(b)
|
||||
nbPoolPut(b) // No longer needed as contents written to compressor.
|
||||
}
|
||||
if err := cp.Flush(); err != nil {
|
||||
c.Errorf("Error during compression: %v", err)
|
||||
@@ -1314,24 +1322,33 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
|
||||
} else {
|
||||
final = true
|
||||
}
|
||||
fh := make([]byte, wsMaxFrameHeaderSize)
|
||||
// Only the first frame should be marked as compressed, so pass
|
||||
// `first` for the compressed boolean.
|
||||
fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize]
|
||||
n, key := wsFillFrameHeader(fh, mask, first, final, first, wsBinaryMessage, lp)
|
||||
if mask {
|
||||
wsMaskBuf(key, p[:lp])
|
||||
}
|
||||
bufs = append(bufs, fh[:n], p[:lp])
|
||||
new := nbPoolGet(wsFrameSizeForBrowsers)
|
||||
lp = copy(new[:wsFrameSizeForBrowsers], p[:lp])
|
||||
bufs = append(bufs, fh[:n], new[:lp])
|
||||
csz += n + lp
|
||||
p = p[lp:]
|
||||
}
|
||||
} else {
|
||||
h, key := wsCreateFrameHeader(mask, true, wsBinaryMessage, len(p))
|
||||
ol := len(p)
|
||||
h, key := wsCreateFrameHeader(mask, true, wsBinaryMessage, ol)
|
||||
if mask {
|
||||
wsMaskBuf(key, p)
|
||||
}
|
||||
bufs = append(bufs, h, p)
|
||||
csz = len(h) + len(p)
|
||||
bufs = append(bufs, h)
|
||||
for len(p) > 0 {
|
||||
new := nbPoolGet(len(p))
|
||||
n := copy(new[:cap(new)], p)
|
||||
bufs = append(bufs, new[:n])
|
||||
p = p[n:]
|
||||
}
|
||||
csz = len(h) + ol
|
||||
}
|
||||
// Add to pb the compressed data size (including headers), but
|
||||
// remove the original uncompressed data size that was added
|
||||
@@ -1343,7 +1360,7 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
|
||||
if mfs > 0 {
|
||||
// We are limiting the frame size.
|
||||
startFrame := func() int {
|
||||
bufs = append(bufs, make([]byte, wsMaxFrameHeaderSize))
|
||||
bufs = append(bufs, nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize])
|
||||
return len(bufs) - 1
|
||||
}
|
||||
endFrame := func(idx, size int) {
|
||||
@@ -1376,8 +1393,10 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
|
||||
if endStart {
|
||||
fhIdx = startFrame()
|
||||
}
|
||||
bufs = append(bufs, b[:total])
|
||||
b = b[total:]
|
||||
new := nbPoolGet(total)
|
||||
n := copy(new[:cap(new)], b[:total])
|
||||
bufs = append(bufs, new[:n])
|
||||
b = b[n:]
|
||||
}
|
||||
}
|
||||
if total > 0 {
|
||||
|
||||
@@ -16,7 +16,6 @@ package server
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
@@ -36,6 +35,8 @@ import (
|
||||
|
||||
"github.com/nats-io/jwt/v2"
|
||||
"github.com/nats-io/nkeys"
|
||||
|
||||
"github.com/klauspost/compress/flate"
|
||||
)
|
||||
|
||||
type testReader struct {
|
||||
@@ -2863,11 +2864,11 @@ func (wc *testWSWrappedConn) Write(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
func TestWSCompressionBasic(t *testing.T) {
|
||||
payload := "This is the content of a message that will be compresseddddddddddddddddddddd."
|
||||
payload := "This is the content of a message that will be compresseddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd."
|
||||
msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload)
|
||||
|
||||
cbuf := &bytes.Buffer{}
|
||||
compressor, _ := flate.NewWriter(cbuf, flate.BestSpeed)
|
||||
compressor, err := flate.NewWriter(cbuf, flate.BestSpeed)
|
||||
require_NoError(t, err)
|
||||
compressor.Write([]byte(msgProto))
|
||||
compressor.Flush()
|
||||
compressed := cbuf.Bytes()
|
||||
@@ -2890,14 +2891,14 @@ func TestWSCompressionBasic(t *testing.T) {
|
||||
}
|
||||
|
||||
var wc *testWSWrappedConn
|
||||
s.mu.Lock()
|
||||
s.mu.RLock()
|
||||
for _, c := range s.clients {
|
||||
c.mu.Lock()
|
||||
wc = &testWSWrappedConn{Conn: c.nc, buf: &bytes.Buffer{}}
|
||||
c.nc = wc
|
||||
c.mu.Unlock()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
s.mu.RUnlock()
|
||||
|
||||
nc := natsConnect(t, s.ClientURL())
|
||||
defer nc.Close()
|
||||
|
||||
Reference in New Issue
Block a user