diff --git a/backend/modules/brain/rags/quivr_rag.py b/backend/modules/brain/rags/quivr_rag.py index 3a925a61d..8487bc33b 100644 --- a/backend/modules/brain/rags/quivr_rag.py +++ b/backend/modules/brain/rags/quivr_rag.py @@ -14,7 +14,7 @@ from langchain_cohere import CohereRerank from langchain_community.chat_models import ChatLiteLLM from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate, PromptTemplate -from langchain_core.runnables import RunnablePassthrough +from langchain_core.runnables import RunnableLambda, RunnablePassthrough from langchain_openai import OpenAIEmbeddings from logger import get_logger from models import BrainSettings # Importing settings related to the 'brain' @@ -203,6 +203,40 @@ class QuivrRAG(BaseModel): def get_retriever(self): return self.vector_store.as_retriever() + def filter_history( + self, chat_history, max_history: int = 10, max_tokens: int = 2000 + ): + """ + Filter out the chat history to only include the messages that are relevant to the current question + + Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '), AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer, sous la direction de son supérieur hiérarchique, Stanislas Girard."), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")] + Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair + a token is 4 characters + """ + chat_history = chat_history[::-1] + total_tokens = 0 + total_pairs = 0 + filtered_chat_history = [] + for i in range(0, len(chat_history), 2): + if i + 1 < len(chat_history): + human_message = chat_history[i] + ai_message = chat_history[i + 1] + message_tokens = ( + len(human_message.content) + len(ai_message.content) + ) // 4 + if ( + total_tokens + message_tokens > max_tokens + or total_pairs >= max_history + ): + break + filtered_chat_history.append(human_message) + filtered_chat_history.append(ai_message) + total_tokens += message_tokens + total_pairs += 1 + chat_history = filtered_chat_history[::-1] + + return chat_history + def get_chain(self): compressor = None if os.getenv("COHERE_API_KEY"): @@ -216,7 +250,9 @@ class QuivrRAG(BaseModel): ) loaded_memory = RunnablePassthrough.assign( - chat_history=lambda x: x["chat_history"], + chat_history=RunnableLambda( + lambda x: self.filter_history(x["chat_history"]), + ), question=lambda x: x["question"], ) @@ -227,7 +263,7 @@ class QuivrRAG(BaseModel): standalone_question = { "standalone_question": { "question": lambda x: x["question"], - "chat_history": lambda x: x["chat_history"], + "chat_history": itemgetter("chat_history"), } | CONDENSE_QUESTION_PROMPT | ChatLiteLLM(temperature=0, model=self.model, api_base=api_base)