Handle trailing incomplete sequences

This commit is contained in:
Kovid Goyal 2024-01-12 14:47:30 +05:30
parent 4238fedee7
commit 4c8b8caead
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 62 additions and 28 deletions

View File

@ -9,6 +9,7 @@
#endif #endif
#include "simd-string.h" #include "simd-string.h"
#include "charsets.h"
// Boilerplate {{{ // Boilerplate {{{
#ifdef __clang__ #ifdef __clang__
@ -231,7 +232,7 @@ FUNC(output_plain_ascii)(UTF8Decoder *d, integer_t vec, size_t src_sz) {
} }
} }
#endif #endif
d->output_sz = src_sz; d->output_sz += src_sz;
} }
static inline void static inline void
@ -292,38 +293,55 @@ sum_bytes_128(simde__m128i v) {
return lower_sum + upper_sum; // Final sum of all bytes return lower_sum + upper_sum; // Final sum of all bytes
} }
static void #define do_one_byte \
const uint8_t ch = src[pos++]; \
switch (decode_utf8(&d->state.cur, &d->state.codep, ch)) { \
case UTF8_ACCEPT: \
d->output[d->output_sz++] = d->state.codep; \
break; \
case UTF8_REJECT: { \
const bool prev_was_accept = d->state.prev == UTF8_ACCEPT; \
zero_at_ptr(&d->state); \
d->output[d->output_sz++] = 0xfffd; \
if (!prev_was_accept) { \
pos--; \
continue; /* so that prev is correct */ \
} \
} break; \
} \
d->state.prev = d->state.cur;
static inline size_t
scalar_decode_to_accept(UTF8Decoder *d, const uint8_t *src, size_t src_sz) { scalar_decode_to_accept(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
while (d->num_consumed < src_sz && d->output_sz < arraysz(d->output) && d->state.cur != UTF8_ACCEPT) { size_t pos = 0;
const uint8_t ch = src[d->num_consumed++]; while (pos < src_sz && d->output_sz < arraysz(d->output) && d->state.cur != UTF8_ACCEPT) {
switch(ch) { do_one_byte
case UTF8_ACCEPT:
d->output[d->output_sz++] = d->state.codep;
break;
case UTF8_REJECT: {
const bool prev_was_accept = d->state.prev == UTF8_ACCEPT;
zero_at_ptr(&d->state);
d->output[d->output_sz++] = 0xfffd;
if (!prev_was_accept && d->num_consumed) {
d->num_consumed--;
continue; // so that prev is correct
}
} break;
}
d->state.prev = d->state.cur;
} }
return pos;
} }
static inline size_t
scalar_decode_all(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
size_t pos = 0;
while (pos < src_sz && d->output_sz < arraysz(d->output)) {
do_one_byte
}
return pos;
}
#undef do_one_byte
#endif #endif
static inline bool static inline bool
FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) { FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
// Based on the algorithm described in: https://woboq.com/blog/utf-8-processing-using-simd.html // Based on the algorithm described in: https://woboq.com/blog/utf-8-processing-using-simd.html
d->output_sz = 0; d->num_consumed = 0; d->output_sz = 0; d->num_consumed = 0;
if (d->state.cur != UTF8_ACCEPT) { if (d->state.cur != UTF8_ACCEPT) {
scalar_decode_to_accept(d, src, src_sz); // Finish the trailing sequence only, we will be called again to process the rest allows use of aligned stores since output
// is not pre-filled.
d->num_consumed = scalar_decode_to_accept(d, src, src_sz);
src += d->num_consumed; src_sz -= d->num_consumed; src += d->num_consumed; src_sz -= d->num_consumed;
if (!src_sz) return false; return false;
} }
src_sz = MIN(src_sz, sizeof(integer_t)); src_sz = MIN(src_sz, sizeof(integer_t));
integer_t vec = load_unaligned((integer_t*)src); integer_t vec = load_unaligned((integer_t*)src);
@ -336,14 +354,25 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
if (esc_test_mask && (num_of_bytes_to_first_esc = __builtin_ctz(esc_test_mask)) < src_sz) { if (esc_test_mask && (num_of_bytes_to_first_esc = __builtin_ctz(esc_test_mask)) < src_sz) {
sentinel_found = true; sentinel_found = true;
src_sz = num_of_bytes_to_first_esc; src_sz = num_of_bytes_to_first_esc;
d->num_consumed = src_sz + 1; // esc is also consumed d->num_consumed += src_sz + 1; // esc is also consumed
} else d->num_consumed = src_sz; } else d->num_consumed += src_sz;
// use scalar decode for short input
if (src_sz < 4) { scalar_decode_all(d, src, src_sz); return sentinel_found; }
// check for an incomplete trailing utf8 sequence
unsigned num_of_trailing_bytes = 0;
if (src[src_sz-1] >= 0xc0) num_of_trailing_bytes = 1; // 2-, 3- and 4-byte characters with only 1 byte left
else if (src[src_sz-2] >= 0xe0) num_of_trailing_bytes = 2; // 3- and 4-byte characters with only 1 byte left
else if (src[src_sz-3] >= 0xf0) num_of_trailing_bytes = 3; // 4-byte characters with only 3 bytes left
src_sz -= num_of_trailing_bytes;
if (src_sz < sizeof(integer_t)) vec = zero_last_n_bytes(vec, sizeof(integer_t) - src_sz); if (src_sz < sizeof(integer_t)) vec = zero_last_n_bytes(vec, sizeof(integer_t) - src_sz);
// Check if we have pure ASCII and use fast path // Check if we have pure ASCII and use fast path
print_register_as_bytes(vec); print_register_as_bytes(vec);
if (!movemask_epi8(vec)) { // no bytes with high bit (0x80) set, so just plain ASCII if (!movemask_epi8(vec)) { // no bytes with high bit (0x80) set, so just plain ASCII
FUNC(output_plain_ascii)(d, vec, src_sz); FUNC(output_plain_ascii)(d, vec, src_sz);
if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes);
return sentinel_found; return sentinel_found;
} }
// Classify the bytes // Classify the bytes
@ -453,6 +482,7 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
const unsigned num_codepoints = src_sz - num_of_discarded_bytes; const unsigned num_codepoints = src_sz - num_of_discarded_bytes;
debug("num_of_discarded_bytes: %u num_codepoints: %u\n", num_of_discarded_bytes, num_codepoints); debug("num_of_discarded_bytes: %u num_codepoints: %u\n", num_of_discarded_bytes, num_codepoints);
FUNC(output_unicode)(d, output1, output2, output3, num_codepoints); FUNC(output_unicode)(d, output1, output2, output3, num_codepoints);
if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes);
return sentinel_found; return sentinel_found;
} }

View File

@ -83,7 +83,7 @@ test_utf8_decode_to_sentinel(PyObject *self UNUSED, PyObject *args) {
if (!PyArg_ParseTuple(args, "s#|i", &src, &src_sz, &which_function)) return NULL; if (!PyArg_ParseTuple(args, "s#|i", &src, &src_sz, &which_function)) return NULL;
bool found_sentinel = false; bool found_sentinel = false;
bool(*func)(UTF8Decoder*, const uint8_t*, size_t sz) = utf8_decode_to_esc; bool(*func)(UTF8Decoder*, const uint8_t*, size_t sz) = utf8_decode_to_esc;
switch(which_function) { switch (which_function) {
case -1: case -1:
zero_at_ptr(&d); Py_RETURN_NONE; zero_at_ptr(&d); Py_RETURN_NONE;
case 1: case 1:
@ -95,7 +95,7 @@ test_utf8_decode_to_sentinel(PyObject *self UNUSED, PyObject *args) {
} }
RAII_PyObject(ans, PyUnicode_FromString("")); RAII_PyObject(ans, PyUnicode_FromString(""));
ssize_t p = 0; ssize_t p = 0;
while(p < src_sz && !found_sentinel) { while (p < src_sz && !found_sentinel) {
found_sentinel = func(&d, src + p, src_sz - p); found_sentinel = func(&d, src + p, src_sz - p);
p += d.num_consumed; p += d.num_consumed;
if (d.output_sz) { if (d.output_sz) {

View File

@ -205,7 +205,7 @@ class TestParser(BaseTest):
actual = parse_parts(1) actual = parse_parts(1)
reset_state() reset_state()
expected = parse_parts(which) expected = parse_parts(which)
self.ae(expected, actual, msg=f'Failed for {x!r} with {which=}\n{expected!r} !=\n{actual!r}') self.ae(expected, actual, msg=f'Failed for {a} with {which=}\n{expected!r} !=\n{actual!r}')
def double_test(x): def double_test(x):
for which in (2, 3): for which in (2, 3):
@ -222,8 +222,12 @@ class TestParser(BaseTest):
x('abcd1234efgh5678ijklABCDmnopEFGH') x('abcd1234efgh5678ijklABCDmnopEFGH')
for which in (2, 3): for which in (2, 3):
t('abcdef', 'ghij') x = partial(t, which=which)
t('2:α3', ':≤4:😸|') x('abcdef', 'ghijk')
x('2:α3', ':≤4:😸|')
# trailing incomplete sequence
x(b'abcd\xf0\x9f', b'\x98\xb81234')
def test_esc_codes(self): def test_esc_codes(self):
s = self.create_screen() s = self.create_screen()