diff --git a/backend/modules/brain/knowledge_brain_qa.py b/backend/modules/brain/knowledge_brain_qa.py index 39afdb6a4..aef2fe072 100644 --- a/backend/modules/brain/knowledge_brain_qa.py +++ b/backend/modules/brain/knowledge_brain_qa.py @@ -181,6 +181,8 @@ class KnowledgeBrainQA(BaseModel, QAInterface): brain_id=brain_id, chat_id=chat_id, streaming=streaming, + max_input=self.max_input, + max_tokens=self.max_tokens, **kwargs, ) diff --git a/backend/modules/brain/rags/quivr_rag.py b/backend/modules/brain/rags/quivr_rag.py index a41c080d3..f1b39ffc3 100644 --- a/backend/modules/brain/rags/quivr_rag.py +++ b/backend/modules/brain/rags/quivr_rag.py @@ -5,14 +5,17 @@ from uuid import UUID from langchain.chains import ConversationalRetrievalChain from langchain.embeddings.ollama import OllamaEmbeddings from langchain.llms.base import BaseLLM -from langchain.memory import ConversationBufferMemory from langchain.prompts import HumanMessagePromptTemplate from langchain.schema import format_document from langchain_community.chat_models import ChatLiteLLM -from langchain_core.messages import SystemMessage, get_buffer_string +from langchain_core.messages import SystemMessage from langchain_core.output_parsers import StrOutputParser -from langchain_core.prompts import ChatPromptTemplate, PromptTemplate -from langchain_core.runnables import RunnableLambda, RunnablePassthrough +from langchain_core.prompts import ( + ChatPromptTemplate, + MessagesPlaceholder, + PromptTemplate, +) +from langchain_core.runnables import RunnablePassthrough from langchain_openai import OpenAIEmbeddings from llm.utils.get_prompt_to_use import get_prompt_to_use from logger import get_logger @@ -53,6 +56,7 @@ ANSWER_PROMPT = ChatPromptTemplate.from_messages( "When answering use markdown or any other techniques to display the content in a nice and aerated way. Use the following pieces of context from files provided by the user to answer the users question in the same language as the user question. Your name is Quivr. You're a helpful assistant. If you don't know the answer with the context provided from the files, just say that you don't know, don't try to make up an answer." ) ), + MessagesPlaceholder(variable_name="chat_history", optional=False), HumanMessagePromptTemplate.from_template(template_answer), ] ) @@ -201,23 +205,17 @@ class QuivrRAG(BaseModel): def get_chain(self): retriever_doc = self.get_retriever() - memory = ConversationBufferMemory( - return_messages=True, output_key="answer", input_key="question" - ) - - loaded_memory = RunnablePassthrough.assign( - chat_history=RunnableLambda(memory.load_memory_variables) - | itemgetter("history"), - ) + _inputs = RunnablePassthrough() standalone_question = { "standalone_question": { - "question": lambda x: x["question"], - "chat_history": lambda x: get_buffer_string(x["chat_history"]), + "question": itemgetter("question"), + "chat_history": itemgetter("chat_history"), } | CONDENSE_QUESTION_PROMPT | ChatLiteLLM(temperature=0, model=self.model) | StrOutputParser(), + "chat_history": itemgetter("chat_history"), } prompt_custom_user = self.prompt_to_use() @@ -230,12 +228,14 @@ class QuivrRAG(BaseModel): "docs": itemgetter("standalone_question") | retriever_doc, "question": lambda x: x["standalone_question"], "custom_instructions": lambda x: prompt_to_use, + "chat_history": itemgetter("chat_history"), } final_inputs = { "context": lambda x: self._combine_documents(x["docs"]), "question": itemgetter("question"), "custom_instructions": itemgetter("custom_instructions"), + "chat_history": itemgetter("chat_history"), } # And finally, we do the part that returns the answers @@ -246,4 +246,4 @@ class QuivrRAG(BaseModel): "docs": itemgetter("docs"), } - return loaded_memory | standalone_question | retrieved_documents | answer + return _inputs | standalone_question | retrieved_documents | answer diff --git a/backend/modules/chat/dto/chats.py b/backend/modules/chat/dto/chats.py index a3eaa0f69..f31909dda 100644 --- a/backend/modules/chat/dto/chats.py +++ b/backend/modules/chat/dto/chats.py @@ -4,7 +4,7 @@ from uuid import UUID from modules.chat.dto.outputs import GetChatHistoryOutput from modules.notification.entity.notification import Notification -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel class ChatMessage(BaseModel):