diff --git a/crates/ai/src/providers/open_ai.rs b/crates/ai/src/providers/open_ai.rs new file mode 100644 index 0000000000..9de21b8a60 --- /dev/null +++ b/crates/ai/src/providers/open_ai.rs @@ -0,0 +1,9 @@ +pub mod completion; +pub mod embedding; +pub mod model; + +pub use completion::*; +pub use embedding::*; +pub use model::OpenAiLanguageModel; + +pub const OPEN_AI_API_URL: &'static str = "https://api.openai.com/v1"; diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index aa58950113..4bdb94d79b 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -21,7 +21,7 @@ use crate::{ models::LanguageModel, }; -use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL}; +use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL}; #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] @@ -58,7 +58,7 @@ pub struct RequestMessage { } #[derive(Debug, Default, Serialize)] -pub struct OpenAIRequest { +pub struct OpenAiRequest { pub model: String, pub messages: Vec, pub stream: bool, @@ -66,7 +66,7 @@ pub struct OpenAIRequest { pub temperature: f32, } -impl CompletionRequest for OpenAIRequest { +impl CompletionRequest for OpenAiRequest { fn data(&self) -> serde_json::Result { serde_json::to_string(self) } @@ -79,7 +79,7 @@ pub struct ResponseMessage { } #[derive(Deserialize, Debug)] -pub struct OpenAIUsage { +pub struct OpenAiUsage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, @@ -93,20 +93,20 @@ pub struct ChatChoiceDelta { } #[derive(Deserialize, Debug)] -pub struct OpenAIResponseStreamEvent { +pub struct OpenAiResponseStreamEvent { pub id: Option, pub object: String, pub created: u32, pub model: String, pub choices: Vec, - pub usage: Option, + pub usage: Option, } pub async fn stream_completion( credential: ProviderCredential, executor: BackgroundExecutor, request: Box, -) -> Result>> { +) -> Result>> { let api_key = match credential { ProviderCredential::Credentials { api_key } => api_key, _ => { @@ -114,10 +114,10 @@ pub async fn stream_completion( } }; - let (tx, rx) = futures::channel::mpsc::unbounded::>(); + let (tx, rx) = futures::channel::mpsc::unbounded::>(); let json_data = request.data()?; - let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) + let mut response = Request::post(format!("{OPEN_AI_API_URL}/chat/completions")) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body(json_data)? @@ -132,7 +132,7 @@ pub async fn stream_completion( fn parse_line( line: Result, - ) -> Result> { + ) -> Result> { if let Some(data) = line?.strip_prefix("data: ") { let event = serde_json::from_str(data)?; Ok(Some(event)) @@ -169,16 +169,16 @@ pub async fn stream_completion( response.body_mut().read_to_string(&mut body).await?; #[derive(Deserialize)] - struct OpenAIResponse { - error: OpenAIError, + struct OpenAiResponse { + error: OpenAiError, } #[derive(Deserialize)] - struct OpenAIError { + struct OpenAiError { message: String, } - match serde_json::from_str::(&body) { + match serde_json::from_str::(&body) { Ok(response) if !response.error.message.is_empty() => Err(anyhow!( "Failed to connect to OpenAI API: {}", response.error.message, @@ -194,16 +194,16 @@ pub async fn stream_completion( } #[derive(Clone)] -pub struct OpenAICompletionProvider { - model: OpenAILanguageModel, +pub struct OpenAiCompletionProvider { + model: OpenAiLanguageModel, credential: Arc>, executor: BackgroundExecutor, } -impl OpenAICompletionProvider { +impl OpenAiCompletionProvider { pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self { let model = executor - .spawn(async move { OpenAILanguageModel::load(&model_name) }) + .spawn(async move { OpenAiLanguageModel::load(&model_name) }) .await; let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); Self { @@ -214,7 +214,7 @@ impl OpenAICompletionProvider { } } -impl CredentialProvider for OpenAICompletionProvider { +impl CredentialProvider for OpenAiCompletionProvider { fn has_credentials(&self) -> bool { match *self.credential.read() { ProviderCredential::Credentials { .. } => true, @@ -232,7 +232,7 @@ impl CredentialProvider for OpenAICompletionProvider { if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { async move { ProviderCredential::Credentials { api_key } }.boxed() } else { - let credentials = cx.read_credentials(OPENAI_API_URL); + let credentials = cx.read_credentials(OPEN_AI_API_URL); async move { if let Some(Some((_, api_key))) = credentials.await.log_err() { if let Some(api_key) = String::from_utf8(api_key).log_err() { @@ -266,7 +266,7 @@ impl CredentialProvider for OpenAICompletionProvider { let credential = credential.clone(); let write_credentials = match credential { ProviderCredential::Credentials { api_key } => { - Some(cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())) + Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes())) } _ => None, }; @@ -281,7 +281,7 @@ impl CredentialProvider for OpenAICompletionProvider { fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> { *self.credential.write() = ProviderCredential::NoCredentials; - let delete_credentials = cx.delete_credentials(OPENAI_API_URL); + let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL); async move { delete_credentials.await.log_err(); } @@ -289,7 +289,7 @@ impl CredentialProvider for OpenAICompletionProvider { } } -impl CompletionProvider for OpenAICompletionProvider { +impl CompletionProvider for OpenAiCompletionProvider { fn base_model(&self) -> Box { let model: Box = Box::new(self.model.clone()); model diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 89aebb1b76..7480a454a1 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -25,17 +25,17 @@ use util::ResultExt; use crate::auth::{CredentialProvider, ProviderCredential}; use crate::embedding::{Embedding, EmbeddingProvider}; use crate::models::LanguageModel; -use crate::providers::open_ai::OpenAILanguageModel; +use crate::providers::open_ai::OpenAiLanguageModel; -use crate::providers::open_ai::OPENAI_API_URL; +use crate::providers::open_ai::OPEN_AI_API_URL; lazy_static! { - static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); + static ref OPEN_AI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } #[derive(Clone)] -pub struct OpenAIEmbeddingProvider { - model: OpenAILanguageModel, +pub struct OpenAiEmbeddingProvider { + model: OpenAiLanguageModel, credential: Arc>, pub client: Arc, pub executor: BackgroundExecutor, @@ -44,42 +44,42 @@ pub struct OpenAIEmbeddingProvider { } #[derive(Serialize)] -struct OpenAIEmbeddingRequest<'a> { +struct OpenAiEmbeddingRequest<'a> { model: &'static str, input: Vec<&'a str>, } #[derive(Deserialize)] -struct OpenAIEmbeddingResponse { - data: Vec, - usage: OpenAIEmbeddingUsage, +struct OpenAiEmbeddingResponse { + data: Vec, + usage: OpenAiEmbeddingUsage, } #[derive(Debug, Deserialize)] -struct OpenAIEmbedding { +struct OpenAiEmbedding { embedding: Vec, index: usize, object: String, } #[derive(Deserialize)] -struct OpenAIEmbeddingUsage { +struct OpenAiEmbeddingUsage { prompt_tokens: usize, total_tokens: usize, } -impl OpenAIEmbeddingProvider { +impl OpenAiEmbeddingProvider { pub async fn new(client: Arc, executor: BackgroundExecutor) -> Self { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); // Loading the model is expensive, so ensure this runs off the main thread. let model = executor - .spawn(async move { OpenAILanguageModel::load("text-embedding-ada-002") }) + .spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") }) .await; let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); - OpenAIEmbeddingProvider { + OpenAiEmbeddingProvider { model, credential, client, @@ -140,7 +140,7 @@ impl OpenAIEmbeddingProvider { .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body( - serde_json::to_string(&OpenAIEmbeddingRequest { + serde_json::to_string(&OpenAiEmbeddingRequest { input: spans.clone(), model: "text-embedding-ada-002", }) @@ -152,7 +152,7 @@ impl OpenAIEmbeddingProvider { } } -impl CredentialProvider for OpenAIEmbeddingProvider { +impl CredentialProvider for OpenAiEmbeddingProvider { fn has_credentials(&self) -> bool { match *self.credential.read() { ProviderCredential::Credentials { .. } => true, @@ -170,7 +170,7 @@ impl CredentialProvider for OpenAIEmbeddingProvider { if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { async move { ProviderCredential::Credentials { api_key } }.boxed() } else { - let credentials = cx.read_credentials(OPENAI_API_URL); + let credentials = cx.read_credentials(OPEN_AI_API_URL); async move { if let Some(Some((_, api_key))) = credentials.await.log_err() { if let Some(api_key) = String::from_utf8(api_key).log_err() { @@ -204,7 +204,7 @@ impl CredentialProvider for OpenAIEmbeddingProvider { let credential = credential.clone(); let write_credentials = match credential { ProviderCredential::Credentials { api_key } => { - Some(cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())) + Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes())) } _ => None, }; @@ -219,7 +219,7 @@ impl CredentialProvider for OpenAIEmbeddingProvider { fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> { *self.credential.write() = ProviderCredential::NoCredentials; - let delete_credentials = cx.delete_credentials(OPENAI_API_URL); + let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL); async move { delete_credentials.await.log_err(); } @@ -228,7 +228,7 @@ impl CredentialProvider for OpenAIEmbeddingProvider { } #[async_trait] -impl EmbeddingProvider for OpenAIEmbeddingProvider { +impl EmbeddingProvider for OpenAiEmbeddingProvider { fn base_model(&self) -> Box { let model: Box = Box::new(self.model.clone()); model @@ -270,7 +270,7 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { StatusCode::OK => { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; - let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?; log::trace!( "openai embedding completed. tokens: {:?}", diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs deleted file mode 100644 index 7d2f86045d..0000000000 --- a/crates/ai/src/providers/open_ai/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -pub mod completion; -pub mod embedding; -pub mod model; - -pub use completion::*; -pub use embedding::*; -pub use model::OpenAILanguageModel; - -pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; diff --git a/crates/ai/src/providers/open_ai/model.rs b/crates/ai/src/providers/open_ai/model.rs index 6e306c80b9..ba3488d7dd 100644 --- a/crates/ai/src/providers/open_ai/model.rs +++ b/crates/ai/src/providers/open_ai/model.rs @@ -5,22 +5,22 @@ use util::ResultExt; use crate::models::{LanguageModel, TruncationDirection}; #[derive(Clone)] -pub struct OpenAILanguageModel { +pub struct OpenAiLanguageModel { name: String, bpe: Option, } -impl OpenAILanguageModel { +impl OpenAiLanguageModel { pub fn load(model_name: &str) -> Self { let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); - OpenAILanguageModel { + OpenAiLanguageModel { name: model_name.to_string(), bpe, } } } -impl LanguageModel for OpenAILanguageModel { +impl LanguageModel for OpenAiLanguageModel { fn name(&self) -> String { self.name.clone() } diff --git a/crates/ai/src/providers/open_ai/new.rs b/crates/ai/src/providers/open_ai/new.rs deleted file mode 100644 index c7d67f2ba1..0000000000 --- a/crates/ai/src/providers/open_ai/new.rs +++ /dev/null @@ -1,11 +0,0 @@ -pub trait LanguageModel { - fn name(&self) -> String; - fn count_tokens(&self, content: &str) -> anyhow::Result; - fn truncate( - &self, - content: &str, - length: usize, - direction: TruncationDirection, - ) -> anyhow::Result; - fn capacity(&self) -> anyhow::Result; -} diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 743c8b22e6..d86d889aff 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -7,7 +7,7 @@ mod streaming_diff; use ai::providers::open_ai::Role; use anyhow::Result; pub use assistant_panel::AssistantPanel; -use assistant_settings::OpenAIModel; +use assistant_settings::OpenAiModel; use chrono::{DateTime, Local}; use collections::HashMap; use fs::Fs; @@ -68,7 +68,7 @@ struct SavedConversation { messages: Vec, message_metadata: HashMap, summary: String, - model: OpenAIModel, + model: OpenAiModel, } impl SavedConversation { diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 3fcbb9a3c9..2488e2c763 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,5 +1,5 @@ use crate::{ - assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel}, + assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAiModel}, codegen::{self, Codegen, CodegenKind}, prompts::generate_content_prompt, Assist, CycleMessageRole, InlineAssist, MessageId, MessageMetadata, MessageStatus, @@ -10,7 +10,7 @@ use ai::prompts::repository_context::PromptCodeSnippet; use ai::{ auth::ProviderCredential, completion::{CompletionProvider, CompletionRequest}, - providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage}, + providers::open_ai::{OpenAiCompletionProvider, OpenAiRequest, RequestMessage}, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; @@ -123,7 +123,7 @@ impl AssistantPanel { .unwrap_or_default(); // Defaulting currently to GPT4, allow for this to be set via config. let completion_provider = - OpenAICompletionProvider::new("gpt-4".into(), cx.background_executor().clone()) + OpenAiCompletionProvider::new("gpt-4".into(), cx.background_executor().clone()) .await; // TODO: deserialize state. @@ -717,7 +717,7 @@ impl AssistantPanel { content: prompt, }); - let request = Box::new(OpenAIRequest { + let request = Box::new(OpenAiRequest { model: model.full_name().into(), messages, stream: true, @@ -1393,7 +1393,7 @@ struct Conversation { pending_summary: Task>, completion_count: usize, pending_completions: Vec, - model: OpenAIModel, + model: OpenAiModel, token_count: Option, max_token_count: usize, pending_token_count: Task>, @@ -1501,7 +1501,7 @@ impl Conversation { }; let model = saved_conversation.model; let completion_provider: Arc = Arc::new( - OpenAICompletionProvider::new( + OpenAiCompletionProvider::new( model.full_name().into(), cx.background_executor().clone(), ) @@ -1626,7 +1626,7 @@ impl Conversation { Some(self.max_token_count as isize - self.token_count? as isize) } - fn set_model(&mut self, model: OpenAIModel, cx: &mut ModelContext) { + fn set_model(&mut self, model: OpenAiModel, cx: &mut ModelContext) { self.model = model; self.count_remaining_tokens(cx); cx.notify(); @@ -1679,7 +1679,7 @@ impl Conversation { return Default::default(); } - let request: Box = Box::new(OpenAIRequest { + let request: Box = Box::new(OpenAiRequest { model: self.model.full_name().to_string(), messages: self .messages(cx) @@ -1962,7 +1962,7 @@ impl Conversation { content: "Summarize the conversation into a short title without punctuation" .into(), })); - let request: Box = Box::new(OpenAIRequest { + let request: Box = Box::new(OpenAiRequest { model: self.model.full_name().to_string(), messages: messages.collect(), stream: true, diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index b2a9231a57..4b37a1b2f6 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use settings::Settings; #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] -pub enum OpenAIModel { +pub enum OpenAiModel { #[serde(rename = "gpt-3.5-turbo-0613")] ThreePointFiveTurbo, #[serde(rename = "gpt-4-0613")] @@ -14,28 +14,28 @@ pub enum OpenAIModel { FourTurbo, } -impl OpenAIModel { +impl OpenAiModel { pub fn full_name(&self) -> &'static str { match self { - OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613", - OpenAIModel::Four => "gpt-4-0613", - OpenAIModel::FourTurbo => "gpt-4-1106-preview", + OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613", + OpenAiModel::Four => "gpt-4-0613", + OpenAiModel::FourTurbo => "gpt-4-1106-preview", } } pub fn short_name(&self) -> &'static str { match self { - OpenAIModel::ThreePointFiveTurbo => "gpt-3.5-turbo", - OpenAIModel::Four => "gpt-4", - OpenAIModel::FourTurbo => "gpt-4-turbo", + OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo", + OpenAiModel::Four => "gpt-4", + OpenAiModel::FourTurbo => "gpt-4-turbo", } } pub fn cycle(&self) -> Self { match self { - OpenAIModel::ThreePointFiveTurbo => OpenAIModel::Four, - OpenAIModel::Four => OpenAIModel::FourTurbo, - OpenAIModel::FourTurbo => OpenAIModel::ThreePointFiveTurbo, + OpenAiModel::ThreePointFiveTurbo => OpenAiModel::Four, + OpenAiModel::Four => OpenAiModel::FourTurbo, + OpenAiModel::FourTurbo => OpenAiModel::ThreePointFiveTurbo, } } } @@ -54,7 +54,7 @@ pub struct AssistantSettings { pub dock: AssistantDockPosition, pub default_width: Pixels, pub default_height: Pixels, - pub default_open_ai_model: OpenAIModel, + pub default_open_ai_model: OpenAiModel, } /// Assistant panel settings @@ -79,7 +79,7 @@ pub struct AssistantSettingsContent { /// The default OpenAI model to use when starting new conversations. /// /// Default: gpt-4-1106-preview - pub default_open_ai_model: Option, + pub default_open_ai_model: Option, } impl Settings for AssistantSettings { diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index c88e257295..c9614a4851 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -4,7 +4,7 @@ use ai::prompts::file_context::FileContext; use ai::prompts::generate::GenerateInlineContent; use ai::prompts::preamble::EngineerPreamble; use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext}; -use ai::providers::open_ai::OpenAILanguageModel; +use ai::providers::open_ai::OpenAiLanguageModel; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp::{self, Reverse}; use std::ops::Range; @@ -131,7 +131,7 @@ pub fn generate_content_prompt( project_name: Option, ) -> anyhow::Result { // Using new Prompt Templates - let openai_model: Arc = Arc::new(OpenAILanguageModel::load(model)); + let openai_model: Arc = Arc::new(OpenAiLanguageModel::load(model)); let lang_name = if let Some(language_name) = language_name { Some(language_name.to_string()) } else { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 62773cced8..6725b5a93e 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -8,7 +8,7 @@ mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; use ai::embedding::{Embedding, EmbeddingProvider}; -use ai::providers::open_ai::OpenAIEmbeddingProvider; +use ai::providers::open_ai::OpenAiEmbeddingProvider; use anyhow::{anyhow, Context as _, Result}; use collections::{BTreeMap, HashMap, HashSet}; use db::VectorDatabase; @@ -91,7 +91,7 @@ pub fn init( cx.spawn(move |cx| async move { let embedding_provider = - OpenAIEmbeddingProvider::new(http_client, cx.background_executor().clone()).await; + OpenAiEmbeddingProvider::new(http_client, cx.background_executor().clone()).await; let semantic_index = SemanticIndex::new( fs, db_file_path,