From 22833c8d1a827bf06e48b8a71344cf8293c859e2 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Mon, 3 Aug 2020 11:18:02 -0600 Subject: [PATCH] Fix sysSubscribe races Made changes to processSub() to accept subscription properties, including the icb callback so that it is set prior to add the subscription to the account's sublist, which prevent races. Fixed some other racy conditions, notably in addServiceImportSub() Signed-off-by: Ivan Kozlovic --- server/accounts.go | 59 +++++++++++++++-------------------------- server/accounts_test.go | 40 ++++++++++++++++++++++++++++ server/client.go | 40 +++++++++++++++++++--------- server/errors.go | 6 +++++ server/events.go | 12 ++------- server/events_test.go | 54 +++++++++++++++++++++++++++++++++++++ server/jetstream.go | 6 +---- server/norace_test.go | 8 ++++++ server/parser.go | 2 +- server/reload_test.go | 2 +- server/stream.go | 11 +------- test/jetstream_test.go | 2 +- test/leafnode_test.go | 2 ++ 13 files changed, 166 insertions(+), 78 deletions(-) diff --git a/server/accounts.go b/server/accounts.go index 0f17687b..b9606f4f 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -104,7 +104,7 @@ type serviceImport struct { acc *Account claim *jwt.Import se *serviceExport - sub *subscription + sid []byte from string to string exsub string @@ -1153,8 +1153,8 @@ func (a *Account) removeServiceImport(subject string) { c := a.ic if ok && si != nil { - if a.ic != nil && si.sub != nil && si.sub.sid != nil { - sid = si.sub.sid + if a.ic != nil && si.sid != nil { + sid = si.sid } } a.mu.Unlock() @@ -1355,46 +1355,33 @@ func (a *Account) subscribeInternal(subject string, cb msgHandler) (*subscriptio return nil, fmt.Errorf("no internal account client") } - sub, err := c.processSub([]byte(subject+" "+sid), false) - if err != nil { - return nil, err - } - - sub.icb = cb - return sub, nil + return c.processSub([]byte(subject), nil, []byte(sid), cb, false) } // This will add an account subscription that matches the "from" from a service import entry. func (a *Account) addServiceImportSub(si *serviceImport) error { a.mu.Lock() c := a.internalClient() - sid := strconv.FormatUint(a.isid+1, 10) - a.mu.Unlock() - // This will happen in parsing when the account has not been properly setup. if c == nil { + a.mu.Unlock() return nil } - - if si.sub != nil { + if si.sid != nil { + a.mu.Unlock() return fmt.Errorf("duplicate call to create subscription for service import") } - - sub, err := c.processSub([]byte(si.from+" "+sid), true) - if err != nil { - return err - } - - sub.icb = func(sub *subscription, c *client, subject, reply string, msg []byte) { - c.processServiceImport(si, a, msg) - } - - a.mu.Lock() a.isid++ - si.sub = sub + sid := strconv.FormatUint(a.isid, 10) + si.sid = []byte(sid) + subject := si.from a.mu.Unlock() - return nil + cb := func(sub *subscription, c *client, subject, reply string, msg []byte) { + c.processServiceImport(si, a, msg) + } + _, err := c.processSub([]byte(subject), nil, []byte(sid), cb, true) + return err } // Remove all the subscriptions associated with service imports. @@ -1402,9 +1389,9 @@ func (a *Account) removeAllServiceImportSubs() { a.mu.RLock() var sids [][]byte for _, si := range a.imports.services { - if si.sub != nil && si.sub.sid != nil { - sids = append(sids, si.sub.sid) - si.sub = nil + if si.sid != nil { + sids = append(sids, si.sid) + si.sid = nil } } c := a.ic @@ -1486,14 +1473,12 @@ func (a *Account) createRespWildcard() []byte { pre := a.siReply wcsub := append(a.siReply, '>') c := a.internalClient() - a.isid += 1 + a.isid++ sid := strconv.FormatUint(a.isid, 10) a.mu.Unlock() // Create subscription and internal callback for all the wildcard response subjects. - if sub, _ := c.processSub([]byte(string(wcsub)+" "+sid), false); sub != nil { - sub.icb = a.processServiceImportResponse - } + c.processSub(wcsub, nil, []byte(sid), a.processServiceImportResponse, false) return pre } @@ -2451,8 +2436,8 @@ func (s *Server) UpdateAccountClaims(a *Account, ac *jwt.AccountClaims) { a.mu.RLock() c := a.ic for _, si := range old.imports.services { - if c != nil && si.sub != nil && si.sub.sid != nil { - sids = append(sids, si.sub.sid) + if c != nil && si.sid != nil { + sids = append(sids, si.sid) } } a.mu.RUnlock() diff --git a/server/accounts_test.go b/server/accounts_test.go index 24fa0fb9..4d5f730a 100644 --- a/server/accounts_test.go +++ b/server/accounts_test.go @@ -1193,6 +1193,46 @@ func TestServiceExportWithWildcards(t *testing.T) { } } +func TestAccountAddServiceImportRace(t *testing.T) { + s, fooAcc, barAcc := simpleAccountServer(t) + defer s.Shutdown() + + if err := fooAcc.AddServiceExport("foo.*", nil); err != nil { + t.Fatalf("Error adding account service export to client foo: %v", err) + } + + total := 100 + errCh := make(chan error, total) + for i := 0; i < 100; i++ { + go func(i int) { + err := barAcc.AddServiceImport(fooAcc, fmt.Sprintf("foo.%d", i), "") + errCh <- err // nil is a valid value. + + }(i) + } + + for i := 0; i < 100; i++ { + err := <-errCh + if err != nil { + t.Fatalf("Error adding account service import: %v", err) + } + } + + barAcc.mu.Lock() + lens := len(barAcc.imports.services) + c := barAcc.internalClient() + barAcc.mu.Unlock() + if lens != total { + t.Fatalf("Expected %d imported services, got %d", total, lens) + } + c.mu.Lock() + lens = len(c.subs) + c.mu.Unlock() + if lens != total { + t.Fatalf("Expected %d subscriptions in internal client, got %d", total, lens) + } +} + func TestServiceImportWithWildcards(t *testing.T) { s, fooAcc, barAcc := simpleAccountServer(t) defer s.Shutdown() diff --git a/server/client.go b/server/client.go index 24e2b1d9..fcdf08f1 100644 --- a/server/client.go +++ b/server/client.go @@ -2104,25 +2104,39 @@ func splitArg(arg []byte) [][]byte { return args } -func (c *client) processSub(argo []byte, noForward bool) (*subscription, error) { +func (c *client) parseSub(argo []byte, noForward bool) error { // Copy so we do not reference a potentially large buffer // FIXME(dlc) - make more efficient. arg := make([]byte, len(argo)) copy(arg, argo) args := splitArg(arg) - sub := &subscription{client: c} + var ( + subject []byte + queue []byte + sid []byte + ) switch len(args) { case 2: - sub.subject = args[0] - sub.queue = nil - sub.sid = args[1] + subject = args[0] + queue = nil + sid = args[1] case 3: - sub.subject = args[0] - sub.queue = args[1] - sub.sid = args[2] + subject = args[0] + queue = args[1] + sid = args[2] default: - return nil, fmt.Errorf("processSub Parse Error: '%s'", arg) + return fmt.Errorf("processSub Parse Error: '%s'", arg) } + // If there was an error, it has been sent to the client. We don't return an + // error here to not close the connection as a parsing error. + c.processSub(subject, queue, sid, nil, noForward) + return nil +} + +func (c *client) processSub(subject, queue, bsid []byte, cb msgHandler, noForward bool) (*subscription, error) { + + // Create the subscription + sub := &subscription{client: c, subject: subject, queue: queue, sid: bsid, icb: cb} c.mu.Lock() @@ -2155,12 +2169,12 @@ func (c *client) processSub(argo []byte, noForward bool) (*subscription, error) if !c.canQueueSubscribe(string(sub.subject), string(sub.queue)) { c.mu.Unlock() c.subPermissionViolation(sub) - return nil, nil + return nil, ErrSubscribePermissionViolation } } else if !c.canSubscribe(string(sub.subject)) { c.mu.Unlock() c.subPermissionViolation(sub) - return nil, nil + return nil, ErrSubscribePermissionViolation } } @@ -2168,7 +2182,7 @@ func (c *client) processSub(argo []byte, noForward bool) (*subscription, error) if c.subsAtLimit() { c.mu.Unlock() c.maxSubsExceeded() - return nil, nil + return nil, ErrTooManySubs } var updateGWs bool @@ -2192,7 +2206,7 @@ func (c *client) processSub(argo []byte, noForward bool) (*subscription, error) if err != nil { c.sendErr("Invalid Subject") - return nil, nil + return nil, ErrMalformedSubject } else if c.opts.Verbose && kind != SYSTEM { c.sendOK() } diff --git a/server/errors.go b/server/errors.go index e323ce06..e3dfc58f 100644 --- a/server/errors.go +++ b/server/errors.go @@ -154,6 +154,12 @@ var ( // ErrClusterNameRemoteConflict signals that a remote server has a different cluster name. ErrClusterNameRemoteConflict = errors.New("cluster name from remote server conflicts") + + // ErrMalformedSubject is returned when a subscription is made with a subject that does not conform to subject rules. + ErrMalformedSubject = errors.New("malformed subject") + + // ErrSubscribePermissionViolation is returned when processing of a subscription fails due to permissions. + ErrSubscribePermissionViolation = errors.New("subscribe permission viloation") ) // configErr is a configuration error. diff --git a/server/events.go b/server/events.go index 5e7ad5a4..1e5ec2fd 100644 --- a/server/events.go +++ b/server/events.go @@ -1227,20 +1227,12 @@ func (s *Server) systemSubscribe(subject string, internalOnly bool, cb msgHandle sid := strconv.Itoa(s.sys.sid) s.mu.Unlock() - arg := []byte(subject + " " + sid) if trace { - c.traceInOp("SUB", arg) + c.traceInOp("SUB", []byte(subject+" "+sid)) } // Now create the subscription - sub, err := c.processSub(arg, internalOnly) - if err != nil { - return nil, err - } - c.mu.Lock() - sub.icb = cb - c.mu.Unlock() - return sub, nil + return c.processSub([]byte(subject), nil, []byte(sid), cb, internalOnly) } func (s *Server) sysUnsubscribe(sub *subscription) { diff --git a/server/events_test.go b/server/events_test.go index 43f48674..5fc6ef4d 100644 --- a/server/events_test.go +++ b/server/events_test.go @@ -563,6 +563,60 @@ func TestSystemAccountDisconnectBadLogin(t *testing.T) { } } +func TestSysSubscribeRace(t *testing.T) { + s, opts := runTrustedServer(t) + defer s.Shutdown() + + acc, akp := createAccount(s) + s.setSystemAccount(acc) + + url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) + + nc, err := nats.Connect(url, createUserCreds(t, s, akp)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + done := make(chan struct{}) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + for { + nc.Publish("foo", []byte("hello")) + select { + case <-done: + return + default: + } + } + }() + + time.Sleep(10 * time.Millisecond) + + received := make(chan struct{}) + // Create message callback handler. + cb := func(sub *subscription, producer *client, subject, reply string, msg []byte) { + select { + case received <- struct{}{}: + default: + } + } + // Now create an internal subscription + sub, err := s.sysSubscribe("foo", cb) + if sub == nil || err != nil { + t.Fatalf("Expected to subscribe, got %v", err) + } + select { + case <-received: + close(done) + case <-time.After(time.Second): + t.Fatalf("Did not receive the message") + } + wg.Wait() +} + func TestSystemAccountInternalSubscriptions(t *testing.T) { s, opts := runTrustedServer(t) defer s.Shutdown() diff --git a/server/jetstream.go b/server/jetstream.go index 91929ea0..cd7a9b50 100644 --- a/server/jetstream.go +++ b/server/jetstream.go @@ -1056,14 +1056,10 @@ func (t *StreamTemplate) createTemplateSubscriptions() error { sid := 1 for _, subject := range t.Config.Subjects { // Now create the subscription - sub, err := c.processSub([]byte(subject+" "+strconv.Itoa(sid)), false) - if err != nil { + if _, err := c.processSub([]byte(subject), nil, []byte(strconv.Itoa(sid)), t.processInboundTemplateMsg, false); err != nil { c.acc.DeleteStreamTemplate(t.Name) return err } - c.mu.Lock() - sub.icb = t.processInboundTemplateMsg - c.mu.Unlock() sid++ } return nil diff --git a/server/norace_test.go b/server/norace_test.go index 5509a151..535d35dd 100644 --- a/server/norace_test.go +++ b/server/norace_test.go @@ -910,3 +910,11 @@ func TestNoRaceLeafNodeClusterNameConflictDeadlock(t *testing.T) { checkLeafNodeConnected(t, s3) checkClusterFormed(t, s1, s2, s3) } + +// This test is same than TestAccountAddServiceImportRace but running +// without the -race flag, it would capture more easily the possible +// duplicate sid, resulting in less than expected number of subscriptions +// in the account's internal subscriptions map. +func TestNoRaceAccountAddServiceImportRace(t *testing.T) { + TestAccountAddServiceImportRace(t) +} diff --git a/server/parser.go b/server/parser.go index 4595776b..b3ac091b 100644 --- a/server/parser.go +++ b/server/parser.go @@ -580,7 +580,7 @@ func (c *client) parse(buf []byte) error { if trace { c.traceInOp("SUB", arg) } - _, err = c.processSub(arg, false) + err = c.parseSub(arg, false) case ROUTER: switch c.op { case 'R', 'r': diff --git a/server/reload_test.go b/server/reload_test.go index eb6b4cc7..39cc9ea3 100644 --- a/server/reload_test.go +++ b/server/reload_test.go @@ -3845,7 +3845,7 @@ func TestConfigReloadLeafNodeWithRemotesNoChanges(t *testing.T) { s1, o1 := RunServerWithConfig(conf1) defer s1.Shutdown() - u, err := url.Parse(fmt.Sprintf("nats://localhost:%d", o1.LeafNode.Port)) + u, err := url.Parse(fmt.Sprintf("nats://127.0.0.1:%d", o1.LeafNode.Port)) if err != nil { t.Fatalf("Error creating url: %v", err) } diff --git a/server/stream.go b/server/stream.go index 48651128..7b26b25a 100644 --- a/server/stream.go +++ b/server/stream.go @@ -576,16 +576,7 @@ func (mset *Stream) subscribeInternal(subject string, cb msgHandler) (*subscript mset.sid++ // Now create the subscription - sub, err := c.processSub([]byte(subject+" "+strconv.Itoa(mset.sid)), false) - if err != nil { - return nil, err - } else if sub == nil { - return nil, fmt.Errorf("malformed subject") - } - c.mu.Lock() - sub.icb = cb - c.mu.Unlock() - return sub, nil + return c.processSub([]byte(subject), nil, []byte(strconv.Itoa(mset.sid)), cb, false) } // Helper for unlocked stream. diff --git a/test/jetstream_test.go b/test/jetstream_test.go index f1702fe4..142758cf 100644 --- a/test/jetstream_test.go +++ b/test/jetstream_test.go @@ -688,7 +688,7 @@ func TestJetStreamAddStreamBadSubjects(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } e := scResp.Error - if e == nil || e.Code != 500 || e.Description != "malformed subject" { + if e == nil || e.Code != 500 || e.Description != server.ErrMalformedSubject.Error() { t.Fatalf("Did not get proper error response: %+v", e) } } diff --git a/test/leafnode_test.go b/test/leafnode_test.go index 5d8cfb84..456e9943 100644 --- a/test/leafnode_test.go +++ b/test/leafnode_test.go @@ -3558,6 +3558,8 @@ func TestServiceExportWithMultipleAccounts(t *testing.T) { }) nc2.Flush() + checkSubInterest(t, srvB, "INTERNAL", "foo", time.Second) + nc, err := nats.Connect(fmt.Sprintf("nats://good:pwd@%s:%d", optsB.Host, optsB.Port)) if err != nil { t.Fatalf("Error on connect: %v", err)