mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-08 07:35:01 +03:00
Populate project search results multi-buffer from semantic search
Co-authored-by: Kyle <kyle@zed.dev>
This commit is contained in:
parent
80ef92a3e1
commit
8d0614ce74
@ -2,7 +2,7 @@ use crate::{
|
||||
SearchOption, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleRegex,
|
||||
ToggleWholeWord,
|
||||
};
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use editor::{
|
||||
items::active_match_index, scroll::autoscroll::Autoscroll, Anchor, Editor, MultiBuffer,
|
||||
@ -187,6 +187,53 @@ impl ProjectSearch {
|
||||
}));
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
fn semantic_search(&mut self, query: String, cx: &mut ModelContext<Self>) -> Option<()> {
|
||||
let project = self.project.clone();
|
||||
let semantic_index = SemanticIndex::global(cx)?;
|
||||
let search_task = semantic_index.update(cx, |semantic_index, cx| {
|
||||
semantic_index.search_project(project, query.clone(), 10, cx)
|
||||
});
|
||||
|
||||
self.search_id += 1;
|
||||
// self.active_query = Some(query);
|
||||
self.match_ranges.clear();
|
||||
self.pending_search = Some(cx.spawn(|this, mut cx| async move {
|
||||
let results = search_task.await.log_err()?;
|
||||
|
||||
let (_task, mut match_ranges) = this.update(&mut cx, |this, cx| {
|
||||
this.excerpts.update(cx, |excerpts, cx| {
|
||||
excerpts.clear(cx);
|
||||
|
||||
let matches = results
|
||||
.into_iter()
|
||||
.map(|result| (result.buffer, vec![result.range]))
|
||||
.collect();
|
||||
|
||||
excerpts.stream_excerpts_with_context_lines(matches, 3, cx)
|
||||
})
|
||||
});
|
||||
|
||||
while let Some(match_range) = match_ranges.next().await {
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.match_ranges.push(match_range);
|
||||
while let Ok(Some(match_range)) = match_ranges.try_next() {
|
||||
this.match_ranges.push(match_range);
|
||||
}
|
||||
cx.notify();
|
||||
});
|
||||
}
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.pending_search.take();
|
||||
cx.notify();
|
||||
});
|
||||
|
||||
None
|
||||
}));
|
||||
|
||||
Some(())
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ViewEvent {
|
||||
@ -595,27 +642,9 @@ impl ProjectSearchView {
|
||||
return;
|
||||
}
|
||||
|
||||
let search_phrase = self.query_editor.read(cx).text(cx);
|
||||
let project = self.model.read(cx).project.clone();
|
||||
if let Some(semantic_index) = SemanticIndex::global(cx) {
|
||||
let search_task = semantic_index.update(cx, |semantic_index, cx| {
|
||||
semantic_index.search_project(project, search_phrase, 10, cx)
|
||||
});
|
||||
semantic.search_task = Some(cx.spawn(|this, mut cx| async move {
|
||||
let results = search_task.await.context("search task")?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
dbg!(&results);
|
||||
// TODO: Update results
|
||||
|
||||
if let Some(semantic) = &mut this.semantic {
|
||||
semantic.search_task = None;
|
||||
}
|
||||
})?;
|
||||
|
||||
anyhow::Ok(())
|
||||
}));
|
||||
}
|
||||
let query = self.query_editor.read(cx).text(cx);
|
||||
self.model
|
||||
.update(cx, |model, cx| model.semantic_search(query, cx));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -252,7 +252,7 @@ impl VectorDatabase {
|
||||
worktree_ids: &[i64],
|
||||
query_embedding: &Vec<f32>,
|
||||
limit: usize,
|
||||
) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
|
||||
) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
|
||||
let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
|
||||
self.for_each_document(&worktree_ids, |id, embedding| {
|
||||
let similarity = dot(&embedding, &query_embedding);
|
||||
@ -296,10 +296,7 @@ impl VectorDatabase {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_documents_by_ids(
|
||||
&self,
|
||||
ids: &[i64],
|
||||
) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
|
||||
fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
|
||||
let mut statement = self.db.prepare(
|
||||
"
|
||||
SELECT
|
||||
@ -307,7 +304,7 @@ impl VectorDatabase {
|
||||
files.worktree_id,
|
||||
files.relative_path,
|
||||
documents.start_byte,
|
||||
documents.end_byte, documents.name
|
||||
documents.end_byte
|
||||
FROM
|
||||
documents, files
|
||||
WHERE
|
||||
@ -322,14 +319,13 @@ impl VectorDatabase {
|
||||
row.get::<_, i64>(1)?,
|
||||
row.get::<_, String>(2)?.into(),
|
||||
row.get(3)?..row.get(4)?,
|
||||
row.get(5)?,
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>, String)>::default();
|
||||
let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
|
||||
for row in result_iter {
|
||||
let (id, worktree_id, path, range, name) = row?;
|
||||
values_by_id.insert(id, (worktree_id, path, range, name));
|
||||
let (id, worktree_id, path, range) = row?;
|
||||
values_by_id.insert(id, (worktree_id, path, range));
|
||||
}
|
||||
|
||||
let mut results = Vec::with_capacity(ids.len());
|
||||
|
@ -70,10 +70,6 @@ impl EmbeddingProvider for DummyEmbeddings {
|
||||
const OPENAI_INPUT_LIMIT: usize = 8190;
|
||||
|
||||
impl OpenAIEmbeddings {
|
||||
pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
|
||||
Self { client, executor }
|
||||
}
|
||||
|
||||
fn truncate(span: String) -> String {
|
||||
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
|
||||
if tokens.len() > OPENAI_INPUT_LIMIT {
|
||||
@ -81,7 +77,6 @@ impl OpenAIEmbeddings {
|
||||
let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
|
||||
if result.is_ok() {
|
||||
let transformed = result.unwrap();
|
||||
// assert_ne!(transformed, span);
|
||||
return transformed;
|
||||
}
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ use db::VectorDatabase;
|
||||
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
|
||||
use futures::{channel::oneshot, Future};
|
||||
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
|
||||
use language::{Language, LanguageRegistry};
|
||||
use language::{Anchor, Buffer, Language, LanguageRegistry};
|
||||
use parking_lot::Mutex;
|
||||
use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
|
||||
use postage::watch;
|
||||
@ -93,7 +93,7 @@ pub struct SemanticIndex {
|
||||
struct ProjectState {
|
||||
worktree_db_ids: Vec<(WorktreeId, i64)>,
|
||||
outstanding_job_count_rx: watch::Receiver<usize>,
|
||||
outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
|
||||
_outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
|
||||
}
|
||||
|
||||
struct JobHandle {
|
||||
@ -135,12 +135,9 @@ pub struct PendingFile {
|
||||
job_handle: JobHandle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SearchResult {
|
||||
pub worktree_id: WorktreeId,
|
||||
pub name: String,
|
||||
pub byte_range: Range<usize>,
|
||||
pub file_path: PathBuf,
|
||||
pub buffer: ModelHandle<Buffer>,
|
||||
pub range: Range<Anchor>,
|
||||
}
|
||||
|
||||
enum DbOperation {
|
||||
@ -520,7 +517,7 @@ impl SemanticIndex {
|
||||
.map(|(a, b)| (*a, *b))
|
||||
.collect(),
|
||||
outstanding_job_count_rx: job_count_rx.clone(),
|
||||
outstanding_job_count_tx: job_count_tx.clone(),
|
||||
_outstanding_job_count_tx: job_count_tx.clone(),
|
||||
},
|
||||
);
|
||||
});
|
||||
@ -623,7 +620,7 @@ impl SemanticIndex {
|
||||
let embedding_provider = self.embedding_provider.clone();
|
||||
let database_url = self.database_url.clone();
|
||||
let fs = self.fs.clone();
|
||||
cx.spawn(|this, cx| async move {
|
||||
cx.spawn(|this, mut cx| async move {
|
||||
let documents = cx
|
||||
.background()
|
||||
.spawn(async move {
|
||||
@ -640,26 +637,39 @@ impl SemanticIndex {
|
||||
})
|
||||
.await?;
|
||||
|
||||
this.read_with(&cx, |this, _| {
|
||||
let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
|
||||
state
|
||||
} else {
|
||||
return Err(anyhow!("project not added"));
|
||||
};
|
||||
let mut tasks = Vec::new();
|
||||
let mut ranges = Vec::new();
|
||||
let weak_project = project.downgrade();
|
||||
project.update(&mut cx, |project, cx| {
|
||||
for (worktree_db_id, file_path, byte_range) in documents {
|
||||
let project_state =
|
||||
if let Some(state) = this.read(cx).projects.get(&weak_project) {
|
||||
state
|
||||
} else {
|
||||
return Err(anyhow!("project not added"));
|
||||
};
|
||||
if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
|
||||
tasks.push(project.open_buffer((worktree_id, file_path), cx));
|
||||
ranges.push(byte_range);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(documents
|
||||
.into_iter()
|
||||
.filter_map(|(worktree_db_id, file_path, byte_range, name)| {
|
||||
let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
|
||||
Some(SearchResult {
|
||||
worktree_id,
|
||||
name,
|
||||
byte_range,
|
||||
file_path,
|
||||
})
|
||||
})
|
||||
.collect())
|
||||
})
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
let buffers = futures::future::join_all(tasks).await;
|
||||
|
||||
Ok(buffers
|
||||
.into_iter()
|
||||
.zip(ranges)
|
||||
.filter_map(|(buffer, range)| {
|
||||
let buffer = buffer.log_err()?;
|
||||
let range = buffer.read_with(&cx, |buffer, _| {
|
||||
buffer.anchor_before(range.start)..buffer.anchor_after(range.end)
|
||||
});
|
||||
Some(SearchResult { buffer, range })
|
||||
})
|
||||
.collect::<Vec<_>>())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ use crate::{
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use gpui::{Task, TestAppContext};
|
||||
use language::{Language, LanguageConfig, LanguageRegistry};
|
||||
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
|
||||
use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
|
||||
use rand::{rngs::StdRng, Rng};
|
||||
use serde_json::json;
|
||||
@ -85,9 +85,6 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||
.unwrap();
|
||||
|
||||
let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
|
||||
let worktree_id = project.read_with(cx, |project, cx| {
|
||||
project.worktrees(cx).next().unwrap().read(cx).id()
|
||||
});
|
||||
let (file_count, outstanding_file_count) = store
|
||||
.update(cx, |store, cx| store.index_project(project.clone(), cx))
|
||||
.await
|
||||
@ -103,9 +100,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(search_results[0].byte_range.start, 0);
|
||||
assert_eq!(search_results[0].name, "aaa");
|
||||
assert_eq!(search_results[0].worktree_id, worktree_id);
|
||||
search_results[0].buffer.read_with(cx, |buffer, _cx| {
|
||||
assert_eq!(search_results[0].range.start.to_offset(buffer), 0);
|
||||
assert_eq!(
|
||||
buffer.file().unwrap().path().as_ref(),
|
||||
Path::new("file1.rs")
|
||||
);
|
||||
});
|
||||
|
||||
fs.save(
|
||||
"/the-root/src/file2.rs".as_ref(),
|
||||
|
Loading…
Reference in New Issue
Block a user