mirror of
https://github.com/diesel-rs/diesel.git
synced 2024-10-04 01:28:13 +03:00
Cleanup catch_unwind handling
This fixes the handling of panics in custom SQLite functions by having a `catch_unwind` before ffi the boundary and by correctly setting the required `UnwindSafe` trait bounds in a similar way as it is done by rusqlite.
This commit is contained in:
parent
3f90c9b886
commit
bfcf3d58ee
@ -205,6 +205,8 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/
|
||||
|
||||
* Queries containing a `distinct on` clause check now on compile time that a compatible order clause was set.
|
||||
|
||||
* Implementations of custom SQLite SQL functions now probably check for panics
|
||||
|
||||
### Deprecated
|
||||
|
||||
* `diesel_(prefix|postfix|infix)_operator!` have been deprecated. These macros
|
||||
|
@ -17,7 +17,7 @@ pub fn register<ArgsSqlType, RetSqlType, Args, Ret, F>(
|
||||
mut f: F,
|
||||
) -> QueryResult<()>
|
||||
where
|
||||
F: FnMut(&RawConnection, Args) -> Ret + Send + 'static + std::panic::RefUnwindSafe,
|
||||
F: FnMut(&RawConnection, Args) -> Ret + std::panic::UnwindSafe + Send + 'static,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
|
||||
Ret: ToSql<RetSqlType, Sqlite>,
|
||||
Sqlite: HasSqlType<RetSqlType>,
|
||||
@ -37,6 +37,7 @@ where
|
||||
|
||||
process_sql_function_result::<RetSqlType, Ret>(result)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -45,14 +46,8 @@ pub fn register_aggregate<ArgsSqlType, RetSqlType, Args, Ret, A>(
|
||||
fn_name: &str,
|
||||
) -> QueryResult<()>
|
||||
where
|
||||
A: SqliteAggregateFunction<Args, Output = Ret>
|
||||
+ 'static
|
||||
+ Send
|
||||
+ std::panic::UnwindSafe
|
||||
+ std::panic::RefUnwindSafe,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite>
|
||||
+ StaticallySizedRow<ArgsSqlType, Sqlite>
|
||||
+ std::panic::UnwindSafe,
|
||||
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
|
||||
Ret: ToSql<RetSqlType, Sqlite>,
|
||||
Sqlite: HasSqlType<RetSqlType>,
|
||||
{
|
||||
|
@ -223,10 +223,8 @@ impl SqliteConnection {
|
||||
mut f: F,
|
||||
) -> QueryResult<()>
|
||||
where
|
||||
F: FnMut(Args) -> Ret + Send + 'static + std::panic::RefUnwindSafe,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite>
|
||||
+ StaticallySizedRow<ArgsSqlType, Sqlite>
|
||||
+ std::panic::UnwindSafe,
|
||||
F: FnMut(Args) -> Ret + std::panic::UnwindSafe + Send + 'static,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
|
||||
Ret: ToSql<RetSqlType, Sqlite>,
|
||||
Sqlite: HasSqlType<RetSqlType>,
|
||||
{
|
||||
@ -244,14 +242,8 @@ impl SqliteConnection {
|
||||
fn_name: &str,
|
||||
) -> QueryResult<()>
|
||||
where
|
||||
A: SqliteAggregateFunction<Args, Output = Ret>
|
||||
+ 'static
|
||||
+ Send
|
||||
+ std::panic::UnwindSafe
|
||||
+ std::panic::RefUnwindSafe,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite>
|
||||
+ StaticallySizedRow<ArgsSqlType, Sqlite>
|
||||
+ std::panic::UnwindSafe,
|
||||
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite> + StaticallySizedRow<ArgsSqlType, Sqlite>,
|
||||
Ret: ToSql<RetSqlType, Sqlite>,
|
||||
Sqlite: HasSqlType<RetSqlType>,
|
||||
{
|
||||
@ -262,7 +254,8 @@ impl SqliteConnection {
|
||||
///
|
||||
/// `collation` must always return the same answer given the same inputs.
|
||||
/// If `collation` panics and unwinds the stack, the process is aborted, since it is used
|
||||
/// across a C FFI boundary, which cannot be unwound across.
|
||||
/// across a C FFI boundary, which cannot be unwound across and there is no way to
|
||||
/// signal failures via the SQLite interface in this case..
|
||||
///
|
||||
/// If the name is already registered it will be overwritten.
|
||||
///
|
||||
|
@ -98,13 +98,16 @@ impl RawConnection {
|
||||
) -> QueryResult<()>
|
||||
where
|
||||
F: FnMut(&Self, &[*mut ffi::sqlite3_value]) -> QueryResult<SerializedValue>
|
||||
+ std::panic::UnwindSafe
|
||||
+ Send
|
||||
+ 'static
|
||||
+ std::panic::RefUnwindSafe,
|
||||
+ 'static,
|
||||
{
|
||||
let callback_fn = Box::into_raw(Box::new(CustomFunctionUserPtr {
|
||||
callback: f,
|
||||
function_name: fn_name.to_owned(),
|
||||
}));
|
||||
let fn_name = Self::get_fn_name(fn_name)?;
|
||||
let flags = Self::get_flags(deterministic);
|
||||
let callback_fn = Box::into_raw(Box::new(f));
|
||||
|
||||
let result = unsafe {
|
||||
ffi::sqlite3_create_function_v2(
|
||||
@ -129,12 +132,8 @@ impl RawConnection {
|
||||
num_args: usize,
|
||||
) -> QueryResult<()>
|
||||
where
|
||||
A: SqliteAggregateFunction<Args, Output = Ret>
|
||||
+ 'static
|
||||
+ Send
|
||||
+ std::panic::UnwindSafe
|
||||
+ std::panic::RefUnwindSafe,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite> + std::panic::UnwindSafe,
|
||||
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite>,
|
||||
Ret: ToSql<RetSqlType, Sqlite>,
|
||||
Sqlite: HasSqlType<RetSqlType>,
|
||||
{
|
||||
@ -164,10 +163,13 @@ impl RawConnection {
|
||||
collation: F,
|
||||
) -> QueryResult<()>
|
||||
where
|
||||
F: Fn(&str, &str) -> std::cmp::Ordering + Send + 'static,
|
||||
F: Fn(&str, &str) -> std::cmp::Ordering + std::panic::UnwindSafe + Send + 'static,
|
||||
{
|
||||
let callback_fn = Box::into_raw(Box::new(CollationUserPtr {
|
||||
callback: collation,
|
||||
collation_name: collation_name.to_owned(),
|
||||
}));
|
||||
let collation_name = Self::get_fn_name(collation_name)?;
|
||||
let callback_fn = Box::into_raw(Box::new(collation));
|
||||
|
||||
let result = unsafe {
|
||||
ffi::sqlite3_create_collation_v2(
|
||||
@ -243,58 +245,95 @@ fn convert_to_string_and_free(err_msg: *const libc::c_char) -> String {
|
||||
msg
|
||||
}
|
||||
|
||||
enum SqliteCallbackError {
|
||||
Abort(&'static str),
|
||||
DieselError(crate::result::Error),
|
||||
Panic(Box<dyn std::any::Any + Send>, String),
|
||||
}
|
||||
|
||||
impl SqliteCallbackError {
|
||||
fn emit(&self, ctx: *mut ffi::sqlite3_context) {
|
||||
let s;
|
||||
let msg = match self {
|
||||
SqliteCallbackError::Abort(msg) => *msg,
|
||||
SqliteCallbackError::DieselError(e) => {
|
||||
s = e.to_string();
|
||||
&s
|
||||
}
|
||||
SqliteCallbackError::Panic(_, msg) => &msg,
|
||||
};
|
||||
unsafe {
|
||||
context_error_str(ctx, msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::result::Error> for SqliteCallbackError {
|
||||
fn from(e: crate::result::Error) -> Self {
|
||||
Self::DieselError(e)
|
||||
}
|
||||
}
|
||||
|
||||
struct CustomFunctionUserPtr<F> {
|
||||
callback: F,
|
||||
function_name: String,
|
||||
}
|
||||
|
||||
#[allow(warnings)]
|
||||
extern "C" fn run_custom_function<F>(
|
||||
extern "C" fn run_custom_function<'b, F>(
|
||||
ctx: *mut ffi::sqlite3_context,
|
||||
num_args: libc::c_int,
|
||||
value_ptr: *mut *mut ffi::sqlite3_value,
|
||||
) where
|
||||
F: FnMut(&RawConnection, &[*mut ffi::sqlite3_value]) -> QueryResult<SerializedValue>
|
||||
+ std::panic::UnwindSafe
|
||||
+ Send
|
||||
+ 'static
|
||||
+ std::panic::RefUnwindSafe,
|
||||
+ 'static,
|
||||
{
|
||||
use std::ops::Deref;
|
||||
static NULL_DATA_ERR: &str = "An unknown error occurred. sqlite3_user_data returned a null pointer. This should never happen.";
|
||||
static NULL_CONN_ERR: &str = "An unknown error occurred. sqlite3_context_db_handle returned a null pointer. This should never happen.";
|
||||
|
||||
let data_ptr = unsafe { ffi::sqlite3_user_data(ctx) };
|
||||
let data_ptr = data_ptr as *mut F;
|
||||
let f = match unsafe { data_ptr.as_mut() } {
|
||||
Some(f) => f,
|
||||
None => {
|
||||
unsafe { context_error_str(ctx, NULL_DATA_ERR) };
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let args = unsafe { slice::from_raw_parts(value_ptr, num_args as _) };
|
||||
let conn = match unsafe { NonNull::new(ffi::sqlite3_context_db_handle(ctx)) } {
|
||||
Some(conn) => RawConnection {
|
||||
// We use `ManuallyDrop` here because we do not want to run the
|
||||
// Drop impl of `RawConnection` as this would close the connection
|
||||
Some(conn) => mem::ManuallyDrop::new(RawConnection {
|
||||
internal_connection: conn,
|
||||
},
|
||||
}),
|
||||
None => {
|
||||
unsafe { context_error_str(ctx, NULL_DATA_ERR) };
|
||||
unsafe { context_error_str(ctx, NULL_CONN_ERR) };
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut f = std::panic::AssertUnwindSafe(f);
|
||||
let result = std::panic::catch_unwind(move || {
|
||||
use std::ops::DerefMut as _;
|
||||
let result = f.deref_mut()(&conn, args);
|
||||
mem::forget(conn);
|
||||
result
|
||||
});
|
||||
let data_ptr = unsafe { ffi::sqlite3_user_data(ctx) };
|
||||
|
||||
let mut data_ptr = match NonNull::new(data_ptr as *mut CustomFunctionUserPtr<F>) {
|
||||
None => unsafe {
|
||||
context_error_str(ctx, NULL_DATA_ERR);
|
||||
return;
|
||||
},
|
||||
Some(mut f) => f,
|
||||
};
|
||||
let data_ptr = unsafe { data_ptr.as_mut() };
|
||||
|
||||
// We need this to move the reference into the catch_unwind part
|
||||
// this is sound as `F` itself and the stored string is `UnwindSafe`
|
||||
let callback = std::panic::AssertUnwindSafe(&mut data_ptr.callback);
|
||||
|
||||
let result =
|
||||
std::panic::catch_unwind(move || Ok((callback.0)(&*conn, args)?)).unwrap_or_else(|p| {
|
||||
Err(SqliteCallbackError::Panic(
|
||||
p,
|
||||
data_ptr.function_name.clone(),
|
||||
))
|
||||
});
|
||||
match result {
|
||||
Ok(Ok(value)) => value.result_of(ctx),
|
||||
Ok(Err(e)) => {
|
||||
let msg = e.to_string();
|
||||
unsafe { context_error_str(ctx, &msg) };
|
||||
}
|
||||
Err(_) => {
|
||||
let msg = format!("{} panicked", std::any::type_name::<F>());
|
||||
unsafe { context_error_str(ctx, &msg) };
|
||||
Ok(value) => value.result_of(ctx),
|
||||
Err(e) => {
|
||||
e.emit(ctx);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -314,11 +353,39 @@ extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A
|
||||
num_args: libc::c_int,
|
||||
value_ptr: *mut *mut ffi::sqlite3_value,
|
||||
) where
|
||||
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::RefUnwindSafe,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite> + std::panic::UnwindSafe,
|
||||
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite>,
|
||||
Ret: ToSql<RetSqlType, Sqlite>,
|
||||
Sqlite: HasSqlType<RetSqlType>,
|
||||
{
|
||||
let args = unsafe { slice::from_raw_parts(value_ptr, num_args as _) };
|
||||
let result =
|
||||
std::panic::catch_unwind(move || run_aggregator_step::<A, Args, ArgsSqlType>(ctx, args))
|
||||
.unwrap_or_else(|e| {
|
||||
Err(SqliteCallbackError::Panic(
|
||||
e,
|
||||
format!("{}::step() paniced", std::any::type_name::<A>()),
|
||||
))
|
||||
});
|
||||
|
||||
match result {
|
||||
Ok(()) => {}
|
||||
Err(e) => e.emit(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
fn run_aggregator_step<A, Args, ArgsSqlType>(
|
||||
ctx: *mut ffi::sqlite3_context,
|
||||
args: &[*mut ffi::sqlite3_value],
|
||||
) -> Result<(), SqliteCallbackError>
|
||||
where
|
||||
A: SqliteAggregateFunction<Args>,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite>,
|
||||
{
|
||||
static NULL_AG_CTX_ERR: &str = "An unknown error occurred. sqlite3_aggregate_context returned a null pointer. This should never happen.";
|
||||
static NULL_CTX_ERR: &str =
|
||||
"We've written the aggregator to the aggregate context, but it could not be retrieved.";
|
||||
|
||||
let aggregate_context = unsafe {
|
||||
// This block of unsafe code makes the following assumptions:
|
||||
//
|
||||
@ -343,52 +410,37 @@ extern "C" fn run_aggregator_step_function<ArgsSqlType, RetSqlType, Args, Ret, A
|
||||
// we cannot guarantee it, so better be safe than sorry)
|
||||
ffi::sqlite3_aggregate_context(ctx, std::mem::size_of::<OptionalAggregator<A>>() as i32)
|
||||
};
|
||||
let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
|
||||
let aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
|
||||
let aggregator = unsafe {
|
||||
match aggregate_context.map(|a| &mut *a.as_ptr()) {
|
||||
Some(&mut OptionalAggregator::Some(ref mut agg)) => agg,
|
||||
Some(mut a_ptr @ &mut OptionalAggregator::None) => {
|
||||
Some(a_ptr @ &mut OptionalAggregator::None) => {
|
||||
ptr::write_unaligned(a_ptr as *mut _, OptionalAggregator::Some(A::default()));
|
||||
if let &mut OptionalAggregator::Some(ref mut agg) = a_ptr {
|
||||
agg
|
||||
} else {
|
||||
assert_fail!("We've written the aggregator to the aggregate context, but it could not be retrieved.");
|
||||
return Err(SqliteCallbackError::Abort(NULL_CTX_ERR));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
null_aggregate_context_error(ctx);
|
||||
return;
|
||||
return Err(SqliteCallbackError::Abort(NULL_AG_CTX_ERR));
|
||||
}
|
||||
}
|
||||
};
|
||||
let args = build_sql_function_args::<ArgsSqlType, Args>(args)?;
|
||||
|
||||
let args = unsafe { slice::from_raw_parts(value_ptr, num_args as _) };
|
||||
let mut aggregator = std::panic::AssertUnwindSafe(aggregator);
|
||||
let result = std::panic::catch_unwind(move || {
|
||||
build_sql_function_args::<ArgsSqlType, Args>(args).map(|args| Ok(aggregator.step(args)))
|
||||
})
|
||||
.unwrap_or_else(|e| Ok(Err(e)));
|
||||
match result {
|
||||
Ok(Ok(())) => (),
|
||||
Ok(Err(_)) => {
|
||||
let msg = format!("{}::step() panicked", std::any::type_name::<A>());
|
||||
unsafe { context_error_str(ctx, &msg) };
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = e.to_string();
|
||||
unsafe { context_error_str(ctx, &msg) };
|
||||
}
|
||||
};
|
||||
Ok(aggregator.step(args))
|
||||
}
|
||||
|
||||
extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret, A>(
|
||||
ctx: *mut ffi::sqlite3_context,
|
||||
) where
|
||||
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send + std::panic::UnwindSafe,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite> + std::panic::UnwindSafe,
|
||||
A: SqliteAggregateFunction<Args, Output = Ret> + 'static + Send,
|
||||
Args: FromSqlRow<ArgsSqlType, Sqlite>,
|
||||
Ret: ToSql<RetSqlType, Sqlite>,
|
||||
Sqlite: HasSqlType<RetSqlType>,
|
||||
{
|
||||
static NO_AGGREGATOR_FOUND: &str = "We've written to the aggregator in the xStep callback. If xStep was never called, then ffi::sqlite_aggregate_context() would have returned a NULL pointer.";
|
||||
let aggregate_context = unsafe {
|
||||
// Within the xFinal callback, it is customary to set nBytes to 0 so no pointless memory
|
||||
// allocations occur, a null pointer is returned in this case
|
||||
@ -399,43 +451,46 @@ extern "C" fn run_aggregator_final_function<ArgsSqlType, RetSqlType, Args, Ret,
|
||||
ffi::sqlite3_aggregate_context(ctx, 0)
|
||||
};
|
||||
|
||||
let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
|
||||
let aggregator = aggregate_context.as_mut().map(|a| {
|
||||
let a = unsafe { a.as_mut() };
|
||||
match std::mem::replace(a, OptionalAggregator::None) {
|
||||
OptionalAggregator::Some(agg) => agg,
|
||||
OptionalAggregator::None => {
|
||||
assert_fail!("We've written to the aggregator in the xStep callback. If xStep was never called, then ffi::sqlite_aggregate_context() would have returned a NULL pointer.");
|
||||
let result = std::panic::catch_unwind(|| {
|
||||
let mut aggregate_context = NonNull::new(aggregate_context as *mut OptionalAggregator<A>);
|
||||
|
||||
let aggregator = if let Some(a) = aggregate_context.as_mut() {
|
||||
let a = unsafe { a.as_mut() };
|
||||
match std::mem::replace(a, OptionalAggregator::None) {
|
||||
OptionalAggregator::None => {
|
||||
return Err(SqliteCallbackError::Abort(NO_AGGREGATOR_FOUND));
|
||||
}
|
||||
OptionalAggregator::Some(a) => Some(a),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let res = A::finalize(aggregator);
|
||||
Ok(process_sql_function_result::<RetSqlType, Ret>(res)?)
|
||||
})
|
||||
.unwrap_or_else(|e| {
|
||||
Err(SqliteCallbackError::Panic(
|
||||
e,
|
||||
format!("{}::finalize() paniced", std::any::type_name::<A>()),
|
||||
))
|
||||
});
|
||||
|
||||
let result = std::panic::catch_unwind(|| A::finalize(aggregator))
|
||||
.map(process_sql_function_result::<RetSqlType, Ret>);
|
||||
|
||||
match result {
|
||||
Ok(Ok(value)) => value.result_of(ctx),
|
||||
Ok(Err(e)) => {
|
||||
let msg = e.to_string();
|
||||
unsafe { context_error_str(ctx, &msg) };
|
||||
}
|
||||
Err(_) => {
|
||||
let msg = format!("{}::finalize() panicked", std::any::type_name::<A>());
|
||||
unsafe { context_error_str(ctx, &msg) };
|
||||
}
|
||||
Ok(value) => value.result_of(ctx),
|
||||
Err(e) => e.emit(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn null_aggregate_context_error(ctx: *mut ffi::sqlite3_context) {
|
||||
static NULL_AG_CTX_ERR: &str = "An unknown error occurred. sqlite3_aggregate_context returned a null pointer. This should never happen.";
|
||||
|
||||
context_error_str(ctx, NULL_AG_CTX_ERR)
|
||||
}
|
||||
|
||||
unsafe fn context_error_str(ctx: *mut ffi::sqlite3_context, error: &str) {
|
||||
ffi::sqlite3_result_error(ctx, error.as_ptr() as *const _, error.len() as _);
|
||||
}
|
||||
|
||||
struct CollationUserPtr<F> {
|
||||
callback: F,
|
||||
collation_name: String,
|
||||
}
|
||||
|
||||
#[allow(warnings)]
|
||||
extern "C" fn run_collation_function<F>(
|
||||
user_ptr: *mut libc::c_void,
|
||||
@ -445,48 +500,87 @@ extern "C" fn run_collation_function<F>(
|
||||
rhs_ptr: *const libc::c_void,
|
||||
) -> libc::c_int
|
||||
where
|
||||
F: Fn(&str, &str) -> std::cmp::Ordering + Send + 'static,
|
||||
F: Fn(&str, &str) -> std::cmp::Ordering + Send + std::panic::UnwindSafe + 'static,
|
||||
{
|
||||
let user_ptr = user_ptr as *const F;
|
||||
let f = unsafe { user_ptr.as_ref() }.unwrap_or_else(|| {
|
||||
assert_fail!(
|
||||
"An unknown error occurred. user_ptr is a null pointer. This should never happen."
|
||||
);
|
||||
});
|
||||
let user_ptr = user_ptr as *const CollationUserPtr<F>;
|
||||
let user_ptr = std::panic::AssertUnwindSafe(unsafe { user_ptr.as_ref() });
|
||||
|
||||
for (ptr, len, side) in &[(rhs_ptr, rhs_len, "rhs"), (lhs_ptr, lhs_len, "lhs")] {
|
||||
if *len < 0 {
|
||||
assert_fail!(
|
||||
"An unknown error occurred. {}_len is negative. This should never happen.",
|
||||
side
|
||||
);
|
||||
}
|
||||
if ptr.is_null() {
|
||||
assert_fail!(
|
||||
let result = std::panic::catch_unwind(|| {
|
||||
let user_ptr = user_ptr.ok_or_else(|| {
|
||||
SqliteCallbackError::Abort(
|
||||
"Got a null pointer as data pointer. This should never happen",
|
||||
)
|
||||
})?;
|
||||
for (ptr, len, side) in &[(rhs_ptr, rhs_len, "rhs"), (lhs_ptr, lhs_len, "lhs")] {
|
||||
if *len < 0 {
|
||||
assert_fail!(
|
||||
"An unknown error occurred. {}_len is negative. This should never happen.",
|
||||
side
|
||||
);
|
||||
}
|
||||
if ptr.is_null() {
|
||||
assert_fail!(
|
||||
"An unknown error occurred. {}_ptr is a null pointer. This should never happen.",
|
||||
side
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let (rhs, lhs) = unsafe {
|
||||
// Depending on the eTextRep-parameter to sqlite3_create_collation_v2() the strings can
|
||||
// have various encodings. register_collation_function() always selects SQLITE_UTF8, so the
|
||||
// pointers point to valid UTF-8 strings (assuming correct behavior of libsqlite3).
|
||||
(
|
||||
str::from_utf8(slice::from_raw_parts(rhs_ptr as *const u8, rhs_len as _)),
|
||||
str::from_utf8(slice::from_raw_parts(lhs_ptr as *const u8, lhs_len as _)),
|
||||
)
|
||||
};
|
||||
|
||||
let rhs =
|
||||
rhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for rhs"))?;
|
||||
let lhs =
|
||||
lhs.map_err(|_| SqliteCallbackError::Abort("Got an invalid UTF-8 string for lhs"))?;
|
||||
|
||||
Ok((user_ptr.callback)(rhs, lhs))
|
||||
})
|
||||
.unwrap_or_else(|p| {
|
||||
Err(SqliteCallbackError::Panic(
|
||||
p,
|
||||
user_ptr
|
||||
.map(|u| u.collation_name.clone())
|
||||
.unwrap_or_default(),
|
||||
))
|
||||
});
|
||||
|
||||
match result {
|
||||
Ok(std::cmp::Ordering::Less) => -1,
|
||||
Ok(std::cmp::Ordering::Equal) => 0,
|
||||
Ok(std::cmp::Ordering::Greater) => 1,
|
||||
Err(SqliteCallbackError::Abort(a)) => {
|
||||
eprintln!(
|
||||
"Collation function {} failed with: {}",
|
||||
user_ptr
|
||||
.map(|c| &c.collation_name as &str)
|
||||
.unwrap_or_default(),
|
||||
a
|
||||
);
|
||||
std::process::abort()
|
||||
}
|
||||
Err(SqliteCallbackError::DieselError(e)) => {
|
||||
eprintln!(
|
||||
"Collation function {} failed with: {}",
|
||||
user_ptr
|
||||
.map(|c| &c.collation_name as &str)
|
||||
.unwrap_or_default(),
|
||||
e
|
||||
);
|
||||
std::process::abort()
|
||||
}
|
||||
Err(SqliteCallbackError::Panic(_, msg)) => {
|
||||
eprintln!("Collation function {} paniced", msg);
|
||||
std::process::abort()
|
||||
}
|
||||
}
|
||||
|
||||
let (rhs, lhs) = unsafe {
|
||||
// Depending on the eTextRep-parameter to sqlite3_create_collation_v2() the strings can
|
||||
// have various encodings. register_collation_function() always selects SQLITE_UTF8, so the
|
||||
// pointers point to valid UTF-8 strings (assuming correct behavior of libsqlite3).
|
||||
(
|
||||
str::from_utf8_unchecked(slice::from_raw_parts(rhs_ptr as *const u8, rhs_len as _)),
|
||||
str::from_utf8_unchecked(slice::from_raw_parts(lhs_ptr as *const u8, lhs_len as _)),
|
||||
)
|
||||
};
|
||||
|
||||
// It doesn't matter if f is UnwindSafe, since we abort on panic.
|
||||
let f = std::panic::AssertUnwindSafe(|| match f(rhs, lhs) {
|
||||
std::cmp::Ordering::Greater => 1,
|
||||
std::cmp::Ordering::Equal => 0,
|
||||
std::cmp::Ordering::Less => -1,
|
||||
});
|
||||
let result = std::panic::catch_unwind(f);
|
||||
result.unwrap_or_else(|_| std::process::abort())
|
||||
}
|
||||
|
||||
extern "C" fn destroy_boxed<F>(data: *mut libc::c_void) {
|
||||
|
@ -1017,11 +1017,8 @@ pub fn derive_valid_grouping(input: TokenStream) -> TokenStream {
|
||||
/// If an implementation of the custom function panics and unwinding is enabled, the panic is
|
||||
/// caught and the function returns to libsqlite with an error. It cannot propagate the panics due
|
||||
/// to the FFI bounary.
|
||||
/// The function or closure is internally wrapped in an
|
||||
/// [`AssertUnwindSafe`](std::panic::AssertUnwindSafe). Its implementation must take care of
|
||||
/// unwind-safety to avoid logic bugs on panic.
|
||||
///
|
||||
/// This also holds for [custom aggregate functions](#custom-aggregate-functions).
|
||||
/// This is is the same for [custom aggregate functions](#custom-aggregate-functions).
|
||||
///
|
||||
/// ## Custom Aggregate Functions
|
||||
///
|
||||
|
@ -263,10 +263,9 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic>
|
||||
f: F,
|
||||
) -> QueryResult<()>
|
||||
where
|
||||
F: Fn(#(#arg_name,)*) -> Ret + Send + 'static + ::std::panic::RefUnwindSafe,
|
||||
F: Fn(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
|
||||
(#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
|
||||
StaticallySizedRow<(#(#arg_type,)*), Sqlite> +
|
||||
::std::panic::UnwindSafe,
|
||||
StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
|
||||
Ret: ToSql<#return_type, Sqlite>,
|
||||
{
|
||||
conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
|
||||
@ -290,10 +289,9 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic>
|
||||
mut f: F,
|
||||
) -> QueryResult<()>
|
||||
where
|
||||
F: FnMut(#(#arg_name,)*) -> Ret + Send + 'static + ::std::panic::RefUnwindSafe,
|
||||
F: FnMut(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
|
||||
(#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
|
||||
StaticallySizedRow<(#(#arg_type,)*), Sqlite> +
|
||||
::std::panic::UnwindSafe,
|
||||
StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
|
||||
Ret: ToSql<#return_type, Sqlite>,
|
||||
{
|
||||
conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
|
||||
|
Loading…
Reference in New Issue
Block a user