diff --git a/server/client.go b/server/client.go index 8545fab6..f84f9c3e 100644 --- a/server/client.go +++ b/server/client.go @@ -588,6 +588,21 @@ func (c *client) subsAtLimit() bool { return c.msubs != jwt.NoLimit && len(c.subs) >= int(c.msubs) } +func minLimit(value *int32, limit int32) bool { + if *value != jwt.NoLimit { + if limit != jwt.NoLimit { + if limit < *value { + *value = limit + return true + } + } + } else if limit != jwt.NoLimit { + *value = limit + return true + } + return false +} + // Apply account limits // Lock is held on entry. // FIXME(dlc) - Should server be able to override here? @@ -595,30 +610,35 @@ func (c *client) applyAccountLimits() { if c.acc == nil || (c.kind != CLIENT && c.kind != LEAF) { return } - - // Set here, will need to fo checks for NoLimit. - if c.acc.msubs != jwt.NoLimit { - c.msubs = c.acc.msubs + c.mpay = jwt.NoLimit + c.msubs = jwt.NoLimit + if c.opts.JWT != "" { // user jwt implies account + if uc, _ := jwt.DecodeUserClaims(c.opts.JWT); uc != nil { + c.mpay = int32(uc.Limits.Payload) + c.msubs = int32(uc.Limits.Subs) + } } - if c.acc.mpay != jwt.NoLimit { - c.mpay = c.acc.mpay - } - + minLimit(&c.mpay, c.acc.mpay) + minLimit(&c.msubs, c.acc.msubs) s := c.srv opts := s.getOpts() - - // We check here if the server has an option set that is lower than the account limit. - if c.mpay != jwt.NoLimit && opts.MaxPayload != 0 && int32(opts.MaxPayload) < c.acc.mpay { - c.Errorf("Max Payload set to %d from server config which overrides %d from account claims", opts.MaxPayload, c.acc.mpay) - c.mpay = int32(opts.MaxPayload) + mPay := opts.MaxPayload + // options encode unlimited differently + if mPay == 0 { + mPay = jwt.NoLimit } - - // We check here if the server has an option set that is lower than the account limit. - if c.msubs != jwt.NoLimit && opts.MaxSubs != 0 && opts.MaxSubs < int(c.acc.msubs) { - c.Errorf("Max Subscriptions set to %d from server config which overrides %d from account claims", opts.MaxSubs, c.acc.msubs) - c.msubs = int32(opts.MaxSubs) + mSubs := int32(opts.MaxSubs) + if mSubs == 0 { + mSubs = jwt.NoLimit + } + wasUnlimited := c.mpay == jwt.NoLimit + if minLimit(&c.mpay, mPay) && !wasUnlimited { + c.Errorf("Max Payload set to %d from server overrides account or user config", opts.MaxPayload) + } + wasUnlimited = c.msubs == jwt.NoLimit + if minLimit(&c.msubs, mSubs) && !wasUnlimited { + c.Errorf("Max Subscriptions set to %d from server overrides account or user config", opts.MaxSubs) } - if c.subsAtLimit() { go func() { c.maxSubsExceeded() diff --git a/server/client_test.go b/server/client_test.go index 39468c1a..99040f2d 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -33,7 +33,9 @@ import ( "crypto/rand" "crypto/tls" + "github.com/nats-io/jwt/v2" "github.com/nats-io/nats.go" + "github.com/nats-io/nkeys" ) type serverInfo struct { @@ -2460,3 +2462,65 @@ func TestClientConnectionName(t *testing.T) { }) } } + +func TestClientLimits(t *testing.T) { + accKp, err := nkeys.CreateAccount() + if err != nil { + t.Fatalf("Error creating account key: %v", err) + } + uKp, err := nkeys.CreateUser() + if err != nil { + t.Fatalf("Error creating user key: %v", err) + } + uPub, err := uKp.PublicKey() + if err != nil { + t.Fatalf("Error obtaining publicKey: %v", err) + } + s, err := NewServer(DefaultOptions()) + if err != nil { + t.Fatalf("Error creating server: %v", err) + } + for _, test := range []struct { + client int32 + acc int32 + srv int32 + expect int32 + }{ + // all identical + {1, 1, 1, 1}, + {-1, -1, 0, -1}, + // only one value unlimited + {1, -1, 0, 1}, + {-1, 1, 0, 1}, + {-1, -1, 1, 1}, + // all combinations of distinct values + {1, 2, 3, 1}, + {1, 3, 2, 1}, + {2, 1, 3, 1}, + {2, 3, 1, 1}, + {3, 1, 2, 1}, + {3, 2, 1, 1}, + } { + t.Run("", func(t *testing.T) { + s.opts.MaxPayload = test.srv + s.opts.MaxSubs = int(test.srv) + c := &client{srv: s, acc: &Account{ + limits: limits{mpay: test.acc, msubs: test.acc}, + }} + uc := jwt.NewUserClaims(uPub) + uc.Limits.Subs = int64(test.client) + uc.Limits.Payload = int64(test.client) + c.opts.JWT, err = uc.Encode(accKp) + if err != nil { + t.Fatalf("Error encoding jwt: %v", err) + } + c.applyAccountLimits() + if c.mpay != test.expect { + t.Fatalf("payload %d not as ecpected %d", c.mpay, test.expect) + } + if c.msubs != test.expect { + t.Fatalf("subscriber %d not as ecpected %d", c.msubs, test.expect) + } + }) + } +} diff --git a/server/jwt_test.go b/server/jwt_test.go index b0f104c7..51157174 100644 --- a/server/jwt_test.go +++ b/server/jwt_test.go @@ -3438,3 +3438,64 @@ func TestJWTTimeExpiration(t *testing.T) { c.Close() }) } + +func TestJWTSubLimits(t *testing.T) { + doNotExpire := time.Now().AddDate(1, 0, 0) + // create account + kp, _ := nkeys.CreateAccount() + aPub, _ := kp.PublicKey() + claim := jwt.NewAccountClaims(aPub) + aJwt, err := claim.Encode(oKp) + require_NoError(t, err) + conf := createConfFile(t, []byte(fmt.Sprintf(` + listen: -1 + operator: %s + resolver: MEM + resolver_preload: { + %s: %s + } + `, ojwt, aPub, aJwt))) + defer os.Remove(conf) + sA, _ := RunServerWithConfig(conf) + defer sA.Shutdown() + errChan := make(chan struct{}) + defer close(errChan) + t.Run("subs", func(t *testing.T) { + creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { j.Subs = 1 }) + defer os.Remove(creds) + c := natsConnect(t, sA.ClientURL(), nats.UserCredentials(creds), + nats.DisconnectErrHandler(func(conn *nats.Conn, err error) { + if e := conn.LastError(); e != nil && strings.Contains(e.Error(), "maximum subscriptions exceeded") { + errChan <- struct{}{} + } + }), + ) + if _, err := c.Subscribe("foo", func(msg *nats.Msg) {}); err != nil { + t.Fatalf("couldn't subscribe: %v", err) + } + if _, err = c.Subscribe("bar", func(msg *nats.Msg) {}); err != nil { + t.Fatalf("expected error got: %v", err) + } + <-errChan + c.Close() + }) + t.Run("payload", func(t *testing.T) { + creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { j.Payload = 5 }) + defer os.Remove(creds) + c := natsConnect(t, sA.ClientURL(), nats.UserCredentials(creds), + nats.DisconnectErrHandler(func(conn *nats.Conn, err error) { + if e := conn.LastError(); e != nil && strings.Contains(e.Error(), "Maximum Payload Violation") { + errChan <- struct{}{} + } + }), + ) + if err := c.Publish("foo", []byte("world")); err != nil { + t.Fatalf("couldn't publish: %v", err) + } + if err := c.Publish("foo", []byte("worldX")); err != nil { + t.Fatalf("couldn't publish: %v", err) + } + <-errChan + c.Close() + }) +}