mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 17:03:29 +03:00
feat(history): max tokens in the history provided (#2487)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
f656dbcb42
commit
5c5e022990
@ -14,7 +14,7 @@ from langchain_cohere import CohereRerank
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from logger import get_logger
|
||||
from models import BrainSettings # Importing settings related to the 'brain'
|
||||
@ -203,6 +203,40 @@ class QuivrRAG(BaseModel):
|
||||
def get_retriever(self):
|
||||
return self.vector_store.as_retriever()
|
||||
|
||||
def filter_history(
|
||||
self, chat_history, max_history: int = 10, max_tokens: int = 2000
|
||||
):
|
||||
"""
|
||||
Filter out the chat history to only include the messages that are relevant to the current question
|
||||
|
||||
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '), AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer, sous la direction de son supérieur hiérarchique, Stanislas Girard."), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
|
||||
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
|
||||
a token is 4 characters
|
||||
"""
|
||||
chat_history = chat_history[::-1]
|
||||
total_tokens = 0
|
||||
total_pairs = 0
|
||||
filtered_chat_history = []
|
||||
for i in range(0, len(chat_history), 2):
|
||||
if i + 1 < len(chat_history):
|
||||
human_message = chat_history[i]
|
||||
ai_message = chat_history[i + 1]
|
||||
message_tokens = (
|
||||
len(human_message.content) + len(ai_message.content)
|
||||
) // 4
|
||||
if (
|
||||
total_tokens + message_tokens > max_tokens
|
||||
or total_pairs >= max_history
|
||||
):
|
||||
break
|
||||
filtered_chat_history.append(human_message)
|
||||
filtered_chat_history.append(ai_message)
|
||||
total_tokens += message_tokens
|
||||
total_pairs += 1
|
||||
chat_history = filtered_chat_history[::-1]
|
||||
|
||||
return chat_history
|
||||
|
||||
def get_chain(self):
|
||||
compressor = None
|
||||
if os.getenv("COHERE_API_KEY"):
|
||||
@ -216,7 +250,9 @@ class QuivrRAG(BaseModel):
|
||||
)
|
||||
|
||||
loaded_memory = RunnablePassthrough.assign(
|
||||
chat_history=lambda x: x["chat_history"],
|
||||
chat_history=RunnableLambda(
|
||||
lambda x: self.filter_history(x["chat_history"]),
|
||||
),
|
||||
question=lambda x: x["question"],
|
||||
)
|
||||
|
||||
@ -227,7 +263,7 @@ class QuivrRAG(BaseModel):
|
||||
standalone_question = {
|
||||
"standalone_question": {
|
||||
"question": lambda x: x["question"],
|
||||
"chat_history": lambda x: x["chat_history"],
|
||||
"chat_history": itemgetter("chat_history"),
|
||||
}
|
||||
| CONDENSE_QUESTION_PROMPT
|
||||
| ChatLiteLLM(temperature=0, model=self.model, api_base=api_base)
|
||||
|
Loading…
Reference in New Issue
Block a user