diff --git a/v2/bitcask.go b/v2/bitcask.go index 481f789..43e01e1 100644 --- a/v2/bitcask.go +++ b/v2/bitcask.go @@ -2,6 +2,8 @@ package bitcask import ( "bytes" + "context" + "errors" "fmt" "hash/crc32" "io" @@ -32,6 +34,8 @@ const ( ttlIndexFile = "ttl_index" ) +var ErrContextDeadlineExceeded = errors.New("Context deadline exceeded.") + // Bitcask is a struct that represents a on-disk LSM and WAL data structure // and in-memory hash of key/value pairs as per the Bitcask paper and seen // in the Riak database. @@ -60,7 +64,7 @@ type Stats struct { // Stats returns statistics about the database including the number of // data files, keys and overall size on disk of the data -func (b *Bitcask) Stats() (stats Stats, err error) { +func (b *Bitcask) Stats(ctx context.Context) (stats Stats, err error) { if stats.Size, err = internal.DirSize(b.path); err != nil { return } @@ -118,7 +122,7 @@ func (b *Bitcask) Sync() error { } // Get fetches value for a key -func (b *Bitcask) Get(key []byte) ([]byte, error) { +func (b *Bitcask) Get(ctx context.Context, key []byte) ([]byte, error) { b.mu.RLock() defer b.mu.RUnlock() e, err := b.get(key) @@ -129,7 +133,7 @@ func (b *Bitcask) Get(key []byte) ([]byte, error) { } // Has returns true if the key exists in the database, false otherwise. -func (b *Bitcask) Has(key []byte) bool { +func (b *Bitcask) Has(ctx context.Context, key []byte) bool { b.mu.RLock() defer b.mu.RUnlock() _, found := b.trie.Search(key) @@ -140,7 +144,7 @@ func (b *Bitcask) Has(key []byte) bool { } // Put stores the key and value in the database. -func (b *Bitcask) Put(key, value []byte) error { +func (b *Bitcask) Put(ctx context.Context, key, value []byte) error { if len(key) == 0 { return ErrEmptyKey } @@ -178,7 +182,7 @@ func (b *Bitcask) Put(key, value []byte) error { } // PutWithTTL stores the key and value in the database with the given TTL -func (b *Bitcask) PutWithTTL(key, value []byte, ttl time.Duration) error { +func (b *Bitcask) PutWithTTL(ctx context.Context, key, value []byte, ttl time.Duration) error { if len(key) == 0 { return ErrEmptyKey } @@ -246,25 +250,33 @@ func (b *Bitcask) delete(key []byte) error { // deleted from the database. // If the function returns an error on any key, no further keys are processed, no // keys are deleted, and the first error is returned. -func (b *Bitcask) Sift(f func(key []byte) (bool, error)) (err error) { +func (b *Bitcask) Sift(ctx context.Context, f func(ctx context.Context, key []byte) (bool, error)) (err error) { keysToDelete := art.New() b.mu.RLock() b.trie.ForEach(func(node art.Node) bool { - if b.isExpired(node.Key()) { - keysToDelete.Insert(node.Key(), true) + select { + case <-ctx.Done(): + err = ErrContextDeadlineExceeded + return false + default: + if b.isExpired(node.Key()) { + keysToDelete.Insert(node.Key(), true) + return true + } + var shouldDelete bool + if shouldDelete, err = f(ctx, node.Key()); err != nil { + return false + } else if shouldDelete { + keysToDelete.Insert(node.Key(), true) + } return true } - var shouldDelete bool - if shouldDelete, err = f(node.Key()); err != nil { - return false - } else if shouldDelete { - keysToDelete.Insert(node.Key(), true) - } - return true }) b.mu.RUnlock() - + if err != nil { + return err + } b.mu.Lock() defer b.mu.Unlock() keysToDelete.ForEach(func(node art.Node) (cont bool) { @@ -297,20 +309,26 @@ func (b *Bitcask) DeleteAll() (err error) { // Scan performs a prefix scan of keys matching the given prefix and calling // the function `f` with the keys found. If the function returns an error // no further keys are processed and the first error is returned. -func (b *Bitcask) Scan(prefix []byte, f func(key []byte) error) (err error) { +func (b *Bitcask) Scan(ctx context.Context, prefix []byte, f func(ctx context.Context, key []byte) error) (err error) { b.mu.RLock() defer b.mu.RUnlock() b.trie.ForEachPrefix(prefix, func(node art.Node) bool { - // Skip the root node - if len(node.Key()) == 0 { + select { + case <-ctx.Done(): + err = ErrContextDeadlineExceeded + return false + default: + // Skip the root node + if len(node.Key()) == 0 { + return true + } + + if err = f(ctx, node.Key()); err != nil { + return false + } return true } - - if err = f(node.Key()); err != nil { - return false - } - return true }) return } @@ -320,29 +338,39 @@ func (b *Bitcask) Scan(prefix []byte, f func(key []byte) error) (err error) { // the function returns true, that key is deleted from the database. // If the function returns an error on any key, no further keys are processed, // no keys are deleted, and the first error is returned. -func (b *Bitcask) SiftScan(prefix []byte, f func(key []byte) (bool, error)) (err error) { +func (b *Bitcask) SiftScan(ctx context.Context, prefix []byte, f func(key []byte) (bool, error)) (err error) { keysToDelete := art.New() b.mu.RLock() b.trie.ForEachPrefix(prefix, func(node art.Node) bool { - // Skip the root node - if len(node.Key()) == 0 { - return true - } - if b.isExpired(node.Key()) { - keysToDelete.Insert(node.Key(), true) - return true - } - var shouldDelete bool - if shouldDelete, err = f(node.Key()); err != nil { + select { + case <-ctx.Done(): + err = ErrContextDeadlineExceeded return false - } else if shouldDelete { - keysToDelete.Insert(node.Key(), true) + default: + // Skip the root node + if len(node.Key()) == 0 { + return true + } + if b.isExpired(node.Key()) { + keysToDelete.Insert(node.Key(), true) + return true + } + var shouldDelete bool + if shouldDelete, err = f(node.Key()); err != nil { + return false + } else if shouldDelete { + keysToDelete.Insert(node.Key(), true) + } + return true } - return true }) b.mu.RUnlock() + if err != nil { + return + } + b.mu.Lock() defer b.mu.Unlock() keysToDelete.ForEach(func(node art.Node) (cont bool) { @@ -356,7 +384,7 @@ func (b *Bitcask) SiftScan(prefix []byte, f func(key []byte) (bool, error)) (err // start key and end key and calling the function `f` with the keys found. // If the function returns an error no further keys are processed and the // first error returned. -func (b *Bitcask) Range(start, end []byte, f func(key []byte) error) (err error) { +func (b *Bitcask) Range(ctx context.Context, start, end []byte, f func(ctx context.Context, key []byte) error) (err error) { if bytes.Compare(start, end) == 1 { return ErrInvalidRange } @@ -370,15 +398,21 @@ func (b *Bitcask) Range(start, end []byte, f func(key []byte) error) (err error) defer b.mu.RUnlock() b.trie.ForEachPrefix(commonPrefix, func(node art.Node) bool { - if bytes.Compare(node.Key(), start) >= 0 && bytes.Compare(node.Key(), end) <= 0 { - if err = f(node.Key()); err != nil { + select { + case <-ctx.Done(): + err = ErrContextDeadlineExceeded + return false + default: + if bytes.Compare(node.Key(), start) >= 0 && bytes.Compare(node.Key(), end) <= 0 { + if err = f(ctx, node.Key()); err != nil { + return false + } + return true + } else if bytes.Compare(node.Key(), start) >= 0 && bytes.Compare(node.Key(), end) > 0 { return false } return true - } else if bytes.Compare(node.Key(), start) >= 0 && bytes.Compare(node.Key(), end) > 0 { - return false } - return true }) return } @@ -389,7 +423,7 @@ func (b *Bitcask) Range(start, end []byte, f func(key []byte) error) (err error) // from the database. // If the function returns an error on any key, no further keys are processed, no // keys are deleted, and the first error is returned. -func (b *Bitcask) SiftRange(start, end []byte, f func(key []byte) (bool, error)) (err error) { +func (b *Bitcask) SiftRange(ctx context.Context, start, end []byte, f func(ctx context.Context, key []byte) (bool, error)) (err error) { if bytes.Compare(start, end) == 1 { return ErrInvalidRange } @@ -403,25 +437,35 @@ func (b *Bitcask) SiftRange(start, end []byte, f func(key []byte) (bool, error)) b.mu.RLock() b.trie.ForEachPrefix(commonPrefix, func(node art.Node) bool { - if bytes.Compare(node.Key(), start) >= 0 && bytes.Compare(node.Key(), end) <= 0 { - if b.isExpired(node.Key()) { - keysToDelete.Insert(node.Key(), true) + select { + case <-ctx.Done(): + err = ErrContextDeadlineExceeded + return false + default: + if bytes.Compare(node.Key(), start) >= 0 && bytes.Compare(node.Key(), end) <= 0 { + if b.isExpired(node.Key()) { + keysToDelete.Insert(node.Key(), true) + return true + } + var shouldDelete bool + if shouldDelete, err = f(ctx, node.Key()); err != nil { + return false + } else if shouldDelete { + keysToDelete.Insert(node.Key(), true) + } return true - } - var shouldDelete bool - if shouldDelete, err = f(node.Key()); err != nil { + } else if bytes.Compare(node.Key(), start) >= 0 && bytes.Compare(node.Key(), end) > 0 { return false - } else if shouldDelete { - keysToDelete.Insert(node.Key(), true) } return true - } else if bytes.Compare(node.Key(), start) >= 0 && bytes.Compare(node.Key(), end) > 0 { - return false } - return true }) b.mu.RUnlock() + if err != nil { + return + } + b.mu.Lock() defer b.mu.Unlock() @@ -434,7 +478,7 @@ func (b *Bitcask) SiftRange(start, end []byte, f func(key []byte) (bool, error)) } // Len returns the total number of keys in the database -func (b *Bitcask) Len() int { +func (b *Bitcask) Len(ctx context.Context) int { b.mu.RLock() defer b.mu.RUnlock() return b.trie.Size() @@ -493,15 +537,21 @@ func (b *Bitcask) runGC() (err error) { // Fold iterates over all keys in the database calling the function `f` for // each key. If the function returns an error, no further keys are processed // and the error is returned. -func (b *Bitcask) Fold(f func(key []byte) error) (err error) { +func (b *Bitcask) Fold(ctx context.Context, f func(ctx context.Context, key []byte) error) (err error) { b.mu.RLock() defer b.mu.RUnlock() b.trie.ForEach(func(node art.Node) bool { - if err = f(node.Key()); err != nil { + select { + case <-ctx.Done(): + err = ErrContextDeadlineExceeded return false + default: + if err = f(ctx, node.Key()); err != nil { + return false + } + return true } - return true }) return @@ -731,7 +781,7 @@ func (b *Bitcask) Merge() error { // Rewrite all key/value pairs into merged database // Doing this automatically strips deleted keys and // old key/value pairs - err = b.Fold(func(key []byte) error { + err = b.Fold(context.Background(), func(ctx context.Context, key []byte) error { item, _ := b.trie.Search(key) // if key was updated after start of merge operation, nothing to do if item.(internal.Item).FileID > filesToMerge[len(filesToMerge)-1] { @@ -743,11 +793,11 @@ func (b *Bitcask) Merge() error { } if e.Expiry != nil { - if err := mdb.PutWithTTL(key, e.Value, time.Until(*e.Expiry)); err != nil { + if err := mdb.PutWithTTL(ctx, key, e.Value, time.Until(*e.Expiry)); err != nil { return err } } else { - if err := mdb.Put(key, e.Value); err != nil { + if err := mdb.Put(ctx, key, e.Value); err != nil { return err } }