diff --git a/Kernel/Heap/kmalloc.cpp b/Kernel/Heap/kmalloc.cpp index 613669d5b45..9171004a475 100644 --- a/Kernel/Heap/kmalloc.cpp +++ b/Kernel/Heap/kmalloc.cpp @@ -305,15 +305,19 @@ size_t kmalloc_good_size(size_t size) return size; } -[[gnu::malloc, gnu::alloc_size(1), gnu::alloc_align(2)]] static void* kmalloc_aligned_cxx(size_t size, size_t alignment) +void* kmalloc_aligned(size_t size, size_t alignment) { VERIFY(alignment <= 4096); - void* ptr = kmalloc(size + alignment + sizeof(ptrdiff_t)); + Checked real_allocation_size = size; + real_allocation_size += alignment; + real_allocation_size += sizeof(ptrdiff_t) + sizeof(size_t); + void* ptr = kmalloc(real_allocation_size.value()); if (ptr == nullptr) return nullptr; size_t max_addr = (size_t)ptr + alignment; void* aligned_ptr = (void*)(max_addr - (max_addr % alignment)); ((ptrdiff_t*)aligned_ptr)[-1] = (ptrdiff_t)((u8*)aligned_ptr - (u8*)ptr); + ((size_t*)aligned_ptr)[-2] = real_allocation_size.value(); return aligned_ptr; } @@ -331,14 +335,14 @@ void* operator new(size_t size, const std::nothrow_t&) noexcept void* operator new(size_t size, std::align_val_t al) { - void* ptr = kmalloc_aligned_cxx(size, (size_t)al); + void* ptr = kmalloc_aligned(size, (size_t)al); VERIFY(ptr); return ptr; } void* operator new(size_t size, std::align_val_t al, const std::nothrow_t&) noexcept { - return kmalloc_aligned_cxx(size, (size_t)al); + return kmalloc_aligned(size, (size_t)al); } void* operator new[](size_t size) diff --git a/Kernel/Heap/kmalloc.h b/Kernel/Heap/kmalloc.h index 73b9880ca8d..5ebfe918071 100644 --- a/Kernel/Heap/kmalloc.h +++ b/Kernel/Heap/kmalloc.h @@ -17,11 +17,11 @@ public: \ [[nodiscard]] void* operator new(size_t) \ { \ - void* ptr = kmalloc_aligned(sizeof(type)); \ + void* ptr = kmalloc_aligned(sizeof(type), alignment); \ VERIFY(ptr); \ return ptr; \ } \ - [[nodiscard]] void* operator new(size_t, const std::nothrow_t&) noexcept { return kmalloc_aligned(sizeof(type)); } \ + [[nodiscard]] void* operator new(size_t, const std::nothrow_t&) noexcept { return kmalloc_aligned(sizeof(type), alignment); } \ void operator delete(void* ptr) noexcept { kfree_aligned(ptr); } \ \ private: @@ -76,25 +76,13 @@ void operator delete[](void* ptr, size_t) noexcept; [[gnu::malloc, gnu::alloc_size(1)]] void* kmalloc(size_t); -template -[[gnu::malloc, gnu::alloc_size(1)]] inline void* kmalloc_aligned(size_t size) -{ - static_assert(ALIGNMENT > sizeof(ptrdiff_t)); - static_assert(ALIGNMENT <= 4096); - void* ptr = kmalloc(size + ALIGNMENT + sizeof(ptrdiff_t)); - if (ptr == nullptr) - return ptr; - size_t max_addr = (size_t)ptr + ALIGNMENT; - void* aligned_ptr = (void*)(max_addr - (max_addr % ALIGNMENT)); - ((ptrdiff_t*)aligned_ptr)[-1] = (ptrdiff_t)((u8*)aligned_ptr - (u8*)ptr); - return aligned_ptr; -} +[[gnu::malloc, gnu::alloc_size(1), gnu::alloc_align(2)]] void* kmalloc_aligned(size_t size, size_t alignment); inline void kfree_aligned(void* ptr) { if (ptr == nullptr) return; - kfree((u8*)ptr - ((const ptrdiff_t*)ptr)[-1]); + kfree_sized((u8*)ptr - ((ptrdiff_t const*)ptr)[-1], ((size_t const*)ptr)[-2]); } size_t kmalloc_good_size(size_t);