diff --git a/server/accounts.go b/server/accounts.go index 0f17687b..b9606f4f 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -104,7 +104,7 @@ type serviceImport struct { acc *Account claim *jwt.Import se *serviceExport - sub *subscription + sid []byte from string to string exsub string @@ -1153,8 +1153,8 @@ func (a *Account) removeServiceImport(subject string) { c := a.ic if ok && si != nil { - if a.ic != nil && si.sub != nil && si.sub.sid != nil { - sid = si.sub.sid + if a.ic != nil && si.sid != nil { + sid = si.sid } } a.mu.Unlock() @@ -1355,46 +1355,33 @@ func (a *Account) subscribeInternal(subject string, cb msgHandler) (*subscriptio return nil, fmt.Errorf("no internal account client") } - sub, err := c.processSub([]byte(subject+" "+sid), false) - if err != nil { - return nil, err - } - - sub.icb = cb - return sub, nil + return c.processSub([]byte(subject), nil, []byte(sid), cb, false) } // This will add an account subscription that matches the "from" from a service import entry. func (a *Account) addServiceImportSub(si *serviceImport) error { a.mu.Lock() c := a.internalClient() - sid := strconv.FormatUint(a.isid+1, 10) - a.mu.Unlock() - // This will happen in parsing when the account has not been properly setup. if c == nil { + a.mu.Unlock() return nil } - - if si.sub != nil { + if si.sid != nil { + a.mu.Unlock() return fmt.Errorf("duplicate call to create subscription for service import") } - - sub, err := c.processSub([]byte(si.from+" "+sid), true) - if err != nil { - return err - } - - sub.icb = func(sub *subscription, c *client, subject, reply string, msg []byte) { - c.processServiceImport(si, a, msg) - } - - a.mu.Lock() a.isid++ - si.sub = sub + sid := strconv.FormatUint(a.isid, 10) + si.sid = []byte(sid) + subject := si.from a.mu.Unlock() - return nil + cb := func(sub *subscription, c *client, subject, reply string, msg []byte) { + c.processServiceImport(si, a, msg) + } + _, err := c.processSub([]byte(subject), nil, []byte(sid), cb, true) + return err } // Remove all the subscriptions associated with service imports. @@ -1402,9 +1389,9 @@ func (a *Account) removeAllServiceImportSubs() { a.mu.RLock() var sids [][]byte for _, si := range a.imports.services { - if si.sub != nil && si.sub.sid != nil { - sids = append(sids, si.sub.sid) - si.sub = nil + if si.sid != nil { + sids = append(sids, si.sid) + si.sid = nil } } c := a.ic @@ -1486,14 +1473,12 @@ func (a *Account) createRespWildcard() []byte { pre := a.siReply wcsub := append(a.siReply, '>') c := a.internalClient() - a.isid += 1 + a.isid++ sid := strconv.FormatUint(a.isid, 10) a.mu.Unlock() // Create subscription and internal callback for all the wildcard response subjects. - if sub, _ := c.processSub([]byte(string(wcsub)+" "+sid), false); sub != nil { - sub.icb = a.processServiceImportResponse - } + c.processSub(wcsub, nil, []byte(sid), a.processServiceImportResponse, false) return pre } @@ -2451,8 +2436,8 @@ func (s *Server) UpdateAccountClaims(a *Account, ac *jwt.AccountClaims) { a.mu.RLock() c := a.ic for _, si := range old.imports.services { - if c != nil && si.sub != nil && si.sub.sid != nil { - sids = append(sids, si.sub.sid) + if c != nil && si.sid != nil { + sids = append(sids, si.sid) } } a.mu.RUnlock() diff --git a/server/accounts_test.go b/server/accounts_test.go index 24fa0fb9..4d5f730a 100644 --- a/server/accounts_test.go +++ b/server/accounts_test.go @@ -1193,6 +1193,46 @@ func TestServiceExportWithWildcards(t *testing.T) { } } +func TestAccountAddServiceImportRace(t *testing.T) { + s, fooAcc, barAcc := simpleAccountServer(t) + defer s.Shutdown() + + if err := fooAcc.AddServiceExport("foo.*", nil); err != nil { + t.Fatalf("Error adding account service export to client foo: %v", err) + } + + total := 100 + errCh := make(chan error, total) + for i := 0; i < 100; i++ { + go func(i int) { + err := barAcc.AddServiceImport(fooAcc, fmt.Sprintf("foo.%d", i), "") + errCh <- err // nil is a valid value. + + }(i) + } + + for i := 0; i < 100; i++ { + err := <-errCh + if err != nil { + t.Fatalf("Error adding account service import: %v", err) + } + } + + barAcc.mu.Lock() + lens := len(barAcc.imports.services) + c := barAcc.internalClient() + barAcc.mu.Unlock() + if lens != total { + t.Fatalf("Expected %d imported services, got %d", total, lens) + } + c.mu.Lock() + lens = len(c.subs) + c.mu.Unlock() + if lens != total { + t.Fatalf("Expected %d subscriptions in internal client, got %d", total, lens) + } +} + func TestServiceImportWithWildcards(t *testing.T) { s, fooAcc, barAcc := simpleAccountServer(t) defer s.Shutdown() diff --git a/server/client.go b/server/client.go index 24e2b1d9..fcdf08f1 100644 --- a/server/client.go +++ b/server/client.go @@ -2104,25 +2104,39 @@ func splitArg(arg []byte) [][]byte { return args } -func (c *client) processSub(argo []byte, noForward bool) (*subscription, error) { +func (c *client) parseSub(argo []byte, noForward bool) error { // Copy so we do not reference a potentially large buffer // FIXME(dlc) - make more efficient. arg := make([]byte, len(argo)) copy(arg, argo) args := splitArg(arg) - sub := &subscription{client: c} + var ( + subject []byte + queue []byte + sid []byte + ) switch len(args) { case 2: - sub.subject = args[0] - sub.queue = nil - sub.sid = args[1] + subject = args[0] + queue = nil + sid = args[1] case 3: - sub.subject = args[0] - sub.queue = args[1] - sub.sid = args[2] + subject = args[0] + queue = args[1] + sid = args[2] default: - return nil, fmt.Errorf("processSub Parse Error: '%s'", arg) + return fmt.Errorf("processSub Parse Error: '%s'", arg) } + // If there was an error, it has been sent to the client. We don't return an + // error here to not close the connection as a parsing error. + c.processSub(subject, queue, sid, nil, noForward) + return nil +} + +func (c *client) processSub(subject, queue, bsid []byte, cb msgHandler, noForward bool) (*subscription, error) { + + // Create the subscription + sub := &subscription{client: c, subject: subject, queue: queue, sid: bsid, icb: cb} c.mu.Lock() @@ -2155,12 +2169,12 @@ func (c *client) processSub(argo []byte, noForward bool) (*subscription, error) if !c.canQueueSubscribe(string(sub.subject), string(sub.queue)) { c.mu.Unlock() c.subPermissionViolation(sub) - return nil, nil + return nil, ErrSubscribePermissionViolation } } else if !c.canSubscribe(string(sub.subject)) { c.mu.Unlock() c.subPermissionViolation(sub) - return nil, nil + return nil, ErrSubscribePermissionViolation } } @@ -2168,7 +2182,7 @@ func (c *client) processSub(argo []byte, noForward bool) (*subscription, error) if c.subsAtLimit() { c.mu.Unlock() c.maxSubsExceeded() - return nil, nil + return nil, ErrTooManySubs } var updateGWs bool @@ -2192,7 +2206,7 @@ func (c *client) processSub(argo []byte, noForward bool) (*subscription, error) if err != nil { c.sendErr("Invalid Subject") - return nil, nil + return nil, ErrMalformedSubject } else if c.opts.Verbose && kind != SYSTEM { c.sendOK() } diff --git a/server/errors.go b/server/errors.go index e323ce06..e3dfc58f 100644 --- a/server/errors.go +++ b/server/errors.go @@ -154,6 +154,12 @@ var ( // ErrClusterNameRemoteConflict signals that a remote server has a different cluster name. ErrClusterNameRemoteConflict = errors.New("cluster name from remote server conflicts") + + // ErrMalformedSubject is returned when a subscription is made with a subject that does not conform to subject rules. + ErrMalformedSubject = errors.New("malformed subject") + + // ErrSubscribePermissionViolation is returned when processing of a subscription fails due to permissions. + ErrSubscribePermissionViolation = errors.New("subscribe permission viloation") ) // configErr is a configuration error. diff --git a/server/events.go b/server/events.go index 5e7ad5a4..1e5ec2fd 100644 --- a/server/events.go +++ b/server/events.go @@ -1227,20 +1227,12 @@ func (s *Server) systemSubscribe(subject string, internalOnly bool, cb msgHandle sid := strconv.Itoa(s.sys.sid) s.mu.Unlock() - arg := []byte(subject + " " + sid) if trace { - c.traceInOp("SUB", arg) + c.traceInOp("SUB", []byte(subject+" "+sid)) } // Now create the subscription - sub, err := c.processSub(arg, internalOnly) - if err != nil { - return nil, err - } - c.mu.Lock() - sub.icb = cb - c.mu.Unlock() - return sub, nil + return c.processSub([]byte(subject), nil, []byte(sid), cb, internalOnly) } func (s *Server) sysUnsubscribe(sub *subscription) { diff --git a/server/events_test.go b/server/events_test.go index 43f48674..5fc6ef4d 100644 --- a/server/events_test.go +++ b/server/events_test.go @@ -563,6 +563,60 @@ func TestSystemAccountDisconnectBadLogin(t *testing.T) { } } +func TestSysSubscribeRace(t *testing.T) { + s, opts := runTrustedServer(t) + defer s.Shutdown() + + acc, akp := createAccount(s) + s.setSystemAccount(acc) + + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + + nc, err := nats.Connect(url, createUserCreds(t, s, akp)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + done := make(chan struct{}) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + for { + nc.Publish("foo", []byte("hello")) + select { + case <-done: + return + default: + } + } + }() + + time.Sleep(10 * time.Millisecond) + + received := make(chan struct{}) + // Create message callback handler. + cb := func(sub *subscription, producer *client, subject, reply string, msg []byte) { + select { + case received <- struct{}{}: + default: + } + } + // Now create an internal subscription + sub, err := s.sysSubscribe("foo", cb) + if sub == nil || err != nil { + t.Fatalf("Expected to subscribe, got %v", err) + } + select { + case <-received: + close(done) + case <-time.After(time.Second): + t.Fatalf("Did not receive the message") + } + wg.Wait() +} + func TestSystemAccountInternalSubscriptions(t *testing.T) { s, opts := runTrustedServer(t) defer s.Shutdown() diff --git a/server/jetstream.go b/server/jetstream.go index 91929ea0..cd7a9b50 100644 --- a/server/jetstream.go +++ b/server/jetstream.go @@ -1056,14 +1056,10 @@ func (t *StreamTemplate) createTemplateSubscriptions() error { sid := 1 for _, subject := range t.Config.Subjects { // Now create the subscription - sub, err := c.processSub([]byte(subject+" "+strconv.Itoa(sid)), false) - if err != nil { + if _, err := c.processSub([]byte(subject), nil, []byte(strconv.Itoa(sid)), t.processInboundTemplateMsg, false); err != nil { c.acc.DeleteStreamTemplate(t.Name) return err } - c.mu.Lock() - sub.icb = t.processInboundTemplateMsg - c.mu.Unlock() sid++ } return nil diff --git a/server/norace_test.go b/server/norace_test.go index 5509a151..535d35dd 100644 --- a/server/norace_test.go +++ b/server/norace_test.go @@ -910,3 +910,11 @@ func TestNoRaceLeafNodeClusterNameConflictDeadlock(t *testing.T) { checkLeafNodeConnected(t, s3) checkClusterFormed(t, s1, s2, s3) } + +// This test is same than TestAccountAddServiceImportRace but running +// without the -race flag, it would capture more easily the possible +// duplicate sid, resulting in less than expected number of subscriptions +// in the account's internal subscriptions map. +func TestNoRaceAccountAddServiceImportRace(t *testing.T) { + TestAccountAddServiceImportRace(t) +} diff --git a/server/parser.go b/server/parser.go index 4595776b..b3ac091b 100644 --- a/server/parser.go +++ b/server/parser.go @@ -580,7 +580,7 @@ func (c *client) parse(buf []byte) error { if trace { c.traceInOp("SUB", arg) } - _, err = c.processSub(arg, false) + err = c.parseSub(arg, false) case ROUTER: switch c.op { case 'R', 'r': diff --git a/server/reload_test.go b/server/reload_test.go index eb6b4cc7..39cc9ea3 100644 --- a/server/reload_test.go +++ b/server/reload_test.go @@ -3845,7 +3845,7 @@ func TestConfigReloadLeafNodeWithRemotesNoChanges(t *testing.T) { s1, o1 := RunServerWithConfig(conf1) defer s1.Shutdown() - u, err := url.Parse(fmt.Sprintf("nats://localhost:%d", o1.LeafNode.Port)) + u, err := url.Parse(fmt.Sprintf("nats://127.0.0.1:%d", o1.LeafNode.Port)) if err != nil { t.Fatalf("Error creating url: %v", err) } diff --git a/server/stream.go b/server/stream.go index 48651128..7b26b25a 100644 --- a/server/stream.go +++ b/server/stream.go @@ -576,16 +576,7 @@ func (mset *Stream) subscribeInternal(subject string, cb msgHandler) (*subscript mset.sid++ // Now create the subscription - sub, err := c.processSub([]byte(subject+" "+strconv.Itoa(mset.sid)), false) - if err != nil { - return nil, err - } else if sub == nil { - return nil, fmt.Errorf("malformed subject") - } - c.mu.Lock() - sub.icb = cb - c.mu.Unlock() - return sub, nil + return c.processSub([]byte(subject), nil, []byte(strconv.Itoa(mset.sid)), cb, false) } // Helper for unlocked stream. diff --git a/test/jetstream_test.go b/test/jetstream_test.go index f1702fe4..142758cf 100644 --- a/test/jetstream_test.go +++ b/test/jetstream_test.go @@ -688,7 +688,7 @@ func TestJetStreamAddStreamBadSubjects(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } e := scResp.Error - if e == nil || e.Code != 500 || e.Description != "malformed subject" { + if e == nil || e.Code != 500 || e.Description != server.ErrMalformedSubject.Error() { t.Fatalf("Did not get proper error response: %+v", e) } } diff --git a/test/leafnode_test.go b/test/leafnode_test.go index 5d8cfb84..456e9943 100644 --- a/test/leafnode_test.go +++ b/test/leafnode_test.go @@ -3558,6 +3558,8 @@ func TestServiceExportWithMultipleAccounts(t *testing.T) { }) nc2.Flush() + checkSubInterest(t, srvB, "INTERNAL", "foo", time.Second) + nc, err := nats.Connect(fmt.Sprintf("nats://good:pwd@%s:%d", optsB.Host, optsB.Port)) if err != nil { t.Fatalf("Error on connect: %v", err)