diff --git a/pkg/bitio/bitio.go b/pkg/bitio/bitio.go index 73bdeae4..4ffd72ae 100644 --- a/pkg/bitio/bitio.go +++ b/pkg/bitio/bitio.go @@ -8,6 +8,9 @@ import ( "io" ) +var ErrOffset = errors.New("invalid seek offset") +var ErrNegativeNBits = errors.New("negative number of bits") + type BitReaderAt interface { ReadBitsAt(p []byte, nBits int, bitOff int64) (n int, err error) } @@ -114,6 +117,10 @@ func BitsByteCount(nBits int64) int64 { } func readFull(p []byte, nBits int, bitOff int64, fn func(p []byte, nBits int, bitOff int64) (int, error)) (int, error) { + if nBits < 0 { + return 0, ErrNegativeNBits + } + readBitOffset := 0 for readBitOffset < nBits { byteOffset := readBitOffset / 8 @@ -180,7 +187,6 @@ func EndPos(rs BitSeeker) (int64, error) { } // Reader is a BitReadSeeker and BitReaderAt reading from a io.ReadSeeker -// TODO: private? type Reader struct { bitPos int64 rs io.ReadSeeker @@ -195,6 +201,10 @@ func NewReaderFromReadSeeker(rs io.ReadSeeker) *Reader { } func (r *Reader) ReadBitsAt(p []byte, nBits int, bitOffset int64) (int, error) { + if nBits < 0 { + return 0, ErrNegativeNBits + } + readBytePos := bitOffset / 8 readSkipBits := int(bitOffset % 8) wantReadBits := readSkipBits + nBits @@ -311,8 +321,6 @@ func (r *SectionBitReader) ReadBits(p []byte, nBits int) (n int, err error) { return rBits, err } -var errOffset = errors.New("invalid seek offset") - func (r *SectionBitReader) SeekBits(bitOff int64, whence int) (int64, error) { switch whence { case io.SeekStart: @@ -325,7 +333,7 @@ func (r *SectionBitReader) SeekBits(bitOff int64, whence int) (int64, error) { panic("unknown whence") } if bitOff < r.bitBase { - return 0, errOffset + return 0, ErrOffset } r.bitOff = bitOff return bitOff - r.bitBase, nil @@ -411,7 +419,7 @@ func (m *MultiBitReader) SeekBits(bitOff int64, whence int) (int64, error) { panic("unknown whence") } if p < 0 || p > end { - return 0, errOffset + return 0, ErrOffset } m.pos = p diff --git a/pkg/bitio/rw64.go b/pkg/bitio/rw64.go index ddaa8a4c..4e562539 100644 --- a/pkg/bitio/rw64.go +++ b/pkg/bitio/rw64.go @@ -8,8 +8,8 @@ import ( // Read64 read nBits bits large unsigned integer from buf starting from firstBit. // Integer is read most significant bit first. func Read64(buf []byte, firstBit int, nBits int) uint64 { - if nBits > 64 { - panic(fmt.Sprintf("only supports =< 64 bits (%d)", nBits)) + if nBits < 0 || nBits > 64 { + panic(fmt.Sprintf("nBits must be 0-64 (%d)", nBits)) } be := binary.BigEndian @@ -91,8 +91,8 @@ func Read64(buf []byte, firstBit int, nBits int) uint64 { } func Write64(v uint64, nBits int, buf []byte, firstBit int) { - if nBits > 64 { - panic(fmt.Sprintf("only supports =< 64 bits (%d)", nBits)) + if nBits < 0 || nBits > 64 { + panic(fmt.Sprintf("nBits must be 0-64 (%d)", nBits)) } be := binary.BigEndian diff --git a/pkg/decode/numbers.go b/pkg/decode/numbers.go index 98ebaf18..cb4d89ec 100644 --- a/pkg/decode/numbers.go +++ b/pkg/decode/numbers.go @@ -37,6 +37,9 @@ func (d *D) TrySE(nBits int, endian Endian) (int64, error) { if err != nil { return 0, err } + if nBits == 0 { + return 0, nil + } if endian == LittleEndian { n = bitio.Uint64ReverseBytes(nBits, n) }