From 33fdc5274d6e194ea1ca24a576858575451b6ba5 Mon Sep 17 00:00:00 2001 From: Derek Collison Date: Wed, 12 Jun 2013 00:12:52 -0700 Subject: [PATCH] Fixed data races --- server/client.go | 22 +++++++++++++++------- server/log.go | 25 +++++++++++++------------ server/parser.go | 2 +- server/server.go | 40 +++++++++++++++++++++++++++++++++------- 4 files changed, 62 insertions(+), 27 deletions(-) diff --git a/server/client.go b/server/client.go index 378897ea..539dcd56 100644 --- a/server/client.go +++ b/server/client.go @@ -1,4 +1,4 @@ -// Copyright 2012 Apcera Inc. All rights reserved. +// Copyright 2012-2013 Apcera Inc. All rights reserved. package server @@ -123,7 +123,7 @@ func (c *client) traceMsg(msg []byte) { } func (c *client) traceOp(op string, arg []byte) { - if !trace { + if trace == 0 { return } opa := []interface{}{fmt.Sprintf("%s OP", op)} @@ -203,7 +203,7 @@ func (c *client) processPong() { const argsLenMax = 3 func (c *client) processPub(arg []byte) error { - if trace { + if trace > 0 { c.traceOp("PUB", arg) } @@ -444,7 +444,7 @@ func (c *client) processMsg(msg []byte) { atomic.AddInt64(&c.srv.inBytes, int64(len(msg))) } - if trace { + if trace > 0 { c.traceMsg(msg) } if c.srv == nil { @@ -554,6 +554,7 @@ func (c *client) clearPingTimer() { c.ptmr = nil } +// Lock should be held func (c *client) setAuthTimer(d time.Duration) { c.atmr = time.AfterFunc(d, func() { c.authViolation() }) } @@ -567,6 +568,12 @@ func (c *client) clearAuthTimer() { c.atmr = nil } +func (c *client) isAuthTimerSet() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.atmr != nil +} + // Lock should be held func (c *client) clearConnection() { if c.conn == nil { @@ -590,16 +597,17 @@ func (c *client) closeConnection() { c.clearPingTimer() c.clearConnection() subs := c.subs.All() + srv := c.srv c.mu.Unlock() - if c.srv != nil { + if srv != nil { // Unregister - c.srv.removeClient(c) + srv.removeClient(c) // Remove subscriptions. for _, s := range subs { if sub, ok := s.(*subscription); ok { - c.srv.sl.Remove(sub.subject, sub) + srv.sl.Remove(sub.subject, sub) } } } diff --git a/server/log.go b/server/log.go index 6d0198da..9f0bc406 100644 --- a/server/log.go +++ b/server/log.go @@ -1,4 +1,4 @@ -// Copyright 2012 Apcera Inc. All rights reserved. +// Copyright 2012-2013 Apcera Inc. All rights reserved. package server @@ -6,13 +6,14 @@ import ( "fmt" "log" "strings" + "sync/atomic" ) // logging functionality, compatible with original nats-server -var trace bool -var debug bool -var nolog bool +var trace int32 +var debug int32 +var nolog int32 func LogSetup() { log.SetFlags(0) @@ -23,15 +24,15 @@ func (s *Server) LogInit() { log.SetFlags(log.LstdFlags) } if s.opts.NoLog { - nolog = true + atomic.StoreInt32(&nolog, 1) } if s.opts.Debug { Log(s.opts) - debug = true + atomic.StoreInt32(&debug, 1) Log("DEBUG is on") } if s.opts.Trace { - trace = true + atomic.StoreInt32(&trace, 1) Log("TRACE is on") } } @@ -59,7 +60,7 @@ func logStr(v []interface{}) string { } func Log(v ...interface{}) { - if !nolog { + if nolog == 0 { log.Print(logStr(v)) } } @@ -77,25 +78,25 @@ func Fatalf(format string, v ...interface{}) { } func Debug(v ...interface{}) { - if debug { + if debug > 0 { Log(v...) } } func Debugf(format string, v ...interface{}) { - if debug { + if debug > 0 { Debug(fmt.Sprintf(format, v...)) } } func Trace(v ...interface{}) { - if trace { + if trace > 0 { Log(v...) } } func Tracef(format string, v ...interface{}) { - if trace { + if trace > 0 { Trace(fmt.Sprintf(format, v...)) } } diff --git a/server/parser.go b/server/parser.go index 6d65fe1e..434a3adb 100644 --- a/server/parser.go +++ b/server/parser.go @@ -66,7 +66,7 @@ func (c *client) parse(buf []byte) error { for i, b = range buf { switch c.state { case OP_START: - if c.atmr != nil && b != 'C' && b != 'c' { + if c.isAuthTimerSet() && b != 'C' && b != 'c' { goto authErr } switch b { diff --git a/server/server.go b/server/server.go index 921d590f..271c1352 100644 --- a/server/server.go +++ b/server/server.go @@ -1,4 +1,4 @@ -// Copyright 2012 Apcera Inc. All rights reserved. +// Copyright 2012-2013 Apcera Inc. All rights reserved. package server @@ -76,6 +76,10 @@ func New(opts *Options) *Server { done: make(chan bool, 1), start: time.Now(), } + + s.mu.Lock() + defer s.mu.Unlock() + // Setup logging with flags s.LogInit() @@ -117,20 +121,37 @@ func (s *Server) handleSignals() { }() } +// Protected check on running state +func (s *Server) isRunning() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.running +} + // Shutdown will shutdown the server instance by kicking out the AcceptLoop // and closing all associated clients. func (s *Server) Shutdown() { + s.mu.Lock() s.running = false - // Close client connections - // FIXME(dlc) lock? will call back into remove.. - for _, c := range s.clients { - c.closeConnection() + + // Copy off the clients + clients := make(map[uint64]*client) + for i, c := range s.clients { + clients[i] = c } + // Kick AcceptLoop() if s.listener != nil { s.listener.Close() s.listener = nil } + s.mu.Unlock() + + // Close client connections + for _, c := range clients { + c.closeConnection() + } + <-s.done } @@ -147,11 +168,13 @@ func (s *Server) AcceptLoop() { Logf("nats-server is ready") // Setup state that can enable shutdown + s.mu.Lock() s.listener = l s.running = true + s.mu.Unlock() - for s.running { - conn, err := s.listener.Accept() + for s.isRunning() { + conn, err := l.Accept() if err != nil { if ne, ok := err.(net.Error); ok && ne.Temporary() { Logf("Accept error: %v", err) @@ -186,6 +209,8 @@ func (s *Server) StartHTTPMonitoring() { func (s *Server) createClient(conn net.Conn) *client { c := &client{srv: s, conn: conn, opts: defaultOpts} + + c.mu.Lock() c.cid = atomic.AddUint64(&s.gcid, 1) c.bw = bufio.NewWriterSize(c.conn, defaultBufSize) @@ -210,6 +235,7 @@ func (s *Server) createClient(conn net.Conn) *client { } // Set the Ping timer c.setPingTimer() + c.mu.Unlock() // Register with the server. s.mu.Lock()