diff --git a/streaming/bitreader.go b/streaming/bitreader.go index 8217aa0..d4ab8c2 100644 --- a/streaming/bitreader.go +++ b/streaming/bitreader.go @@ -41,6 +41,9 @@ func (r *BitReader) ReadBits(count int) (uint64, error) { value <<= uint(bitsRead) value |= uint64(buffer) + + r.offset += bitsRead + count -= bitsRead } return value, nil diff --git a/streaming/streaming_test.go b/streaming/streaming_test.go index 142f772..66debab 100644 --- a/streaming/streaming_test.go +++ b/streaming/streaming_test.go @@ -2,7 +2,6 @@ package streaming import ( "bytes" - "fmt" "testing" ) @@ -16,6 +15,28 @@ func TestBitReader(t *testing.T) { 0xff, // 11111111 } - reader := NewReader(bytes.NewReader(data)) - fmt.Println(reader.ReadBits(2)) + r := NewReader(bytes.NewReader(data)) + + readPass := func(c int, v uint64) { + if value, err := r.ReadBits(c); value != v || err != nil { + t.Fail() + } + } + + readFail := func(c int) { + if value, err := r.ReadBits(c); value != 0 || err == nil { + t.Fail() + } + } + + readPass(0, 0x00) + readFail(65) + readPass(2, 0x01) + readPass(2, 0x02) + readPass(3, 0x04) + readPass(1, 0x01) + readPass(12, 0x096f) + readPass(8, 0x000a) + readPass(20, 0x0a00ff) + readFail(1) }