Limit search depth for account cycles for imports

Signed-off-by: Derek Collison <derek@nats.io>
This commit is contained in:
Derek Collison
2020-12-02 11:44:27 -08:00
parent 9b107c0f4b
commit cddf23c200
4 changed files with 83 additions and 18 deletions

View File

@@ -1346,42 +1346,47 @@ func (a *Account) AddServiceImportWithClaim(destination *Account, from, to strin
}
// Check if this introduces a cycle before proceeding.
if a.serviceImportFormsCycle(destination, from) {
return ErrImportFormsCycle
if err := a.serviceImportFormsCycle(destination, from); err != nil {
return err
}
_, err := a.addServiceImport(destination, from, to, imClaim)
return err
}
func (a *Account) serviceImportFormsCycle(dest *Account, from string) bool {
const MaxAccountCycleSearchDepth = 1024
func (a *Account) serviceImportFormsCycle(dest *Account, from string) error {
return dest.checkServiceImportsForCycles(from, map[string]bool{a.Name: true})
}
func (a *Account) checkServiceImportsForCycles(from string, visited map[string]bool) bool {
func (a *Account) checkServiceImportsForCycles(from string, visited map[string]bool) error {
if len(visited) >= MaxAccountCycleSearchDepth {
return ErrCycleSearchDepth
}
a.mu.RLock()
for _, si := range a.imports.services {
if SubjectsCollide(from, si.to) {
a.mu.RUnlock()
if visited[si.acc.Name] {
return true
return ErrImportFormsCycle
}
// Push ourselves and check si.acc
visited[a.Name] = true
if subjectIsSubsetMatch(si.from, from) {
from = si.from
}
if si.acc.checkServiceImportsForCycles(from, visited) {
return true
if err := si.acc.checkServiceImportsForCycles(from, visited); err != nil {
return err
}
a.mu.RLock()
}
}
a.mu.RUnlock()
return false
return nil
}
func (a *Account) streamImportFormsCycle(dest *Account, to string) bool {
func (a *Account) streamImportFormsCycle(dest *Account, to string) error {
return dest.checkStreamImportsForCycles(to, map[string]bool{a.Name: true})
}
@@ -1395,33 +1400,37 @@ func (a *Account) hasStreamExportMatching(to string) bool {
return false
}
func (a *Account) checkStreamImportsForCycles(to string, visited map[string]bool) bool {
func (a *Account) checkStreamImportsForCycles(to string, visited map[string]bool) error {
if len(visited) >= MaxAccountCycleSearchDepth {
return ErrCycleSearchDepth
}
a.mu.RLock()
if !a.hasStreamExportMatching(to) {
a.mu.RUnlock()
return false
return nil
}
for _, si := range a.imports.streams {
if SubjectsCollide(to, si.to) {
a.mu.RUnlock()
if visited[si.acc.Name] {
return true
return ErrImportFormsCycle
}
// Push ourselves and check si.acc
visited[a.Name] = true
if subjectIsSubsetMatch(si.to, to) {
to = si.to
}
if si.acc.checkStreamImportsForCycles(to, visited) {
return true
if err := si.acc.checkStreamImportsForCycles(to, visited); err != nil {
return err
}
a.mu.RLock()
}
}
a.mu.RUnlock()
return false
return nil
}
// SetServiceImportSharing will allow sharing of information about requests with the export account.
@@ -2220,8 +2229,8 @@ func (a *Account) AddMappedStreamImportWithClaim(account *Account, from, to stri
}
// Check if this forms a cycle.
if a.streamImportFormsCycle(account, to) {
return ErrImportFormsCycle
if err := a.streamImportFormsCycle(account, to); err != nil {
return err
}
var (

View File

@@ -128,6 +128,9 @@ var (
// ErrImportFormsCycle is returned when an import would form a cycle.
ErrImportFormsCycle = errors.New("import forms a cycle")
// ErrCycleSearchDepth is returned when we have exceeded our maximum search depth..
ErrCycleSearchDepth = errors.New("search cycle depth exhausted")
// ErrClientOrRouteConnectedToGatewayPort represents an error condition when
// a client or route attempted to connect to the Gateway port.
ErrClientOrRouteConnectedToGatewayPort = errors.New("attempted to connect to gateway port")

View File

@@ -14,6 +14,7 @@
package test
import (
"fmt"
"os"
"strings"
"testing"
@@ -211,3 +212,55 @@ func TestAccountCycleServiceNonCycleChain(t *testing.T) {
t.Fatalf("Expected no error but got %s", err)
}
}
// Go's stack are infinite sans memory, but not call depth. However its good to limit.
func TestAccountCycleDepthLimit(t *testing.T) {
var last *server.Account
chainLen := server.MaxAccountCycleSearchDepth + 1
// Services
for i := 1; i <= chainLen; i++ {
acc := server.NewAccount(fmt.Sprintf("ACC-%d", i))
if err := acc.AddServiceExport("*", nil); err != nil {
t.Fatalf("Error adding service export to '*': %v", err)
}
if last != nil {
err := acc.AddServiceImport(last, "foo", "foo")
switch i {
case chainLen:
if err != server.ErrCycleSearchDepth {
t.Fatalf("Expected last import to fail with '%v', but got '%v'", server.ErrCycleSearchDepth, err)
}
default:
if err != nil {
t.Fatalf("Error adding service import to 'foo': %v", err)
}
}
}
last = acc
}
last = nil
// Streams
for i := 1; i <= chainLen; i++ {
acc := server.NewAccount(fmt.Sprintf("ACC-%d", i))
if err := acc.AddStreamExport("foo", nil); err != nil {
t.Fatalf("Error adding stream export to '*': %v", err)
}
if last != nil {
err := acc.AddStreamImport(last, "foo", "")
switch i {
case chainLen:
if err != server.ErrCycleSearchDepth {
t.Fatalf("Expected last import to fail with '%v', but got '%v'", server.ErrCycleSearchDepth, err)
}
default:
if err != nil {
t.Fatalf("Error adding stream import to 'foo': %v", err)
}
}
}
last = acc
}
}

View File

@@ -122,7 +122,7 @@ func (sc *supercluster) setupLatencyTracking(t *testing.T, p int) {
t.Fatalf("Error adding latency tracking to 'FOO': %v", err)
}
if err := bar.AddServiceImport(foo, "ngs.usage", "ngs.usage.bar"); err != nil {
t.Fatalf("Error adding latency tracking to 'FOO': %v", err)
t.Fatalf("Error adding service import to 'ngs.usage': %v", err)
}
}
}