From a29ccb4ff83cd764182caebd092c81cdbc729499 Mon Sep 17 00:00:00 2001 From: Kay Simmons Date: Wed, 30 Nov 2022 10:54:01 -0800 Subject: [PATCH] make thread safe connection more thread safe Co-Authored-By: Mikayla Maki --- Cargo.lock | 2 + crates/db/Cargo.toml | 1 + crates/db/src/db.rs | 32 ++- crates/db/src/kvp.rs | 6 +- crates/sqlez/Cargo.toml | 1 + crates/sqlez/src/migrations.rs | 6 +- crates/sqlez/src/thread_safe_connection.rs | 230 +++++++++++++-------- crates/sqlez/src/util.rs | 4 + crates/sqlez_macros/src/sqlez_macros.rs | 2 +- crates/workspace/src/persistence.rs | 14 +- crates/workspace/src/workspace.rs | 17 +- crates/zed/src/zed.rs | 5 +- 12 files changed, 196 insertions(+), 124 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9e3181575f..fd1bb4ea0a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1569,6 +1569,7 @@ dependencies = [ "log", "parking_lot 0.11.2", "serde", + "smol", "sqlez", "sqlez_macros", "tempdir", @@ -5596,6 +5597,7 @@ dependencies = [ "lazy_static", "libsqlite3-sys", "parking_lot 0.11.2", + "smol", "thread_local", ] diff --git a/crates/db/Cargo.toml b/crates/db/Cargo.toml index 2d88d4ece5..69c90e02f9 100644 --- a/crates/db/Cargo.toml +++ b/crates/db/Cargo.toml @@ -23,6 +23,7 @@ lazy_static = "1.4.0" log = { version = "0.4.16", features = ["kv_unstable_serde"] } parking_lot = "0.11.1" serde = { version = "1.0", features = ["derive"] } +smol = "1.2" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/db/src/db.rs b/crates/db/src/db.rs index adf6f5c035..701aa57656 100644 --- a/crates/db/src/db.rs +++ b/crates/db/src/db.rs @@ -4,31 +4,36 @@ pub mod kvp; pub use anyhow; pub use indoc::indoc; pub use lazy_static; +pub use smol; pub use sqlez; pub use sqlez_macros; use sqlez::domain::Migrator; use sqlez::thread_safe_connection::ThreadSafeConnection; +use sqlez_macros::sql; use std::fs::{create_dir_all, remove_dir_all}; use std::path::Path; use std::sync::atomic::{AtomicBool, Ordering}; use util::channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}; use util::paths::DB_DIR; -const INITIALIZE_QUERY: &'static str = indoc! {" - PRAGMA journal_mode=WAL; +const CONNECTION_INITIALIZE_QUERY: &'static str = sql!( PRAGMA synchronous=NORMAL; PRAGMA busy_timeout=1; PRAGMA foreign_keys=TRUE; PRAGMA case_sensitive_like=TRUE; -"}; +); + +const DB_INITIALIZE_QUERY: &'static str = sql!( + PRAGMA journal_mode=WAL; +); lazy_static::lazy_static! { static ref DB_WIPED: AtomicBool = AtomicBool::new(false); } /// Open or create a database at the given directory path. -pub fn open_file_db() -> ThreadSafeConnection { +pub async fn open_file_db() -> ThreadSafeConnection { // Use 0 for now. Will implement incrementing and clearing of old db files soon TM let current_db_dir = (*DB_DIR).join(Path::new(&format!("0-{}", *RELEASE_CHANNEL_NAME))); @@ -43,12 +48,19 @@ pub fn open_file_db() -> ThreadSafeConnection { create_dir_all(¤t_db_dir).expect("Should be able to create the database directory"); let db_path = current_db_dir.join(Path::new("db.sqlite")); - ThreadSafeConnection::new(db_path.to_string_lossy().as_ref(), true) - .with_initialize_query(INITIALIZE_QUERY) + ThreadSafeConnection::::builder(db_path.to_string_lossy().as_ref(), true) + .with_db_initialization_query(DB_INITIALIZE_QUERY) + .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY) + .build() + .await } -pub fn open_memory_db(db_name: &str) -> ThreadSafeConnection { - ThreadSafeConnection::new(db_name, false).with_initialize_query(INITIALIZE_QUERY) +pub async fn open_memory_db(db_name: &str) -> ThreadSafeConnection { + ThreadSafeConnection::::builder(db_name, false) + .with_db_initialization_query(DB_INITIALIZE_QUERY) + .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY) + .build() + .await } /// Implements a basic DB wrapper for a given domain @@ -67,9 +79,9 @@ macro_rules! connection { ::db::lazy_static::lazy_static! { pub static ref $id: $t = $t(if cfg!(any(test, feature = "test-support")) { - ::db::open_memory_db(stringify!($id)) + $crate::smol::block_on(::db::open_memory_db(stringify!($id))) } else { - ::db::open_file_db() + $crate::smol::block_on(::db::open_file_db()) }); } }; diff --git a/crates/db/src/kvp.rs b/crates/db/src/kvp.rs index b3f2a716cb..da796fa469 100644 --- a/crates/db/src/kvp.rs +++ b/crates/db/src/kvp.rs @@ -15,9 +15,9 @@ impl std::ops::Deref for KeyValueStore { lazy_static::lazy_static! { pub static ref KEY_VALUE_STORE: KeyValueStore = KeyValueStore(if cfg!(any(test, feature = "test-support")) { - open_memory_db(stringify!($id)) + smol::block_on(open_memory_db("KEY_VALUE_STORE")) } else { - open_file_db() + smol::block_on(open_file_db()) }); } @@ -62,7 +62,7 @@ mod tests { #[gpui::test] async fn test_kvp() { - let db = KeyValueStore(crate::open_memory_db("test_kvp")); + let db = KeyValueStore(crate::open_memory_db("test_kvp").await); assert_eq!(db.read_kvp("key-1").unwrap(), None); diff --git a/crates/sqlez/Cargo.toml b/crates/sqlez/Cargo.toml index cab1af7d6c..8a7f1ba415 100644 --- a/crates/sqlez/Cargo.toml +++ b/crates/sqlez/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" anyhow = { version = "1.0.38", features = ["backtrace"] } indoc = "1.0.7" libsqlite3-sys = { version = "0.25.2", features = ["bundled"] } +smol = "1.2" thread_local = "1.1.4" lazy_static = "1.4" parking_lot = "0.11.1" diff --git a/crates/sqlez/src/migrations.rs b/crates/sqlez/src/migrations.rs index 6c0aafaf20..41c505f85b 100644 --- a/crates/sqlez/src/migrations.rs +++ b/crates/sqlez/src/migrations.rs @@ -15,9 +15,9 @@ impl Connection { // Setup the migrations table unconditionally self.exec(indoc! {" CREATE TABLE IF NOT EXISTS migrations ( - domain TEXT, - step INTEGER, - migration TEXT + domain TEXT, + step INTEGER, + migration TEXT )"})?()?; let completed_migrations = diff --git a/crates/sqlez/src/thread_safe_connection.rs b/crates/sqlez/src/thread_safe_connection.rs index 6c35d1e945..880a58d194 100644 --- a/crates/sqlez/src/thread_safe_connection.rs +++ b/crates/sqlez/src/thread_safe_connection.rs @@ -1,4 +1,4 @@ -use futures::{Future, FutureExt}; +use futures::{channel::oneshot, Future, FutureExt}; use lazy_static::lazy_static; use parking_lot::RwLock; use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread}; @@ -10,17 +10,25 @@ use crate::{ util::UnboundedSyncSender, }; -type QueuedWrite = Box; +const MIGRATION_RETRIES: usize = 10; +type QueuedWrite = Box; lazy_static! { + /// List of queues of tasks by database uri. This lets us serialize writes to the database + /// and have a single worker thread per db file. This means many thread safe connections + /// (possibly with different migrations) could all be communicating with the same background + /// thread. static ref QUEUES: RwLock, UnboundedSyncSender>> = Default::default(); } +/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static, +/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection +/// may be accessed by passing a callback to the `write` function which will queue the callback pub struct ThreadSafeConnection { uri: Arc, persistent: bool, - initialize_query: Option<&'static str>, + connection_initialize_query: Option<&'static str>, connections: Arc>, _migrator: PhantomData, } @@ -28,87 +36,125 @@ pub struct ThreadSafeConnection { unsafe impl Send for ThreadSafeConnection {} unsafe impl Sync for ThreadSafeConnection {} -impl ThreadSafeConnection { - pub fn new(uri: &str, persistent: bool) -> Self { - Self { - uri: Arc::from(uri), - persistent, - initialize_query: None, - connections: Default::default(), - _migrator: PhantomData, - } +pub struct ThreadSafeConnectionBuilder { + db_initialize_query: Option<&'static str>, + connection: ThreadSafeConnection, +} + +impl ThreadSafeConnectionBuilder { + /// Sets the query to run every time a connection is opened. This must + /// be infallible (EG only use pragma statements) and not cause writes. + /// to the db or it will panic. + pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self { + self.connection.connection_initialize_query = Some(initialize_query); + self } - /// Sets the query to run every time a connection is opened. This must - /// be infallible (EG only use pragma statements) - pub fn with_initialize_query(mut self, initialize_query: &'static str) -> Self { - self.initialize_query = Some(initialize_query); + /// Queues an initialization query for the database file. This must be infallible + /// but may cause changes to the database file such as with `PRAGMA journal_mode` + pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self { + self.db_initialize_query = Some(initialize_query); self } + pub async fn build(self) -> ThreadSafeConnection { + let db_initialize_query = self.db_initialize_query; + + self.connection + .write(move |connection| { + if let Some(db_initialize_query) = db_initialize_query { + connection.exec(db_initialize_query).expect(&format!( + "Db initialize query failed to execute: {}", + db_initialize_query + ))() + .unwrap(); + } + + let mut failure_result = None; + for _ in 0..MIGRATION_RETRIES { + failure_result = Some(M::migrate(connection)); + if failure_result.as_ref().unwrap().is_ok() { + break; + } + } + + failure_result.unwrap().expect("Migration failed"); + }) + .await; + + self.connection + } +} + +impl ThreadSafeConnection { + pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder { + ThreadSafeConnectionBuilder:: { + db_initialize_query: None, + connection: Self { + uri: Arc::from(uri), + persistent, + connection_initialize_query: None, + connections: Default::default(), + _migrator: PhantomData, + }, + } + } + /// Opens a new db connection with the initialized file path. This is internal and only /// called from the deref function. - /// If opening fails, the connection falls back to a shared memory connection fn open_file(&self) -> Connection { - // This unwrap is secured by a panic in the constructor. Be careful if you remove it! Connection::open_file(self.uri.as_ref()) } - /// Opens a shared memory connection using the file path as the identifier. This unwraps - /// as we expect it always to succeed + /// Opens a shared memory connection using the file path as the identifier. This is internal + /// and only called from the deref function. fn open_shared_memory(&self) -> Connection { Connection::open_memory(Some(self.uri.as_ref())) } - // Open a new connection for the given domain, leaving this - // connection intact. - pub fn for_domain(&self) -> ThreadSafeConnection { - ThreadSafeConnection { - uri: self.uri.clone(), - persistent: self.persistent, - initialize_query: self.initialize_query, - connections: Default::default(), - _migrator: PhantomData, - } - } - - pub fn write( - &self, - callback: impl 'static + Send + FnOnce(&Connection) -> T, - ) -> impl Future { + fn queue_write_task(&self, callback: QueuedWrite) { // Startup write thread for this database if one hasn't already // been started and insert a channel to queue work for it if !QUEUES.read().contains_key(&self.uri) { - use std::sync::mpsc::channel; - - let (sender, reciever) = channel::(); - let mut write_connection = self.create_connection(); - // Enable writes for this connection - write_connection.write = true; - thread::spawn(move || { - while let Ok(write) = reciever.recv() { - write(&write_connection) - } - }); - let mut queues = QUEUES.write(); - queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender)); + if !queues.contains_key(&self.uri) { + use std::sync::mpsc::channel; + + let (sender, reciever) = channel::(); + let mut write_connection = self.create_connection(); + // Enable writes for this connection + write_connection.write = true; + thread::spawn(move || { + while let Ok(write) = reciever.recv() { + write(&write_connection) + } + }); + + queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender)); + } } // Grab the queue for this database let queues = QUEUES.read(); let write_channel = queues.get(&self.uri).unwrap(); + write_channel + .send(callback) + .expect("Could not send write action to backgorund thread"); + } + + pub fn write( + &self, + callback: impl 'static + Send + FnOnce(&Connection) -> T, + ) -> impl Future { // Create a one shot channel for the result of the queued write // so we can await on the result - let (sender, reciever) = futures::channel::oneshot::channel(); - write_channel - .send(Box::new(move |connection| { - sender.send(callback(connection)).ok(); - })) - .expect("Could not send write action to background thread"); + let (sender, reciever) = oneshot::channel(); + self.queue_write_task(Box::new(move |connection| { + sender.send(callback(connection)).ok(); + })); - reciever.map(|response| response.expect("Background thread unexpectedly closed")) + reciever.map(|response| response.expect("Background writer thread unexpectedly closed")) } pub(crate) fn create_connection(&self) -> Connection { @@ -118,10 +164,11 @@ impl ThreadSafeConnection { self.open_shared_memory() }; - // Enable writes for the migrations and initialization queries - connection.write = true; + // Disallow writes on the connection. The only writes allowed for thread safe connections + // are from the background thread that can serialize them. + connection.write = false; - if let Some(initialize_query) = self.initialize_query { + if let Some(initialize_query) = self.connection_initialize_query { connection.exec(initialize_query).expect(&format!( "Initialize query failed to execute: {}", initialize_query @@ -129,20 +176,34 @@ impl ThreadSafeConnection { .unwrap() } - M::migrate(&connection).expect("Migrations failed"); - - // Disable db writes for normal thread local connection - connection.write = false; connection } } +impl ThreadSafeConnection<()> { + /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations. + /// This allows construction to be infallible and not write to the db. + pub fn new( + uri: &str, + persistent: bool, + connection_initialize_query: Option<&'static str>, + ) -> Self { + Self { + uri: Arc::from(uri), + persistent, + connection_initialize_query, + connections: Default::default(), + _migrator: PhantomData, + } + } +} + impl Clone for ThreadSafeConnection { fn clone(&self) -> Self { Self { uri: self.uri.clone(), persistent: self.persistent, - initialize_query: self.initialize_query.clone(), + connection_initialize_query: self.connection_initialize_query.clone(), connections: self.connections.clone(), _migrator: PhantomData, } @@ -163,11 +224,11 @@ impl Deref for ThreadSafeConnection { #[cfg(test)] mod test { - use std::{fs, ops::Deref, thread}; + use indoc::indoc; + use lazy_static::__Deref; + use std::thread; - use crate::domain::Domain; - - use super::ThreadSafeConnection; + use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection}; #[test] fn many_initialize_and_migrate_queries_at_once() { @@ -185,27 +246,22 @@ mod test { for _ in 0..100 { handles.push(thread::spawn(|| { - let _ = ThreadSafeConnection::::new("annoying-test.db", false) - .with_initialize_query( - " - PRAGMA journal_mode=WAL; - PRAGMA synchronous=NORMAL; - PRAGMA busy_timeout=1; - PRAGMA foreign_keys=TRUE; - PRAGMA case_sensitive_like=TRUE; - ", - ) - .deref(); + let builder = + ThreadSafeConnection::::builder("annoying-test.db", false) + .with_db_initialization_query("PRAGMA journal_mode=WAL") + .with_connection_initialize_query(indoc! {" + PRAGMA synchronous=NORMAL; + PRAGMA busy_timeout=1; + PRAGMA foreign_keys=TRUE; + PRAGMA case_sensitive_like=TRUE; + "}); + let _ = smol::block_on(builder.build()).deref(); })); } for handle in handles { let _ = handle.join(); } - - // fs::remove_file("annoying-test.db").unwrap(); - // fs::remove_file("annoying-test.db-shm").unwrap(); - // fs::remove_file("annoying-test.db-wal").unwrap(); } #[test] @@ -241,8 +297,10 @@ mod test { } } - let _ = ThreadSafeConnection::::new("wild_zed_lost_failure", false) - .with_initialize_query("PRAGMA FOREIGN_KEYS=true") - .deref(); + let builder = + ThreadSafeConnection::::builder("wild_zed_lost_failure", false) + .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true"); + + smol::block_on(builder.build()); } } diff --git a/crates/sqlez/src/util.rs b/crates/sqlez/src/util.rs index b5366cffc4..ce0353b15e 100644 --- a/crates/sqlez/src/util.rs +++ b/crates/sqlez/src/util.rs @@ -4,6 +4,10 @@ use std::sync::mpsc::Sender; use parking_lot::Mutex; use thread_local::ThreadLocal; +/// Unbounded standard library sender which is stored per thread to get around +/// the lack of sync on the standard library version while still being unbounded +/// Note: this locks on the cloneable sender, but its done once per thread, so it +/// shouldn't result in too much contention pub struct UnboundedSyncSender { clonable_sender: Mutex>, local_senders: ThreadLocal>, diff --git a/crates/sqlez_macros/src/sqlez_macros.rs b/crates/sqlez_macros/src/sqlez_macros.rs index 25249b89b6..532503a3e6 100644 --- a/crates/sqlez_macros/src/sqlez_macros.rs +++ b/crates/sqlez_macros/src/sqlez_macros.rs @@ -3,7 +3,7 @@ use sqlez::thread_safe_connection::ThreadSafeConnection; use syn::Error; lazy_static::lazy_static! { - static ref SQLITE: ThreadSafeConnection = ThreadSafeConnection::new(":memory:", false); + static ref SQLITE: ThreadSafeConnection = ThreadSafeConnection::new(":memory:", false, None); } #[proc_macro] diff --git a/crates/workspace/src/persistence.rs b/crates/workspace/src/persistence.rs index 0d35c19d5d..c8b31cd254 100644 --- a/crates/workspace/src/persistence.rs +++ b/crates/workspace/src/persistence.rs @@ -395,7 +395,7 @@ mod tests { async fn test_next_id_stability() { env_logger::try_init().ok(); - let db = WorkspaceDb(open_memory_db("test_next_id_stability")); + let db = WorkspaceDb(open_memory_db("test_next_id_stability").await); db.write(|conn| { conn.migrate( @@ -442,7 +442,7 @@ mod tests { async fn test_workspace_id_stability() { env_logger::try_init().ok(); - let db = WorkspaceDb(open_memory_db("test_workspace_id_stability")); + let db = WorkspaceDb(open_memory_db("test_workspace_id_stability").await); db.write(|conn| { conn.migrate( @@ -523,7 +523,7 @@ mod tests { async fn test_full_workspace_serialization() { env_logger::try_init().ok(); - let db = WorkspaceDb(open_memory_db("test_full_workspace_serialization")); + let db = WorkspaceDb(open_memory_db("test_full_workspace_serialization").await); let dock_pane = crate::persistence::model::SerializedPane { children: vec![ @@ -597,7 +597,7 @@ mod tests { async fn test_workspace_assignment() { env_logger::try_init().ok(); - let db = WorkspaceDb(open_memory_db("test_basic_functionality")); + let db = WorkspaceDb(open_memory_db("test_basic_functionality").await); let workspace_1 = SerializedWorkspace { id: 1, @@ -689,7 +689,7 @@ mod tests { async fn test_basic_dock_pane() { env_logger::try_init().ok(); - let db = WorkspaceDb(open_memory_db("basic_dock_pane")); + let db = WorkspaceDb(open_memory_db("basic_dock_pane").await); let dock_pane = crate::persistence::model::SerializedPane::new( vec![ @@ -714,7 +714,7 @@ mod tests { async fn test_simple_split() { env_logger::try_init().ok(); - let db = WorkspaceDb(open_memory_db("simple_split")); + let db = WorkspaceDb(open_memory_db("simple_split").await); // ----------------- // | 1,2 | 5,6 | @@ -766,7 +766,7 @@ mod tests { async fn test_cleanup_panes() { env_logger::try_init().ok(); - let db = WorkspaceDb(open_memory_db("test_cleanup_panes")); + let db = WorkspaceDb(open_memory_db("test_cleanup_panes").await); let center_pane = SerializedPaneGroup::Group { axis: gpui::Axis::Horizontal, diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 9755c2c6ca..584f6392d1 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -162,11 +162,7 @@ pub fn init(app_state: Arc, cx: &mut MutableAppContext) { let app_state = Arc::downgrade(&app_state); move |_: &NewFile, cx: &mut MutableAppContext| { if let Some(app_state) = app_state.upgrade() { - let task = open_new(&app_state, cx); - cx.spawn(|_| async { - task.await; - }) - .detach(); + open_new(&app_state, cx).detach(); } } }); @@ -174,11 +170,7 @@ pub fn init(app_state: Arc, cx: &mut MutableAppContext) { let app_state = Arc::downgrade(&app_state); move |_: &NewWindow, cx: &mut MutableAppContext| { if let Some(app_state) = app_state.upgrade() { - let task = open_new(&app_state, cx); - cx.spawn(|_| async { - task.await; - }) - .detach(); + open_new(&app_state, cx).detach(); } } }); @@ -2641,13 +2633,16 @@ pub fn open_paths( }) } -fn open_new(app_state: &Arc, cx: &mut MutableAppContext) -> Task<()> { +pub fn open_new(app_state: &Arc, cx: &mut MutableAppContext) -> Task<()> { let task = Workspace::new_local(Vec::new(), app_state.clone(), cx); cx.spawn(|mut cx| async move { + eprintln!("Open new task spawned"); let (workspace, opened_paths) = task.await; + eprintln!("workspace and path items created"); workspace.update(&mut cx, |_, cx| { if opened_paths.is_empty() { + eprintln!("new file redispatched"); cx.dispatch_action(NewFile); } }) diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 3693a5e580..0a25cfb66f 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -626,7 +626,7 @@ mod tests { use theme::ThemeRegistry; use workspace::{ item::{Item, ItemHandle}, - open_paths, pane, NewFile, Pane, SplitDirection, WorkspaceHandle, + open_new, open_paths, pane, NewFile, Pane, SplitDirection, WorkspaceHandle, }; #[gpui::test] @@ -762,8 +762,7 @@ mod tests { #[gpui::test] async fn test_new_empty_workspace(cx: &mut TestAppContext) { let app_state = init(cx); - cx.dispatch_global_action(workspace::NewFile); - cx.foreground().run_until_parked(); + cx.update(|cx| open_new(&app_state, cx)).await; let window_id = *cx.window_ids().first().unwrap(); let workspace = cx.root_view::(window_id).unwrap();