mirror of
https://github.com/diesel-rs/diesel.git
synced 2024-10-04 01:28:13 +03:00
Expose sqlite's sqlite3_create_collation()
This commit is contained in:
parent
7e676025dd
commit
4307a11dd3
@ -56,6 +56,8 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/
|
||||
|
||||
* Added the error position for PostgreSQL errors
|
||||
|
||||
* Added ability to create custom collation functions in SQLite.
|
||||
|
||||
### Removed
|
||||
|
||||
* All previously deprecated items have been removed.
|
||||
|
@ -250,6 +250,48 @@ impl SqliteConnection {
|
||||
functions::register_aggregate::<_, _, _, _, A>(&self.raw_connection, fn_name)
|
||||
}
|
||||
|
||||
/// Register a collation function.
|
||||
///
|
||||
/// `collation` must always return the same answer given the same inputs.
|
||||
/// If `collation` panics and unwinds the stack, the process is aborted, since it is used
|
||||
/// across a C FFI boundary, which cannot be unwound across.
|
||||
///
|
||||
/// If the name is already registered it will be overwritten.
|
||||
///
|
||||
/// This method will return an error if registering the function fails, either due to an
|
||||
/// out-of-memory situation or because a collation with that name already exists and is
|
||||
/// currently being used in parallel by a query.
|
||||
///
|
||||
/// The collation needs to be specified when creating a table:
|
||||
/// `CREATE TABLE my_table ( str TEXT COLLATE MY_COLLATION )`,
|
||||
/// where `MY_COLLATION` corresponds to `collation_name`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # include!("../../doctest_setup.rs");
|
||||
/// #
|
||||
/// # fn main() {
|
||||
/// # run_test().unwrap();
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn run_test() -> QueryResult<()> {
|
||||
/// # let conn = SqliteConnection::establish(":memory:").unwrap();
|
||||
/// // sqlite NOCASE only works for ASCII characters,
|
||||
/// // this collation allows handling UTF-8 (barring locale differences)
|
||||
/// conn.register_collation("RUSTNOCASE", |rhs, lhs| {
|
||||
/// rhs.to_lowercase().cmp(&lhs.to_lowercase())
|
||||
/// })
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn register_collation<F>(&self, collation_name: &str, collation: F) -> QueryResult<()>
|
||||
where
|
||||
F: Fn(&str, &str) -> std::cmp::Ordering + Send + 'static,
|
||||
{
|
||||
self.raw_connection
|
||||
.register_collation_function(collation_name, collation)
|
||||
}
|
||||
|
||||
fn register_diesel_sql_functions(&self) -> QueryResult<()> {
|
||||
use crate::sql_types::{Integer, Text};
|
||||
|
||||
@ -522,4 +564,75 @@ mod tests {
|
||||
.unwrap();
|
||||
assert_eq!(Some(3), result);
|
||||
}
|
||||
|
||||
table! {
|
||||
my_collation_example {
|
||||
id -> Integer,
|
||||
value -> Text,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_collation_function() {
|
||||
use self::my_collation_example::dsl::*;
|
||||
|
||||
let connection = SqliteConnection::establish(":memory:").unwrap();
|
||||
|
||||
connection
|
||||
.register_collation("RUSTNOCASE", |rhs, lhs| {
|
||||
rhs.to_lowercase().cmp(&lhs.to_lowercase())
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
connection
|
||||
.execute(
|
||||
"CREATE TABLE my_collation_example (id integer primary key autoincrement, value text collate RUSTNOCASE)",
|
||||
)
|
||||
.unwrap();
|
||||
connection
|
||||
.execute("INSERT INTO my_collation_example (value) VALUES ('foo'), ('FOo'), ('f00')")
|
||||
.unwrap();
|
||||
|
||||
let result = my_collation_example
|
||||
.filter(value.eq("foo"))
|
||||
.select(value)
|
||||
.load::<String>(&connection);
|
||||
assert_eq!(
|
||||
Ok(&["foo".to_owned(), "FOo".to_owned()][..]),
|
||||
result.as_ref().map(|vec| vec.as_ref())
|
||||
);
|
||||
|
||||
let result = my_collation_example
|
||||
.filter(value.eq("FOO"))
|
||||
.select(value)
|
||||
.load::<String>(&connection);
|
||||
assert_eq!(
|
||||
Ok(&["foo".to_owned(), "FOo".to_owned()][..]),
|
||||
result.as_ref().map(|vec| vec.as_ref())
|
||||
);
|
||||
|
||||
let result = my_collation_example
|
||||
.filter(value.eq("f00"))
|
||||
.select(value)
|
||||
.load::<String>(&connection);
|
||||
assert_eq!(
|
||||
Ok(&["f00".to_owned()][..]),
|
||||
result.as_ref().map(|vec| vec.as_ref())
|
||||
);
|
||||
|
||||
let result = my_collation_example
|
||||
.filter(value.eq("F00"))
|
||||
.select(value)
|
||||
.load::<String>(&connection);
|
||||
assert_eq!(
|
||||
Ok(&["f00".to_owned()][..]),
|
||||
result.as_ref().map(|vec| vec.as_ref())
|
||||
);
|
||||
|
||||
let result = my_collation_example
|
||||
.filter(value.eq("oof"))
|
||||
.select(value)
|
||||
.load::<String>(&connection);
|
||||
assert_eq!(Ok(&[][..]), result.as_ref().map(|vec| vec.as_ref()));
|
||||
}
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ impl RawConnection {
|
||||
Some(run_custom_function::<F>),
|
||||
None,
|
||||
None,
|
||||
Some(destroy_boxed_fn::<F>),
|
||||
Some(destroy_boxed::<F>),
|
||||
)
|
||||
};
|
||||
|
||||
@ -140,6 +140,35 @@ impl RawConnection {
|
||||
Self::process_sql_function_result(result)
|
||||
}
|
||||
|
||||
pub fn register_collation_function<F>(
|
||||
&self,
|
||||
collation_name: &str,
|
||||
collation: F,
|
||||
) -> QueryResult<()>
|
||||
where
|
||||
F: Fn(&str, &str) -> std::cmp::Ordering + Send + 'static,
|
||||
{
|
||||
let collation_name = Self::get_fn_name(collation_name)?;
|
||||
let callback_fn = Box::into_raw(Box::new(collation));
|
||||
|
||||
let result = unsafe {
|
||||
ffi::sqlite3_create_collation_v2(
|
||||
self.internal_connection.as_ptr(),
|
||||
collation_name.as_ptr(),
|
||||
ffi::SQLITE_UTF8,
|
||||
callback_fn as *mut _,
|
||||
Some(run_collation_function::<F>),
|
||||
Some(destroy_boxed::<F>),
|
||||
)
|
||||
};
|
||||
|
||||
let result = Self::process_sql_function_result(result);
|
||||
if result.is_err() {
|
||||
destroy_boxed::<F>(callback_fn as *mut _);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn get_fn_name(fn_name: &str) -> Result<CString, NulError> {
|
||||
Ok(CString::new(fn_name)?)
|
||||
}
|
||||
@ -379,12 +408,63 @@ unsafe fn null_aggregate_context_error(ctx: *mut ffi::sqlite3_context) {
|
||||
);
|
||||
}
|
||||
|
||||
extern "C" fn destroy_boxed_fn<F>(data: *mut libc::c_void)
|
||||
#[allow(warnings)]
|
||||
extern "C" fn run_collation_function<F>(
|
||||
user_ptr: *mut libc::c_void,
|
||||
lhs_len: libc::c_int,
|
||||
lhs_ptr: *const libc::c_void,
|
||||
rhs_len: libc::c_int,
|
||||
rhs_ptr: *const libc::c_void,
|
||||
) -> libc::c_int
|
||||
where
|
||||
F: FnMut(&RawConnection, &[*mut ffi::sqlite3_value]) -> QueryResult<SerializedValue>
|
||||
+ Send
|
||||
+ 'static,
|
||||
F: Fn(&str, &str) -> std::cmp::Ordering + Send + 'static,
|
||||
{
|
||||
let user_ptr = user_ptr as *const F;
|
||||
let f = unsafe { user_ptr.as_ref() }.unwrap_or_else(|| {
|
||||
eprintln!(
|
||||
"An unknown error occurred. user_ptr is a null pointer. This should never happen."
|
||||
);
|
||||
std::process::abort();
|
||||
});
|
||||
|
||||
for (ptr, len, side) in &[(rhs_ptr, rhs_len, "rhs"), (lhs_ptr, lhs_len, "lhs")] {
|
||||
if *len < 0 {
|
||||
eprintln!(
|
||||
"An unknown error occurred. {}_len is negative. This should never happen.",
|
||||
side
|
||||
);
|
||||
std::process::abort();
|
||||
}
|
||||
if ptr.is_null() {
|
||||
eprintln!(
|
||||
"An unknown error occurred. {}_ptr is a null pointer. This should never happen.",
|
||||
side
|
||||
);
|
||||
std::process::abort();
|
||||
}
|
||||
}
|
||||
|
||||
let (rhs, lhs) = unsafe {
|
||||
// Depending on the eTextRep-parameter to sqlite3_create_collation_v2() the strings can
|
||||
// have various encodings. register_collation_function() always selects SQLITE_UTF8, so the
|
||||
// pointers point to valid UTF-8 strings (assuming correct behavior of libsqlite3).
|
||||
(
|
||||
str::from_utf8_unchecked(slice::from_raw_parts(rhs_ptr as *const u8, rhs_len as _)),
|
||||
str::from_utf8_unchecked(slice::from_raw_parts(lhs_ptr as *const u8, lhs_len as _)),
|
||||
)
|
||||
};
|
||||
|
||||
// It doesn't matter if f is UnwindSafe, since we abort on panic.
|
||||
let f = std::panic::AssertUnwindSafe(|| match f(rhs, lhs) {
|
||||
std::cmp::Ordering::Greater => 1,
|
||||
std::cmp::Ordering::Equal => 0,
|
||||
std::cmp::Ordering::Less => -1,
|
||||
});
|
||||
let result = std::panic::catch_unwind(f);
|
||||
result.unwrap_or_else(|_| std::process::abort())
|
||||
}
|
||||
|
||||
extern "C" fn destroy_boxed<F>(data: *mut libc::c_void) {
|
||||
let ptr = data as *mut F;
|
||||
unsafe { Box::from_raw(ptr) };
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user