From cddf23c2005b4ea8947c188c70cfc9305b57bd1b Mon Sep 17 00:00:00 2001 From: Derek Collison Date: Wed, 2 Dec 2020 11:44:27 -0800 Subject: [PATCH] Limit search depth for account cycles for imports Signed-off-by: Derek Collison --- server/accounts.go | 43 +++++++++++++++++------------ server/errors.go | 3 ++ test/accounts_cycles_test.go | 53 ++++++++++++++++++++++++++++++++++++ test/service_latency_test.go | 2 +- 4 files changed, 83 insertions(+), 18 deletions(-) diff --git a/server/accounts.go b/server/accounts.go index 44987535..159d4b9a 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -1346,42 +1346,47 @@ func (a *Account) AddServiceImportWithClaim(destination *Account, from, to strin } // Check if this introduces a cycle before proceeding. - if a.serviceImportFormsCycle(destination, from) { - return ErrImportFormsCycle + if err := a.serviceImportFormsCycle(destination, from); err != nil { + return err } _, err := a.addServiceImport(destination, from, to, imClaim) return err } -func (a *Account) serviceImportFormsCycle(dest *Account, from string) bool { +const MaxAccountCycleSearchDepth = 1024 + +func (a *Account) serviceImportFormsCycle(dest *Account, from string) error { return dest.checkServiceImportsForCycles(from, map[string]bool{a.Name: true}) } -func (a *Account) checkServiceImportsForCycles(from string, visited map[string]bool) bool { +func (a *Account) checkServiceImportsForCycles(from string, visited map[string]bool) error { + if len(visited) >= MaxAccountCycleSearchDepth { + return ErrCycleSearchDepth + } a.mu.RLock() for _, si := range a.imports.services { if SubjectsCollide(from, si.to) { a.mu.RUnlock() if visited[si.acc.Name] { - return true + return ErrImportFormsCycle } // Push ourselves and check si.acc visited[a.Name] = true if subjectIsSubsetMatch(si.from, from) { from = si.from } - if si.acc.checkServiceImportsForCycles(from, visited) { - return true + if err := si.acc.checkServiceImportsForCycles(from, visited); err != nil { + return err } a.mu.RLock() } } a.mu.RUnlock() - return false + return nil } -func (a *Account) streamImportFormsCycle(dest *Account, to string) bool { +func (a *Account) streamImportFormsCycle(dest *Account, to string) error { return dest.checkStreamImportsForCycles(to, map[string]bool{a.Name: true}) } @@ -1395,33 +1400,37 @@ func (a *Account) hasStreamExportMatching(to string) bool { return false } -func (a *Account) checkStreamImportsForCycles(to string, visited map[string]bool) bool { +func (a *Account) checkStreamImportsForCycles(to string, visited map[string]bool) error { + if len(visited) >= MaxAccountCycleSearchDepth { + return ErrCycleSearchDepth + } + a.mu.RLock() if !a.hasStreamExportMatching(to) { a.mu.RUnlock() - return false + return nil } for _, si := range a.imports.streams { if SubjectsCollide(to, si.to) { a.mu.RUnlock() if visited[si.acc.Name] { - return true + return ErrImportFormsCycle } // Push ourselves and check si.acc visited[a.Name] = true if subjectIsSubsetMatch(si.to, to) { to = si.to } - if si.acc.checkStreamImportsForCycles(to, visited) { - return true + if err := si.acc.checkStreamImportsForCycles(to, visited); err != nil { + return err } a.mu.RLock() } } a.mu.RUnlock() - return false + return nil } // SetServiceImportSharing will allow sharing of information about requests with the export account. @@ -2220,8 +2229,8 @@ func (a *Account) AddMappedStreamImportWithClaim(account *Account, from, to stri } // Check if this forms a cycle. - if a.streamImportFormsCycle(account, to) { - return ErrImportFormsCycle + if err := a.streamImportFormsCycle(account, to); err != nil { + return err } var ( diff --git a/server/errors.go b/server/errors.go index 480113d0..3610352e 100644 --- a/server/errors.go +++ b/server/errors.go @@ -128,6 +128,9 @@ var ( // ErrImportFormsCycle is returned when an import would form a cycle. ErrImportFormsCycle = errors.New("import forms a cycle") + // ErrCycleSearchDepth is returned when we have exceeded our maximum search depth.. + ErrCycleSearchDepth = errors.New("search cycle depth exhausted") + // ErrClientOrRouteConnectedToGatewayPort represents an error condition when // a client or route attempted to connect to the Gateway port. ErrClientOrRouteConnectedToGatewayPort = errors.New("attempted to connect to gateway port") diff --git a/test/accounts_cycles_test.go b/test/accounts_cycles_test.go index cdcce45b..4974cbae 100644 --- a/test/accounts_cycles_test.go +++ b/test/accounts_cycles_test.go @@ -14,6 +14,7 @@ package test import ( + "fmt" "os" "strings" "testing" @@ -211,3 +212,55 @@ func TestAccountCycleServiceNonCycleChain(t *testing.T) { t.Fatalf("Expected no error but got %s", err) } } + +// Go's stack are infinite sans memory, but not call depth. However its good to limit. +func TestAccountCycleDepthLimit(t *testing.T) { + var last *server.Account + chainLen := server.MaxAccountCycleSearchDepth + 1 + + // Services + for i := 1; i <= chainLen; i++ { + acc := server.NewAccount(fmt.Sprintf("ACC-%d", i)) + if err := acc.AddServiceExport("*", nil); err != nil { + t.Fatalf("Error adding service export to '*': %v", err) + } + if last != nil { + err := acc.AddServiceImport(last, "foo", "foo") + switch i { + case chainLen: + if err != server.ErrCycleSearchDepth { + t.Fatalf("Expected last import to fail with '%v', but got '%v'", server.ErrCycleSearchDepth, err) + } + default: + if err != nil { + t.Fatalf("Error adding service import to 'foo': %v", err) + } + } + } + last = acc + } + + last = nil + + // Streams + for i := 1; i <= chainLen; i++ { + acc := server.NewAccount(fmt.Sprintf("ACC-%d", i)) + if err := acc.AddStreamExport("foo", nil); err != nil { + t.Fatalf("Error adding stream export to '*': %v", err) + } + if last != nil { + err := acc.AddStreamImport(last, "foo", "") + switch i { + case chainLen: + if err != server.ErrCycleSearchDepth { + t.Fatalf("Expected last import to fail with '%v', but got '%v'", server.ErrCycleSearchDepth, err) + } + default: + if err != nil { + t.Fatalf("Error adding stream import to 'foo': %v", err) + } + } + } + last = acc + } +} diff --git a/test/service_latency_test.go b/test/service_latency_test.go index 00287a20..33225b16 100644 --- a/test/service_latency_test.go +++ b/test/service_latency_test.go @@ -122,7 +122,7 @@ func (sc *supercluster) setupLatencyTracking(t *testing.T, p int) { t.Fatalf("Error adding latency tracking to 'FOO': %v", err) } if err := bar.AddServiceImport(foo, "ngs.usage", "ngs.usage.bar"); err != nil { - t.Fatalf("Error adding latency tracking to 'FOO': %v", err) + t.Fatalf("Error adding service import to 'ngs.usage': %v", err) } } }