diff --git a/kitty/simd-string-impl.h b/kitty/simd-string-impl.h index b6ba3976d..9b2555c97 100644 --- a/kitty/simd-string-impl.h +++ b/kitty/simd-string-impl.h @@ -9,6 +9,7 @@ #endif #include "simd-string.h" +#include "charsets.h" // Boilerplate {{{ #ifdef __clang__ @@ -231,7 +232,7 @@ FUNC(output_plain_ascii)(UTF8Decoder *d, integer_t vec, size_t src_sz) { } } #endif - d->output_sz = src_sz; + d->output_sz += src_sz; } static inline void @@ -292,38 +293,55 @@ sum_bytes_128(simde__m128i v) { return lower_sum + upper_sum; // Final sum of all bytes } -static void +#define do_one_byte \ + const uint8_t ch = src[pos++]; \ + switch (decode_utf8(&d->state.cur, &d->state.codep, ch)) { \ + case UTF8_ACCEPT: \ + d->output[d->output_sz++] = d->state.codep; \ + break; \ + case UTF8_REJECT: { \ + const bool prev_was_accept = d->state.prev == UTF8_ACCEPT; \ + zero_at_ptr(&d->state); \ + d->output[d->output_sz++] = 0xfffd; \ + if (!prev_was_accept) { \ + pos--; \ + continue; /* so that prev is correct */ \ + } \ + } break; \ + } \ + d->state.prev = d->state.cur; + +static inline size_t scalar_decode_to_accept(UTF8Decoder *d, const uint8_t *src, size_t src_sz) { - while (d->num_consumed < src_sz && d->output_sz < arraysz(d->output) && d->state.cur != UTF8_ACCEPT) { - const uint8_t ch = src[d->num_consumed++]; - switch(ch) { - case UTF8_ACCEPT: - d->output[d->output_sz++] = d->state.codep; - break; - case UTF8_REJECT: { - const bool prev_was_accept = d->state.prev == UTF8_ACCEPT; - zero_at_ptr(&d->state); - d->output[d->output_sz++] = 0xfffd; - if (!prev_was_accept && d->num_consumed) { - d->num_consumed--; - continue; // so that prev is correct - } - } break; - } - d->state.prev = d->state.cur; + size_t pos = 0; + while (pos < src_sz && d->output_sz < arraysz(d->output) && d->state.cur != UTF8_ACCEPT) { + do_one_byte } + return pos; } +static inline size_t +scalar_decode_all(UTF8Decoder *d, const uint8_t *src, size_t src_sz) { + size_t pos = 0; + while (pos < src_sz && d->output_sz < arraysz(d->output)) { + do_one_byte + } + return pos; +} +#undef do_one_byte #endif static inline bool FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) { // Based on the algorithm described in: https://woboq.com/blog/utf-8-processing-using-simd.html + d->output_sz = 0; d->num_consumed = 0; if (d->state.cur != UTF8_ACCEPT) { - scalar_decode_to_accept(d, src, src_sz); + // Finish the trailing sequence only, we will be called again to process the rest allows use of aligned stores since output + // is not pre-filled. + d->num_consumed = scalar_decode_to_accept(d, src, src_sz); src += d->num_consumed; src_sz -= d->num_consumed; - if (!src_sz) return false; + return false; } src_sz = MIN(src_sz, sizeof(integer_t)); integer_t vec = load_unaligned((integer_t*)src); @@ -336,14 +354,25 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) { if (esc_test_mask && (num_of_bytes_to_first_esc = __builtin_ctz(esc_test_mask)) < src_sz) { sentinel_found = true; src_sz = num_of_bytes_to_first_esc; - d->num_consumed = src_sz + 1; // esc is also consumed - } else d->num_consumed = src_sz; + d->num_consumed += src_sz + 1; // esc is also consumed + } else d->num_consumed += src_sz; + + // use scalar decode for short input + if (src_sz < 4) { scalar_decode_all(d, src, src_sz); return sentinel_found; } + + // check for an incomplete trailing utf8 sequence + unsigned num_of_trailing_bytes = 0; + if (src[src_sz-1] >= 0xc0) num_of_trailing_bytes = 1; // 2-, 3- and 4-byte characters with only 1 byte left + else if (src[src_sz-2] >= 0xe0) num_of_trailing_bytes = 2; // 3- and 4-byte characters with only 1 byte left + else if (src[src_sz-3] >= 0xf0) num_of_trailing_bytes = 3; // 4-byte characters with only 3 bytes left + src_sz -= num_of_trailing_bytes; if (src_sz < sizeof(integer_t)) vec = zero_last_n_bytes(vec, sizeof(integer_t) - src_sz); // Check if we have pure ASCII and use fast path print_register_as_bytes(vec); if (!movemask_epi8(vec)) { // no bytes with high bit (0x80) set, so just plain ASCII FUNC(output_plain_ascii)(d, vec, src_sz); + if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes); return sentinel_found; } // Classify the bytes @@ -453,6 +482,7 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) { const unsigned num_codepoints = src_sz - num_of_discarded_bytes; debug("num_of_discarded_bytes: %u num_codepoints: %u\n", num_of_discarded_bytes, num_codepoints); FUNC(output_unicode)(d, output1, output2, output3, num_codepoints); + if (num_of_trailing_bytes) scalar_decode_all(d, src + src_sz, num_of_trailing_bytes); return sentinel_found; } diff --git a/kitty/simd-string.c b/kitty/simd-string.c index c05a7428d..272c2ca8d 100644 --- a/kitty/simd-string.c +++ b/kitty/simd-string.c @@ -83,7 +83,7 @@ test_utf8_decode_to_sentinel(PyObject *self UNUSED, PyObject *args) { if (!PyArg_ParseTuple(args, "s#|i", &src, &src_sz, &which_function)) return NULL; bool found_sentinel = false; bool(*func)(UTF8Decoder*, const uint8_t*, size_t sz) = utf8_decode_to_esc; - switch(which_function) { + switch (which_function) { case -1: zero_at_ptr(&d); Py_RETURN_NONE; case 1: @@ -95,7 +95,7 @@ test_utf8_decode_to_sentinel(PyObject *self UNUSED, PyObject *args) { } RAII_PyObject(ans, PyUnicode_FromString("")); ssize_t p = 0; - while(p < src_sz && !found_sentinel) { + while (p < src_sz && !found_sentinel) { found_sentinel = func(&d, src + p, src_sz - p); p += d.num_consumed; if (d.output_sz) { diff --git a/kitty_tests/parser.py b/kitty_tests/parser.py index d07a3935a..6a4f93560 100644 --- a/kitty_tests/parser.py +++ b/kitty_tests/parser.py @@ -205,7 +205,7 @@ class TestParser(BaseTest): actual = parse_parts(1) reset_state() expected = parse_parts(which) - self.ae(expected, actual, msg=f'Failed for {x!r} with {which=}\n{expected!r} !=\n{actual!r}') + self.ae(expected, actual, msg=f'Failed for {a} with {which=}\n{expected!r} !=\n{actual!r}') def double_test(x): for which in (2, 3): @@ -222,8 +222,12 @@ class TestParser(BaseTest): x('abcd1234efgh5678ijklABCDmnopEFGH') for which in (2, 3): - t('abcdef', 'ghij') - t('2:α3', ':≤4:😸|') + x = partial(t, which=which) + x('abcdef', 'ghijk') + x('2:α3', ':≤4:😸|') + # trailing incomplete sequence + x(b'abcd\xf0\x9f', b'\x98\xb81234') + def test_esc_codes(self): s = self.create_screen()