diff --git a/formats/dcc/dcc.go b/formats/dcc/dcc.go index 85a8120..66c2335 100644 --- a/formats/dcc/dcc.go +++ b/formats/dcc/dcc.go @@ -88,52 +88,52 @@ func NewFromReader(reader io.ReadSeeker) (*Sprite, error) { func readDirectionHeader(reader io.ReadSeeker) (*directionHeader, error) { r := streaming.NewBitReader(reader) - codedSize, err := r.ReadBits(32) + codedSize, err := r.ReadBitsUnsigned(32) if err != nil { return nil, err } - hasRawPixelEncoding, err := r.ReadBits(1) + hasRawPixelEncoding, err := r.ReadBitsUnsigned(1) if err != nil { return nil, err } - compressEqualCells, err := r.ReadBits(1) + compressEqualCells, err := r.ReadBitsUnsigned(1) if err != nil { return nil, err } - variable0Bits, err := r.ReadBits(4) + variable0Bits, err := r.ReadBitsUnsigned(4) if err != nil { return nil, err } - widthBits, err := r.ReadBits(4) + widthBits, err := r.ReadBitsUnsigned(4) if err != nil { return nil, err } - heightBits, err := r.ReadBits(4) + heightBits, err := r.ReadBitsUnsigned(4) if err != nil { return nil, err } - offsetXBits, err := r.ReadBits(4) + offsetXBits, err := r.ReadBitsUnsigned(4) if err != nil { return nil, err } - offsetYBits, err := r.ReadBits(4) + offsetYBits, err := r.ReadBitsUnsigned(4) if err != nil { return nil, err } - optionalBytesBits, err := r.ReadBits(4) + optionalBytesBits, err := r.ReadBitsUnsigned(4) if err != nil { return nil, err } - codedBytesBits, err := r.ReadBits(4) + codedBytesBits, err := r.ReadBitsUnsigned(4) if err != nil { return nil, err } diff --git a/streaming/bitreader.go b/streaming/bitreader.go index 4bb3dec..8c82a91 100644 --- a/streaming/bitreader.go +++ b/streaming/bitreader.go @@ -15,7 +15,36 @@ func NewBitReader(reader io.Reader) *BitReader { return &BitReader{reader: reader} } -func (r *BitReader) ReadBits(count int) (uint64, error) { +func (r *BitReader) ReadBitsSigned(count int) (int64, error) { + value, err := r.readBits(count) + if err != nil { + return 0, err + } + + if count > 0 { + valueMasked := value &^ (1 << uint(count-1)) + if valueMasked != value { + return -int64(valueMasked), nil + } + } + + return int64(value), nil +} + +func (r *BitReader) ReadBitsUnsigned(count int) (uint64, error) { + return r.readBits(count) +} + +func (r *BitReader) ReadBitFlag() (bool, error) { + value, err := r.readBits(1) + if err != nil { + return false, err + } + + return value == 1, nil +} + +func (r *BitReader) readBits(count int) (uint64, error) { if count > 64 { return 0, errors.New("cannot read more than 64 bits at a time") } diff --git a/streaming/streaming_test.go b/streaming/streaming_test.go index c0cbe9c..726e7c4 100644 --- a/streaming/streaming_test.go +++ b/streaming/streaming_test.go @@ -18,13 +18,13 @@ func TestBitReader(t *testing.T) { r := NewBitReader(bytes.NewReader(data)) readPass := func(c int, v uint64) { - if value, err := r.ReadBits(c); value != v || err != nil { + if value, err := r.ReadBitsUnsigned(c); value != v || err != nil { t.Fail() } } readFail := func(c int) { - if value, err := r.ReadBits(c); value != 0 || err == nil { + if value, err := r.ReadBitsUnsigned(c); value != 0 || err == nil { t.Fail() } }