diff --git a/server/client.go b/server/client.go index 6042a837..1f2cdb38 100644 --- a/server/client.go +++ b/server/client.go @@ -131,6 +131,7 @@ type client struct { // Here first because of use of atomics, and memory alignment. stats mpay int64 + msubs int mu sync.Mutex typ int cid uint64 @@ -775,6 +776,11 @@ func (c *client) maxConnExceeded() { c.closeConnection(MaxConnectionsExceeded) } +func (c *client) maxSubsExceeded() { + c.Errorf(ErrTooManySubs.Error()) + c.sendErr(ErrTooManySubs.Error()) +} + func (c *client) maxPayloadViolation(sz int, max int64) { c.Errorf("%s: %d vs %d", ErrMaxPayload.Error(), sz, max) c.sendErr("Maximum Payload Violation") @@ -1141,6 +1147,13 @@ func (c *client) processSub(argo []byte) (err error) { return nil } + if c.msubs > 0 && len(c.subs) >= c.msubs { + c.mu.Unlock() + c.maxSubsExceeded() + return nil + } + + // Check if we have a maximum on the number of subscriptions. // We can have two SUB protocols coming from a route due to some // race conditions. We should make sure that we process only one. sid := string(sub.sid) diff --git a/server/configs/test.conf b/server/configs/test.conf index 7d2814b1..01c75f00 100644 --- a/server/configs/test.conf +++ b/server/configs/test.conf @@ -26,6 +26,9 @@ prof_port: 6543 # max_connections max_connections: 100 +# max_subscriptions (per connection) +max_subscriptions: 1000 + # max_pending max_pending: 10000000 diff --git a/server/errors.go b/server/errors.go index b93a9e5c..c722bfc4 100644 --- a/server/errors.go +++ b/server/errors.go @@ -41,6 +41,10 @@ var ( // server has been reached. ErrTooManyConnections = errors.New("Maximum Connections Exceeded") + // ErrTooManySubs signals a client that the maximum number of subscriptions per connection + // has been reached. + ErrTooManySubs = errors.New("Maximum Subscriptions Exceeded") + // ErrClientConnectedToRoutePort represents an error condition when a client // attempted to connect to the route listen port. ErrClientConnectedToRoutePort = errors.New("Attempted To Connect To Route Port") diff --git a/server/opts.go b/server/opts.go index b9619ca1..d54079fa 100644 --- a/server/opts.go +++ b/server/opts.go @@ -59,6 +59,7 @@ type Options struct { NoSigs bool `json:"-"` Logtime bool `json:"-"` MaxConn int `json:"max_connections"` + MaxSubs int `json:"max_subscriptions,omitempty"` Users []*User `json:"-"` Username string `json:"-"` Password string `json:"-"` @@ -295,6 +296,8 @@ func (o *Options) ProcessConfigFile(configFile string) error { o.MaxPending = v.(int64) case "max_connections", "max_conn": o.MaxConn = int(v.(int64)) + case "max_subscriptions", "max_subs": + o.MaxSubs = int(v.(int64)) case "ping_interval": o.PingInterval = time.Duration(int(v.(int64))) * time.Second case "ping_max": diff --git a/server/opts_test.go b/server/opts_test.go index 8fd3eb12..9c391c5a 100644 --- a/server/opts_test.go +++ b/server/opts_test.go @@ -82,6 +82,7 @@ func TestConfigFile(t *testing.T) { MaxControlLine: 2048, MaxPayload: 65536, MaxConn: 100, + MaxSubs: 1000, MaxPending: 10000000, PingInterval: 60 * time.Second, MaxPingsOut: 3, @@ -238,6 +239,7 @@ func TestMergeOverrides(t *testing.T) { MaxControlLine: 2048, MaxPayload: 65536, MaxConn: 100, + MaxSubs: 1000, MaxPending: 10000000, PingInterval: 60 * time.Second, MaxPingsOut: 3, diff --git a/server/reload_test.go b/server/reload_test.go index b062bd24..0585def9 100644 --- a/server/reload_test.go +++ b/server/reload_test.go @@ -1621,6 +1621,29 @@ func TestConfigReloadClusterNoAdvertise(t *testing.T) { } } +func TestConfigReloadMaxSubsUnsupported(t *testing.T) { + conf := "maxsubs.conf" + if err := ioutil.WriteFile(conf, []byte(`max_subs: 1`), 0666); err != nil { + t.Fatalf("Error creating config file: %v", err) + } + defer os.Remove(conf) + opts, err := ProcessConfigFile(conf) + if err != nil { + stackFatalf(t, "Error processing config file: %v", err) + } + opts.NoLog = true + opts.NoSigs = true + s := RunServer(opts) + defer s.Shutdown() + + if err := ioutil.WriteFile(conf, []byte(`max_subs: 10`), 0666); err != nil { + t.Fatalf("Error writing config file: %v", err) + } + if err := s.Reload(); err == nil { + t.Fatal("Expected Reload to return an error") + } +} + func TestConfigReloadClientAdvertise(t *testing.T) { conf := "clientadv.conf" if err := ioutil.WriteFile(conf, []byte(`listen: "0.0.0.0:-1"`), 0666); err != nil { diff --git a/server/server.go b/server/server.go index 496679fd..383c47f5 100644 --- a/server/server.go +++ b/server/server.go @@ -766,6 +766,11 @@ func (s *Server) createClient(conn net.Conn) *client { return c } + // If there is a max subscriptions specified, add to the client. + if opts.MaxSubs > 0 { + c.msubs = opts.MaxSubs + } + // If there is a max connections specified, check that adding // this new client would not push us over the max if opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn { diff --git a/server/server_test.go b/server/server_test.go index fe501e30..20b8dbe2 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -386,6 +386,33 @@ func TestMaxConnections(t *testing.T) { } } +func TestMaxSubscriptions(t *testing.T) { + opts := DefaultOptions() + opts.MaxSubs = 10 + s := RunServer(opts) + defer s.Shutdown() + + addr := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + nc, err := nats.Connect(addr) + if err != nil { + t.Fatalf("Error creating client: %v\n", err) + } + defer nc.Close() + + for i := 0; i < 10; i++ { + _, err := nc.Subscribe(fmt.Sprintf("foo.%d", i), func(*nats.Msg) {}) + if err != nil { + t.Fatalf("Error subscribing: %v\n", err) + } + } + // This should cause the error. + nc.Subscribe("foo.22", func(*nats.Msg) {}) + nc.Flush() + if err := nc.LastError(); err == nil { + t.Fatal("Expected an error but got none\n") + } +} + func TestProcessCommandLineArgs(t *testing.T) { var host string var port int