ur: refactors ur_bsr_bytes_any(), fixes edge-case bugs

This commit is contained in:
Joe Bryan 2020-08-31 15:56:55 -07:00
parent bb136b7981
commit c123e9413a
2 changed files with 49 additions and 44 deletions

View File

@ -1184,6 +1184,8 @@ _test_bsr_bytes_any_loop(const char *cap, uint8_t len, uint8_t val)
} }
free(bytes); free(bytes);
free(d);
free(c);
return ret; return ret;
} }

View File

@ -360,6 +360,8 @@ ur_bsr_bytes_any(ur_bsr_t *bsr, uint64_t len, uint8_t *out)
{ {
uint64_t left = bsr->left; uint64_t left = bsr->left;
bsr->bits += len;
if ( !left ) { if ( !left ) {
return; return;
} }
@ -368,33 +370,24 @@ ur_bsr_bytes_any(ur_bsr_t *bsr, uint64_t len, uint8_t *out)
uint8_t off = bsr->off; uint8_t off = bsr->off;
uint64_t len_byt = len >> 3; uint64_t len_byt = len >> 3;
uint8_t len_bit = ur_mask_3(len); uint8_t len_bit = ur_mask_3(len);
uint64_t need = len_byt + !!len_bit;
if ( !off ) { if ( !off ) {
uint8_t bits = off + len_bit;
uint64_t need = len_byt + (bits >> 3) + !!ur_mask_3(bits);
if ( need > left ) { if ( need > left ) {
memcpy(out, b, left); memcpy(out, b, left);
left = 0;
bsr->bytes = 0; bsr->bytes = 0;
bsr->left = 0;
} }
else { else {
memcpy(out, b, len_byt); memcpy(out, b, len_byt);
off = len_bit; off = len_bit;
left -= len_byt;
if ( !left ) {
bsr->bytes = 0;
}
else {
bsr->bytes += len_byt;
}
bsr->left = left;
if ( off ) { if ( off ) {
out[len_byt] = b[len_byt] & ((1 << off) - 1); out[len_byt] = b[len_byt] & ((1 << off) - 1);
} }
left -= len_byt;
bsr->bytes = ( left ) ? b + len_byt : 0;
} }
} }
// the most-significant bits from a byte in the stream // the most-significant bits from a byte in the stream
@ -403,61 +396,71 @@ ur_bsr_bytes_any(ur_bsr_t *bsr, uint64_t len, uint8_t *out)
else { else {
uint8_t rest = 8 - off; uint8_t rest = 8 - off;
uint8_t mask = (1 << off) - 1; uint8_t mask = (1 << off) - 1;
uint8_t byt = b[0]; uint8_t byt, l, m = *b >> off;
uint8_t l, m = byt >> off; uint64_t last = left - 1;
ur_bool_t end;
uint64_t i, max;
// loop over all the bytes we need (or all that remain)
//
// [l] holds [off] bits
// [m] holds [rest] bits
//
{ {
uint64_t need = len_byt + !!len_bit; uint64_t max = ur_min(last, len_byt);
end = need >= left; uint64_t i;
max = end ? (left - 1) : len_byt;
for ( i = 0; i < max; i++ ) {
byt = *++b;
l = byt & mask;
out[i] = m ^ (l << rest);
m = byt >> off;
}
} }
for ( i = 0; i < max; i++ ) { // we're reading into or beyond the last byte [bsr]
byt = b[1ULL + i]; //
l = byt & mask; // [m] holds all the remaining bits in [bsr],
out[i] = m ^ (l << rest); // but we might not need all of it
m = byt >> off; //
} if ( need >= left ) {
uint8_t bits = len - (last << 3);
if ( end ) { if ( bits < rest ) {
if ( len_bit && len_bit < rest ) { out[last] = m & ((1 << bits) - 1);
out[max] = m & ((1 << len_bit) - 1); bsr->bytes = b;
bsr->bytes += max; left = 1;
left -= max;
off += len_bit; off += len_bit;
} }
else { else {
out[max] = m; out[last] = m;
bsr->bytes = 0; bsr->bytes = 0;
left = 0; left = 0;
off = 0; off = 0;
} }
} }
// we need less than a byte, but it might span multiple bytes
//
else { else {
uint8_t bits = off + len_bit; uint8_t bits = off + len_bit;
uint64_t step = max + !!(bits >> 3); uint8_t step = !!(bits >> 3);
bsr->bytes += step; bsr->bytes = b + step;
left -= step; left -= len_byt + step;
off = ur_mask_3(bits); off = ur_mask_3(bits);
if ( len_bit ) { if ( len_bit ) {
if ( len_bit <= rest ) { if ( len_bit <= rest ) {
out[max] = m & ((1 << len_bit) - 1); out[len_byt] = m & ((1 << len_bit) - 1);
} }
else { else {
l = b[1ULL + max] & ((1 << off) - 1);; l = *++b & ((1 << off) - 1);
out[max] = m ^ (l << rest); out[len_byt] = m ^ (l << rest);
} }
} }
} }
} }
bsr->off = off; bsr->off = off;
bsr->left = left; bsr->left = left;
bsr->bits += len;
} }
} }