diff --git a/bitcask_test.go b/bitcask_test.go index 7e8958e..a2a916d 100644 --- a/bitcask_test.go +++ b/bitcask_test.go @@ -200,41 +200,29 @@ func TestMerge(t *testing.T) { } func TestConcurrent(t *testing.T) { + var ( + db *Bitcask + err error + ) + assert := assert.New(t) testdir, err := ioutil.TempDir("", "bitcask") assert.NoError(err) t.Run("Setup", func(t *testing.T) { - var ( - db *Bitcask - err error - ) - t.Run("Open", func(t *testing.T) { db, err = Open(testdir) assert.NoError(err) }) t.Run("Put", func(t *testing.T) { - for i := 0; i < 1024; i++ { - err = db.Put(string(i), []byte(strings.Repeat(" ", 1024))) - assert.NoError(err) - } + err = db.Put("foo", []byte("bar")) + assert.NoError(err) }) }) t.Run("Concurrent", func(t *testing.T) { - var ( - db *Bitcask - err error - ) - - t.Run("Open", func(t *testing.T) { - db, err = Open(testdir) - assert.NoError(err) - }) - t.Run("Put", func(t *testing.T) { f := func(wg *sync.WaitGroup, x int) { defer func() { @@ -261,6 +249,29 @@ func TestConcurrent(t *testing.T) { wg.Wait() }) + t.Run("Get", func(t *testing.T) { + f := func(wg *sync.WaitGroup, N int) { + defer func() { + wg.Done() + }() + for i := 0; i <= N; i++ { + value, err := db.Get("foo") + assert.NoError(err) + assert.Equal([]byte("bar"), value) + } + } + + wg := &sync.WaitGroup{} + + go f(wg, 100) + wg.Add(1) + + go f(wg, 100) + wg.Add(1) + + wg.Wait() + }) + t.Run("Close", func(t *testing.T) { err = db.Close() assert.NoError(err) diff --git a/datafile.go b/datafile.go index 8e4fa0d..ae24d70 100644 --- a/datafile.go +++ b/datafile.go @@ -21,7 +21,7 @@ var ( ) type Datafile struct { - sync.Mutex + sync.RWMutex id int r *os.File @@ -105,17 +105,23 @@ func (df *Datafile) Size() (int64, error) { return stat.Size(), nil } -func (df *Datafile) Read() (pb.Entry, error) { - var e pb.Entry +func (df *Datafile) Read() (e pb.Entry, err error) { + df.Lock() + defer df.Unlock() + return e, df.dec.Decode(&e) } func (df *Datafile) ReadAt(index int64) (e pb.Entry, err error) { + df.Lock() + defer df.Unlock() + _, err = df.r.Seek(index, os.SEEK_SET) if err != nil { return } - return df.Read() + + return e, df.dec.Decode(&e) } func (df *Datafile) Write(e pb.Entry) (int64, error) { diff --git a/streampb/stream.go b/streampb/stream.go index 25d1602..1efda01 100644 --- a/streampb/stream.go +++ b/streampb/stream.go @@ -16,26 +16,27 @@ const ( // NewEncoder creates a streaming protobuf encoder. func NewEncoder(w io.Writer) *Encoder { - return &Encoder{w: w, prefixBuf: make([]byte, prefixSize)} + return &Encoder{w} } // Encoder wraps an underlying io.Writer and allows you to stream // proto encodings on it. type Encoder struct { - w io.Writer - prefixBuf []byte + w io.Writer } // Encode takes any proto.Message and streams it to the underlying writer. // Messages are framed with a length prefix. func (e *Encoder) Encode(msg proto.Message) error { + prefixBuf := make([]byte, prefixSize) + buf, err := proto.Marshal(msg) if err != nil { return err } - binary.BigEndian.PutUint64(e.prefixBuf, uint64(len(buf))) + binary.BigEndian.PutUint64(prefixBuf, uint64(len(buf))) - if _, err := e.w.Write(e.prefixBuf); err != nil { + if _, err := e.w.Write(prefixBuf); err != nil { return errors.Wrap(err, "failed writing length prefix") } @@ -45,28 +46,26 @@ func (e *Encoder) Encode(msg proto.Message) error { // NewDecoder creates a streaming protobuf decoder. func NewDecoder(r io.Reader) *Decoder { - return &Decoder{ - r: r, - prefixBuf: make([]byte, prefixSize), - } + return &Decoder{r: r} } // Decoder wraps an underlying io.Reader and allows you to stream // proto decodings on it. type Decoder struct { - r io.Reader - prefixBuf []byte + r io.Reader } // Decode takes a proto.Message and unmarshals the next payload in the // underlying io.Reader. It returns an EOF when it's done. func (d *Decoder) Decode(v proto.Message) error { - _, err := io.ReadFull(d.r, d.prefixBuf) + prefixBuf := make([]byte, prefixSize) + + _, err := io.ReadFull(d.r, prefixBuf) if err != nil { return err } - n := binary.BigEndian.Uint64(d.prefixBuf) + n := binary.BigEndian.Uint64(prefixBuf) buf := make([]byte, n)