fix(gpt4): Refactor GPT4Brain and KnowledgeBrainQA classes to add non-streaming-saving-answer (#2460)

This pull request refactors the GPT4Brain and KnowledgeBrainQA classes
to add the functionality of saving non-streaming answers. It includes
changes to the `generate_answer` method and the addition of the
`save_non_streaming_answer` method. This enhancement improves the
overall functionality and performance of the code.
This commit is contained in:
Stan Girard 2024-04-21 04:09:52 -07:00 committed by GitHub
parent ad83d7a927
commit 1f48043bb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 31 deletions

View File

@ -4,12 +4,14 @@ from uuid import UUID
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from logger import get_logger
from modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from modules.chat.dto.chats import ChatQuestion
from modules.chat.dto.inputs import CreateChatHistory
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService
logger = get_logger(__name__)
chat_service = ChatService()
@ -92,30 +94,10 @@ class GPT4Brain(KnowledgeBrainQA):
}
)
answer = model_response["answer"].content
new_chat = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"brain_id": self.brain.brain_id,
"prompt_id": self.prompt_to_use_id,
}
)
)
answer = model_response.content
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"message_time": new_chat.message_time,
"prompt_title": (
self.prompt_to_use.title if self.prompt_to_use else None
),
"brain_name": self.brain.name if self.brain else None,
"message_id": new_chat.message_id,
"brain_id": str(self.brain.brain_id) if self.brain else None,
}
return self.save_non_streaming_answer(
chat_id=chat_id,
question=question,
answer=answer,
)

View File

@ -3,17 +3,15 @@ from typing import AsyncIterable, List, Optional
from uuid import UUID
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from modules.brain.service.utils.format_chat_history import format_chat_history
from modules.prompt.service.get_prompt_to_use import get_prompt_to_use
from modules.brain.service.utils.get_prompt_to_use_id import get_prompt_to_use_id
from logger import get_logger
from models import BrainSettings
from modules.user.service.user_usage import UserUsage
from modules.brain.entity.brain_entity import BrainEntity
from modules.brain.qa_interface import QAInterface
from modules.brain.rags.quivr_rag import QuivrRAG
from modules.brain.rags.rag_interface import RAGInterface
from modules.brain.service.brain_service import BrainService
from modules.brain.service.utils.format_chat_history import format_chat_history
from modules.brain.service.utils.get_prompt_to_use_id import get_prompt_to_use_id
from modules.chat.controller.chat.utils import (
find_model_and_generate_metadata,
update_user_usage,
@ -22,9 +20,11 @@ from modules.chat.dto.chats import ChatQuestion, Sources
from modules.chat.dto.inputs import CreateChatHistory
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService
from modules.prompt.service.get_prompt_to_use import get_prompt_to_use
from modules.upload.service.generate_file_signed_url import generate_file_signed_url
from modules.user.service.user_usage import UserUsage
from pydantic import BaseModel, ConfigDict
from pydantic_settings import BaseSettings
from modules.upload.service.generate_file_signed_url import generate_file_signed_url
logger = get_logger(__name__)
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
@ -384,3 +384,31 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
)
except Exception as e:
logger.error("Error updating message by ID: %s", e)
def save_non_streaming_answer(self, chat_id, question, answer):
new_chat = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"brain_id": self.brain.brain_id,
"prompt_id": self.prompt_to_use_id,
}
)
)
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"message_time": new_chat.message_time,
"prompt_title": (
self.prompt_to_use.title if self.prompt_to_use else None
),
"brain_name": self.brain.name if self.brain else None,
"message_id": new_chat.message_id,
"brain_id": str(self.brain.brain_id) if self.brain else None,
}
)