diff --git a/Cargo.lock b/Cargo.lock index e8b2dfadd4..9c4628e9fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -116,6 +116,7 @@ dependencies = [ "serde_json", "settings", "theme", + "tiktoken-rs", "util", "workspace", ] @@ -745,6 +746,21 @@ dependencies = [ "which", ] +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -870,6 +886,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3d4260bcc2e8fc9df1eac4919a720effeb63a3f0952f5bf4944adfa18897f09" dependencies = [ "memchr", + "once_cell", + "regex-automata", "serde", ] @@ -2220,6 +2238,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" +[[package]] +name = "fancy-regex" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" +dependencies = [ + "bit-set", + "regex", +] + [[package]] name = "fastrand" version = "1.9.0" @@ -6969,6 +6997,21 @@ dependencies = [ "weezl", ] +[[package]] +name = "tiktoken-rs" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ba161c549e2c0686f35f5d920e63fad5cafba2c28ad2caceaf07e5d9fa6e8c4" +dependencies = [ + "anyhow", + "base64 0.21.0", + "bstr", + "fancy-regex", + "lazy_static", + "parking_lot 0.12.1", + "rustc-hash", +] + [[package]] name = "time" version = "0.1.45" diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 9052b1e5ed..e36df880d9 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -29,6 +29,7 @@ isahc.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true +tiktoken-rs = "0.4" [dev-dependencies] editor = { path = "../editor", features = ["test-support"] } diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 39e9a6ba15..68f722f1ee 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -16,7 +16,8 @@ use gpui::{ use isahc::{http::StatusCode, Request, RequestExt}; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; use settings::SettingsStore; -use std::{cell::Cell, io, rc::Rc, sync::Arc}; +use std::{cell::Cell, io, rc::Rc, sync::Arc, time::Duration}; +use tiktoken_rs::model::get_context_size; use util::{post_inc, ResultExt, TryFutureExt}; use workspace::{ dock::{DockPosition, Panel}, @@ -398,7 +399,12 @@ struct Assistant { completion_count: usize, pending_completions: Vec, languages: Arc, + model: String, + token_count: Option, + max_token_count: usize, + pending_token_count: Task>, api_key: Rc>>, + _subscriptions: Vec, } impl Entity for Assistant { @@ -411,19 +417,78 @@ impl Assistant { language_registry: Arc, cx: &mut ModelContext, ) -> Self { + let model = "gpt-3.5-turbo"; + let buffer = cx.add_model(|_| MultiBuffer::new(0)); let mut this = Self { - buffer: cx.add_model(|_| MultiBuffer::new(0)), messages: Default::default(), messages_by_id: Default::default(), completion_count: Default::default(), pending_completions: Default::default(), languages: language_registry, + token_count: None, + max_token_count: get_context_size(model), + pending_token_count: Task::ready(None), + model: model.into(), + _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], api_key, + buffer, }; this.push_message(Role::User, cx); + this.count_remaining_tokens(cx); this } + fn handle_buffer_event( + &mut self, + _: ModelHandle, + event: &editor::multi_buffer::Event, + cx: &mut ModelContext, + ) { + match event { + editor::multi_buffer::Event::ExcerptsAdded { .. } + | editor::multi_buffer::Event::ExcerptsRemoved { .. } + | editor::multi_buffer::Event::Edited => self.count_remaining_tokens(cx), + _ => {} + } + } + + fn count_remaining_tokens(&mut self, cx: &mut ModelContext) { + let messages = self + .messages + .iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: message.content.read(cx).text(), + name: None, + }) + .collect::>(); + let model = self.model.clone(); + self.pending_token_count = cx.spawn(|this, mut cx| { + async move { + cx.background().timer(Duration::from_millis(200)).await; + let token_count = cx + .background() + .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) }) + .await?; + + this.update(&mut cx, |this, cx| { + this.token_count = Some(token_count); + cx.notify() + }); + anyhow::Ok(()) + } + .log_err() + }); + } + + fn remaining_tokens(&self) -> Option { + Some(self.max_token_count as isize - self.token_count? as isize) + } + fn assist(&mut self, cx: &mut ModelContext) { let messages = self .messages @@ -434,7 +499,7 @@ impl Assistant { }) .collect(); let request = OpenAIRequest { - model: "gpt-3.5-turbo".into(), + model: self.model.clone(), messages, stream: true, }; @@ -530,6 +595,7 @@ struct PendingCompletion { struct AssistantEditor { assistant: ModelHandle, editor: ViewHandle, + _subscriptions: Vec, } impl AssistantEditor { @@ -590,7 +656,11 @@ impl AssistantEditor { ); editor }); - Self { assistant, editor } + Self { + _subscriptions: vec![cx.observe(&assistant, |_, _, cx| cx.notify())], + assistant, + editor, + } } fn assist(&mut self, _: &Assist, cx: &mut ViewContext) { @@ -684,10 +754,34 @@ impl View for AssistantEditor { fn render(&mut self, cx: &mut ViewContext) -> AnyElement { let theme = &theme::current(cx).assistant; + let remaining_tokens = self + .assistant + .read(cx) + .remaining_tokens() + .map(|remaining_tokens| { + let remaining_tokens_style = if remaining_tokens <= 0 { + &theme.no_remaining_tokens + } else { + &theme.remaining_tokens + }; + Label::new( + remaining_tokens.to_string(), + remaining_tokens_style.text.clone(), + ) + .contained() + .with_style(remaining_tokens_style.container) + .aligned() + .top() + .right() + }); - ChildView::new(&self.editor, cx) - .contained() - .with_style(theme.container) + Stack::new() + .with_child( + ChildView::new(&self.editor, cx) + .contained() + .with_style(theme.container), + ) + .with_children(remaining_tokens) .into_any() } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index ee06443068..453468349b 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -10,7 +10,7 @@ pub mod items; mod link_go_to_definition; mod mouse_context_menu; pub mod movement; -mod multi_buffer; +pub mod multi_buffer; mod persistence; pub mod scroll; pub mod selections_collection; diff --git a/crates/theme/src/theme.rs b/crates/theme/src/theme.rs index 8282336ba5..97aac92afd 100644 --- a/crates/theme/src/theme.rs +++ b/crates/theme/src/theme.rs @@ -976,6 +976,8 @@ pub struct AssistantStyle { pub sent_at: ContainedText, pub user_sender: ContainedText, pub assistant_sender: ContainedText, + pub remaining_tokens: ContainedText, + pub no_remaining_tokens: ContainedText, pub api_key_editor: FieldEditor, pub api_key_prompt: ContainedText, } diff --git a/styles/src/styleTree/assistant.ts b/styles/src/styleTree/assistant.ts index 085e43071c..3d21ee8519 100644 --- a/styles/src/styleTree/assistant.ts +++ b/styles/src/styleTree/assistant.ts @@ -23,6 +23,20 @@ export default function assistant(colorScheme: ColorScheme) { margin: { top: 2, left: 8 }, ...text(layer, "sans", "default", { size: "2xs" }), }, + remaining_tokens: { + padding: 4, + margin: { right: 16, top: 4 }, + background: background(layer, "on"), + cornerRadius: 4, + ...text(layer, "sans", "positive", { size: "xs" }), + }, + no_remaining_tokens: { + padding: 4, + margin: { right: 16, top: 4 }, + background: background(layer, "on"), + cornerRadius: 4, + ...text(layer, "sans", "negative", { size: "xs" }), + }, apiKeyEditor: { background: background(layer, "on"), cornerRadius: 6,