Files
nats-server/server/avl/seqset.go
Derek Collison a778921b8c Fixed a bug that when sequences were deleted and we cleaned up empty nodes we would not redo heights and balances.
This caused a rotate operation to possibly return nil and replace our root with nil when non empty.

Signed-off-by: Derek Collison <derek@nats.io>
2023-07-30 11:01:32 -07:00

678 lines
14 KiB
Go

// Copyright 2023 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package avl
import (
"encoding/binary"
"errors"
"math/bits"
"sort"
)
// SequenceSet is a memory and encoding optimized set for storing unsigned ints.
//
// SequenceSet is ~80-100 times more efficient memory wise than a map[uint64]struct{}.
// SequenceSet is ~1.75 times slower at inserts than the same map.
// SequenceSet is not thread safe.
//
// We use an AVL tree with nodes that hold bitmasks for set membership.
//
// Encoding will convert to a space optimized encoding using bitmasks.
type SequenceSet struct {
root *node // root node
size int // number of items
nodes int // number of nodes
// Having this here vs on the stack in Insert/Delete
// makes a difference in memory usage.
changed bool
}
// Insert will insert the sequence into the set.
// The tree will be balanced inline.
func (ss *SequenceSet) Insert(seq uint64) {
if ss.root = ss.root.insert(seq, &ss.changed, &ss.nodes); ss.changed {
ss.changed = false
ss.size++
}
}
// Exists will return true iff the sequence is a member of this set.
func (ss *SequenceSet) Exists(seq uint64) bool {
for n := ss.root; n != nil; {
if seq < n.base {
n = n.l
continue
} else if seq >= n.base+numEntries {
n = n.r
continue
}
return n.exists(seq)
}
return false
}
// SetInitialMin should be used to set the initial minimum sequence when known.
// This will more effectively utilize space versus self selecting.
// The set should be empty.
func (ss *SequenceSet) SetInitialMin(min uint64) error {
if !ss.IsEmpty() {
return ErrSetNotEmpty
}
ss.root, ss.nodes = &node{base: min, h: 1}, 1
return nil
}
// Delete will remove the sequence from the set.
// Will optionally remove nodes and rebalance.
// Returns where the sequence was set.
func (ss *SequenceSet) Delete(seq uint64) bool {
if ss == nil || ss.root == nil {
return false
}
ss.root = ss.root.delete(seq, &ss.changed, &ss.nodes)
if ss.changed {
ss.changed = false
ss.size--
if ss.size == 0 {
ss.Empty()
}
return true
}
return false
}
// Size returns the number of items in the set.
func (ss *SequenceSet) Size() int {
return ss.size
}
// Nodes returns the number of nodes in the tree.
func (ss *SequenceSet) Nodes() int {
return ss.nodes
}
// Empty will clear all items from a set.
func (ss *SequenceSet) Empty() {
ss.root = nil
ss.size = 0
ss.nodes = 0
}
// IsEmpty is a fast check of the set being empty.
func (ss *SequenceSet) IsEmpty() bool {
if ss == nil || ss.root == nil {
return true
}
return false
}
// Range will invoke the given function for each item in the set.
// They will range over the set in ascending order.
// If the callback returns false we terminate the iteration.
func (ss *SequenceSet) Range(f func(uint64) bool) {
ss.root.iter(f)
}
// Heights returns the left and right heights of the tree.
func (ss *SequenceSet) Heights() (l, r int) {
if ss.root == nil {
return 0, 0
}
if ss.root.l != nil {
l = ss.root.l.h
}
if ss.root.r != nil {
r = ss.root.r.h
}
return l, r
}
// Returns min, max and number of set items.
func (ss *SequenceSet) State() (min, max, num uint64) {
if ss == nil || ss.root == nil {
return 0, 0, 0
}
min, max = ss.MinMax()
return min, max, uint64(ss.Size())
}
// MinMax will return the minunum and maximum values in the set.
func (ss *SequenceSet) MinMax() (min, max uint64) {
if ss.root == nil {
return 0, 0
}
for l := ss.root; l != nil; l = l.l {
if l.l == nil {
min = l.min()
}
}
for r := ss.root; r != nil; r = r.r {
if r.r == nil {
max = r.max()
}
}
return min, max
}
func clone(src *node, target **node) {
if src == nil {
return
}
n := &node{base: src.base, bits: src.bits, h: src.h}
*target = n
clone(src.l, &n.l)
clone(src.r, &n.r)
}
// Clone will return a clone of the given SequenceSet.
func (ss *SequenceSet) Clone() *SequenceSet {
if ss == nil {
return nil
}
css := &SequenceSet{nodes: ss.nodes, size: ss.size}
clone(ss.root, &css.root)
return css
}
// Union will union this SequenceSet with ssa.
func (ss *SequenceSet) Union(ssa ...*SequenceSet) {
for _, sa := range ssa {
sa.root.nodeIter(func(n *node) {
for nb, b := range n.bits {
for pos := uint64(0); b != 0; pos++ {
if b&1 == 1 {
seq := n.base + (uint64(nb) * uint64(bitsPerBucket)) + pos
ss.Insert(seq)
}
b >>= 1
}
}
})
}
}
// Union will return a union of all sets.
func Union(ssa ...*SequenceSet) *SequenceSet {
if len(ssa) == 0 {
return nil
}
// Sort so we can clone largest.
sort.Slice(ssa, func(i, j int) bool { return ssa[i].Size() > ssa[j].Size() })
ss := ssa[0].Clone()
// Insert the rest through range call.
for i := 1; i < len(ssa); i++ {
ssa[i].Range(func(n uint64) bool {
ss.Insert(n)
return true
})
}
return ss
}
const (
// Magic is used to identify the encode binary state..
magic = uint8(22)
// Version
version = uint8(2)
// hdrLen
hdrLen = 2
// minimum length of an encoded SequenceSet.
minLen = 2 + 8 // magic + version + num nodes + num entries.
)
// EncodeLen returns the bytes needed for encoding.
func (ss SequenceSet) EncodeLen() int {
return minLen + (ss.Nodes() * ((numBuckets+1)*8 + 2))
}
func (ss SequenceSet) Encode(buf []byte) ([]byte, error) {
nn, encLen := ss.Nodes(), ss.EncodeLen()
if cap(buf) < encLen {
buf = make([]byte, encLen)
} else {
buf = buf[:encLen]
}
// TODO(dlc) - Go 1.19 introduced Append to not have to keep track.
// Once 1.20 is out we could change this over.
// Also binary.Write() is way slower, do not use.
var le = binary.LittleEndian
buf[0], buf[1] = magic, version
i := hdrLen
le.PutUint32(buf[i:], uint32(nn))
le.PutUint32(buf[i+4:], uint32(ss.size))
i += 8
ss.root.nodeIter(func(n *node) {
le.PutUint64(buf[i:], n.base)
i += 8
for _, b := range n.bits {
le.PutUint64(buf[i:], b)
i += 8
}
le.PutUint16(buf[i:], uint16(n.h))
i += 2
})
return buf[:i], nil
}
// ErrBadEncoding is returned when we can not decode properly.
var (
ErrBadEncoding = errors.New("ss: bad encoding")
ErrBadVersion = errors.New("ss: bad version")
ErrSetNotEmpty = errors.New("ss: set not empty")
)
// Decode returns the sequence set and number of bytes read from the buffer on success.
func Decode(buf []byte) (*SequenceSet, int, error) {
if len(buf) < minLen || buf[0] != magic {
return nil, -1, ErrBadEncoding
}
switch v := buf[1]; v {
case 1:
return decodev1(buf)
case 2:
return decodev2(buf)
default:
return nil, -1, ErrBadVersion
}
}
// Helper to decode v2.
func decodev2(buf []byte) (*SequenceSet, int, error) {
var le = binary.LittleEndian
index := 2
nn := int(le.Uint32(buf[index:]))
sz := int(le.Uint32(buf[index+4:]))
index += 8
expectedLen := minLen + (nn * ((numBuckets+1)*8 + 2))
if len(buf) < expectedLen {
return nil, -1, ErrBadEncoding
}
ss, nodes := SequenceSet{size: sz}, make([]node, nn)
for i := 0; i < nn; i++ {
n := &nodes[i]
n.base = le.Uint64(buf[index:])
index += 8
for bi := range n.bits {
n.bits[bi] = le.Uint64(buf[index:])
index += 8
}
n.h = int(le.Uint16(buf[index:]))
index += 2
ss.insertNode(n)
}
return &ss, index, nil
}
// Helper to decode v1 into v2 which has fixed buckets of 32 vs 64 originally.
func decodev1(buf []byte) (*SequenceSet, int, error) {
var le = binary.LittleEndian
index := 2
nn := int(le.Uint32(buf[index:]))
sz := int(le.Uint32(buf[index+4:]))
index += 8
const v1NumBuckets = 64
expectedLen := minLen + (nn * ((v1NumBuckets+1)*8 + 2))
if len(buf) < expectedLen {
return nil, -1, ErrBadEncoding
}
var ss SequenceSet
for i := 0; i < nn; i++ {
base := le.Uint64(buf[index:])
index += 8
for nb := uint64(0); nb < v1NumBuckets; nb++ {
n := le.Uint64(buf[index:])
// Walk all set bits and insert sequences manually for this decode from v1.
for pos := uint64(0); n != 0; pos++ {
if n&1 == 1 {
seq := base + (nb * uint64(bitsPerBucket)) + pos
ss.Insert(seq)
}
n >>= 1
}
index += 8
}
// Skip over encoded height.
index += 2
}
// Sanity check.
if ss.Size() != sz {
return nil, -1, ErrBadEncoding
}
return &ss, index, nil
}
// insertNode places a decoded node into the tree.
// These should be done in tree order as defined by Encode()
// This allows us to not have to calculate height or do rebalancing.
// So much better performance this way.
func (ss *SequenceSet) insertNode(n *node) {
ss.nodes++
if ss.root == nil {
ss.root = n
return
}
// Walk our way to the insertion point.
for p := ss.root; p != nil; {
if n.base < p.base {
if p.l == nil {
p.l = n
return
}
p = p.l
} else {
if p.r == nil {
p.r = n
return
}
p = p.r
}
}
}
const (
bitsPerBucket = 64 // bits in uint64
numBuckets = 32
numEntries = numBuckets * bitsPerBucket
)
type node struct {
//v dvalue
base uint64
bits [numBuckets]uint64
l *node
r *node
h int
}
// Set the proper bit.
// seq should have already been qualified and inserted should be non nil.
func (n *node) set(seq uint64, inserted *bool) {
seq -= n.base
i := seq / bitsPerBucket
mask := uint64(1) << (seq % bitsPerBucket)
if (n.bits[i] & mask) == 0 {
n.bits[i] |= mask
*inserted = true
}
}
func (n *node) insert(seq uint64, inserted *bool, nodes *int) *node {
if n == nil {
base := (seq / numEntries) * numEntries
n := &node{base: base, h: 1}
n.set(seq, inserted)
*nodes++
return n
}
if seq < n.base {
n.l = n.l.insert(seq, inserted, nodes)
} else if seq >= n.base+numEntries {
n.r = n.r.insert(seq, inserted, nodes)
} else {
n.set(seq, inserted)
}
n.h = maxH(n) + 1
// Don't make a function, impacts performance.
if bf := balanceF(n); bf > 1 {
// Left unbalanced.
if balanceF(n.l) < 0 {
n.l = n.l.rotateL()
}
return n.rotateR()
} else if bf < -1 {
// Right unbalanced.
if balanceF(n.r) > 0 {
n.r = n.r.rotateR()
}
return n.rotateL()
}
return n
}
func (n *node) rotateL() *node {
r := n.r
if r != nil {
n.r = r.l
r.l = n
n.h = maxH(n) + 1
r.h = maxH(r) + 1
} else {
n.r = nil
n.h = maxH(n) + 1
}
return r
}
func (n *node) rotateR() *node {
l := n.l
if l != nil {
n.l = l.r
l.r = n
n.h = maxH(n) + 1
l.h = maxH(l) + 1
} else {
n.l = nil
n.h = maxH(n) + 1
}
return l
}
func balanceF(n *node) int {
if n == nil {
return 0
}
var lh, rh int
if n.l != nil {
lh = n.l.h
}
if n.r != nil {
rh = n.r.h
}
return lh - rh
}
func maxH(n *node) int {
if n == nil {
return 0
}
var lh, rh int
if n.l != nil {
lh = n.l.h
}
if n.r != nil {
rh = n.r.h
}
if lh > rh {
return lh
}
return rh
}
// Clear the proper bit.
// seq should have already been qualified and deleted should be non nil.
// Will return true if this node is now empty.
func (n *node) clear(seq uint64, deleted *bool) bool {
seq -= n.base
i := seq / bitsPerBucket
mask := uint64(1) << (seq % bitsPerBucket)
if (n.bits[i] & mask) != 0 {
n.bits[i] &^= mask
*deleted = true
}
for _, b := range n.bits {
if b != 0 {
return false
}
}
return true
}
func (n *node) delete(seq uint64, deleted *bool, nodes *int) *node {
if n == nil {
return nil
}
if seq < n.base {
n.l = n.l.delete(seq, deleted, nodes)
} else if seq >= n.base+numEntries {
n.r = n.r.delete(seq, deleted, nodes)
} else if empty := n.clear(seq, deleted); empty {
*nodes--
if n.l == nil {
n = n.r
} else if n.r == nil {
n = n.l
} else {
// We have both children.
n.r = n.r.insertNodePrev(n.l)
n = n.r
}
}
if n != nil {
n.h = maxH(n) + 1
}
// Check balance.
if bf := balanceF(n); bf > 1 {
// Left unbalanced.
if balanceF(n.l) < 0 {
n.l = n.l.rotateL()
}
return n.rotateR()
} else if bf < -1 {
// right unbalanced.
if balanceF(n.r) > 0 {
n.r = n.r.rotateR()
}
return n.rotateL()
}
return n
}
// Will insert nn into the node assuming it is less than all other nodes in n.
// Will re-calculate height and balance.
func (n *node) insertNodePrev(nn *node) *node {
if n.l == nil {
n.l = nn
} else {
n.l = n.l.insertNodePrev(nn)
}
n.h = maxH(n) + 1
// Check balance.
if bf := balanceF(n); bf > 1 {
// Left unbalanced.
if balanceF(n.l) < 0 {
n.l = n.l.rotateL()
}
return n.rotateR()
} else if bf < -1 {
// right unbalanced.
if balanceF(n.r) > 0 {
n.r = n.r.rotateR()
}
return n.rotateL()
}
return n
}
func (n *node) exists(seq uint64) bool {
seq -= n.base
i := seq / bitsPerBucket
mask := uint64(1) << (seq % bitsPerBucket)
return n.bits[i]&mask != 0
}
// Return minimum sequence in the set.
// This node can not be empty.
func (n *node) min() uint64 {
for i, b := range n.bits {
if b != 0 {
return n.base +
uint64(i*bitsPerBucket) +
uint64(bits.TrailingZeros64(b))
}
}
return 0
}
// Return maximum sequence in the set.
// This node can not be empty.
func (n *node) max() uint64 {
for i := numBuckets - 1; i >= 0; i-- {
if b := n.bits[i]; b != 0 {
return n.base +
uint64(i*bitsPerBucket) +
uint64(bitsPerBucket-bits.LeadingZeros64(b>>1))
}
}
return 0
}
// This is done in tree order.
func (n *node) nodeIter(f func(n *node)) {
if n == nil {
return
}
f(n)
n.l.nodeIter(f)
n.r.nodeIter(f)
}
// iter will iterate through the set's items in this node.
// If the supplied function returns false we terminate the iteration.
func (n *node) iter(f func(uint64) bool) bool {
if n == nil {
return true
}
if ok := n.l.iter(f); !ok {
return false
}
for num := n.base; num < n.base+numEntries; num++ {
if n.exists(num) {
if ok := f(num); !ok {
return false
}
}
}
if ok := n.r.iter(f); !ok {
return false
}
return true
}