diff --git a/server/client.go b/server/client.go index 3626cc21..09d83bb2 100644 --- a/server/client.go +++ b/server/client.go @@ -447,12 +447,6 @@ func (c *client) processConnect(arg []byte) error { srv.mu.Unlock() } - // Check for max connections - if ok := srv.checkMaxConn(c); !ok { - c.maxConnLimit() - return ErrTooManyConnections - } - // Check for Auth if ok := srv.checkAuth(c); !ok { c.authViolation() @@ -496,14 +490,8 @@ func (c *client) authViolation() { c.closeConnection() } -func (c *client) maxConnLimit() { - if c.srv != nil && c.srv.opts.Users != nil { - c.Errorf("%s - User %q", - ErrTooManyConnections.Error(), - c.opts.Username) - } else { - c.Errorf(ErrTooManyConnections.Error()) - } +func (c *client) maxConnExceeded() { + c.Errorf(ErrTooManyConnections.Error()) c.sendErr(ErrTooManyConnections.Error()) c.closeConnection() } diff --git a/server/errors.go b/server/errors.go index 65ddeca1..fee22923 100644 --- a/server/errors.go +++ b/server/errors.go @@ -28,5 +28,5 @@ var ( // ErrTooManyConnections signals a client that the maximum number of connections supported by the // server has been reached. - ErrTooManyConnections = errors.New("Too many connections") + ErrTooManyConnections = errors.New("Maximum Connections Exceeded") ) diff --git a/server/server.go b/server/server.go index c0e937a2..14dd800c 100644 --- a/server/server.go +++ b/server/server.go @@ -546,6 +546,13 @@ func (s *Server) createClient(conn net.Conn) *client { s.mu.Unlock() return c } + // If there is a max connections specified, check that adding + // this new client would not push us over the max + if s.opts.MaxConn > 0 && len(s.clients) >= s.opts.MaxConn { + s.mu.Unlock() + c.maxConnExceeded() + return nil + } s.clients[c.cid] = c s.mu.Unlock() @@ -731,17 +738,6 @@ func (s *Server) checkAuth(c *client) bool { } } -// Check that number of clients is below Max connection setting. -func (s *Server) checkMaxConn(c *client) bool { - if c.typ == CLIENT { - s.mu.Lock() - ok := len(s.clients) <= s.opts.MaxConn - s.mu.Unlock() - return ok - } - return true -} - // Remove a client or route from our internal accounting. func (s *Server) removeClient(c *client) { var rID string diff --git a/server/server_test.go b/server/server_test.go index a113fcb0..141302ed 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -242,9 +242,7 @@ func TestMaxConnections(t *testing.T) { nc2, err := nats.Connect(addr) if err == nil { - if nc2 != nil { - nc2.Close() - } + nc2.Close() t.Fatal("Expected connection to fail") } }