diff --git a/server/raft_helpers_test.go b/server/raft_helpers_test.go new file mode 100644 index 00000000..6d235d35 --- /dev/null +++ b/server/raft_helpers_test.go @@ -0,0 +1,276 @@ +// Copyright 2023 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Do not exlude this file with the !skip_js_tests since those helpers +// are also used by MQTT. + +package server + +import ( + "encoding/binary" + "fmt" + "math/rand" + "sync" + "testing" + "time" +) + +type stateMachine interface { + server() *Server + node() RaftNode + // This will call forward as needed so can be called on any node. + propose(data []byte) + // When entries have been committed and can be applied. + applyEntry(ce *CommittedEntry) + // When a leader change happens. + leaderChange(isLeader bool) + // Stop the raft group. + stop() + // Restart + restart() +} + +// Factory function needed for constructor. +type smFactory func(s *Server, cfg *RaftConfig, node RaftNode) stateMachine + +type smGroup []stateMachine + +// Leader of the group. +func (sg smGroup) leader() stateMachine { + for _, sm := range sg { + if sm.node().Leader() { + return sm + } + } + return nil +} + +// Wait on a leader to be elected. +func (sg smGroup) waitOnLeader() { + expires := time.Now().Add(10 * time.Second) + for time.Now().Before(expires) { + for _, sm := range sg { + if sm.node().Leader() { + return + } + } + time.Sleep(100 * time.Millisecond) + } +} + +// Pick a random member. +func (sg smGroup) randomMember() stateMachine { + return sg[rand.Intn(len(sg))] +} + +// Return a non-leader +func (sg smGroup) nonLeader() stateMachine { + for _, sm := range sg { + if !sm.node().Leader() { + return sm + } + } + return nil +} + +// Create a raft group and place on numMembers servers at random. +func (c *cluster) createRaftGroup(name string, numMembers int, smf smFactory) smGroup { + c.t.Helper() + if numMembers > len(c.servers) { + c.t.Fatalf("Members > Peers: %d vs %d", numMembers, len(c.servers)) + } + servers := append([]*Server{}, c.servers...) + rand.Shuffle(len(servers), func(i, j int) { servers[i], servers[j] = servers[j], servers[i] }) + return c.createRaftGroupWithPeers(name, servers[:numMembers], smf) +} + +func (c *cluster) createRaftGroupWithPeers(name string, servers []*Server, smf smFactory) smGroup { + c.t.Helper() + + var sg smGroup + var peers []string + + for _, s := range servers { + // generate peer names. + s.mu.RLock() + peers = append(peers, s.sys.shash) + s.mu.RUnlock() + } + + for _, s := range servers { + fs, err := newFileStore( + FileStoreConfig{StoreDir: c.t.TempDir(), BlockSize: defaultMediumBlockSize, AsyncFlush: false, SyncInterval: 5 * time.Minute}, + StreamConfig{Name: name, Storage: FileStorage}, + ) + require_NoError(c.t, err) + cfg := &RaftConfig{Name: name, Store: c.t.TempDir(), Log: fs} + s.bootstrapRaftNode(cfg, peers, true) + n, err := s.startRaftNode(globalAccountName, cfg) + require_NoError(c.t, err) + sm := smf(s, cfg, n) + sg = append(sg, sm) + go smLoop(sm) + } + return sg +} + +// Driver program for the state machine. +// Should be run in its own go routine. +func smLoop(sm stateMachine) { + s, n := sm.server(), sm.node() + qch, lch, aq := n.QuitC(), n.LeadChangeC(), n.ApplyQ() + + for { + select { + case <-s.quitCh: + return + case <-qch: + return + case <-aq.ch: + ces := aq.pop() + for _, ce := range ces { + sm.applyEntry(ce) + } + aq.recycle(&ces) + + case isLeader := <-lch: + sm.leaderChange(isLeader) + } + } +} + +// Simple implementation of a replicated state. +// The adder state just sums up int64 values. +type stateAdder struct { + sync.Mutex + s *Server + n RaftNode + cfg *RaftConfig + sum int64 +} + +// Simple getters for server and the raft node. +func (a *stateAdder) server() *Server { + a.Lock() + defer a.Unlock() + return a.s +} +func (a *stateAdder) node() RaftNode { + a.Lock() + defer a.Unlock() + return a.n +} + +func (a *stateAdder) propose(data []byte) { + a.Lock() + defer a.Unlock() + a.n.ForwardProposal(data) +} + +func (a *stateAdder) applyEntry(ce *CommittedEntry) { + a.Lock() + defer a.Unlock() + if ce == nil { + // This means initial state is done/replayed. + return + } + for _, e := range ce.Entries { + if e.Type == EntryNormal { + delta, _ := binary.Varint(e.Data) + a.sum += delta + } else if e.Type == EntrySnapshot { + a.sum, _ = binary.Varint(e.Data) + } + } + // Update applied. + a.n.Applied(ce.Index) +} + +func (a *stateAdder) leaderChange(isLeader bool) {} + +// Adder specific to change the total. +func (a *stateAdder) proposeDelta(delta int64) { + data := make([]byte, binary.MaxVarintLen64) + n := binary.PutVarint(data, int64(delta)) + a.propose(data[:n]) +} + +// Stop the group. +func (a *stateAdder) stop() { + a.Lock() + defer a.Unlock() + a.n.Stop() +} + +// Restart the group +func (a *stateAdder) restart() { + a.Lock() + defer a.Unlock() + + if a.n.State() != Closed { + return + } + + // The filestore is stopped as well, so need to extract the parts to recreate it. + rn := a.n.(*raft) + fs := rn.wal.(*fileStore) + + var err error + a.cfg.Log, err = newFileStore(fs.fcfg, fs.cfg.StreamConfig) + if err != nil { + panic(err) + } + a.n, err = a.s.startRaftNode(globalAccountName, a.cfg) + if err != nil { + panic(err) + } + // Finally restart the driver. + go smLoop(a) +} + +// Total for the adder state machine. +func (a *stateAdder) total() int64 { + a.Lock() + defer a.Unlock() + return a.sum +} + +// Install a snapshot. +func (a *stateAdder) snapshot(t *testing.T) { + a.Lock() + defer a.Unlock() + data := make([]byte, binary.MaxVarintLen64) + n := binary.PutVarint(data, a.sum) + snap := data[:n] + require_NoError(t, a.n.InstallSnapshot(snap)) +} + +// Helper to wait for a certain state. +func (rg smGroup) waitOnTotal(t *testing.T, expected int64) { + t.Helper() + checkFor(t, 20*time.Second, 200*time.Millisecond, func() error { + for _, sm := range rg { + asm := sm.(*stateAdder) + if total := asm.total(); total != expected { + return fmt.Errorf("Adder on %v has wrong total: %d vs %d", + asm.server(), total, expected) + } + } + return nil + }) +} + +// Factory function. +func newStateAdder(s *Server, cfg *RaftConfig, n RaftNode) stateMachine { + return &stateAdder{s: s, n: n, cfg: cfg} +}