ares/rust/ares_guard/c-src/guard.c

266 lines
6.2 KiB
C
Raw Normal View History

2024-02-10 08:43:20 +03:00
#include <assert.h>
#include <errno.h>
#include <setjmp.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
#include <unistd.h>
#include "guard.h"
#define GD_PAGE_BITS 14ULL
#define GD_PAGE_SIZE (1ULL << GD_PAGE_BITS) // 16 KB
#define GD_PAGE_MASK (GD_PAGE_SIZE - 1)
#define GD_PAGE_ROUND_DOWN(foo) (foo & (~GD_PAGE_MASK))
2024-02-19 21:26:21 +03:00
typedef struct GD_state GD_state;
struct GD_state {
uintptr_t guard_p;
const uintptr_t *stack_pp;
const uintptr_t *alloc_pp;
GD_buflistnode *buffer_list;
struct sigaction prev_sigsegv_sa;
struct sigaction prev_sigbus_sa;
};
2024-02-19 21:47:39 +03:00
static GD_state gd = {
.guard_p = 0,
.stack_pp = NULL,
.alloc_pp = NULL,
.buffer_list = NULL,
.prev_sigsegv_sa = { .sa_sigaction = NULL, .sa_flags = 0 },
.prev_sigbus_sa = { .sa_sigaction = NULL, .sa_flags = 0 },
};
2024-02-10 08:43:20 +03:00
2024-02-16 21:23:17 +03:00
static guard_result
2024-02-10 08:43:20 +03:00
_prot_page(void *address, int prot)
{
if (mprotect(address, GD_PAGE_SIZE, prot)) {
fprintf(stderr, "guard: prot: mprotect error %d\r\n", errno);
fprintf(stderr, "%s\r\n", strerror(errno));
return guard_mprotect | errno;
}
2024-02-19 20:54:43 +03:00
return guard_success;
2024-02-10 08:43:20 +03:00
}
2024-02-16 21:23:17 +03:00
static guard_result
2024-02-10 08:43:20 +03:00
_mark_page(void *address)
{
return _prot_page(address, PROT_NONE);
}
2024-02-16 21:23:17 +03:00
static guard_result
2024-02-10 08:43:20 +03:00
_unmark_page(void *address)
{
return _prot_page(address, PROT_READ | PROT_WRITE);
}
// Center the guard page.
2024-02-16 21:23:17 +03:00
static guard_result
_focus_guard()
{
2024-02-19 21:26:21 +03:00
uintptr_t stack_p = *gd.stack_pp;
uintptr_t alloc_p = *gd.alloc_pp;
uintptr_t old_guard_p = gd.guard_p;
uintptr_t new_guard_p;
2024-02-16 21:23:17 +03:00
guard_result err = 0;
2024-02-10 08:43:20 +03:00
if (stack_p == 0 || alloc_p == 0) {
fprintf(stderr, "guard: focus: stack or alloc pointer is null\r\n");
return guard_null;
} else if (stack_p == alloc_p) {
return guard_oom;
}
2024-02-19 20:54:43 +03:00
// Compute new guard page.
2024-02-10 08:43:20 +03:00
new_guard_p = GD_PAGE_ROUND_DOWN((stack_p + alloc_p) / 2);
if (new_guard_p == old_guard_p) {
return guard_oom;
}
2024-02-19 20:54:43 +03:00
// Mark new guard page.
2024-02-10 08:43:20 +03:00
if ((err = _mark_page((void *)new_guard_p))) {
fprintf(stderr, "guard: focus: mark error\r\n");
return err;
}
2024-02-19 20:54:43 +03:00
// Update guard page tracker.
2024-02-19 21:26:21 +03:00
gd.guard_p = new_guard_p;
2024-02-19 20:54:43 +03:00
// Unmark the old guard page if there is one.
if (old_guard_p) {
if ((err = _unmark_page((void *)old_guard_p))) {
fprintf(stderr, "guard: focus: unmark error\r\n");
return err;
}
}
2024-02-19 20:54:43 +03:00
return guard_success;
2024-02-10 08:43:20 +03:00
}
static void
2024-02-10 08:43:20 +03:00
_signal_handler(int sig, siginfo_t *si, void *unused)
{
uintptr_t sig_addr;
2024-02-16 21:23:17 +03:00
guard_result err = 0;
2024-02-10 08:43:20 +03:00
2024-02-19 21:26:21 +03:00
assert(gd.guard_p);
2024-02-18 04:27:13 +03:00
if (sig != SIGSEGV && sig != SIGBUS) {
fprintf(stderr, "guard: handler: invalid signal: %d\r\n", sig);
2024-02-16 21:23:17 +03:00
assert(0);
2024-02-10 08:43:20 +03:00
}
sig_addr = (uintptr_t)si->si_addr;
2024-02-19 21:26:21 +03:00
if (sig_addr >= gd.guard_p &&
sig_addr < gd.guard_p + GD_PAGE_SIZE)
2024-02-10 08:43:20 +03:00
{
err = _focus_guard();
2024-02-10 08:43:20 +03:00
if (err) {
2024-02-19 21:26:21 +03:00
siglongjmp(gd.buffer_list->buffer, err);
2024-02-10 08:43:20 +03:00
}
}
else {
2024-02-18 04:27:13 +03:00
switch (sig) {
case SIGSEGV: {
2024-02-19 21:26:21 +03:00
if (gd.prev_sigsegv_sa.sa_sigaction != NULL) {
gd.prev_sigsegv_sa.sa_sigaction(sig, si, unused);
} else if (gd.prev_sigsegv_sa.sa_handler != NULL) {
gd.prev_sigsegv_sa.sa_handler(sig);
2024-02-18 04:27:13 +03:00
} else {
assert(0);
}
break;
}
case SIGBUS: {
2024-02-19 21:26:21 +03:00
if (gd.prev_sigbus_sa.sa_sigaction != NULL) {
gd.prev_sigbus_sa.sa_sigaction(sig, si, unused);
} else if (gd.prev_sigbus_sa.sa_handler != NULL) {
gd.prev_sigbus_sa.sa_handler(sig);
2024-02-18 04:27:13 +03:00
} else {
assert(0);
}
}
2024-02-10 08:43:20 +03:00
}
}
}
2024-02-19 20:54:43 +03:00
// Registers the same handler function for SIGSEGV and SIGBUS.
2024-02-16 21:23:17 +03:00
static guard_result
2024-02-19 20:54:43 +03:00
_register_handlers()
2024-02-10 08:43:20 +03:00
{
struct sigaction sa;
sa.sa_flags = SA_SIGINFO;
sa.sa_sigaction = _signal_handler;
2024-02-18 04:27:13 +03:00
2024-02-19 21:26:21 +03:00
if (sigaction(SIGSEGV, &sa, &gd.prev_sigsegv_sa)) {
2024-02-18 04:27:13 +03:00
fprintf(stderr, "guard: register: sigaction error\r\n");
fprintf(stderr, "%s\r\n", strerror(errno));
return guard_sigaction | errno;
}
2024-02-19 21:26:21 +03:00
if (sigaction(SIGBUS, &sa, &gd.prev_sigbus_sa)) {
2024-02-10 08:43:20 +03:00
fprintf(stderr, "guard: register: sigaction error\r\n");
fprintf(stderr, "%s\r\n", strerror(errno));
return guard_sigaction | errno;
}
2024-02-19 20:54:43 +03:00
return guard_success;
2024-02-10 08:43:20 +03:00
}
2024-02-16 21:23:17 +03:00
guard_result
guard(
2024-02-16 21:23:17 +03:00
void *(*f)(void *),
void *closure,
const uintptr_t *const s_pp,
const uintptr_t *const a_pp,
2024-02-19 20:54:43 +03:00
void **ret
2024-02-10 08:43:20 +03:00
) {
2024-02-16 21:23:17 +03:00
GD_buflistnode *new_buffer;
guard_result err = 0;
guard_result td_err = 0;
2024-02-10 08:43:20 +03:00
2024-02-19 21:26:21 +03:00
if (gd.guard_p == 0) {
assert(gd.buffer_list == NULL);
2024-02-19 21:26:21 +03:00
gd.stack_pp = s_pp;
gd.alloc_pp = a_pp;
2024-02-19 20:54:43 +03:00
// Initialize the guard page.
if ((err = _focus_guard())) {
2024-02-16 21:32:11 +03:00
fprintf(stderr, "guard: initial focus error\r\n");
goto exit;
}
2024-02-19 20:54:43 +03:00
// Register guard page signal handler.
if ((err = _register_handlers())) {
2024-02-16 21:32:11 +03:00
fprintf(stderr, "guard: registration error\r\n");
goto clean;
}
} else {
2024-02-19 21:26:21 +03:00
assert(gd.buffer_list != NULL);
}
2024-02-19 20:54:43 +03:00
// Setup new longjmp buffer.
2024-02-16 21:23:17 +03:00
new_buffer = (GD_buflistnode *)malloc(sizeof(GD_buflistnode));
if (new_buffer == NULL) {
fprintf(stderr, "guard: malloc error\r\n");
fprintf(stderr, "%s\r\n", strerror(errno));
err = guard_malloc | errno;
goto skip;
2024-02-10 08:43:20 +03:00
}
2024-02-19 21:26:21 +03:00
new_buffer->next = gd.buffer_list;
gd.buffer_list = new_buffer;
2024-02-10 08:43:20 +03:00
2024-02-19 20:54:43 +03:00
// Run given closure.
2024-02-19 21:26:21 +03:00
if (!(err = sigsetjmp(gd.buffer_list->buffer, 1))) {
2024-02-10 08:43:20 +03:00
*ret = f(closure);
}
2024-02-10 08:43:20 +03:00
2024-02-19 20:54:43 +03:00
// Restore previous longjmp buffer.
2024-02-19 21:26:21 +03:00
gd.buffer_list = gd.buffer_list->next;
free((void *)new_buffer);
skip:
2024-02-19 21:26:21 +03:00
if (gd.buffer_list == NULL) {
if (sigaction(SIGSEGV, &gd.prev_sigsegv_sa, NULL)) {
2024-02-19 20:54:43 +03:00
fprintf(stderr, "guard: error replacing sigsegv handler\r\n");
2024-02-18 04:27:13 +03:00
fprintf(stderr, "%s\r\n", strerror(errno));
td_err = guard_sigaction | errno;
if (!err) {
err = td_err;
}
}
2024-02-19 21:26:21 +03:00
if (sigaction(SIGBUS, &gd.prev_sigbus_sa, NULL)) {
2024-02-19 20:54:43 +03:00
fprintf(stderr, "guard: error replacing sigbus handler\r\n");
fprintf(stderr, "%s\r\n", strerror(errno));
td_err = guard_sigaction | errno;
if (!err) {
err = td_err;
}
}
clean:
2024-02-19 20:54:43 +03:00
// Unmark guard page.
2024-02-19 21:26:21 +03:00
assert(gd.guard_p != 0);
td_err = _unmark_page((void *)gd.guard_p);
if (td_err) {
fprintf(stderr, "guard: unmark error\r\n");
fprintf(stderr, "%s\r\n", strerror(errno));
if (!err) {
err = td_err;
}
}
2024-02-19 21:26:21 +03:00
gd.guard_p = 0;
2024-02-10 08:43:20 +03:00
}
exit:
return err;
2024-02-10 08:43:20 +03:00
}