fix zero-ing of last n bytes

This commit is contained in:
Kovid Goyal 2024-01-10 22:50:20 +05:30
parent daa169b8ed
commit b28fbf6817
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 26 additions and 21 deletions

View File

@ -57,6 +57,7 @@ _Pragma("clang diagnostic pop")
#define shift_right_by_bits32 simde_mm_srli_epi32
#define shuffle_epi8 simde_mm_shuffle_epi8
#define numbered_bytes() set_epi8(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)
#define reverse_numbered_bytes() simde_mm_setr_epi8(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)
// output[i] = MAX(0, a[i] - b[1i])
#define subtract_saturate_epu8 simde_mm_subs_epu8
#define create_zero_integer simde_mm_setzero_si128
@ -93,6 +94,7 @@ _Pragma("clang diagnostic pop")
#define shift_left_by_eight_bytes(vec) simde_mm256_alignr_epi8(vec, simde_mm256_permute2x128_si256(vec, vec, _MM_SHUFFLE(2, 0, 0, 1)), 8)
#define shift_left_by_sixteen_bytes(vec) simde_mm256_permute2x128_si256(vec, vec, _MM_SHUFFLE(2, 0, 0, 1))
#define numbered_bytes() set_epi8(31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)
#define reverse_numbered_bytes() simde_mm256_setr_epi8(31,30,29,28,27,26,25,24,23,22,21,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0)
static inline integer_t shuffle_impl256(const integer_t value, const integer_t shuffle) {
#define K0 simde_mm256_setr_epi8( \
@ -114,13 +116,31 @@ static inline integer_t shuffle_impl256(const integer_t value, const integer_t s
#define shuffle_epi8 shuffle_impl256
#define sum_bytes(x) (sum_bytes_128(simde_mm256_extracti128_si256(vec, 0)) + sum_bytes_128(simde_mm256_extracti128_si256(vec, 1)))
#endif
#if 1
#define print_register_as_bytes(r) { \
printf("%s:\n", #r); \
alignas(64) uint8_t data[sizeof(r)]; \
store_aligned((integer_t*)data, r); \
for (unsigned i = 0; i < sizeof(integer_t); i++) { \
uint8_t ch = data[i]; \
if (' ' <= ch && ch < 0x7f) printf("_%c ", ch); else printf("%.2x ", ch); \
} \
printf("\n"); \
}
#define debug printf
#else
#define print_register_as_bytes(r)
#define debug(...)
#endif
// }}}
static inline integer_t
FUNC(zero_last_n_bytes)(integer_t vec, int n) {
FUNC(zero_last_n_bytes)(integer_t vec, char n) {
const integer_t threshold = set1_epi8(n);
const integer_t index = numbered_bytes();
return andnot_si(cmpgt_epi8(threshold, index), vec);
const integer_t index = reverse_numbered_bytes();
const integer_t mask = cmpgt_epi8(threshold, index);
return andnot_si(mask, vec);
}
static inline const uint8_t*
@ -141,23 +161,6 @@ FUNC(find_either_of_two_bytes)(const uint8_t *haystack, const size_t sz, const u
return NULL;
}
#if 1
#define print_register_as_bytes(r) { \
printf("%s:\n", #r); \
alignas(64) uint8_t data[sizeof(r)]; \
store_aligned((integer_t*)data, r); \
for (unsigned i = 0; i < sizeof(integer_t); i++) { \
uint8_t ch = data[i]; \
if (' ' <= ch && ch < 0x7f) printf("_%c ", ch); else printf("%.2x ", ch); \
} \
printf("\n"); \
}
#define debug printf
#else
#define print_register_as_bytes(r)
#define debug(...)
#endif
static inline void
FUNC(output_plain_ascii)(UTF8Decoder *d, integer_t vec, size_t src_sz) {
#if BITS == 128
@ -293,7 +296,7 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
src_sz = num_of_bytes_to_first_esc;
d->num_consumed = src_sz + 1; // esc is also consumed
} else d->num_consumed = src_sz;
if (src_sz < sizeof(integer_t)/8) zero_last_n_bytes(vec, sizeof(integer_t)/8 - src_sz);
if (src_sz < sizeof(integer_t)) vec = zero_last_n_bytes(vec, sizeof(integer_t) - src_sz);
const integer_t one = set1_epi8(1), two = set1_epi8(2), three = set1_epi8(3);
// Classify the bytes
@ -452,6 +455,7 @@ FUNC(utf8_decode_to_esc)(UTF8Decoder *d, const uint8_t *src, size_t src_sz) {
#undef create_zero_integer
#undef shuffle_epi8
#undef numbered_bytes
#undef reverse_numbered_bytes
#undef zero_last_n_bytes
#undef sum_bytes
#ifndef SIMD_STRING_IMPL_INCLUDED_ONCE

View File

@ -189,6 +189,7 @@ def t(x, which=2, reset=True):
self.ae(expected, actual, msg=f'Failed for {x!r} with {which=}\n{expected!r} !=\n{actual!r}')
for which in (2, 3):
x = partial(t, which=which)
x('2:α3')
x('2:α3:≤4:😸|')
x('abcd1234efgh5678')
x('abc\x1bd1234efgh5678')