From a67b06a51172bca5dddd205485b29cd326fb0664 Mon Sep 17 00:00:00 2001 From: Alex Yatskov Date: Sat, 9 Feb 2019 16:00:49 -0800 Subject: [PATCH] work on bit reader --- streaming/bitreader.go | 90 +++++++++++++++++++++++++------------ streaming/streaming_test.go | 37 +++++++-------- 2 files changed, 79 insertions(+), 48 deletions(-) diff --git a/streaming/bitreader.go b/streaming/bitreader.go index 90706a5..dbf7f71 100644 --- a/streaming/bitreader.go +++ b/streaming/bitreader.go @@ -5,9 +5,9 @@ import ( ) type BitReader struct { - reader io.Reader - offset int - buffer [1]byte + reader io.Reader + bitOffset int + tailByte byte } func NewBitReader(reader io.Reader) *BitReader { @@ -66,31 +66,65 @@ func (r *BitReader) ReadInt64(count int) (int64, error) { } func (r *BitReader) ReadUint64(count int) (uint64, error) { - var value uint64 - for count > 0 { - bitOffset := r.offset % 8 - bitsLeft := 8 - bitOffset - if bitsLeft == 8 { - if _, err := r.reader.Read(r.buffer[:]); err != nil { - return 0, err - } - } - - bitsRead := count - if bitsRead > bitsLeft { - bitsRead = bitsLeft - } - - buffer := r.buffer[0] - buffer <<= uint(bitOffset) - buffer >>= (uint(bitOffset) + uint(bitsLeft-bitsRead)) - - value <<= uint(bitsRead) - value |= uint64(buffer) - - r.offset += bitsRead - count -= bitsRead + buffer, bitOffset, err := r.readBytes(count) + if err != nil { + return 0, err } - return value, nil + var result uint64 + + remainder := count + for byteOffset := 0; remainder > 0; byteOffset++ { + bitsRead := 8 - bitOffset + if bitsRead > remainder { + bitsRead = remainder + } + + bufferByte := buffer[byteOffset] + bufferByte >>= uint(bitOffset) + bufferByte &= ^(0xff << uint(bitsRead)) + + result |= (uint64(bufferByte) << uint(count-remainder)) + + remainder -= bitsRead + bitOffset = 0 + } + + return result, nil +} + +func (r *BitReader) readBytes(count int) ([]byte, int, error) { + if count == 0 { + return nil, 0, nil + } + + var ( + bitOffsetInByte = r.bitOffset % 8 + bitsLeftInByte = 8 - bitOffsetInByte + bytesNeeded = 1 + ) + + if bitsLeftInByte < count { + bitsOverrun := count - bitsLeftInByte + bytesNeeded += bitsOverrun / 8 + if bitsOverrun%8 != 0 { + bytesNeeded++ + } + } + + buffer := make([]byte, bytesNeeded) + bufferToRead := buffer + if bitsLeftInByte < 8 { + buffer[0] = r.tailByte + bufferToRead = buffer[1:] + } + + if _, err := io.ReadAtLeast(r.reader, bufferToRead, len(bufferToRead)); err != nil { + return nil, 0, err + } + + r.bitOffset += count + r.tailByte = buffer[bytesNeeded-1] + + return buffer, bitOffsetInByte, nil } diff --git a/streaming/streaming_test.go b/streaming/streaming_test.go index 9624296..2a1250e 100644 --- a/streaming/streaming_test.go +++ b/streaming/streaming_test.go @@ -2,40 +2,37 @@ package streaming import ( "bytes" + "fmt" "testing" ) func TestBitReader(t *testing.T) { data := []byte{ - 0x69, // 01101001 - 0x96, // 10010110 - 0xf0, // 11110000 - 0xaa, // 10101010 - 0x00, // 00000000 - 0xff, // 11111111 + 0x01, // 00000001 + 0x23, // 00100011 + 0x45, // 01000101 + 0x67, // 01100111 + 0x89, // 01100111 + 0xAB, // 10101011 + 0xCD, // 11001101 + 0xEF, // 11101111 } r := NewBitReader(bytes.NewReader(data)) readPass := func(c int, v uint64) { if value, err := r.ReadUint64(c); value != v || err != nil { - t.Fail() - } - } - - readFail := func(c int) { - if value, err := r.ReadUint64(c); value != 0 || err == nil { + fmt.Printf("%.16x (expected %.16x)\n", value, v) t.Fail() } } readPass(0, 0x00) - readPass(2, 0x01) - readPass(2, 0x02) - readPass(3, 0x04) - readPass(1, 0x01) - readPass(12, 0x096f) - readPass(8, 0x000a) - readPass(20, 0x0a00ff) - readFail(1) + readPass(8, 0x01) + readPass(16, 0x4523) + readPass(3, 0x67&0x07) + readPass(13, 0x8967>>3) + readPass(13, 0xcdab&0x1fff) + readPass(2, (0xcdab>>13)&3) + readPass(9, 0xefcdab>>15) }