clean up warnings and fix tests in the ai crate

This commit is contained in:
KCaverly 2023-10-30 11:07:24 -04:00
parent a2c3971ad6
commit f3c113fe02
5 changed files with 102 additions and 143 deletions

View File

@ -13,4 +13,11 @@ pub trait CompletionProvider: CredentialProvider {
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn box_clone(&self) -> Box<dyn CompletionProvider>;
}
impl Clone for Box<dyn CompletionProvider> {
fn clone(&self) -> Box<dyn CompletionProvider> {
self.box_clone()
}
}

View File

@ -147,7 +147,7 @@ pub(crate) mod tests {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::Start,
TruncationDirection::End,
)?;
token_count = max_token_length;
}
@ -172,7 +172,7 @@ pub(crate) mod tests {
content = args.model.truncate(
&content,
max_token_length,
TruncationDirection::Start,
TruncationDirection::End,
)?;
token_count = max_token_length;
}

View File

@ -193,6 +193,7 @@ pub async fn stream_completion(
}
}
#[derive(Clone)]
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
@ -271,6 +272,10 @@ impl CompletionProvider for OpenAICompletionProvider {
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
// Currently the CompletionRequest for OpenAI, includes a 'model' parameter
// This means that the model is determined by the CompletionRequest and not the CompletionProvider,
// which is currently model based, due to the langauge model.
// At some point in the future we should rectify this.
let credential = self.credential.read().clone();
let request = stream_completion(credential, self.executor.clone(), prompt);
async move {
@ -287,4 +292,7 @@ impl CompletionProvider for OpenAICompletionProvider {
}
.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View File

@ -33,7 +33,10 @@ impl LanguageModel for FakeLanguageModel {
length: usize,
direction: TruncationDirection,
) -> anyhow::Result<String> {
println!("TRYING TO TRUNCATE: {:?}", length.clone());
if length > self.count_tokens(content)? {
println!("NOT TRUNCATING");
return anyhow::Ok(content.to_string());
}
@ -133,6 +136,14 @@ pub struct FakeCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
impl Clone for FakeCompletionProvider {
fn clone(&self) -> Self {
Self {
last_completion_tx: Mutex::new(None),
}
}
}
impl FakeCompletionProvider {
pub fn new() -> Self {
Self {
@ -174,4 +185,7 @@ impl CompletionProvider for FakeCompletionProvider {
*self.last_completion_tx.lock() = Some(tx);
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
Box::new((*self).clone())
}
}

View File

@ -9,9 +9,7 @@ use crate::{
use ai::{
auth::ProviderCredential,
completion::{CompletionProvider, CompletionRequest},
providers::open_ai::{
stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage,
},
providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage},
};
use ai::prompts::repository_context::PromptCodeSnippet;
@ -47,7 +45,7 @@ use search::BufferSearchBar;
use semantic_index::{SemanticIndex, SemanticIndexStatus};
use settings::SettingsStore;
use std::{
cell::{Cell, RefCell},
cell::Cell,
cmp,
fmt::Write,
iter,
@ -144,10 +142,8 @@ pub struct AssistantPanel {
zoomed: bool,
has_focus: bool,
toolbar: ViewHandle<Toolbar>,
credential: Rc<RefCell<ProviderCredential>>,
completion_provider: Box<dyn CompletionProvider>,
api_key_editor: Option<ViewHandle<Editor>>,
has_read_credentials: bool,
languages: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
subscriptions: Vec<Subscription>,
@ -223,10 +219,8 @@ impl AssistantPanel {
zoomed: false,
has_focus: false,
toolbar,
credential: Rc::new(RefCell::new(ProviderCredential::NoCredentials)),
completion_provider,
api_key_editor: None,
has_read_credentials: false,
languages: workspace.app_state().languages.clone(),
fs: workspace.app_state().fs.clone(),
width: None,
@ -265,7 +259,7 @@ impl AssistantPanel {
cx: &mut ViewContext<Workspace>,
) {
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
if this.update(cx, |assistant, cx| assistant.has_credentials(cx)) {
if this.update(cx, |assistant, _| assistant.has_credentials()) {
this
} else {
workspace.focus_panel::<AssistantPanel>(cx);
@ -331,6 +325,9 @@ impl AssistantPanel {
cx.background().clone(),
));
// Retrieve Credentials Authenticates the Provider
// provider.retrieve_credentials(cx);
let codegen = cx.add_model(|cx| {
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
});
@ -814,7 +811,7 @@ impl AssistantPanel {
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
let editor = cx.add_view(|cx| {
ConversationEditor::new(
self.credential.clone(),
self.completion_provider.clone(),
self.languages.clone(),
self.fs.clone(),
self.workspace.clone(),
@ -883,9 +880,8 @@ impl AssistantPanel {
let credential = ProviderCredential::Credentials {
api_key: api_key.clone(),
};
self.completion_provider
.save_credentials(cx, credential.clone());
*self.credential.borrow_mut() = credential;
self.completion_provider.save_credentials(cx, credential);
self.api_key_editor.take();
cx.focus_self();
@ -898,7 +894,6 @@ impl AssistantPanel {
fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
self.completion_provider.delete_credentials(cx);
*self.credential.borrow_mut() = ProviderCredential::NoCredentials;
self.api_key_editor = Some(build_api_key_editor(cx));
cx.focus_self();
cx.notify();
@ -1157,19 +1152,12 @@ impl AssistantPanel {
let fs = self.fs.clone();
let workspace = self.workspace.clone();
let credential = self.credential.clone();
let languages = self.languages.clone();
cx.spawn(|this, mut cx| async move {
let saved_conversation = fs.load(&path).await?;
let saved_conversation = serde_json::from_str(&saved_conversation)?;
let conversation = cx.add_model(|cx| {
Conversation::deserialize(
saved_conversation,
path.clone(),
credential,
languages,
cx,
)
Conversation::deserialize(saved_conversation, path.clone(), languages, cx)
});
this.update(&mut cx, |this, cx| {
// If, by the time we've loaded the conversation, the user has already opened
@ -1193,39 +1181,12 @@ impl AssistantPanel {
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
}
fn has_credentials(&mut self, cx: &mut ViewContext<Self>) -> bool {
let credential = self.load_credentials(cx);
match credential {
ProviderCredential::Credentials { .. } => true,
ProviderCredential::NotNeeded => true,
ProviderCredential::NoCredentials => false,
}
fn has_credentials(&mut self) -> bool {
self.completion_provider.has_credentials()
}
fn load_credentials(&mut self, cx: &mut ViewContext<Self>) -> ProviderCredential {
let existing_credential = self.credential.clone();
let existing_credential = existing_credential.borrow().clone();
match existing_credential {
ProviderCredential::NoCredentials => {
if !self.has_read_credentials {
self.has_read_credentials = true;
let retrieved_credentials = self.completion_provider.retrieve_credentials(cx);
match retrieved_credentials {
ProviderCredential::NoCredentials {} => {
self.api_key_editor = Some(build_api_key_editor(cx));
cx.notify();
}
_ => {
*self.credential.borrow_mut() = retrieved_credentials;
}
}
}
}
_ => {}
}
self.credential.borrow().clone()
fn load_credentials(&mut self, cx: &mut ViewContext<Self>) {
self.completion_provider.retrieve_credentials(cx);
}
}
@ -1475,10 +1436,10 @@ struct Conversation {
token_count: Option<usize>,
max_token_count: usize,
pending_token_count: Task<Option<()>>,
credential: Rc<RefCell<ProviderCredential>>,
pending_save: Task<Result<()>>,
path: Option<PathBuf>,
_subscriptions: Vec<Subscription>,
completion_provider: Box<dyn CompletionProvider>,
}
impl Entity for Conversation {
@ -1487,10 +1448,9 @@ impl Entity for Conversation {
impl Conversation {
fn new(
credential: Rc<RefCell<ProviderCredential>>,
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
completion_provider: Box<dyn CompletionProvider>,
) -> Self {
let markdown = language_registry.language_for_name("Markdown");
let buffer = cx.add_model(|cx| {
@ -1529,8 +1489,8 @@ impl Conversation {
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: None,
credential,
buffer,
completion_provider,
};
let message = MessageAnchor {
id: MessageId(post_inc(&mut this.next_message_id.0)),
@ -1576,7 +1536,6 @@ impl Conversation {
fn deserialize(
saved_conversation: SavedConversation,
path: PathBuf,
credential: Rc<RefCell<ProviderCredential>>,
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
) -> Self {
@ -1585,6 +1544,10 @@ impl Conversation {
None => Some(Uuid::new_v4().to_string()),
};
let model = saved_conversation.model;
let completion_provider: Box<dyn CompletionProvider> = Box::new(
OpenAICompletionProvider::new(model.full_name(), cx.background().clone()),
);
completion_provider.retrieve_credentials(cx);
let markdown = language_registry.language_for_name("Markdown");
let mut message_anchors = Vec::new();
let mut next_message_id = MessageId(0);
@ -1631,8 +1594,8 @@ impl Conversation {
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: Some(path),
credential,
buffer,
completion_provider,
};
this.count_remaining_tokens(cx);
this
@ -1753,12 +1716,8 @@ impl Conversation {
}
if should_assist {
let credential = self.credential.borrow().clone();
match credential {
ProviderCredential::NoCredentials => {
return Default::default();
}
_ => {}
if !self.completion_provider.has_credentials() {
return Default::default();
}
let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
@ -1773,7 +1732,7 @@ impl Conversation {
temperature: 1.0,
});
let stream = stream_completion(credential, cx.background().clone(), request);
let stream = self.completion_provider.complete(request);
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
@ -1791,33 +1750,28 @@ impl Conversation {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
this.upgrade(&cx)
.ok_or_else(|| anyhow!("conversation was dropped"))?
.update(&mut cx, |this, cx| {
let text: Arc<str> = choice.delta.content?.into();
let message_ix =
this.message_anchors.iter().position(|message| {
message.id == assistant_message_id
})?;
this.buffer.update(cx, |buffer, cx| {
let offset = this.message_anchors[message_ix + 1..]
.iter()
.find(|message| message.start.is_valid(buffer))
.map_or(buffer.len(), |message| {
message
.start
.to_offset(buffer)
.saturating_sub(1)
});
buffer.edit([(offset..offset, text)], None, cx);
});
cx.emit(ConversationEvent::StreamedCompletion);
let text = message?;
Some(())
this.upgrade(&cx)
.ok_or_else(|| anyhow!("conversation was dropped"))?
.update(&mut cx, |this, cx| {
let message_ix = this
.message_anchors
.iter()
.position(|message| message.id == assistant_message_id)?;
this.buffer.update(cx, |buffer, cx| {
let offset = this.message_anchors[message_ix + 1..]
.iter()
.find(|message| message.start.is_valid(buffer))
.map_or(buffer.len(), |message| {
message.start.to_offset(buffer).saturating_sub(1)
});
buffer.edit([(offset..offset, text)], None, cx);
});
}
cx.emit(ConversationEvent::StreamedCompletion);
Some(())
});
smol::future::yield_now().await;
}
@ -2039,13 +1993,8 @@ impl Conversation {
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
if self.message_anchors.len() >= 2 && self.summary.is_none() {
let credential = self.credential.borrow().clone();
match credential {
ProviderCredential::NoCredentials => {
return;
}
_ => {}
if !self.completion_provider.has_credentials() {
return;
}
let messages = self
@ -2065,23 +2014,20 @@ impl Conversation {
temperature: 1.0,
});
let stream = stream_completion(credential, cx.background().clone(), request);
let stream = self.completion_provider.complete(request);
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
let mut message = message?;
if let Some(choice) = message.choices.pop() {
let text = choice.delta.content.unwrap_or_default();
this.update(&mut cx, |this, cx| {
this.summary
.get_or_insert(Default::default())
.text
.push_str(&text);
cx.emit(ConversationEvent::SummaryChanged);
});
}
let text = message?;
this.update(&mut cx, |this, cx| {
this.summary
.get_or_insert(Default::default())
.text
.push_str(&text);
cx.emit(ConversationEvent::SummaryChanged);
});
}
this.update(&mut cx, |this, cx| {
@ -2255,13 +2201,14 @@ struct ConversationEditor {
impl ConversationEditor {
fn new(
credential: Rc<RefCell<ProviderCredential>>,
completion_provider: Box<dyn CompletionProvider>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
workspace: WeakViewHandle<Workspace>,
cx: &mut ViewContext<Self>,
) -> Self {
let conversation = cx.add_model(|cx| Conversation::new(credential, language_registry, cx));
let conversation =
cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider));
Self::for_conversation(conversation, fs, workspace, cx)
}
@ -3450,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
mod tests {
use super::*;
use crate::MessageId;
use ai::test::FakeCompletionProvider;
use gpui::AppContext;
#[gpui::test]
@ -3457,13 +3405,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
let conversation = cx.add_model(|cx| {
Conversation::new(
Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
registry,
cx,
)
});
let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -3591,13 +3535,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
let conversation = cx.add_model(|cx| {
Conversation::new(
Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
registry,
cx,
)
});
let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -3693,13 +3633,8 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
let conversation = cx.add_model(|cx| {
Conversation::new(
Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
registry,
cx,
)
});
let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@ -3781,13 +3716,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
let conversation = cx.add_model(|cx| {
Conversation::new(
Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
registry.clone(),
cx,
)
});
let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation =
cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_0 = conversation.read(cx).message_anchors[0].id;
let message_1 = conversation.update(cx, |conversation, cx| {
@ -3824,7 +3755,6 @@ mod tests {
Conversation::deserialize(
conversation.read(cx).serialize(cx),
Default::default(),
Rc::new(RefCell::new(ProviderCredential::NotNeeded)),
registry.clone(),
cx,
)