diff --git a/server/client.go b/server/client.go index b8069185..804fca75 100644 --- a/server/client.go +++ b/server/client.go @@ -27,6 +27,7 @@ type client struct { srv *Server subs *hashmap.HashMap pcd map[*client]struct{} + atmr *time.Timer cstats parseState } @@ -75,7 +76,7 @@ func (c *client) readLoop() { return } if err := c.parse(b[:n]); err != nil { - Logf("Parse Error: %v\n", err) + Log(err.Error(), clientConnStr(c.conn), c.cid) c.closeConnection() return } @@ -113,12 +114,35 @@ func (c *client) traceOp(op string, arg []byte) { func (c *client) processConnect(arg []byte) error { c.traceOp("CONNECT", arg) + + // This will be resolved regardless before we exit this func, + // so we can just clear it here. + if c.atmr != nil { + c.atmr.Stop() + c.atmr = nil + } + // FIXME, check err - err := json.Unmarshal(arg, &c.opts) + if err := json.Unmarshal(arg, &c.opts); err != nil { + return err + } + // Check for Auth + if c.srv != nil { + if ok := c.srv.checkAuth(c); !ok { + c.sendErr("Authorization is Required") + return fmt.Errorf("Authorization Error") + } + } if c.opts.Verbose { c.sendOK() } - return err + return nil +} + +func (c *client) authViolation() { + c.sendErr("Authorization is Required") + fmt.Printf("AUTH TIMER EXPIRED!!\n") + c.closeConnection() } func (c *client) sendErr(err string) { @@ -405,6 +429,25 @@ func (c *client) processMsg(msg []byte) { } } +// Lock should be held +func (c *client) clearAuthTimer() { + if c.atmr == nil { + return + } + c.atmr.Stop() + c.atmr = nil +} + +// Lock should be held +func (c *client) clearConnection() { + if c.conn == nil { + return + } + c.bw.Flush() + c.conn.Close() + c.conn = nil +} + func (c *client) closeConnection() { if c.conn == nil { return @@ -412,9 +455,8 @@ func (c *client) closeConnection() { Debug("Client connection closed", clientConnStr(c.conn), c.cid) c.mu.Lock() - c.bw.Flush() - c.conn.Close() - c.conn = nil + c.clearAuthTimer() + c.clearConnection() subs := c.subs.All() c.mu.Unlock() diff --git a/server/const.go b/server/const.go index a85b6643..652e24a7 100644 --- a/server/const.go +++ b/server/const.go @@ -25,7 +25,7 @@ const ( DEFAULT_MAX_CONNECTIONS = (64 * 1024) // TLS/SSL wait time - SSL_TIMEOUT = 500 * time.Millisecond + SSL_TIMEOUT = 250 * time.Millisecond // Authorization wait time AUTH_TIMEOUT = 2 * SSL_TIMEOUT diff --git a/server/log.go b/server/log.go index 15dc98c9..005fff10 100644 --- a/server/log.go +++ b/server/log.go @@ -21,7 +21,7 @@ func (s *Server) LogInit() { if s.opts.Logtime { log.SetFlags(log.LstdFlags) } - if s.opts.Trace { + if s.opts.Debug { Log(s.opts) } if s.opts.Debug { diff --git a/server/parser.go b/server/parser.go index fced0871..5bd4d7cd 100644 --- a/server/parser.go +++ b/server/parser.go @@ -64,6 +64,9 @@ func (c *client) parse(buf []byte) error { for i, b = range buf { switch c.state { case OP_START: + if c.atmr != nil && b != 'C' && b != 'c' { + goto authErr + } switch b { case 'C', 'c': c.state = OP_C @@ -345,6 +348,10 @@ func (c *client) parse(buf []byte) error { parseErr: return fmt.Errorf("Parse Error [%d]: '%s'", c.state, buf[i:]) + +authErr: + c.authViolation() + return fmt.Errorf("Authorization Error") } diff --git a/server/server.go b/server/server.go index bd4e15a1..9a58287b 100644 --- a/server/server.go +++ b/server/server.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "sync/atomic" + "time" "github.com/apcera/gnatsd/hashmap" "github.com/apcera/gnatsd/sublist" @@ -45,7 +46,8 @@ type Server struct { debug bool } -func optionDefaults(opt *Options) { +func processOptions(opt *Options) { + // Setup non-standard Go defaults if opt.Host == "" { opt.Host = DEFAULT_HOST } @@ -58,8 +60,8 @@ func optionDefaults(opt *Options) { } func New(opts Options) *Server { - optionDefaults(&opts) - inf := Info{ + processOptions(&opts) + info := Info{ Id: genId(), Version: VERSION, Host: opts.Host, @@ -68,8 +70,12 @@ func New(opts Options) *Server { SslRequired: false, MaxPayload: MAX_PAYLOAD_SIZE, } + // Check for Auth items + if opts.Username != "" || opts.Authorization != "" { + info.AuthRequired = true + } s := &Server{ - info: inf, + info: info, sl: sublist.New(), opts: opts, debug: opts.Debug, @@ -78,10 +84,13 @@ func New(opts Options) *Server { // Setup logging with flags s.LogInit() + /* if opts.Debug { b, _ := json.Marshal(opts) Debug(fmt.Sprintf("[%s]", b)) } + */ + // Generate the info json b, err := json.Marshal(s.info) @@ -139,6 +148,9 @@ func (s *Server) createClient(conn net.Conn) *client { s.sendInfo(c) go c.readLoop() + if s.info.AuthRequired { + c.atmr = time.AfterFunc(AUTH_TIMEOUT, func() { c.authViolation() }) + } return c } @@ -146,3 +158,19 @@ func (s *Server) sendInfo(c *client) { // FIXME, err c.conn.Write(s.infoJson) } + +// Check auth and return boolean indicating if client is ok +func (s *Server) checkAuth(c *client) bool { + if !s.info.AuthRequired { + return true + } + // We require auth here, check the client + // Authorization tokens trump username/password + if s.opts.Authorization != "" { + return s.opts.Authorization == c.opts.Authorization + } else if s.opts.Username != c.opts.Username || + s.opts.Password != c.opts.Password { + return false + } + return true +} diff --git a/test/auth_test.go b/test/auth_test.go new file mode 100644 index 00000000..611dfbcf --- /dev/null +++ b/test/auth_test.go @@ -0,0 +1,132 @@ +// Copyright 2012 Apcera Inc. All rights reserved. + +package test + +import ( + "encoding/json" + "fmt" + "net" + "testing" + "time" + + "github.com/apcera/gnatsd/server" +) + +const AUTH_PORT=10422 + +func doAuthConnect(t tLogger, c net.Conn, token, user, pass string) { + cs := fmt.Sprintf("CONNECT {\"verbose\":true,\"auth_token\":\"%s\",\"user\":\"%s\",\"pass\":\"%s\"}\r\n", token, user, pass) + sendProto(t, c, cs) +} + +func testInfoForAuth(t tLogger, infojs []byte) bool { + var sinfo server.Info + err := json.Unmarshal(infojs, &sinfo) + if err != nil { + t.Fatalf("Could not unmarshal INFO json: %v\n", err) + } + return sinfo.AuthRequired +} + +func expectAuthRequired(t tLogger, c net.Conn) { + buf := expectResult(t, c, infoRe) + infojs := infoRe.FindAllSubmatch(buf, 1)[0][1] + if testInfoForAuth(t, infojs) != true { + t.Fatalf("Expected server to require authorization: '%s'", infojs) + } +} + +const AUTH_TOKEN = "_YZZ22_" + +func TestStartupAuthToken(t *testing.T) { + s = startServer(t, AUTH_PORT, fmt.Sprintf("--auth=%s", AUTH_TOKEN)) +} + +func TestNoAuthClient(t *testing.T) { + c := createClientConn(t, "localhost", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", "", "") + expectResult(t, c, errRe) +} + +func TestAuthClientBadToken(t *testing.T) { + c := createClientConn(t, "localhost", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "ZZZ", "", "") + expectResult(t, c, errRe) +} + +func TestAuthClientNoConnect(t *testing.T) { + c := createClientConn(t, "localhost", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + // This is timing dependent.. + time.Sleep(server.AUTH_TIMEOUT) + expectResult(t, c, errRe) +} + +func TestAuthClientGoodConnect(t *testing.T) { + c := createClientConn(t, "localhost", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, AUTH_TOKEN, "", "") + expectResult(t, c, okRe) +} + +func TestAuthClientFailOnEverythingElse(t *testing.T) { + c := createClientConn(t, "localhost", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + sendProto(t, c, "PUB foo 2\r\nok\r\n") + expectResult(t, c, errRe) +} + +func TestStopServerAuthToken(t *testing.T) { + s.stopServer() +} + +const AUTH_USER = "derek" +const AUTH_PASS = "foobar" + +// The username/password versions +func TestStartupAuthPassword(t *testing.T) { + s = startServer(t, AUTH_PORT, fmt.Sprintf("--user=%s --pass=%s", AUTH_USER, AUTH_PASS)) +} + +func TestNoUserOrPasswordClient(t *testing.T) { + c := createClientConn(t, "localhost", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", "", "") + expectResult(t, c, errRe) +} + +func TestBadUserClient(t *testing.T) { + c := createClientConn(t, "localhost", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", "derekzz", AUTH_PASS) + expectResult(t, c, errRe) +} + +func TestBadPasswordClient(t *testing.T) { + c := createClientConn(t, "localhost", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", AUTH_USER, "ZZ") + expectResult(t, c, errRe) +} + +func TestPasswordClientGoodConnect(t *testing.T) { + c := createClientConn(t, "localhost", AUTH_PORT) + defer c.Close() + expectAuthRequired(t, c) + doAuthConnect(t, c, "", AUTH_USER, AUTH_PASS) + expectResult(t, c, okRe) +} + +func TestStopServerAuthPassword(t *testing.T) { + s.stopServer() +} diff --git a/test/bench_test.go b/test/bench_test.go index fc9038b5..a315ce37 100644 --- a/test/bench_test.go +++ b/test/bench_test.go @@ -47,27 +47,27 @@ func benchPub(b *testing.B, subject, payload string) { s.stopServer() } -func BenchmarkPubNoPayload(b *testing.B) { +func Benchmark____PubNoPayload(b *testing.B) { benchPub(b, "a", "") } -func BenchmarkPubMinPayload(b *testing.B) { +func Benchmark___PubMinPayload(b *testing.B) { benchPub(b, "a", "b") } -func BenchmarkPubTinyPayload(b *testing.B) { +func Benchmark__PubTinyPayload(b *testing.B) { benchPub(b, "foo", "ok") } -func BenchmarkPubSmallPayload(b *testing.B) { +func Benchmark_PubSmallPayload(b *testing.B) { benchPub(b, "foo", "hello world") } -func BenchmarkPubMedPayload(b *testing.B) { +func Benchmark___PubMedPayload(b *testing.B) { benchPub(b, "foo", "The quick brown fox jumps over the lazy dog") } -func BenchmarkPubLrgPayload(b *testing.B) { +func Benchmark_PubLargePayload(b *testing.B) { b.StopTimer() var p string for i := 0 ; i < 200 ; i++ { @@ -98,7 +98,7 @@ func drainConnection(b *testing.B, c net.Conn, ch chan bool, expected int) { ch <- true } -func BenchmarkPubSub(b *testing.B) { +func Benchmark__________PubSub(b *testing.B) { b.StopTimer() s = startServer(b, PERF_PORT, "") c := createClientConn(b, "localhost", PERF_PORT) @@ -130,7 +130,7 @@ func BenchmarkPubSub(b *testing.B) { s.stopServer() } -func BenchmarkPubSubMultipleConnections(b *testing.B) { +func Benchmark__PubSubTwoConns(b *testing.B) { b.StopTimer() s = startServer(b, PERF_PORT, "") c := createClientConn(b, "localhost", PERF_PORT) @@ -165,7 +165,7 @@ func BenchmarkPubSubMultipleConnections(b *testing.B) { s.stopServer() } -func BenchmarkPubTwoQueueSub(b *testing.B) { +func Benchmark__PubTwoQueueSub(b *testing.B) { b.StopTimer() s = startServer(b, PERF_PORT, "") c := createClientConn(b, "localhost", PERF_PORT) @@ -198,7 +198,7 @@ func BenchmarkPubTwoQueueSub(b *testing.B) { s.stopServer() } -func BenchmarkPubFourQueueSub(b *testing.B) { +func Benchmark_PubFourQueueSub(b *testing.B) { b.StopTimer() s = startServer(b, PERF_PORT, "") c := createClientConn(b, "localhost", PERF_PORT) diff --git a/test/pedantic_test.go b/test/pedantic_test.go index 775a4072..7f8b78bd 100644 --- a/test/pedantic_test.go +++ b/test/pedantic_test.go @@ -12,9 +12,11 @@ func TestStartupPedantic(t *testing.T) { func TestPedanticSub(t *testing.T) { c := createClientConn(t, "localhost", PROTO_TEST_PORT) - doConnect(t, c, true, true, false) + defer c.Close() send := sendCommand(t, c) expect := expectCommand(t, c) + doConnect(t, c, true, true, false) + expect(okRe) // Ping should still be same send("PING\r\n") @@ -53,9 +55,15 @@ func TestPedanticSub(t *testing.T) { func TestPedanticPub(t *testing.T) { c := createClientConn(t, "localhost", PROTO_TEST_PORT) - doConnect(t, c, true, true, false) + defer c.Close() send := sendCommand(t, c) expect := expectCommand(t, c) + doConnect(t, c, true, true, false) + expect(okRe) + + // Ping should still be same + send("PING\r\n") + expect(pongRe) // Test malformed subjects for PUB // PUB subjects can not have wildcards diff --git a/test/proto_test.go b/test/proto_test.go index 8020083d..46a69860 100644 --- a/test/proto_test.go +++ b/test/proto_test.go @@ -3,14 +3,8 @@ package test import ( - "encoding/json" - "fmt" - "net" - "regexp" "testing" "time" - - "github.com/apcera/gnatsd/server" ) var s *natsServer @@ -21,126 +15,6 @@ func TestStartup(t *testing.T) { s = startServer(t, PROTO_TEST_PORT, "") } -type sendFun func(string) -type expectFun func(*regexp.Regexp) []byte - -// Closure version for easier reading -func sendCommand(t tLogger, c net.Conn) sendFun { - return func(op string) { - sendProto(t, c, op) - } -} - -// Closure version for easier reading -func expectCommand(t tLogger, c net.Conn) expectFun { - return func(re *regexp.Regexp) []byte { - return expectResult(t, c, re) - } -} - -// Send the protocol command to the server. -func sendProto(t tLogger, c net.Conn, op string) { - n, err := c.Write([]byte(op)) - if err != nil { - t.Fatalf("Error writing command to conn: %v\n", err) - } - if n != len(op) { - t.Fatalf("Partial write: %d vs %d\n", n, len(op)) - } -} - -// Reuse expect buffer -var expBuf = make([]byte, 32768) - -// Test result from server against regexp -func expectResult(t tLogger, c net.Conn, re *regexp.Regexp) []byte { - // Wait for commands to be processed and results queued for read - time.Sleep(100 * time.Millisecond) - c.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) - defer c.SetReadDeadline(time.Time{}) - - n, err := c.Read(expBuf) - if err != nil { - t.Fatalf("Error reading from conn: %v\n", err) - } - buf := expBuf[:n] - if !re.Match(buf) { - t.Fatalf("Response did not match expected: '%s' vs '%s'\n", buf, re) - } - return buf -} - -// This will check that we got what we expected. -func checkMsg(t tLogger, m [][]byte, subject, sid, reply, len, msg string) { - if string(m[SUB_INDEX]) != subject { - t.Fatalf("Did not get correct subject: expected '%s' got '%s'\n", subject, m[SUB_INDEX]) - } - if string(m[SID_INDEX]) != sid { - t.Fatalf("Did not get correct sid: exepected '%s' got '%s'\n", sid, m[SID_INDEX]) - } - if string(m[REPLY_INDEX]) != reply { - t.Fatalf("Did not get correct reply: exepected '%s' got '%s'\n", reply, m[REPLY_INDEX]) - } - if string(m[LEN_INDEX]) != len { - t.Fatalf("Did not get correct msg length: expected '%s' got '%s'\n", len, m[LEN_INDEX]) - } - if string(m[MSG_INDEX]) != msg { - t.Fatalf("Did not get correct msg: expected '%s' got '%s'\n", msg, m[MSG_INDEX]) - } -} - -// Closure for expectMsgs -func expectMsgsCommand(t tLogger, ef expectFun) func(int) [][][]byte { - return func(expected int) [][][]byte { - buf := ef(msgRe) - matches := msgRe.FindAllSubmatch(buf, -1) - if len(matches) != expected { - t.Fatalf("Did not get correct # msgs: %d vs %d\n", len(matches), expected) - } - return matches - } -} - -var ( - infoRe = regexp.MustCompile(`\AINFO\s+([^\r\n]+)\r\n`) - pongRe = regexp.MustCompile(`\APONG\r\n`) - msgRe = regexp.MustCompile(`(?:(?:MSG\s+([^\s]+)\s+([^\s]+)\s+(([^\s]+)[^\S\r\n]+)?(\d+)\r\n([^\\r\\n]*?)\r\n)+?)`) - okRe = regexp.MustCompile(`\A\+OK\r\n`) - errRe = regexp.MustCompile(`\A\-ERR\s+([^\r\n]+)\r\n`) -) - -const ( - SUB_INDEX = 1 - SID_INDEX = 2 - REPLY_INDEX = 4 - LEN_INDEX = 5 - MSG_INDEX = 6 -) - -func doConnect(t tLogger, c net.Conn, verbose, pedantic, ssl bool) { - cs := fmt.Sprintf("CONNECT {\"verbose\":%v,\"pedantic\":%v,\"ssl_required\":%v}\r\n", verbose, pedantic, ssl) - sendProto(t, c, cs) - buf := expectResult(t, c, infoRe) - js := infoRe.FindAllSubmatch(buf, 1)[0][1] - var sinfo server.Info - err := json.Unmarshal(js, &sinfo) - if err != nil { - t.Fatalf("Could not unmarshal INFO json: %v\n", err) - } -} - -func doDefaultConnect(t tLogger, c net.Conn) { - // Basic Connect - doConnect(t, c, false, false, false) -} - -func setupConn(t tLogger, c net.Conn) (sendFun, expectFun) { - doDefaultConnect(t, c) - send := sendCommand(t, c) - expect := expectCommand(t, c) - return send, expect -} - func TestProtoBasics(t *testing.T) { c := createClientConn(t, "localhost", PROTO_TEST_PORT) send, expect := setupConn(t, c) @@ -191,6 +65,9 @@ func TestQueueSub(t *testing.T) { for i := 0; i < sent; i++ { send("PUB foo 2\r\nok\r\n") } + // Wait for responses + time.Sleep(250*time.Millisecond) + matches := expectMsgs(sent) sids := make(map[string]int) for _, m := range matches { @@ -221,6 +98,9 @@ func TestMultipleQueueSub(t *testing.T) { for i := 0; i < sent; i++ { send("PUB foo 2\r\nok\r\n") } + // Wait for responses + time.Sleep(250*time.Millisecond) + matches := expectMsgs(sent * 2) sids := make(map[string]int) for _, m := range matches { diff --git a/test/test.go b/test/test.go index 9e566aac..acad09b1 100644 --- a/test/test.go +++ b/test/test.go @@ -3,11 +3,16 @@ package test import ( + "bytes" + "encoding/json" "fmt" "net" "os/exec" + "regexp" "strings" "time" + + "github.com/apcera/gnatsd/server" ) const natsServerExe = "../gnatsd" @@ -69,4 +74,124 @@ func createClientConn(t tLogger, host string, port int) net.Conn { return c } +func doConnect(t tLogger, c net.Conn, verbose, pedantic, ssl bool) { + cs := fmt.Sprintf("CONNECT {\"verbose\":%v,\"pedantic\":%v,\"ssl_required\":%v}\r\n", verbose, pedantic, ssl) + sendProto(t, c, cs) + buf := expectResult(t, c, infoRe) + js := infoRe.FindAllSubmatch(buf, 1)[0][1] + var sinfo server.Info + err := json.Unmarshal(js, &sinfo) + if err != nil { + t.Fatalf("Could not unmarshal INFO json: %v\n", err) + } +} +func doDefaultConnect(t tLogger, c net.Conn) { + // Basic Connect + doConnect(t, c, false, false, false) +} + +func setupConn(t tLogger, c net.Conn) (sendFun, expectFun) { + doDefaultConnect(t, c) + send := sendCommand(t, c) + expect := expectCommand(t, c) + return send, expect +} + +type sendFun func(string) +type expectFun func(*regexp.Regexp) []byte + +// Closure version for easier reading +func sendCommand(t tLogger, c net.Conn) sendFun { + return func(op string) { + sendProto(t, c, op) + } +} + +// Closure version for easier reading +func expectCommand(t tLogger, c net.Conn) expectFun { + return func(re *regexp.Regexp) []byte { + return expectResult(t, c, re) + } +} + +// Send the protocol command to the server. +func sendProto(t tLogger, c net.Conn, op string) { + n, err := c.Write([]byte(op)) + if err != nil { + t.Fatalf("Error writing command to conn: %v\n", err) + } + if n != len(op) { + t.Fatalf("Partial write: %d vs %d\n", n, len(op)) + } +} + +var ( + infoRe = regexp.MustCompile(`\AINFO\s+([^\r\n]+)\r\n`) + pongRe = regexp.MustCompile(`\APONG\r\n`) + msgRe = regexp.MustCompile(`(?:(?:MSG\s+([^\s]+)\s+([^\s]+)\s+(([^\s]+)[^\S\r\n]+)?(\d+)\r\n([^\\r\\n]*?)\r\n)+?)`) + okRe = regexp.MustCompile(`\A\+OK\r\n`) + errRe = regexp.MustCompile(`\A\-ERR\s+([^\r\n]+)\r\n`) +) + +const ( + SUB_INDEX = 1 + SID_INDEX = 2 + REPLY_INDEX = 4 + LEN_INDEX = 5 + MSG_INDEX = 6 +) + +// Reuse expect buffer +var expBuf = make([]byte, 32768) + +// Test result from server against regexp +func expectResult(t tLogger, c net.Conn, re *regexp.Regexp) []byte { + // Wait for commands to be processed and results queued for read + // time.Sleep(50 * time.Millisecond) + c.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) + defer c.SetReadDeadline(time.Time{}) + + n, err := c.Read(expBuf) + if n <= 0 && err != nil { + t.Fatalf("Error reading from conn: %v\n", err) + } + buf := expBuf[:n] + + if !re.Match(buf) { + buf = bytes.Replace(buf, []byte("\r\n"), []byte("\\r\\n"), -1) + t.Fatalf("Response did not match expected: \n\tReceived:'%s'\n\tExpected:'%s'\n", buf, re) + } + return buf +} + +// This will check that we got what we expected. +func checkMsg(t tLogger, m [][]byte, subject, sid, reply, len, msg string) { + if string(m[SUB_INDEX]) != subject { + t.Fatalf("Did not get correct subject: expected '%s' got '%s'\n", subject, m[SUB_INDEX]) + } + if string(m[SID_INDEX]) != sid { + t.Fatalf("Did not get correct sid: exepected '%s' got '%s'\n", sid, m[SID_INDEX]) + } + if string(m[REPLY_INDEX]) != reply { + t.Fatalf("Did not get correct reply: exepected '%s' got '%s'\n", reply, m[REPLY_INDEX]) + } + if string(m[LEN_INDEX]) != len { + t.Fatalf("Did not get correct msg length: expected '%s' got '%s'\n", len, m[LEN_INDEX]) + } + if string(m[MSG_INDEX]) != msg { + t.Fatalf("Did not get correct msg: expected '%s' got '%s'\n", msg, m[MSG_INDEX]) + } +} + +// Closure for expectMsgs +func expectMsgsCommand(t tLogger, ef expectFun) func(int) [][][]byte { + return func(expected int) [][][]byte { + buf := ef(msgRe) + matches := msgRe.FindAllSubmatch(buf, -1) + if len(matches) != expected { + t.Fatalf("Did not get correct # msgs: %d vs %d\n", len(matches), expected) + } + return matches + } +} diff --git a/test/verbose_test.go b/test/verbose_test.go index 1c199433..af41b297 100644 --- a/test/verbose_test.go +++ b/test/verbose_test.go @@ -13,9 +13,13 @@ func TestStartupVerbose(t *testing.T) { func TestVerbosePing(t *testing.T) { c := createClientConn(t, "localhost", PROTO_TEST_PORT) doConnect(t, c, true, false, false) + defer c.Close() + send := sendCommand(t, c) expect := expectCommand(t, c) + expect(okRe) + // Ping should still be same send("PING\r\n") expect(pongRe) @@ -24,9 +28,13 @@ func TestVerbosePing(t *testing.T) { func TestVerboseConnect(t *testing.T) { c := createClientConn(t, "localhost", PROTO_TEST_PORT) doConnect(t, c, true, false, false) + defer c.Close() + send := sendCommand(t, c) expect := expectCommand(t, c) + expect(okRe) + // Connect send("CONNECT {\"verbose\":true,\"pedantic\":true,\"ssl_required\":false}\r\n") expect(okRe) @@ -38,6 +46,8 @@ func TestVerbosePubSub(t *testing.T) { send := sendCommand(t, c) expect := expectCommand(t, c) + expect(okRe) + // Pub send("PUB foo 2\r\nok\r\n") expect(okRe)