diff --git a/server/client.go b/server/client.go index 9eeb9bb7..b9da9c68 100644 --- a/server/client.go +++ b/server/client.go @@ -310,23 +310,28 @@ func (c *client) processConnect(arg []byte) error { c.mu.Lock() c.clearAuthTimer() c.last = time.Now() + typ := c.typ + r := c.route + srv := c.srv c.mu.Unlock() if err := json.Unmarshal(arg, &c.opts); err != nil { return err } - if c.srv != nil { + if srv != nil { // Check for Auth - if ok := c.srv.checkAuth(c); !ok { + if ok := srv.checkAuth(c); !ok { c.authViolation() return ErrAuthorization } } // Grab connection name of remote route. - if c.typ == ROUTER && c.route != nil { + if typ == ROUTER && r != nil { + c.mu.Lock() c.route.remoteID = c.opts.Name + c.mu.Unlock() } if c.opts.Verbose { diff --git a/server/server.go b/server/server.go index a0b648fa..1cde2790 100644 --- a/server/server.go +++ b/server/server.go @@ -690,10 +690,14 @@ func (s *Server) checkAuth(c *client) bool { // Remove a client or route from our internal accounting. func (s *Server) removeClient(c *client) { + var rID string c.mu.Lock() cid := c.cid typ := c.typ r := c.route + if r != nil { + rID = r.remoteID + } c.mu.Unlock() s.mu.Lock() @@ -703,10 +707,10 @@ func (s *Server) removeClient(c *client) { case ROUTER: delete(s.routes, cid) if r != nil { - rc, ok := s.remotes[r.remoteID] + rc, ok := s.remotes[rID] // Only delete it if it is us.. if ok && c == rc { - delete(s.remotes, r.remoteID) + delete(s.remotes, rID) } } } diff --git a/test/routes_test.go b/test/routes_test.go index d6901a26..683954e1 100644 --- a/test/routes_test.go +++ b/test/routes_test.go @@ -9,6 +9,7 @@ import ( "net" "runtime" "strings" + "sync" "testing" "time" @@ -611,3 +612,47 @@ func TestAutoUnsubPropagation(t *testing.T) { routeExpect(unsubnomaxRe) } + +type ignoreLogger struct { +} + +func (l *ignoreLogger) Fatalf(f string, args ...interface{}) { +} +func (l *ignoreLogger) Errorf(f string, args ...interface{}) { +} + +func TestRouteConnectOnShutdownRace(t *testing.T) { + s, opts := runRouteServer(t) + defer s.Shutdown() + + l := &ignoreLogger{} + + var wg sync.WaitGroup + + cQuit := make(chan bool, 1) + + wg.Add(1) + + go func() { + defer wg.Done() + for { + route := createRouteConn(l, opts.ClusterHost, opts.ClusterPort) + if route != nil { + setupRouteEx(l, route, opts, "ROUTE:4222") + route.Close() + } + select { + case <-cQuit: + return + default: + } + } + }() + + time.Sleep(5 * time.Millisecond) + s.Shutdown() + + cQuit <- true + + wg.Wait() +}