mirror of
https://github.com/uqbar-dao/nectar.git
synced 2024-12-12 06:33:46 +03:00
568 lines
19 KiB
Rust
568 lines
19 KiB
Rust
use anyhow::Result;
|
|
use dashmap::DashMap;
|
|
use rusqlite::Connection;
|
|
use std::collections::{HashMap, HashSet, VecDeque};
|
|
use std::sync::Arc;
|
|
use tokio::fs;
|
|
use tokio::sync::Mutex;
|
|
|
|
use lib::types::core::*;
|
|
|
|
lazy_static::lazy_static! {
|
|
static ref READ_KEYWORDS: HashSet<String> = {
|
|
let mut set = HashSet::new();
|
|
let keywords = ["ANALYZE", "ATTACH", "BEGIN", "EXPLAIN", "PRAGMA", "SELECT", "VALUES", "WITH"];
|
|
for &keyword in &keywords {
|
|
set.insert(keyword.to_string());
|
|
}
|
|
set
|
|
};
|
|
|
|
static ref WRITE_KEYWORDS: HashSet<String> = {
|
|
let mut set = HashSet::new();
|
|
let keywords = ["ALTER", "ANALYZE", "COMMIT", "CREATE", "DELETE", "DETACH", "DROP", "END", "INSERT", "REINDEX", "RELEASE", "RENAME", "REPLACE", "ROLLBACK", "SAVEPOINT", "UPDATE", "VACUUM"];
|
|
for &keyword in &keywords {
|
|
set.insert(keyword.to_string());
|
|
}
|
|
set
|
|
};
|
|
}
|
|
|
|
pub async fn sqlite(
|
|
our_node: String,
|
|
send_to_loop: MessageSender,
|
|
send_to_terminal: PrintSender,
|
|
mut recv_from_loop: MessageReceiver,
|
|
send_to_caps_oracle: CapMessageSender,
|
|
home_directory_path: String,
|
|
) -> anyhow::Result<()> {
|
|
let sqlite_path = format!("{}/sqlite", &home_directory_path);
|
|
|
|
if let Err(e) = fs::create_dir_all(&sqlite_path).await {
|
|
panic!("failed creating sqlite dir! {:?}", e);
|
|
}
|
|
|
|
let open_dbs: Arc<DashMap<(PackageId, String), Mutex<Connection>>> = Arc::new(DashMap::new());
|
|
let txs: Arc<DashMap<u64, Vec<(String, Vec<SqlValue>)>>> = Arc::new(DashMap::new());
|
|
|
|
let mut process_queues: HashMap<ProcessId, Arc<Mutex<VecDeque<KernelMessage>>>> =
|
|
HashMap::new();
|
|
|
|
loop {
|
|
tokio::select! {
|
|
Some(km) = recv_from_loop.recv() => {
|
|
if our_node.clone() != km.source.node {
|
|
println!(
|
|
"sqlite: request must come from our_node={}, got: {}",
|
|
our_node,
|
|
km.source.node,
|
|
);
|
|
continue;
|
|
}
|
|
|
|
let queue = process_queues
|
|
.entry(km.source.process.clone())
|
|
.or_insert_with(|| Arc::new(Mutex::new(VecDeque::new())))
|
|
.clone();
|
|
|
|
{
|
|
let mut queue_lock = queue.lock().await;
|
|
queue_lock.push_back(km.clone());
|
|
}
|
|
|
|
// clone Arcs
|
|
let our_node = our_node.clone();
|
|
let send_to_caps_oracle = send_to_caps_oracle.clone();
|
|
let send_to_terminal = send_to_terminal.clone();
|
|
let send_to_loop = send_to_loop.clone();
|
|
let open_dbs = open_dbs.clone();
|
|
|
|
let txs = txs.clone();
|
|
let sqlite_path = sqlite_path.clone();
|
|
|
|
tokio::spawn(async move {
|
|
let mut queue_lock = queue.lock().await;
|
|
if let Some(km) = queue_lock.pop_front() {
|
|
if let Err(e) = handle_request(
|
|
our_node.clone(),
|
|
km.clone(),
|
|
open_dbs.clone(),
|
|
txs.clone(),
|
|
send_to_loop.clone(),
|
|
send_to_terminal.clone(),
|
|
send_to_caps_oracle.clone(),
|
|
sqlite_path.clone(),
|
|
)
|
|
.await
|
|
{
|
|
let _ = send_to_loop
|
|
.send(make_error_message(our_node.clone(), &km, e))
|
|
.await;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_request(
|
|
our_node: String,
|
|
km: KernelMessage,
|
|
open_dbs: Arc<DashMap<(PackageId, String), Mutex<Connection>>>,
|
|
txs: Arc<DashMap<u64, Vec<(String, Vec<SqlValue>)>>>,
|
|
send_to_loop: MessageSender,
|
|
send_to_terminal: PrintSender,
|
|
send_to_caps_oracle: CapMessageSender,
|
|
sqlite_path: String,
|
|
) -> Result<(), SqliteError> {
|
|
let KernelMessage {
|
|
id,
|
|
source,
|
|
message,
|
|
lazy_load_blob: blob,
|
|
..
|
|
} = km.clone();
|
|
let Message::Request(Request {
|
|
body,
|
|
expects_response,
|
|
metadata,
|
|
..
|
|
}) = message.clone()
|
|
else {
|
|
return Err(SqliteError::InputError {
|
|
error: "not a request".into(),
|
|
});
|
|
};
|
|
|
|
let request: SqliteRequest = match serde_json::from_slice(&body) {
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
println!("sqlite: got invalid Request: {}", e);
|
|
return Err(SqliteError::InputError {
|
|
error: "didn't serialize to SqliteRequest.".into(),
|
|
});
|
|
}
|
|
};
|
|
|
|
check_caps(
|
|
our_node.clone(),
|
|
source.clone(),
|
|
open_dbs.clone(),
|
|
send_to_caps_oracle.clone(),
|
|
&request,
|
|
sqlite_path.clone(),
|
|
)
|
|
.await?;
|
|
|
|
let (body, bytes) = match request.action {
|
|
SqliteAction::Open => {
|
|
// handled in check_caps
|
|
(serde_json::to_vec(&SqliteResponse::Ok).unwrap(), None)
|
|
}
|
|
SqliteAction::RemoveDb => {
|
|
// handled in check_caps
|
|
(serde_json::to_vec(&SqliteResponse::Ok).unwrap(), None)
|
|
}
|
|
SqliteAction::Read { query } => {
|
|
let db = match open_dbs.get(&(request.package_id, request.db)) {
|
|
Some(db) => db,
|
|
None => {
|
|
return Err(SqliteError::NoDb);
|
|
}
|
|
};
|
|
let db = db.lock().await;
|
|
let first_word = query
|
|
.split_whitespace()
|
|
.next()
|
|
.map(|word| word.to_uppercase())
|
|
.unwrap_or("".to_string());
|
|
if !READ_KEYWORDS.contains(&first_word) {
|
|
return Err(SqliteError::NotAReadKeyword);
|
|
}
|
|
|
|
let parameters = get_json_params(blob)?;
|
|
|
|
let mut statement = db.prepare(&query)?;
|
|
let column_names: Vec<String> = statement
|
|
.column_names()
|
|
.iter()
|
|
.map(|c| c.to_string())
|
|
.collect();
|
|
|
|
let results: Vec<HashMap<String, serde_json::Value>> = statement
|
|
.query_map(rusqlite::params_from_iter(parameters.iter()), |row| {
|
|
let mut map = HashMap::new();
|
|
for (i, column_name) in column_names.iter().enumerate() {
|
|
let value: SqlValue = row.get(i)?;
|
|
let value_json = match value {
|
|
SqlValue::Integer(int) => serde_json::Value::Number(int.into()),
|
|
SqlValue::Real(real) => serde_json::Value::Number(
|
|
serde_json::Number::from_f64(real).unwrap(),
|
|
),
|
|
SqlValue::Text(text) => serde_json::Value::String(text),
|
|
SqlValue::Blob(blob) => serde_json::Value::String(base64::encode(blob)), // or another representation if you prefer
|
|
_ => serde_json::Value::Null,
|
|
};
|
|
map.insert(column_name.clone(), value_json);
|
|
}
|
|
Ok(map)
|
|
})?
|
|
.collect::<Result<Vec<_>, _>>()?;
|
|
|
|
let results = serde_json::json!(results).to_string();
|
|
let results_bytes = results.as_bytes().to_vec();
|
|
|
|
(
|
|
serde_json::to_vec(&SqliteResponse::Read).unwrap(),
|
|
Some(results_bytes),
|
|
)
|
|
}
|
|
SqliteAction::Write { statement, tx_id } => {
|
|
let db = match open_dbs.get(&(request.package_id, request.db)) {
|
|
Some(db) => db,
|
|
None => {
|
|
return Err(SqliteError::NoDb);
|
|
}
|
|
};
|
|
let db = db.lock().await;
|
|
|
|
let first_word = statement
|
|
.split_whitespace()
|
|
.next()
|
|
.map(|word| word.to_uppercase())
|
|
.unwrap_or("".to_string());
|
|
|
|
if !WRITE_KEYWORDS.contains(&first_word) {
|
|
return Err(SqliteError::NotAWriteKeyword);
|
|
}
|
|
|
|
let parameters = get_json_params(blob)?;
|
|
|
|
match tx_id {
|
|
Some(tx_id) => {
|
|
txs.entry(tx_id)
|
|
.or_default()
|
|
.push((statement.clone(), parameters));
|
|
}
|
|
None => {
|
|
let mut stmt = db.prepare(&statement)?;
|
|
stmt.execute(rusqlite::params_from_iter(parameters.iter()))?;
|
|
}
|
|
};
|
|
(serde_json::to_vec(&SqliteResponse::Ok).unwrap(), None)
|
|
}
|
|
SqliteAction::BeginTx => {
|
|
let tx_id = rand::random::<u64>();
|
|
txs.insert(tx_id, Vec::new());
|
|
|
|
(
|
|
serde_json::to_vec(&SqliteResponse::BeginTx { tx_id }).unwrap(),
|
|
None,
|
|
)
|
|
}
|
|
SqliteAction::Commit { tx_id } => {
|
|
let db = match open_dbs.get(&(request.package_id, request.db)) {
|
|
Some(db) => db,
|
|
None => {
|
|
return Err(SqliteError::NoDb);
|
|
}
|
|
};
|
|
let mut db = db.lock().await;
|
|
|
|
let txs = match txs.remove(&tx_id).map(|(_, tx)| tx) {
|
|
None => {
|
|
return Err(SqliteError::NoTx);
|
|
}
|
|
Some(tx) => tx,
|
|
};
|
|
|
|
let tx = db.transaction()?;
|
|
for (query, params) in txs {
|
|
tx.execute(&query, rusqlite::params_from_iter(params.iter()))?;
|
|
}
|
|
|
|
tx.commit()?;
|
|
(serde_json::to_vec(&SqliteResponse::Ok).unwrap(), None)
|
|
}
|
|
SqliteAction::Backup => {
|
|
for db_ref in open_dbs.iter() {
|
|
let db = db_ref.value().lock().await;
|
|
let result: rusqlite::Result<()> = db
|
|
.query_row("PRAGMA wal_checkpoint(TRUNCATE)", [], |_| Ok(()))
|
|
.map(|_| ());
|
|
if let Err(e) = result {
|
|
return Err(SqliteError::RusqliteError {
|
|
error: e.to_string(),
|
|
});
|
|
}
|
|
}
|
|
(serde_json::to_vec(&SqliteResponse::Ok).unwrap(), None)
|
|
}
|
|
};
|
|
|
|
if let Some(target) = km.rsvp.or_else(|| {
|
|
expects_response.map(|_| Address {
|
|
node: our_node.clone(),
|
|
process: source.process.clone(),
|
|
})
|
|
}) {
|
|
let response = KernelMessage {
|
|
id,
|
|
source: Address {
|
|
node: our_node.clone(),
|
|
process: SQLITE_PROCESS_ID.clone(),
|
|
},
|
|
target,
|
|
rsvp: None,
|
|
message: Message::Response((
|
|
Response {
|
|
inherit: false,
|
|
body,
|
|
metadata,
|
|
capabilities: vec![],
|
|
},
|
|
None,
|
|
)),
|
|
lazy_load_blob: bytes.map(|bytes| LazyLoadBlob {
|
|
mime: Some("application/octet-stream".into()),
|
|
bytes,
|
|
}),
|
|
};
|
|
|
|
let _ = send_to_loop.send(response).await;
|
|
} else {
|
|
send_to_terminal
|
|
.send(Printout {
|
|
verbosity: 2,
|
|
content: format!(
|
|
"sqlite: not sending response: {:?}",
|
|
serde_json::from_slice::<SqliteResponse>(&body)
|
|
),
|
|
})
|
|
.await
|
|
.unwrap();
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn check_caps(
|
|
our_node: String,
|
|
source: Address,
|
|
open_dbs: Arc<DashMap<(PackageId, String), Mutex<Connection>>>,
|
|
mut send_to_caps_oracle: CapMessageSender,
|
|
request: &SqliteRequest,
|
|
sqlite_path: String,
|
|
) -> Result<(), SqliteError> {
|
|
let (send_cap_bool, recv_cap_bool) = tokio::sync::oneshot::channel();
|
|
let src_package_id = PackageId::new(source.process.package(), source.process.publisher());
|
|
|
|
match &request.action {
|
|
SqliteAction::Write { .. } | SqliteAction::BeginTx | SqliteAction::Commit { .. } => {
|
|
send_to_caps_oracle
|
|
.send(CapMessage::Has {
|
|
on: source.process.clone(),
|
|
cap: Capability {
|
|
issuer: Address {
|
|
node: our_node.clone(),
|
|
process: SQLITE_PROCESS_ID.clone(),
|
|
},
|
|
params: serde_json::to_string(&serde_json::json!({
|
|
"kind": "write",
|
|
"db": request.db.to_string(),
|
|
}))
|
|
.unwrap(),
|
|
},
|
|
responder: send_cap_bool,
|
|
})
|
|
.await?;
|
|
let has_cap = recv_cap_bool.await?;
|
|
if !has_cap {
|
|
return Err(SqliteError::NoCap {
|
|
error: request.action.to_string(),
|
|
});
|
|
}
|
|
Ok(())
|
|
}
|
|
SqliteAction::Read { .. } => {
|
|
send_to_caps_oracle
|
|
.send(CapMessage::Has {
|
|
on: source.process.clone(),
|
|
cap: Capability {
|
|
issuer: Address {
|
|
node: our_node.clone(),
|
|
process: SQLITE_PROCESS_ID.clone(),
|
|
},
|
|
params: serde_json::to_string(&serde_json::json!({
|
|
"kind": "read",
|
|
"db": request.db.to_string(),
|
|
}))
|
|
.unwrap(),
|
|
},
|
|
responder: send_cap_bool,
|
|
})
|
|
.await?;
|
|
let has_cap = recv_cap_bool.await?;
|
|
if !has_cap {
|
|
return Err(SqliteError::NoCap {
|
|
error: request.action.to_string(),
|
|
});
|
|
}
|
|
Ok(())
|
|
}
|
|
SqliteAction::Open => {
|
|
if src_package_id != request.package_id {
|
|
return Err(SqliteError::NoCap {
|
|
error: request.action.to_string(),
|
|
});
|
|
}
|
|
|
|
add_capability(
|
|
"read",
|
|
&request.db.to_string(),
|
|
&our_node,
|
|
&source,
|
|
&mut send_to_caps_oracle,
|
|
)
|
|
.await?;
|
|
add_capability(
|
|
"write",
|
|
&request.db.to_string(),
|
|
&our_node,
|
|
&source,
|
|
&mut send_to_caps_oracle,
|
|
)
|
|
.await?;
|
|
|
|
if open_dbs.contains_key(&(request.package_id.clone(), request.db.clone())) {
|
|
return Ok(());
|
|
}
|
|
|
|
let db_path = format!("{}/{}/{}", sqlite_path, request.package_id, request.db);
|
|
fs::create_dir_all(&db_path).await?;
|
|
|
|
let db_file_path = format!("{}/{}.db", db_path, request.db);
|
|
|
|
let db = Connection::open(db_file_path)?;
|
|
let _ = db.execute("PRAGMA journal_mode=WAL", []);
|
|
|
|
open_dbs.insert(
|
|
(request.package_id.clone(), request.db.clone()),
|
|
Mutex::new(db),
|
|
);
|
|
Ok(())
|
|
}
|
|
SqliteAction::RemoveDb => {
|
|
if src_package_id != request.package_id {
|
|
return Err(SqliteError::NoCap {
|
|
error: request.action.to_string(),
|
|
});
|
|
}
|
|
|
|
let db_path = format!("{}/{}/{}", sqlite_path, request.package_id, request.db);
|
|
open_dbs.remove(&(request.package_id.clone(), request.db.clone()));
|
|
|
|
fs::remove_dir_all(&db_path).await?;
|
|
Ok(())
|
|
}
|
|
SqliteAction::Backup => {
|
|
// flushing WALs for backup
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn add_capability(
|
|
kind: &str,
|
|
db: &str,
|
|
our_node: &str,
|
|
source: &Address,
|
|
send_to_caps_oracle: &mut CapMessageSender,
|
|
) -> Result<(), SqliteError> {
|
|
let cap = Capability {
|
|
issuer: Address {
|
|
node: our_node.to_string(),
|
|
process: SQLITE_PROCESS_ID.clone(),
|
|
},
|
|
params: serde_json::to_string(&serde_json::json!({ "kind": kind, "db": db })).unwrap(),
|
|
};
|
|
let (send_cap_bool, recv_cap_bool) = tokio::sync::oneshot::channel();
|
|
send_to_caps_oracle
|
|
.send(CapMessage::Add {
|
|
on: source.process.clone(),
|
|
caps: vec![cap],
|
|
responder: send_cap_bool,
|
|
})
|
|
.await?;
|
|
let _ = recv_cap_bool.await?;
|
|
Ok(())
|
|
}
|
|
|
|
fn json_to_sqlite(value: &serde_json::Value) -> Result<SqlValue, SqliteError> {
|
|
match value {
|
|
serde_json::Value::Number(n) => {
|
|
if let Some(int_val) = n.as_i64() {
|
|
Ok(SqlValue::Integer(int_val))
|
|
} else if let Some(float_val) = n.as_f64() {
|
|
Ok(SqlValue::Real(float_val))
|
|
} else {
|
|
Err(SqliteError::InvalidParameters)
|
|
}
|
|
}
|
|
serde_json::Value::String(s) => {
|
|
match base64::decode(s) {
|
|
Ok(decoded_bytes) => {
|
|
// convert to SQLite Blob if it's a valid base64 string
|
|
Ok(SqlValue::Blob(decoded_bytes))
|
|
}
|
|
Err(_) => {
|
|
// if it's not base64, just use the string itself
|
|
Ok(SqlValue::Text(s.clone()))
|
|
}
|
|
}
|
|
}
|
|
serde_json::Value::Bool(b) => Ok(SqlValue::Boolean(*b)),
|
|
serde_json::Value::Null => Ok(SqlValue::Null),
|
|
_ => Err(SqliteError::InvalidParameters),
|
|
}
|
|
}
|
|
|
|
fn get_json_params(blob: Option<LazyLoadBlob>) -> Result<Vec<SqlValue>, SqliteError> {
|
|
match blob {
|
|
None => Ok(vec![]),
|
|
Some(blob) => match serde_json::from_slice::<serde_json::Value>(&blob.bytes) {
|
|
Ok(serde_json::Value::Array(vec)) => vec
|
|
.iter()
|
|
.map(json_to_sqlite)
|
|
.collect::<Result<Vec<_>, _>>(),
|
|
_ => Err(SqliteError::InvalidParameters),
|
|
},
|
|
}
|
|
}
|
|
|
|
fn make_error_message(our_name: String, km: &KernelMessage, error: SqliteError) -> KernelMessage {
|
|
KernelMessage {
|
|
id: km.id,
|
|
source: Address {
|
|
node: our_name.clone(),
|
|
process: SQLITE_PROCESS_ID.clone(),
|
|
},
|
|
target: match &km.rsvp {
|
|
None => km.source.clone(),
|
|
Some(rsvp) => rsvp.clone(),
|
|
},
|
|
rsvp: None,
|
|
message: Message::Response((
|
|
Response {
|
|
inherit: false,
|
|
body: serde_json::to_vec(&SqliteResponse::Err { error }).unwrap(),
|
|
metadata: None,
|
|
capabilities: vec![],
|
|
},
|
|
None,
|
|
)),
|
|
lazy_load_blob: None,
|
|
}
|
|
}
|