mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-18 20:01:52 +03:00
21e239c208
This pull request adds the ProxyBrain integration to the project. The ProxyBrain class is responsible for handling conversational QA and generating answers based on the provided chat history and question.
98 lines
3.0 KiB
Python
98 lines
3.0 KiB
Python
import json
|
|
from typing import AsyncIterable
|
|
from uuid import UUID
|
|
|
|
from langchain_community.chat_models import ChatLiteLLM
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
from logger import get_logger
|
|
from modules.brain.knowledge_brain_qa import KnowledgeBrainQA
|
|
from modules.chat.dto.chats import ChatQuestion
|
|
from modules.chat.dto.outputs import GetChatHistoryOutput
|
|
from modules.chat.service.chat_service import ChatService
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
chat_service = ChatService()
|
|
|
|
|
|
class ProxyBrain(KnowledgeBrainQA):
|
|
"""This is the Proxy brain class.
|
|
|
|
Args:
|
|
KnowledgeBrainQA (_type_): A brain that store the knowledge internaly
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
**kwargs,
|
|
)
|
|
|
|
def get_chain(self):
|
|
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
(
|
|
"system",
|
|
"You are Quivr. You are an assistant. {custom_personality}",
|
|
),
|
|
MessagesPlaceholder(variable_name="chat_history"),
|
|
("human", "{question}"),
|
|
]
|
|
)
|
|
|
|
chain = prompt | ChatLiteLLM(model=self.model, max_tokens=self.max_tokens)
|
|
|
|
return chain
|
|
|
|
async def generate_stream(
|
|
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
|
) -> AsyncIterable:
|
|
conversational_qa_chain = self.get_chain()
|
|
transformed_history, streamed_chat_history = (
|
|
self.initialize_streamed_chat_history(chat_id, question)
|
|
)
|
|
response_tokens = []
|
|
|
|
async for chunk in conversational_qa_chain.astream(
|
|
{
|
|
"question": question.question,
|
|
"chat_history": transformed_history,
|
|
"custom_personality": (
|
|
self.prompt_to_use.content if self.prompt_to_use else None
|
|
),
|
|
}
|
|
):
|
|
response_tokens.append(chunk.content)
|
|
streamed_chat_history.assistant = chunk.content
|
|
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
|
|
|
self.save_answer(question, response_tokens, streamed_chat_history, save_answer)
|
|
|
|
def generate_answer(
|
|
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
|
) -> GetChatHistoryOutput:
|
|
conversational_qa_chain = self.get_chain()
|
|
transformed_history, streamed_chat_history = (
|
|
self.initialize_streamed_chat_history(chat_id, question)
|
|
)
|
|
model_response = conversational_qa_chain.invoke(
|
|
{
|
|
"question": question.question,
|
|
"chat_history": transformed_history,
|
|
"custom_personality": (
|
|
self.prompt_to_use.content if self.prompt_to_use else None
|
|
),
|
|
}
|
|
)
|
|
|
|
answer = model_response.content
|
|
|
|
return self.save_non_streaming_answer(
|
|
chat_id=chat_id,
|
|
question=question,
|
|
answer=answer,
|
|
)
|