mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 03:38:42 -07:00
Adding user jwt payload and subscriber limits
Addresses part of #1552 Signed-off-by: Matthias Hanel <mh@synadia.com>
This commit is contained in:
@@ -588,6 +588,21 @@ func (c *client) subsAtLimit() bool {
|
||||
return c.msubs != jwt.NoLimit && len(c.subs) >= int(c.msubs)
|
||||
}
|
||||
|
||||
func minLimit(value *int32, limit int32) bool {
|
||||
if *value != jwt.NoLimit {
|
||||
if limit != jwt.NoLimit {
|
||||
if limit < *value {
|
||||
*value = limit
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else if limit != jwt.NoLimit {
|
||||
*value = limit
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Apply account limits
|
||||
// Lock is held on entry.
|
||||
// FIXME(dlc) - Should server be able to override here?
|
||||
@@ -595,30 +610,35 @@ func (c *client) applyAccountLimits() {
|
||||
if c.acc == nil || (c.kind != CLIENT && c.kind != LEAF) {
|
||||
return
|
||||
}
|
||||
|
||||
// Set here, will need to fo checks for NoLimit.
|
||||
if c.acc.msubs != jwt.NoLimit {
|
||||
c.msubs = c.acc.msubs
|
||||
c.mpay = jwt.NoLimit
|
||||
c.msubs = jwt.NoLimit
|
||||
if c.opts.JWT != "" { // user jwt implies account
|
||||
if uc, _ := jwt.DecodeUserClaims(c.opts.JWT); uc != nil {
|
||||
c.mpay = int32(uc.Limits.Payload)
|
||||
c.msubs = int32(uc.Limits.Subs)
|
||||
}
|
||||
}
|
||||
if c.acc.mpay != jwt.NoLimit {
|
||||
c.mpay = c.acc.mpay
|
||||
}
|
||||
|
||||
minLimit(&c.mpay, c.acc.mpay)
|
||||
minLimit(&c.msubs, c.acc.msubs)
|
||||
s := c.srv
|
||||
opts := s.getOpts()
|
||||
|
||||
// We check here if the server has an option set that is lower than the account limit.
|
||||
if c.mpay != jwt.NoLimit && opts.MaxPayload != 0 && int32(opts.MaxPayload) < c.acc.mpay {
|
||||
c.Errorf("Max Payload set to %d from server config which overrides %d from account claims", opts.MaxPayload, c.acc.mpay)
|
||||
c.mpay = int32(opts.MaxPayload)
|
||||
mPay := opts.MaxPayload
|
||||
// options encode unlimited differently
|
||||
if mPay == 0 {
|
||||
mPay = jwt.NoLimit
|
||||
}
|
||||
|
||||
// We check here if the server has an option set that is lower than the account limit.
|
||||
if c.msubs != jwt.NoLimit && opts.MaxSubs != 0 && opts.MaxSubs < int(c.acc.msubs) {
|
||||
c.Errorf("Max Subscriptions set to %d from server config which overrides %d from account claims", opts.MaxSubs, c.acc.msubs)
|
||||
c.msubs = int32(opts.MaxSubs)
|
||||
mSubs := int32(opts.MaxSubs)
|
||||
if mSubs == 0 {
|
||||
mSubs = jwt.NoLimit
|
||||
}
|
||||
wasUnlimited := c.mpay == jwt.NoLimit
|
||||
if minLimit(&c.mpay, mPay) && !wasUnlimited {
|
||||
c.Errorf("Max Payload set to %d from server overrides account or user config", opts.MaxPayload)
|
||||
}
|
||||
wasUnlimited = c.msubs == jwt.NoLimit
|
||||
if minLimit(&c.msubs, mSubs) && !wasUnlimited {
|
||||
c.Errorf("Max Subscriptions set to %d from server overrides account or user config", opts.MaxSubs)
|
||||
}
|
||||
|
||||
if c.subsAtLimit() {
|
||||
go func() {
|
||||
c.maxSubsExceeded()
|
||||
|
||||
@@ -33,7 +33,9 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/nats-io/jwt/v2"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nkeys"
|
||||
)
|
||||
|
||||
type serverInfo struct {
|
||||
@@ -2460,3 +2462,65 @@ func TestClientConnectionName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientLimits(t *testing.T) {
|
||||
accKp, err := nkeys.CreateAccount()
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating account key: %v", err)
|
||||
}
|
||||
uKp, err := nkeys.CreateUser()
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating user key: %v", err)
|
||||
}
|
||||
uPub, err := uKp.PublicKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Error obtaining publicKey: %v", err)
|
||||
}
|
||||
s, err := NewServer(DefaultOptions())
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating server: %v", err)
|
||||
}
|
||||
for _, test := range []struct {
|
||||
client int32
|
||||
acc int32
|
||||
srv int32
|
||||
expect int32
|
||||
}{
|
||||
// all identical
|
||||
{1, 1, 1, 1},
|
||||
{-1, -1, 0, -1},
|
||||
// only one value unlimited
|
||||
{1, -1, 0, 1},
|
||||
{-1, 1, 0, 1},
|
||||
{-1, -1, 1, 1},
|
||||
// all combinations of distinct values
|
||||
{1, 2, 3, 1},
|
||||
{1, 3, 2, 1},
|
||||
{2, 1, 3, 1},
|
||||
{2, 3, 1, 1},
|
||||
{3, 1, 2, 1},
|
||||
{3, 2, 1, 1},
|
||||
} {
|
||||
t.Run("", func(t *testing.T) {
|
||||
s.opts.MaxPayload = test.srv
|
||||
s.opts.MaxSubs = int(test.srv)
|
||||
c := &client{srv: s, acc: &Account{
|
||||
limits: limits{mpay: test.acc, msubs: test.acc},
|
||||
}}
|
||||
uc := jwt.NewUserClaims(uPub)
|
||||
uc.Limits.Subs = int64(test.client)
|
||||
uc.Limits.Payload = int64(test.client)
|
||||
c.opts.JWT, err = uc.Encode(accKp)
|
||||
if err != nil {
|
||||
t.Fatalf("Error encoding jwt: %v", err)
|
||||
}
|
||||
c.applyAccountLimits()
|
||||
if c.mpay != test.expect {
|
||||
t.Fatalf("payload %d not as ecpected %d", c.mpay, test.expect)
|
||||
}
|
||||
if c.msubs != test.expect {
|
||||
t.Fatalf("subscriber %d not as ecpected %d", c.msubs, test.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3438,3 +3438,64 @@ func TestJWTTimeExpiration(t *testing.T) {
|
||||
c.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTSubLimits(t *testing.T) {
|
||||
doNotExpire := time.Now().AddDate(1, 0, 0)
|
||||
// create account
|
||||
kp, _ := nkeys.CreateAccount()
|
||||
aPub, _ := kp.PublicKey()
|
||||
claim := jwt.NewAccountClaims(aPub)
|
||||
aJwt, err := claim.Encode(oKp)
|
||||
require_NoError(t, err)
|
||||
conf := createConfFile(t, []byte(fmt.Sprintf(`
|
||||
listen: -1
|
||||
operator: %s
|
||||
resolver: MEM
|
||||
resolver_preload: {
|
||||
%s: %s
|
||||
}
|
||||
`, ojwt, aPub, aJwt)))
|
||||
defer os.Remove(conf)
|
||||
sA, _ := RunServerWithConfig(conf)
|
||||
defer sA.Shutdown()
|
||||
errChan := make(chan struct{})
|
||||
defer close(errChan)
|
||||
t.Run("subs", func(t *testing.T) {
|
||||
creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { j.Subs = 1 })
|
||||
defer os.Remove(creds)
|
||||
c := natsConnect(t, sA.ClientURL(), nats.UserCredentials(creds),
|
||||
nats.DisconnectErrHandler(func(conn *nats.Conn, err error) {
|
||||
if e := conn.LastError(); e != nil && strings.Contains(e.Error(), "maximum subscriptions exceeded") {
|
||||
errChan <- struct{}{}
|
||||
}
|
||||
}),
|
||||
)
|
||||
if _, err := c.Subscribe("foo", func(msg *nats.Msg) {}); err != nil {
|
||||
t.Fatalf("couldn't subscribe: %v", err)
|
||||
}
|
||||
if _, err = c.Subscribe("bar", func(msg *nats.Msg) {}); err != nil {
|
||||
t.Fatalf("expected error got: %v", err)
|
||||
}
|
||||
<-errChan
|
||||
c.Close()
|
||||
})
|
||||
t.Run("payload", func(t *testing.T) {
|
||||
creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { j.Payload = 5 })
|
||||
defer os.Remove(creds)
|
||||
c := natsConnect(t, sA.ClientURL(), nats.UserCredentials(creds),
|
||||
nats.DisconnectErrHandler(func(conn *nats.Conn, err error) {
|
||||
if e := conn.LastError(); e != nil && strings.Contains(e.Error(), "Maximum Payload Violation") {
|
||||
errChan <- struct{}{}
|
||||
}
|
||||
}),
|
||||
)
|
||||
if err := c.Publish("foo", []byte("world")); err != nil {
|
||||
t.Fatalf("couldn't publish: %v", err)
|
||||
}
|
||||
if err := c.Publish("foo", []byte("worldX")); err != nil {
|
||||
t.Fatalf("couldn't publish: %v", err)
|
||||
}
|
||||
<-errChan
|
||||
c.Close()
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user