diff --git a/Cargo.lock b/Cargo.lock index 3ad201641a..55ab7f8d73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8689,6 +8689,7 @@ dependencies = [ "languages", "log", "open_ai", + "parking_lot", "project", "serde", "serde_json", diff --git a/crates/assistant2/examples/assistant_example.rs b/crates/assistant2/examples/assistant_example.rs index 260c3bc8f9..6dbbb27148 100644 --- a/crates/assistant2/examples/assistant_example.rs +++ b/crates/assistant2/examples/assistant_example.rs @@ -87,16 +87,14 @@ fn main() { let project_index = semantic_index.project_index(project.clone(), cx); - let mut tool_registry = ToolRegistry::new(); - tool_registry - .register(ProjectIndexTool::new(project_index.clone(), fs.clone())) - .context("failed to register ProjectIndexTool") - .log_err(); - - let tool_registry = Arc::new(tool_registry); - cx.open_window(WindowOptions::default(), |cx| { - cx.new_view(|cx| Example::new(language_registry, tool_registry, cx)) + let mut tool_registry = ToolRegistry::new(); + tool_registry + .register(ProjectIndexTool::new(project_index.clone(), fs.clone()), cx) + .context("failed to register ProjectIndexTool") + .log_err(); + + cx.new_view(|cx| Example::new(language_registry, Arc::new(tool_registry), cx)) }); cx.activate(true); }) diff --git a/crates/assistant2/examples/chat_with_functions.rs b/crates/assistant2/examples/chat_with_functions.rs index 6c2870e680..35dacacb78 100644 --- a/crates/assistant2/examples/chat_with_functions.rs +++ b/crates/assistant2/examples/chat_with_functions.rs @@ -135,7 +135,7 @@ impl LanguageModelTool for RollDiceTool { return Task::ready(Ok(DiceRoll { rolls })); } - fn new_view( + fn output_view( _tool_call_id: String, _input: Self::Input, result: Result, @@ -194,20 +194,20 @@ fn main() { cx.spawn(|cx| async move { cx.update(|cx| { - let mut tool_registry = ToolRegistry::new(); - tool_registry - .register(RollDiceTool::new()) - .context("failed to register DummyTool") - .log_err(); - - let tool_registry = Arc::new(tool_registry); - - println!("Tools registered"); - for definition in tool_registry.definitions() { - println!("{}", definition); - } - cx.open_window(WindowOptions::default(), |cx| { + let mut tool_registry = ToolRegistry::new(); + tool_registry + .register(RollDiceTool::new(), cx) + .context("failed to register DummyTool") + .log_err(); + + let tool_registry = Arc::new(tool_registry); + + println!("Tools registered"); + for definition in tool_registry.definitions() { + println!("{}", definition); + } + cx.new_view(|cx| Example::new(language_registry, tool_registry, cx)) }); cx.activate(true); diff --git a/crates/assistant2/examples/file_interactions.rs b/crates/assistant2/examples/file_interactions.rs index c810085b86..3d397aac7e 100644 --- a/crates/assistant2/examples/file_interactions.rs +++ b/crates/assistant2/examples/file_interactions.rs @@ -115,7 +115,7 @@ impl LanguageModelTool for FileBrowserTool { }) } - fn new_view( + fn output_view( _tool_call_id: String, _input: Self::Input, result: Result, @@ -174,20 +174,20 @@ fn main() { let fs = Arc::new(fs::RealFs::new(None)); let cwd = std::env::current_dir().expect("Failed to get current working directory"); - let mut tool_registry = ToolRegistry::new(); - tool_registry - .register(FileBrowserTool::new(fs, cwd)) - .context("failed to register FileBrowserTool") - .log_err(); - - let tool_registry = Arc::new(tool_registry); - - println!("Tools registered"); - for definition in tool_registry.definitions() { - println!("{}", definition); - } - cx.open_window(WindowOptions::default(), |cx| { + let mut tool_registry = ToolRegistry::new(); + tool_registry + .register(FileBrowserTool::new(fs, cwd), cx) + .context("failed to register FileBrowserTool") + .log_err(); + + let tool_registry = Arc::new(tool_registry); + + println!("Tools registered"); + for definition in tool_registry.definitions() { + println!("{}", definition); + } + cx.new_view(|cx| Example::new(language_registry, tool_registry, cx)) }); cx.activate(true); diff --git a/crates/assistant2/src/assistant2.rs b/crates/assistant2/src/assistant2.rs index 8204dc3654..68784862b8 100644 --- a/crates/assistant2/src/assistant2.rs +++ b/crates/assistant2/src/assistant2.rs @@ -8,22 +8,21 @@ use client::{proto, Client}; use completion_provider::*; use editor::Editor; use feature_flags::FeatureFlagAppExt as _; -use futures::{channel::oneshot, future::join_all, Future, FutureExt, StreamExt}; +use futures::{future::join_all, StreamExt}; use gpui::{ list, prelude::*, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle, - FocusableView, Global, ListAlignment, ListState, Model, Render, Task, View, WeakView, + FocusableView, Global, ListAlignment, ListState, Render, Task, View, WeakView, }; use language::{language_settings::SoftWrap, LanguageRegistry}; use open_ai::{FunctionContent, ToolCall, ToolCallContent}; -use project::Fs; use rich_text::RichText; -use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex}; +use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use serde::Deserialize; use settings::Settings; -use std::{cmp, sync::Arc}; +use std::sync::Arc; use theme::ThemeSettings; use tools::ProjectIndexTool; -use ui::{popover_menu, prelude::*, ButtonLike, CollapsibleContainer, Color, ContextMenu, Tooltip}; +use ui::{popover_menu, prelude::*, ButtonLike, Color, ContextMenu, Tooltip}; use util::{paths::EMBEDDINGS_DIR, ResultExt}; use workspace::{ dock::{DockPosition, Panel, PanelEvent}, @@ -110,10 +109,10 @@ impl AssistantPanel { let mut tool_registry = ToolRegistry::new(); tool_registry - .register(ProjectIndexTool::new( - project_index.clone(), - app_state.fs.clone(), - )) + .register( + ProjectIndexTool::new(project_index.clone(), app_state.fs.clone()), + cx, + ) .context("failed to register ProjectIndexTool") .log_err(); @@ -447,11 +446,7 @@ impl AssistantChat { } editor }); - let message = ChatMessage::User(UserMessage { - id, - body, - contexts: Vec::new(), - }); + let message = ChatMessage::User(UserMessage { id, body }); self.push_message(message, cx); } @@ -525,11 +520,7 @@ impl AssistantChat { let is_last = ix == self.messages.len() - 1; match &self.messages[ix] { - ChatMessage::User(UserMessage { - body, - contexts: _contexts, - .. - }) => div() + ChatMessage::User(UserMessage { body, .. }) => div() .when(!is_last, |element| element.mb_2()) .child(div().p_2().child(Label::new("You").color(Color::Default))) .child( @@ -539,7 +530,7 @@ impl AssistantChat { .text_color(cx.theme().colors().editor_foreground) .font(ThemeSettings::get_global(cx).buffer_font.clone()) .bg(cx.theme().colors().editor_background) - .child(body.clone()), // .children(contexts.iter().map(|context| context.render(cx))), + .child(body.clone()), ) .into_any(), ChatMessage::Assistant(AssistantMessage { @@ -588,11 +579,11 @@ impl AssistantChat { for message in &self.messages { match message { - ChatMessage::User(UserMessage { body, contexts, .. }) => { - // setup context for model - contexts.iter().for_each(|context| { - completion_messages.extend(context.completion_messages(cx)) - }); + ChatMessage::User(UserMessage { body, .. }) => { + // When we re-introduce contexts like active file, we'll inject them here instead of relying on the model to request them + // contexts.iter().for_each(|context| { + // completion_messages.extend(context.completion_messages(cx)) + // }); // Show user's message last so that the assistant is grounded in the user's request completion_messages.push(CompletionMessage::User { @@ -712,6 +703,12 @@ impl Render for AssistantChat { .text_color(Color::Default.color(cx)) .child(self.render_model_dropdown(cx)) .child(list(self.list_state.clone()).flex_1()) + .child( + h_flex() + .mt_2() + .gap_2() + .children(self.tool_registry.status_views().iter().cloned()), + ) } } @@ -743,7 +740,6 @@ impl ChatMessage { struct UserMessage { id: MessageId, body: View, - contexts: Vec, } struct AssistantMessage { @@ -752,211 +748,3 @@ struct AssistantMessage { tool_calls: Vec, error: Option, } - -// Since we're swapping out for direct query usage, we might not need to use this injected context -// It will be useful though for when the user _definitely_ wants the model to see a specific file, -// query, error, etc. -#[allow(dead_code)] -enum AssistantContext { - Codebase(View), -} - -#[allow(dead_code)] -struct CodebaseExcerpt { - element_id: ElementId, - path: SharedString, - text: SharedString, - score: f32, - expanded: bool, -} - -impl AssistantContext { - #[allow(dead_code)] - fn render(&self, _cx: &mut ViewContext) -> AnyElement { - match self { - AssistantContext::Codebase(context) => context.clone().into_any_element(), - } - } - - fn completion_messages(&self, cx: &WindowContext) -> Vec { - match self { - AssistantContext::Codebase(context) => context.read(cx).completion_messages(), - } - } -} - -enum CodebaseContext { - Pending { _task: Task<()> }, - Done(Result>), -} - -impl CodebaseContext { - fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext) { - if let CodebaseContext::Done(Ok(excerpts)) = self { - if let Some(excerpt) = excerpts - .iter_mut() - .find(|excerpt| excerpt.element_id == element_id) - { - excerpt.expanded = !excerpt.expanded; - cx.notify(); - } - } - } -} - -impl Render for CodebaseContext { - fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - match self { - CodebaseContext::Pending { .. } => div() - .h_flex() - .items_center() - .gap_1() - .child(Icon::new(IconName::Ai).color(Color::Muted).into_element()) - .child("Searching codebase..."), - CodebaseContext::Done(Ok(excerpts)) => { - div() - .v_flex() - .gap_2() - .children(excerpts.iter().map(|excerpt| { - let expanded = excerpt.expanded; - let element_id = excerpt.element_id.clone(); - - CollapsibleContainer::new(element_id.clone(), expanded) - .start_slot( - h_flex() - .gap_1() - .child(Icon::new(IconName::File).color(Color::Muted)) - .child(Label::new(excerpt.path.clone()).color(Color::Muted)), - ) - .on_click(cx.listener(move |this, _, cx| { - this.toggle_expanded(element_id.clone(), cx); - })) - .child( - div() - .p_2() - .rounded_md() - .bg(cx.theme().colors().editor_background) - .child( - excerpt.text.clone(), // todo!(): Show as an editor block - ), - ) - })) - } - CodebaseContext::Done(Err(error)) => div().child(error.to_string()), - } - } -} - -impl CodebaseContext { - #[allow(dead_code)] - fn new( - query: impl 'static + Future>, - populated: oneshot::Sender, - project_index: Model, - fs: Arc, - cx: &mut ViewContext, - ) -> Self { - let query = query.boxed_local(); - let _task = cx.spawn(|this, mut cx| async move { - let result = async { - let query = query.await?; - let results = this - .update(&mut cx, |_this, cx| { - project_index.read(cx).search(&query, 16, cx) - })? - .await; - - let excerpts = results.into_iter().map(|result| { - let abs_path = result - .worktree - .read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path)); - let fs = fs.clone(); - - async move { - let path = result.path.clone(); - let text = fs.load(&abs_path?).await?; - // todo!("what should we do with stale ranges?"); - let range = cmp::min(result.range.start, text.len()) - ..cmp::min(result.range.end, text.len()); - - let text = SharedString::from(text[range].to_string()); - - anyhow::Ok(CodebaseExcerpt { - element_id: ElementId::Name(nanoid::nanoid!().into()), - path: path.to_string_lossy().to_string().into(), - text, - score: result.score, - expanded: false, - }) - } - }); - - anyhow::Ok( - futures::future::join_all(excerpts) - .await - .into_iter() - .filter_map(|result| result.log_err()) - .collect(), - ) - } - .await; - - this.update(&mut cx, |this, cx| { - this.populate(result, populated, cx); - }) - .ok(); - }); - - Self::Pending { _task } - } - - #[allow(dead_code)] - fn populate( - &mut self, - result: Result>, - populated: oneshot::Sender, - cx: &mut ViewContext, - ) { - let success = result.is_ok(); - *self = Self::Done(result); - populated.send(success).ok(); - cx.notify(); - } - - fn completion_messages(&self) -> Vec { - // One system message for the whole batch of excerpts: - - // Semantic search results for user query: - // - // Excerpt from $path: - // ~~~ - // `text` - // ~~~ - // - // Excerpt from $path: - - match self { - CodebaseContext::Done(Ok(excerpts)) => { - if excerpts.is_empty() { - return Vec::new(); - } - - let mut body = "Semantic search results for user query:\n".to_string(); - - for excerpt in excerpts { - body.push_str("Excerpt from "); - body.push_str(excerpt.path.as_ref()); - body.push_str(", score "); - body.push_str(&excerpt.score.to_string()); - body.push_str(":\n"); - body.push_str("~~~\n"); - body.push_str(excerpt.text.as_ref()); - body.push_str("~~~\n"); - } - - vec![CompletionMessage::System { content: body }] - } - _ => vec![], - } - } -} diff --git a/crates/assistant2/src/tools.rs b/crates/assistant2/src/tools.rs index 3e86e72168..6ffddfe51d 100644 --- a/crates/assistant2/src/tools.rs +++ b/crates/assistant2/src/tools.rs @@ -1,9 +1,9 @@ use anyhow::Result; use assistant_tooling::LanguageModelTool; -use gpui::{prelude::*, AppContext, Model, Task}; +use gpui::{prelude::*, AnyView, AppContext, Model, Task}; use project::Fs; use schemars::JsonSchema; -use semantic_index::ProjectIndex; +use semantic_index::{ProjectIndex, Status}; use serde::Deserialize; use std::sync::Arc; use ui::{ @@ -36,13 +36,14 @@ pub struct CodebaseQuery { pub struct ProjectIndexView { input: CodebaseQuery, - output: Result>, + output: Result, } impl ProjectIndexView { fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext) { - if let Ok(excerpts) = &mut self.output { - if let Some(excerpt) = excerpts + if let Ok(output) = &mut self.output { + if let Some(excerpt) = output + .excerpts .iter_mut() .find(|excerpt| excerpt.element_id == element_id) { @@ -59,11 +60,11 @@ impl Render for ProjectIndexView { let result = &self.output; - let excerpts = match result { + let output = match result { Err(err) => { return div().child(Label::new(format!("Error: {}", err)).color(Color::Error)); } - Ok(excerpts) => excerpts, + Ok(output) => output, }; div() @@ -80,7 +81,7 @@ impl Render for ProjectIndexView { .child(Label::new(query).color(Color::Muted)), ), ) - .children(excerpts.iter().map(|excerpt| { + .children(output.excerpts.iter().map(|excerpt| { let element_id = excerpt.element_id.clone(); let expanded = excerpt.expanded; @@ -99,9 +100,7 @@ impl Render for ProjectIndexView { .p_2() .rounded_md() .bg(cx.theme().colors().editor_background) - .child( - excerpt.text.clone(), // todo!(): Show as an editor block - ), + .child(excerpt.text.clone()), ) })) } @@ -112,8 +111,15 @@ pub struct ProjectIndexTool { fs: Arc, } +pub struct ProjectIndexOutput { + excerpts: Vec, + status: Status, +} + impl ProjectIndexTool { pub fn new(project_index: Model, fs: Arc) -> Self { + // Listen for project index status and update the ProjectIndexTool directly + // TODO: setup a better description based on the user's current codebase. Self { project_index, fs } } @@ -121,7 +127,7 @@ impl ProjectIndexTool { impl LanguageModelTool for ProjectIndexTool { type Input = CodebaseQuery; - type Output = Vec; + type Output = ProjectIndexOutput; type View = ProjectIndexView; fn name(&self) -> String { @@ -135,6 +141,7 @@ impl LanguageModelTool for ProjectIndexTool { fn execute(&self, query: &Self::Input, cx: &AppContext) -> Task> { let project_index = self.project_index.read(cx); + let status = project_index.status(); let results = project_index.search( query.query.as_str(), query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT), @@ -180,11 +187,11 @@ impl LanguageModelTool for ProjectIndexTool { .into_iter() .filter_map(|result| result.log_err()) .collect(); - anyhow::Ok(excerpts) + anyhow::Ok(ProjectIndexOutput { excerpts, status }) }) } - fn new_view( + fn output_view( _tool_call_id: String, input: Self::Input, output: Result, @@ -193,16 +200,28 @@ impl LanguageModelTool for ProjectIndexTool { cx.new_view(|_cx| ProjectIndexView { input, output }) } + fn status_view(&self, cx: &mut WindowContext) -> Option { + Some( + cx.new_view(|cx| ProjectIndexStatusView::new(self.project_index.clone(), cx)) + .into(), + ) + } + fn format(_input: &Self::Input, output: &Result) -> String { match &output { - Ok(excerpts) => { - if excerpts.len() == 0 { - return "No results found".to_string(); - } - + Ok(output) => { let mut body = "Semantic search results:\n".to_string(); - for excerpt in excerpts { + if output.status != Status::Idle { + body.push_str("Still indexing. Results may be incomplete.\n"); + } + + if output.excerpts.is_empty() { + body.push_str("No results found"); + return body; + } + + for excerpt in &output.excerpts { body.push_str("Excerpt from "); body.push_str(excerpt.path.as_ref()); body.push_str(", score "); @@ -218,3 +237,31 @@ impl LanguageModelTool for ProjectIndexTool { } } } + +struct ProjectIndexStatusView { + project_index: Model, +} + +impl ProjectIndexStatusView { + pub fn new(project_index: Model, cx: &mut ViewContext) -> Self { + cx.subscribe(&project_index, |_this, _, _status: &Status, cx| { + cx.notify(); + }) + .detach(); + Self { project_index } + } +} + +impl Render for ProjectIndexStatusView { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let status = self.project_index.read(cx).status(); + + h_flex().gap_2().map(|element| match status { + Status::Idle => element.child(Label::new("Project index ready")), + Status::Loading => element.child(Label::new("Project index loading...")), + Status::Scanning { remaining_count } => element.child(Label::new(format!( + "Project index scanning: {remaining_count} remaining..." + ))), + }) + } +} diff --git a/crates/assistant_tooling/src/registry.rs b/crates/assistant_tooling/src/registry.rs index 6a3bc313cd..136f012d33 100644 --- a/crates/assistant_tooling/src/registry.rs +++ b/crates/assistant_tooling/src/registry.rs @@ -1,5 +1,5 @@ use anyhow::{anyhow, Result}; -use gpui::{Task, WindowContext}; +use gpui::{AnyView, Task, WindowContext}; use std::collections::HashMap; use crate::tool::{ @@ -12,6 +12,7 @@ pub struct ToolRegistry { Box Task>>, >, definitions: Vec, + status_views: Vec, } impl ToolRegistry { @@ -19,6 +20,7 @@ impl ToolRegistry { Self { tools: HashMap::new(), definitions: Vec::new(), + status_views: Vec::new(), } } @@ -26,8 +28,17 @@ impl ToolRegistry { &self.definitions } - pub fn register(&mut self, tool: T) -> Result<()> { + pub fn register( + &mut self, + tool: T, + cx: &mut WindowContext, + ) -> Result<()> { self.definitions.push(tool.definition()); + + if let Some(tool_view) = tool.status_view(cx) { + self.status_views.push(tool_view); + } + let name = tool.name(); let previous = self.tools.insert( name.clone(), @@ -52,7 +63,7 @@ impl ToolRegistry { cx.spawn(move |mut cx| async move { let result: Result = result.await; let for_model = T::format(&input, &result); - let view = cx.update(|cx| T::new_view(id.clone(), input, result, cx))?; + let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?; Ok(ToolFunctionCall { id, @@ -100,6 +111,10 @@ impl ToolRegistry { tool(tool_call, cx) } + + pub fn status_views(&self) -> &[AnyView] { + &self.status_views + } } #[cfg(test)] @@ -165,7 +180,7 @@ mod test { Task::ready(Ok(weather)) } - fn new_view( + fn output_view( _tool_call_id: String, _input: Self::Input, result: Result, @@ -182,46 +197,6 @@ mod test { } } - #[gpui::test] - async fn test_function_registry(cx: &mut TestAppContext) { - cx.background_executor.run_until_parked(); - - let mut registry = ToolRegistry::new(); - - let tool = WeatherTool { - current_weather: WeatherResult { - location: "San Francisco".to_string(), - temperature: 21.0, - unit: "Celsius".to_string(), - }, - }; - - registry.register(tool).unwrap(); - - // let _result = cx - // .update(|cx| { - // registry.call( - // &ToolFunctionCall { - // name: "get_current_weather".to_string(), - // arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"# - // .to_string(), - // id: "test-123".to_string(), - // result: None, - // }, - // cx, - // ) - // }) - // .await; - - // assert!(result.is_ok()); - // let result = result.unwrap(); - - // let expected = r#"{"location":"San Francisco","temperature":21.0,"unit":"Celsius"}"#; - - // todo!(): Put this back in after the interface is stabilized - // assert_eq!(result, expected); - } - #[gpui::test] async fn test_openai_weather_example(cx: &mut TestAppContext) { cx.background_executor.run_until_parked(); diff --git a/crates/assistant_tooling/src/tool.rs b/crates/assistant_tooling/src/tool.rs index 8a1ffcf9d4..31ed8fdee8 100644 --- a/crates/assistant_tooling/src/tool.rs +++ b/crates/assistant_tooling/src/tool.rs @@ -95,10 +95,14 @@ pub trait LanguageModelTool { fn format(input: &Self::Input, output: &Result) -> String; - fn new_view( + fn output_view( tool_call_id: String, input: Self::Input, output: Result, cx: &mut WindowContext, ) -> View; + + fn status_view(&self, _cx: &mut WindowContext) -> Option { + None + } } diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 5f06d4193f..a23f7853de 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -30,6 +30,7 @@ language.workspace = true log.workspace = true heed.workspace = true open_ai.workspace = true +parking_lot.workspace = true project.workspace = true settings.workspace = true serde.workspace = true diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 097a050ee8..f14438fed6 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -3,7 +3,7 @@ mod embedding; use anyhow::{anyhow, Context as _, Result}; use chunking::{chunk_text, Chunk}; -use collections::{Bound, HashMap}; +use collections::{Bound, HashMap, HashSet}; pub use embedding::*; use fs::Fs; use futures::stream::StreamExt; @@ -14,15 +14,17 @@ use gpui::{ }; use heed::types::{SerdeBincode, Str}; use language::LanguageRegistry; -use project::{Entry, Project, UpdatedEntriesSet, Worktree}; +use parking_lot::Mutex; +use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree}; use serde::{Deserialize, Serialize}; use smol::channel; use std::{ cmp::Ordering, future::Future, + num::NonZeroUsize, ops::Range, path::{Path, PathBuf}, - sync::Arc, + sync::{Arc, Weak}, time::{Duration, SystemTime}, }; use util::ResultExt; @@ -102,19 +104,16 @@ pub struct ProjectIndex { worktree_indices: HashMap, language_registry: Arc, fs: Arc, - pub last_status: Status, + last_status: Status, + status_tx: channel::Sender<()>, embedding_provider: Arc, + _maintain_status: Task<()>, _subscription: Subscription, } enum WorktreeIndexHandle { - Loading { - _task: Task>, - }, - Loaded { - index: Model, - _subscription: Subscription, - }, + Loading { _task: Task> }, + Loaded { index: Model }, } impl ProjectIndex { @@ -126,20 +125,36 @@ impl ProjectIndex { ) -> Self { let language_registry = project.read(cx).languages().clone(); let fs = project.read(cx).fs().clone(); + let (status_tx, mut status_rx) = channel::unbounded(); let mut this = ProjectIndex { db_connection, project: project.downgrade(), worktree_indices: HashMap::default(), language_registry, fs, + status_tx, last_status: Status::Idle, embedding_provider, _subscription: cx.subscribe(&project, Self::handle_project_event), + _maintain_status: cx.spawn(|this, mut cx| async move { + while status_rx.next().await.is_some() { + if this + .update(&mut cx, |this, cx| this.update_status(cx)) + .is_err() + { + break; + } + } + }), }; this.update_worktree_indices(cx); this } + pub fn status(&self) -> Status { + self.last_status + } + fn handle_project_event( &mut self, _: Model, @@ -180,19 +195,18 @@ impl ProjectIndex { self.db_connection.clone(), self.language_registry.clone(), self.fs.clone(), + self.status_tx.clone(), self.embedding_provider.clone(), cx, ); let load_worktree = cx.spawn(|this, mut cx| async move { - if let Some(index) = worktree_index.await.log_err() { - this.update(&mut cx, |this, cx| { + if let Some(worktree_index) = worktree_index.await.log_err() { + this.update(&mut cx, |this, _| { this.worktree_indices.insert( worktree_id, WorktreeIndexHandle::Loaded { - _subscription: cx - .observe(&index, |this, _, cx| this.update_status(cx)), - index, + index: worktree_index, }, ); })?; @@ -215,22 +229,29 @@ impl ProjectIndex { } fn update_status(&mut self, cx: &mut ModelContext) { - let mut status = Status::Idle; - for index in self.worktree_indices.values() { + let mut indexing_count = 0; + let mut any_loading = false; + + for index in self.worktree_indices.values_mut() { match index { WorktreeIndexHandle::Loading { .. } => { - status = Status::Scanning; + any_loading = true; break; } WorktreeIndexHandle::Loaded { index, .. } => { - if index.read(cx).status == Status::Scanning { - status = Status::Scanning; - break; - } + indexing_count += index.read(cx).entry_ids_being_indexed.len(); } } } + let status = if any_loading { + Status::Loading + } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) { + Status::Scanning { remaining_count } + } else { + Status::Idle + }; + if status != self.last_status { self.last_status = status; cx.emit(status); @@ -263,6 +284,17 @@ impl ProjectIndex { results }) } + + #[cfg(test)] + pub fn path_count(&self, cx: &AppContext) -> Result { + let mut result = 0; + for worktree_index in self.worktree_indices.values() { + if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index { + result += index.read(cx).path_count()?; + } + } + Ok(result) + } } pub struct SearchResult { @@ -275,7 +307,8 @@ pub struct SearchResult { #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum Status { Idle, - Scanning, + Loading, + Scanning { remaining_count: NonZeroUsize }, } impl EventEmitter for ProjectIndex {} @@ -287,7 +320,7 @@ struct WorktreeIndex { language_registry: Arc, fs: Arc, embedding_provider: Arc, - status: Status, + entry_ids_being_indexed: Arc, _index_entries: Task>, _subscription: Subscription, } @@ -298,6 +331,7 @@ impl WorktreeIndex { db_connection: heed::Env, language_registry: Arc, fs: Arc, + status_tx: channel::Sender<()>, embedding_provider: Arc, cx: &mut AppContext, ) -> Task>> { @@ -321,6 +355,7 @@ impl WorktreeIndex { worktree, db_connection, db, + status_tx, language_registry, fs, embedding_provider, @@ -330,10 +365,12 @@ impl WorktreeIndex { }) } + #[allow(clippy::too_many_arguments)] fn new( worktree: Model, db_connection: heed::Env, db: heed::Database>, + status: channel::Sender<()>, language_registry: Arc, fs: Arc, embedding_provider: Arc, @@ -353,7 +390,7 @@ impl WorktreeIndex { language_registry, fs, embedding_provider, - status: Status::Idle, + entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)), _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)), _subscription, } @@ -364,28 +401,14 @@ impl WorktreeIndex { updated_entries: channel::Receiver, mut cx: AsyncAppContext, ) -> Result<()> { - let index = this.update(&mut cx, |this, cx| { - cx.notify(); - this.status = Status::Scanning; - this.index_entries_changed_on_disk(cx) - })?; + let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?; index.await.log_err(); - this.update(&mut cx, |this, cx| { - this.status = Status::Idle; - cx.notify(); - })?; while let Ok(updated_entries) = updated_entries.recv().await { let index = this.update(&mut cx, |this, cx| { - cx.notify(); - this.status = Status::Scanning; this.index_updated_entries(updated_entries, cx) })?; index.await.log_err(); - this.update(&mut cx, |this, cx| { - this.status = Status::Idle; - cx.notify(); - })?; } Ok(()) @@ -426,6 +449,7 @@ impl WorktreeIndex { let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128); let db_connection = self.db_connection.clone(); let db = self.db; + let entries_being_indexed = self.entry_ids_being_indexed.clone(); let task = cx.background_executor().spawn(async move { let txn = db_connection .read_txn() @@ -476,7 +500,8 @@ impl WorktreeIndex { } if entry.mtime != saved_mtime { - updated_entries_tx.send(entry.clone()).await?; + let handle = entries_being_indexed.insert(&entry); + updated_entries_tx.send((entry.clone(), handle)).await?; } } @@ -505,6 +530,7 @@ impl WorktreeIndex { ) -> ScanEntries { let (updated_entries_tx, updated_entries_rx) = channel::bounded(512); let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128); + let entries_being_indexed = self.entry_ids_being_indexed.clone(); let task = cx.background_executor().spawn(async move { for (path, entry_id, status) in updated_entries.iter() { match status { @@ -513,7 +539,8 @@ impl WorktreeIndex { | project::PathChange::AddedOrUpdated => { if let Some(entry) = worktree.entry_for_id(*entry_id) { if entry.is_file() { - updated_entries_tx.send(entry.clone()).await?; + let handle = entries_being_indexed.insert(&entry); + updated_entries_tx.send((entry.clone(), handle)).await?; } } } @@ -542,7 +569,7 @@ impl WorktreeIndex { fn chunk_files( &self, worktree_abs_path: Arc, - entries: channel::Receiver, + entries: channel::Receiver<(Entry, IndexingEntryHandle)>, cx: &AppContext, ) -> ChunkFiles { let language_registry = self.language_registry.clone(); @@ -553,7 +580,7 @@ impl WorktreeIndex { .scoped(|cx| { for _ in 0..cx.num_cpus() { cx.spawn(async { - while let Ok(entry) = entries.recv().await { + while let Ok((entry, handle)) = entries.recv().await { let entry_abs_path = worktree_abs_path.join(&entry.path); let Some(text) = fs .load(&entry_abs_path) @@ -572,8 +599,8 @@ impl WorktreeIndex { let grammar = language.as_ref().and_then(|language| language.grammar()); let chunked_file = ChunkedFile { - worktree_root: worktree_abs_path.clone(), chunks: chunk_text(&text, grammar), + handle, entry, text, }; @@ -622,7 +649,11 @@ impl WorktreeIndex { let mut embeddings = Vec::new(); for embedding_batch in chunks.chunks(embedding_provider.batch_size()) { - embeddings.extend(embedding_provider.embed(embedding_batch).await?); + if let Some(batch_embeddings) = + embedding_provider.embed(embedding_batch).await.log_err() + { + embeddings.extend_from_slice(&batch_embeddings); + } } let mut embeddings = embeddings.into_iter(); @@ -643,7 +674,9 @@ impl WorktreeIndex { chunks: embedded_chunks, }; - embedded_files_tx.send(embedded_file).await?; + embedded_files_tx + .send((embedded_file, chunked_file.handle)) + .await?; } } Ok(()) @@ -658,7 +691,7 @@ impl WorktreeIndex { fn persist_embeddings( &self, mut deleted_entry_ranges: channel::Receiver<(Bound, Bound)>, - embedded_files: channel::Receiver, + embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>, cx: &AppContext, ) -> Task> { let db_connection = self.db_connection.clone(); @@ -676,12 +709,15 @@ impl WorktreeIndex { let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2)); while let Some(embedded_files) = embedded_files.next().await { let mut txn = db_connection.write_txn()?; - for file in embedded_files { + for (file, _) in &embedded_files { log::debug!("saving embedding for file {:?}", file.path); let key = db_key_for_path(&file.path); - db.put(&mut txn, &key, &file)?; + db.put(&mut txn, &key, file)?; } txn.commit()?; + eprintln!("committed {:?}", embedded_files.len()); + + drop(embedded_files); log::debug!("committed"); } @@ -789,10 +825,19 @@ impl WorktreeIndex { Ok(search_results) }) } + + #[cfg(test)] + fn path_count(&self) -> Result { + let txn = self + .db_connection + .read_txn() + .context("failed to create read transaction")?; + Ok(self.db.len(&txn)?) + } } struct ScanEntries { - updated_entries: channel::Receiver, + updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>, deleted_entry_ranges: channel::Receiver<(Bound, Bound)>, task: Task>, } @@ -803,15 +848,14 @@ struct ChunkFiles { } struct ChunkedFile { - #[allow(dead_code)] - pub worktree_root: Arc, pub entry: Entry, + pub handle: IndexingEntryHandle, pub text: String, pub chunks: Vec, } struct EmbedFiles { - files: channel::Receiver, + files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>, task: Task>, } @@ -828,6 +872,47 @@ struct EmbeddedChunk { embedding: Embedding, } +struct IndexingEntrySet { + entry_ids: Mutex>, + tx: channel::Sender<()>, +} + +struct IndexingEntryHandle { + entry_id: ProjectEntryId, + set: Weak, +} + +impl IndexingEntrySet { + fn new(tx: channel::Sender<()>) -> Self { + Self { + entry_ids: Default::default(), + tx, + } + } + + fn insert(self: &Arc, entry: &project::Entry) -> IndexingEntryHandle { + self.entry_ids.lock().insert(entry.id); + self.tx.send_blocking(()).ok(); + IndexingEntryHandle { + entry_id: entry.id, + set: Arc::downgrade(self), + } + } + + pub fn len(&self) -> usize { + self.entry_ids.lock().len() + } +} + +impl Drop for IndexingEntryHandle { + fn drop(&mut self) { + if let Some(set) = self.set.upgrade() { + set.tx.send_blocking(()).ok(); + set.entry_ids.lock().remove(&self.entry_id); + } + } +} + fn db_key_for_path(path: &Arc) -> String { path.to_string_lossy().replace('/', "\0") } @@ -835,10 +920,7 @@ fn db_key_for_path(path: &Arc) -> String { #[cfg(test)] mod tests { use super::*; - - use futures::channel::oneshot; use futures::{future::BoxFuture, FutureExt}; - use gpui::{Global, TestAppContext}; use language::language_settings::AllLanguageSettings; use project::Project; @@ -922,18 +1004,13 @@ mod tests { let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx)); - let (tx, rx) = oneshot::channel(); - let mut tx = Some(tx); - let subscription = cx.update(|cx| { - cx.subscribe(&project_index, move |_, event, _| { - if let Some(tx) = tx.take() { - _ = tx.send(*event); - } - }) - }); - - rx.await.expect("no event emitted"); - drop(subscription); + while project_index + .read_with(cx, |index, cx| index.path_count(cx)) + .unwrap() + == 0 + { + project_index.next_event(cx).await; + } let results = cx .update(|cx| {