mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-18 11:51:41 +03:00
675885c762
# Description
Hey,
Here's a breakdown of what I've done:
- Reducing the number of opened fd and memory footprint: Previously, for
each uploaded file, we were opening a temporary NamedTemporaryFile to
write existing content read from Supabase. However, due to the
dependency on `langchain` loader classes, we couldn't use memory buffers
for the loaders. Now, with the changes made, we only open a single
temporary file for each `process_file_and_notify`, cutting down on
excessive file opening, read syscalls, and memory buffer usage. This
could cause stability issues when ingesting and processing large volumes
of documents. Unfortunately, there is still reopening of temporary files
in some code paths but this can be improved further in later work.
- Removing `UploadFile` class from File: The `UploadFile` ( a FastAPI
abstraction over a SpooledTemporaryFile for multipart upload) was
redundant in our `File` setup since we already downloaded the file from
remote storage and read it into memory + wrote the file into a temp
file. By removing this abstraction, we streamline our code and eliminate
unnecessary complexity.
- `async` function Adjustments: I've removed the async labeling from
functions where it wasn't truly asynchronous. For instance, calling
`filter_file` for processing files isn't genuinely async, ass async file
reading isn't actually asynchronous—it [uses a threadpool for reading
the
file](9f16bf5c25/starlette/datastructures.py (L458)
)
. Given that we're already leveraging `celery` for parallelism (one
worker per core), we need to ensure that reading and processing occur in
the same thread, or at least minimize thread spawning. Additionally,
since the rest of the code isn't inherently asynchronous, our bottleneck
lies in CPU operations rather than asynchronous processing.
These changes aim to improve performance and streamline our codebase.
Let me know if you have any questions or suggestions for further
improvements!
## Checklist before requesting a review
- [x] My code follows the style guidelines of this project
- [x] I have performed a self-review of my code
- [x] I have ideally added tests that prove my fix is effective or that
my feature works
---------
Signed-off-by: aminediro <aminediro@github.com>
Co-authored-by: aminediro <aminediro@github.com>
Co-authored-by: Stan Girard <girard.stanislas@gmail.com>
399 lines
14 KiB
Python
399 lines
14 KiB
Python
import datetime
|
|
import os
|
|
from operator import itemgetter
|
|
from typing import List, Optional
|
|
from uuid import UUID
|
|
|
|
from langchain.chains import ConversationalRetrievalChain
|
|
from langchain.llms.base import BaseLLM
|
|
from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
|
|
from langchain.retrievers import ContextualCompressionRetriever
|
|
from langchain.retrievers.document_compressors import FlashrankRerank
|
|
from langchain.schema import format_document
|
|
from langchain_cohere import CohereRerank
|
|
from langchain_community.chat_models import ChatLiteLLM
|
|
from langchain_community.embeddings import OllamaEmbeddings
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
|
|
from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
|
|
from langchain_core.pydantic_v1 import Field as FieldV1
|
|
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
|
from logger import get_logger
|
|
from models import BrainSettings # Importing settings related to the 'brain'
|
|
from models.settings import get_supabase_client
|
|
from modules.brain.service.brain_service import BrainService
|
|
from modules.chat.service.chat_service import ChatService
|
|
from modules.knowledge.repository.knowledges import Knowledges
|
|
from modules.prompt.service.get_prompt_to_use import get_prompt_to_use
|
|
from pydantic import BaseModel, ConfigDict
|
|
from pydantic_settings import BaseSettings
|
|
from supabase.client import Client
|
|
from vectorstore.supabase import CustomSupabaseVectorStore
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class cited_answer(BaseModelV1):
|
|
"""Answer the user question based only on the given sources, and cite the sources used."""
|
|
|
|
thoughts: str = FieldV1(
|
|
...,
|
|
description="""Description of the thought process, based only on the given sources.
|
|
Cite the text as much as possible and give the document name it appears in. In the format : 'Doc_name states : cited_text'. Be the most
|
|
procedural as possible. Write all the steps needed to find the answer until you find it.""",
|
|
)
|
|
answer: str = FieldV1(
|
|
...,
|
|
description="The answer to the user question, which is based only on the given sources.",
|
|
)
|
|
citations: List[int] = FieldV1(
|
|
...,
|
|
description="The integer IDs of the SPECIFIC sources which justify the answer.",
|
|
)
|
|
|
|
thoughts: str = FieldV1(
|
|
...,
|
|
description="Explain shortly what you did to find the answer and what you used by citing the sources by their name.",
|
|
)
|
|
followup_questions: List[str] = FieldV1(
|
|
...,
|
|
description="Generate up to 3 follow-up questions that could be asked based on the answer given or context provided.",
|
|
)
|
|
|
|
|
|
# First step is to create the Rephrasing Prompt
|
|
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. Keep as much details as possible from previous messages. Keep entity names and all.
|
|
|
|
Chat History:
|
|
{chat_history}
|
|
Follow Up Input: {question}
|
|
Standalone question:"""
|
|
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
|
|
|
|
# Next is the answering prompt
|
|
|
|
template_answer = """
|
|
Context:
|
|
{context}
|
|
|
|
User Question: {question}
|
|
Answer:
|
|
"""
|
|
|
|
today_date = datetime.datetime.now().strftime("%B %d, %Y")
|
|
|
|
system_message_template = (
|
|
f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}."
|
|
)
|
|
|
|
system_message_template += """
|
|
When answering use markdown.
|
|
Use markdown code blocks for code snippets.
|
|
Answer in a concise and clear manner.
|
|
Use the following pieces of context from files provided by the user to answer the users.
|
|
Answer in the same language as the user question.
|
|
If you don't know the answer with the context provided from the files, just say that you don't know, don't try to make up an answer.
|
|
Don't cite the source id in the answer objects, but you can use the source to answer the question.
|
|
You have access to the files to answer the user question (limited to first 20 files):
|
|
{files}
|
|
|
|
If not None, User instruction to follow to answer: {custom_instructions}
|
|
Don't cite the source id in the answer objects, but you can use the source to answer the question.
|
|
"""
|
|
|
|
|
|
ANSWER_PROMPT = ChatPromptTemplate.from_messages(
|
|
[
|
|
SystemMessagePromptTemplate.from_template(system_message_template),
|
|
HumanMessagePromptTemplate.from_template(template_answer),
|
|
]
|
|
)
|
|
|
|
|
|
# How we format documents
|
|
|
|
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
|
|
template="Source: {index} \n {page_content}"
|
|
)
|
|
|
|
|
|
def is_valid_uuid(uuid_to_test, version=4):
|
|
try:
|
|
uuid_obj = UUID(uuid_to_test, version=version)
|
|
except ValueError:
|
|
return False
|
|
|
|
return str(uuid_obj) == uuid_to_test
|
|
|
|
|
|
brain_service = BrainService()
|
|
chat_service = ChatService()
|
|
|
|
|
|
class QuivrRAG(BaseModel):
|
|
"""
|
|
Quivr implementation of the RAGInterface.
|
|
"""
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
# Instantiate settings
|
|
brain_settings: BaseSettings = BrainSettings()
|
|
# Default class attributes
|
|
model: str = None # pyright: ignore reportPrivateUsage=none
|
|
temperature: float = 0.1
|
|
chat_id: str = None # pyright: ignore reportPrivateUsage=none
|
|
brain_id: str = None # pyright: ignore reportPrivateUsage=none
|
|
max_tokens: int = 2000 # Output length
|
|
max_input: int = 2000
|
|
streaming: bool = False
|
|
knowledge_service: Knowledges = None
|
|
|
|
@property
|
|
def embeddings(self):
|
|
if self.brain_settings.ollama_api_base_url:
|
|
return OllamaEmbeddings(
|
|
base_url=self.brain_settings.ollama_api_base_url
|
|
) # pyright: ignore reportPrivateUsage=none
|
|
else:
|
|
return OpenAIEmbeddings()
|
|
|
|
def prompt_to_use(self):
|
|
if self.brain_id and is_valid_uuid(self.brain_id):
|
|
return get_prompt_to_use(UUID(self.brain_id), self.prompt_id)
|
|
else:
|
|
return None
|
|
|
|
def model_compatible_with_function_calling(self):
|
|
if self.model in [
|
|
"gpt-4o",
|
|
"gpt-4-turbo",
|
|
"gpt-4-turbo-2024-04-09",
|
|
"gpt-4-turbo-preview",
|
|
"gpt-4-0125-preview",
|
|
"gpt-4-1106-preview",
|
|
"gpt-4",
|
|
"gpt-4-0613",
|
|
"gpt-3.5-turbo",
|
|
"gpt-3.5-turbo-0125",
|
|
"gpt-3.5-turbo-1106",
|
|
"gpt-3.5-turbo-0613",
|
|
]:
|
|
return True
|
|
return False
|
|
|
|
supabase_client: Optional[Client] = None
|
|
vector_store: Optional[CustomSupabaseVectorStore] = None
|
|
qa: Optional[ConversationalRetrievalChain] = None
|
|
prompt_id: Optional[UUID] = None
|
|
|
|
def __init__(
|
|
self,
|
|
model: str,
|
|
brain_id: str,
|
|
chat_id: str,
|
|
streaming: bool = False,
|
|
prompt_id: Optional[UUID] = None,
|
|
max_tokens: int = 2000,
|
|
max_input: int = 2000,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
model=model,
|
|
brain_id=brain_id,
|
|
chat_id=chat_id,
|
|
streaming=streaming,
|
|
max_tokens=max_tokens,
|
|
max_input=max_input,
|
|
**kwargs,
|
|
)
|
|
self.supabase_client = self._create_supabase_client()
|
|
self.vector_store = self._create_vector_store()
|
|
self.prompt_id = prompt_id
|
|
self.max_tokens = max_tokens
|
|
self.max_input = max_input
|
|
self.model = model
|
|
self.brain_id = brain_id
|
|
self.chat_id = chat_id
|
|
self.streaming = streaming
|
|
self.knowledge_service = Knowledges()
|
|
|
|
def _create_supabase_client(self) -> Client:
|
|
return get_supabase_client()
|
|
|
|
def _create_vector_store(self) -> CustomSupabaseVectorStore:
|
|
return CustomSupabaseVectorStore(
|
|
self.supabase_client,
|
|
self.embeddings,
|
|
table_name="vectors",
|
|
brain_id=self.brain_id,
|
|
max_input=self.max_input,
|
|
)
|
|
|
|
def _create_llm(
|
|
self,
|
|
callbacks,
|
|
model,
|
|
streaming=False,
|
|
temperature=0,
|
|
) -> BaseLLM:
|
|
"""
|
|
Create a LLM with the given parameters
|
|
"""
|
|
if streaming and callbacks is None:
|
|
raise ValueError(
|
|
"Callbacks must be provided when using streaming language models"
|
|
)
|
|
|
|
api_base = None
|
|
if self.brain_settings.ollama_api_base_url and model.startswith("ollama"):
|
|
api_base = (
|
|
self.brain_settings.ollama_api_base_url # pyright: ignore reportPrivateUsage=none
|
|
)
|
|
|
|
return ChatLiteLLM(
|
|
temperature=temperature,
|
|
max_tokens=self.max_tokens,
|
|
model=model,
|
|
streaming=streaming,
|
|
verbose=False,
|
|
callbacks=callbacks,
|
|
api_base=api_base,
|
|
) # pyright: ignore reportPrivateUsage=none
|
|
|
|
def _combine_documents(
|
|
self, docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
|
|
):
|
|
# for each docs, add an index in the metadata to be able to cite the sources
|
|
for doc, index in zip(docs, range(len(docs))):
|
|
doc.metadata["index"] = index
|
|
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
|
return document_separator.join(doc_strings)
|
|
|
|
def get_retriever(self):
|
|
return self.vector_store.as_retriever()
|
|
|
|
def filter_history(
|
|
self, chat_history, max_history: int = 10, max_tokens: int = 2000
|
|
):
|
|
"""
|
|
Filter out the chat history to only include the messages that are relevant to the current question
|
|
|
|
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '), AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer, sous la direction de son supérieur hiérarchique, Stanislas Girard."), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
|
|
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
|
|
a token is 4 characters
|
|
"""
|
|
chat_history = chat_history[::-1]
|
|
total_tokens = 0
|
|
total_pairs = 0
|
|
filtered_chat_history = []
|
|
for i in range(0, len(chat_history), 2):
|
|
if i + 1 < len(chat_history):
|
|
human_message = chat_history[i]
|
|
ai_message = chat_history[i + 1]
|
|
message_tokens = (
|
|
len(human_message.content) + len(ai_message.content)
|
|
) // 4
|
|
if (
|
|
total_tokens + message_tokens > max_tokens
|
|
or total_pairs >= max_history
|
|
):
|
|
break
|
|
filtered_chat_history.append(human_message)
|
|
filtered_chat_history.append(ai_message)
|
|
total_tokens += message_tokens
|
|
total_pairs += 1
|
|
chat_history = filtered_chat_history[::-1]
|
|
|
|
return chat_history
|
|
|
|
def get_chain(self):
|
|
|
|
list_files_array = self.knowledge_service.get_all_knowledge_in_brain(
|
|
self.brain_id
|
|
) # pyright: ignore reportPrivateUsage=none
|
|
|
|
list_files_array = [file.file_name or file.url for file in list_files_array]
|
|
# Max first 10 files
|
|
if len(list_files_array) > 20:
|
|
list_files_array = list_files_array[:20]
|
|
|
|
list_files = "\n".join(list_files_array) if list_files_array else "None"
|
|
|
|
compressor = None
|
|
if os.getenv("COHERE_API_KEY"):
|
|
compressor = CohereRerank(top_n=20)
|
|
else:
|
|
compressor = FlashrankRerank(model="ms-marco-TinyBERT-L-2-v2", top_n=20)
|
|
|
|
retriever_doc = self.get_retriever()
|
|
compression_retriever = ContextualCompressionRetriever(
|
|
base_compressor=compressor, base_retriever=retriever_doc
|
|
)
|
|
|
|
loaded_memory = RunnablePassthrough.assign(
|
|
chat_history=RunnableLambda(
|
|
lambda x: self.filter_history(x["chat_history"]),
|
|
),
|
|
question=lambda x: x["question"],
|
|
)
|
|
|
|
api_base = None
|
|
if self.brain_settings.ollama_api_base_url and self.model.startswith("ollama"):
|
|
api_base = self.brain_settings.ollama_api_base_url
|
|
|
|
standalone_question = {
|
|
"standalone_question": {
|
|
"question": lambda x: x["question"],
|
|
"chat_history": itemgetter("chat_history"),
|
|
}
|
|
| CONDENSE_QUESTION_PROMPT
|
|
| ChatLiteLLM(temperature=0, model=self.model, api_base=api_base)
|
|
| StrOutputParser(),
|
|
}
|
|
|
|
prompt_custom_user = self.prompt_to_use()
|
|
prompt_to_use = "None"
|
|
if prompt_custom_user:
|
|
prompt_to_use = prompt_custom_user.content
|
|
|
|
# Now we retrieve the documents
|
|
retrieved_documents = {
|
|
"docs": itemgetter("standalone_question") | compression_retriever,
|
|
"question": lambda x: x["standalone_question"],
|
|
"custom_instructions": lambda x: prompt_to_use,
|
|
}
|
|
|
|
final_inputs = {
|
|
"context": lambda x: self._combine_documents(x["docs"]),
|
|
"question": itemgetter("question"),
|
|
"custom_instructions": itemgetter("custom_instructions"),
|
|
"files": lambda x: list_files,
|
|
}
|
|
llm = ChatLiteLLM(
|
|
max_tokens=self.max_tokens,
|
|
model=self.model,
|
|
temperature=self.temperature,
|
|
api_base=api_base,
|
|
) # pyright: ignore reportPrivateUsage=none
|
|
if self.model_compatible_with_function_calling():
|
|
|
|
# And finally, we do the part that returns the answers
|
|
llm_function = ChatOpenAI(
|
|
max_tokens=self.max_tokens,
|
|
model=self.model,
|
|
temperature=self.temperature,
|
|
)
|
|
llm = llm_function.bind_tools(
|
|
[cited_answer],
|
|
tool_choice="cited_answer",
|
|
)
|
|
|
|
answer = {
|
|
"answer": final_inputs | ANSWER_PROMPT | llm,
|
|
"docs": itemgetter("docs"),
|
|
}
|
|
|
|
return loaded_memory | standalone_question | retrieved_documents | answer
|