From a366d2b104a210b813dc60a2b7d1219f3d1ff092 Mon Sep 17 00:00:00 2001 From: Alex Yatskov Date: Wed, 23 Jan 2019 20:06:32 -0800 Subject: [PATCH] add read methods for basic types --- formats/dcc/dcc.go | 38 ++++++++++-------------- streaming/bitreader.go | 59 ++++++++++++++++++++++++------------- streaming/streaming_test.go | 5 ++-- 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/formats/dcc/dcc.go b/formats/dcc/dcc.go index 66c2335..d4ab79e 100644 --- a/formats/dcc/dcc.go +++ b/formats/dcc/dcc.go @@ -88,69 +88,61 @@ func NewFromReader(reader io.ReadSeeker) (*Sprite, error) { func readDirectionHeader(reader io.ReadSeeker) (*directionHeader, error) { r := streaming.NewBitReader(reader) - codedSize, err := r.ReadBitsUnsigned(32) + var ( + header directionHeader + err error + ) + + header.CodedSize, err = r.ReadUint32(32) if err != nil { return nil, err } - hasRawPixelEncoding, err := r.ReadBitsUnsigned(1) + header.HasRawPixelEncoding, err = r.ReadBool() if err != nil { return nil, err } - compressEqualCells, err := r.ReadBitsUnsigned(1) + header.CompressEqualCells, err = r.ReadBool() if err != nil { return nil, err } - variable0Bits, err := r.ReadBitsUnsigned(4) + header.Variable0Bits, err = r.ReadUint32(4) if err != nil { return nil, err } - widthBits, err := r.ReadBitsUnsigned(4) + header.WidthBits, err = r.ReadUint32(4) if err != nil { return nil, err } - heightBits, err := r.ReadBitsUnsigned(4) + header.HeightBits, err = r.ReadUint32(4) if err != nil { return nil, err } - offsetXBits, err := r.ReadBitsUnsigned(4) + header.OffsetXBits, err = r.ReadInt32(4) if err != nil { return nil, err } - offsetYBits, err := r.ReadBitsUnsigned(4) + header.OffsetYBits, err = r.ReadInt32(4) if err != nil { return nil, err } - optionalBytesBits, err := r.ReadBitsUnsigned(4) + header.OptionalBytesBits, err = r.ReadUint32(4) if err != nil { return nil, err } - codedBytesBits, err := r.ReadBitsUnsigned(4) + header.CodedBytesBits, err = r.ReadUint32(4) if err != nil { return nil, err } - header := directionHeader{ - CodedSize: uint32(codedSize), - HasRawPixelEncoding: hasRawPixelEncoding == 1, - CompressEqualCells: compressEqualCells == 1, - Variable0Bits: uint32(variable0Bits), - WidthBits: uint32(widthBits), - HeightBits: uint32(heightBits), - OffsetXBits: int32(offsetXBits), - OffsetYBits: int32(offsetYBits), - OptionalBytesBits: uint32(optionalBytesBits), - CodedBytesBits: uint32(codedBytesBits), - } - return &header, nil } diff --git a/streaming/bitreader.go b/streaming/bitreader.go index 8c82a91..90706a5 100644 --- a/streaming/bitreader.go +++ b/streaming/bitreader.go @@ -1,7 +1,6 @@ package streaming import ( - "errors" "io" ) @@ -15,8 +14,43 @@ func NewBitReader(reader io.Reader) *BitReader { return &BitReader{reader: reader} } -func (r *BitReader) ReadBitsSigned(count int) (int64, error) { - value, err := r.readBits(count) +func (r *BitReader) ReadBool() (bool, error) { + value, err := r.ReadUint64(1) + return value == 1, err +} + +func (r *BitReader) ReadInt8(count int) (int8, error) { + value, err := r.ReadInt64(count) + return int8(value), err +} + +func (r *BitReader) ReadUint8(count int) (uint8, error) { + value, err := r.ReadUint64(count) + return uint8(value), err +} + +func (r *BitReader) ReadInt16(count int) (int16, error) { + value, err := r.ReadInt64(count) + return int16(value), err +} + +func (r *BitReader) ReadUint16(count int) (uint16, error) { + value, err := r.ReadUint64(count) + return uint16(value), err +} + +func (r *BitReader) ReadInt32(count int) (int32, error) { + value, err := r.ReadInt64(count) + return int32(value), err +} + +func (r *BitReader) ReadUint32(count int) (uint32, error) { + value, err := r.ReadUint64(count) + return uint32(value), err +} + +func (r *BitReader) ReadInt64(count int) (int64, error) { + value, err := r.ReadUint64(count) if err != nil { return 0, err } @@ -31,24 +65,7 @@ func (r *BitReader) ReadBitsSigned(count int) (int64, error) { return int64(value), nil } -func (r *BitReader) ReadBitsUnsigned(count int) (uint64, error) { - return r.readBits(count) -} - -func (r *BitReader) ReadBitFlag() (bool, error) { - value, err := r.readBits(1) - if err != nil { - return false, err - } - - return value == 1, nil -} - -func (r *BitReader) readBits(count int) (uint64, error) { - if count > 64 { - return 0, errors.New("cannot read more than 64 bits at a time") - } - +func (r *BitReader) ReadUint64(count int) (uint64, error) { var value uint64 for count > 0 { bitOffset := r.offset % 8 diff --git a/streaming/streaming_test.go b/streaming/streaming_test.go index 726e7c4..9624296 100644 --- a/streaming/streaming_test.go +++ b/streaming/streaming_test.go @@ -18,19 +18,18 @@ func TestBitReader(t *testing.T) { r := NewBitReader(bytes.NewReader(data)) readPass := func(c int, v uint64) { - if value, err := r.ReadBitsUnsigned(c); value != v || err != nil { + if value, err := r.ReadUint64(c); value != v || err != nil { t.Fail() } } readFail := func(c int) { - if value, err := r.ReadBitsUnsigned(c); value != 0 || err == nil { + if value, err := r.ReadUint64(c); value != 0 || err == nil { t.Fail() } } readPass(0, 0x00) - readFail(65) readPass(2, 0x01) readPass(2, 0x02) readPass(3, 0x04)