quivr/backend/modules/brain/integrations/Proxy/Brain.py
Stan Girard 0dd1d12e6a
Add config parameter to conversational_qa_chain (#2540)
This pull request adds a new config parameter to the
`conversational_qa_chain` function. The config parameter allows for
passing metadata, specifically the conversation ID, to the function.
This change ensures that the conversation ID is included in the metadata
when invoking the `conversational_qa_chain` function.
2024-05-04 11:05:29 -07:00

102 lines
3.2 KiB
Python

import json
from typing import AsyncIterable
from uuid import UUID
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from logger import get_logger
from modules.brain.knowledge_brain_qa import KnowledgeBrainQA
from modules.chat.dto.chats import ChatQuestion
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService
logger = get_logger(__name__)
chat_service = ChatService()
class ProxyBrain(KnowledgeBrainQA):
"""This is the Proxy brain class.
Args:
KnowledgeBrainQA (_type_): A brain that store the knowledge internaly
"""
def __init__(
self,
**kwargs,
):
super().__init__(
**kwargs,
)
def get_chain(self):
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are Quivr. You are an assistant. {custom_personality}",
),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{question}"),
]
)
chain = prompt | ChatLiteLLM(model=self.model, max_tokens=self.max_tokens)
return chain
async def generate_stream(
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> AsyncIterable:
conversational_qa_chain = self.get_chain()
transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question)
)
response_tokens = []
config = {"metadata": {"conversation_id": str(chat_id)}}
async for chunk in conversational_qa_chain.astream(
{
"question": question.question,
"chat_history": transformed_history,
"custom_personality": (
self.prompt_to_use.content if self.prompt_to_use else None
),
},
config=config,
):
response_tokens.append(chunk.content)
streamed_chat_history.assistant = chunk.content
yield f"data: {json.dumps(streamed_chat_history.dict())}"
self.save_answer(question, response_tokens, streamed_chat_history, save_answer)
def generate_answer(
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> GetChatHistoryOutput:
conversational_qa_chain = self.get_chain()
transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question)
)
config = {"metadata": {"conversation_id": str(chat_id)}}
model_response = conversational_qa_chain.invoke(
{
"question": question.question,
"chat_history": transformed_history,
"custom_personality": (
self.prompt_to_use.content if self.prompt_to_use else None
),
},
config=config,
)
answer = model_response.content
return self.save_non_streaming_answer(
chat_id=chat_id,
question=question,
answer=answer,
)