mirror of
https://github.com/gogrlx/nats-server.git
synced 2026-04-02 03:38:42 -07:00
Limit search depth for account cycles for imports
Signed-off-by: Derek Collison <derek@nats.io>
This commit is contained in:
@@ -1346,42 +1346,47 @@ func (a *Account) AddServiceImportWithClaim(destination *Account, from, to strin
|
||||
}
|
||||
|
||||
// Check if this introduces a cycle before proceeding.
|
||||
if a.serviceImportFormsCycle(destination, from) {
|
||||
return ErrImportFormsCycle
|
||||
if err := a.serviceImportFormsCycle(destination, from); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := a.addServiceImport(destination, from, to, imClaim)
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *Account) serviceImportFormsCycle(dest *Account, from string) bool {
|
||||
const MaxAccountCycleSearchDepth = 1024
|
||||
|
||||
func (a *Account) serviceImportFormsCycle(dest *Account, from string) error {
|
||||
return dest.checkServiceImportsForCycles(from, map[string]bool{a.Name: true})
|
||||
}
|
||||
|
||||
func (a *Account) checkServiceImportsForCycles(from string, visited map[string]bool) bool {
|
||||
func (a *Account) checkServiceImportsForCycles(from string, visited map[string]bool) error {
|
||||
if len(visited) >= MaxAccountCycleSearchDepth {
|
||||
return ErrCycleSearchDepth
|
||||
}
|
||||
a.mu.RLock()
|
||||
for _, si := range a.imports.services {
|
||||
if SubjectsCollide(from, si.to) {
|
||||
a.mu.RUnlock()
|
||||
if visited[si.acc.Name] {
|
||||
return true
|
||||
return ErrImportFormsCycle
|
||||
}
|
||||
// Push ourselves and check si.acc
|
||||
visited[a.Name] = true
|
||||
if subjectIsSubsetMatch(si.from, from) {
|
||||
from = si.from
|
||||
}
|
||||
if si.acc.checkServiceImportsForCycles(from, visited) {
|
||||
return true
|
||||
if err := si.acc.checkServiceImportsForCycles(from, visited); err != nil {
|
||||
return err
|
||||
}
|
||||
a.mu.RLock()
|
||||
}
|
||||
}
|
||||
a.mu.RUnlock()
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Account) streamImportFormsCycle(dest *Account, to string) bool {
|
||||
func (a *Account) streamImportFormsCycle(dest *Account, to string) error {
|
||||
return dest.checkStreamImportsForCycles(to, map[string]bool{a.Name: true})
|
||||
}
|
||||
|
||||
@@ -1395,33 +1400,37 @@ func (a *Account) hasStreamExportMatching(to string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Account) checkStreamImportsForCycles(to string, visited map[string]bool) bool {
|
||||
func (a *Account) checkStreamImportsForCycles(to string, visited map[string]bool) error {
|
||||
if len(visited) >= MaxAccountCycleSearchDepth {
|
||||
return ErrCycleSearchDepth
|
||||
}
|
||||
|
||||
a.mu.RLock()
|
||||
|
||||
if !a.hasStreamExportMatching(to) {
|
||||
a.mu.RUnlock()
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, si := range a.imports.streams {
|
||||
if SubjectsCollide(to, si.to) {
|
||||
a.mu.RUnlock()
|
||||
if visited[si.acc.Name] {
|
||||
return true
|
||||
return ErrImportFormsCycle
|
||||
}
|
||||
// Push ourselves and check si.acc
|
||||
visited[a.Name] = true
|
||||
if subjectIsSubsetMatch(si.to, to) {
|
||||
to = si.to
|
||||
}
|
||||
if si.acc.checkStreamImportsForCycles(to, visited) {
|
||||
return true
|
||||
if err := si.acc.checkStreamImportsForCycles(to, visited); err != nil {
|
||||
return err
|
||||
}
|
||||
a.mu.RLock()
|
||||
}
|
||||
}
|
||||
a.mu.RUnlock()
|
||||
return false
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetServiceImportSharing will allow sharing of information about requests with the export account.
|
||||
@@ -2220,8 +2229,8 @@ func (a *Account) AddMappedStreamImportWithClaim(account *Account, from, to stri
|
||||
}
|
||||
|
||||
// Check if this forms a cycle.
|
||||
if a.streamImportFormsCycle(account, to) {
|
||||
return ErrImportFormsCycle
|
||||
if err := a.streamImportFormsCycle(account, to); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -128,6 +128,9 @@ var (
|
||||
// ErrImportFormsCycle is returned when an import would form a cycle.
|
||||
ErrImportFormsCycle = errors.New("import forms a cycle")
|
||||
|
||||
// ErrCycleSearchDepth is returned when we have exceeded our maximum search depth..
|
||||
ErrCycleSearchDepth = errors.New("search cycle depth exhausted")
|
||||
|
||||
// ErrClientOrRouteConnectedToGatewayPort represents an error condition when
|
||||
// a client or route attempted to connect to the Gateway port.
|
||||
ErrClientOrRouteConnectedToGatewayPort = errors.New("attempted to connect to gateway port")
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -211,3 +212,55 @@ func TestAccountCycleServiceNonCycleChain(t *testing.T) {
|
||||
t.Fatalf("Expected no error but got %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Go's stack are infinite sans memory, but not call depth. However its good to limit.
|
||||
func TestAccountCycleDepthLimit(t *testing.T) {
|
||||
var last *server.Account
|
||||
chainLen := server.MaxAccountCycleSearchDepth + 1
|
||||
|
||||
// Services
|
||||
for i := 1; i <= chainLen; i++ {
|
||||
acc := server.NewAccount(fmt.Sprintf("ACC-%d", i))
|
||||
if err := acc.AddServiceExport("*", nil); err != nil {
|
||||
t.Fatalf("Error adding service export to '*': %v", err)
|
||||
}
|
||||
if last != nil {
|
||||
err := acc.AddServiceImport(last, "foo", "foo")
|
||||
switch i {
|
||||
case chainLen:
|
||||
if err != server.ErrCycleSearchDepth {
|
||||
t.Fatalf("Expected last import to fail with '%v', but got '%v'", server.ErrCycleSearchDepth, err)
|
||||
}
|
||||
default:
|
||||
if err != nil {
|
||||
t.Fatalf("Error adding service import to 'foo': %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
last = acc
|
||||
}
|
||||
|
||||
last = nil
|
||||
|
||||
// Streams
|
||||
for i := 1; i <= chainLen; i++ {
|
||||
acc := server.NewAccount(fmt.Sprintf("ACC-%d", i))
|
||||
if err := acc.AddStreamExport("foo", nil); err != nil {
|
||||
t.Fatalf("Error adding stream export to '*': %v", err)
|
||||
}
|
||||
if last != nil {
|
||||
err := acc.AddStreamImport(last, "foo", "")
|
||||
switch i {
|
||||
case chainLen:
|
||||
if err != server.ErrCycleSearchDepth {
|
||||
t.Fatalf("Expected last import to fail with '%v', but got '%v'", server.ErrCycleSearchDepth, err)
|
||||
}
|
||||
default:
|
||||
if err != nil {
|
||||
t.Fatalf("Error adding stream import to 'foo': %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
last = acc
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,7 +122,7 @@ func (sc *supercluster) setupLatencyTracking(t *testing.T, p int) {
|
||||
t.Fatalf("Error adding latency tracking to 'FOO': %v", err)
|
||||
}
|
||||
if err := bar.AddServiceImport(foo, "ngs.usage", "ngs.usage.bar"); err != nil {
|
||||
t.Fatalf("Error adding latency tracking to 'FOO': %v", err)
|
||||
t.Fatalf("Error adding service import to 'ngs.usage': %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user