From 22046ef9a7f3129abdb9b3328e24fbf05fefdfaf Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 24 Jan 2024 13:36:44 +0100 Subject: [PATCH] Load language models in the background --- crates/ai/src/providers/open_ai/completion.rs | 6 +- crates/ai/src/providers/open_ai/embedding.rs | 7 +- crates/assistant/src/assistant_panel.rs | 132 ++++++++++-------- crates/semantic_index/src/semantic_index.rs | 7 +- 4 files changed, 83 insertions(+), 69 deletions(-) diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 0e325ee624..fda4d69816 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -201,8 +201,10 @@ pub struct OpenAICompletionProvider { } impl OpenAICompletionProvider { - pub fn new(model_name: &str, executor: BackgroundExecutor) -> Self { - let model = OpenAILanguageModel::load(model_name); + pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self { + let model = executor + .spawn(async move { OpenAILanguageModel::load(&model_name) }) + .await; let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); Self { model, diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 0a9b6ba969..1dca571733 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -67,11 +67,14 @@ struct OpenAIEmbeddingUsage { } impl OpenAIEmbeddingProvider { - pub fn new(client: Arc, executor: BackgroundExecutor) -> Self { + pub async fn new(client: Arc, executor: BackgroundExecutor) -> Self { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); - let model = OpenAILanguageModel::load("text-embedding-ada-002"); + // Loading the model is expensive, so ensure this runs off the main thread. + let model = executor + .spawn(async move { OpenAILanguageModel::load("text-embedding-ada-002") }) + .await; let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); OpenAIEmbeddingProvider { diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 1f57e52032..097c6424d7 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -31,9 +31,9 @@ use fs::Fs; use futures::StreamExt; use gpui::{ canvas, div, point, relative, rems, uniform_list, Action, AnyElement, AppContext, - AsyncWindowContext, AvailableSpace, ClipboardItem, Context, EventEmitter, FocusHandle, - FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement, IntoElement, Model, - ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString, + AsyncAppContext, AsyncWindowContext, AvailableSpace, ClipboardItem, Context, EventEmitter, + FocusHandle, FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement, + IntoElement, Model, ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString, StatefulInteractiveElement, Styled, Subscription, Task, TextStyle, UniformListScrollHandle, View, ViewContext, VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext, }; @@ -123,6 +123,10 @@ impl AssistantPanel { .await .log_err() .unwrap_or_default(); + // Defaulting currently to GPT4, allow for this to be set via config. + let completion_provider = + OpenAICompletionProvider::new("gpt-4".into(), cx.background_executor().clone()) + .await; // TODO: deserialize state. let workspace_handle = workspace.clone(); @@ -156,11 +160,6 @@ impl AssistantPanel { }); let semantic_index = SemanticIndex::global(cx); - // Defaulting currently to GPT4, allow for this to be set via config. - let completion_provider = Arc::new(OpenAICompletionProvider::new( - "gpt-4", - cx.background_executor().clone(), - )); let focus_handle = cx.focus_handle(); cx.on_focus_in(&focus_handle, Self::focus_in).detach(); @@ -176,7 +175,7 @@ impl AssistantPanel { zoomed: false, focus_handle, toolbar, - completion_provider, + completion_provider: Arc::new(completion_provider), api_key_editor: None, languages: workspace.app_state().languages.clone(), fs: workspace.app_state().fs.clone(), @@ -1079,9 +1078,9 @@ impl AssistantPanel { cx.spawn(|this, mut cx| async move { let saved_conversation = fs.load(&path).await?; let saved_conversation = serde_json::from_str(&saved_conversation)?; - let conversation = cx.new_model(|cx| { - Conversation::deserialize(saved_conversation, path.clone(), languages, cx) - })?; + let conversation = + Conversation::deserialize(saved_conversation, path.clone(), languages, &mut cx) + .await?; this.update(&mut cx, |this, cx| { // If, by the time we've loaded the conversation, the user has already opened // the same conversation, we don't want to open it again. @@ -1462,21 +1461,25 @@ impl Conversation { } } - fn deserialize( + async fn deserialize( saved_conversation: SavedConversation, path: PathBuf, language_registry: Arc, - cx: &mut ModelContext, - ) -> Self { + cx: &mut AsyncAppContext, + ) -> Result> { let id = match saved_conversation.id { Some(id) => Some(id), None => Some(Uuid::new_v4().to_string()), }; let model = saved_conversation.model; let completion_provider: Arc = Arc::new( - OpenAICompletionProvider::new(model.full_name(), cx.background_executor().clone()), + OpenAICompletionProvider::new( + model.full_name().into(), + cx.background_executor().clone(), + ) + .await, ); - completion_provider.retrieve_credentials(cx); + cx.update(|cx| completion_provider.retrieve_credentials(cx))?; let markdown = language_registry.language_for_name("Markdown"); let mut message_anchors = Vec::new(); let mut next_message_id = MessageId(0); @@ -1499,32 +1502,34 @@ impl Conversation { }) .detach_and_log_err(cx); buffer - }); + })?; - let mut this = Self { - id, - message_anchors, - messages_metadata: saved_conversation.message_metadata, - next_message_id, - summary: Some(Summary { - text: saved_conversation.summary, - done: true, - }), - pending_summary: Task::ready(None), - completion_count: Default::default(), - pending_completions: Default::default(), - token_count: None, - max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()), - pending_token_count: Task::ready(None), - model, - _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], - pending_save: Task::ready(Ok(())), - path: Some(path), - buffer, - completion_provider, - }; - this.count_remaining_tokens(cx); - this + cx.new_model(|cx| { + let mut this = Self { + id, + message_anchors, + messages_metadata: saved_conversation.message_metadata, + next_message_id, + summary: Some(Summary { + text: saved_conversation.summary, + done: true, + }), + pending_summary: Task::ready(None), + completion_count: Default::default(), + pending_completions: Default::default(), + token_count: None, + max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()), + pending_token_count: Task::ready(None), + model, + _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], + pending_save: Task::ready(Ok(())), + path: Some(path), + buffer, + completion_provider, + }; + this.count_remaining_tokens(cx); + this + }) } fn handle_buffer_event( @@ -3169,7 +3174,7 @@ mod tests { use super::*; use crate::MessageId; use ai::test::FakeCompletionProvider; - use gpui::AppContext; + use gpui::{AppContext, TestAppContext}; use settings::SettingsStore; #[gpui::test] @@ -3487,16 +3492,17 @@ mod tests { } #[gpui::test] - fn test_serialization(cx: &mut AppContext) { - let settings_store = SettingsStore::test(cx); + async fn test_serialization(cx: &mut TestAppContext) { + let settings_store = cx.update(SettingsStore::test); cx.set_global(settings_store); - init(cx); + cx.update(init); let registry = Arc::new(LanguageRegistry::test()); let completion_provider = Arc::new(FakeCompletionProvider::new()); let conversation = cx.new_model(|cx| Conversation::new(registry.clone(), cx, completion_provider)); - let buffer = conversation.read(cx).buffer.clone(); - let message_0 = conversation.read(cx).message_anchors[0].id; + let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone()); + let message_0 = + conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id); let message_1 = conversation.update(cx, |conversation, cx| { conversation .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx) @@ -3517,9 +3523,9 @@ mod tests { .unwrap() }); buffer.update(cx, |buffer, cx| buffer.undo(cx)); - assert_eq!(buffer.read(cx).text(), "a\nb\nc\n"); + assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n"); assert_eq!( - messages(&conversation, cx), + cx.read(|cx| messages(&conversation, cx)), [ (message_0, Role::User, 0..2), (message_1.id, Role::Assistant, 2..6), @@ -3527,18 +3533,22 @@ mod tests { ] ); - let deserialized_conversation = cx.new_model(|cx| { - Conversation::deserialize( - conversation.read(cx).serialize(cx), - Default::default(), - registry.clone(), - cx, - ) - }); - let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone(); - assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n"); + let deserialized_conversation = Conversation::deserialize( + conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)), + Default::default(), + registry.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let deserialized_buffer = + deserialized_conversation.read_with(cx, |conversation, _| conversation.buffer.clone()); assert_eq!( - messages(&deserialized_conversation, cx), + deserialized_buffer.read_with(cx, |buffer, _| buffer.text()), + "a\nb\nc\n" + ); + assert_eq!( + cx.read(|cx| messages(&deserialized_conversation, cx)), [ (message_0, Role::User, 0..2), (message_1.id, Role::Assistant, 2..6), diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index a556986f9b..475ab079dc 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -90,13 +90,12 @@ pub fn init( .detach(); cx.spawn(move |cx| async move { + let embedding_provider = + OpenAIEmbeddingProvider::new(http_client, cx.background_executor().clone()).await; let semantic_index = SemanticIndex::new( fs, db_file_path, - Arc::new(OpenAIEmbeddingProvider::new( - http_client, - cx.background_executor().clone(), - )), + Arc::new(embedding_provider), language_registry, cx.clone(), )