diff --git a/server/accounts_test.go b/server/accounts_test.go index 61904e28..e54f8903 100644 --- a/server/accounts_test.go +++ b/server/accounts_test.go @@ -1,4 +1,4 @@ -// Copyright 2018-2022 The NATS Authors +// Copyright 2018-2023 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -3683,3 +3683,75 @@ func TestAccountImportDuplicateResponseDeliveryWithLeafnodes(t *testing.T) { t.Fatalf("Expected only 1 response, got %d", n) } } + +func TestAccountReloadServiceImportPanic(t *testing.T) { + conf := createConfFile(t, []byte(` + listen: 127.0.0.1:-1 + accounts { + A { + users = [ { user: "a", pass: "p" } ] + exports [ { service: "HELP" } ] + } + B { + users = [ { user: "b", pass: "p" } ] + imports [ { service: { account: A, subject: "HELP"} } ] + } + $SYS { users = [ { user: "admin", pass: "s3cr3t!" } ] } + } + `)) + + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + // Now connect up the subscriber for HELP. No-op for this test. + nc, _ := jsClientConnect(t, s, nats.UserInfo("a", "p")) + _, err := nc.Subscribe("HELP", func(m *nats.Msg) { m.Respond([]byte("OK")) }) + require_NoError(t, err) + defer nc.Close() + + // Now create connection to account b where we will publish to HELP. + nc, _ = jsClientConnect(t, s, nats.UserInfo("b", "p")) + require_NoError(t, err) + defer nc.Close() + + // We want to continually be publishing messages that will trigger the service import while calling reload. + done := make(chan bool) + var wg sync.WaitGroup + wg.Add(1) + + var requests, responses atomic.Uint64 + reply := nats.NewInbox() + _, err = nc.Subscribe(reply, func(m *nats.Msg) { responses.Add(1) }) + require_NoError(t, err) + + go func() { + defer wg.Done() + for { + select { + case <-done: + return + default: + nc.PublishRequest("HELP", reply, []byte("HELP")) + requests.Add(1) + } + } + }() + + // Perform a bunch of reloads. + for i := 0; i < 1000; i++ { + err := s.Reload() + require_NoError(t, err) + } + + close(done) + wg.Wait() + + totalRequests := requests.Load() + checkFor(t, 10*time.Second, 250*time.Millisecond, func() error { + resp := responses.Load() + if resp == totalRequests { + return nil + } + return fmt.Errorf("Have not received all responses, want %d got %d", totalRequests, resp) + }) +} diff --git a/server/client.go b/server/client.go index f3752482..4886a18e 100644 --- a/server/client.go +++ b/server/client.go @@ -789,15 +789,16 @@ func (c *client) subsAtLimit() bool { } func minLimit(value *int32, limit int32) bool { - if *value != jwt.NoLimit { + v := atomic.LoadInt32(value) + if v != jwt.NoLimit { if limit != jwt.NoLimit { - if limit < *value { - *value = limit + if limit < v { + atomic.StoreInt32(value, limit) return true } } } else if limit != jwt.NoLimit { - *value = limit + atomic.StoreInt32(value, limit) return true } return false @@ -810,7 +811,7 @@ func (c *client) applyAccountLimits() { if c.acc == nil || (c.kind != CLIENT && c.kind != LEAF) { return } - c.mpay = jwt.NoLimit + atomic.StoreInt32(&c.mpay, jwt.NoLimit) c.msubs = jwt.NoLimit if c.opts.JWT != _EMPTY_ { // user jwt implies account if uc, _ := jwt.DecodeUserClaims(c.opts.JWT); uc != nil { @@ -3576,15 +3577,21 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) { } // Mostly under testing scenarios. + c.mu.Lock() if c.srv == nil || c.acc == nil { + c.mu.Unlock() return false, false } + acc := c.acc + genidAddr := &acc.sl.genid // Check pub permissions - if c.perms != nil && (c.perms.pub.allow != nil || c.perms.pub.deny != nil) && !c.pubAllowed(string(c.pa.subject)) { + if c.perms != nil && (c.perms.pub.allow != nil || c.perms.pub.deny != nil) && !c.pubAllowedFullCheck(string(c.pa.subject), true, true) { + c.mu.Unlock() c.pubPermissionViolation(c.pa.subject) return false, true } + c.mu.Unlock() // Now check for reserved replies. These are used for service imports. if c.kind == CLIENT && len(c.pa.reply) > 0 && isReservedReply(c.pa.reply) { @@ -3605,10 +3612,10 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) { // performance impact reported in our bench) var isGWRouted bool if c.kind != CLIENT { - if atomic.LoadInt32(&c.acc.gwReplyMapping.check) > 0 { - c.acc.mu.RLock() - c.pa.subject, isGWRouted = c.acc.gwReplyMapping.get(c.pa.subject) - c.acc.mu.RUnlock() + if atomic.LoadInt32(&acc.gwReplyMapping.check) > 0 { + acc.mu.RLock() + c.pa.subject, isGWRouted = acc.gwReplyMapping.get(c.pa.subject) + acc.mu.RUnlock() } } else if atomic.LoadInt32(&c.gwReplyMapping.check) > 0 { c.mu.Lock() @@ -3651,7 +3658,7 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) { var r *SublistResult var ok bool - genid := atomic.LoadUint64(&c.acc.sl.genid) + genid := atomic.LoadUint64(genidAddr) if genid == c.in.genid && c.in.results != nil { r, ok = c.in.results[string(c.pa.subject)] } else { @@ -3662,7 +3669,7 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) { // Go back to the sublist data structure. if !ok { - r = c.acc.sl.Match(string(c.pa.subject)) + r = acc.sl.Match(string(c.pa.subject)) c.in.results[string(c.pa.subject)] = r // Prune the results cache. Keeps us from unbounded growth. Random delete. if len(c.in.results) > maxResultCacheSize { @@ -3693,7 +3700,7 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) { atomic.LoadInt64(&c.srv.gateway.totalQSubs) > 0 { flag |= pmrCollectQueueNames } - didDeliver, qnames = c.processMsgResults(c.acc, r, msg, c.pa.deliver, c.pa.subject, c.pa.reply, flag) + didDeliver, qnames = c.processMsgResults(acc, r, msg, c.pa.deliver, c.pa.subject, c.pa.reply, flag) } // Now deal with gateways @@ -3703,7 +3710,7 @@ func (c *client) processInboundClientMsg(msg []byte) (bool, bool) { reply = append(reply, '@') reply = append(reply, c.pa.deliver...) } - didDeliver = c.sendMsgToGateways(c.acc, msg, c.pa.subject, reply, qnames) || didDeliver + didDeliver = c.sendMsgToGateways(acc, msg, c.pa.subject, reply, qnames) || didDeliver } // Check to see if we did not deliver to anyone and the client has a reply subject set @@ -3909,6 +3916,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt checkJS = true } } + siAcc := si.acc acc.mu.RUnlock() // We have a special case where JetStream pulls in all service imports through one export. @@ -3939,7 +3947,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt } } else if !isResponse && si.latency != nil && tracking { // Check to see if this was a bad request with no reply and we were supposed to be tracking. - si.acc.sendBadRequestTrackingLatency(si, c, headers) + siAcc.sendBadRequestTrackingLatency(si, c, headers) } // Send tracking info here if we are tracking this response. @@ -3967,7 +3975,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt // Now check to see if this account has mappings that could affect the service import. // Can't use non-locked trick like in processInboundClientMsg, so just call into selectMappedSubject // so we only lock once. - nsubj, changed := si.acc.selectMappedSubject(to) + nsubj, changed := siAcc.selectMappedSubject(to) if changed { c.pa.mapped = []byte(to) to = nsubj @@ -3984,7 +3992,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt // Place our client info for the request in the original message. // This will survive going across routes, etc. if !isResponse { - isSysImport := si.acc == c.srv.SystemAccount() + isSysImport := siAcc == c.srv.SystemAccount() var ci *ClientInfo if hadPrevSi && c.pa.hdr >= 0 { var cis ClientInfo @@ -4025,11 +4033,11 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt c.pa.reply = nrr if changed && c.isMqtt() && c.pa.hdr > 0 { - c.srv.mqttStoreQoS1MsgForAccountOnNewSubject(c.pa.hdr, msg, si.acc.GetName(), to) + c.srv.mqttStoreQoS1MsgForAccountOnNewSubject(c.pa.hdr, msg, siAcc.GetName(), to) } // FIXME(dlc) - Do L1 cache trick like normal client? - rr := si.acc.sl.Match(to) + rr := siAcc.sl.Match(to) // If we are a route or gateway or leafnode and this message is flipped to a queue subscriber we // need to handle that since the processMsgResults will want a queue filter. @@ -4054,10 +4062,10 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt if c.srv.gateway.enabled { flags |= pmrCollectQueueNames var queues [][]byte - didDeliver, queues = c.processMsgResults(si.acc, rr, msg, c.pa.deliver, []byte(to), nrr, flags) - didDeliver = c.sendMsgToGateways(si.acc, msg, []byte(to), nrr, queues) || didDeliver + didDeliver, queues = c.processMsgResults(siAcc, rr, msg, c.pa.deliver, []byte(to), nrr, flags) + didDeliver = c.sendMsgToGateways(siAcc, msg, []byte(to), nrr, queues) || didDeliver } else { - didDeliver, _ = c.processMsgResults(si.acc, rr, msg, c.pa.deliver, []byte(to), nrr, flags) + didDeliver, _ = c.processMsgResults(siAcc, rr, msg, c.pa.deliver, []byte(to), nrr, flags) } // Restore to original values. @@ -4090,7 +4098,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt } else { // This is a main import and since we could not even deliver to the exporting account // go ahead and remove the respServiceImport we created above. - si.acc.removeRespServiceImport(rsi, reason) + siAcc.removeRespServiceImport(rsi, reason) } } } diff --git a/server/server.go b/server/server.go index ec623990..eeb12c27 100644 --- a/server/server.go +++ b/server/server.go @@ -753,6 +753,12 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error) opts := s.getOpts() + // We need to track service imports since we can not swap them out (unsub and re-sub) + // until the proper server struct accounts have been swapped in properly. Doing it in + // place could lead to data loss or server panic since account under new si has no real + // account and hence no sublist, so will panic on inbound message. + siMap := make(map[*Account][][]byte) + // Check opts and walk through them. We need to copy them here // so that we do not keep a real one sitting in the options. for _, acc := range opts.Accounts { @@ -773,12 +779,16 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error) // Collect the sids for the service imports since we are going to // replace with new ones. var sids [][]byte - c := a.ic for _, si := range a.imports.services { - if c != nil && si.sid != nil { + if si.sid != nil { sids = append(sids, si.sid) } } + // Setup to process later if needed. + if len(sids) > 0 || len(acc.imports.services) > 0 { + siMap[a] = sids + } + // Now reset all export/imports fields since they are going to be // filled in shallowCopy() a.imports.streams, a.imports.services = nil, nil @@ -787,14 +797,6 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error) // and pass `a` (our existing account) to get it updated. acc.shallowCopy(a) a.mu.Unlock() - // Need to release the lock for this. - s.mu.Unlock() - for _, sid := range sids { - c.processUnsub(sid) - } - // Add subscriptions for existing service imports. - a.addAllServiceImportSubs() - s.mu.Lock() create = false } } @@ -862,6 +864,7 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error) for _, si := range acc.imports.services { if v, ok := s.accounts.Load(si.acc.Name); ok { si.acc = v.(*Account) + // It is possible to allow for latency tracking inside your // own account, so lock only when not the same account. if si.acc == acc { @@ -889,6 +892,16 @@ func (s *Server) configureAccounts(reloading bool) (map[string]struct{}, error) return true }) + // Check if we need to process service imports pending from above. + // This processing needs to be after we swap in the real accounts above. + for acc, sids := range siMap { + c := acc.ic + for _, sid := range sids { + c.processUnsub(sid) + } + acc.addAllServiceImportSubs() + } + // Set the system account if it was configured. // Otherwise create a default one. if opts.SystemAccount != _EMPTY_ {