quivr/backend/modules/brain/rags/quivr_rag.py
AmineDiro 675885c762
feat(upload): async improved (#2544)
# Description
Hey,

Here's a breakdown of what I've done:

- Reducing the number of opened fd and memory footprint: Previously, for
each uploaded file, we were opening a temporary NamedTemporaryFile to
write existing content read from Supabase. However, due to the
dependency on `langchain` loader classes, we couldn't use memory buffers
for the loaders. Now, with the changes made, we only open a single
temporary file for each `process_file_and_notify`, cutting down on
excessive file opening, read syscalls, and memory buffer usage. This
could cause stability issues when ingesting and processing large volumes
of documents. Unfortunately, there is still reopening of temporary files
in some code paths but this can be improved further in later work.
- Removing `UploadFile` class from File: The `UploadFile` ( a FastAPI
abstraction over a SpooledTemporaryFile for multipart upload) was
redundant in our `File` setup since we already downloaded the file from
remote storage and read it into memory + wrote the file into a temp
file. By removing this abstraction, we streamline our code and eliminate
unnecessary complexity.
- `async` function Adjustments: I've removed the async labeling from
functions where it wasn't truly asynchronous. For instance, calling
`filter_file` for processing files isn't genuinely async, ass async file
reading isn't actually asynchronous—it [uses a threadpool for reading
the
file](9f16bf5c25/starlette/datastructures.py (L458))
. Given that we're already leveraging `celery` for parallelism (one
worker per core), we need to ensure that reading and processing occur in
the same thread, or at least minimize thread spawning. Additionally,
since the rest of the code isn't inherently asynchronous, our bottleneck
lies in CPU operations rather than asynchronous processing.

These changes aim to improve performance and streamline our codebase. 
Let me know if you have any questions or suggestions for further
improvements!

## Checklist before requesting a review
- [x] My code follows the style guidelines of this project
- [x] I have performed a self-review of my code
- [x] I have ideally added tests that prove my fix is effective or that
my feature works

---------

Signed-off-by: aminediro <aminediro@github.com>
Co-authored-by: aminediro <aminediro@github.com>
Co-authored-by: Stan Girard <girard.stanislas@gmail.com>
2024-06-04 06:29:27 -07:00

399 lines
14 KiB
Python

import datetime
import os
from operator import itemgetter
from typing import List, Optional
from uuid import UUID
from langchain.chains import ConversationalRetrievalChain
from langchain.llms.base import BaseLLM
from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain.schema import format_document
from langchain_cohere import CohereRerank
from langchain_community.chat_models import ChatLiteLLM
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
from langchain_core.pydantic_v1 import Field as FieldV1
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
from models.settings import get_supabase_client
from modules.brain.service.brain_service import BrainService
from modules.chat.service.chat_service import ChatService
from modules.knowledge.repository.knowledges import Knowledges
from modules.prompt.service.get_prompt_to_use import get_prompt_to_use
from pydantic import BaseModel, ConfigDict
from pydantic_settings import BaseSettings
from supabase.client import Client
from vectorstore.supabase import CustomSupabaseVectorStore
logger = get_logger(__name__)
class cited_answer(BaseModelV1):
"""Answer the user question based only on the given sources, and cite the sources used."""
thoughts: str = FieldV1(
...,
description="""Description of the thought process, based only on the given sources.
Cite the text as much as possible and give the document name it appears in. In the format : 'Doc_name states : cited_text'. Be the most
procedural as possible. Write all the steps needed to find the answer until you find it.""",
)
answer: str = FieldV1(
...,
description="The answer to the user question, which is based only on the given sources.",
)
citations: List[int] = FieldV1(
...,
description="The integer IDs of the SPECIFIC sources which justify the answer.",
)
thoughts: str = FieldV1(
...,
description="Explain shortly what you did to find the answer and what you used by citing the sources by their name.",
)
followup_questions: List[str] = FieldV1(
...,
description="Generate up to 3 follow-up questions that could be asked based on the answer given or context provided.",
)
# First step is to create the Rephrasing Prompt
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. Keep as much details as possible from previous messages. Keep entity names and all.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
# Next is the answering prompt
template_answer = """
Context:
{context}
User Question: {question}
Answer:
"""
today_date = datetime.datetime.now().strftime("%B %d, %Y")
system_message_template = (
f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}."
)
system_message_template += """
When answering use markdown.
Use markdown code blocks for code snippets.
Answer in a concise and clear manner.
Use the following pieces of context from files provided by the user to answer the users.
Answer in the same language as the user question.
If you don't know the answer with the context provided from the files, just say that you don't know, don't try to make up an answer.
Don't cite the source id in the answer objects, but you can use the source to answer the question.
You have access to the files to answer the user question (limited to first 20 files):
{files}
If not None, User instruction to follow to answer: {custom_instructions}
Don't cite the source id in the answer objects, but you can use the source to answer the question.
"""
ANSWER_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_template),
HumanMessagePromptTemplate.from_template(template_answer),
]
)
# How we format documents
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
template="Source: {index} \n {page_content}"
)
def is_valid_uuid(uuid_to_test, version=4):
try:
uuid_obj = UUID(uuid_to_test, version=version)
except ValueError:
return False
return str(uuid_obj) == uuid_to_test
brain_service = BrainService()
chat_service = ChatService()
class QuivrRAG(BaseModel):
"""
Quivr implementation of the RAGInterface.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
# Instantiate settings
brain_settings: BaseSettings = BrainSettings()
# Default class attributes
model: str = None # pyright: ignore reportPrivateUsage=none
temperature: float = 0.1
chat_id: str = None # pyright: ignore reportPrivateUsage=none
brain_id: str = None # pyright: ignore reportPrivateUsage=none
max_tokens: int = 2000 # Output length
max_input: int = 2000
streaming: bool = False
knowledge_service: Knowledges = None
@property
def embeddings(self):
if self.brain_settings.ollama_api_base_url:
return OllamaEmbeddings(
base_url=self.brain_settings.ollama_api_base_url
) # pyright: ignore reportPrivateUsage=none
else:
return OpenAIEmbeddings()
def prompt_to_use(self):
if self.brain_id and is_valid_uuid(self.brain_id):
return get_prompt_to_use(UUID(self.brain_id), self.prompt_id)
else:
return None
def model_compatible_with_function_calling(self):
if self.model in [
"gpt-4o",
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-0613",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
]:
return True
return False
supabase_client: Optional[Client] = None
vector_store: Optional[CustomSupabaseVectorStore] = None
qa: Optional[ConversationalRetrievalChain] = None
prompt_id: Optional[UUID] = None
def __init__(
self,
model: str,
brain_id: str,
chat_id: str,
streaming: bool = False,
prompt_id: Optional[UUID] = None,
max_tokens: int = 2000,
max_input: int = 2000,
**kwargs,
):
super().__init__(
model=model,
brain_id=brain_id,
chat_id=chat_id,
streaming=streaming,
max_tokens=max_tokens,
max_input=max_input,
**kwargs,
)
self.supabase_client = self._create_supabase_client()
self.vector_store = self._create_vector_store()
self.prompt_id = prompt_id
self.max_tokens = max_tokens
self.max_input = max_input
self.model = model
self.brain_id = brain_id
self.chat_id = chat_id
self.streaming = streaming
self.knowledge_service = Knowledges()
def _create_supabase_client(self) -> Client:
return get_supabase_client()
def _create_vector_store(self) -> CustomSupabaseVectorStore:
return CustomSupabaseVectorStore(
self.supabase_client,
self.embeddings,
table_name="vectors",
brain_id=self.brain_id,
max_input=self.max_input,
)
def _create_llm(
self,
callbacks,
model,
streaming=False,
temperature=0,
) -> BaseLLM:
"""
Create a LLM with the given parameters
"""
if streaming and callbacks is None:
raise ValueError(
"Callbacks must be provided when using streaming language models"
)
api_base = None
if self.brain_settings.ollama_api_base_url and model.startswith("ollama"):
api_base = (
self.brain_settings.ollama_api_base_url # pyright: ignore reportPrivateUsage=none
)
return ChatLiteLLM(
temperature=temperature,
max_tokens=self.max_tokens,
model=model,
streaming=streaming,
verbose=False,
callbacks=callbacks,
api_base=api_base,
) # pyright: ignore reportPrivateUsage=none
def _combine_documents(
self, docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
):
# for each docs, add an index in the metadata to be able to cite the sources
for doc, index in zip(docs, range(len(docs))):
doc.metadata["index"] = index
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
def get_retriever(self):
return self.vector_store.as_retriever()
def filter_history(
self, chat_history, max_history: int = 10, max_tokens: int = 2000
):
"""
Filter out the chat history to only include the messages that are relevant to the current question
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '), AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer, sous la direction de son supérieur hiérarchique, Stanislas Girard."), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
a token is 4 characters
"""
chat_history = chat_history[::-1]
total_tokens = 0
total_pairs = 0
filtered_chat_history = []
for i in range(0, len(chat_history), 2):
if i + 1 < len(chat_history):
human_message = chat_history[i]
ai_message = chat_history[i + 1]
message_tokens = (
len(human_message.content) + len(ai_message.content)
) // 4
if (
total_tokens + message_tokens > max_tokens
or total_pairs >= max_history
):
break
filtered_chat_history.append(human_message)
filtered_chat_history.append(ai_message)
total_tokens += message_tokens
total_pairs += 1
chat_history = filtered_chat_history[::-1]
return chat_history
def get_chain(self):
list_files_array = self.knowledge_service.get_all_knowledge_in_brain(
self.brain_id
) # pyright: ignore reportPrivateUsage=none
list_files_array = [file.file_name or file.url for file in list_files_array]
# Max first 10 files
if len(list_files_array) > 20:
list_files_array = list_files_array[:20]
list_files = "\n".join(list_files_array) if list_files_array else "None"
compressor = None
if os.getenv("COHERE_API_KEY"):
compressor = CohereRerank(top_n=20)
else:
compressor = FlashrankRerank(model="ms-marco-TinyBERT-L-2-v2", top_n=20)
retriever_doc = self.get_retriever()
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever_doc
)
loaded_memory = RunnablePassthrough.assign(
chat_history=RunnableLambda(
lambda x: self.filter_history(x["chat_history"]),
),
question=lambda x: x["question"],
)
api_base = None
if self.brain_settings.ollama_api_base_url and self.model.startswith("ollama"):
api_base = self.brain_settings.ollama_api_base_url
standalone_question = {
"standalone_question": {
"question": lambda x: x["question"],
"chat_history": itemgetter("chat_history"),
}
| CONDENSE_QUESTION_PROMPT
| ChatLiteLLM(temperature=0, model=self.model, api_base=api_base)
| StrOutputParser(),
}
prompt_custom_user = self.prompt_to_use()
prompt_to_use = "None"
if prompt_custom_user:
prompt_to_use = prompt_custom_user.content
# Now we retrieve the documents
retrieved_documents = {
"docs": itemgetter("standalone_question") | compression_retriever,
"question": lambda x: x["standalone_question"],
"custom_instructions": lambda x: prompt_to_use,
}
final_inputs = {
"context": lambda x: self._combine_documents(x["docs"]),
"question": itemgetter("question"),
"custom_instructions": itemgetter("custom_instructions"),
"files": lambda x: list_files,
}
llm = ChatLiteLLM(
max_tokens=self.max_tokens,
model=self.model,
temperature=self.temperature,
api_base=api_base,
) # pyright: ignore reportPrivateUsage=none
if self.model_compatible_with_function_calling():
# And finally, we do the part that returns the answers
llm_function = ChatOpenAI(
max_tokens=self.max_tokens,
model=self.model,
temperature=self.temperature,
)
llm = llm_function.bind_tools(
[cited_answer],
tool_choice="cited_answer",
)
answer = {
"answer": final_inputs | ANSWER_PROMPT | llm,
"docs": itemgetter("docs"),
}
return loaded_memory | standalone_question | retrieved_documents | answer