From 34ce97bb8c53cca13ddecdc2a5d092d00c90a246 Mon Sep 17 00:00:00 2001 From: Derek Collison Date: Thu, 16 Jan 2020 20:17:57 -0800 Subject: [PATCH] Added support for wildcards for service imports Signed-off-by: Derek Collison --- server/accounts.go | 45 +++++++++++++----- server/accounts_test.go | 100 ++++++++++++++++++++++++++++++++++++++++ server/client.go | 29 +++++++++--- 3 files changed, 156 insertions(+), 18 deletions(-) diff --git a/server/accounts.go b/server/accounts.go index d4b50fa6..83b9ef37 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -1,4 +1,4 @@ -// Copyright 2018-2019 The NATS Authors +// Copyright 2018-2020 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -182,6 +182,7 @@ type exportMap struct { type importMap struct { streams []*streamImport services map[string]*serviceImport // TODO(dlc) sync.Map may be better. + hasWC bool // This is for service import wildcards. } // NewAccount creates a new unlimited account with the given name. @@ -831,13 +832,15 @@ func (a *Account) AddServiceImportWithClaim(destination *Account, from, to strin if destination == nil { return ErrMissingAccount } + // Empty means use from. Also means we can use wildcards since we are not doing remapping. + if !IsValidSubject(from) || (to != "" && (!IsValidLiteralSubject(from) || !IsValidLiteralSubject(to))) { + return ErrInvalidSubject + } + // Empty means use from. if to == "" { to = from } - if !IsValidLiteralSubject(from) || !IsValidLiteralSubject(to) { - return ErrInvalidSubject - } // First check to see if the account has authorized us to route to the "to" subject. if !destination.checkServiceImportAuthorized(a, to, imClaim) { return ErrServiceImportAuthorization @@ -865,11 +868,25 @@ func (a *Account) NumServiceImports() int { // removeServiceImport will remove the route by subject. func (a *Account) removeServiceImport(subject string) { a.mu.Lock() + si, ok := a.imports.services[subject] - if ok && si != nil && si.ae { - a.nae-- - } delete(a.imports.services, subject) + + if ok && si != nil { + if si.ae { + a.nae-- + } + if a.imports.hasWC && subjectHasWildcard(subject) { + // Need to still make sure we have a wildcard entry. + a.imports.hasWC = false + for subject, _ := range a.imports.services { + if subjectHasWildcard(subject) { + a.imports.hasWC = true + break + } + } + } + } a.mu.Unlock() } @@ -997,7 +1014,7 @@ func (a *Account) SetMaxResponseMaps(max int) { a.maxnrm = int32(max) } -// Add a route to connect from an implicit route created for a response to a request. +// Add a service import to connect from an implicit import created for a response to a request. // This does no checks and should be only called by the msg processing code. Use // AddServiceImport from above if responding to user input or config changes, etc. func (a *Account) addServiceImport(dest *Account, from, to string, claim *jwt.Import) (*serviceImport, error) { @@ -1019,6 +1036,12 @@ func (a *Account) addServiceImport(dest *Account, from, to string, claim *jwt.Im return nil, fmt.Errorf("duplicate service import subject %q, previously used in import for account %q, subject %q", from, dup.acc.Name, dup.to) } + if subjectHasWildcard(from) { + a.imports.hasWC = true + } + if to == "" { + to = from + } si := &serviceImport{dest, claim, from, to, 0, rt, lat, nil, false, false, false, false} a.imports.services[from] = si a.mu.Unlock() @@ -1414,8 +1437,8 @@ func (a *Account) getServiceExport(subj string) *serviceExport { // This helper is used when trying to match a serviceExport record that is // represented by a wildcard. // Lock should be held on entry. -func (a *Account) getWildcardServiceExport(to string) *serviceExport { - tokens := strings.Split(to, tsep) +func (a *Account) getWildcardServiceExport(from string) *serviceExport { + tokens := strings.Split(from, tsep) for subj, ea := range a.exports.services { if isSubsetMatch(tokens, subj) { return ea @@ -1648,7 +1671,7 @@ func (a *Account) checkServiceImportAuthorized(account *Account, subject string, // Check if another account is authorized to route requests to this service. func (a *Account) checkServiceImportAuthorizedNoLock(account *Account, subject string, imClaim *jwt.Import) bool { // Find the subject in the services list. - if a.exports.services == nil || !IsValidLiteralSubject(subject) { + if a.exports.services == nil { return false } return a.checkServiceExportApproved(account, subject, imClaim) diff --git a/server/accounts_test.go b/server/accounts_test.go index f10c74d3..378076b6 100644 --- a/server/accounts_test.go +++ b/server/accounts_test.go @@ -1195,6 +1195,106 @@ func TestServiceExportWithWildcards(t *testing.T) { } } +func TestServiceImportWithWildcards(t *testing.T) { + s, fooAcc, barAcc := simpleAccountServer(t) + defer s.Shutdown() + + if err := fooAcc.AddServiceExport("test.*", nil); err != nil { + t.Fatalf("Error adding account service export to client foo: %v", err) + } + // We can not map wildcards atm, so if we supply a to mapping and a wildcard we should fail. + if err := barAcc.AddServiceImport(fooAcc, "test.*", "foo"); err == nil { + t.Fatalf("Expected error adding account service import with wildcard and mapping, got none") + } + if err := barAcc.AddServiceImport(fooAcc, "test.>", ""); err == nil { + t.Fatalf("Expected error adding account service import with broader wildcard, got none") + } + // This should work. + if err := barAcc.AddServiceImport(fooAcc, "test.*", ""); err != nil { + t.Fatalf("Error adding account service import: %v", err) + } + // Make sure we can send and receive. + cfoo, crFoo, _ := newClientForServer(s) + defer cfoo.nc.Close() + + if err := cfoo.registerWithAccount(fooAcc); err != nil { + t.Fatalf("Error registering client with 'foo' account: %v", err) + } + + // Now setup the resonder under cfoo + cfoo.parse([]byte("SUB test.* 1\r\n")) + + cbar, crBar, _ := newClientForServer(s) + defer cbar.nc.Close() + + if err := cbar.registerWithAccount(barAcc); err != nil { + t.Fatalf("Error registering client with 'bar' account: %v", err) + } + + // Now send the request. + go cbar.parseAndFlush([]byte("SUB bar 11\r\nPUB test.22 bar 4\r\nhelp\r\n")) + + // Now read the request from crFoo + l, err := crFoo.ReadString('\n') + if err != nil { + t.Fatalf("Error reading from client 'bar': %v", err) + } + + mraw := msgPat.FindAllStringSubmatch(l, -1) + if len(mraw) == 0 { + t.Fatalf("No message received") + } + matches := mraw[0] + if matches[SUB_INDEX] != "test.22" { + t.Fatalf("Did not get correct subject: '%s'", matches[SUB_INDEX]) + } + if matches[SID_INDEX] != "1" { + t.Fatalf("Did not get correct sid: '%s'", matches[SID_INDEX]) + } + // Make sure this looks like _INBOX + if !strings.HasPrefix(matches[REPLY_INDEX], "_R_.") { + t.Fatalf("Expected an _R_.* like reply, got '%s'", matches[REPLY_INDEX]) + } + checkPayload(crFoo, []byte("help\r\n"), t) + + replyOp := fmt.Sprintf("PUB %s 2\r\n22\r\n", matches[REPLY_INDEX]) + go cfoo.parseAndFlush([]byte(replyOp)) + + // Now read the response from crBar + l, err = crBar.ReadString('\n') + if err != nil { + t.Fatalf("Error reading from client 'bar': %v", err) + } + mraw = msgPat.FindAllStringSubmatch(l, -1) + if len(mraw) == 0 { + t.Fatalf("No message received") + } + matches = mraw[0] + if matches[SUB_INDEX] != "bar" { + t.Fatalf("Did not get correct subject: '%s'", matches[SUB_INDEX]) + } + if matches[SID_INDEX] != "11" { + t.Fatalf("Did not get correct sid: '%s'", matches[SID_INDEX]) + } + if matches[REPLY_INDEX] != "" { + t.Fatalf("Did not get correct sid: '%s'", matches[SID_INDEX]) + } + checkPayload(crBar, []byte("22\r\n"), t) + + // Remove the service import with the wildcard and make sure hasWC is cleared. + barAcc.removeServiceImport("test.*") + + barAcc.mu.Lock() + defer barAcc.mu.Unlock() + + if len(barAcc.imports.services) != 0 { + t.Fatalf("Expected no imported services, got %d", len(barAcc.imports.services)) + } + if barAcc.imports.hasWC { + t.Fatalf("Expected the hasWC flag to be cleared") + } +} + // Make sure the AddStreamExport function is additive if called multiple times. func TestAddStreamExport(t *testing.T) { s, fooAcc, barAcc := simpleAccountServer(t) diff --git a/server/client.go b/server/client.go index 471c3b47..a64fe0ef 100644 --- a/server/client.go +++ b/server/client.go @@ -2958,13 +2958,22 @@ func (c *client) checkForImportServices(acc *Account, msg []byte) bool { return false } + var didDeliver, matchedWC bool + acc.mu.RLock() si := acc.imports.services[string(c.pa.subject)] + if si == nil && acc.imports.hasWC { + // TODO(dlc) - this will be slow with large number of service imports, may need to revisit and optimize. + for subject, tsi := range acc.imports.services { + if subjectHasWildcard(subject) && subjectIsSubsetMatch(string(c.pa.subject), subject) { + si, matchedWC = tsi, true + break + } + } + } invalid := si != nil && si.invalid acc.mu.RUnlock() - var didDeliver bool - // Get the results from the other account for the mapped "to" subject. // If we have been marked invalid simply return here. if si != nil && !invalid && si.acc != nil && si.acc.sl != nil { @@ -2987,8 +2996,14 @@ func (c *client) checkForImportServices(acc *Account, msg []byte) bool { c.sendRTTPing() } } + + // Pick correct to subject. If we matched on a wildcard use the literal publish subject. + to := si.to + if matchedWC { + to = string(c.pa.subject) + } // FIXME(dlc) - Do L1 cache trick from above. - rr := si.acc.sl.Match(si.to) + rr := si.acc.sl.Match(to) // This gives us a notion that we have interest in this message. didDeliver = len(rr.psubs)+len(rr.qsubs) > 0 @@ -2997,7 +3012,7 @@ func (c *client) checkForImportServices(acc *Account, msg []byte) bool { // If so we need to clean that up. if !didDeliver && si.internal { // We may also have a response entry, so go through that way. - si.acc.checkForRespEntry(si.to) + si.acc.checkForRespEntry(to) } // If we are a route or gateway or leafnode and this message is flipped to a queue subscriber we @@ -3011,10 +3026,10 @@ func (c *client) checkForImportServices(acc *Account, msg []byte) bool { // try to send this converted message to all gateways. if c.srv.gateway.enabled { flags |= pmrCollectQueueNames - queues := c.processMsgResults(si.acc, rr, msg, []byte(si.to), nrr, flags) - didDeliver = c.sendMsgToGateways(si.acc, msg, []byte(si.to), nrr, queues) || didDeliver + queues := c.processMsgResults(si.acc, rr, msg, []byte(to), nrr, flags) + didDeliver = c.sendMsgToGateways(si.acc, msg, []byte(to), nrr, queues) || didDeliver } else { - c.processMsgResults(si.acc, rr, msg, []byte(si.to), nrr, flags) + c.processMsgResults(si.acc, rr, msg, []byte(to), nrr, flags) } shouldRemove := si.ae