mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-16 19:14:41 -07:00
Merge pull request #1570 from nats-io/limits
Adding user jwt payload and subscriber limits
This commit is contained in:
@@ -588,6 +588,21 @@ func (c *client) subsAtLimit() bool {
|
||||
return c.msubs != jwt.NoLimit && len(c.subs) >= int(c.msubs)
|
||||
}
|
||||
|
||||
func minLimit(value *int32, limit int32) bool {
|
||||
if *value != jwt.NoLimit {
|
||||
if limit != jwt.NoLimit {
|
||||
if limit < *value {
|
||||
*value = limit
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else if limit != jwt.NoLimit {
|
||||
*value = limit
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Apply account limits
|
||||
// Lock is held on entry.
|
||||
// FIXME(dlc) - Should server be able to override here?
|
||||
@@ -595,30 +610,35 @@ func (c *client) applyAccountLimits() {
|
||||
if c.acc == nil || (c.kind != CLIENT && c.kind != LEAF) {
|
||||
return
|
||||
}
|
||||
|
||||
// Set here, will need to fo checks for NoLimit.
|
||||
if c.acc.msubs != jwt.NoLimit {
|
||||
c.msubs = c.acc.msubs
|
||||
c.mpay = jwt.NoLimit
|
||||
c.msubs = jwt.NoLimit
|
||||
if c.opts.JWT != "" { // user jwt implies account
|
||||
if uc, _ := jwt.DecodeUserClaims(c.opts.JWT); uc != nil {
|
||||
c.mpay = int32(uc.Limits.Payload)
|
||||
c.msubs = int32(uc.Limits.Subs)
|
||||
}
|
||||
}
|
||||
if c.acc.mpay != jwt.NoLimit {
|
||||
c.mpay = c.acc.mpay
|
||||
}
|
||||
|
||||
minLimit(&c.mpay, c.acc.mpay)
|
||||
minLimit(&c.msubs, c.acc.msubs)
|
||||
s := c.srv
|
||||
opts := s.getOpts()
|
||||
|
||||
// We check here if the server has an option set that is lower than the account limit.
|
||||
if c.mpay != jwt.NoLimit && opts.MaxPayload != 0 && int32(opts.MaxPayload) < c.acc.mpay {
|
||||
c.Errorf("Max Payload set to %d from server config which overrides %d from account claims", opts.MaxPayload, c.acc.mpay)
|
||||
c.mpay = int32(opts.MaxPayload)
|
||||
mPay := opts.MaxPayload
|
||||
// options encode unlimited differently
|
||||
if mPay == 0 {
|
||||
mPay = jwt.NoLimit
|
||||
}
|
||||
|
||||
// We check here if the server has an option set that is lower than the account limit.
|
||||
if c.msubs != jwt.NoLimit && opts.MaxSubs != 0 && opts.MaxSubs < int(c.acc.msubs) {
|
||||
c.Errorf("Max Subscriptions set to %d from server config which overrides %d from account claims", opts.MaxSubs, c.acc.msubs)
|
||||
c.msubs = int32(opts.MaxSubs)
|
||||
mSubs := int32(opts.MaxSubs)
|
||||
if mSubs == 0 {
|
||||
mSubs = jwt.NoLimit
|
||||
}
|
||||
wasUnlimited := c.mpay == jwt.NoLimit
|
||||
if minLimit(&c.mpay, mPay) && !wasUnlimited {
|
||||
c.Errorf("Max Payload set to %d from server overrides account or user config", opts.MaxPayload)
|
||||
}
|
||||
wasUnlimited = c.msubs == jwt.NoLimit
|
||||
if minLimit(&c.msubs, mSubs) && !wasUnlimited {
|
||||
c.Errorf("Max Subscriptions set to %d from server overrides account or user config", opts.MaxSubs)
|
||||
}
|
||||
|
||||
if c.subsAtLimit() {
|
||||
go func() {
|
||||
c.maxSubsExceeded()
|
||||
|
||||
@@ -33,7 +33,9 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/nats-io/jwt/v2"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nkeys"
|
||||
)
|
||||
|
||||
type serverInfo struct {
|
||||
@@ -2460,3 +2462,65 @@ func TestClientConnectionName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientLimits(t *testing.T) {
|
||||
accKp, err := nkeys.CreateAccount()
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating account key: %v", err)
|
||||
}
|
||||
uKp, err := nkeys.CreateUser()
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating user key: %v", err)
|
||||
}
|
||||
uPub, err := uKp.PublicKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Error obtaining publicKey: %v", err)
|
||||
}
|
||||
s, err := NewServer(DefaultOptions())
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating server: %v", err)
|
||||
}
|
||||
for _, test := range []struct {
|
||||
client int32
|
||||
acc int32
|
||||
srv int32
|
||||
expect int32
|
||||
}{
|
||||
// all identical
|
||||
{1, 1, 1, 1},
|
||||
{-1, -1, 0, -1},
|
||||
// only one value unlimited
|
||||
{1, -1, 0, 1},
|
||||
{-1, 1, 0, 1},
|
||||
{-1, -1, 1, 1},
|
||||
// all combinations of distinct values
|
||||
{1, 2, 3, 1},
|
||||
{1, 3, 2, 1},
|
||||
{2, 1, 3, 1},
|
||||
{2, 3, 1, 1},
|
||||
{3, 1, 2, 1},
|
||||
{3, 2, 1, 1},
|
||||
} {
|
||||
t.Run("", func(t *testing.T) {
|
||||
s.opts.MaxPayload = test.srv
|
||||
s.opts.MaxSubs = int(test.srv)
|
||||
c := &client{srv: s, acc: &Account{
|
||||
limits: limits{mpay: test.acc, msubs: test.acc},
|
||||
}}
|
||||
uc := jwt.NewUserClaims(uPub)
|
||||
uc.Limits.Subs = int64(test.client)
|
||||
uc.Limits.Payload = int64(test.client)
|
||||
c.opts.JWT, err = uc.Encode(accKp)
|
||||
if err != nil {
|
||||
t.Fatalf("Error encoding jwt: %v", err)
|
||||
}
|
||||
c.applyAccountLimits()
|
||||
if c.mpay != test.expect {
|
||||
t.Fatalf("payload %d not as ecpected %d", c.mpay, test.expect)
|
||||
}
|
||||
if c.msubs != test.expect {
|
||||
t.Fatalf("subscriber %d not as ecpected %d", c.msubs, test.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,6 +52,15 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
func chanRecv(t *testing.T, recvChan <-chan struct{}, limit time.Duration) {
|
||||
t.Helper()
|
||||
select {
|
||||
case <-recvChan:
|
||||
case <-time.After(limit):
|
||||
t.Fatal("Should have received from channel")
|
||||
}
|
||||
}
|
||||
|
||||
func opTrustBasicSetup() *Server {
|
||||
kp, _ := nkeys.FromSeed(oSeed)
|
||||
pub, _ := kp.PublicKey()
|
||||
@@ -1844,8 +1853,8 @@ func TestAccountURLResolverFetchFailureInCluster(t *testing.T) {
|
||||
// startup cluster
|
||||
checkClusterFormed(t, sA, sB)
|
||||
// Both server observed one fetch on startup
|
||||
<-chanImpA
|
||||
<-chanImpB
|
||||
chanRecv(t, chanImpA, 10*time.Second)
|
||||
chanRecv(t, chanImpB, 10*time.Second)
|
||||
assertChanLen(0, chanImpA, chanImpB, chanExpA, chanExpB)
|
||||
// Create first client, directly connects to A
|
||||
urlA := fmt.Sprintf("nats://%s:%d", sA.opts.Host, sA.opts.Port)
|
||||
@@ -1870,8 +1879,8 @@ func TestAccountURLResolverFetchFailureInCluster(t *testing.T) {
|
||||
}
|
||||
defer subA.Unsubscribe()
|
||||
// Connect of client triggered a fetch by Server A
|
||||
<-chanImpA
|
||||
<-chanExpA
|
||||
chanRecv(t, chanImpA, 10*time.Second)
|
||||
chanRecv(t, chanExpA, 10*time.Second)
|
||||
assertChanLen(0, chanImpA, chanImpB, chanExpA, chanExpB)
|
||||
//time.Sleep(10 * time.Second)
|
||||
// create second client, directly connect to B
|
||||
@@ -1882,8 +1891,8 @@ func TestAccountURLResolverFetchFailureInCluster(t *testing.T) {
|
||||
}
|
||||
defer ncB.Close()
|
||||
// Connect of client triggered a fetch by Server B
|
||||
<-chanImpB
|
||||
<-chanExpB
|
||||
chanRecv(t, chanImpB, 10*time.Second)
|
||||
chanRecv(t, chanExpB, 10*time.Second)
|
||||
assertChanLen(0, chanImpA, chanImpB, chanExpA, chanExpB)
|
||||
checkClusterFormed(t, sA, sB)
|
||||
// the route subscription was lost due to the failed fetch
|
||||
@@ -3358,8 +3367,8 @@ func TestJWTTimeExpiration(t *testing.T) {
|
||||
errChan <- struct{}{}
|
||||
}
|
||||
}))
|
||||
<-errChan
|
||||
<-disconnectChan
|
||||
chanRecv(t, errChan, 10*time.Second)
|
||||
chanRecv(t, disconnectChan, 10*time.Second)
|
||||
require_True(t, c.IsReconnecting())
|
||||
require_False(t, c.IsConnected())
|
||||
c.Close()
|
||||
@@ -3397,11 +3406,11 @@ func TestJWTTimeExpiration(t *testing.T) {
|
||||
errChan <- struct{}{}
|
||||
}
|
||||
}))
|
||||
<-errChan
|
||||
<-reConnectChan
|
||||
chanRecv(t, errChan, 10*time.Second)
|
||||
chanRecv(t, reConnectChan, 10*time.Second)
|
||||
require_False(t, c.IsReconnecting())
|
||||
require_True(t, c.IsConnected())
|
||||
<-errChan
|
||||
chanRecv(t, errChan, 10*time.Second)
|
||||
c.Close()
|
||||
})
|
||||
t.Run("lower jwt expiration overwrites time", func(t *testing.T) {
|
||||
@@ -3431,10 +3440,67 @@ func TestJWTTimeExpiration(t *testing.T) {
|
||||
errChan <- struct{}{}
|
||||
}
|
||||
}))
|
||||
<-errChan
|
||||
<-disconnectChan
|
||||
chanRecv(t, errChan, 10*time.Second)
|
||||
chanRecv(t, disconnectChan, 10*time.Second)
|
||||
require_True(t, c.IsReconnecting())
|
||||
require_False(t, c.IsConnected())
|
||||
c.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTLimits(t *testing.T) {
|
||||
doNotExpire := time.Now().AddDate(1, 0, 0)
|
||||
// create account
|
||||
kp, _ := nkeys.CreateAccount()
|
||||
aPub, _ := kp.PublicKey()
|
||||
claim := jwt.NewAccountClaims(aPub)
|
||||
aJwt, err := claim.Encode(oKp)
|
||||
require_NoError(t, err)
|
||||
conf := createConfFile(t, []byte(fmt.Sprintf(`
|
||||
listen: -1
|
||||
operator: %s
|
||||
resolver: MEM
|
||||
resolver_preload: {
|
||||
%s: %s
|
||||
}
|
||||
`, ojwt, aPub, aJwt)))
|
||||
defer os.Remove(conf)
|
||||
sA, _ := RunServerWithConfig(conf)
|
||||
defer sA.Shutdown()
|
||||
errChan := make(chan struct{})
|
||||
defer close(errChan)
|
||||
t.Run("subs", func(t *testing.T) {
|
||||
creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { j.Subs = 1 })
|
||||
defer os.Remove(creds)
|
||||
c := natsConnect(t, sA.ClientURL(), nats.UserCredentials(creds),
|
||||
nats.DisconnectErrHandler(func(conn *nats.Conn, err error) {
|
||||
if e := conn.LastError(); e != nil && strings.Contains(e.Error(), "maximum subscriptions exceeded") {
|
||||
errChan <- struct{}{}
|
||||
}
|
||||
}),
|
||||
)
|
||||
defer c.Close()
|
||||
if _, err := c.Subscribe("foo", func(msg *nats.Msg) {}); err != nil {
|
||||
t.Fatalf("couldn't subscribe: %v", err)
|
||||
}
|
||||
if _, err = c.Subscribe("bar", func(msg *nats.Msg) {}); err != nil {
|
||||
t.Fatalf("expected error got: %v", err)
|
||||
}
|
||||
chanRecv(t, errChan, time.Second)
|
||||
})
|
||||
t.Run("payload", func(t *testing.T) {
|
||||
creds := createUserWithLimit(t, kp, doNotExpire, func(j *jwt.Limits) { j.Payload = 5 })
|
||||
defer os.Remove(creds)
|
||||
c := natsConnect(t, sA.ClientURL(), nats.UserCredentials(creds))
|
||||
defer c.Close()
|
||||
if err := c.Flush(); err != nil {
|
||||
t.Fatalf("flush failed %v", err)
|
||||
}
|
||||
if err := c.Publish("foo", []byte("world")); err != nil {
|
||||
t.Fatalf("couldn't publish: %v", err)
|
||||
}
|
||||
if err := c.Publish("foo", []byte("worldX")); err != nats.ErrMaxPayload {
|
||||
t.Fatalf("couldn't publish: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user