diff --git a/bitcask.go b/bitcask.go index 5780236..2cc5303 100644 --- a/bitcask.go +++ b/bitcask.go @@ -1,6 +1,7 @@ package bitcask import ( + "bytes" "errors" "fmt" "hash/crc32" @@ -270,6 +271,24 @@ func (b *Bitcask) Scan(prefix []byte, f func(key []byte) error) (err error) { return } +// Range performs a range scan of keys matching a range of keys between the +// start key and end key and calling the function `f` with the keys found. +// If the function returns an error no further keys are processed and the +// first error returned. +func (b *Bitcask) Range(start, end []byte, f func(key []byte) error) (err error) { + b.trie.ForEach(func(node art.Node) bool { + if bytes.Compare(node.Key(), start) >= 0 && bytes.Compare(node.Key(), end) <= 0 { + if err = f(node.Key()); err != nil { + return false + } + return true + } else { + return false + } + }) + return +} + // Len returns the total number of keys in the database func (b *Bitcask) Len() int { b.mu.RLock() diff --git a/bitcask_test.go b/bitcask_test.go index a1b42c7..807c69a 100644 --- a/bitcask_test.go +++ b/bitcask_test.go @@ -1626,6 +1626,61 @@ func TestScan(t *testing.T) { }) } +func TestRange(t *testing.T) { + assert := assert.New(t) + + testdir, err := ioutil.TempDir("", "bitcask") + assert.NoError(err) + + var db *Bitcask + + t.Run("Setup", func(t *testing.T) { + t.Run("Open", func(t *testing.T) { + db, err = Open(testdir) + assert.NoError(err) + }) + + t.Run("Put", func(t *testing.T) { + for i := 1; i < 10; i++ { + key := []byte(fmt.Sprintf("foo_%d", i)) + val := []byte(fmt.Sprintf("%d", i)) + err = db.Put(key, val) + assert.NoError(err) + } + }) + }) + + t.Run("Range", func(t *testing.T) { + var ( + vals [][]byte + expected = [][]byte{ + []byte("3"), + []byte("4"), + []byte("5"), + []byte("6"), + []byte("7"), + } + ) + + err = db.Range([]byte("foo_3"), []byte("foo_7"), func(key []byte) error { + val, err := db.Get(key) + assert.NoError(err) + vals = append(vals, val) + return nil + }) + vals = SortByteArrays(vals) + assert.Equal(expected, vals) + }) + + t.Run("RangeErrors", func(t *testing.T) { + err = db.Range([]byte("foo_3"), []byte("foo_7"), func(key []byte) error { + return ErrMockError + }) + assert.Error(err) + assert.Equal(ErrMockError, err) + }) +} + func TestLocking(t *testing.T) { assert := assert.New(t) diff --git a/cmd/bitcask/range.go b/cmd/bitcask/range.go new file mode 100644 index 0000000..191262b --- /dev/null +++ b/cmd/bitcask/range.go @@ -0,0 +1,61 @@ +package main + +import ( + "fmt" + "os" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "github.com/prologic/bitcask" +) + +var rangeCmd = &cobra.Command{ + Use: "range ", + Aliases: []string{}, + Short: "Perform a range scan for keys from a start to end key", + Long: `This performa a range scan for keys starting with the given start +key and ending with the end key. This uses a Trie to search for matching keys +within the range and returns all matched keys.`, + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + path := viper.GetString("path") + + start := args[0] + end := args[1] + + os.Exit(_range(path, start, end)) + }, +} + +func init() { + RootCmd.AddCommand(rangeCmd) +} + +func _range(path, start, end string) int { + db, err := bitcask.Open(path) + if err != nil { + log.WithError(err).Error("error opening database") + return 1 + } + defer db.Close() + + err = db.Range([]byte(start), []byte(end), func(key []byte) error { + value, err := db.Get(key) + if err != nil { + log.WithError(err).Error("error reading key") + return err + } + + fmt.Printf("%s\n", string(value)) + log.WithField("key", key).WithField("value", value).Debug("key/value") + return nil + }) + if err != nil { + log.WithError(err).Error("error rangening keys") + return 1 + } + + return 0 +}