Use a fast SIMD implementation to XOR data going into the disk cache

This commit is contained in:
Kovid Goyal 2024-02-13 13:18:34 +05:30
parent 88f3c8c5ee
commit ad3ab877f8
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
5 changed files with 72 additions and 26 deletions

View File

@ -11,6 +11,7 @@
#include "disk-cache.h"
#include "safe-wrappers.h"
#include "kitty-uthash.h"
#include "simd-string.h"
#include "loop-utils.h"
#include "fast-file-copy.h"
#include "threading.h"
@ -57,17 +58,6 @@ typedef struct {
unsigned long long total_size;
} DiskCache;
static void
xor_data(const uint8_t* restrict key, const size_t key_sz, uint8_t* restrict data, const size_t data_sz) {
size_t unaligned_sz = data_sz % key_sz;
size_t aligned_sz = data_sz - unaligned_sz;
for (size_t offset = 0; offset < aligned_sz; offset += key_sz) {
for (size_t i = 0; i < key_sz; i++) data[offset + i] ^= key[i];
}
for (size_t i = 0; i < unaligned_sz; i++) data[aligned_sz + i] ^= key[i];
}
void
free_cache_entry(CacheEntry *e) {
if (e->hash_key) { free(e->hash_key); e->hash_key = NULL; }
@ -285,7 +275,7 @@ find_cache_entry_to_write(DiskCache *self) {
s->data = NULL;
self->currently_writing.data_sz = s->data_sz;
self->currently_writing.pos_in_cache_file = -1;
xor_data(s->encryption_key, sizeof(s->encryption_key), self->currently_writing.data, s->data_sz);
xor_data64(s->encryption_key, self->currently_writing.data, s->data_sz);
self->currently_writing.hash_keylen = MIN(s->hash_keylen, MAX_KEY_SIZE);
memcpy(self->currently_writing.hash_key, s->hash_key, self->currently_writing.hash_keylen);
find_hole_to_use(self, self->currently_writing.data_sz);
@ -612,11 +602,11 @@ read_from_disk_cache(PyObject *self_, const void *key, size_t key_sz, void*(allo
if (s->data) { memcpy(data, s->data, s->data_sz); }
else if (self->currently_writing.data && self->currently_writing.hash_key && self->currently_writing.hash_keylen == key_sz && memcmp(self->currently_writing.hash_key, key, key_sz) == 0) {
memcpy(data, self->currently_writing.data, s->data_sz);
xor_data(s->encryption_key, sizeof(s->encryption_key), data, s->data_sz);
xor_data64(s->encryption_key, data, s->data_sz);
}
else {
read_from_cache_entry(self, s, data);
xor_data(s->encryption_key, sizeof(s->encryption_key), data, s->data_sz);
xor_data64(s->encryption_key, data, s->data_sz);
}
if (store_in_ram && !s->data && s->data_sz) {
void *copy = malloc(s->data_sz);
@ -697,16 +687,15 @@ PYWRAP(ensure_state) {
Py_RETURN_NONE;
}
PYWRAP(xor_data) {
PYWRAP(xor_data64) {
(void) self;
const char *key, *data;
Py_ssize_t keylen, data_sz;
PA("y#y#", &key, &keylen, &data, &data_sz);
PyObject *ans = PyBytes_FromStringAndSize(NULL, data_sz);
PA("s#s#", &key, &keylen, &data, &data_sz);
if (keylen != 64) { PyErr_SetString(PyExc_TypeError, "key must be 64 bytes long"); return NULL; }
PyObject *ans = PyBytes_FromStringAndSize(data, data_sz);
if (ans == NULL) return NULL;
void *dest = PyBytes_AS_STRING(ans);
memcpy(dest, data, data_sz);
xor_data((const uint8_t*)key, keylen, dest, data_sz);
xor_data64((const uint8_t*)key, (uint8_t*)PyBytes_AS_STRING(ans), data_sz);
return ans;
}
@ -848,7 +837,7 @@ PyTypeObject DiskCache_Type = {
};
static PyMethodDef module_methods[] = {
MW(xor_data, METH_VARARGS),
MW(xor_data64, METH_VARARGS),
{NULL, NULL, 0, NULL} /* Sentinel */
};

View File

@ -7,6 +7,7 @@
#pragma once
#include "data-types.h"
#include "simd-string.h"
#include <stdalign.h>
#ifndef KITTY_SIMD_LEVEL
#define KITTY_SIMD_LEVEL 128
@ -19,6 +20,7 @@
#define NOSIMD { fatal("No SIMD implementations for this CPU"); }
bool FUNC(utf8_decode_to_esc)(UTF8Decoder *d UNUSED, const uint8_t *src UNUSED, size_t src_sz UNUSED) NOSIMD
const uint8_t* FUNC(find_either_of_two_bytes)(const uint8_t *haystack UNUSED, const size_t sz UNUSED, const uint8_t a UNUSED, const uint8_t b UNUSED) NOSIMD
void FUNC(xor_data64)(const uint8_t key[64], uint8_t* data, const size_t data_sz);
#undef NOSIMD
#else
@ -52,11 +54,13 @@ END_IGNORE_DIAGNOSTIC
#define load_unaligned simde_mm_loadu_si128
#define load_aligned(x) simde_mm_load_si128((const integer_t*)(x))
#define store_unaligned simde_mm_storeu_si128
#define store_aligned(dest, vec) simde_mm_store_si128((integer_t*)dest, vec)
#define cmpeq_epi8 simde_mm_cmpeq_epi8
#define cmplt_epi8 simde_mm_cmplt_epi8
#define cmpgt_epi8 simde_mm_cmpgt_epi8
#define or_si simde_mm_or_si128
#define and_si simde_mm_and_si128
#define xor_si simde_mm_xor_si128
#define andnot_si simde_mm_andnot_si128
#define movemask_epi8 simde_mm_movemask_epi8
#define extract_lower_quarter_as_chars simde_mm_cvtepu8_epi32
@ -118,11 +122,13 @@ w(left, sixteen_bytes, 16)
#define load_unaligned simde_mm256_loadu_si256
#define load_aligned(x) simde_mm256_load_si256((const integer_t*)(x))
#define store_unaligned simde_mm256_storeu_si256
#define store_aligned(dest, vec) simde_mm256_store_si256((integer_t*)dest, vec)
#define cmpeq_epi8 simde_mm256_cmpeq_epi8
#define cmpgt_epi8 simde_mm256_cmpgt_epi8
#define cmplt_epi8(a, b) cmpgt_epi8(b, a)
#define or_si simde_mm256_or_si256
#define and_si simde_mm256_and_si256
#define xor_si simde_mm256_xor_si256
#define andnot_si simde_mm256_andnot_si256
#define movemask_epi8 simde_mm256_movemask_epi8
#define extract_lower_half_as_chars simde_mm256_cvtepu8_epi32
@ -307,6 +313,42 @@ zero_last_n_bytes(const integer_t vec, const char n) {
return and_si(mask, vec);
}
#define KEY_SIZE 64
void
FUNC(xor_data64)(const uint8_t key[KEY_SIZE], uint8_t* data, const size_t data_sz) {
// First process unaligned bytes at the start of data
const uintptr_t unaligned_bytes = KEY_SIZE - ((uintptr_t)data & (KEY_SIZE - 1));
if (data_sz <= unaligned_bytes) { for (unsigned i = 0; i < data_sz; i++) data[i] ^= key[i]; return; }
for (unsigned i = 0; i < unaligned_bytes; i++) data[i] ^= key[i];
// Rotate the key by unaligned_bytes
alignas(sizeof(integer_t)) char aligned_key[KEY_SIZE];
memcpy(aligned_key, key + unaligned_bytes, KEY_SIZE - unaligned_bytes);
memcpy(aligned_key + KEY_SIZE - unaligned_bytes, key, unaligned_bytes);
const integer_t v1 = load_aligned(aligned_key), v2 = load_aligned(aligned_key + sizeof(integer_t));
#if KITTY_SIMD_LEVEL == 128
const integer_t v3 = load_aligned(aligned_key + 2*sizeof(integer_t)), v4 = load_aligned(aligned_key + 3 * sizeof(integer_t));
#endif
// Process KEY_SIZE aligned chunks using SIMD
integer_t d;
uint8_t *p = data + unaligned_bytes, *limit = data + data_sz;
const uintptr_t trailing_bytes = (uintptr_t)limit & (KEY_SIZE - 1);
limit -= trailing_bytes;
#define do_one(which) d = load_aligned(p); store_aligned(p, xor_si(which, d)); p += sizeof(integer_t);
while (p < limit) {
do_one(v1); do_one(v2);
#if KITTY_SIMD_LEVEL == 128
do_one(v3); do_one(v4);
#endif
}
#undef do_one
// Process remaining trailing_bytes
for (unsigned i = 0; i < trailing_bytes; i++) limit[i] ^= aligned_key[i];
zero_upper(); return;
}
#undef KEY_SIZE
#define check_chunk() if (n > -1) { \
const uint8_t *ans = haystack + n; \
zero_upper(); \
@ -716,11 +758,13 @@ start_classification:
#undef load_unaligned
#undef load_aligned
#undef store_unaligned
#undef store_aligned
#undef cmpeq_epi8
#undef cmplt_epi8
#undef cmpgt_epi8
#undef or_si
#undef and_si
#undef xor_si
#undef andnot_si
#undef movemask_epi8
#undef CONCAT

View File

@ -10,6 +10,12 @@
#include "simd-string.h"
static bool has_sse4_2 = false, has_avx2 = false;
// xor_data64 {{{
static void xor_data64_scalar(const uint8_t key[64], uint8_t* data, const size_t data_sz) { for (size_t i = 0; i < data_sz; i++) data[i] ^= key[i & 63]; }
static void (*xor_data64_impl)(const uint8_t key[64], uint8_t* data, const size_t data_sz) = xor_data64_scalar;
void xor_data64(const uint8_t key[64], uint8_t* data, const size_t data_sz) { xor_data64_impl(key, data, data_sz); }
// }}}
// find_either_of_two_bytes {{{
static const uint8_t*
find_either_of_two_bytes_scalar(const uint8_t *haystack, const size_t sz, const uint8_t x, const uint8_t y) {
@ -188,6 +194,7 @@ init_simd(void *x) {
A(has_avx2, True);
find_either_of_two_bytes_impl = find_either_of_two_bytes_256;
utf8_decode_to_esc_impl = utf8_decode_to_esc_256;
xor_data64_impl = xor_data64_256;
} else {
A(has_avx2, False);
}
@ -195,6 +202,7 @@ init_simd(void *x) {
A(has_sse4_2, True);
if (find_either_of_two_bytes_impl == find_either_of_two_bytes_scalar) find_either_of_two_bytes_impl = find_either_of_two_bytes_128;
if (utf8_decode_to_esc_impl == utf8_decode_to_esc_scalar) utf8_decode_to_esc_impl = utf8_decode_to_esc_128;
if (xor_data64_impl == xor_data64_scalar) xor_data64_impl = xor_data64_128;
} else {
A(has_sse4_2, False);
}

View File

@ -46,8 +46,13 @@ bool init_simd(void* module);
// two chars or NULL if not found.
const uint8_t* find_either_of_two_bytes(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b);
// XOR data with the 64 byte key
void xor_data64(const uint8_t key[64], uint8_t* data, const size_t data_sz);
// SIMD implementations, internal use
bool utf8_decode_to_esc_128(UTF8Decoder *d, const uint8_t *src, size_t src_sz);
bool utf8_decode_to_esc_256(UTF8Decoder *d, const uint8_t *src, size_t src_sz);
const uint8_t* find_either_of_two_bytes_128(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b);
const uint8_t* find_either_of_two_bytes_256(const uint8_t *haystack, const size_t sz, const uint8_t a, const uint8_t b);
void xor_data64_128(const uint8_t key[64], uint8_t* data, const size_t data_sz);
void xor_data64_256(const uint8_t key[64], uint8_t* data, const size_t data_sz);

View File

@ -12,7 +12,7 @@
from io import BytesIO
from itertools import cycle
from kitty.fast_data_types import base64_decode, base64_encode, load_png_data, shm_unlink, shm_write, xor_data
from kitty.fast_data_types import base64_decode, base64_encode, load_png_data, shm_unlink, shm_write, xor_data64
from . import BaseTest, parse_bytes
@ -189,12 +189,12 @@ def xor(skey, data):
ckey = cycle(bytearray(skey))
return bytes(bytearray(k ^ d for k, d in zip(ckey, bytearray(data))))
base_data = os.urandom(64)
key = os.urandom(len(base_data))
for base in (b'', base_data):
base_data = os.urandom(61)
key = os.urandom(64)
for base in (b'', base_data, base_data * 3):
for extra in range(len(base_data)):
data = base + base_data[:extra]
self.assertEqual(xor_data(key, data), xor(key, data))
self.assertEqual(xor(key, data), xor_data64(key, data))
def test_disk_cache(self):
s = self.create_screen()