diff --git a/server/parser.go b/server/parser.go index a6134f3d..5fd417fd 100644 --- a/server/parser.go +++ b/server/parser.go @@ -1,4 +1,4 @@ -// Copyright 2012 Apcera Inc. All rights reserved. +// Copyright 2012-2014 Apcera Inc. All rights reserved. package server @@ -398,12 +398,23 @@ func (c *client) parse(buf []byte) error { case '\r': c.drop = 1 case '\n': - if err := c.processConnect(buf[c.as : i-c.drop]); err != nil { + var arg []byte + if c.argBuf != nil { + arg = c.argBuf + } else { + arg = buf[c.as : i-c.drop] + } + if err := c.processConnect(arg); err != nil { return err } c.drop, c.state = 0, OP_START + c.argBuf = nil // Reset notion on authSet authSet = c.isAuthTimerSet() + default: + if c.argBuf != nil { + c.argBuf = append(c.argBuf, b) + } } case OP_M: switch b { @@ -573,7 +584,8 @@ func (c *client) parse(buf []byte) error { } // Check for split buffer scenarios for any ARG state. if (c.state == SUB_ARG || c.state == UNSUB_ARG || c.state == PUB_ARG || - c.state == MSG_ARG || c.state == MINUS_ERR_ARG) && c.argBuf == nil { + c.state == MSG_ARG || c.state == MINUS_ERR_ARG || + c.state == CONNECT_ARG) && c.argBuf == nil { c.argBuf = c.scratch[:0] c.argBuf = append(c.argBuf, buf[c.as:(i+1)-c.drop]...) // FIXME, check max len @@ -585,11 +597,10 @@ func (c *client) parse(buf []byte) error { if c.argBuf == nil { c.clonePubArg() } - // FIXME: copy better here? Make whole buf if large? - //c.msgBuf = c.scratch[:0] c.msgBuf = c.scratch[len(c.argBuf):len(c.argBuf)] c.msgBuf = append(c.msgBuf, (buf[c.as:])...) } + return nil authErr: diff --git a/server/split_test.go b/server/split_test.go index 2dea8d0d..7864d349 100644 --- a/server/split_test.go +++ b/server/split_test.go @@ -241,3 +241,51 @@ func TestSplitBufferPubOp5(t *testing.T) { t.Fatalf("c.msgBuf did not snaphot the msg") } } + +func TestSplitConnectArg(t *testing.T) { + c := &client{subs: hashmap.New()} + connectAll := []byte("CONNECT {\"verbose\":false,\"ssl_required\":false," + + "\"user\":\"test\",\"pedantic\":true,\"pass\":\"pass\"}\r\n") + + argJson := connectAll[8:] + + c1 := connectAll[:5] + c2 := connectAll[5:22] + c3 := connectAll[22 : len(connectAll)-2] + c4 := connectAll[len(connectAll)-2:] + + if err := c.parse(c1); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.argBuf != nil { + t.Fatalf("Unexpected argBug placeholder.\n") + } + + if err := c.parse(c2); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.argBuf == nil { + t.Fatalf("Expected argBug to not be nil.\n") + } + if !bytes.Equal(c.argBuf, argJson[:14]) { + t.Fatalf("argBuf not correct, received %q, wanted %q\n", argJson[:14], c.argBuf) + } + + if err := c.parse(c3); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.argBuf == nil { + t.Fatalf("Expected argBug to not be nil.\n") + } + if !bytes.Equal(c.argBuf, argJson[:len(argJson)-2]) { + t.Fatalf("argBuf not correct, received %q, wanted %q\n", + argJson[:len(argJson)-2], c.argBuf) + } + + if err := c.parse(c4); err != nil { + t.Fatalf("Unexpected parse error: %v\n", err) + } + if c.argBuf != nil { + t.Fatalf("Unexpected argBug placeholder.\n") + } +} diff --git a/test/proto_test.go b/test/proto_test.go index 20023c56..440a2ebc 100644 --- a/test/proto_test.go +++ b/test/proto_test.go @@ -173,3 +173,23 @@ func TestSubToArgState(t *testing.T) { send("SUBZZZ foo 1\r\n") expect(errRe) } + +// Issue #63 +func TestProtoCrash(t *testing.T) { + s := runProtoServer() + defer s.Shutdown() + + c := createClientConn(t, "localhost", PROTO_TEST_PORT) + defer c.Close() + + send, expect := sendCommand(t, c), expectCommand(t, c) + + checkInfoMsg(t, c) + + send("CONNECT {\"verbose\":true,\"ssl_required\":false,\"user\":\"test\",\"pedantic\":true,\"pass\":\"password\"}") + + time.Sleep(100 * time.Millisecond) + + send("\r\n") + expect(okRe) +} diff --git a/test/test.go b/test/test.go index f237ad0d..f92affa2 100644 --- a/test/test.go +++ b/test/test.go @@ -184,7 +184,7 @@ func checkSocket(t tLogger, addr string, wait time.Duration) { t.Fatalf("Failed to connect to the socket: %q", addr) } -func doConnect(t tLogger, c net.Conn, verbose, pedantic, ssl bool) { +func checkInfoMsg(t tLogger, c net.Conn) { buf := expectResult(t, c, infoRe) js := infoRe.FindAllSubmatch(buf, 1)[0][1] var sinfo server.Info @@ -192,6 +192,10 @@ func doConnect(t tLogger, c net.Conn, verbose, pedantic, ssl bool) { if err != nil { stackFatalf(t, "Could not unmarshal INFO json: %v\n", err) } +} + +func doConnect(t tLogger, c net.Conn, verbose, pedantic, ssl bool) { + checkInfoMsg(t, c) cs := fmt.Sprintf("CONNECT {\"verbose\":%v,\"pedantic\":%v,\"ssl_required\":%v}\r\n", verbose, pedantic, ssl) sendProto(t, c, cs) }