Expose sqlite's sqlite3_create_collation()

This commit is contained in:
Alexander 'z33ky' Hirsch 2020-09-05 00:10:05 +02:00 committed by Georg Semmler
parent 7e676025dd
commit 4307a11dd3
No known key found for this signature in database
GPG Key ID: A87BCEE5205CE489
3 changed files with 200 additions and 5 deletions

View File

@ -56,6 +56,8 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/
* Added the error position for PostgreSQL errors
* Added ability to create custom collation functions in SQLite.
### Removed
* All previously deprecated items have been removed.

View File

@ -250,6 +250,48 @@ impl SqliteConnection {
functions::register_aggregate::<_, _, _, _, A>(&self.raw_connection, fn_name)
}
/// Register a collation function.
///
/// `collation` must always return the same answer given the same inputs.
/// If `collation` panics and unwinds the stack, the process is aborted, since it is used
/// across a C FFI boundary, which cannot be unwound across.
///
/// If the name is already registered it will be overwritten.
///
/// This method will return an error if registering the function fails, either due to an
/// out-of-memory situation or because a collation with that name already exists and is
/// currently being used in parallel by a query.
///
/// The collation needs to be specified when creating a table:
/// `CREATE TABLE my_table ( str TEXT COLLATE MY_COLLATION )`,
/// where `MY_COLLATION` corresponds to `collation_name`.
///
/// # Example
///
/// ```rust
/// # include!("../../doctest_setup.rs");
/// #
/// # fn main() {
/// # run_test().unwrap();
/// # }
/// #
/// # fn run_test() -> QueryResult<()> {
/// # let conn = SqliteConnection::establish(":memory:").unwrap();
/// // sqlite NOCASE only works for ASCII characters,
/// // this collation allows handling UTF-8 (barring locale differences)
/// conn.register_collation("RUSTNOCASE", |rhs, lhs| {
/// rhs.to_lowercase().cmp(&lhs.to_lowercase())
/// })
/// # }
/// ```
pub fn register_collation<F>(&self, collation_name: &str, collation: F) -> QueryResult<()>
where
F: Fn(&str, &str) -> std::cmp::Ordering + Send + 'static,
{
self.raw_connection
.register_collation_function(collation_name, collation)
}
fn register_diesel_sql_functions(&self) -> QueryResult<()> {
use crate::sql_types::{Integer, Text};
@ -522,4 +564,75 @@ mod tests {
.unwrap();
assert_eq!(Some(3), result);
}
table! {
my_collation_example {
id -> Integer,
value -> Text,
}
}
#[test]
fn register_collation_function() {
use self::my_collation_example::dsl::*;
let connection = SqliteConnection::establish(":memory:").unwrap();
connection
.register_collation("RUSTNOCASE", |rhs, lhs| {
rhs.to_lowercase().cmp(&lhs.to_lowercase())
})
.unwrap();
connection
.execute(
"CREATE TABLE my_collation_example (id integer primary key autoincrement, value text collate RUSTNOCASE)",
)
.unwrap();
connection
.execute("INSERT INTO my_collation_example (value) VALUES ('foo'), ('FOo'), ('f00')")
.unwrap();
let result = my_collation_example
.filter(value.eq("foo"))
.select(value)
.load::<String>(&connection);
assert_eq!(
Ok(&["foo".to_owned(), "FOo".to_owned()][..]),
result.as_ref().map(|vec| vec.as_ref())
);
let result = my_collation_example
.filter(value.eq("FOO"))
.select(value)
.load::<String>(&connection);
assert_eq!(
Ok(&["foo".to_owned(), "FOo".to_owned()][..]),
result.as_ref().map(|vec| vec.as_ref())
);
let result = my_collation_example
.filter(value.eq("f00"))
.select(value)
.load::<String>(&connection);
assert_eq!(
Ok(&["f00".to_owned()][..]),
result.as_ref().map(|vec| vec.as_ref())
);
let result = my_collation_example
.filter(value.eq("F00"))
.select(value)
.load::<String>(&connection);
assert_eq!(
Ok(&["f00".to_owned()][..]),
result.as_ref().map(|vec| vec.as_ref())
);
let result = my_collation_example
.filter(value.eq("oof"))
.select(value)
.load::<String>(&connection);
assert_eq!(Ok(&[][..]), result.as_ref().map(|vec| vec.as_ref()));
}
}

View File

@ -102,7 +102,7 @@ impl RawConnection {
Some(run_custom_function::<F>),
None,
None,
Some(destroy_boxed_fn::<F>),
Some(destroy_boxed::<F>),
)
};
@ -140,6 +140,35 @@ impl RawConnection {
Self::process_sql_function_result(result)
}
pub fn register_collation_function<F>(
&self,
collation_name: &str,
collation: F,
) -> QueryResult<()>
where
F: Fn(&str, &str) -> std::cmp::Ordering + Send + 'static,
{
let collation_name = Self::get_fn_name(collation_name)?;
let callback_fn = Box::into_raw(Box::new(collation));
let result = unsafe {
ffi::sqlite3_create_collation_v2(
self.internal_connection.as_ptr(),
collation_name.as_ptr(),
ffi::SQLITE_UTF8,
callback_fn as *mut _,
Some(run_collation_function::<F>),
Some(destroy_boxed::<F>),
)
};
let result = Self::process_sql_function_result(result);
if result.is_err() {
destroy_boxed::<F>(callback_fn as *mut _);
}
result
}
fn get_fn_name(fn_name: &str) -> Result<CString, NulError> {
Ok(CString::new(fn_name)?)
}
@ -379,12 +408,63 @@ unsafe fn null_aggregate_context_error(ctx: *mut ffi::sqlite3_context) {
);
}
extern "C" fn destroy_boxed_fn<F>(data: *mut libc::c_void)
#[allow(warnings)]
extern "C" fn run_collation_function<F>(
user_ptr: *mut libc::c_void,
lhs_len: libc::c_int,
lhs_ptr: *const libc::c_void,
rhs_len: libc::c_int,
rhs_ptr: *const libc::c_void,
) -> libc::c_int
where
F: FnMut(&RawConnection, &[*mut ffi::sqlite3_value]) -> QueryResult<SerializedValue>
+ Send
+ 'static,
F: Fn(&str, &str) -> std::cmp::Ordering + Send + 'static,
{
let user_ptr = user_ptr as *const F;
let f = unsafe { user_ptr.as_ref() }.unwrap_or_else(|| {
eprintln!(
"An unknown error occurred. user_ptr is a null pointer. This should never happen."
);
std::process::abort();
});
for (ptr, len, side) in &[(rhs_ptr, rhs_len, "rhs"), (lhs_ptr, lhs_len, "lhs")] {
if *len < 0 {
eprintln!(
"An unknown error occurred. {}_len is negative. This should never happen.",
side
);
std::process::abort();
}
if ptr.is_null() {
eprintln!(
"An unknown error occurred. {}_ptr is a null pointer. This should never happen.",
side
);
std::process::abort();
}
}
let (rhs, lhs) = unsafe {
// Depending on the eTextRep-parameter to sqlite3_create_collation_v2() the strings can
// have various encodings. register_collation_function() always selects SQLITE_UTF8, so the
// pointers point to valid UTF-8 strings (assuming correct behavior of libsqlite3).
(
str::from_utf8_unchecked(slice::from_raw_parts(rhs_ptr as *const u8, rhs_len as _)),
str::from_utf8_unchecked(slice::from_raw_parts(lhs_ptr as *const u8, lhs_len as _)),
)
};
// It doesn't matter if f is UnwindSafe, since we abort on panic.
let f = std::panic::AssertUnwindSafe(|| match f(rhs, lhs) {
std::cmp::Ordering::Greater => 1,
std::cmp::Ordering::Equal => 0,
std::cmp::Ordering::Less => -1,
});
let result = std::panic::catch_unwind(f);
result.unwrap_or_else(|_| std::process::abort())
}
extern "C" fn destroy_boxed<F>(data: *mut libc::c_void) {
let ptr = data as *mut F;
unsafe { Box::from_raw(ptr) };
}