Merge pull request #250 from nats-io/wait_for_routes_go_routines

Ensure Shutdown() waits for outstanding go routines
This commit is contained in:
Derek Collison
2016-04-22 03:29:03 -07:00
5 changed files with 125 additions and 39 deletions

View File

@@ -113,7 +113,7 @@ func init() {
}
// Lock should be held
func (c *client) initClient(tlsConn bool) {
func (c *client) initClient() {
s := c.srv
c.cid = atomic.AddUint64(&s.gcid, 1)
c.bw = bufio.NewWriterSize(c.nc, startBufSize)
@@ -143,14 +143,6 @@ func (c *client) initClient(tlsConn bool) {
case ROUTER:
c.ncs = fmt.Sprintf("%s - rid:%d", conn, c.cid)
}
if !tlsConn {
// Set the Ping timer
c.setPingTimer()
// Spin up the read loop.
go c.readLoop()
}
}
func (c *client) readLoop() {
@@ -159,6 +151,7 @@ func (c *client) readLoop() {
c.mu.Lock()
nc := c.nc
s := c.srv
defer s.grWG.Done()
c.mu.Unlock()
if nc == nil {
@@ -1054,6 +1047,12 @@ func (c *client) closeConnection() {
srv.mu.Lock()
defer srv.mu.Unlock()
// It is possible that the server is being shutdown.
// If so, don't try to reconnect
if !srv.running {
return
}
if rid != "" && srv.remotes[rid] != nil {
Debugf("Not attempting reconnect for solicited route, already connected to \"%s\"", rid)
return
@@ -1062,7 +1061,9 @@ func (c *client) closeConnection() {
return
} else if rtype != Implicit || retryImplicit {
Debugf("Attempting reconnect for solicited route \"%s\"", rurl)
go srv.reConnectToRoute(rurl, rtype)
// Keep track of this go-routine so we can wait for it on
// server shutdown.
srv.startGoRoutine(func() { srv.reConnectToRoute(rurl, rtype) })
}
}
}

View File

@@ -191,7 +191,7 @@ func (s *Server) processImplicitRoute(info *Info) {
if info.AuthRequired {
r.User = url.UserPassword(s.opts.ClusterUsername, s.opts.ClusterPassword)
}
go s.connectToRoute(r, false)
s.startGoRoutine(func() { s.connectToRoute(r, false) })
}
// hasThisRouteConfigured returns true if info.Host:info.Port is present
@@ -277,7 +277,7 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client {
c.mu.Lock()
// Initialize
c.initClient(tlsRequired)
c.initClient()
c.Debugf("Route connection created")
@@ -332,17 +332,29 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client {
// Rewrap bw
c.bw = bufio.NewWriterSize(c.nc, startBufSize)
}
// Do final client initialization
// Do final client initialization
// Set the Ping timer
c.setPingTimer()
// Set the Ping timer
c.setPingTimer()
// Spin up the read loop.
go c.readLoop()
// For routes, the "client" is added to s.routes only when processing
// the INFO protocol, that is much later.
// In the meantime, if the server shutsdown, there would be no reference
// to the client (connection) to be closed, leaving this readLoop
// uinterrupted, causing the Shutdown() to wait indefinitively.
// We need to store the client in a special map, under a special lock.
s.grMu.Lock()
s.grTmpClients[c.cid] = c
s.grMu.Unlock()
// Spin up the read loop.
s.startGoRoutine(func() { c.readLoop() })
if tlsRequired {
c.Debugf("TLS handshake complete")
cs := conn.ConnectionState()
cs := c.nc.(*tls.Conn).ConnectionState()
c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite))
}
@@ -436,6 +448,11 @@ func (s *Server) addRoute(c *client, info *Info) (bool, bool) {
}
remote, exists := s.remotes[id]
if !exists {
// Remove from the temporary map
s.grMu.Lock()
delete(s.grTmpClients, c.cid)
s.grMu.Unlock()
s.routes[c.cid] = c
s.remotes[id] = c
@@ -550,7 +567,10 @@ func (s *Server) routeAcceptLoop(ch chan struct{}) {
continue
}
tmpDelay = ACCEPT_MIN_SLEEP
go s.createRoute(conn, nil)
s.startGoRoutine(func() {
s.createRoute(conn, nil)
s.grWG.Done()
})
}
Debugf("Router accept loop exiting..")
s.done <- true
@@ -598,6 +618,7 @@ func (s *Server) reConnectToRoute(rURL *url.URL, rtype RouteType) {
}
func (s *Server) connectToRoute(rURL *url.URL, tryForEver bool) {
defer s.grWG.Done()
for s.isRunning() && rURL != nil {
Debugf("Trying to connect to route on %s", rURL.Host)
conn, err := net.DialTimeout("tcp", rURL.Host, DEFAULT_ROUTE_DIAL)
@@ -628,7 +649,8 @@ func (c *client) isSolicitedRoute() bool {
func (s *Server) solicitRoutes() {
for _, r := range s.opts.Routes {
go s.connectToRoute(r, true)
route := r
s.startGoRoutine(func() { s.connectToRoute(route, true) })
}
}

View File

@@ -65,6 +65,10 @@ type Server struct {
routeInfo Info
routeInfoJSON []byte
rcQuit chan bool
grMu sync.Mutex
grTmpClients map[uint64]*client
grRunning bool
grWG sync.WaitGroup // to wait on various go routines
}
// Make sure all are 64bits for atomic use
@@ -113,6 +117,10 @@ func New(opts *Options) *Server {
// For tracking clients
s.clients = make(map[uint64]*client)
// For tracking connections that are not yet registered
// in s.routes, but for which readLoop has started.
s.grTmpClients = make(map[uint64]*client)
// For tracking routes and their remote ids
s.routes = make(map[uint64]*client)
s.remotes = make(map[string]*client)
@@ -205,6 +213,9 @@ func (s *Server) Start() {
Debugf("Go build version %s", s.info.GoVersion)
s.running = true
s.grMu.Lock()
s.grRunning = true
s.grMu.Unlock()
// Log the pid to a file
if s.opts.PidFile != _EMPTY_ {
@@ -251,6 +262,9 @@ func (s *Server) Shutdown() {
}
s.running = false
s.grMu.Lock()
s.grRunning = false
s.grMu.Unlock()
conns := make(map[uint64]*client)
@@ -258,6 +272,13 @@ func (s *Server) Shutdown() {
for i, c := range s.clients {
conns[i] = c
}
// Copy off the connections that are not yet registered
// in s.routes, but for which the readLoop has started
s.grMu.Lock()
for i, c := range s.grTmpClients {
conns[i] = c
}
s.grMu.Unlock()
// Copy off the routes
for i, r := range s.routes {
conns[i] = r
@@ -302,6 +323,9 @@ func (s *Server) Shutdown() {
<-s.done
doneExpected--
}
// Wait for go routines to be done.
s.grWG.Wait()
}
// AcceptLoop is exported for easier testing.
@@ -365,7 +389,10 @@ func (s *Server) AcceptLoop() {
continue
}
tmpDelay = ACCEPT_MIN_SLEEP
go s.createClient(conn)
s.startGoRoutine(func() {
s.createClient(conn)
s.grWG.Done()
})
}
Noticef("Server Exiting..")
s.done <- true
@@ -480,7 +507,7 @@ func (s *Server) createClient(conn net.Conn) *client {
c.mu.Lock()
// Initialize
c.initClient(tlsRequired)
c.initClient()
c.Debugf("Client connection created")
@@ -498,14 +525,22 @@ func (s *Server) createClient(conn net.Conn) *client {
// Register with the server.
s.mu.Lock()
// If server is not running, Shutdown() may have already gathered the
// list of connections to close. It won't contain this one, so we need
// to bail out now otherwise the readLoop started down there would not
// be interrupted.
if !s.running {
s.mu.Unlock()
return c
}
s.clients[c.cid] = c
s.mu.Unlock()
// Re-Grab lock
c.mu.Lock()
// Check for TLS
if tlsRequired {
// Re-Grab lock
c.mu.Lock()
c.Debugf("Starting TLS client connection handshake")
c.nc = tls.Server(c.nc, s.opts.TLSConfig)
conn := c.nc.(*tls.Conn)
@@ -528,24 +563,35 @@ func (s *Server) createClient(conn net.Conn) *client {
// Re-Grab lock
c.mu.Lock()
}
// The connection may have been closed
if c.nc == nil {
c.mu.Unlock()
return c
}
if tlsRequired {
// Rewrap bw
c.bw = bufio.NewWriterSize(c.nc, startBufSize)
// Do final client initialization
// Set the Ping timer
c.setPingTimer()
// Spin up the read loop.
go c.readLoop()
c.Debugf("TLS handshake complete")
cs := conn.ConnectionState()
c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite))
c.mu.Unlock()
}
// Do final client initialization
// Set the Ping timer
c.setPingTimer()
// Spin up the read loop.
s.startGoRoutine(func() { c.readLoop() })
if tlsRequired {
c.Debugf("TLS handshake complete")
cs := c.nc.(*tls.Conn).ConnectionState()
c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite))
}
c.mu.Unlock()
return c
}
@@ -757,9 +803,18 @@ func (s *Server) GetRouteListenEndpoint() string {
return net.JoinHostPort(host, strconv.Itoa(s.opts.ClusterPort))
}
// Id returns the server's ID
// ID returns the server's ID
func (s *Server) ID() string {
s.mu.Lock()
defer s.mu.Unlock()
return s.info.ID
}
func (s *Server) startGoRoutine(f func()) {
s.grMu.Lock()
if s.grRunning {
s.grWG.Add(1)
go f()
}
s.grMu.Unlock()
}

View File

@@ -108,6 +108,10 @@ func TestBasicClusterPubSub(t *testing.T) {
sendA("PING\r\n")
expectA(pongRe)
if err := checkExpectedSubs(1, srvA, srvB); err != nil {
t.Fatalf("%v", err)
}
sendB, expectB := setupConn(t, clientB)
sendB("PUB foo 2\r\nok\r\n")
sendB("PING\r\n")

View File

@@ -42,6 +42,10 @@ func TestBasicTLSClusterPubSub(t *testing.T) {
sendB("PING\r\n")
expectB(pongRe)
if err := checkExpectedSubs(1, srvA, srvB); err != nil {
t.Fatalf("%v", err)
}
expectMsgs := expectMsgsCommand(t, expectA)
matches := expectMsgs(1)