Merge pull request #642 from kinode-dao/dr/sqlite-overhaul

Dr/sqlite overhaul
This commit is contained in:
doria 2024-12-22 16:45:57 -05:00 committed by GitHub
commit 7a4c2d5168
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 231 additions and 187 deletions

View File

@ -4,8 +4,9 @@ use dashmap::DashMap;
use lib::types::core::{
Address, CapMessage, CapMessageSender, Capability, FdManagerRequest, KernelMessage,
LazyLoadBlob, Message, MessageReceiver, MessageSender, PackageId, PrintSender, Printout,
ProcessId, Request, Response, SqlValue, SqliteAction, SqliteError, SqliteRequest,
SqliteResponse, FD_MANAGER_PROCESS_ID, SQLITE_PROCESS_ID,
ProcessId, Request, Response, SqlValue, SqliteAction, SqliteCapabilityKind,
SqliteCapabilityParams, SqliteError, SqliteRequest, SqliteResponse, FD_MANAGER_PROCESS_ID,
SQLITE_PROCESS_ID,
};
use rusqlite::Connection;
use std::{
@ -54,51 +55,46 @@ impl SqliteState {
}
}
pub async fn open_db(&mut self, package_id: PackageId, db: String) -> Result<(), SqliteError> {
let key = (package_id.clone(), db.clone());
if self.open_dbs.contains_key(&key) {
pub async fn open_db(&mut self, key: &(PackageId, String)) -> Result<(), SqliteError> {
if self.open_dbs.contains_key(key) {
let mut access_order = self.access_order.lock().await;
access_order.remove(&key);
access_order.push_back(key);
access_order.remove(key);
access_order.push_back(key.clone());
return Ok(());
}
if self.open_dbs.len() as u64 >= self.fds_limit {
// close least recently used db
let key = self.access_order.lock().await.pop_front().unwrap();
self.remove_db(key.0, key.1).await;
let to_close = self.access_order.lock().await.pop_front().unwrap();
self.remove_db(&to_close).await;
}
#[cfg(unix)]
let db_path = self.sqlite_path.join(format!("{package_id}")).join(&db);
let db_path = self.sqlite_path.join(format!("{}", key.0)).join(&key.1);
#[cfg(target_os = "windows")]
let db_path = self
.sqlite_path
.join(format!(
"{}_{}",
package_id._package(),
package_id._publisher()
))
.join(&db);
.join(format!("{}_{}", key.0._package(), key.0._publisher()))
.join(&key.1);
fs::create_dir_all(&db_path).await?;
let db_file_path = db_path.join(format!("{}.db", db));
let db_file_path = db_path.join(format!("{}.db", key.1));
let db_conn = Connection::open(db_file_path)?;
let _: String = db_conn.query_row("PRAGMA journal_mode=WAL", [], |row| row.get(0))?;
self.open_dbs.insert(key, Mutex::new(db_conn));
self.open_dbs.insert(key.clone(), Mutex::new(db_conn));
let mut access_order = self.access_order.lock().await;
access_order.push_back((package_id, db));
access_order.push_back(key.clone());
Ok(())
}
pub async fn remove_db(&mut self, package_id: PackageId, db: String) {
self.open_dbs.remove(&(package_id.clone(), db.to_string()));
pub async fn remove_db(&mut self, key: &(PackageId, String)) {
self.open_dbs.remove(key);
let mut access_order = self.access_order.lock().await;
access_order.remove(&(package_id, db));
access_order.remove(key);
}
pub async fn remove_least_recently_used_dbs(&mut self, n: u64) {
@ -106,7 +102,7 @@ impl SqliteState {
let mut lock = self.access_order.lock().await;
let key = lock.pop_front().unwrap();
drop(lock);
self.remove_db(key.0, key.1).await;
self.remove_db(&key).await;
}
}
}
@ -176,8 +172,7 @@ pub async fn sqlite(
tokio::spawn(async move {
let mut queue_lock = queue.lock().await;
if let Some(km) = queue_lock.pop_front() {
let (km_id, km_rsvp) =
(km.id.clone(), km.rsvp.clone().unwrap_or(km.source.clone()));
let (km_id, km_rsvp) = (km.id, km.rsvp.clone().unwrap_or(km.source.clone()));
if let Err(e) = handle_request(km, &mut state, &send_to_caps_oracle).await {
Printout::new(1, SQLITE_PROCESS_ID.clone(), format!("sqlite: {e}"))
@ -226,27 +221,31 @@ async fn handle_request(
..
}) = message
else {
return Err(SqliteError::InputError {
error: "not a request".into(),
});
// we got a response -- safe to ignore
return Ok(());
};
let request: SqliteRequest = match serde_json::from_slice(&body) {
Ok(r) => r,
Err(e) => {
println!("sqlite: got invalid Request: {}", e);
return Err(SqliteError::InputError {
error: "didn't serialize to SqliteRequest.".into(),
});
println!("sqlite: got invalid request: {e}");
return Err(SqliteError::MalformedRequest);
}
};
check_caps(&source, state, send_to_caps_oracle, &request).await?;
let db_key = (request.package_id, request.db);
check_caps(
&source,
state,
send_to_caps_oracle,
&request.action,
&db_key,
)
.await?;
// always open to ensure db exists
state
.open_db(request.package_id.clone(), request.db.clone())
.await?;
state.open_db(&db_key).await?;
let (body, bytes) = match request.action {
SqliteAction::Open => {
@ -257,11 +256,11 @@ async fn handle_request(
// handled in check_caps
(serde_json::to_vec(&SqliteResponse::Ok).unwrap(), None)
}
SqliteAction::Read { query } => {
let db = match state.open_dbs.get(&(request.package_id, request.db)) {
SqliteAction::Query(query) => {
let db = match state.open_dbs.get(&db_key) {
Some(db) => db,
None => {
return Err(SqliteError::NoDb);
return Err(SqliteError::NoDb(db_key.0, db_key.1));
}
};
let db = db.lock().await;
@ -314,10 +313,10 @@ async fn handle_request(
)
}
SqliteAction::Write { statement, tx_id } => {
let db = match state.open_dbs.get(&(request.package_id, request.db)) {
let db = match state.open_dbs.get(&db_key) {
Some(db) => db,
None => {
return Err(SqliteError::NoDb);
return Err(SqliteError::NoDb(db_key.0, db_key.1));
}
};
let db = db.lock().await;
@ -359,17 +358,17 @@ async fn handle_request(
)
}
SqliteAction::Commit { tx_id } => {
let db = match state.open_dbs.get(&(request.package_id, request.db)) {
let db = match state.open_dbs.get(&db_key) {
Some(db) => db,
None => {
return Err(SqliteError::NoDb);
return Err(SqliteError::NoDb(db_key.0, db_key.1));
}
};
let mut db = db.lock().await;
let txs = match state.txs.remove(&tx_id).map(|(_, tx)| tx) {
None => {
return Err(SqliteError::NoTx);
return Err(SqliteError::NoTx(tx_id));
}
Some(tx) => tx,
};
@ -382,20 +381,6 @@ async fn handle_request(
tx.commit()?;
(serde_json::to_vec(&SqliteResponse::Ok).unwrap(), None)
}
SqliteAction::Backup => {
for db_ref in state.open_dbs.iter() {
let db = db_ref.value().lock().await;
let result: rusqlite::Result<()> = db
.query_row("PRAGMA wal_checkpoint(TRUNCATE)", [], |_| Ok(()))
.map(|_| ());
if let Err(e) = result {
return Err(SqliteError::RusqliteError {
error: e.to_string(),
});
}
}
(serde_json::to_vec(&SqliteResponse::Ok).unwrap(), None)
}
};
if let Some(target) = km.rsvp.or_else(|| expects_response.map(|_| source)) {
@ -429,128 +414,110 @@ async fn check_caps(
source: &Address,
state: &mut SqliteState,
send_to_caps_oracle: &CapMessageSender,
request: &SqliteRequest,
action: &SqliteAction,
db_key: &(PackageId, String),
) -> Result<(), SqliteError> {
let (send_cap_bool, recv_cap_bool) = tokio::sync::oneshot::channel();
let src_package_id = PackageId::new(source.process.package(), source.process.publisher());
match &request.action {
match action {
SqliteAction::Write { .. } | SqliteAction::BeginTx | SqliteAction::Commit { .. } => {
send_to_caps_oracle
let Ok(()) = send_to_caps_oracle
.send(CapMessage::Has {
on: source.process.clone(),
cap: Capability::new(
state.our.as_ref().clone(),
serde_json::json!({
"kind": "write",
"db": request.db.to_string(),
serde_json::to_string(&SqliteCapabilityParams {
kind: SqliteCapabilityKind::Write,
db_key: db_key.clone(),
})
.to_string(),
.unwrap(),
),
responder: send_cap_bool,
})
.await?;
let has_cap = recv_cap_bool.await?;
if !has_cap {
return Err(SqliteError::NoCap {
error: request.action.to_string(),
});
}
.await
else {
return Err(SqliteError::AddCapFailed);
};
let Ok(_) = recv_cap_bool.await else {
return Err(SqliteError::AddCapFailed);
};
Ok(())
}
SqliteAction::Read { .. } => {
send_to_caps_oracle
SqliteAction::Query { .. } => {
let Ok(()) = send_to_caps_oracle
.send(CapMessage::Has {
on: source.process.clone(),
cap: Capability::new(
state.our.as_ref().clone(),
serde_json::json!({
"kind": "read",
"db": request.db.to_string(),
serde_json::to_string(&SqliteCapabilityParams {
kind: SqliteCapabilityKind::Read,
db_key: db_key.clone(),
})
.to_string(),
.unwrap(),
),
responder: send_cap_bool,
})
.await?;
let has_cap = recv_cap_bool.await?;
if !has_cap {
return Err(SqliteError::NoCap {
error: request.action.to_string(),
});
}
.await
else {
return Err(SqliteError::AddCapFailed);
};
let Ok(_) = recv_cap_bool.await else {
return Err(SqliteError::AddCapFailed);
};
Ok(())
}
SqliteAction::Open => {
if src_package_id != request.package_id {
return Err(SqliteError::NoCap {
error: request.action.to_string(),
});
if src_package_id != db_key.0 {
return Err(SqliteError::MismatchingPackageId);
}
add_capability(
"read",
&request.db.to_string(),
SqliteCapabilityKind::Read,
db_key,
&state.our,
&source,
send_to_caps_oracle,
)
.await?;
add_capability(
"write",
&request.db.to_string(),
SqliteCapabilityKind::Write,
db_key,
&state.our,
&source,
send_to_caps_oracle,
)
.await?;
if state
.open_dbs
.contains_key(&(request.package_id.clone(), request.db.clone()))
{
if state.open_dbs.contains_key(db_key) {
return Ok(());
}
state
.open_db(request.package_id.clone(), request.db.clone())
.await?;
state.open_db(db_key).await?;
Ok(())
}
SqliteAction::RemoveDb => {
if src_package_id != request.package_id {
return Err(SqliteError::NoCap {
error: request.action.to_string(),
});
if src_package_id != db_key.0 {
return Err(SqliteError::MismatchingPackageId);
}
state
.remove_db(request.package_id.clone(), request.db.clone())
.await;
state.remove_db(db_key).await;
#[cfg(unix)]
let db_path = state
.sqlite_path
.join(format!("{}", request.package_id))
.join(&request.db);
.join(format!("{}", db_key.0))
.join(&db_key.1);
#[cfg(target_os = "windows")]
let db_path = state
.sqlite_path
.join(format!(
"{}_{}",
request.package_id._package(),
request.package_id._publisher()
))
.join(&request.db);
.join(format!("{}_{}", db_key.0._package(), db_key.0._publisher()))
.join(&db_key.1);
fs::remove_dir_all(&db_path).await?;
Ok(())
}
SqliteAction::Backup => {
// flushing WALs for backup
Ok(())
}
}
}
@ -559,9 +526,7 @@ async fn handle_fd_request(km: KernelMessage, state: &mut SqliteState) -> anyhow
return Err(anyhow::anyhow!("not a request"));
};
let request: FdManagerRequest = serde_json::from_slice(&body)?;
match request {
match serde_json::from_slice(&body)? {
FdManagerRequest::FdsLimit(new_fds_limit) => {
state.fds_limit = new_fds_limit;
if state.open_dbs.len() as u64 >= state.fds_limit {
@ -581,25 +546,34 @@ async fn handle_fd_request(km: KernelMessage, state: &mut SqliteState) -> anyhow
}
async fn add_capability(
kind: &str,
db: &str,
kind: SqliteCapabilityKind,
db_key: &(PackageId, String),
our: &Address,
source: &Address,
send_to_caps_oracle: &CapMessageSender,
) -> Result<(), SqliteError> {
let cap = Capability {
issuer: our.clone(),
params: serde_json::json!({ "kind": kind, "db": db }).to_string(),
params: serde_json::to_string(&SqliteCapabilityParams {
kind,
db_key: db_key.clone(),
})
.unwrap(),
};
let (send_cap_bool, recv_cap_bool) = tokio::sync::oneshot::channel();
send_to_caps_oracle
let Ok(()) = send_to_caps_oracle
.send(CapMessage::Add {
on: source.process.clone(),
caps: vec![cap],
responder: Some(send_cap_bool),
})
.await?;
let _ = recv_cap_bool.await?;
.await
else {
return Err(SqliteError::AddCapFailed);
};
let Ok(_) = recv_cap_bool.await else {
return Err(SqliteError::AddCapFailed);
};
Ok(())
}

View File

@ -1,43 +1,117 @@
use crate::types::core::{CapMessage, PackageId};
use crate::types::core::PackageId;
use rusqlite::types::{FromSql, FromSqlError, ToSql, ValueRef};
use serde::{Deserialize, Serialize};
use thiserror::Error;
/// IPC Request format for the sqlite:distro:sys runtime module.
#[derive(Debug, Serialize, Deserialize)]
/// Actions are sent to a specific SQLite database. `db` is the name,
/// `package_id` is the [`PackageId`] that created the database. Capabilities
/// are checked: you can access another process's database if it has given
/// you the read and/or write capability to do so.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SqliteRequest {
pub package_id: PackageId,
pub db: String,
pub action: SqliteAction,
}
#[derive(Debug, Serialize, Deserialize)]
/// IPC Action format representing operations that can be performed on the
/// SQLite runtime module. These actions are included in a [`SqliteRequest`]
/// sent to the `sqlite:distro:sys` runtime module.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum SqliteAction {
/// Opens an existing key-value database or creates a new one if it doesn't exist.
/// Requires `package_id` in [`SqliteRequest`] to match the package ID of the sender.
/// The sender will own the database and can remove it with [`SqliteAction::RemoveDb`].
///
/// A successful open will respond with [`SqliteResponse::Ok`]. Any error will be
/// contained in the [`SqliteResponse::Err`] variant.
Open,
/// Permanently deletes the entire key-value database.
/// Requires `package_id` in [`SqliteRequest`] to match the package ID of the sender.
/// Only the owner can remove the database.
///
/// A successful remove will respond with [`SqliteResponse::Ok`]. Any error will be
/// contained in the [`SqliteResponse::Err`] variant.
RemoveDb,
/// Executes a write statement (INSERT/UPDATE/DELETE)
///
/// * `statement` - SQL statement to execute
/// * `tx_id` - Optional transaction ID
/// * blob: Vec<SqlValue> - Parameters for the SQL statement, where SqlValue can be:
/// - null
/// - boolean
/// - i64
/// - f64
/// - String
/// - Vec<u8> (binary data)
///
/// Using this action requires the sender to have the write capability
/// for the database.
///
/// A successful write will respond with [`SqliteResponse::Ok`]. Any error will be
/// contained in the [`SqliteResponse::Err`] variant.
Write {
statement: String,
tx_id: Option<u64>,
},
Read {
query: String,
},
/// Executes a read query (SELECT)
///
/// * blob: Vec<SqlValue> - Parameters for the SQL query, where SqlValue can be:
/// - null
/// - boolean
/// - i64
/// - f64
/// - String
/// - Vec<u8> (binary data)
///
/// Using this action requires the sender to have the read capability
/// for the database.
///
/// A successful query will respond with [`SqliteResponse::Query`], where the
/// response blob contains the results of the query. Any error will be contained
/// in the [`SqliteResponse::Err`] variant.
Query(String),
/// Begins a new transaction for atomic operations.
///
/// Sending this will prompt a [`SqliteResponse::BeginTx`] response with the
/// transaction ID. Any error will be contained in the [`SqliteResponse::Err`] variant.
BeginTx,
Commit {
tx_id: u64,
},
Backup,
/// Commits all operations in the specified transaction.
///
/// # Parameters
/// * `tx_id` - The ID of the transaction to commit
///
/// A successful commit will respond with [`SqliteResponse::Ok`]. Any error will be
/// contained in the [`SqliteResponse::Err`] variant.
Commit { tx_id: u64 },
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum SqliteResponse {
/// Indicates successful completion of an operation.
/// Sent in response to actions Open, RemoveDb, Write, Query, BeginTx, and Commit.
Ok,
/// Returns the results of a query.
///
/// * blob: Vec<Vec<SqlValue>> - Array of rows, where each row contains SqlValue types:
/// - null
/// - boolean
/// - i64
/// - f64
/// - String
/// - Vec<u8> (binary data)
Read,
/// Returns the transaction ID for a newly created transaction.
///
/// # Fields
/// * `tx_id` - The ID of the newly created transaction
BeginTx { tx_id: u64 },
/// Indicates an error occurred during the operation.
Err(SqliteError),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
/// Used in blobs to represent array row values in SQLite.
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum SqlValue {
Integer(i64),
Real(f64),
@ -47,28 +121,50 @@ pub enum SqlValue {
Null,
}
#[derive(Debug, Serialize, Deserialize, Error)]
#[derive(Clone, Debug, Serialize, Deserialize, Error)]
pub enum SqliteError {
#[error("sqlite: DbDoesNotExist")]
NoDb,
#[error("sqlite: NoTx")]
NoTx,
#[error("sqlite: No capability: {error}")]
NoCap { error: String },
#[error("sqlite: UnexpectedResponse")]
UnexpectedResponse,
#[error("sqlite: NotAWriteKeyword")]
#[error("db [{0}, {1}] does not exist")]
NoDb(PackageId, String),
#[error("no transaction {0} found")]
NoTx(u64),
#[error("no write capability for requested DB")]
NoWriteCap,
#[error("no read capability for requested DB")]
NoReadCap,
#[error("request to open or remove DB with mismatching package ID")]
MismatchingPackageId,
#[error("failed to generate capability for new DB")]
AddCapFailed,
#[error("write statement started with non-existent write keyword")]
NotAWriteKeyword,
#[error("sqlite: NotAReadKeyword")]
#[error("read query started with non-existent read keyword")]
NotAReadKeyword,
#[error("sqlite: Invalid Parameters")]
#[error("parameters blob in read/write was misshapen or contained invalid JSON objects")]
InvalidParameters,
#[error("sqlite: IO error: {error}")]
IOError { error: String },
#[error("sqlite: rusqlite error: {error}")]
RusqliteError { error: String },
#[error("sqlite: input bytes/json/key error: {error}")]
InputError { error: String },
#[error("sqlite got a malformed request that failed to deserialize")]
MalformedRequest,
#[error("rusqlite error: {0}")]
RusqliteError(String),
#[error("IO error: {0}")]
IOError(String),
}
/// The JSON parameters contained in all capabilities issued by `sqlite:distro:sys`.
///
/// # Fields
/// * `kind` - The kind of capability, either [`SqliteCapabilityKind::Read`] or [`SqliteCapabilityKind::Write`]
/// * `db_key` - The database key, a tuple of the [`PackageId`] that created the database and the database name
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SqliteCapabilityParams {
pub kind: SqliteCapabilityKind,
pub db_key: (PackageId, String),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SqliteCapabilityKind {
Read,
Write,
}
impl ToSql for SqlValue {
@ -101,40 +197,14 @@ impl FromSql for SqlValue {
}
}
impl std::fmt::Display for SqliteAction {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl From<std::io::Error> for SqliteError {
fn from(err: std::io::Error) -> Self {
SqliteError::IOError {
error: err.to_string(),
}
SqliteError::IOError(err.to_string())
}
}
impl From<rusqlite::Error> for SqliteError {
fn from(err: rusqlite::Error) -> Self {
SqliteError::RusqliteError {
error: err.to_string(),
}
}
}
impl From<tokio::sync::oneshot::error::RecvError> for SqliteError {
fn from(err: tokio::sync::oneshot::error::RecvError) -> Self {
SqliteError::NoCap {
error: err.to_string(),
}
}
}
impl From<tokio::sync::mpsc::error::SendError<CapMessage>> for SqliteError {
fn from(err: tokio::sync::mpsc::error::SendError<CapMessage>) -> Self {
SqliteError::NoCap {
error: err.to_string(),
}
SqliteError::RusqliteError(err.to_string())
}
}