mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-11-09 20:47:28 +03:00
fix(gpt4): Refactor GPT4Brain and KnowledgeBrainQA classes to add non-streaming-saving-answer (#2460)
This pull request refactors the GPT4Brain and KnowledgeBrainQA classes to add the functionality of saving non-streaming answers. It includes changes to the `generate_answer` method and the addition of the `save_non_streaming_answer` method. This enhancement improves the overall functionality and performance of the code.
This commit is contained in:
parent
ad83d7a927
commit
1f48043bb9
@ -4,12 +4,14 @@ from uuid import UUID
|
||||
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from logger import get_logger
|
||||
from modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from modules.chat.dto.chats import ChatQuestion
|
||||
from modules.chat.dto.inputs import CreateChatHistory
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
chat_service = ChatService()
|
||||
|
||||
|
||||
@ -92,30 +94,10 @@ class GPT4Brain(KnowledgeBrainQA):
|
||||
}
|
||||
)
|
||||
|
||||
answer = model_response["answer"].content
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": self.brain.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
answer = model_response.content
|
||||
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": self.brain.name if self.brain else None,
|
||||
"message_id": new_chat.message_id,
|
||||
"brain_id": str(self.brain.brain_id) if self.brain else None,
|
||||
}
|
||||
return self.save_non_streaming_answer(
|
||||
chat_id=chat_id,
|
||||
question=question,
|
||||
answer=answer,
|
||||
)
|
||||
|
@ -3,17 +3,15 @@ from typing import AsyncIterable, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from modules.brain.service.utils.format_chat_history import format_chat_history
|
||||
from modules.prompt.service.get_prompt_to_use import get_prompt_to_use
|
||||
from modules.brain.service.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from logger import get_logger
|
||||
from models import BrainSettings
|
||||
from modules.user.service.user_usage import UserUsage
|
||||
from modules.brain.entity.brain_entity import BrainEntity
|
||||
from modules.brain.qa_interface import QAInterface
|
||||
from modules.brain.rags.quivr_rag import QuivrRAG
|
||||
from modules.brain.rags.rag_interface import RAGInterface
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.brain.service.utils.format_chat_history import format_chat_history
|
||||
from modules.brain.service.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from modules.chat.controller.chat.utils import (
|
||||
find_model_and_generate_metadata,
|
||||
update_user_usage,
|
||||
@ -22,9 +20,11 @@ from modules.chat.dto.chats import ChatQuestion, Sources
|
||||
from modules.chat.dto.inputs import CreateChatHistory
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
from modules.prompt.service.get_prompt_to_use import get_prompt_to_use
|
||||
from modules.upload.service.generate_file_signed_url import generate_file_signed_url
|
||||
from modules.user.service.user_usage import UserUsage
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic_settings import BaseSettings
|
||||
from modules.upload.service.generate_file_signed_url import generate_file_signed_url
|
||||
|
||||
logger = get_logger(__name__)
|
||||
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
||||
@ -384,3 +384,31 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error updating message by ID: %s", e)
|
||||
|
||||
def save_non_streaming_answer(self, chat_id, question, answer):
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": self.brain.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": self.brain.name if self.brain else None,
|
||||
"message_id": new_chat.message_id,
|
||||
"brain_id": str(self.brain.brain_id) if self.brain else None,
|
||||
}
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user