1
0
mirror of https://github.com/taigrr/bitcask synced 2025-01-18 04:03:17 -08:00

codec_index: check sizes, new tests for data corruption & refactor (#84)

* bitcask/codec_index: check key and data sizes

* codec_index: tests for key and data size overflows

* codec_index: simplify internal funcs for unused returns
This commit is contained in:
Ignacio Hagopian 2019-09-03 23:26:26 -03:00 committed by James Mills
parent 24ab3fbf27
commit 93cc1d409f
3 changed files with 86 additions and 40 deletions

View File

@ -351,7 +351,7 @@ func (b *Bitcask) reopen() error {
} }
defer f.Close() defer f.Close()
if err := internal.ReadIndex(f, t); err != nil { if err := internal.ReadIndex(f, t, b.config.maxKeySize, b.config.maxValueSize); err != nil {
return err return err
} }
} else { } else {

View File

@ -12,6 +12,8 @@ var (
errTruncatedKeySize = errors.New("key size is truncated") errTruncatedKeySize = errors.New("key size is truncated")
errTruncatedKeyData = errors.New("key data is truncated") errTruncatedKeyData = errors.New("key data is truncated")
errTruncatedData = errors.New("data is truncated") errTruncatedData = errors.New("data is truncated")
errKeySizeTooLarge = errors.New("key size too large")
errDataSizeTooLarge = errors.New("data size too large")
) )
const ( const (
@ -22,7 +24,7 @@ const (
sizeSize = int64Size sizeSize = int64Size
) )
func readKeyBytes(r io.Reader) ([]byte, error) { func readKeyBytes(r io.Reader, maxKeySize int) ([]byte, error) {
s := make([]byte, int32Size) s := make([]byte, int32Size)
_, err := io.ReadFull(r, s) _, err := io.ReadFull(r, s)
if err != nil { if err != nil {
@ -32,6 +34,10 @@ func readKeyBytes(r io.Reader) ([]byte, error) {
return nil, errors.Wrap(errTruncatedKeySize, err.Error()) return nil, errors.Wrap(errTruncatedKeySize, err.Error())
} }
size := binary.BigEndian.Uint32(s) size := binary.BigEndian.Uint32(s)
if size > uint32(maxKeySize) {
return nil, errKeySizeTooLarge
}
b := make([]byte, size) b := make([]byte, size)
_, err = io.ReadFull(r, b) _, err = io.ReadFull(r, b)
if err != nil { if err != nil {
@ -40,50 +46,54 @@ func readKeyBytes(r io.Reader) ([]byte, error) {
return b, nil return b, nil
} }
func writeBytes(b []byte, w io.Writer) (int, error) { func writeBytes(b []byte, w io.Writer) error {
s := make([]byte, int32Size) s := make([]byte, int32Size)
binary.BigEndian.PutUint32(s, uint32(len(b))) binary.BigEndian.PutUint32(s, uint32(len(b)))
n, err := w.Write(s) _, err := w.Write(s)
if err != nil { if err != nil {
return n, err return err
} }
m, err := w.Write(b) _, err = w.Write(b)
if err != nil { if err != nil {
return (n + m), err return err
} }
return (n + m), nil return nil
} }
func readItem(r io.Reader) (Item, error) { func readItem(r io.Reader, maxValueSize int) (Item, error) {
buf := make([]byte, (fileIDSize + offsetSize + sizeSize)) buf := make([]byte, (fileIDSize + offsetSize + sizeSize))
_, err := io.ReadFull(r, buf) _, err := io.ReadFull(r, buf)
if err != nil { if err != nil {
return Item{}, errors.Wrap(errTruncatedData, err.Error()) return Item{}, errors.Wrap(errTruncatedData, err.Error())
} }
size := int64(binary.BigEndian.Uint64(buf[(fileIDSize + offsetSize):]))
if size > int64(maxValueSize) {
return Item{}, errDataSizeTooLarge
}
return Item{ return Item{
FileID: int(binary.BigEndian.Uint32(buf[:fileIDSize])), FileID: int(binary.BigEndian.Uint32(buf[:fileIDSize])),
Offset: int64(binary.BigEndian.Uint64(buf[fileIDSize:(fileIDSize + offsetSize)])), Offset: int64(binary.BigEndian.Uint64(buf[fileIDSize:(fileIDSize + offsetSize)])),
Size: int64(binary.BigEndian.Uint64(buf[(fileIDSize + offsetSize):])), Size: size,
}, nil }, nil
} }
func writeItem(item Item, w io.Writer) (int, error) { func writeItem(item Item, w io.Writer) error {
buf := make([]byte, (fileIDSize + offsetSize + sizeSize)) buf := make([]byte, (fileIDSize + offsetSize + sizeSize))
binary.BigEndian.PutUint32(buf[:fileIDSize], uint32(item.FileID)) binary.BigEndian.PutUint32(buf[:fileIDSize], uint32(item.FileID))
binary.BigEndian.PutUint64(buf[fileIDSize:(fileIDSize+offsetSize)], uint64(item.Offset)) binary.BigEndian.PutUint64(buf[fileIDSize:(fileIDSize+offsetSize)], uint64(item.Offset))
binary.BigEndian.PutUint64(buf[(fileIDSize+offsetSize):], uint64(item.Size)) binary.BigEndian.PutUint64(buf[(fileIDSize+offsetSize):], uint64(item.Size))
n, err := w.Write(buf) _, err := w.Write(buf)
if err != nil { if err != nil {
return 0, err return err
} }
return n, nil return nil
} }
// ReadIndex reads a persisted tree from a io.Reader into a Tree // ReadIndex reads a persisted from a io.Reader into a Tree
func ReadIndex(r io.Reader, t art.Tree) error { func ReadIndex(r io.Reader, t art.Tree, maxKeySize, maxValueSize int) error {
for { for {
key, err := readKeyBytes(r) key, err := readKeyBytes(r, maxKeySize)
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
break break
@ -91,7 +101,7 @@ func ReadIndex(r io.Reader, t art.Tree) error {
return err return err
} }
item, err := readItem(r) item, err := readItem(r, maxValueSize)
if err != nil { if err != nil {
return err return err
} }
@ -105,13 +115,13 @@ func ReadIndex(r io.Reader, t art.Tree) error {
// WriteIndex persists a Tree into a io.Writer // WriteIndex persists a Tree into a io.Writer
func WriteIndex(t art.Tree, w io.Writer) (err error) { func WriteIndex(t art.Tree, w io.Writer) (err error) {
t.ForEach(func(node art.Node) bool { t.ForEach(func(node art.Node) bool {
_, err = writeBytes(node.Key(), w) err = writeBytes(node.Key(), w)
if err != nil { if err != nil {
return false return false
} }
item := node.Value().(Item) item := node.Value().(Item)
_, err := writeItem(item, w) err := writeItem(item, w)
if err != nil { if err != nil {
return false return false
} }

View File

@ -3,6 +3,7 @@ package internal
import ( import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"encoding/binary"
"testing" "testing"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -35,7 +36,7 @@ func TestReadIndex(t *testing.T) {
b := bytes.NewBuffer(sampleTreeBytes) b := bytes.NewBuffer(sampleTreeBytes)
at := art.New() at := art.New()
err := ReadIndex(b, at) err := ReadIndex(b, at, 1024, 1024)
if err != nil { if err != nil {
t.Fatalf("error while deserializing correct sample tree: %v", err) t.Fatalf("error while deserializing correct sample tree: %v", err)
} }
@ -55,27 +56,62 @@ func TestReadIndex(t *testing.T) {
func TestReadCorruptedData(t *testing.T) { func TestReadCorruptedData(t *testing.T) {
sampleBytes, _ := base64.StdEncoding.DecodeString(base64SampleTree) sampleBytes, _ := base64.StdEncoding.DecodeString(base64SampleTree)
table := []struct {
name string
err error
data []byte
}{
{name: "truncated-key-size-first-item", err: errTruncatedKeySize, data: sampleBytes[:2]},
{name: "truncated-key-data-second-item", err: errTruncatedKeyData, data: sampleBytes[:6]},
{name: "truncated-key-size-second-item", err: errTruncatedKeySize, data: sampleBytes[:(int32Size+4+fileIDSize+offsetSize+sizeSize)+2]},
{name: "truncated-key-data-second-item", err: errTruncatedKeyData, data: sampleBytes[:(int32Size+4+fileIDSize+offsetSize+sizeSize)+6]},
{name: "truncated-data", err: errTruncatedData, data: sampleBytes[:int32Size+4+(fileIDSize+offsetSize+sizeSize-3)]},
}
for i := range table { t.Run("truncated", func(t *testing.T) {
t.Run(table[i].name, func(t *testing.T) { table := []struct {
bf := bytes.NewBuffer(table[i].data) name string
err error
data []byte
}{
{name: "key-size-first-item", err: errTruncatedKeySize, data: sampleBytes[:2]},
{name: "key-data-second-item", err: errTruncatedKeyData, data: sampleBytes[:6]},
{name: "key-size-second-item", err: errTruncatedKeySize, data: sampleBytes[:(int32Size+4+fileIDSize+offsetSize+sizeSize)+2]},
{name: "key-data-second-item", err: errTruncatedKeyData, data: sampleBytes[:(int32Size+4+fileIDSize+offsetSize+sizeSize)+6]},
{name: "data", err: errTruncatedData, data: sampleBytes[:int32Size+4+(fileIDSize+offsetSize+sizeSize-3)]},
}
for i := range table {
t.Run(table[i].name, func(t *testing.T) {
bf := bytes.NewBuffer(table[i].data)
if err := ReadIndex(bf, art.New(), 1024, 1024); errors.Cause(err) != table[i].err {
t.Fatalf("expected %v, got %v", table[i].err, err)
}
})
}
})
t.Run("overflow", func(t *testing.T) {
overflowKeySize := make([]byte, len(sampleBytes))
copy(overflowKeySize, sampleBytes)
binary.BigEndian.PutUint32(overflowKeySize, 1025)
overflowDataSize := make([]byte, len(sampleBytes))
copy(overflowDataSize, sampleBytes)
binary.BigEndian.PutUint32(overflowDataSize[int32Size+4+fileIDSize+offsetSize:], 1025)
table := []struct {
name string
err error
maxKeySize int
maxValueSize int
data []byte
}{
{name: "key-data-overflow", err: errKeySizeTooLarge, maxKeySize: 1024, maxValueSize: 1024, data: overflowKeySize},
{name: "item-data-overflow", err: errDataSizeTooLarge, maxKeySize: 1024, maxValueSize: 1024, data: overflowDataSize},
}
for i := range table {
t.Run(table[i].name, func(t *testing.T) {
bf := bytes.NewBuffer(table[i].data)
if err := ReadIndex(bf, art.New(), table[i].maxKeySize, table[i].maxValueSize); errors.Cause(err) != table[i].err {
t.Fatalf("expected %v, got %v", table[i].err, err)
}
})
}
})
if err := ReadIndex(bf, art.New()); errors.Cause(err) != table[i].err {
t.Fatalf("expected %v, got %v", table[i].err, err)
}
})
}
} }
func getSampleTree() (art.Tree, int) { func getSampleTree() (art.Tree, int) {