feat(history): max tokens in the history provided (#2487)

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
This commit is contained in:
Stan Girard 2024-04-24 14:09:55 -07:00 committed by GitHub
parent f656dbcb42
commit 5c5e022990
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,7 +14,7 @@ from langchain_cohere import CohereRerank
from langchain_community.chat_models import ChatLiteLLM from langchain_community.chat_models import ChatLiteLLM
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnablePassthrough from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAIEmbeddings
from logger import get_logger from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain' from models import BrainSettings # Importing settings related to the 'brain'
@ -203,6 +203,40 @@ class QuivrRAG(BaseModel):
def get_retriever(self): def get_retriever(self):
return self.vector_store.as_retriever() return self.vector_store.as_retriever()
def filter_history(
self, chat_history, max_history: int = 10, max_tokens: int = 2000
):
"""
Filter out the chat history to only include the messages that are relevant to the current question
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '), AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer, sous la direction de son supérieur hiérarchique, Stanislas Girard."), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
a token is 4 characters
"""
chat_history = chat_history[::-1]
total_tokens = 0
total_pairs = 0
filtered_chat_history = []
for i in range(0, len(chat_history), 2):
if i + 1 < len(chat_history):
human_message = chat_history[i]
ai_message = chat_history[i + 1]
message_tokens = (
len(human_message.content) + len(ai_message.content)
) // 4
if (
total_tokens + message_tokens > max_tokens
or total_pairs >= max_history
):
break
filtered_chat_history.append(human_message)
filtered_chat_history.append(ai_message)
total_tokens += message_tokens
total_pairs += 1
chat_history = filtered_chat_history[::-1]
return chat_history
def get_chain(self): def get_chain(self):
compressor = None compressor = None
if os.getenv("COHERE_API_KEY"): if os.getenv("COHERE_API_KEY"):
@ -216,7 +250,9 @@ class QuivrRAG(BaseModel):
) )
loaded_memory = RunnablePassthrough.assign( loaded_memory = RunnablePassthrough.assign(
chat_history=lambda x: x["chat_history"], chat_history=RunnableLambda(
lambda x: self.filter_history(x["chat_history"]),
),
question=lambda x: x["question"], question=lambda x: x["question"],
) )
@ -227,7 +263,7 @@ class QuivrRAG(BaseModel):
standalone_question = { standalone_question = {
"standalone_question": { "standalone_question": {
"question": lambda x: x["question"], "question": lambda x: x["question"],
"chat_history": lambda x: x["chat_history"], "chat_history": itemgetter("chat_history"),
} }
| CONDENSE_QUESTION_PROMPT | CONDENSE_QUESTION_PROMPT
| ChatLiteLLM(temperature=0, model=self.model, api_base=api_base) | ChatLiteLLM(temperature=0, model=self.model, api_base=api_base)