From 7fa088c804cccffab7a4472ca340bd8d11f90fbe Mon Sep 17 00:00:00 2001 From: jnmoyne Date: Tue, 1 Mar 2022 16:51:02 -0800 Subject: [PATCH] Adds deterministic subject tokens to partition mapping introduces 'Moustache' style subject mapping format (e.g. foo.*.* -> foo.{{wildcard(1)}}.{{wildcard(2)}}.{{partition(10,1,2)}}) --- server/accounts.go | 123 ++++++++++++++++++++++----- server/accounts_test.go | 25 ++++++ test/accounts_cycles_test.go | 159 +++++++++++++++++++++++++++++++++++ 3 files changed, 286 insertions(+), 21 deletions(-) diff --git a/server/accounts.go b/server/accounts.go index 38b20270..6bb40df9 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -18,6 +18,7 @@ import ( "encoding/hex" "errors" "fmt" + "hash/fnv" "hash/maphash" "io/ioutil" "math" @@ -25,6 +26,7 @@ import ( "net/http" "net/textproto" "reflect" + "regexp" "sort" "strconv" "strings" @@ -4114,18 +4116,73 @@ type transform struct { src, dest string dtoks []string stoks []string - dtpi []int8 + dtpi [][]int // destination token position indexes + dtpinp []int32 // destination token position index number of partitions } -// Helper to pull raw place holder index. Returns -1 if not a place holder. -func placeHolderIndex(token string) int { - if len(token) > 1 && token[0] == '$' { - var tp int - if n, err := fmt.Sscanf(token, "$%d", &tp); err == nil && n == 1 { - return tp +func getMappingFunctionArgs(functionName string, token string) ([]string, error) { + regEx := `{{` + functionName + ` *\((.*)\)}}` + r1, err := regexp.Compile(regEx) + if err != nil { + return nil, err + } + commandStrings := r1.FindStringSubmatch(token) + if len(commandStrings) > 1 { + ra, err := regexp.Compile(`,\s*`) + if err != nil { + return nil, err + } + args := ra.Split(commandStrings[1], -1) + return args, nil + } + return nil, nil +} + +// Helper to pull raw place holder indexes and number of partitions. Returns -1 if not a place holder. +func placeHolderIndex(token string) ([]int, int32) { + if len(token) > 1 { + // old $1, $2, etc... mapping format still supported to maintain backwards compatibility + if token[0] == '$' { // simple non-partition mapping + tp, err := strconv.Atoi(token[1:]) + if err == nil { + return []int{tp}, -1 + } + } + + // New 'moustache' style mapping + // wildcard(wildcard token index) (equivalent to $) + args, err := getMappingFunctionArgs("wildcard", token) + if err == nil && args != nil { + if len(args) == 1 { + tp, err := strconv.Atoi(strings.Trim(args[0], " ")) + if err == nil { + return []int{tp}, -1 + } + } + } + + // partition(number of partitions, token1, token2, ...) + args, err = getMappingFunctionArgs("partition", token) + if err == nil && args != nil { + if len(args) >= 2 { + tphnp, err := strconv.Atoi(strings.Trim(args[0], " ")) + if err == nil { + var numPositions = len(args[1:]) + tps := make([]int, numPositions) + for ti, t := range args[1:] { + i, err := strconv.Atoi(strings.Trim(t, " ")) + if err == nil { + tps[ti] = i + } else { + return []int{-1}, -1 + } + } + return tps, int32(tphnp) + } + } } } - return -1 + return []int{-1}, -1 } // newTransform will create a new transform checking the src and dest subjects for accuracy. @@ -4139,7 +4196,8 @@ func newTransform(src, dest string) (*transform, error) { return nil, ErrBadSubject } - var dtpi []int8 + var dtpi [][]int + var dtpinb []int32 // If the src has partial wildcards then the dest needs to have the token place markers. if npwcs > 0 || hasFwc { @@ -4153,25 +4211,30 @@ func newTransform(src, dest string) (*transform, error) { nphs := 0 for _, token := range dtokens { - tp := placeHolderIndex(token) - if tp >= 0 { - if tp > npwcs { - return nil, ErrBadSubject - } + tp, nb := placeHolderIndex(token) + if tp[0] >= 0 { nphs++ // Now build up our runtime mapping from dest to source tokens. - dtpi = append(dtpi, int8(sti[tp])) + var stis []int + for _, position := range tp { + if position > npwcs { + return nil, ErrBadSubject + } + stis = append(stis, sti[position]) + } + dtpi = append(dtpi, stis) + dtpinb = append(dtpinb, nb) } else { - dtpi = append(dtpi, -1) + dtpi = append(dtpi, []int{-1}) + dtpinb = append(dtpinb, -1) } } - - if nphs != npwcs { + if nphs < npwcs { return nil, ErrBadSubject } } - return &transform{src: src, dest: dest, dtoks: dtokens, stoks: stokens, dtpi: dtpi}, nil + return &transform{src: src, dest: dest, dtoks: dtokens, stoks: stokens, dtpi: dtpi, dtpinp: dtpinb}, nil } // match will take a literal published subject that is associated with a client and will match and transform @@ -4215,6 +4278,16 @@ func (tr *transform) transformSubject(subject string) (string, error) { return tr.transform(tts) } +func (tr *transform) getHashBucket(key string, numBuckets int) string { + h := fnv.New32a() + _, err := h.Write([]byte(key)) + if err != nil { + return fmt.Sprintf("error: %v", err) + } + + return fmt.Sprintf("%d", h.Sum32()%uint32(numBuckets)) +} + // Do a transform on the subject to the dest subject. func (tr *transform) transform(tokens []string) (string, error) { if len(tr.dtpi) == 0 { @@ -4230,7 +4303,7 @@ func (tr *transform) transform(tokens []string) (string, error) { li := len(tr.dtpi) - 1 for i, index := range tr.dtpi { // <0 means use destination token. - if index < 0 { + if index[0] < 0 { token = tr.dtoks[i] // Break if fwc if len(token) == 1 && token[0] == fwc { @@ -4238,7 +4311,15 @@ func (tr *transform) transform(tokens []string) (string, error) { } } else { // >= 0 means use source map index to figure out which source token to pull. - token = tokens[index] + if tr.dtpinp[i] > 0 { // there is a valid (i.e. not -1) value for number of partitions, this is a partition transform token + var keyForHashing string + for _, sourceToken := range tr.dtpi[i] { + keyForHashing = keyForHashing + tokens[sourceToken] + } + token = tr.getHashBucket(keyForHashing, int(tr.dtpinp[i])) + } else { // back to normal substitution + token = tokens[tr.dtpi[i][0]] + } } b.WriteString(token) if i < li { diff --git a/server/accounts_test.go b/server/accounts_test.go index 6a3ee7a8..dc5842c8 100644 --- a/server/accounts_test.go +++ b/server/accounts_test.go @@ -18,6 +18,7 @@ import ( "encoding/json" "fmt" "net/http" + "reflect" "strconv" "strings" "sync" @@ -46,6 +47,30 @@ func simpleAccountServer(t *testing.T) (*Server, *Account, *Account) { return s, f, b } +func TestPlaceHolderIndex(t *testing.T) { + testString := "$1" + indexes, nbPartitions := placeHolderIndex(testString) + + if len(indexes) != 1 || indexes[0] != 1 || nbPartitions != -1 { + t.Fatalf("Error parsing %s", testString) + } + + testString = "{{partition(10,1,2,3)}}" + + indexes, nbPartitions = placeHolderIndex(testString) + + if !reflect.DeepEqual(indexes, []int{1, 2, 3}) || nbPartitions != 10 { + t.Fatalf("Error parsing %s", testString) + } + + testString = "{{wildcard(2)}}" + indexes, nbPartitions = placeHolderIndex(testString) + + if len(indexes) != 1 || indexes[0] != 2 || nbPartitions != -1 { + t.Fatalf("Error parsing %s", testString) + } +} + func TestRegisterDuplicateAccounts(t *testing.T) { s, _, _ := simpleAccountServer(t) if _, err := s.RegisterAccount("$foo"); err == nil { diff --git a/test/accounts_cycles_test.go b/test/accounts_cycles_test.go index c4263b89..672c16a4 100644 --- a/test/accounts_cycles_test.go +++ b/test/accounts_cycles_test.go @@ -15,6 +15,7 @@ package test import ( "fmt" + "strconv" "strings" "testing" "time" @@ -351,6 +352,164 @@ func TestAccountCycleDepthLimit(t *testing.T) { } } +// Test token and partition subject mapping within an account +func TestAccountSubjectMapping(t *testing.T) { + conf := createConfFile(t, []byte(` + port: -1 + mappings = { + "foo.*.*" : "foo.$1.{{wildcard(2)}}.{{partition(10,1,2)}}" + } + `)) + defer removeFile(t, conf) + + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + nc1 := clientConnectToServer(t, s) + defer nc1.Close() + + numMessages := 100 + subjectsReceived := make(chan string) + + msg := []byte("HELLO") + sub1, err := nc1.Subscribe("foo.*.*.*", func(m *nats.Msg) { + subjectsReceived <- m.Subject + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + sub1.AutoUnsubscribe(numMessages * 2) + + nc2 := clientConnectToServer(t, s) + defer nc2.Close() + + // publish numMessages with an increasing id (should map to partition numbers with the range of 10 partitions) - twice + for j := 0; j < 2; j++ { + for i := 0; i < numMessages; i++ { + err = nc2.Publish(fmt.Sprintf("foo.%d.%d", i, numMessages-i), msg) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } + } + + // verify all the partition numbers are in the expected range + partitionsReceived := make([]int, numMessages) + + for i := 0; i < numMessages; i++ { + subject := <-subjectsReceived + sTokens := strings.Split(subject, ".") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + t1, _ := strconv.Atoi(sTokens[1]) + t2, _ := strconv.Atoi(sTokens[2]) + partitionsReceived[i], err = strconv.Atoi(sTokens[3]) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if partitionsReceived[i] > 9 || partitionsReceived[i] < 0 || t1 != i || t2 != numMessages-i { + t.Fatalf("Error received unexpected %d.%d to partition %d", t1, t2, partitionsReceived[i]) + } + } + + // verify hashing is deterministic by checking it produces the same exact result twice + for i := 0; i < numMessages; i++ { + subject := <-subjectsReceived + partitionNumber, err := strconv.Atoi(strings.Split(subject, ".")[3]) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if partitionsReceived[i] != partitionNumber { + t.Fatalf("Error: same id mapped to two different partitions") + } + } +} + +// test token and partition subject mapping within an account +// Alice imports from Bob with subject mapping +func TestAccountImportSubjectMapping(t *testing.T) { + conf := createConfFile(t, []byte(` + port: -1 + accounts { + A { + users: [{user: a, pass: x}] + imports [ {stream: {account: B, subject: "foo.*.*"}, to : "foo.$1.{{wildcard(2)}}.{{partition(10,1,2)}}"}] + } + B { + users: [{user: b, pass x}] + exports [ { stream: ">" } ] + } + } + `)) + defer removeFile(t, conf) + + s, opts := RunServerWithConfig(conf) + + defer s.Shutdown() + ncA := clientConnectToServerWithUP(t, opts, "a", "x") + defer ncA.Close() + + numMessages := 100 + subjectsReceived := make(chan string) + + msg := []byte("HELLO") + sub1, err := ncA.Subscribe("foo.*.*.*", func(m *nats.Msg) { + subjectsReceived <- m.Subject + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + sub1.AutoUnsubscribe(numMessages * 2) + + ncB := clientConnectToServerWithUP(t, opts, "b", "x") + defer ncB.Close() + + // publish numMessages with an increasing id (should map to partition numbers with the range of 10 partitions) - twice + for j := 0; j < 2; j++ { + for i := 0; i < numMessages; i++ { + err = ncB.Publish(fmt.Sprintf("foo.%d.%d", i, numMessages-i), msg) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } + } + + // verify all the partition numbers are in the expected range + partitionsReceived := make([]int, numMessages) + + for i := 0; i < numMessages; i++ { + subject := <-subjectsReceived + sTokens := strings.Split(subject, ".") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + t1, _ := strconv.Atoi(sTokens[1]) + t2, _ := strconv.Atoi(sTokens[2]) + partitionsReceived[i], err = strconv.Atoi(sTokens[3]) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if partitionsReceived[i] > 9 || partitionsReceived[i] < 0 || t1 != i || t2 != numMessages-i { + t.Fatalf("Error received unexpected %d.%d to partition %d", t1, t2, partitionsReceived[i]) + } + } + + // verify hashing is deterministic by checking it produces the same exact result twice + for i := 0; i < numMessages; i++ { + subject := <-subjectsReceived + partitionNumber, err := strconv.Atoi(strings.Split(subject, ".")[3]) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if partitionsReceived[i] != partitionNumber { + t.Fatalf("Error: same id mapped to two different partitions") + } + } +} + func clientConnectToServer(t *testing.T, s *server.Server) *nats.Conn { t.Helper() nc, err := nats.Connect(s.ClientURL(),