mirror of
https://github.com/zed-industries/zed.git
synced 2024-11-07 20:39:04 +03:00
Add a default_open_ai_model
setting for the assistant (#2876)
[This PR has been sitting around for a bit](https://github.com/zed-industries/zed/pull/2845). I received a bit of mixed opinions from the team on how this setting should work, if it should use the full model names or some simpler form of it, etc. I went ahead and made the decision to do the following: - Use the full model names in settings - ex: `gpt-4-0613` - Default to `gpt-4-0613` when no setting is present - Save the full model names in the conversation history files (this is how it was prior) - ex: `gpt-4-0613` - Display the shortened model names in the assistant - ex: `gpt-4` - Not worry about adding an option to add custom models (can add in a follow-up PR) - Not query what models are available to the user via their api key (can add in a follow-up PR) Release Notes: - Added a `default_open_ai_model` setting for the assistant (defaults to `gpt-4-0613`). --------- Co-authored-by: Mikayla <mikayla@zed.dev>
This commit is contained in:
parent
3a13795021
commit
a836f9c23d
@ -138,7 +138,13 @@
|
||||
// Default width when the assistant is docked to the left or right.
|
||||
"default_width": 640,
|
||||
// Default height when the assistant is docked to the bottom.
|
||||
"default_height": 320
|
||||
"default_height": 320,
|
||||
// The default OpenAI model to use when starting new conversations. This
|
||||
// setting can take two values:
|
||||
//
|
||||
// 1. "gpt-3.5-turbo-0613""
|
||||
// 2. "gpt-4-0613""
|
||||
"default_open_ai_model": "gpt-4-0613"
|
||||
},
|
||||
// Whether the screen sharing icon is shown in the os status bar.
|
||||
"show_call_status_icon": true,
|
||||
|
@ -3,6 +3,7 @@ mod assistant_settings;
|
||||
|
||||
use anyhow::Result;
|
||||
pub use assistant::AssistantPanel;
|
||||
use assistant_settings::OpenAIModel;
|
||||
use chrono::{DateTime, Local};
|
||||
use collections::HashMap;
|
||||
use fs::Fs;
|
||||
@ -60,7 +61,7 @@ struct SavedConversation {
|
||||
messages: Vec<SavedMessage>,
|
||||
message_metadata: HashMap<MessageId, MessageMetadata>,
|
||||
summary: String,
|
||||
model: String,
|
||||
model: OpenAIModel,
|
||||
}
|
||||
|
||||
impl SavedConversation {
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
assistant_settings::{AssistantDockPosition, AssistantSettings},
|
||||
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
|
||||
MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
|
||||
RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
|
||||
};
|
||||
@ -833,7 +833,7 @@ struct Conversation {
|
||||
pending_summary: Task<Option<()>>,
|
||||
completion_count: usize,
|
||||
pending_completions: Vec<PendingCompletion>,
|
||||
model: String,
|
||||
model: OpenAIModel,
|
||||
token_count: Option<usize>,
|
||||
max_token_count: usize,
|
||||
pending_token_count: Task<Option<()>>,
|
||||
@ -853,7 +853,6 @@ impl Conversation {
|
||||
language_registry: Arc<LanguageRegistry>,
|
||||
cx: &mut ModelContext<Self>,
|
||||
) -> Self {
|
||||
let model = "gpt-3.5-turbo-0613";
|
||||
let markdown = language_registry.language_for_name("Markdown");
|
||||
let buffer = cx.add_model(|cx| {
|
||||
let mut buffer = Buffer::new(0, "", cx);
|
||||
@ -872,6 +871,9 @@ impl Conversation {
|
||||
buffer
|
||||
});
|
||||
|
||||
let settings = settings::get::<AssistantSettings>(cx);
|
||||
let model = settings.default_open_ai_model.clone();
|
||||
|
||||
let mut this = Self {
|
||||
message_anchors: Default::default(),
|
||||
messages_metadata: Default::default(),
|
||||
@ -881,9 +883,9 @@ impl Conversation {
|
||||
completion_count: Default::default(),
|
||||
pending_completions: Default::default(),
|
||||
token_count: None,
|
||||
max_token_count: tiktoken_rs::model::get_context_size(model),
|
||||
max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
|
||||
pending_token_count: Task::ready(None),
|
||||
model: model.into(),
|
||||
model: model.clone(),
|
||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||
pending_save: Task::ready(Ok(())),
|
||||
path: None,
|
||||
@ -977,7 +979,7 @@ impl Conversation {
|
||||
completion_count: Default::default(),
|
||||
pending_completions: Default::default(),
|
||||
token_count: None,
|
||||
max_token_count: tiktoken_rs::model::get_context_size(&model),
|
||||
max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
|
||||
pending_token_count: Task::ready(None),
|
||||
model,
|
||||
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
|
||||
@ -1031,13 +1033,16 @@ impl Conversation {
|
||||
cx.background().timer(Duration::from_millis(200)).await;
|
||||
let token_count = cx
|
||||
.background()
|
||||
.spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
|
||||
.spawn(async move {
|
||||
tiktoken_rs::num_tokens_from_messages(&model.full_name(), &messages)
|
||||
})
|
||||
.await?;
|
||||
|
||||
this.upgrade(&cx)
|
||||
.ok_or_else(|| anyhow!("conversation was dropped"))?
|
||||
.update(&mut cx, |this, cx| {
|
||||
this.max_token_count = tiktoken_rs::model::get_context_size(&this.model);
|
||||
this.max_token_count =
|
||||
tiktoken_rs::model::get_context_size(&this.model.full_name());
|
||||
this.token_count = Some(token_count);
|
||||
cx.notify()
|
||||
});
|
||||
@ -1051,7 +1056,7 @@ impl Conversation {
|
||||
Some(self.max_token_count as isize - self.token_count? as isize)
|
||||
}
|
||||
|
||||
fn set_model(&mut self, model: String, cx: &mut ModelContext<Self>) {
|
||||
fn set_model(&mut self, model: OpenAIModel, cx: &mut ModelContext<Self>) {
|
||||
self.model = model;
|
||||
self.count_remaining_tokens(cx);
|
||||
cx.notify();
|
||||
@ -1093,7 +1098,7 @@ impl Conversation {
|
||||
}
|
||||
} else {
|
||||
let request = OpenAIRequest {
|
||||
model: self.model.clone(),
|
||||
model: self.model.full_name().to_string(),
|
||||
messages: self
|
||||
.messages(cx)
|
||||
.filter(|message| matches!(message.status, MessageStatus::Done))
|
||||
@ -1419,7 +1424,7 @@ impl Conversation {
|
||||
.into(),
|
||||
}));
|
||||
let request = OpenAIRequest {
|
||||
model: self.model.clone(),
|
||||
model: self.model.full_name().to_string(),
|
||||
messages: messages.collect(),
|
||||
stream: true,
|
||||
};
|
||||
@ -2023,11 +2028,8 @@ impl ConversationEditor {
|
||||
|
||||
fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
|
||||
self.conversation.update(cx, |conversation, cx| {
|
||||
let new_model = match conversation.model.as_str() {
|
||||
"gpt-4-0613" => "gpt-3.5-turbo-0613",
|
||||
_ => "gpt-4-0613",
|
||||
};
|
||||
conversation.set_model(new_model.into(), cx);
|
||||
let new_model = conversation.model.cycle();
|
||||
conversation.set_model(new_model, cx);
|
||||
});
|
||||
}
|
||||
|
||||
@ -2049,7 +2051,8 @@ impl ConversationEditor {
|
||||
|
||||
MouseEventHandler::new::<Model, _>(0, cx, |state, cx| {
|
||||
let style = style.model.style_for(state);
|
||||
Label::new(self.conversation.read(cx).model.clone(), style.text.clone())
|
||||
let model_display_name = self.conversation.read(cx).model.short_name();
|
||||
Label::new(model_display_name, style.text.clone())
|
||||
.contained()
|
||||
.with_style(style.container)
|
||||
})
|
||||
@ -2238,6 +2241,8 @@ mod tests {
|
||||
|
||||
#[gpui::test]
|
||||
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
@ -2364,6 +2369,8 @@ mod tests {
|
||||
|
||||
#[gpui::test]
|
||||
fn test_message_splitting(cx: &mut AppContext) {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
@ -2458,6 +2465,8 @@ mod tests {
|
||||
|
||||
#[gpui::test]
|
||||
fn test_messages_for_offsets(cx: &mut AppContext) {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
|
||||
let buffer = conversation.read(cx).buffer.clone();
|
||||
@ -2538,6 +2547,8 @@ mod tests {
|
||||
|
||||
#[gpui::test]
|
||||
fn test_serialization(cx: &mut AppContext) {
|
||||
cx.set_global(SettingsStore::test(cx));
|
||||
init(cx);
|
||||
let registry = Arc::new(LanguageRegistry::test());
|
||||
let conversation =
|
||||
cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
|
||||
|
@ -3,6 +3,37 @@ use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::Setting;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
pub enum OpenAIModel {
|
||||
#[serde(rename = "gpt-3.5-turbo-0613")]
|
||||
ThreePointFiveTurbo,
|
||||
#[serde(rename = "gpt-4-0613")]
|
||||
Four,
|
||||
}
|
||||
|
||||
impl OpenAIModel {
|
||||
pub fn full_name(&self) -> &'static str {
|
||||
match self {
|
||||
OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
|
||||
OpenAIModel::Four => "gpt-4-0613",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn short_name(&self) -> &'static str {
|
||||
match self {
|
||||
OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo",
|
||||
OpenAIModel::Four => "gpt-4",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cycle(&self) -> Self {
|
||||
match self {
|
||||
OpenAIModel::ThreePointFiveTurbo => OpenAIModel::Four,
|
||||
OpenAIModel::Four => OpenAIModel::ThreePointFiveTurbo,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AssistantDockPosition {
|
||||
@ -17,6 +48,7 @@ pub struct AssistantSettings {
|
||||
pub dock: AssistantDockPosition,
|
||||
pub default_width: f32,
|
||||
pub default_height: f32,
|
||||
pub default_open_ai_model: OpenAIModel,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
|
||||
@ -25,6 +57,7 @@ pub struct AssistantSettingsContent {
|
||||
pub dock: Option<AssistantDockPosition>,
|
||||
pub default_width: Option<f32>,
|
||||
pub default_height: Option<f32>,
|
||||
pub default_open_ai_model: Option<OpenAIModel>,
|
||||
}
|
||||
|
||||
impl Setting for AssistantSettings {
|
||||
|
@ -1,4 +1,4 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use collections::{btree_map, hash_map, BTreeMap, HashMap};
|
||||
use gpui::AppContext;
|
||||
use lazy_static::lazy_static;
|
||||
@ -162,6 +162,7 @@ impl SettingsStore {
|
||||
|
||||
if let Some(setting) = setting_value
|
||||
.load_setting(&default_settings, &user_values_stack, cx)
|
||||
.context("A default setting must be added to the `default.json` file")
|
||||
.log_err()
|
||||
{
|
||||
setting_value.set_global_value(setting);
|
||||
|
Loading…
Reference in New Issue
Block a user