mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-16 19:14:41 -07:00
Merge pull request #250 from nats-io/wait_for_routes_go_routines
Ensure Shutdown() waits for outstanding go routines
This commit is contained in:
@@ -113,7 +113,7 @@ func init() {
|
||||
}
|
||||
|
||||
// Lock should be held
|
||||
func (c *client) initClient(tlsConn bool) {
|
||||
func (c *client) initClient() {
|
||||
s := c.srv
|
||||
c.cid = atomic.AddUint64(&s.gcid, 1)
|
||||
c.bw = bufio.NewWriterSize(c.nc, startBufSize)
|
||||
@@ -143,14 +143,6 @@ func (c *client) initClient(tlsConn bool) {
|
||||
case ROUTER:
|
||||
c.ncs = fmt.Sprintf("%s - rid:%d", conn, c.cid)
|
||||
}
|
||||
|
||||
if !tlsConn {
|
||||
// Set the Ping timer
|
||||
c.setPingTimer()
|
||||
|
||||
// Spin up the read loop.
|
||||
go c.readLoop()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) readLoop() {
|
||||
@@ -159,6 +151,7 @@ func (c *client) readLoop() {
|
||||
c.mu.Lock()
|
||||
nc := c.nc
|
||||
s := c.srv
|
||||
defer s.grWG.Done()
|
||||
c.mu.Unlock()
|
||||
|
||||
if nc == nil {
|
||||
@@ -1054,6 +1047,12 @@ func (c *client) closeConnection() {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
// It is possible that the server is being shutdown.
|
||||
// If so, don't try to reconnect
|
||||
if !srv.running {
|
||||
return
|
||||
}
|
||||
|
||||
if rid != "" && srv.remotes[rid] != nil {
|
||||
Debugf("Not attempting reconnect for solicited route, already connected to \"%s\"", rid)
|
||||
return
|
||||
@@ -1062,7 +1061,9 @@ func (c *client) closeConnection() {
|
||||
return
|
||||
} else if rtype != Implicit || retryImplicit {
|
||||
Debugf("Attempting reconnect for solicited route \"%s\"", rurl)
|
||||
go srv.reConnectToRoute(rurl, rtype)
|
||||
// Keep track of this go-routine so we can wait for it on
|
||||
// server shutdown.
|
||||
srv.startGoRoutine(func() { srv.reConnectToRoute(rurl, rtype) })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,7 +191,7 @@ func (s *Server) processImplicitRoute(info *Info) {
|
||||
if info.AuthRequired {
|
||||
r.User = url.UserPassword(s.opts.ClusterUsername, s.opts.ClusterPassword)
|
||||
}
|
||||
go s.connectToRoute(r, false)
|
||||
s.startGoRoutine(func() { s.connectToRoute(r, false) })
|
||||
}
|
||||
|
||||
// hasThisRouteConfigured returns true if info.Host:info.Port is present
|
||||
@@ -277,7 +277,7 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client {
|
||||
c.mu.Lock()
|
||||
|
||||
// Initialize
|
||||
c.initClient(tlsRequired)
|
||||
c.initClient()
|
||||
|
||||
c.Debugf("Route connection created")
|
||||
|
||||
@@ -332,17 +332,29 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client {
|
||||
|
||||
// Rewrap bw
|
||||
c.bw = bufio.NewWriterSize(c.nc, startBufSize)
|
||||
}
|
||||
|
||||
// Do final client initialization
|
||||
// Do final client initialization
|
||||
|
||||
// Set the Ping timer
|
||||
c.setPingTimer()
|
||||
// Set the Ping timer
|
||||
c.setPingTimer()
|
||||
|
||||
// Spin up the read loop.
|
||||
go c.readLoop()
|
||||
// For routes, the "client" is added to s.routes only when processing
|
||||
// the INFO protocol, that is much later.
|
||||
// In the meantime, if the server shutsdown, there would be no reference
|
||||
// to the client (connection) to be closed, leaving this readLoop
|
||||
// uinterrupted, causing the Shutdown() to wait indefinitively.
|
||||
// We need to store the client in a special map, under a special lock.
|
||||
s.grMu.Lock()
|
||||
s.grTmpClients[c.cid] = c
|
||||
s.grMu.Unlock()
|
||||
|
||||
// Spin up the read loop.
|
||||
s.startGoRoutine(func() { c.readLoop() })
|
||||
|
||||
if tlsRequired {
|
||||
c.Debugf("TLS handshake complete")
|
||||
cs := conn.ConnectionState()
|
||||
cs := c.nc.(*tls.Conn).ConnectionState()
|
||||
c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite))
|
||||
}
|
||||
|
||||
@@ -436,6 +448,11 @@ func (s *Server) addRoute(c *client, info *Info) (bool, bool) {
|
||||
}
|
||||
remote, exists := s.remotes[id]
|
||||
if !exists {
|
||||
// Remove from the temporary map
|
||||
s.grMu.Lock()
|
||||
delete(s.grTmpClients, c.cid)
|
||||
s.grMu.Unlock()
|
||||
|
||||
s.routes[c.cid] = c
|
||||
s.remotes[id] = c
|
||||
|
||||
@@ -550,7 +567,10 @@ func (s *Server) routeAcceptLoop(ch chan struct{}) {
|
||||
continue
|
||||
}
|
||||
tmpDelay = ACCEPT_MIN_SLEEP
|
||||
go s.createRoute(conn, nil)
|
||||
s.startGoRoutine(func() {
|
||||
s.createRoute(conn, nil)
|
||||
s.grWG.Done()
|
||||
})
|
||||
}
|
||||
Debugf("Router accept loop exiting..")
|
||||
s.done <- true
|
||||
@@ -598,6 +618,7 @@ func (s *Server) reConnectToRoute(rURL *url.URL, rtype RouteType) {
|
||||
}
|
||||
|
||||
func (s *Server) connectToRoute(rURL *url.URL, tryForEver bool) {
|
||||
defer s.grWG.Done()
|
||||
for s.isRunning() && rURL != nil {
|
||||
Debugf("Trying to connect to route on %s", rURL.Host)
|
||||
conn, err := net.DialTimeout("tcp", rURL.Host, DEFAULT_ROUTE_DIAL)
|
||||
@@ -628,7 +649,8 @@ func (c *client) isSolicitedRoute() bool {
|
||||
|
||||
func (s *Server) solicitRoutes() {
|
||||
for _, r := range s.opts.Routes {
|
||||
go s.connectToRoute(r, true)
|
||||
route := r
|
||||
s.startGoRoutine(func() { s.connectToRoute(route, true) })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -65,6 +65,10 @@ type Server struct {
|
||||
routeInfo Info
|
||||
routeInfoJSON []byte
|
||||
rcQuit chan bool
|
||||
grMu sync.Mutex
|
||||
grTmpClients map[uint64]*client
|
||||
grRunning bool
|
||||
grWG sync.WaitGroup // to wait on various go routines
|
||||
}
|
||||
|
||||
// Make sure all are 64bits for atomic use
|
||||
@@ -113,6 +117,10 @@ func New(opts *Options) *Server {
|
||||
// For tracking clients
|
||||
s.clients = make(map[uint64]*client)
|
||||
|
||||
// For tracking connections that are not yet registered
|
||||
// in s.routes, but for which readLoop has started.
|
||||
s.grTmpClients = make(map[uint64]*client)
|
||||
|
||||
// For tracking routes and their remote ids
|
||||
s.routes = make(map[uint64]*client)
|
||||
s.remotes = make(map[string]*client)
|
||||
@@ -205,6 +213,9 @@ func (s *Server) Start() {
|
||||
Debugf("Go build version %s", s.info.GoVersion)
|
||||
|
||||
s.running = true
|
||||
s.grMu.Lock()
|
||||
s.grRunning = true
|
||||
s.grMu.Unlock()
|
||||
|
||||
// Log the pid to a file
|
||||
if s.opts.PidFile != _EMPTY_ {
|
||||
@@ -251,6 +262,9 @@ func (s *Server) Shutdown() {
|
||||
}
|
||||
|
||||
s.running = false
|
||||
s.grMu.Lock()
|
||||
s.grRunning = false
|
||||
s.grMu.Unlock()
|
||||
|
||||
conns := make(map[uint64]*client)
|
||||
|
||||
@@ -258,6 +272,13 @@ func (s *Server) Shutdown() {
|
||||
for i, c := range s.clients {
|
||||
conns[i] = c
|
||||
}
|
||||
// Copy off the connections that are not yet registered
|
||||
// in s.routes, but for which the readLoop has started
|
||||
s.grMu.Lock()
|
||||
for i, c := range s.grTmpClients {
|
||||
conns[i] = c
|
||||
}
|
||||
s.grMu.Unlock()
|
||||
// Copy off the routes
|
||||
for i, r := range s.routes {
|
||||
conns[i] = r
|
||||
@@ -302,6 +323,9 @@ func (s *Server) Shutdown() {
|
||||
<-s.done
|
||||
doneExpected--
|
||||
}
|
||||
|
||||
// Wait for go routines to be done.
|
||||
s.grWG.Wait()
|
||||
}
|
||||
|
||||
// AcceptLoop is exported for easier testing.
|
||||
@@ -365,7 +389,10 @@ func (s *Server) AcceptLoop() {
|
||||
continue
|
||||
}
|
||||
tmpDelay = ACCEPT_MIN_SLEEP
|
||||
go s.createClient(conn)
|
||||
s.startGoRoutine(func() {
|
||||
s.createClient(conn)
|
||||
s.grWG.Done()
|
||||
})
|
||||
}
|
||||
Noticef("Server Exiting..")
|
||||
s.done <- true
|
||||
@@ -480,7 +507,7 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
c.mu.Lock()
|
||||
|
||||
// Initialize
|
||||
c.initClient(tlsRequired)
|
||||
c.initClient()
|
||||
|
||||
c.Debugf("Client connection created")
|
||||
|
||||
@@ -498,14 +525,22 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
|
||||
// Register with the server.
|
||||
s.mu.Lock()
|
||||
// If server is not running, Shutdown() may have already gathered the
|
||||
// list of connections to close. It won't contain this one, so we need
|
||||
// to bail out now otherwise the readLoop started down there would not
|
||||
// be interrupted.
|
||||
if !s.running {
|
||||
s.mu.Unlock()
|
||||
return c
|
||||
}
|
||||
s.clients[c.cid] = c
|
||||
s.mu.Unlock()
|
||||
|
||||
// Re-Grab lock
|
||||
c.mu.Lock()
|
||||
|
||||
// Check for TLS
|
||||
if tlsRequired {
|
||||
// Re-Grab lock
|
||||
c.mu.Lock()
|
||||
|
||||
c.Debugf("Starting TLS client connection handshake")
|
||||
c.nc = tls.Server(c.nc, s.opts.TLSConfig)
|
||||
conn := c.nc.(*tls.Conn)
|
||||
@@ -528,24 +563,35 @@ func (s *Server) createClient(conn net.Conn) *client {
|
||||
|
||||
// Re-Grab lock
|
||||
c.mu.Lock()
|
||||
}
|
||||
|
||||
// The connection may have been closed
|
||||
if c.nc == nil {
|
||||
c.mu.Unlock()
|
||||
return c
|
||||
}
|
||||
|
||||
if tlsRequired {
|
||||
// Rewrap bw
|
||||
c.bw = bufio.NewWriterSize(c.nc, startBufSize)
|
||||
|
||||
// Do final client initialization
|
||||
|
||||
// Set the Ping timer
|
||||
c.setPingTimer()
|
||||
|
||||
// Spin up the read loop.
|
||||
go c.readLoop()
|
||||
|
||||
c.Debugf("TLS handshake complete")
|
||||
cs := conn.ConnectionState()
|
||||
c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite))
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Do final client initialization
|
||||
|
||||
// Set the Ping timer
|
||||
c.setPingTimer()
|
||||
|
||||
// Spin up the read loop.
|
||||
s.startGoRoutine(func() { c.readLoop() })
|
||||
|
||||
if tlsRequired {
|
||||
c.Debugf("TLS handshake complete")
|
||||
cs := c.nc.(*tls.Conn).ConnectionState()
|
||||
c.Debugf("TLS version %s, cipher suite %s", tlsVersion(cs.Version), tlsCipher(cs.CipherSuite))
|
||||
}
|
||||
|
||||
c.mu.Unlock()
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -757,9 +803,18 @@ func (s *Server) GetRouteListenEndpoint() string {
|
||||
return net.JoinHostPort(host, strconv.Itoa(s.opts.ClusterPort))
|
||||
}
|
||||
|
||||
// Id returns the server's ID
|
||||
// ID returns the server's ID
|
||||
func (s *Server) ID() string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.info.ID
|
||||
}
|
||||
|
||||
func (s *Server) startGoRoutine(f func()) {
|
||||
s.grMu.Lock()
|
||||
if s.grRunning {
|
||||
s.grWG.Add(1)
|
||||
go f()
|
||||
}
|
||||
s.grMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -108,6 +108,10 @@ func TestBasicClusterPubSub(t *testing.T) {
|
||||
sendA("PING\r\n")
|
||||
expectA(pongRe)
|
||||
|
||||
if err := checkExpectedSubs(1, srvA, srvB); err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
|
||||
sendB, expectB := setupConn(t, clientB)
|
||||
sendB("PUB foo 2\r\nok\r\n")
|
||||
sendB("PING\r\n")
|
||||
|
||||
@@ -42,6 +42,10 @@ func TestBasicTLSClusterPubSub(t *testing.T) {
|
||||
sendB("PING\r\n")
|
||||
expectB(pongRe)
|
||||
|
||||
if err := checkExpectedSubs(1, srvA, srvB); err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
|
||||
expectMsgs := expectMsgsCommand(t, expectA)
|
||||
|
||||
matches := expectMsgs(1)
|
||||
|
||||
Reference in New Issue
Block a user