use std::{collections::HashMap, sync::Arc}; use serde::Serialize; use tokio::sync::{oneshot, Mutex}; use crate::{id::Id, virtual_branches::BranchId}; pub struct AskpassRequest { sender: oneshot::Sender>, } #[derive(Debug, Clone, serde::Serialize)] // This is needed to end up with a struct with either `branch_id` or `action` #[serde(untagged)] pub enum Context { Push { branch_id: Option }, Fetch { action: String }, } #[derive(Clone)] pub struct AskpassBroker { pending_requests: Arc, AskpassRequest>>>, submit_prompt_event: Arc) + Send + Sync>, } #[derive(Debug, Clone, serde::Serialize)] pub struct PromptEvent { id: Id, prompt: String, context: C, } impl AskpassBroker { pub fn init(submit_prompt: impl Fn(PromptEvent) + Send + Sync + 'static) -> Self { Self { pending_requests: Arc::new(Mutex::new(HashMap::new())), submit_prompt_event: Arc::new(submit_prompt), } } pub async fn submit_prompt(&self, prompt: String, context: Context) -> Option { let (sender, receiver) = oneshot::channel(); let id = Id::generate(); let request = AskpassRequest { sender }; self.pending_requests.lock().await.insert(id, request); (self.submit_prompt_event)(PromptEvent { id, prompt, context, }); receiver.await.unwrap() } pub async fn handle_response(&self, id: Id, response: Option) { let mut pending_requests = self.pending_requests.lock().await; if let Some(request) = pending_requests.remove(&id) { let _ = request.sender.send(response); } else { log::warn!("received response for unknown askpass request: {}", id); } } }