fix: refresh local ai state when opening workspace (#5961)

* chore: fix local ai state when open other workspace

* chore: fix duplicate message
This commit is contained in:
Nathan.fooo 2024-08-14 16:58:56 +08:00 committed by GitHub
parent 6d496b2088
commit fa230907ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 58 additions and 42 deletions

View File

@ -66,6 +66,11 @@ impl AIManager {
}
}
pub async fn initialize(&self, _workspace_id: &str) -> Result<(), FlowyError> {
self.local_ai_controller.refresh().await?;
Ok(())
}
pub async fn open_chat(&self, chat_id: &str) -> Result<(), FlowyError> {
trace!("open chat: {}", chat_id);
self.chats.entry(chat_id.to_string()).or_insert_with(|| {

View File

@ -140,11 +140,7 @@ impl Chat {
let _ = question_sink.send(StreamMessage::Done.to_string()).await;
// Save message to disk
save_chat_message(
self.user_service.sqlite_connection(uid)?,
&self.chat_id,
vec![question.clone()],
)?;
save_and_notify_message(uid, &self.chat_id, &self.user_service, question.clone())?;
let stop_stream = self.stop_stream.clone();
let chat_id = self.chat_id.clone();
@ -222,7 +218,7 @@ impl Chat {
let answer = cloud_service
.create_answer(&workspace_id, &chat_id, &content, question_id, metadata)
.await?;
Self::save_answer(uid, &chat_id, &user_service, answer)?;
save_and_notify_message(uid, &chat_id, &user_service, answer)?;
Ok::<(), FlowyError>(())
});
@ -230,26 +226,6 @@ impl Chat {
Ok(question_pb)
}
fn save_answer(
uid: i64,
chat_id: &str,
user_service: &Arc<dyn AIUserService>,
answer: ChatMessage,
) -> Result<(), FlowyError> {
trace!("[Chat] save answer: answer={:?}", answer);
save_chat_message(
user_service.sqlite_connection(uid)?,
chat_id,
vec![answer.clone()],
)?;
let pb = ChatMessagePB::from(answer);
make_notification(chat_id, ChatNotification::DidReceiveChatMessage)
.payload(pb)
.send();
Ok(())
}
/// Load chat messages for a given `chat_id`.
///
/// 1. When opening a chat:
@ -453,7 +429,7 @@ impl Chat {
.get_answer(&workspace_id, &self.chat_id, question_message_id)
.await?;
Self::save_answer(self.uid, &self.chat_id, &self.user_service, answer.clone())?;
save_and_notify_message(self.uid, &self.chat_id, &self.user_service, answer.clone())?;
let pb = ChatMessagePB::from(answer);
Ok(pb)
}
@ -581,3 +557,23 @@ impl StringBuffer {
std::mem::take(&mut self.content)
}
}
pub(crate) fn save_and_notify_message(
uid: i64,
chat_id: &str,
user_service: &Arc<dyn AIUserService>,
message: ChatMessage,
) -> Result<(), FlowyError> {
trace!("[Chat] save answer: answer={:?}", message);
save_chat_message(
user_service.sqlite_connection(uid)?,
chat_id,
vec![message.clone()],
)?;
let pb = ChatMessagePB::from(message);
make_notification(chat_id, ChatNotification::DidReceiveChatMessage)
.payload(pb)
.send();
Ok(())
}

View File

@ -150,7 +150,7 @@ pub(crate) async fn refresh_local_ai_info_handler(
ai_manager: AFPluginState<Weak<AIManager>>,
) -> DataResult<LLMModelInfoPB, FlowyError> {
let ai_manager = upgrade_ai_manager(ai_manager)?;
let model_info = ai_manager.local_ai_controller.refresh().await;
let model_info = ai_manager.local_ai_controller.refresh_model_info().await;
if model_info.is_err() {
if let Some(llm_model) = ai_manager.local_ai_controller.get_current_model() {
let model_info = LLMModelInfo {

View File

@ -49,6 +49,7 @@ pub struct LocalAIController {
local_ai_resource: Arc<LocalAIResourceController>,
current_chat_id: Mutex<Option<String>>,
store_preferences: Arc<KVStorePreferences>,
user_service: Arc<dyn AIUserService>,
}
impl Deref for LocalAIController {
@ -74,7 +75,11 @@ impl LocalAIController {
};
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let llm_res = Arc::new(LocalAIResourceController::new(user_service, res_impl, tx));
let llm_res = Arc::new(LocalAIResourceController::new(
user_service.clone(),
res_impl,
tx,
));
let current_chat_id = Mutex::new(None);
let mut running_state_rx = local_ai.subscribe_running_state();
@ -101,6 +106,7 @@ impl LocalAIController {
local_ai_resource: llm_res,
current_chat_id,
store_preferences,
user_service,
};
let rag_enabled = this.is_rag_enabled();
@ -142,7 +148,13 @@ impl LocalAIController {
this
}
pub async fn refresh(&self) -> FlowyResult<LLMModelInfo> {
pub async fn refresh(&self) -> FlowyResult<()> {
let is_enabled = self.is_enabled();
self.enable_chat_plugin(is_enabled).await?;
Ok(())
}
pub async fn refresh_model_info(&self) -> FlowyResult<LLMModelInfo> {
self.local_ai_resource.refresh_llm_resource().await
}
@ -158,10 +170,16 @@ impl LocalAIController {
/// Indicate whether the local AI is enabled.
pub fn is_enabled(&self) -> bool {
self
.store_preferences
.get_bool(APPFLOWY_LOCAL_AI_ENABLED)
.unwrap_or(true)
if let Ok(key) = self.local_ai_enabled_key() {
self.store_preferences.get_bool(&key).unwrap_or(true)
} else {
false
}
}
fn local_ai_enabled_key(&self) -> FlowyResult<String> {
let workspace_id = self.user_service.workspace_id()?;
Ok(format!("{}:{}", APPFLOWY_LOCAL_AI_ENABLED, workspace_id))
}
/// Indicate whether the local AI chat is enabled. In the future, we can support multiple
@ -297,13 +315,9 @@ impl LocalAIController {
}
pub async fn toggle_local_ai(&self) -> FlowyResult<bool> {
let enabled = !self
.store_preferences
.get_bool(APPFLOWY_LOCAL_AI_ENABLED)
.unwrap_or(true);
self
.store_preferences
.set_bool(APPFLOWY_LOCAL_AI_ENABLED, enabled)?;
let key = self.local_ai_enabled_key()?;
let enabled = !self.store_preferences.get_bool(&key).unwrap_or(true);
self.store_preferences.set_bool(&key, enabled)?;
// when enable local ai. we need to check if chat is enabled, if enabled, we need to init chat plugin
// otherwise, we need to destroy the plugin

View File

@ -182,13 +182,14 @@ impl UserStatusCallback for UserStatusCallbackImpl {
Ok(())
}
async fn open_workspace(&self, user_id: i64, _user_workspace: &UserWorkspace) -> FlowyResult<()> {
async fn open_workspace(&self, user_id: i64, user_workspace: &UserWorkspace) -> FlowyResult<()> {
self
.folder_manager
.initialize_with_workspace_id(user_id)
.await?;
self.database_manager.initialize(user_id).await?;
self.document_manager.initialize(user_id).await?;
self.ai_manager.initialize(&user_workspace.id).await?;
Ok(())
}