More invalid utf-8 tests

This commit is contained in:
Kovid Goyal 2024-01-13 13:57:09 +05:30
parent 8a10fcaf5a
commit fa3579656b
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 41 additions and 7 deletions

View File

@ -61,6 +61,7 @@ _Pragma("clang diagnostic pop")
#define reverse_numbered_bytes() simde_mm_setr_epi8(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0) #define reverse_numbered_bytes() simde_mm_setr_epi8(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)
// output[i] = MAX(0, a[i] - b[1i]) // output[i] = MAX(0, a[i] - b[1i])
#define subtract_saturate_epu8 simde_mm_subs_epu8 #define subtract_saturate_epu8 simde_mm_subs_epu8
#define subtract_epi8 simde_mm_sub_epi8
#define create_zero_integer simde_mm_setzero_si128 #define create_zero_integer simde_mm_setzero_si128
#define sum_bytes sum_bytes_128 #define sum_bytes sum_bytes_128
@ -83,6 +84,7 @@ FUNC(is_zero)(const integer_t a) { return simde_mm_testz_si128(a, a); }
#define extract_lower_half_as_chars simde_mm256_cvtepu8_epi32 #define extract_lower_half_as_chars simde_mm256_cvtepu8_epi32
#define blendv_epi8 simde_mm256_blendv_epi8 #define blendv_epi8 simde_mm256_blendv_epi8
#define subtract_saturate_epu8 simde_mm256_subs_epu8 #define subtract_saturate_epu8 simde_mm256_subs_epu8
#define subtract_epi8 simde_mm256_sub_epi8
#define shift_left_by_bits16 simde_mm256_slli_epi16 #define shift_left_by_bits16 simde_mm256_slli_epi16
#define shift_right_by_bits32 simde_mm256_srli_epi32 #define shift_right_by_bits32 simde_mm256_srli_epi32
#define create_zero_integer simde_mm256_setzero_si256 #define create_zero_integer simde_mm256_setzero_si256
@ -363,8 +365,8 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
d->num_consumed += src_sz + 1; // esc is also consumed d->num_consumed += src_sz + 1; // esc is also consumed
} else d->num_consumed += src_sz; } else d->num_consumed += src_sz;
// use scalar decode for short input and check that all bytes are less than 0xf4 // use scalar decode for short input
if (src_sz < 4 || !is_zero(subtract_saturate_epu8(vec, set1_epi8(0xf4)))) { if (src_sz < 4) {
scalar_decode_all(d, src, src_sz); return sentinel_found; scalar_decode_all(d, src, src_sz); return sentinel_found;
} }
@ -378,7 +380,8 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
// Check if we have pure ASCII and use fast path // Check if we have pure ASCII and use fast path
print_register_as_bytes(vec); print_register_as_bytes(vec);
if (!movemask_epi8(vec)) { // no bytes with high bit (0x80) set, so just plain ASCII int32_t ascii_mask = movemask_epi8(vec);
if (!ascii_mask) { // no bytes with high bit (0x80) set, so just plain ASCII
FUNC(output_plain_ascii)(d, vec, src_sz); FUNC(output_plain_ascii)(d, vec, src_sz);
if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes); if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes);
return sentinel_found; return sentinel_found;
@ -400,7 +403,7 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
print_register_as_bytes(mask); print_register_as_bytes(mask);
integer_t count = and_si(state, set1_epi8(0x7)); // keep lower 3 bits of state integer_t count = and_si(state, set1_epi8(0x7)); // keep lower 3 bits of state
print_register_as_bytes(count); print_register_as_bytes(count);
const integer_t one = set1_epi8(1), two = set1_epi8(2), three = set1_epi8(3); const integer_t zero = create_zero_integer(), one = set1_epi8(1), two = set1_epi8(2), three = set1_epi8(3);
// count contains the number of bytes in the sequence for the start byte of every sequence and zero elsewhere // count contains the number of bytes in the sequence for the start byte of every sequence and zero elsewhere
// shift 02 bytes by 1 and subtract 1 // shift 02 bytes by 1 and subtract 1
integer_t count_subs1 = subtract_saturate_epu8(count, one); integer_t count_subs1 = subtract_saturate_epu8(count, one);
@ -409,6 +412,11 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
counts = add_epi8(counts, shift_right_by_two_bytes(subtract_saturate_epu8(counts, two))); counts = add_epi8(counts, shift_right_by_two_bytes(subtract_saturate_epu8(counts, two)));
// counts now contains the number of bytes remaining in each utf-8 sequence of 2 or more bytes // counts now contains the number of bytes remaining in each utf-8 sequence of 2 or more bytes
print_register_as_bytes(counts); print_register_as_bytes(counts);
// Only ASCII chars should have corresponding byte of counts == 0
if (ascii_mask ^ movemask_epi8(cmpgt_epi8(counts, zero))) goto invalid_utf8;
// The difference between a byte in counts and the next one should be negative,
// zero, or one. Any other value means there is not enough continuation bytes.
if (movemask_epi8(cmpgt_epi8(subtract_epi8(shift_right_by_one_byte(counts), counts), one))) goto invalid_utf8;
// Process the bytes storing the three resulting bytes that make up the unicode codepoint // Process the bytes storing the three resulting bytes that make up the unicode codepoint
// mask all control bits so that we have only useful bits left // mask all control bits so that we have only useful bits left
@ -419,7 +427,7 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
// The lowest byte is made up of 6 bits from locations with counts == 1 and the lowest two bits from locations with count == 2 // The lowest byte is made up of 6 bits from locations with counts == 1 and the lowest two bits from locations with count == 2
// In addition, the ASCII bytes are copied unchanged from vec // In addition, the ASCII bytes are copied unchanged from vec
integer_t vec_non_ascii = andnot_si(cmpeq_epi8(counts, create_zero_integer()), vec); integer_t vec_non_ascii = andnot_si(cmpeq_epi8(counts, zero), vec);
print_register_as_bytes(vec_non_ascii); print_register_as_bytes(vec_non_ascii);
integer_t vec_right1 = shift_right_by_one_byte(vec_non_ascii); integer_t vec_right1 = shift_right_by_one_byte(vec_non_ascii);
integer_t output1 = blendv_epi8(vec, integer_t output1 = blendv_epi8(vec,
@ -492,6 +500,9 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
FUNC(output_unicode)(d, output1, output2, output3, num_codepoints); FUNC(output_unicode)(d, output1, output2, output3, num_codepoints);
if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes); if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes);
return sentinel_found; return sentinel_found;
invalid_utf8:
scalar_decode_all(d, src, src_sz + num_of_trailing_bytes);
return sentinel_found;
} }
@ -529,6 +540,7 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
#undef blendv_epi8 #undef blendv_epi8
#undef add_epi8 #undef add_epi8
#undef subtract_saturate_epu8 #undef subtract_saturate_epu8
#undef subtract_epi8
#undef create_zero_integer #undef create_zero_integer
#undef shuffle_epi8 #undef shuffle_epi8
#undef numbered_bytes #undef numbered_bytes

View File

@ -232,9 +232,31 @@ def double_test(x):
x(b'abcd\xc3', b'\xa41234') x(b'abcd\xc3', b'\xa41234')
x(b'abcd\xe2', b'\x89\xa41234') x(b'abcd\xe2', b'\x89\xa41234')
x(b'abcd\xe2\x89', b'\xa41234') x(b'abcd\xe2\x89', b'\xa41234')
# various invalid input
x(b'abcd\xf51234\xffABCD') # bytes > 0xf4
def test_invalid(src, expected, which=2):
reset_state()
_, actual = test_utf8_decode_to_sentinel(b'filler' + asbytes(src), which)
expected = 'filler' + expected
self.ae(expected, actual, f'Failed for: {src!r} with {which=}')
# various invalid input
for which in (1, 2, 3):
pb = partial(test_invalid, which=which)
pb(b'abcd\xf51234', 'abcd\ufffd1234') # bytes > 0xf4
pb(b'abcd\xff1234', 'abcd\ufffd1234') # bytes > 0xf4
pb(b'"\xbf"', '"\ufffd"')
pb(b'"\x80"', '"\ufffd"')
pb(b'"\x80\xbf"', '"\ufffd\ufffd"')
pb(b'"\x80\xbf\x80"', '"\ufffd\ufffd\ufffd"')
pb(b'"\xc0 "', '"\ufffd "')
pb(b'"\xfe"', '"\ufffd"')
pb(b'"\xff"', '"\ufffd"')
pb(b'"\xff\xfe"', '"\ufffd\ufffd"')
pb(b'"\xfe\xfe\xff\xff"', '"\ufffd\ufffd\ufffd\ufffd"')
pb(b'"\xef\xbf"', '"\ufffd"')
pb(b'"\xe0\xa0"', '"\ufffd"')
pb(b'"\xf0\x9f\x98"', '"\ufffd"')
pb(b'"\xef\x93\x94\x95"', '"\uf4d4\ufffd"')
def test_esc_codes(self): def test_esc_codes(self):
s = self.create_screen() s = self.create_screen()