mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-26 07:12:03 +03:00
Show remaining tokens
This commit is contained in:
parent
3750e64d9f
commit
f00f16fe37
43
Cargo.lock
generated
43
Cargo.lock
generated
@ -116,6 +116,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"settings",
|
||||
"theme",
|
||||
"tiktoken-rs",
|
||||
"util",
|
||||
"workspace",
|
||||
]
|
||||
@ -745,6 +746,21 @@ dependencies = [
|
||||
"which",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
|
||||
dependencies = [
|
||||
"bit-vec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-vec"
|
||||
version = "0.6.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
@ -870,6 +886,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3d4260bcc2e8fc9df1eac4919a720effeb63a3f0952f5bf4944adfa18897f09"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"once_cell",
|
||||
"regex-automata",
|
||||
"serde",
|
||||
]
|
||||
|
||||
@ -2220,6 +2238,16 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
|
||||
|
||||
[[package]]
|
||||
name = "fancy-regex"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2"
|
||||
dependencies = [
|
||||
"bit-set",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "1.9.0"
|
||||
@ -6969,6 +6997,21 @@ dependencies = [
|
||||
"weezl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiktoken-rs"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ba161c549e2c0686f35f5d920e63fad5cafba2c28ad2caceaf07e5d9fa6e8c4"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64 0.21.0",
|
||||
"bstr",
|
||||
"fancy-regex",
|
||||
"lazy_static",
|
||||
"parking_lot 0.12.1",
|
||||
"rustc-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.1.45"
|
||||
|
@ -29,6 +29,7 @@ isahc.workspace = true
|
||||
schemars.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
tiktoken-rs = "0.4"
|
||||
|
||||
[dev-dependencies]
|
||||
editor = { path = "../editor", features = ["test-support"] }
|
||||
|
@ -16,7 +16,8 @@ use gpui::{
|
||||
use isahc::{http::StatusCode, Request, RequestExt};
|
||||
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
|
||||
use settings::SettingsStore;
|
||||
use std::{cell::Cell, io, rc::Rc, sync::Arc};
|
||||
use std::{cell::Cell, io, rc::Rc, sync::Arc, time::Duration};
|
||||
use tiktoken_rs::model::get_context_size;
|
||||
use util::{post_inc, ResultExt, TryFutureExt};
|
||||
use workspace::{
|
||||
dock::{DockPosition, Panel},
|
||||
@ -398,7 +399,12 @@ struct Assistant {
|
||||
completion_count: usize,
|
||||
pending_completions: Vec<PendingCompletion>,
|
||||
languages: Arc<LanguageRegistry>,
|
||||
model: String,
|
||||
token_count: Option<usize>,
|
||||
max_token_count: usize,
|
||||
pending_token_count: Task<Option<()>>,
|
||||
api_key: Rc<Cell<Option<String>>>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
impl Entity for Assistant {
|
||||
@ -411,19 +417,78 @@ impl Assistant {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
let model = "gpt-3.5-turbo";
|
||||
let buffer = cx.add_model(|_| MultiBuffer::new(0));
|
||||
let mut this = Self {
|
||||
buffer: cx.add_model(|_| MultiBuffer::new(0)),
|
||||
messages: Default::default(),
|
||||
messages_by_id: Default::default(),
|
||||
completion_count: Default::default(),
|
||||
pending_completions: Default::default(),
|
||||
languages: language_registry,
|
||||
token_count: None,
|
||||
max_token_count: get_context_size(model),
|
||||
pending_token_count: Task::ready(None),
|
||||
model: model.into(),
|
||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||
api_key,
|
||||
buffer,
|
||||
};
|
||||
this.push_message(Role::User, cx);
|
||||
this.count_remaining_tokens(cx);
|
||||
this
|
||||
}
|
||||
|
||||
fn handle_buffer_event(
|
||||
&mut self,
|
||||
_: ModelHandle<MultiBuffer>,
|
||||
event: &editor::multi_buffer::Event,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) {
|
||||
match event {
|
||||
editor::multi_buffer::Event::ExcerptsAdded { .. }
|
||||
| editor::multi_buffer::Event::ExcerptsRemoved { .. }
|
||||
| editor::multi_buffer::Event::Edited => self.count_remaining_tokens(cx),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
|
||||
let messages = self
|
||||
.messages
|
||||
.iter()
|
||||
.map(|message| tiktoken_rs::ChatCompletionRequestMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".into(),
|
||||
Role::Assistant => "assistant".into(),
|
||||
Role::System => "system".into(),
|
||||
},
|
||||
content: message.content.read(cx).text(),
|
||||
name: None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let model = self.model.clone();
|
||||
self.pending_token_count = cx.spawn(|this, mut cx| {
|
||||
async move {
|
||||
cx.background().timer(Duration::from_millis(200)).await;
|
||||
let token_count = cx
|
||||
.background()
|
||||
.spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
|
||||
.await?;
|
||||
|
||||
this.update(&mut cx, |this, cx| {
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify()
|
||||
});
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.log_err()
|
||||
});
|
||||
}
|
||||
|
||||
fn remaining_tokens(&self) -> Option<isize> {
|
||||
Some(self.max_token_count as isize - self.token_count? as isize)
|
||||
}
|
||||
|
||||
fn assist(&mut self, cx: &mut ModelContext<Self>) {
|
||||
let messages = self
|
||||
.messages
|
||||
@ -434,7 +499,7 @@ impl Assistant {
|
||||
})
|
||||
.collect();
|
||||
let request = OpenAIRequest {
|
||||
model: "gpt-3.5-turbo".into(),
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
stream: true,
|
||||
};
|
||||
@ -530,6 +595,7 @@ struct PendingCompletion {
|
||||
struct AssistantEditor {
|
||||
assistant: ModelHandle<Assistant>,
|
||||
editor: ViewHandle<Editor>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
}
|
||||
|
||||
impl AssistantEditor {
|
||||
@ -590,7 +656,11 @@ impl AssistantEditor {
|
||||
);
|
||||
editor
|
||||
});
|
||||
Self { assistant, editor }
|
||||
Self {
|
||||
_subscriptions: vec![cx.observe(&assistant, |_, _, cx| cx.notify())],
|
||||
assistant,
|
||||
editor,
|
||||
}
|
||||
}
|
||||
|
||||
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
|
||||
@ -684,10 +754,34 @@ impl View for AssistantEditor {
|
||||
|
||||
fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
|
||||
let theme = &theme::current(cx).assistant;
|
||||
let remaining_tokens = self
|
||||
.assistant
|
||||
.read(cx)
|
||||
.remaining_tokens()
|
||||
.map(|remaining_tokens| {
|
||||
let remaining_tokens_style = if remaining_tokens <= 0 {
|
||||
&theme.no_remaining_tokens
|
||||
} else {
|
||||
&theme.remaining_tokens
|
||||
};
|
||||
Label::new(
|
||||
remaining_tokens.to_string(),
|
||||
remaining_tokens_style.text.clone(),
|
||||
)
|
||||
.contained()
|
||||
.with_style(remaining_tokens_style.container)
|
||||
.aligned()
|
||||
.top()
|
||||
.right()
|
||||
});
|
||||
|
||||
ChildView::new(&self.editor, cx)
|
||||
.contained()
|
||||
.with_style(theme.container)
|
||||
Stack::new()
|
||||
.with_child(
|
||||
ChildView::new(&self.editor, cx)
|
||||
.contained()
|
||||
.with_style(theme.container),
|
||||
)
|
||||
.with_children(remaining_tokens)
|
||||
.into_any()
|
||||
}
|
||||
|
||||
|
@ -10,7 +10,7 @@ pub mod items;
|
||||
mod link_go_to_definition;
|
||||
mod mouse_context_menu;
|
||||
pub mod movement;
|
||||
mod multi_buffer;
|
||||
pub mod multi_buffer;
|
||||
mod persistence;
|
||||
pub mod scroll;
|
||||
pub mod selections_collection;
|
||||
|
@ -976,6 +976,8 @@ pub struct AssistantStyle {
|
||||
pub sent_at: ContainedText,
|
||||
pub user_sender: ContainedText,
|
||||
pub assistant_sender: ContainedText,
|
||||
pub remaining_tokens: ContainedText,
|
||||
pub no_remaining_tokens: ContainedText,
|
||||
pub api_key_editor: FieldEditor,
|
||||
pub api_key_prompt: ContainedText,
|
||||
}
|
||||
|
@ -23,6 +23,20 @@ export default function assistant(colorScheme: ColorScheme) {
|
||||
margin: { top: 2, left: 8 },
|
||||
...text(layer, "sans", "default", { size: "2xs" }),
|
||||
},
|
||||
remaining_tokens: {
|
||||
padding: 4,
|
||||
margin: { right: 16, top: 4 },
|
||||
background: background(layer, "on"),
|
||||
cornerRadius: 4,
|
||||
...text(layer, "sans", "positive", { size: "xs" }),
|
||||
},
|
||||
no_remaining_tokens: {
|
||||
padding: 4,
|
||||
margin: { right: 16, top: 4 },
|
||||
background: background(layer, "on"),
|
||||
cornerRadius: 4,
|
||||
...text(layer, "sans", "negative", { size: "xs" }),
|
||||
},
|
||||
apiKeyEditor: {
|
||||
background: background(layer, "on"),
|
||||
cornerRadius: 6,
|
||||
|
Loading…
Reference in New Issue
Block a user