mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 03:38:42 -07:00
auth support, cleanup
This commit is contained in:
@@ -27,6 +27,7 @@ type client struct {
|
||||
srv *Server
|
||||
subs *hashmap.HashMap
|
||||
pcd map[*client]struct{}
|
||||
atmr *time.Timer
|
||||
cstats
|
||||
parseState
|
||||
}
|
||||
@@ -75,7 +76,7 @@ func (c *client) readLoop() {
|
||||
return
|
||||
}
|
||||
if err := c.parse(b[:n]); err != nil {
|
||||
Logf("Parse Error: %v\n", err)
|
||||
Log(err.Error(), clientConnStr(c.conn), c.cid)
|
||||
c.closeConnection()
|
||||
return
|
||||
}
|
||||
@@ -113,12 +114,35 @@ func (c *client) traceOp(op string, arg []byte) {
|
||||
|
||||
func (c *client) processConnect(arg []byte) error {
|
||||
c.traceOp("CONNECT", arg)
|
||||
|
||||
// This will be resolved regardless before we exit this func,
|
||||
// so we can just clear it here.
|
||||
if c.atmr != nil {
|
||||
c.atmr.Stop()
|
||||
c.atmr = nil
|
||||
}
|
||||
|
||||
// FIXME, check err
|
||||
err := json.Unmarshal(arg, &c.opts)
|
||||
if err := json.Unmarshal(arg, &c.opts); err != nil {
|
||||
return err
|
||||
}
|
||||
// Check for Auth
|
||||
if c.srv != nil {
|
||||
if ok := c.srv.checkAuth(c); !ok {
|
||||
c.sendErr("Authorization is Required")
|
||||
return fmt.Errorf("Authorization Error")
|
||||
}
|
||||
}
|
||||
if c.opts.Verbose {
|
||||
c.sendOK()
|
||||
}
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) authViolation() {
|
||||
c.sendErr("Authorization is Required")
|
||||
fmt.Printf("AUTH TIMER EXPIRED!!\n")
|
||||
c.closeConnection()
|
||||
}
|
||||
|
||||
func (c *client) sendErr(err string) {
|
||||
@@ -405,6 +429,25 @@ func (c *client) processMsg(msg []byte) {
|
||||
}
|
||||
}
|
||||
|
||||
// Lock should be held
|
||||
func (c *client) clearAuthTimer() {
|
||||
if c.atmr == nil {
|
||||
return
|
||||
}
|
||||
c.atmr.Stop()
|
||||
c.atmr = nil
|
||||
}
|
||||
|
||||
// Lock should be held
|
||||
func (c *client) clearConnection() {
|
||||
if c.conn == nil {
|
||||
return
|
||||
}
|
||||
c.bw.Flush()
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
func (c *client) closeConnection() {
|
||||
if c.conn == nil {
|
||||
return
|
||||
@@ -412,9 +455,8 @@ func (c *client) closeConnection() {
|
||||
Debug("Client connection closed", clientConnStr(c.conn), c.cid)
|
||||
|
||||
c.mu.Lock()
|
||||
c.bw.Flush()
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
c.clearAuthTimer()
|
||||
c.clearConnection()
|
||||
subs := c.subs.All()
|
||||
c.mu.Unlock()
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ const (
|
||||
DEFAULT_MAX_CONNECTIONS = (64 * 1024)
|
||||
|
||||
// TLS/SSL wait time
|
||||
SSL_TIMEOUT = 500 * time.Millisecond
|
||||
SSL_TIMEOUT = 250 * time.Millisecond
|
||||
|
||||
// Authorization wait time
|
||||
AUTH_TIMEOUT = 2 * SSL_TIMEOUT
|
||||
|
||||
@@ -21,7 +21,7 @@ func (s *Server) LogInit() {
|
||||
if s.opts.Logtime {
|
||||
log.SetFlags(log.LstdFlags)
|
||||
}
|
||||
if s.opts.Trace {
|
||||
if s.opts.Debug {
|
||||
Log(s.opts)
|
||||
}
|
||||
if s.opts.Debug {
|
||||
|
||||
@@ -64,6 +64,9 @@ func (c *client) parse(buf []byte) error {
|
||||
for i, b = range buf {
|
||||
switch c.state {
|
||||
case OP_START:
|
||||
if c.atmr != nil && b != 'C' && b != 'c' {
|
||||
goto authErr
|
||||
}
|
||||
switch b {
|
||||
case 'C', 'c':
|
||||
c.state = OP_C
|
||||
@@ -345,6 +348,10 @@ func (c *client) parse(buf []byte) error {
|
||||
|
||||
parseErr:
|
||||
return fmt.Errorf("Parse Error [%d]: '%s'", c.state, buf[i:])
|
||||
|
||||
authErr:
|
||||
c.authViolation()
|
||||
return fmt.Errorf("Authorization Error")
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/apcera/gnatsd/hashmap"
|
||||
"github.com/apcera/gnatsd/sublist"
|
||||
@@ -45,7 +46,8 @@ type Server struct {
|
||||
debug bool
|
||||
}
|
||||
|
||||
func optionDefaults(opt *Options) {
|
||||
func processOptions(opt *Options) {
|
||||
// Setup non-standard Go defaults
|
||||
if opt.Host == "" {
|
||||
opt.Host = DEFAULT_HOST
|
||||
}
|
||||
@@ -58,8 +60,8 @@ func optionDefaults(opt *Options) {
|
||||
}
|
||||
|
||||
func New(opts Options) *Server {
|
||||
optionDefaults(&opts)
|
||||
inf := Info{
|
||||
processOptions(&opts)
|
||||
info := Info{
|
||||
Id: genId(),
|
||||
Version: VERSION,
|
||||
Host: opts.Host,
|
||||
@@ -68,8 +70,12 @@ func New(opts Options) *Server {
|
||||
SslRequired: false,
|
||||
MaxPayload: MAX_PAYLOAD_SIZE,
|
||||
}
|
||||
// Check for Auth items
|
||||
if opts.Username != "" || opts.Authorization != "" {
|
||||
info.AuthRequired = true
|
||||
}
|
||||
s := &Server{
|
||||
info: inf,
|
||||
info: info,
|
||||
sl: sublist.New(),
|
||||
opts: opts,
|
||||
debug: opts.Debug,
|
||||
@@ -78,10 +84,13 @@ func New(opts Options) *Server {
|
||||
// Setup logging with flags
|
||||
s.LogInit()
|
||||
|
||||
/*
|
||||
if opts.Debug {
|
||||
b, _ := json.Marshal(opts)
|
||||
Debug(fmt.Sprintf("[%s]", b))
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
// Generate the info json
|
||||
b, err := json.Marshal(s.info)
|
||||
@@ -139,6 +148,9 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
|
||||
s.sendInfo(c)
|
||||
go c.readLoop()
|
||||
if s.info.AuthRequired {
|
||||
c.atmr = time.AfterFunc(AUTH_TIMEOUT, func() { c.authViolation() })
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -146,3 +158,19 @@ func (s *Server) sendInfo(c *client) {
|
||||
// FIXME, err
|
||||
c.conn.Write(s.infoJson)
|
||||
}
|
||||
|
||||
// Check auth and return boolean indicating if client is ok
|
||||
func (s *Server) checkAuth(c *client) bool {
|
||||
if !s.info.AuthRequired {
|
||||
return true
|
||||
}
|
||||
// We require auth here, check the client
|
||||
// Authorization tokens trump username/password
|
||||
if s.opts.Authorization != "" {
|
||||
return s.opts.Authorization == c.opts.Authorization
|
||||
} else if s.opts.Username != c.opts.Username ||
|
||||
s.opts.Password != c.opts.Password {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
132
test/auth_test.go
Normal file
132
test/auth_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
// Copyright 2012 Apcera Inc. All rights reserved.
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/apcera/gnatsd/server"
|
||||
)
|
||||
|
||||
const AUTH_PORT=10422
|
||||
|
||||
func doAuthConnect(t tLogger, c net.Conn, token, user, pass string) {
|
||||
cs := fmt.Sprintf("CONNECT {\"verbose\":true,\"auth_token\":\"%s\",\"user\":\"%s\",\"pass\":\"%s\"}\r\n", token, user, pass)
|
||||
sendProto(t, c, cs)
|
||||
}
|
||||
|
||||
func testInfoForAuth(t tLogger, infojs []byte) bool {
|
||||
var sinfo server.Info
|
||||
err := json.Unmarshal(infojs, &sinfo)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not unmarshal INFO json: %v\n", err)
|
||||
}
|
||||
return sinfo.AuthRequired
|
||||
}
|
||||
|
||||
func expectAuthRequired(t tLogger, c net.Conn) {
|
||||
buf := expectResult(t, c, infoRe)
|
||||
infojs := infoRe.FindAllSubmatch(buf, 1)[0][1]
|
||||
if testInfoForAuth(t, infojs) != true {
|
||||
t.Fatalf("Expected server to require authorization: '%s'", infojs)
|
||||
}
|
||||
}
|
||||
|
||||
const AUTH_TOKEN = "_YZZ22_"
|
||||
|
||||
func TestStartupAuthToken(t *testing.T) {
|
||||
s = startServer(t, AUTH_PORT, fmt.Sprintf("--auth=%s", AUTH_TOKEN))
|
||||
}
|
||||
|
||||
func TestNoAuthClient(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", AUTH_PORT)
|
||||
defer c.Close()
|
||||
expectAuthRequired(t, c)
|
||||
doAuthConnect(t, c, "", "", "")
|
||||
expectResult(t, c, errRe)
|
||||
}
|
||||
|
||||
func TestAuthClientBadToken(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", AUTH_PORT)
|
||||
defer c.Close()
|
||||
expectAuthRequired(t, c)
|
||||
doAuthConnect(t, c, "ZZZ", "", "")
|
||||
expectResult(t, c, errRe)
|
||||
}
|
||||
|
||||
func TestAuthClientNoConnect(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", AUTH_PORT)
|
||||
defer c.Close()
|
||||
expectAuthRequired(t, c)
|
||||
// This is timing dependent..
|
||||
time.Sleep(server.AUTH_TIMEOUT)
|
||||
expectResult(t, c, errRe)
|
||||
}
|
||||
|
||||
func TestAuthClientGoodConnect(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", AUTH_PORT)
|
||||
defer c.Close()
|
||||
expectAuthRequired(t, c)
|
||||
doAuthConnect(t, c, AUTH_TOKEN, "", "")
|
||||
expectResult(t, c, okRe)
|
||||
}
|
||||
|
||||
func TestAuthClientFailOnEverythingElse(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", AUTH_PORT)
|
||||
defer c.Close()
|
||||
expectAuthRequired(t, c)
|
||||
sendProto(t, c, "PUB foo 2\r\nok\r\n")
|
||||
expectResult(t, c, errRe)
|
||||
}
|
||||
|
||||
func TestStopServerAuthToken(t *testing.T) {
|
||||
s.stopServer()
|
||||
}
|
||||
|
||||
const AUTH_USER = "derek"
|
||||
const AUTH_PASS = "foobar"
|
||||
|
||||
// The username/password versions
|
||||
func TestStartupAuthPassword(t *testing.T) {
|
||||
s = startServer(t, AUTH_PORT, fmt.Sprintf("--user=%s --pass=%s", AUTH_USER, AUTH_PASS))
|
||||
}
|
||||
|
||||
func TestNoUserOrPasswordClient(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", AUTH_PORT)
|
||||
defer c.Close()
|
||||
expectAuthRequired(t, c)
|
||||
doAuthConnect(t, c, "", "", "")
|
||||
expectResult(t, c, errRe)
|
||||
}
|
||||
|
||||
func TestBadUserClient(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", AUTH_PORT)
|
||||
defer c.Close()
|
||||
expectAuthRequired(t, c)
|
||||
doAuthConnect(t, c, "", "derekzz", AUTH_PASS)
|
||||
expectResult(t, c, errRe)
|
||||
}
|
||||
|
||||
func TestBadPasswordClient(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", AUTH_PORT)
|
||||
defer c.Close()
|
||||
expectAuthRequired(t, c)
|
||||
doAuthConnect(t, c, "", AUTH_USER, "ZZ")
|
||||
expectResult(t, c, errRe)
|
||||
}
|
||||
|
||||
func TestPasswordClientGoodConnect(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", AUTH_PORT)
|
||||
defer c.Close()
|
||||
expectAuthRequired(t, c)
|
||||
doAuthConnect(t, c, "", AUTH_USER, AUTH_PASS)
|
||||
expectResult(t, c, okRe)
|
||||
}
|
||||
|
||||
func TestStopServerAuthPassword(t *testing.T) {
|
||||
s.stopServer()
|
||||
}
|
||||
@@ -47,27 +47,27 @@ func benchPub(b *testing.B, subject, payload string) {
|
||||
s.stopServer()
|
||||
}
|
||||
|
||||
func BenchmarkPubNoPayload(b *testing.B) {
|
||||
func Benchmark____PubNoPayload(b *testing.B) {
|
||||
benchPub(b, "a", "")
|
||||
}
|
||||
|
||||
func BenchmarkPubMinPayload(b *testing.B) {
|
||||
func Benchmark___PubMinPayload(b *testing.B) {
|
||||
benchPub(b, "a", "b")
|
||||
}
|
||||
|
||||
func BenchmarkPubTinyPayload(b *testing.B) {
|
||||
func Benchmark__PubTinyPayload(b *testing.B) {
|
||||
benchPub(b, "foo", "ok")
|
||||
}
|
||||
|
||||
func BenchmarkPubSmallPayload(b *testing.B) {
|
||||
func Benchmark_PubSmallPayload(b *testing.B) {
|
||||
benchPub(b, "foo", "hello world")
|
||||
}
|
||||
|
||||
func BenchmarkPubMedPayload(b *testing.B) {
|
||||
func Benchmark___PubMedPayload(b *testing.B) {
|
||||
benchPub(b, "foo", "The quick brown fox jumps over the lazy dog")
|
||||
}
|
||||
|
||||
func BenchmarkPubLrgPayload(b *testing.B) {
|
||||
func Benchmark_PubLargePayload(b *testing.B) {
|
||||
b.StopTimer()
|
||||
var p string
|
||||
for i := 0 ; i < 200 ; i++ {
|
||||
@@ -98,7 +98,7 @@ func drainConnection(b *testing.B, c net.Conn, ch chan bool, expected int) {
|
||||
ch <- true
|
||||
}
|
||||
|
||||
func BenchmarkPubSub(b *testing.B) {
|
||||
func Benchmark__________PubSub(b *testing.B) {
|
||||
b.StopTimer()
|
||||
s = startServer(b, PERF_PORT, "")
|
||||
c := createClientConn(b, "localhost", PERF_PORT)
|
||||
@@ -130,7 +130,7 @@ func BenchmarkPubSub(b *testing.B) {
|
||||
s.stopServer()
|
||||
}
|
||||
|
||||
func BenchmarkPubSubMultipleConnections(b *testing.B) {
|
||||
func Benchmark__PubSubTwoConns(b *testing.B) {
|
||||
b.StopTimer()
|
||||
s = startServer(b, PERF_PORT, "")
|
||||
c := createClientConn(b, "localhost", PERF_PORT)
|
||||
@@ -165,7 +165,7 @@ func BenchmarkPubSubMultipleConnections(b *testing.B) {
|
||||
s.stopServer()
|
||||
}
|
||||
|
||||
func BenchmarkPubTwoQueueSub(b *testing.B) {
|
||||
func Benchmark__PubTwoQueueSub(b *testing.B) {
|
||||
b.StopTimer()
|
||||
s = startServer(b, PERF_PORT, "")
|
||||
c := createClientConn(b, "localhost", PERF_PORT)
|
||||
@@ -198,7 +198,7 @@ func BenchmarkPubTwoQueueSub(b *testing.B) {
|
||||
s.stopServer()
|
||||
}
|
||||
|
||||
func BenchmarkPubFourQueueSub(b *testing.B) {
|
||||
func Benchmark_PubFourQueueSub(b *testing.B) {
|
||||
b.StopTimer()
|
||||
s = startServer(b, PERF_PORT, "")
|
||||
c := createClientConn(b, "localhost", PERF_PORT)
|
||||
|
||||
@@ -12,9 +12,11 @@ func TestStartupPedantic(t *testing.T) {
|
||||
|
||||
func TestPedanticSub(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", PROTO_TEST_PORT)
|
||||
doConnect(t, c, true, true, false)
|
||||
defer c.Close()
|
||||
send := sendCommand(t, c)
|
||||
expect := expectCommand(t, c)
|
||||
doConnect(t, c, true, true, false)
|
||||
expect(okRe)
|
||||
|
||||
// Ping should still be same
|
||||
send("PING\r\n")
|
||||
@@ -53,9 +55,15 @@ func TestPedanticSub(t *testing.T) {
|
||||
|
||||
func TestPedanticPub(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", PROTO_TEST_PORT)
|
||||
doConnect(t, c, true, true, false)
|
||||
defer c.Close()
|
||||
send := sendCommand(t, c)
|
||||
expect := expectCommand(t, c)
|
||||
doConnect(t, c, true, true, false)
|
||||
expect(okRe)
|
||||
|
||||
// Ping should still be same
|
||||
send("PING\r\n")
|
||||
expect(pongRe)
|
||||
|
||||
// Test malformed subjects for PUB
|
||||
// PUB subjects can not have wildcards
|
||||
|
||||
@@ -3,14 +3,8 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/apcera/gnatsd/server"
|
||||
)
|
||||
|
||||
var s *natsServer
|
||||
@@ -21,126 +15,6 @@ func TestStartup(t *testing.T) {
|
||||
s = startServer(t, PROTO_TEST_PORT, "")
|
||||
}
|
||||
|
||||
type sendFun func(string)
|
||||
type expectFun func(*regexp.Regexp) []byte
|
||||
|
||||
// Closure version for easier reading
|
||||
func sendCommand(t tLogger, c net.Conn) sendFun {
|
||||
return func(op string) {
|
||||
sendProto(t, c, op)
|
||||
}
|
||||
}
|
||||
|
||||
// Closure version for easier reading
|
||||
func expectCommand(t tLogger, c net.Conn) expectFun {
|
||||
return func(re *regexp.Regexp) []byte {
|
||||
return expectResult(t, c, re)
|
||||
}
|
||||
}
|
||||
|
||||
// Send the protocol command to the server.
|
||||
func sendProto(t tLogger, c net.Conn, op string) {
|
||||
n, err := c.Write([]byte(op))
|
||||
if err != nil {
|
||||
t.Fatalf("Error writing command to conn: %v\n", err)
|
||||
}
|
||||
if n != len(op) {
|
||||
t.Fatalf("Partial write: %d vs %d\n", n, len(op))
|
||||
}
|
||||
}
|
||||
|
||||
// Reuse expect buffer
|
||||
var expBuf = make([]byte, 32768)
|
||||
|
||||
// Test result from server against regexp
|
||||
func expectResult(t tLogger, c net.Conn, re *regexp.Regexp) []byte {
|
||||
// Wait for commands to be processed and results queued for read
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
c.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
||||
defer c.SetReadDeadline(time.Time{})
|
||||
|
||||
n, err := c.Read(expBuf)
|
||||
if err != nil {
|
||||
t.Fatalf("Error reading from conn: %v\n", err)
|
||||
}
|
||||
buf := expBuf[:n]
|
||||
if !re.Match(buf) {
|
||||
t.Fatalf("Response did not match expected: '%s' vs '%s'\n", buf, re)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
// This will check that we got what we expected.
|
||||
func checkMsg(t tLogger, m [][]byte, subject, sid, reply, len, msg string) {
|
||||
if string(m[SUB_INDEX]) != subject {
|
||||
t.Fatalf("Did not get correct subject: expected '%s' got '%s'\n", subject, m[SUB_INDEX])
|
||||
}
|
||||
if string(m[SID_INDEX]) != sid {
|
||||
t.Fatalf("Did not get correct sid: exepected '%s' got '%s'\n", sid, m[SID_INDEX])
|
||||
}
|
||||
if string(m[REPLY_INDEX]) != reply {
|
||||
t.Fatalf("Did not get correct reply: exepected '%s' got '%s'\n", reply, m[REPLY_INDEX])
|
||||
}
|
||||
if string(m[LEN_INDEX]) != len {
|
||||
t.Fatalf("Did not get correct msg length: expected '%s' got '%s'\n", len, m[LEN_INDEX])
|
||||
}
|
||||
if string(m[MSG_INDEX]) != msg {
|
||||
t.Fatalf("Did not get correct msg: expected '%s' got '%s'\n", msg, m[MSG_INDEX])
|
||||
}
|
||||
}
|
||||
|
||||
// Closure for expectMsgs
|
||||
func expectMsgsCommand(t tLogger, ef expectFun) func(int) [][][]byte {
|
||||
return func(expected int) [][][]byte {
|
||||
buf := ef(msgRe)
|
||||
matches := msgRe.FindAllSubmatch(buf, -1)
|
||||
if len(matches) != expected {
|
||||
t.Fatalf("Did not get correct # msgs: %d vs %d\n", len(matches), expected)
|
||||
}
|
||||
return matches
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
infoRe = regexp.MustCompile(`\AINFO\s+([^\r\n]+)\r\n`)
|
||||
pongRe = regexp.MustCompile(`\APONG\r\n`)
|
||||
msgRe = regexp.MustCompile(`(?:(?:MSG\s+([^\s]+)\s+([^\s]+)\s+(([^\s]+)[^\S\r\n]+)?(\d+)\r\n([^\\r\\n]*?)\r\n)+?)`)
|
||||
okRe = regexp.MustCompile(`\A\+OK\r\n`)
|
||||
errRe = regexp.MustCompile(`\A\-ERR\s+([^\r\n]+)\r\n`)
|
||||
)
|
||||
|
||||
const (
|
||||
SUB_INDEX = 1
|
||||
SID_INDEX = 2
|
||||
REPLY_INDEX = 4
|
||||
LEN_INDEX = 5
|
||||
MSG_INDEX = 6
|
||||
)
|
||||
|
||||
func doConnect(t tLogger, c net.Conn, verbose, pedantic, ssl bool) {
|
||||
cs := fmt.Sprintf("CONNECT {\"verbose\":%v,\"pedantic\":%v,\"ssl_required\":%v}\r\n", verbose, pedantic, ssl)
|
||||
sendProto(t, c, cs)
|
||||
buf := expectResult(t, c, infoRe)
|
||||
js := infoRe.FindAllSubmatch(buf, 1)[0][1]
|
||||
var sinfo server.Info
|
||||
err := json.Unmarshal(js, &sinfo)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not unmarshal INFO json: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func doDefaultConnect(t tLogger, c net.Conn) {
|
||||
// Basic Connect
|
||||
doConnect(t, c, false, false, false)
|
||||
}
|
||||
|
||||
func setupConn(t tLogger, c net.Conn) (sendFun, expectFun) {
|
||||
doDefaultConnect(t, c)
|
||||
send := sendCommand(t, c)
|
||||
expect := expectCommand(t, c)
|
||||
return send, expect
|
||||
}
|
||||
|
||||
func TestProtoBasics(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", PROTO_TEST_PORT)
|
||||
send, expect := setupConn(t, c)
|
||||
@@ -191,6 +65,9 @@ func TestQueueSub(t *testing.T) {
|
||||
for i := 0; i < sent; i++ {
|
||||
send("PUB foo 2\r\nok\r\n")
|
||||
}
|
||||
// Wait for responses
|
||||
time.Sleep(250*time.Millisecond)
|
||||
|
||||
matches := expectMsgs(sent)
|
||||
sids := make(map[string]int)
|
||||
for _, m := range matches {
|
||||
@@ -221,6 +98,9 @@ func TestMultipleQueueSub(t *testing.T) {
|
||||
for i := 0; i < sent; i++ {
|
||||
send("PUB foo 2\r\nok\r\n")
|
||||
}
|
||||
// Wait for responses
|
||||
time.Sleep(250*time.Millisecond)
|
||||
|
||||
matches := expectMsgs(sent * 2)
|
||||
sids := make(map[string]int)
|
||||
for _, m := range matches {
|
||||
|
||||
125
test/test.go
125
test/test.go
@@ -3,11 +3,16 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/apcera/gnatsd/server"
|
||||
)
|
||||
|
||||
const natsServerExe = "../gnatsd"
|
||||
@@ -69,4 +74,124 @@ func createClientConn(t tLogger, host string, port int) net.Conn {
|
||||
return c
|
||||
}
|
||||
|
||||
func doConnect(t tLogger, c net.Conn, verbose, pedantic, ssl bool) {
|
||||
cs := fmt.Sprintf("CONNECT {\"verbose\":%v,\"pedantic\":%v,\"ssl_required\":%v}\r\n", verbose, pedantic, ssl)
|
||||
sendProto(t, c, cs)
|
||||
buf := expectResult(t, c, infoRe)
|
||||
js := infoRe.FindAllSubmatch(buf, 1)[0][1]
|
||||
var sinfo server.Info
|
||||
err := json.Unmarshal(js, &sinfo)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not unmarshal INFO json: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func doDefaultConnect(t tLogger, c net.Conn) {
|
||||
// Basic Connect
|
||||
doConnect(t, c, false, false, false)
|
||||
}
|
||||
|
||||
func setupConn(t tLogger, c net.Conn) (sendFun, expectFun) {
|
||||
doDefaultConnect(t, c)
|
||||
send := sendCommand(t, c)
|
||||
expect := expectCommand(t, c)
|
||||
return send, expect
|
||||
}
|
||||
|
||||
type sendFun func(string)
|
||||
type expectFun func(*regexp.Regexp) []byte
|
||||
|
||||
// Closure version for easier reading
|
||||
func sendCommand(t tLogger, c net.Conn) sendFun {
|
||||
return func(op string) {
|
||||
sendProto(t, c, op)
|
||||
}
|
||||
}
|
||||
|
||||
// Closure version for easier reading
|
||||
func expectCommand(t tLogger, c net.Conn) expectFun {
|
||||
return func(re *regexp.Regexp) []byte {
|
||||
return expectResult(t, c, re)
|
||||
}
|
||||
}
|
||||
|
||||
// Send the protocol command to the server.
|
||||
func sendProto(t tLogger, c net.Conn, op string) {
|
||||
n, err := c.Write([]byte(op))
|
||||
if err != nil {
|
||||
t.Fatalf("Error writing command to conn: %v\n", err)
|
||||
}
|
||||
if n != len(op) {
|
||||
t.Fatalf("Partial write: %d vs %d\n", n, len(op))
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
infoRe = regexp.MustCompile(`\AINFO\s+([^\r\n]+)\r\n`)
|
||||
pongRe = regexp.MustCompile(`\APONG\r\n`)
|
||||
msgRe = regexp.MustCompile(`(?:(?:MSG\s+([^\s]+)\s+([^\s]+)\s+(([^\s]+)[^\S\r\n]+)?(\d+)\r\n([^\\r\\n]*?)\r\n)+?)`)
|
||||
okRe = regexp.MustCompile(`\A\+OK\r\n`)
|
||||
errRe = regexp.MustCompile(`\A\-ERR\s+([^\r\n]+)\r\n`)
|
||||
)
|
||||
|
||||
const (
|
||||
SUB_INDEX = 1
|
||||
SID_INDEX = 2
|
||||
REPLY_INDEX = 4
|
||||
LEN_INDEX = 5
|
||||
MSG_INDEX = 6
|
||||
)
|
||||
|
||||
// Reuse expect buffer
|
||||
var expBuf = make([]byte, 32768)
|
||||
|
||||
// Test result from server against regexp
|
||||
func expectResult(t tLogger, c net.Conn, re *regexp.Regexp) []byte {
|
||||
// Wait for commands to be processed and results queued for read
|
||||
// time.Sleep(50 * time.Millisecond)
|
||||
c.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
||||
defer c.SetReadDeadline(time.Time{})
|
||||
|
||||
n, err := c.Read(expBuf)
|
||||
if n <= 0 && err != nil {
|
||||
t.Fatalf("Error reading from conn: %v\n", err)
|
||||
}
|
||||
buf := expBuf[:n]
|
||||
|
||||
if !re.Match(buf) {
|
||||
buf = bytes.Replace(buf, []byte("\r\n"), []byte("\\r\\n"), -1)
|
||||
t.Fatalf("Response did not match expected: \n\tReceived:'%s'\n\tExpected:'%s'\n", buf, re)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
// This will check that we got what we expected.
|
||||
func checkMsg(t tLogger, m [][]byte, subject, sid, reply, len, msg string) {
|
||||
if string(m[SUB_INDEX]) != subject {
|
||||
t.Fatalf("Did not get correct subject: expected '%s' got '%s'\n", subject, m[SUB_INDEX])
|
||||
}
|
||||
if string(m[SID_INDEX]) != sid {
|
||||
t.Fatalf("Did not get correct sid: exepected '%s' got '%s'\n", sid, m[SID_INDEX])
|
||||
}
|
||||
if string(m[REPLY_INDEX]) != reply {
|
||||
t.Fatalf("Did not get correct reply: exepected '%s' got '%s'\n", reply, m[REPLY_INDEX])
|
||||
}
|
||||
if string(m[LEN_INDEX]) != len {
|
||||
t.Fatalf("Did not get correct msg length: expected '%s' got '%s'\n", len, m[LEN_INDEX])
|
||||
}
|
||||
if string(m[MSG_INDEX]) != msg {
|
||||
t.Fatalf("Did not get correct msg: expected '%s' got '%s'\n", msg, m[MSG_INDEX])
|
||||
}
|
||||
}
|
||||
|
||||
// Closure for expectMsgs
|
||||
func expectMsgsCommand(t tLogger, ef expectFun) func(int) [][][]byte {
|
||||
return func(expected int) [][][]byte {
|
||||
buf := ef(msgRe)
|
||||
matches := msgRe.FindAllSubmatch(buf, -1)
|
||||
if len(matches) != expected {
|
||||
t.Fatalf("Did not get correct # msgs: %d vs %d\n", len(matches), expected)
|
||||
}
|
||||
return matches
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,9 +13,13 @@ func TestStartupVerbose(t *testing.T) {
|
||||
func TestVerbosePing(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", PROTO_TEST_PORT)
|
||||
doConnect(t, c, true, false, false)
|
||||
defer c.Close()
|
||||
|
||||
send := sendCommand(t, c)
|
||||
expect := expectCommand(t, c)
|
||||
|
||||
expect(okRe)
|
||||
|
||||
// Ping should still be same
|
||||
send("PING\r\n")
|
||||
expect(pongRe)
|
||||
@@ -24,9 +28,13 @@ func TestVerbosePing(t *testing.T) {
|
||||
func TestVerboseConnect(t *testing.T) {
|
||||
c := createClientConn(t, "localhost", PROTO_TEST_PORT)
|
||||
doConnect(t, c, true, false, false)
|
||||
defer c.Close()
|
||||
|
||||
send := sendCommand(t, c)
|
||||
expect := expectCommand(t, c)
|
||||
|
||||
expect(okRe)
|
||||
|
||||
// Connect
|
||||
send("CONNECT {\"verbose\":true,\"pedantic\":true,\"ssl_required\":false}\r\n")
|
||||
expect(okRe)
|
||||
@@ -38,6 +46,8 @@ func TestVerbosePubSub(t *testing.T) {
|
||||
send := sendCommand(t, c)
|
||||
expect := expectCommand(t, c)
|
||||
|
||||
expect(okRe)
|
||||
|
||||
// Pub
|
||||
send("PUB foo 2\r\nok\r\n")
|
||||
expect(okRe)
|
||||
|
||||
Reference in New Issue
Block a user