Handle trailing incomplete sequences

This commit is contained in:
Kovid Goyal 2024-01-12 14:47:30 +05:30
parent 4238fedee7
commit 4c8b8caead
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 62 additions and 28 deletions

View File

@ -9,6 +9,7 @@
#endif #endif
#include "simd-string.h" #include "simd-string.h"
#include "charsets.h"
// Boilerplate {{{ // Boilerplate {{{
#ifdef __clang__ #ifdef __clang__
@ -231,7 +232,7 @@ FUNC(output_plain_ascii)(UTF8Decoder *d, integer_t vec, size_t src_sz) {
} }
} }
#endif #endif
d->output_sz = src_sz; d->output_sz += src_sz;
} }
static inline void static inline void
@ -292,38 +293,55 @@ sum_bytes_128(simde__m128i v) {
return lower_sum + upper_sum; // Final sum of all bytes return lower_sum + upper_sum; // Final sum of all bytes
} }
static void #define do_one_byte \
scalar_decode_to_accept(UTF8Decoder *d, const uint8_t *src, size_t src_sz) { const uint8_t ch = src[pos++]; \
while (d->num_consumed < src_sz && d->output_sz < arraysz(d->output) && d->state.cur != UTF8_ACCEPT) { switch (decode_utf8(&d->state.cur, &d->state.codep, ch)) { \
const uint8_t ch = src[d->num_consumed++]; case UTF8_ACCEPT: \
switch(ch) { d->output[d->output_sz++] = d->state.codep; \
case UTF8_ACCEPT: break; \
d->output[d->output_sz++] = d->state.codep; case UTF8_REJECT: { \
break; const bool prev_was_accept = d->state.prev == UTF8_ACCEPT; \
case UTF8_REJECT: { zero_at_ptr(&d->state); \
const bool prev_was_accept = d->state.prev == UTF8_ACCEPT; d->output[d->output_sz++] = 0xfffd; \
zero_at_ptr(&d->state); if (!prev_was_accept) { \
d->output[d->output_sz++] = 0xfffd; pos--; \
if (!prev_was_accept && d->num_consumed) { continue; /* so that prev is correct */ \
d->num_consumed--; } \
continue; // so that prev is correct } break; \
} } \
} break;
}
d->state.prev = d->state.cur; d->state.prev = d->state.cur;
static inline size_t
scalar_decode_to_accept(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
size_t pos = 0;
while (pos < src_sz && d->output_sz < arraysz(d->output) && d->state.cur != UTF8_ACCEPT) {
do_one_byte
} }
return pos;
} }
static inline size_t
scalar_decode_all(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
size_t pos = 0;
while (pos < src_sz && d->output_sz < arraysz(d->output)) {
do_one_byte
}
return pos;
}
#undef do_one_byte
#endif #endif
static inline bool static inline bool
FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) { FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
// Based on the algorithm described in: https://woboq.com/blog/utf-8-processing-using-simd.html // Based on the algorithm described in: https://woboq.com/blog/utf-8-processing-using-simd.html
d->output_sz = 0; d->num_consumed = 0; d->output_sz = 0; d->num_consumed = 0;
if (d->state.cur != UTF8_ACCEPT) { if (d->state.cur != UTF8_ACCEPT) {
scalar_decode_to_accept(d, src, src_sz); // Finish the trailing sequence only, we will be called again to process the rest allows use of aligned stores since output
// is not pre-filled.
d->num_consumed = scalar_decode_to_accept(d, src, src_sz);
src += d->num_consumed; src_sz -= d->num_consumed; src += d->num_consumed; src_sz -= d->num_consumed;
if (!src_sz) return false; return false;
} }
src_sz = MIN(src_sz, sizeof(integer_t)); src_sz = MIN(src_sz, sizeof(integer_t));
integer_t vec = load_unaligned((integer_t*)src); integer_t vec = load_unaligned((integer_t*)src);
@ -336,14 +354,25 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
if (esc_test_mask && (num_of_bytes_to_first_esc = __builtin_ctz(esc_test_mask)) < src_sz) { if (esc_test_mask && (num_of_bytes_to_first_esc = __builtin_ctz(esc_test_mask)) < src_sz) {
sentinel_found = true; sentinel_found = true;
src_sz = num_of_bytes_to_first_esc; src_sz = num_of_bytes_to_first_esc;
d->num_consumed = src_sz + 1; // esc is also consumed d->num_consumed += src_sz + 1; // esc is also consumed
} else d->num_consumed = src_sz; } else d->num_consumed += src_sz;
// use scalar decode for short input
if (src_sz < 4) { scalar_decode_all(d, src, src_sz); return sentinel_found; }
// check for an incomplete trailing utf8 sequence
unsigned num_of_trailing_bytes = 0;
if (src[src_sz-1] >= 0xc0) num_of_trailing_bytes = 1; // 2-, 3- and 4-byte characters with only 1 byte left
else if (src[src_sz-2] >= 0xe0) num_of_trailing_bytes = 2; // 3- and 4-byte characters with only 1 byte left
else if (src[src_sz-3] >= 0xf0) num_of_trailing_bytes = 3; // 4-byte characters with only 3 bytes left
src_sz -= num_of_trailing_bytes;
if (src_sz < sizeof(integer_t)) vec = zero_last_n_bytes(vec, sizeof(integer_t) - src_sz); if (src_sz < sizeof(integer_t)) vec = zero_last_n_bytes(vec, sizeof(integer_t) - src_sz);
// Check if we have pure ASCII and use fast path // Check if we have pure ASCII and use fast path
print_register_as_bytes(vec); print_register_as_bytes(vec);
if (!movemask_epi8(vec)) { // no bytes with high bit (0x80) set, so just plain ASCII if (!movemask_epi8(vec)) { // no bytes with high bit (0x80) set, so just plain ASCII
FUNC(output_plain_ascii)(d, vec, src_sz); FUNC(output_plain_ascii)(d, vec, src_sz);
if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes);
return sentinel_found; return sentinel_found;
} }
// Classify the bytes // Classify the bytes
@ -453,6 +482,7 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
const unsigned num_codepoints = src_sz - num_of_discarded_bytes; const unsigned num_codepoints = src_sz - num_of_discarded_bytes;
debug("num_of_discarded_bytes: %u num_codepoints: %u\n", num_of_discarded_bytes, num_codepoints); debug("num_of_discarded_bytes: %u num_codepoints: %u\n", num_of_discarded_bytes, num_codepoints);
FUNC(output_unicode)(d, output1, output2, output3, num_codepoints); FUNC(output_unicode)(d, output1, output2, output3, num_codepoints);
if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes);
return sentinel_found; return sentinel_found;
} }

View File

@ -205,7 +205,7 @@ class TestParser(BaseTest):
actual = parse_parts(1) actual = parse_parts(1)
reset_state() reset_state()
expected = parse_parts(which) expected = parse_parts(which)
self.ae(expected, actual, msg=f'Failed for {x!r} with {which=}\n{expected!r} !=\n{actual!r}') self.ae(expected, actual, msg=f'Failed for {a} with {which=}\n{expected!r} !=\n{actual!r}')
def double_test(x): def double_test(x):
for which in (2, 3): for which in (2, 3):
@ -222,8 +222,12 @@ class TestParser(BaseTest):
x('abcd1234efgh5678ijklABCDmnopEFGH') x('abcd1234efgh5678ijklABCDmnopEFGH')
for which in (2, 3): for which in (2, 3):
t('abcdef', 'ghij') x = partial(t, which=which)
t('2:α3', ':≤4:😸|') x('abcdef', 'ghijk')
x('2:α3', ':≤4:😸|')
# trailing incomplete sequence
x(b'abcd\xf0\x9f', b'\x98\xb81234')
def test_esc_codes(self): def test_esc_codes(self):
s = self.create_screen() s = self.create_screen()