More work on utf8 SIMD decode

This commit is contained in:
Kovid Goyal 2024-01-10 22:19:04 +05:30
parent a5251bedc9
commit daa169b8ed
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 42 additions and 20 deletions

View File

@ -60,6 +60,7 @@ _Pragma("clang diagnostic pop")
// output[i] = MAX(0, a[i] - b[1i])
#define subtract_saturate_epu8 simde_mm_subs_epu8
#define create_zero_integer simde_mm_setzero_si128
#define sum_bytes sum_bytes_128
#else
@ -109,7 +110,9 @@ static inline integer_t shuffle_impl256(const integer_t value, const integer_t s
#undef K0
#undef K1
}
#define shuffle_epi8 shuffle_impl256
#define sum_bytes(x) (sum_bytes_128(simde_mm256_extracti128_si256(vec, 0)) + sum_bytes_128(simde_mm256_extracti128_si256(vec, 1)))
#endif
// }}}
@ -149,8 +152,10 @@ FUNC(find_either_of_two_bytes)(const uint8_t *haystack, const size_t sz, const u
} \
printf("\n"); \
}
#define debug printf
#else
#define print_register_as_bytes(r)
#define debug(...)
#endif
static inline void
@ -190,9 +195,10 @@ FUNC(output_unicode)(UTF8Decoder *d, integer_t output1, integer_t output2, integ
#if BITS == 128
for (const uint32_t *limit = d->output + num_codepoints, *p = d->output; p < limit; p += sizeof(integer_t)/sizeof(uint32_t)) {
const integer_t unpacked1 = extract_lower_quarter_as_chars(output1);
const integer_t unpacked2 = shift_left_by_one_byte(extract_lower_quarter_as_chars(output2));
const integer_t unpacked3 = shift_left_by_two_bytes(extract_lower_quarter_as_chars(output3));
store_aligned((integer_t*)p, or_si(or_si(unpacked1, unpacked2), unpacked3));
const integer_t unpacked2 = shift_right_by_one_byte(extract_lower_quarter_as_chars(output2));
const integer_t unpacked3 = shift_right_by_two_bytes(extract_lower_quarter_as_chars(output3));
const integer_t unpacked = or_si(or_si(unpacked1, unpacked2), unpacked3);
store_aligned((integer_t*)p, unpacked);
output1 = shift_right_by_bytes128(output1, sizeof(integer_t)/sizeof(d->output[0]));
output2 = shift_right_by_bytes128(output2, sizeof(integer_t)/sizeof(d->output[0]));
output3 = shift_right_by_bytes128(output3, sizeof(integer_t)/sizeof(d->output[0]));
@ -203,8 +209,8 @@ FUNC(output_unicode)(UTF8Decoder *d, integer_t output1, integer_t output2, integ
simde__m128i x1, x2, x3;
#define chunk() { \
const integer_t unpacked1 = extract_lower_half_as_chars(x1); \
const integer_t unpacked2 = shift_left_by_one_byte(extract_lower_half_as_chars(x2)); \
const integer_t unpacked3 = shift_left_by_two_bytes(extract_lower_half_as_chars(x3)); \
const integer_t unpacked2 = shift_right_by_one_byte(extract_lower_half_as_chars(x2)); \
const integer_t unpacked3 = shift_right_by_two_bytes(extract_lower_half_as_chars(x3)); \
store_aligned((integer_t*)p, or_si(or_si(unpacked1, unpacked2), unpacked3)); \
p += sizeof(integer_t)/sizeof(uint32_t); \
}
@ -228,6 +234,20 @@ FUNC(output_unicode)(UTF8Decoder *d, integer_t output1, integer_t output2, integ
}
#ifndef SIMD_STRING_IMPL_INCLUDED_ONCE
static inline unsigned
sum_bytes_128(simde__m128i v) {
// Use _mm_sad_epu8 to perform a sum of absolute differences against zero
// This sums up all 8-bit integers in the 128-bit vector and packs the result into a 64-bit integer
simde__m128i sum = simde_mm_sad_epu8(v, simde_mm_setzero_si128());
// At this point, the sum of the first half is in the lower 64 bits, and the sum of the second half is in the upper 64 bits.
// Extract the lower and upper 64-bit sums and add them together.
const unsigned lower_sum = simde_mm_cvtsi128_si32(sum); // Extracts the lower 32 bits
const unsigned upper_sum = simde_mm_cvtsi128_si32(simde_mm_srli_si128(sum, 8)); // Extracts the upper 32 bits
return lower_sum + upper_sum; // Final sum of all bytes
}
static void
scalar_decode_to_accept(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
while (d->num_consumed < src_sz && d->output_sz < arraysz(d->output) && d->state.cur != UTF8_ACCEPT) {
@ -301,15 +321,14 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
print_register_as_bytes(mask);
integer_t count = and_si(state, set1_epi8(0x7)); // keep lower 3 bits of state
print_register_as_bytes(count);
// count contains 0 for ASCII and number of bytes in sequence for other bytes
#define subtract_shift_and_add(target, amt, s) add_epi8(target, shift_right_by_##amt(subtract_saturate_epu8(target, s)))
// count contains the number of bytes in the sequence for the start byte of every sequence and zero elsewhere
// shift 02 bytes by 1 and subtract 1
integer_t counts = subtract_shift_and_add(count, one_byte, one);
integer_t count_subs1 = subtract_saturate_epu8(count, one);
integer_t counts = add_epi8(count, shift_right_by_one_byte(count_subs1));
// shift 03 and 04 bytes by 2 and subtract 2
counts = subtract_shift_and_add(counts, two_bytes, two);
counts = add_epi8(counts, shift_right_by_two_bytes(subtract_saturate_epu8(counts, two)));
// counts now contains the number of bytes remaining in each utf-8 sequence of 2 or more bytes
print_register_as_bytes(counts);
#undef subtract_shift_and_add
// Process the bytes storing the three resulting bytes that make up the unicode codepoint
// mask all control bits so that we have only useful bits left
@ -350,7 +369,7 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
print_register_as_bytes(output3);
// Shuffle bytes to remove continuation bytes
integer_t shifts = subtract_saturate_epu8(count, one); // number of bytes we need to skip for each UTF-8 sequence
integer_t shifts = count_subs1; // number of bytes we need to skip for each UTF-8 sequence
// propagate the shifts to all subsequent bytes by shift and add
shifts = add_epi8(shifts, shift_right_by_one_byte(shifts));
shifts = add_epi8(shifts, shift_right_by_two_bytes(shifts));
@ -387,6 +406,11 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
print_register_as_bytes(output1);
print_register_as_bytes(output2);
print_register_as_bytes(output3);
const unsigned num_of_discarded_bytes = sum_bytes(count_subs1);
const unsigned num_codepoints = src_sz - num_of_discarded_bytes;
debug("num_of_discarded_bytes: %u num_codepoints: %u\n", num_of_discarded_bytes, num_codepoints);
FUNC(output_unicode)(d, output1, output2, output3, num_codepoints);
return sentinel_found;
}
@ -429,6 +453,7 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
#undef shuffle_epi8
#undef numbered_bytes
#undef zero_last_n_bytes
#undef sum_bytes
#ifndef SIMD_STRING_IMPL_INCLUDED_ONCE
#define SIMD_STRING_IMPL_INCLUDED_ONCE
#endif

View File

@ -181,21 +181,18 @@ def test_utf8_parsing(self):
pb(b'"\xf0\x9f\x98"', '"\ufffd"')
def test_utf8_simd_decode(self):
test_utf8_decode_to_sentinel('2:α3:≤4:😸|', 2)
return
def t(x, which=2, reset=True):
if reset:
test_utf8_decode_to_sentinel(b'', -1)
expected = test_utf8_decode_to_sentinel(x, 1)
actual = test_utf8_decode_to_sentinel(x, which)
self.ae(expected, actual)
self.ae(expected, actual, msg=f'Failed for {x!r} with {which=}\n{expected!r} !=\n{actual!r}')
for which in (2, 3):
with self.subTest(which=which):
x = partial(t, which=which)
x('abcd1234efgh5678')
x('abc\x1bd1234efgh5678')
x('abcd1234efgh5678ijklABCDmnopEFGH')
x = partial(t, which=which)
x('2:α3:≤4:😸|')
x('abcd1234efgh5678')
x('abc\x1bd1234efgh5678')
x('abcd1234efgh5678ijklABCDmnopEFGH')
def test_esc_codes(self):
s = self.create_screen()