diff --git a/server/gateway.go b/server/gateway.go index 5568e33a..1fb0ce9b 100644 --- a/server/gateway.go +++ b/server/gateway.go @@ -162,15 +162,8 @@ type srvGateway struct { m map[string]map[string]*sitally } - // This is to track recent subscriptions for a given connection + // This is to track recent subscriptions for a given account rsubs sync.Map - // This client will be used for SYSTEM clients when storing/looking up - // recent subscriptions in rsubs. This is because some code may not - // use the same actual *client object for SYSTEM client. For instance - // a raft node creates an internal client, that would be used to store - // the subscription in rsubs, but the sending part that checks for rsubs - // is using an internal client created in sendq.go's internalLoop. - sysCli *client resolver netResolver // Used to resolve host name before calling net.Dial() sqbsz int // Max buffer size to send queue subs protocol. Used for testing. @@ -366,7 +359,6 @@ func (s *Server) newGateway(opts *Options) error { resolver: opts.Gateway.resolver, runknown: opts.Gateway.RejectUnknown, oldHash: getOldHash(opts.Gateway.Name), - sysCli: &client{}, } gateway.Lock() defer gateway.Unlock() @@ -2367,16 +2359,13 @@ func (s *Server) gatewayUpdateSubInterest(accName string, sub *subscription, cha } if sub.client != nil { rsubs := &s.gateway.rsubs - c := sub.client - if c.kind == SYSTEM { - c = s.gateway.sysCli - } - sli, _ := rsubs.Load(c) + acc := sub.client.acc + sli, _ := rsubs.Load(acc) if change > 0 { var sl *Sublist if sli == nil { sl = NewSublistNoCache() - rsubs.Store(c, sl) + rsubs.Store(acc, sl) } else { sl = sli.(*Sublist) } @@ -2388,7 +2377,7 @@ func (s *Server) gatewayUpdateSubInterest(accName string, sub *subscription, cha sl := sli.(*Sublist) sl.Remove(sub) if sl.Count() == 0 { - rsubs.Delete(c) + rsubs.Delete(acc) } } } @@ -2427,21 +2416,10 @@ func hasGWRoutedReplyPrefix(subj []byte) bool { } // Evaluates if the given reply should be mapped or not. -func (g *srvGateway) shouldMapReplyForGatewaySend(c *client, acc *Account, reply []byte) bool { - // If the reply is a service reply (_R_), we will use the account's internal - // client instead of the client handed to us. This client holds the wildcard - // for all service replies. For other kind of connections, we still use the - // given `client` object. - if isServiceReply(reply) && c.kind == CLIENT { - acc.mu.Lock() - c = acc.internalClient() - acc.mu.Unlock() - } else if c.kind == SYSTEM { - c = g.sysCli - } - // If for this client there is a recent matching subscription interest +func (g *srvGateway) shouldMapReplyForGatewaySend(acc *Account, reply []byte) bool { + // If for this account there is a recent matching subscription interest // then we will map. - sli, _ := g.rsubs.Load(c) + sli, _ := g.rsubs.Load(acc) if sli == nil { return false } @@ -2566,7 +2544,7 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr // Assume we will use original mreply = reply // Decide if we should map. - if gw.shouldMapReplyForGatewaySend(c, acc, reply) { + if gw.shouldMapReplyForGatewaySend(acc, reply) { mreply = mreplya[:0] gwc.mu.Lock() useOldPrefix := gwc.gw.useOldPrefix diff --git a/server/jetstream_super_cluster_test.go b/server/jetstream_super_cluster_test.go index 1f1b756f..db701528 100644 --- a/server/jetstream_super_cluster_test.go +++ b/server/jetstream_super_cluster_test.go @@ -3761,3 +3761,77 @@ func TestJetStreamSuperClusterMixedModeSwitchToInterestOnlyOperatorConfig(t *tes waitForOutboundGateways(t, s, 2, 5*time.Second) check(s) } + +type captureGWRewriteLogger struct { + DummyLogger + ch chan string +} + +func (l *captureGWRewriteLogger) Tracef(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + if strings.Contains(msg, "$JS.SNAPSHOT.ACK.TEST") && strings.Contains(msg, gwReplyPrefix) { + select { + case l.ch <- msg: + default: + } + } +} + +func TestJetStreamSuperClusterGWReplyRewrite(t *testing.T) { + sc := createJetStreamSuperCluster(t, 3, 2) + defer sc.shutdown() + + nc, js := jsClientConnect(t, sc.serverByName("C1-S1")) + defer nc.Close() + + _, err := js.AddStream(&nats.StreamConfig{ + Name: "TEST", + Subjects: []string{"foo"}, + Replicas: 3, + }) + require_NoError(t, err) + sc.waitOnStreamLeader(globalAccountName, "TEST") + + for i := 0; i < 10; i++ { + sendStreamMsg(t, nc, "foo", "msg") + } + + nc2, _ := jsClientConnect(t, sc.serverByName("C2-S2")) + defer nc2.Close() + + s := sc.clusters[0].streamLeader(globalAccountName, "TEST") + var gws []*client + s.getOutboundGatewayConnections(&gws) + for _, gw := range gws { + gw.mu.Lock() + gw.trace = true + gw.mu.Unlock() + } + l := &captureGWRewriteLogger{ch: make(chan string, 1)} + s.SetLogger(l, false, true) + + // Send a request through the gateway + sreq := &JSApiStreamSnapshotRequest{ + DeliverSubject: nats.NewInbox(), + ChunkSize: 512, + } + natsSub(t, nc2, sreq.DeliverSubject, func(m *nats.Msg) { + m.Respond(nil) + }) + natsFlush(t, nc2) + req, _ := json.Marshal(sreq) + rmsg, err := nc2.Request(fmt.Sprintf(JSApiStreamSnapshotT, "TEST"), req, time.Second) + require_NoError(t, err) + var resp JSApiStreamSnapshotResponse + err = json.Unmarshal(rmsg.Data, &resp) + require_NoError(t, err) + if resp.Error != nil { + t.Fatalf("Did not get correct error response: %+v", resp.Error) + } + + // Now we just want to make sure that the reply has the gateway prefix + select { + case <-l.ch: + case <-time.After(10 * time.Second): + } +}