diff --git a/server/client.go b/server/client.go index 06d96f81..aa511390 100644 --- a/server/client.go +++ b/server/client.go @@ -536,12 +536,6 @@ func (c *client) deliverMsg(sub *subscription, mh, msg []byte) { goto writeErr } - // FIXME, this is already attached to original message - _, err = client.bw.WriteString(CR_LF) - if err != nil { - goto writeErr - } - if deadlineSet { client.nc.SetWriteDeadline(time.Time{}) } diff --git a/server/const.go b/server/const.go index ffca86d1..9d81423f 100644 --- a/server/const.go +++ b/server/const.go @@ -36,6 +36,7 @@ const ( // CRLF string CR_LF = "\r\n" + LEN_CR_LF = len(CR_LF) // Write/Flush Deadlines DEFAULT_FLUSH_DEADLINE = 500 * time.Millisecond diff --git a/server/parser.go b/server/parser.go index d28c5787..866c416b 100644 --- a/server/parser.go +++ b/server/parser.go @@ -68,6 +68,9 @@ const ( func (c *client) parse(buf []byte) error { var i int var b byte + + // snapshot this, and reset when we receive a + // proper CONNECT if needed. authSet := c.isAuthTimerSet() for i, b = range buf { @@ -145,21 +148,32 @@ func (c *client) parse(buf []byte) error { } case MSG_PAYLOAD: if c.msgBuf != nil { + c.msgBuf = append(c.msgBuf, b) if len(c.msgBuf) >= c.pa.size { - c.processMsg(c.msgBuf) - c.argBuf, c.msgBuf, c.state = nil, nil, MSG_END - } else { - c.msgBuf = append(c.msgBuf, b) + c.state = MSG_END } } else if i-c.as >= c.pa.size { - c.processMsg(buf[c.as:i]) - c.argBuf, c.msgBuf, c.state = nil, nil, MSG_END + c.state = MSG_END } case MSG_END: switch b { case '\n': + if c.msgBuf != nil { + c.msgBuf = append(c.msgBuf, b) + } else { + c.msgBuf = buf[c.as:i+1] + } + // strict check for proto + if len(c.msgBuf) != c.pa.size + LEN_CR_LF { + goto parseErr + } + c.processMsg(c.msgBuf) + c.argBuf, c.msgBuf = nil, nil c.drop, c.as, c.state = 0, i+1, OP_START default: + if c.msgBuf != nil { + c.msgBuf = append(c.msgBuf, b) + } continue } case OP_S: @@ -368,6 +382,8 @@ func (c *client) parse(buf []byte) error { return err } c.drop, c.state = 0, OP_START + // Reset notion on authSet + authSet = c.isAuthTimerSet() } case OP_M: switch b { @@ -429,7 +445,7 @@ func (c *client) parse(buf []byte) error { // FIXME, check max len } // Check for split msg - if c.state == MSG_PAYLOAD && c.msgBuf == nil { + if (c.state == MSG_PAYLOAD || c.state == MSG_END) && c.msgBuf == nil { // We need to clone the pubArg if it is still referencing the // read buffer and we are not able to process the msg. if c.argBuf == nil { diff --git a/server/parser_test.go b/server/parser_test.go index 86023cd4..3fd7172d 100644 --- a/server/parser_test.go +++ b/server/parser_test.go @@ -169,7 +169,8 @@ func TestParsePub(t *testing.T) { t.Fatalf("Did not parse msg size correctly: 5 vs %d\n", c.pa.size) } - c.state = OP_START + // Clear snapshots + c.argBuf, c.msgBuf, c.state = nil, nil, OP_START pub = []byte("PUB foo.bar INBOX.22 11\r\nhello world\r") err = c.parse(pub) @@ -243,7 +244,8 @@ func TestParseMsg(t *testing.T) { t.Fatalf("Did not parse sid correctly: 'RSID:1:2' vs '%s'\n", c.pa.sid) } - c.state = OP_START + // Clear snapshots + c.argBuf, c.msgBuf, c.state = nil, nil, OP_START pub = []byte("MSG foo.bar RSID:1:2 INBOX.22 11\r\nhello world\r") err = c.parse(pub) @@ -341,4 +343,12 @@ func TestShouldFail(t *testing.T) { if err := c.parse([]byte("SUB foo bar baz 22\r\n")); err == nil { t.Fatal("Should have received a parse error") } + c.state = OP_START + if err := c.parse([]byte("PUB foo 2\r\nok \r\n")); err == nil { + t.Fatal("Should have received a parse error") + } + c.state = OP_START + if err := c.parse([]byte("PUB foo 2\r\nok\r \n")); err == nil { + t.Fatal("Should have received a parse error") + } } diff --git a/server/split_test.go b/server/split_test.go index f7610d85..484b9908 100644 --- a/server/split_test.go +++ b/server/split_test.go @@ -217,3 +217,22 @@ func TestSplitBufferPubOp4(t *testing.T) { t.Fatalf("Unexpected size bytes: '%s' vs '%s'\n", c.pa.szb, "11") } } + +func TestSplitBufferPubOp5(t *testing.T) { + c := &client{subs: hashmap.New()} + pubAll := []byte("PUB foo 11\r\nhello world\r\n") + + // Splits need to be on MSG_END now too, so make sure we check that. + // Split between \r and \n + pub := pubAll[:len(pubAll)-1] + + if err := c.parse(pub); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.msgBuf == nil { + t.Fatalf("msgBuf should not be nil!\n") + } + if !bytes.Equal(c.msgBuf, []byte("hello world\r")) { + t.Fatalf("c.msgBuf did not snaphot the msg") + } +}