Adding user jwt payload and subscriber limits

Addresses part of #1552

Signed-off-by: Matthias Hanel <mh@synadia.com>
This commit is contained in:
Matthias Hanel
2020-08-24 16:00:18 -04:00
parent 8c60b10d5f
commit 9d1526cbb8
3 changed files with 164 additions and 19 deletions

View File

@@ -588,6 +588,21 @@ func (c *client) subsAtLimit() bool {
return c.msubs != jwt.NoLimit && len(c.subs) >= int(c.msubs)
}
func minLimit(value *int32, limit int32) bool {
if *value != jwt.NoLimit {
if limit != jwt.NoLimit {
if limit < *value {
*value = limit
return true
}
}
} else if limit != jwt.NoLimit {
*value = limit
return true
}
return false
}
// Apply account limits
// Lock is held on entry.
// FIXME(dlc) - Should server be able to override here?
@@ -595,30 +610,35 @@ func (c *client) applyAccountLimits() {
if c.acc == nil || (c.kind != CLIENT && c.kind != LEAF) {
return
}
// Set here, will need to fo checks for NoLimit.
if c.acc.msubs != jwt.NoLimit {
c.msubs = c.acc.msubs
c.mpay = jwt.NoLimit
c.msubs = jwt.NoLimit
if c.opts.JWT != "" { // user jwt implies account
if uc, _ := jwt.DecodeUserClaims(c.opts.JWT); uc != nil {
c.mpay = int32(uc.Limits.Payload)
c.msubs = int32(uc.Limits.Subs)
}
}
if c.acc.mpay != jwt.NoLimit {
c.mpay = c.acc.mpay
}
minLimit(&c.mpay, c.acc.mpay)
minLimit(&c.msubs, c.acc.msubs)
s := c.srv
opts := s.getOpts()
// We check here if the server has an option set that is lower than the account limit.
if c.mpay != jwt.NoLimit && opts.MaxPayload != 0 && int32(opts.MaxPayload) < c.acc.mpay {
c.Errorf("Max Payload set to %d from server config which overrides %d from account claims", opts.MaxPayload, c.acc.mpay)
c.mpay = int32(opts.MaxPayload)
mPay := opts.MaxPayload
// options encode unlimited differently
if mPay == 0 {
mPay = jwt.NoLimit
}
// We check here if the server has an option set that is lower than the account limit.
if c.msubs != jwt.NoLimit && opts.MaxSubs != 0 && opts.MaxSubs < int(c.acc.msubs) {
c.Errorf("Max Subscriptions set to %d from server config which overrides %d from account claims", opts.MaxSubs, c.acc.msubs)
c.msubs = int32(opts.MaxSubs)
mSubs := int32(opts.MaxSubs)
if mSubs == 0 {
mSubs = jwt.NoLimit
}
wasUnlimited := c.mpay == jwt.NoLimit
if minLimit(&c.mpay, mPay) && !wasUnlimited {
c.Errorf("Max Payload set to %d from server overrides account or user config", opts.MaxPayload)
}
wasUnlimited = c.msubs == jwt.NoLimit
if minLimit(&c.msubs, mSubs) && !wasUnlimited {
c.Errorf("Max Subscriptions set to %d from server overrides account or user config", opts.MaxSubs)
}
if c.subsAtLimit() {
go func() {
c.maxSubsExceeded()

View File

@@ -33,7 +33,9 @@ import (
"crypto/rand"
"crypto/tls"
"github.com/nats-io/jwt/v2"
"github.com/nats-io/nats.go"
"github.com/nats-io/nkeys"
)
type serverInfo struct {
@@ -2460,3 +2462,65 @@ func TestClientConnectionName(t *testing.T) {
})
}
}
func TestClientLimits(t *testing.T) {
accKp, err := nkeys.CreateAccount()
if err != nil {
t.Fatalf("Error creating account key: %v", err)
}
uKp, err := nkeys.CreateUser()
if err != nil {
t.Fatalf("Error creating user key: %v", err)
}
uPub, err := uKp.PublicKey()
if err != nil {
t.Fatalf("Error obtaining publicKey: %v", err)
}
s, err := NewServer(DefaultOptions())
if err != nil {
t.Fatalf("Error creating server: %v", err)
}
for _, test := range []struct {
client int32
acc int32
srv int32
expect int32
}{
// all identical
{1, 1, 1, 1},
{-1, -1, 0, -1},
// only one value unlimited
{1, -1, 0, 1},
{-1, 1, 0, 1},
{-1, -1, 1, 1},
// all combinations of distinct values
{1, 2, 3, 1},
{1, 3, 2, 1},
{2, 1, 3, 1},
{2, 3, 1, 1},
{3, 1, 2, 1},
{3, 2, 1, 1},
} {
t.Run("", func(t *testing.T) {
s.opts.MaxPayload = test.srv
s.opts.MaxSubs = int(test.srv)
c := &client{srv: s, acc: &Account{
limits: limits{mpay: test.acc, msubs: test.acc},
}}
uc := jwt.NewUserClaims(uPub)
uc.Limits.Subs = int64(test.client)
uc.Limits.Payload = int64(test.client)
c.opts.JWT, err = uc.Encode(accKp)
if err != nil {
t.Fatalf("Error encoding jwt: %v", err)
}
c.applyAccountLimits()
if c.mpay != test.expect {
t.Fatalf("payload %d not as ecpected %d", c.mpay, test.expect)
}
if c.msubs != test.expect {
t.Fatalf("subscriber %d not as ecpected %d", c.msubs, test.expect)
}
})
}
}

View File

@@ -3438,3 +3438,64 @@ func TestJWTTimeExpiration(t *testing.T) {
c.Close()
})
}
func TestJWTSubLimits(t *testing.T) {
doNotExpire := time.Now().AddDate(1, 0, 0)
// create account
kp, _ := nkeys.CreateAccount()
aPub, _ := kp.PublicKey()
claim := jwt.NewAccountClaims(aPub)
aJwt, err := claim.Encode(oKp)
require_NoError(t, err)
conf := createConfFile(t, []byte(fmt.Sprintf(`
listen: -1
operator: %s
resolver: MEM
resolver_preload: {
%s: %s
}
`, ojwt, aPub, aJwt)))
defer os.Remove(conf)
sA, _ := RunServerWithConfig(conf)
defer sA.Shutdown()
errChan := make(chan struct{})
defer close(errChan)
t.Run("subs", func(t *testing.T) {
creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { j.Subs = 1 })
defer os.Remove(creds)
c := natsConnect(t, sA.ClientURL(), nats.UserCredentials(creds),
nats.DisconnectErrHandler(func(conn *nats.Conn, err error) {
if e := conn.LastError(); e != nil && strings.Contains(e.Error(), "maximum subscriptions exceeded") {
errChan <- struct{}{}
}
}),
)
if _, err := c.Subscribe("foo", func(msg *nats.Msg) {}); err != nil {
t.Fatalf("couldn't subscribe: %v", err)
}
if _, err = c.Subscribe("bar", func(msg *nats.Msg) {}); err != nil {
t.Fatalf("expected error got: %v", err)
}
<-errChan
c.Close()
})
t.Run("payload", func(t *testing.T) {
creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { j.Payload = 5 })
defer os.Remove(creds)
c := natsConnect(t, sA.ClientURL(), nats.UserCredentials(creds),
nats.DisconnectErrHandler(func(conn *nats.Conn, err error) {
if e := conn.LastError(); e != nil && strings.Contains(e.Error(), "Maximum Payload Violation") {
errChan <- struct{}{}
}
}),
)
if err := c.Publish("foo", []byte("world")); err != nil {
t.Fatalf("couldn't publish: %v", err)
}
if err := c.Publish("foo", []byte("worldX")); err != nil {
t.Fatalf("couldn't publish: %v", err)
}
<-errChan
c.Close()
})
}