From 93cc1d409ff7bc66c27e71174a23d5171a7fcabc Mon Sep 17 00:00:00 2001 From: Ignacio Hagopian Date: Tue, 3 Sep 2019 23:26:26 -0300 Subject: [PATCH] codec_index: check sizes, new tests for data corruption & refactor (#84) * bitcask/codec_index: check key and data sizes * codec_index: tests for key and data size overflows * codec_index: simplify internal funcs for unused returns --- bitcask.go | 2 +- internal/codec_index.go | 48 ++++++++++++++--------- internal/codec_index_test.go | 76 ++++++++++++++++++++++++++---------- 3 files changed, 86 insertions(+), 40 deletions(-) diff --git a/bitcask.go b/bitcask.go index 30f7cb9..577acd0 100644 --- a/bitcask.go +++ b/bitcask.go @@ -351,7 +351,7 @@ func (b *Bitcask) reopen() error { } defer f.Close() - if err := internal.ReadIndex(f, t); err != nil { + if err := internal.ReadIndex(f, t, b.config.maxKeySize, b.config.maxValueSize); err != nil { return err } } else { diff --git a/internal/codec_index.go b/internal/codec_index.go index 5dfab16..ba0e959 100644 --- a/internal/codec_index.go +++ b/internal/codec_index.go @@ -12,6 +12,8 @@ var ( errTruncatedKeySize = errors.New("key size is truncated") errTruncatedKeyData = errors.New("key data is truncated") errTruncatedData = errors.New("data is truncated") + errKeySizeTooLarge = errors.New("key size too large") + errDataSizeTooLarge = errors.New("data size too large") ) const ( @@ -22,7 +24,7 @@ const ( sizeSize = int64Size ) -func readKeyBytes(r io.Reader) ([]byte, error) { +func readKeyBytes(r io.Reader, maxKeySize int) ([]byte, error) { s := make([]byte, int32Size) _, err := io.ReadFull(r, s) if err != nil { @@ -32,6 +34,10 @@ func readKeyBytes(r io.Reader) ([]byte, error) { return nil, errors.Wrap(errTruncatedKeySize, err.Error()) } size := binary.BigEndian.Uint32(s) + if size > uint32(maxKeySize) { + return nil, errKeySizeTooLarge + } + b := make([]byte, size) _, err = io.ReadFull(r, b) if err != nil { @@ -40,50 +46,54 @@ func readKeyBytes(r io.Reader) ([]byte, error) { return b, nil } -func writeBytes(b []byte, w io.Writer) (int, error) { +func writeBytes(b []byte, w io.Writer) error { s := make([]byte, int32Size) binary.BigEndian.PutUint32(s, uint32(len(b))) - n, err := w.Write(s) + _, err := w.Write(s) if err != nil { - return n, err + return err } - m, err := w.Write(b) + _, err = w.Write(b) if err != nil { - return (n + m), err + return err } - return (n + m), nil + return nil } -func readItem(r io.Reader) (Item, error) { +func readItem(r io.Reader, maxValueSize int) (Item, error) { buf := make([]byte, (fileIDSize + offsetSize + sizeSize)) _, err := io.ReadFull(r, buf) if err != nil { return Item{}, errors.Wrap(errTruncatedData, err.Error()) } + size := int64(binary.BigEndian.Uint64(buf[(fileIDSize + offsetSize):])) + if size > int64(maxValueSize) { + return Item{}, errDataSizeTooLarge + } return Item{ FileID: int(binary.BigEndian.Uint32(buf[:fileIDSize])), Offset: int64(binary.BigEndian.Uint64(buf[fileIDSize:(fileIDSize + offsetSize)])), - Size: int64(binary.BigEndian.Uint64(buf[(fileIDSize + offsetSize):])), + Size: size, }, nil } -func writeItem(item Item, w io.Writer) (int, error) { +func writeItem(item Item, w io.Writer) error { buf := make([]byte, (fileIDSize + offsetSize + sizeSize)) binary.BigEndian.PutUint32(buf[:fileIDSize], uint32(item.FileID)) binary.BigEndian.PutUint64(buf[fileIDSize:(fileIDSize+offsetSize)], uint64(item.Offset)) binary.BigEndian.PutUint64(buf[(fileIDSize+offsetSize):], uint64(item.Size)) - n, err := w.Write(buf) + _, err := w.Write(buf) if err != nil { - return 0, err + return err } - return n, nil + return nil } -// ReadIndex reads a persisted tree from a io.Reader into a Tree -func ReadIndex(r io.Reader, t art.Tree) error { +// ReadIndex reads a persisted from a io.Reader into a Tree +func ReadIndex(r io.Reader, t art.Tree, maxKeySize, maxValueSize int) error { for { - key, err := readKeyBytes(r) + key, err := readKeyBytes(r, maxKeySize) if err != nil { if err == io.EOF { break @@ -91,7 +101,7 @@ func ReadIndex(r io.Reader, t art.Tree) error { return err } - item, err := readItem(r) + item, err := readItem(r, maxValueSize) if err != nil { return err } @@ -105,13 +115,13 @@ func ReadIndex(r io.Reader, t art.Tree) error { // WriteIndex persists a Tree into a io.Writer func WriteIndex(t art.Tree, w io.Writer) (err error) { t.ForEach(func(node art.Node) bool { - _, err = writeBytes(node.Key(), w) + err = writeBytes(node.Key(), w) if err != nil { return false } item := node.Value().(Item) - _, err := writeItem(item, w) + err := writeItem(item, w) if err != nil { return false } diff --git a/internal/codec_index_test.go b/internal/codec_index_test.go index bb48db1..b465e3c 100644 --- a/internal/codec_index_test.go +++ b/internal/codec_index_test.go @@ -3,6 +3,7 @@ package internal import ( "bytes" "encoding/base64" + "encoding/binary" "testing" "github.com/pkg/errors" @@ -35,7 +36,7 @@ func TestReadIndex(t *testing.T) { b := bytes.NewBuffer(sampleTreeBytes) at := art.New() - err := ReadIndex(b, at) + err := ReadIndex(b, at, 1024, 1024) if err != nil { t.Fatalf("error while deserializing correct sample tree: %v", err) } @@ -55,27 +56,62 @@ func TestReadIndex(t *testing.T) { func TestReadCorruptedData(t *testing.T) { sampleBytes, _ := base64.StdEncoding.DecodeString(base64SampleTree) - table := []struct { - name string - err error - data []byte - }{ - {name: "truncated-key-size-first-item", err: errTruncatedKeySize, data: sampleBytes[:2]}, - {name: "truncated-key-data-second-item", err: errTruncatedKeyData, data: sampleBytes[:6]}, - {name: "truncated-key-size-second-item", err: errTruncatedKeySize, data: sampleBytes[:(int32Size+4+fileIDSize+offsetSize+sizeSize)+2]}, - {name: "truncated-key-data-second-item", err: errTruncatedKeyData, data: sampleBytes[:(int32Size+4+fileIDSize+offsetSize+sizeSize)+6]}, - {name: "truncated-data", err: errTruncatedData, data: sampleBytes[:int32Size+4+(fileIDSize+offsetSize+sizeSize-3)]}, - } - for i := range table { - t.Run(table[i].name, func(t *testing.T) { - bf := bytes.NewBuffer(table[i].data) + t.Run("truncated", func(t *testing.T) { + table := []struct { + name string + err error + data []byte + }{ + {name: "key-size-first-item", err: errTruncatedKeySize, data: sampleBytes[:2]}, + {name: "key-data-second-item", err: errTruncatedKeyData, data: sampleBytes[:6]}, + {name: "key-size-second-item", err: errTruncatedKeySize, data: sampleBytes[:(int32Size+4+fileIDSize+offsetSize+sizeSize)+2]}, + {name: "key-data-second-item", err: errTruncatedKeyData, data: sampleBytes[:(int32Size+4+fileIDSize+offsetSize+sizeSize)+6]}, + {name: "data", err: errTruncatedData, data: sampleBytes[:int32Size+4+(fileIDSize+offsetSize+sizeSize-3)]}, + } + + for i := range table { + t.Run(table[i].name, func(t *testing.T) { + bf := bytes.NewBuffer(table[i].data) + + if err := ReadIndex(bf, art.New(), 1024, 1024); errors.Cause(err) != table[i].err { + t.Fatalf("expected %v, got %v", table[i].err, err) + } + }) + } + }) + + t.Run("overflow", func(t *testing.T) { + overflowKeySize := make([]byte, len(sampleBytes)) + copy(overflowKeySize, sampleBytes) + binary.BigEndian.PutUint32(overflowKeySize, 1025) + + overflowDataSize := make([]byte, len(sampleBytes)) + copy(overflowDataSize, sampleBytes) + binary.BigEndian.PutUint32(overflowDataSize[int32Size+4+fileIDSize+offsetSize:], 1025) + + table := []struct { + name string + err error + maxKeySize int + maxValueSize int + data []byte + }{ + {name: "key-data-overflow", err: errKeySizeTooLarge, maxKeySize: 1024, maxValueSize: 1024, data: overflowKeySize}, + {name: "item-data-overflow", err: errDataSizeTooLarge, maxKeySize: 1024, maxValueSize: 1024, data: overflowDataSize}, + } + + for i := range table { + t.Run(table[i].name, func(t *testing.T) { + bf := bytes.NewBuffer(table[i].data) + + if err := ReadIndex(bf, art.New(), table[i].maxKeySize, table[i].maxValueSize); errors.Cause(err) != table[i].err { + t.Fatalf("expected %v, got %v", table[i].err, err) + } + }) + } + }) - if err := ReadIndex(bf, art.New()); errors.Cause(err) != table[i].err { - t.Fatalf("expected %v, got %v", table[i].err, err) - } - }) - } } func getSampleTree() (art.Tree, int) {