Merge pull request #1570 from nats-io/limits

Adding user jwt payload and subscriber limits
This commit is contained in:
Ivan Kozlovic
2020-08-25 09:59:51 -06:00
committed by GitHub
3 changed files with 182 additions and 32 deletions

View File

@@ -588,6 +588,21 @@ func (c *client) subsAtLimit() bool {
return c.msubs != jwt.NoLimit && len(c.subs) >= int(c.msubs)
}
func minLimit(value *int32, limit int32) bool {
if *value != jwt.NoLimit {
if limit != jwt.NoLimit {
if limit < *value {
*value = limit
return true
}
}
} else if limit != jwt.NoLimit {
*value = limit
return true
}
return false
}
// Apply account limits
// Lock is held on entry.
// FIXME(dlc) - Should server be able to override here?
@@ -595,30 +610,35 @@ func (c *client) applyAccountLimits() {
if c.acc == nil || (c.kind != CLIENT && c.kind != LEAF) {
return
}
// Set here, will need to fo checks for NoLimit.
if c.acc.msubs != jwt.NoLimit {
c.msubs = c.acc.msubs
c.mpay = jwt.NoLimit
c.msubs = jwt.NoLimit
if c.opts.JWT != "" { // user jwt implies account
if uc, _ := jwt.DecodeUserClaims(c.opts.JWT); uc != nil {
c.mpay = int32(uc.Limits.Payload)
c.msubs = int32(uc.Limits.Subs)
}
}
if c.acc.mpay != jwt.NoLimit {
c.mpay = c.acc.mpay
}
minLimit(&c.mpay, c.acc.mpay)
minLimit(&c.msubs, c.acc.msubs)
s := c.srv
opts := s.getOpts()
// We check here if the server has an option set that is lower than the account limit.
if c.mpay != jwt.NoLimit && opts.MaxPayload != 0 && int32(opts.MaxPayload) < c.acc.mpay {
c.Errorf("Max Payload set to %d from server config which overrides %d from account claims", opts.MaxPayload, c.acc.mpay)
c.mpay = int32(opts.MaxPayload)
mPay := opts.MaxPayload
// options encode unlimited differently
if mPay == 0 {
mPay = jwt.NoLimit
}
// We check here if the server has an option set that is lower than the account limit.
if c.msubs != jwt.NoLimit && opts.MaxSubs != 0 && opts.MaxSubs < int(c.acc.msubs) {
c.Errorf("Max Subscriptions set to %d from server config which overrides %d from account claims", opts.MaxSubs, c.acc.msubs)
c.msubs = int32(opts.MaxSubs)
mSubs := int32(opts.MaxSubs)
if mSubs == 0 {
mSubs = jwt.NoLimit
}
wasUnlimited := c.mpay == jwt.NoLimit
if minLimit(&c.mpay, mPay) && !wasUnlimited {
c.Errorf("Max Payload set to %d from server overrides account or user config", opts.MaxPayload)
}
wasUnlimited = c.msubs == jwt.NoLimit
if minLimit(&c.msubs, mSubs) && !wasUnlimited {
c.Errorf("Max Subscriptions set to %d from server overrides account or user config", opts.MaxSubs)
}
if c.subsAtLimit() {
go func() {
c.maxSubsExceeded()

View File

@@ -33,7 +33,9 @@ import (
"crypto/rand"
"crypto/tls"
"github.com/nats-io/jwt/v2"
"github.com/nats-io/nats.go"
"github.com/nats-io/nkeys"
)
type serverInfo struct {
@@ -2460,3 +2462,65 @@ func TestClientConnectionName(t *testing.T) {
})
}
}
func TestClientLimits(t *testing.T) {
accKp, err := nkeys.CreateAccount()
if err != nil {
t.Fatalf("Error creating account key: %v", err)
}
uKp, err := nkeys.CreateUser()
if err != nil {
t.Fatalf("Error creating user key: %v", err)
}
uPub, err := uKp.PublicKey()
if err != nil {
t.Fatalf("Error obtaining publicKey: %v", err)
}
s, err := NewServer(DefaultOptions())
if err != nil {
t.Fatalf("Error creating server: %v", err)
}
for _, test := range []struct {
client int32
acc int32
srv int32
expect int32
}{
// all identical
{1, 1, 1, 1},
{-1, -1, 0, -1},
// only one value unlimited
{1, -1, 0, 1},
{-1, 1, 0, 1},
{-1, -1, 1, 1},
// all combinations of distinct values
{1, 2, 3, 1},
{1, 3, 2, 1},
{2, 1, 3, 1},
{2, 3, 1, 1},
{3, 1, 2, 1},
{3, 2, 1, 1},
} {
t.Run("", func(t *testing.T) {
s.opts.MaxPayload = test.srv
s.opts.MaxSubs = int(test.srv)
c := &client{srv: s, acc: &Account{
limits: limits{mpay: test.acc, msubs: test.acc},
}}
uc := jwt.NewUserClaims(uPub)
uc.Limits.Subs = int64(test.client)
uc.Limits.Payload = int64(test.client)
c.opts.JWT, err = uc.Encode(accKp)
if err != nil {
t.Fatalf("Error encoding jwt: %v", err)
}
c.applyAccountLimits()
if c.mpay != test.expect {
t.Fatalf("payload %d not as ecpected %d", c.mpay, test.expect)
}
if c.msubs != test.expect {
t.Fatalf("subscriber %d not as ecpected %d", c.msubs, test.expect)
}
})
}
}

View File

@@ -52,6 +52,15 @@ func init() {
}
}
func chanRecv(t *testing.T, recvChan <-chan struct{}, limit time.Duration) {
t.Helper()
select {
case <-recvChan:
case <-time.After(limit):
t.Fatal("Should have received from channel")
}
}
func opTrustBasicSetup() *Server {
kp, _ := nkeys.FromSeed(oSeed)
pub, _ := kp.PublicKey()
@@ -1844,8 +1853,8 @@ func TestAccountURLResolverFetchFailureInCluster(t *testing.T) {
// startup cluster
checkClusterFormed(t, sA, sB)
// Both server observed one fetch on startup
<-chanImpA
<-chanImpB
chanRecv(t, chanImpA, 10*time.Second)
chanRecv(t, chanImpB, 10*time.Second)
assertChanLen(0, chanImpA, chanImpB, chanExpA, chanExpB)
// Create first client, directly connects to A
urlA := fmt.Sprintf("nats://%s:%d", sA.opts.Host, sA.opts.Port)
@@ -1870,8 +1879,8 @@ func TestAccountURLResolverFetchFailureInCluster(t *testing.T) {
}
defer subA.Unsubscribe()
// Connect of client triggered a fetch by Server A
<-chanImpA
<-chanExpA
chanRecv(t, chanImpA, 10*time.Second)
chanRecv(t, chanExpA, 10*time.Second)
assertChanLen(0, chanImpA, chanImpB, chanExpA, chanExpB)
//time.Sleep(10 * time.Second)
// create second client, directly connect to B
@@ -1882,8 +1891,8 @@ func TestAccountURLResolverFetchFailureInCluster(t *testing.T) {
}
defer ncB.Close()
// Connect of client triggered a fetch by Server B
<-chanImpB
<-chanExpB
chanRecv(t, chanImpB, 10*time.Second)
chanRecv(t, chanExpB, 10*time.Second)
assertChanLen(0, chanImpA, chanImpB, chanExpA, chanExpB)
checkClusterFormed(t, sA, sB)
// the route subscription was lost due to the failed fetch
@@ -3358,8 +3367,8 @@ func TestJWTTimeExpiration(t *testing.T) {
errChan <- struct{}{}
}
}))
<-errChan
<-disconnectChan
chanRecv(t, errChan, 10*time.Second)
chanRecv(t, disconnectChan, 10*time.Second)
require_True(t, c.IsReconnecting())
require_False(t, c.IsConnected())
c.Close()
@@ -3397,11 +3406,11 @@ func TestJWTTimeExpiration(t *testing.T) {
errChan <- struct{}{}
}
}))
<-errChan
<-reConnectChan
chanRecv(t, errChan, 10*time.Second)
chanRecv(t, reConnectChan, 10*time.Second)
require_False(t, c.IsReconnecting())
require_True(t, c.IsConnected())
<-errChan
chanRecv(t, errChan, 10*time.Second)
c.Close()
})
t.Run("lower jwt expiration overwrites time", func(t *testing.T) {
@@ -3431,10 +3440,67 @@ func TestJWTTimeExpiration(t *testing.T) {
errChan <- struct{}{}
}
}))
<-errChan
<-disconnectChan
chanRecv(t, errChan, 10*time.Second)
chanRecv(t, disconnectChan, 10*time.Second)
require_True(t, c.IsReconnecting())
require_False(t, c.IsConnected())
c.Close()
})
}
func TestJWTLimits(t *testing.T) {
doNotExpire := time.Now().AddDate(1, 0, 0)
// create account
kp, _ := nkeys.CreateAccount()
aPub, _ := kp.PublicKey()
claim := jwt.NewAccountClaims(aPub)
aJwt, err := claim.Encode(oKp)
require_NoError(t, err)
conf := createConfFile(t, []byte(fmt.Sprintf(`
listen: -1
operator: %s
resolver: MEM
resolver_preload: {
%s: %s
}
`, ojwt, aPub, aJwt)))
defer os.Remove(conf)
sA, _ := RunServerWithConfig(conf)
defer sA.Shutdown()
errChan := make(chan struct{})
defer close(errChan)
t.Run("subs", func(t *testing.T) {
creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { j.Subs = 1 })
defer os.Remove(creds)
c := natsConnect(t, sA.ClientURL(), nats.UserCredentials(creds),
nats.DisconnectErrHandler(func(conn *nats.Conn, err error) {
if e := conn.LastError(); e != nil && strings.Contains(e.Error(), "maximum subscriptions exceeded") {
errChan <- struct{}{}
}
}),
)
defer c.Close()
if _, err := c.Subscribe("foo", func(msg *nats.Msg) {}); err != nil {
t.Fatalf("couldn't subscribe: %v", err)
}
if _, err = c.Subscribe("bar", func(msg *nats.Msg) {}); err != nil {
t.Fatalf("expected error got: %v", err)
}
chanRecv(t, errChan, time.Second)
})
t.Run("payload", func(t *testing.T) {
creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { j.Payload = 5 })
defer os.Remove(creds)
c := natsConnect(t, sA.ClientURL(), nats.UserCredentials(creds))
defer c.Close()
if err := c.Flush(); err != nil {
t.Fatalf("flush failed %v", err)
}
if err := c.Publish("foo", []byte("world")); err != nil {
t.Fatalf("couldn't publish: %v", err)
}
if err := c.Publish("foo", []byte("worldX")); err != nats.ErrMaxPayload {
t.Fatalf("couldn't publish: %v", err)
}
})
}