diff --git a/server/accounts.go b/server/accounts.go index 36de78d2..cd3a3677 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -18,6 +18,7 @@ import ( "encoding/hex" "errors" "fmt" + "hash/fnv" "hash/maphash" "io/ioutil" "math" @@ -25,6 +26,7 @@ import ( "net/http" "net/textproto" "reflect" + "regexp" "sort" "strconv" "strings" @@ -161,6 +163,10 @@ const ( Chunked ) +var commaSeparatorRegEx = regexp.MustCompile(`,\s*`) +var partitionMappingFunctionRegEx = regexp.MustCompile(`{{\s*partition\s*\((.*)\)\s*}}`) +var wildcardMappingFunctionRegEx = regexp.MustCompile(`{{\s*wildcard\s*\((.*)\)\s*}}`) + // String helper. func (rt ServiceRespType) String() string { switch rt { @@ -4132,18 +4138,65 @@ type transform struct { src, dest string dtoks []string stoks []string - dtpi []int8 + dtpi [][]int // destination token position indexes + dtpinp []int32 // destination token position index number of partitions } -// Helper to pull raw place holder index. Returns -1 if not a place holder. -func placeHolderIndex(token string) int { - if len(token) > 1 && token[0] == '$' { - var tp int - if n, err := fmt.Sscanf(token, "$%d", &tp); err == nil && n == 1 { - return tp +func getMappingFunctionArgs(functionRegEx *regexp.Regexp, token string) []string { + commandStrings := functionRegEx.FindStringSubmatch(token) + if len(commandStrings) > 1 { + return commaSeparatorRegEx.Split(commandStrings[1], -1) + } + return nil +} + +// Helper to pull raw place holder indexes and number of partitions. Returns -1 if not a place holder. +func placeHolderIndex(token string) ([]int, int32, error) { + if len(token) > 1 { + // old $1, $2, etc... mapping format still supported to maintain backwards compatibility + if token[0] == '$' { // simple non-partition mapping + tp, err := strconv.Atoi(token[1:]) + if err != nil { + return []int{-1}, -1, nil + } + return []int{tp}, -1, nil + } + + // New 'moustache' style mapping + // wildcard(wildcard token index) (equivalent to $) + args := getMappingFunctionArgs(wildcardMappingFunctionRegEx, token) + if args != nil { + if len(args) == 1 { + tp, err := strconv.Atoi(strings.Trim(args[0], " ")) + if err != nil { + return []int{}, -1, err + } + return []int{tp}, -1, nil + } + } + + // partition(number of partitions, token1, token2, ...) + args = getMappingFunctionArgs(partitionMappingFunctionRegEx, token) + if args != nil { + if len(args) >= 2 { + tphnp, err := strconv.Atoi(strings.Trim(args[0], " ")) + if err != nil { + return []int{}, -1, err + } + var numPositions = len(args[1:]) + tps := make([]int, numPositions) + for ti, t := range args[1:] { + i, err := strconv.Atoi(strings.Trim(t, " ")) + if err != nil { + return []int{}, -1, err + } + tps[ti] = i + } + return tps, int32(tphnp), nil + } } } - return -1 + return []int{-1}, -1, nil } // newTransform will create a new transform checking the src and dest subjects for accuracy. @@ -4157,7 +4210,8 @@ func newTransform(src, dest string) (*transform, error) { return nil, ErrBadSubject } - var dtpi []int8 + var dtpi [][]int + var dtpinb []int32 // If the src has partial wildcards then the dest needs to have the token place markers. if npwcs > 0 || hasFwc { @@ -4171,25 +4225,33 @@ func newTransform(src, dest string) (*transform, error) { nphs := 0 for _, token := range dtokens { - tp := placeHolderIndex(token) - if tp >= 0 { - if tp > npwcs { - return nil, ErrBadSubject - } + tp, nb, err := placeHolderIndex(token) + if err != nil { + return nil, ErrBadSubjectMappingDestination + } + if tp[0] >= 0 { nphs++ // Now build up our runtime mapping from dest to source tokens. - dtpi = append(dtpi, int8(sti[tp])) + var stis []int + for _, position := range tp { + if position > npwcs { + return nil, ErrBadSubjectMappingDestination + } + stis = append(stis, sti[position]) + } + dtpi = append(dtpi, stis) + dtpinb = append(dtpinb, nb) } else { - dtpi = append(dtpi, -1) + dtpi = append(dtpi, []int{-1}) + dtpinb = append(dtpinb, -1) } } - - if nphs != npwcs { - return nil, ErrBadSubject + if nphs < npwcs { + return nil, ErrBadSubjectMappingDestination } } - return &transform{src: src, dest: dest, dtoks: dtokens, stoks: stokens, dtpi: dtpi}, nil + return &transform{src: src, dest: dest, dtoks: dtokens, stoks: stokens, dtpi: dtpi, dtpinp: dtpinb}, nil } // match will take a literal published subject that is associated with a client and will match and transform @@ -4233,6 +4295,13 @@ func (tr *transform) transformSubject(subject string) (string, error) { return tr.transform(tts) } +func (tr *transform) getHashPartition(key []byte, numBuckets int) string { + h := fnv.New32a() + h.Write(key) + + return strconv.Itoa(int(h.Sum32() % uint32(numBuckets))) +} + // Do a transform on the subject to the dest subject. func (tr *transform) transform(tokens []string) (string, error) { if len(tr.dtpi) == 0 { @@ -4248,7 +4317,7 @@ func (tr *transform) transform(tokens []string) (string, error) { li := len(tr.dtpi) - 1 for i, index := range tr.dtpi { // <0 means use destination token. - if index < 0 { + if index[0] < 0 { token = tr.dtoks[i] // Break if fwc if len(token) == 1 && token[0] == fwc { @@ -4256,7 +4325,18 @@ func (tr *transform) transform(tokens []string) (string, error) { } } else { // >= 0 means use source map index to figure out which source token to pull. - token = tokens[index] + if tr.dtpinp[i] > 0 { // there is a valid (i.e. not -1) value for number of partitions, this is a partition transform token + var ( + _buffer [64]byte + keyForHashing = _buffer[:0] + ) + for _, sourceToken := range tr.dtpi[i] { + keyForHashing = append(keyForHashing, []byte(tokens[sourceToken])...) + } + token = tr.getHashPartition(keyForHashing, int(tr.dtpinp[i])) + } else { // back to normal substitution + token = tokens[tr.dtpi[i][0]] + } } b.WriteString(token) if i < li { diff --git a/server/accounts_test.go b/server/accounts_test.go index 6a3ee7a8..458cf0cb 100644 --- a/server/accounts_test.go +++ b/server/accounts_test.go @@ -18,6 +18,7 @@ import ( "encoding/json" "fmt" "net/http" + "reflect" "strconv" "strings" "sync" @@ -46,6 +47,53 @@ func simpleAccountServer(t *testing.T) (*Server, *Account, *Account) { return s, f, b } +func TestPlaceHolderIndex(t *testing.T) { + testString := "$1" + indexes, nbPartitions, err := placeHolderIndex(testString) + + if err != nil || len(indexes) != 1 || indexes[0] != 1 || nbPartitions != -1 { + t.Fatalf("Error parsing %s", testString) + } + + testString = "{{partition(10,1,2,3)}}" + + indexes, nbPartitions, err = placeHolderIndex(testString) + + if err != nil || !reflect.DeepEqual(indexes, []int{1, 2, 3}) || nbPartitions != 10 { + t.Fatalf("Error parsing %s", testString) + } + + testString = "{{ partition(10,1,2,3) }}" + + indexes, nbPartitions, err = placeHolderIndex(testString) + + if err != nil || !reflect.DeepEqual(indexes, []int{1, 2, 3}) || nbPartitions != 10 { + t.Fatalf("Error parsing %s", testString) + } + + testString = "{{partition (10,1,2,3)}}" + + indexes, nbPartitions, err = placeHolderIndex(testString) + + if err != nil || !reflect.DeepEqual(indexes, []int{1, 2, 3}) || nbPartitions != 10 { + t.Fatalf("Error parsing %s", testString) + } + + testString = "{{wildcard(2)}}" + indexes, nbPartitions, err = placeHolderIndex(testString) + + if err != nil || len(indexes) != 1 || indexes[0] != 2 || nbPartitions != -1 { + t.Fatalf("Error parsing %s", testString) + } + + testString = "{{ wildcard (2) }}" + indexes, nbPartitions, err = placeHolderIndex(testString) + + if err != nil || len(indexes) != 1 || indexes[0] != 2 || nbPartitions != -1 { + t.Fatalf("Error parsing %s", testString) + } +} + func TestRegisterDuplicateAccounts(t *testing.T) { s, _, _ := simpleAccountServer(t) if _, err := s.RegisterAccount("$foo"); err == nil { @@ -3131,7 +3179,7 @@ func TestSamplingHeader(t *testing.T) { func TestSubjectTransforms(t *testing.T) { shouldErr := func(src, dest string) { t.Helper() - if _, err := newTransform(src, dest); err != ErrBadSubject { + if _, err := newTransform(src, dest); err != ErrBadSubject && err != ErrBadSubjectMappingDestination { t.Fatalf("Did not get an error for src=%q and dest=%q", src, dest) } } diff --git a/server/errors.go b/server/errors.go index 0edb5b9e..a5ba0ac9 100644 --- a/server/errors.go +++ b/server/errors.go @@ -46,6 +46,9 @@ var ( // ErrBadSubject represents an error condition for an invalid subject. ErrBadSubject = errors.New("invalid subject") + // ErrBadSubjectMappingDestination is used to error on a bad transform destination mapping + ErrBadSubjectMappingDestination = errors.New("invalid subject mapping destination") + // ErrBadQualifier is used to error on a bad qualifier for a transform. ErrBadQualifier = errors.New("bad qualifier") diff --git a/server/opts.go b/server/opts.go index f1b5fbce..6cf87719 100644 --- a/server/opts.go +++ b/server/opts.go @@ -2372,7 +2372,7 @@ func parseAccountMappings(v interface{}, acc *Account, errors *[]error, warnings switch vv := v.(type) { case string: if err := acc.AddMapping(subj, v.(string)); err != nil { - err := &configErr{tk, fmt.Sprintf("Error adding mapping for %q: %v", subj, err)} + err := &configErr{tk, fmt.Sprintf("Error adding mapping for %q to %q : %v", subj, v.(string), err)} *errors = append(*errors, err) continue } @@ -2389,7 +2389,7 @@ func parseAccountMappings(v interface{}, acc *Account, errors *[]error, warnings // Now add them in.. if err := acc.AddWeightedMappings(subj, mappings...); err != nil { - err := &configErr{tk, fmt.Sprintf("Error adding mapping for %q: %v", subj, err)} + err := &configErr{tk, fmt.Sprintf("Error adding mapping for %q to %q : %v", subj, v.(string), err)} *errors = append(*errors, err) continue } @@ -2401,7 +2401,7 @@ func parseAccountMappings(v interface{}, acc *Account, errors *[]error, warnings } // Now add it in.. if err := acc.AddWeightedMappings(subj, mdest); err != nil { - err := &configErr{tk, fmt.Sprintf("Error adding mapping for %q: %v", subj, err)} + err := &configErr{tk, fmt.Sprintf("Error adding mapping for %q to %q : %v", subj, v.(string), err)} *errors = append(*errors, err) continue } diff --git a/test/accounts_cycles_test.go b/test/accounts_cycles_test.go index c4263b89..672c16a4 100644 --- a/test/accounts_cycles_test.go +++ b/test/accounts_cycles_test.go @@ -15,6 +15,7 @@ package test import ( "fmt" + "strconv" "strings" "testing" "time" @@ -351,6 +352,164 @@ func TestAccountCycleDepthLimit(t *testing.T) { } } +// Test token and partition subject mapping within an account +func TestAccountSubjectMapping(t *testing.T) { + conf := createConfFile(t, []byte(` + port: -1 + mappings = { + "foo.*.*" : "foo.$1.{{wildcard(2)}}.{{partition(10,1,2)}}" + } + `)) + defer removeFile(t, conf) + + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + nc1 := clientConnectToServer(t, s) + defer nc1.Close() + + numMessages := 100 + subjectsReceived := make(chan string) + + msg := []byte("HELLO") + sub1, err := nc1.Subscribe("foo.*.*.*", func(m *nats.Msg) { + subjectsReceived <- m.Subject + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + sub1.AutoUnsubscribe(numMessages * 2) + + nc2 := clientConnectToServer(t, s) + defer nc2.Close() + + // publish numMessages with an increasing id (should map to partition numbers with the range of 10 partitions) - twice + for j := 0; j < 2; j++ { + for i := 0; i < numMessages; i++ { + err = nc2.Publish(fmt.Sprintf("foo.%d.%d", i, numMessages-i), msg) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } + } + + // verify all the partition numbers are in the expected range + partitionsReceived := make([]int, numMessages) + + for i := 0; i < numMessages; i++ { + subject := <-subjectsReceived + sTokens := strings.Split(subject, ".") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + t1, _ := strconv.Atoi(sTokens[1]) + t2, _ := strconv.Atoi(sTokens[2]) + partitionsReceived[i], err = strconv.Atoi(sTokens[3]) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if partitionsReceived[i] > 9 || partitionsReceived[i] < 0 || t1 != i || t2 != numMessages-i { + t.Fatalf("Error received unexpected %d.%d to partition %d", t1, t2, partitionsReceived[i]) + } + } + + // verify hashing is deterministic by checking it produces the same exact result twice + for i := 0; i < numMessages; i++ { + subject := <-subjectsReceived + partitionNumber, err := strconv.Atoi(strings.Split(subject, ".")[3]) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if partitionsReceived[i] != partitionNumber { + t.Fatalf("Error: same id mapped to two different partitions") + } + } +} + +// test token and partition subject mapping within an account +// Alice imports from Bob with subject mapping +func TestAccountImportSubjectMapping(t *testing.T) { + conf := createConfFile(t, []byte(` + port: -1 + accounts { + A { + users: [{user: a, pass: x}] + imports [ {stream: {account: B, subject: "foo.*.*"}, to : "foo.$1.{{wildcard(2)}}.{{partition(10,1,2)}}"}] + } + B { + users: [{user: b, pass x}] + exports [ { stream: ">" } ] + } + } + `)) + defer removeFile(t, conf) + + s, opts := RunServerWithConfig(conf) + + defer s.Shutdown() + ncA := clientConnectToServerWithUP(t, opts, "a", "x") + defer ncA.Close() + + numMessages := 100 + subjectsReceived := make(chan string) + + msg := []byte("HELLO") + sub1, err := ncA.Subscribe("foo.*.*.*", func(m *nats.Msg) { + subjectsReceived <- m.Subject + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + sub1.AutoUnsubscribe(numMessages * 2) + + ncB := clientConnectToServerWithUP(t, opts, "b", "x") + defer ncB.Close() + + // publish numMessages with an increasing id (should map to partition numbers with the range of 10 partitions) - twice + for j := 0; j < 2; j++ { + for i := 0; i < numMessages; i++ { + err = ncB.Publish(fmt.Sprintf("foo.%d.%d", i, numMessages-i), msg) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } + } + + // verify all the partition numbers are in the expected range + partitionsReceived := make([]int, numMessages) + + for i := 0; i < numMessages; i++ { + subject := <-subjectsReceived + sTokens := strings.Split(subject, ".") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + t1, _ := strconv.Atoi(sTokens[1]) + t2, _ := strconv.Atoi(sTokens[2]) + partitionsReceived[i], err = strconv.Atoi(sTokens[3]) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if partitionsReceived[i] > 9 || partitionsReceived[i] < 0 || t1 != i || t2 != numMessages-i { + t.Fatalf("Error received unexpected %d.%d to partition %d", t1, t2, partitionsReceived[i]) + } + } + + // verify hashing is deterministic by checking it produces the same exact result twice + for i := 0; i < numMessages; i++ { + subject := <-subjectsReceived + partitionNumber, err := strconv.Atoi(strings.Split(subject, ".")[3]) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if partitionsReceived[i] != partitionNumber { + t.Fatalf("Error: same id mapped to two different partitions") + } + } +} + func clientConnectToServer(t *testing.T, s *server.Server) *nats.Conn { t.Helper() nc, err := nats.Connect(s.ClientURL(),