From d8854442761fcb8a8df11a561cf24f831b549874 Mon Sep 17 00:00:00 2001 From: lopukhov Date: Mon, 26 Dec 2022 10:34:51 +0000 Subject: [PATCH] Replace the usage of `AtomicUsize` with `OneCell` clarifying the usage of `unsafe` in the process. - Using `usize` or `AtomicUsize` to store pointers is not recommended following the reasons detailed in the "strict provenance" proposal - Using an explicit function type and `Option` is more idiomatic - `OneCell` was already being used and it "tracks" in a more explicit way the uninitialized state of the pointer --- rust/src/lib.rs | 60 ++++++++++++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/rust/src/lib.rs b/rust/src/lib.rs index e9663f7..c806dd1 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -34,8 +34,7 @@ use std::sync::atomic::{AtomicUsize, Ordering::SeqCst}; #[macro_export] macro_rules! progress { () => {{ - static COUNTER: $crate::Counter = - $crate::Counter::progress(concat!(file!(), ":", line!())); + static COUNTER: $crate::Counter = $crate::Counter::progress(concat!(file!(), ":", line!())); COUNTER.increment(); }}; ($name:expr) => {{ @@ -212,10 +211,10 @@ impl Counter { fn create_counter(&self) -> Option<&'static coz_counter_t> { let name = CString::new(self.name).unwrap(); let ptr = coz_get_counter(self.ty, &name); - if ptr.is_null() { - None - } else { - Some(unsafe { &*ptr }) + match ptr { + // SAFETY: Pointer to counter returned by `coz_get_counter` is not null and aligned. + Some(ptr) if !ptr.is_null() => Some(unsafe { &*ptr }), + _ => None, } } } @@ -224,14 +223,12 @@ impl Counter { /// coz calls to `begin` and `end` for the duration of a scope, regardless of how /// the scope was exited (e.g. by early return, `?` or panic). pub struct Guard<'t> { - counter: &'t Counter + counter: &'t Counter, } impl<'t> Guard<'t> { pub fn new(counter: &'t Counter) -> Self { - Guard { - counter - } + Guard { counter } } } @@ -247,30 +244,37 @@ struct coz_counter_t { backoff: libc::size_t, } +/// The type of `_coz_get_counter` as defined in `include/coz.h` +/// +/// `typedef coz_counter_t* (*coz_get_counter_t)(int, const char*);` +type GetCounterFn = unsafe extern "C" fn(libc::c_int, *const libc::c_char) -> *mut coz_counter_t; + #[cfg(target_os = "linux")] -fn coz_get_counter(ty: libc::c_int, name: &CStr) -> *mut coz_counter_t { - static PTR: AtomicUsize = AtomicUsize::new(1); - let mut ptr = PTR.load(SeqCst); - if ptr == 1 { +fn coz_get_counter(ty: libc::c_int, name: &CStr) -> Option<*mut coz_counter_t> { + static GET_COUNTER: OnceCell> = OnceCell::new(); + let func = GET_COUNTER.get_or_init(|| { let name = CStr::from_bytes_with_nul(b"_coz_get_counter\0").unwrap(); - ptr = unsafe { libc::dlsym(libc::RTLD_DEFAULT, name.as_ptr() as *const _) as usize }; - PTR.store(ptr, SeqCst); - } - if ptr == 0 { - return ptr::null_mut(); - } + // SAFETY: We are calling an external function that does exist in Linux. + // No specific invariants that we must uphold have been defined. + let func = unsafe { libc::dlsym(libc::RTLD_DEFAULT, name.as_ptr()) }; + if func.is_null() { + None + } else { + // SAFETY: If the pointer returned by dlsym is not null it is a valid pointer to the function + // identified by the provided symbol. The type of `_coz_get_counter` is defined in `include/coz.h` + // as [GetCounterFn]. + Some(unsafe { mem::transmute(func) }) + } + }); thread_init(); // just in case we haven't already - unsafe { - mem::transmute::< - usize, - unsafe extern "C" fn(libc::c_int, *const libc::c_char) -> *mut coz_counter_t, - >(ptr)(ty, name.as_ptr()) - } + // SAFETY: We are calling an external function which exists as it is not None + // No specific invariants that we must uphold have been defined. + func.map(|f| unsafe { f(ty, name.as_ptr()) }) } #[cfg(not(target_os = "linux"))] -fn coz_get_counter(_ty: libc::c_int, _name: &CStr) -> *mut coz_counter_t { - ptr::null_mut() +fn coz_get_counter(_ty: libc::c_int, _name: &CStr) -> Option<*mut coz_counter_t> { + None }