update job handle to ensure file count is consistent

Co-authored-by: Piotr <piotr@zed.dev>
This commit is contained in:
KCaverly 2023-08-21 12:47:43 +02:00
parent 1cae4758cc
commit def215af9f

View File

@ -98,9 +98,16 @@ struct ProjectState {
#[derive(Clone)]
struct JobHandle {
tx: Weak<Mutex<watch::Sender<usize>>>,
tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
}
impl JobHandle {
fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
Self {
tx: Arc::new(Arc::downgrade(&tx)),
}
}
}
impl ProjectState {
fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
self.worktree_db_ids
@ -651,7 +658,7 @@ impl SemanticIndex {
count += 1;
*job_count_tx.lock().borrow_mut() += 1;
let job_handle = JobHandle {
tx: Arc::downgrade(&job_count_tx),
tx: Arc::new(Arc::downgrade(&job_count_tx)),
};
parsing_files_tx
.try_send(PendingFile {
@ -726,6 +733,7 @@ impl SemanticIndex {
let database_url = self.database_url.clone();
let fs = self.fs.clone();
cx.spawn(|this, mut cx| async move {
let t0 = Instant::now();
let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
let phrase_embedding = embedding_provider
@ -735,6 +743,11 @@ impl SemanticIndex {
.next()
.unwrap();
log::trace!(
"Embedding search phrase took: {:?} milliseconds",
t0.elapsed().as_millis()
);
let file_ids =
database.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)?;
@ -809,6 +822,11 @@ impl SemanticIndex {
let buffers = futures::future::join_all(tasks).await;
log::trace!(
"Semantic Searching took: {:?} milliseconds in total",
t0.elapsed().as_millis()
);
Ok(buffers
.into_iter()
.zip(ranges)
@ -830,12 +848,25 @@ impl Entity for SemanticIndex {
impl Drop for JobHandle {
fn drop(&mut self) {
if let Some(tx) = self.tx.upgrade() {
let mut tx = tx.lock();
// Manage for overflow, cause we are cloning the Job Handle
if *tx.borrow() > 0 {
if let Some(inner) = Arc::get_mut(&mut self.tx) {
if let Some(tx) = inner.upgrade() {
let mut tx = tx.lock();
*tx.borrow_mut() -= 1;
};
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_job_handle() {
let (job_count_tx, job_count_rx) = watch::channel_with(0);
let tx = Arc::new(Mutex::new(job_count_tx));
let job_handle = JobHandle::new(tx);
assert_eq!(1, *job_count_rx.borrow_mut());
}
}