quivr/core/quivr_core/quivr_rag.py
Stan Girard 7acb52a963
feat(quivr-core): beginning (#3388)
# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
2024-10-21 00:50:31 -07:00

258 lines
9.5 KiB
Python

import logging
from operator import itemgetter
from typing import AsyncGenerator, Optional, Sequence
# TODO(@aminediro): this is the only dependency to langchain package, we should remove it
from langchain.retrievers import ContextualCompressionRetriever
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.vectorstores import VectorStore
from quivr_core.chat import ChatHistory
from quivr_core.config import RetrievalConfig
from quivr_core.llm import LLMEndpoint
from quivr_core.models import (
ParsedRAGChunkResponse,
ParsedRAGResponse,
QuivrKnowledge,
RAGResponseMetadata,
cited_answer,
)
from quivr_core.prompts import custom_prompts
from quivr_core.utils import (
combine_documents,
format_file_list,
get_chunk_metadata,
parse_chunk_response,
parse_response,
)
logger = logging.getLogger("quivr_core")
class IdempotentCompressor(BaseDocumentCompressor):
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
return documents
class QuivrQARAG:
"""
QuivrQA RAG is a class that provides a RAG interface to the QuivrQA system.
"""
def __init__(
self,
*,
retrieval_config: RetrievalConfig,
llm: LLMEndpoint,
vector_store: VectorStore,
reranker: BaseDocumentCompressor | None = None,
):
self.retrieval_config = retrieval_config
self.vector_store = vector_store
self.llm_endpoint = llm
self.reranker = reranker if reranker is not None else IdempotentCompressor()
@property
def retriever(self):
"""
Retriever is a function that retrieves the documents from the vector store.
"""
return self.vector_store.as_retriever()
def filter_history(
self,
chat_history: ChatHistory,
):
"""
Filter out the chat history to only include the messages that are relevant to the current question
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '), AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer, sous la direction de son supérieur hiérarchique, Stanislas Girard."), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''), HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
a token is 4 characters
"""
total_tokens = 0
total_pairs = 0
filtered_chat_history: list[AIMessage | HumanMessage] = []
for human_message, ai_message in chat_history.iter_pairs():
# TODO: replace with tiktoken
message_tokens = (len(human_message.content) + len(ai_message.content)) // 4
if (
total_tokens + message_tokens
> self.retrieval_config.llm_config.max_output_tokens
or total_pairs >= self.retrieval_config.max_history
):
break
filtered_chat_history.append(human_message)
filtered_chat_history.append(ai_message)
total_tokens += message_tokens
total_pairs += 1
return filtered_chat_history[::-1]
def build_chain(self, files: str):
"""
Builds the chain for the QuivrQA RAG.
"""
compression_retriever = ContextualCompressionRetriever(
base_compressor=self.reranker, base_retriever=self.retriever
)
loaded_memory = RunnablePassthrough.assign(
chat_history=RunnableLambda(
lambda x: self.filter_history(x["chat_history"]),
),
question=lambda x: x["question"],
)
standalone_question = {
"standalone_question": {
"question": lambda x: x["question"],
"chat_history": itemgetter("chat_history"),
}
| custom_prompts.CONDENSE_QUESTION_PROMPT
| self.llm_endpoint._llm
| StrOutputParser(),
}
# Now we retrieve the documents
retrieved_documents = {
"docs": itemgetter("standalone_question") | compression_retriever,
"question": lambda x: x["standalone_question"],
"custom_instructions": lambda x: self.retrieval_config.prompt,
}
final_inputs = {
"context": lambda x: combine_documents(x["docs"]),
"question": itemgetter("question"),
"custom_instructions": itemgetter("custom_instructions"),
"files": lambda _: files, # TODO: shouldn't be here
}
# Bind the llm to cited_answer if model supports it
llm = self.llm_endpoint._llm
if self.llm_endpoint.supports_func_calling():
llm = self.llm_endpoint._llm.bind_tools(
[cited_answer],
tool_choice="any",
)
answer = {
"answer": final_inputs | custom_prompts.ANSWER_PROMPT | llm,
"docs": itemgetter("docs"),
}
return loaded_memory | standalone_question | retrieved_documents | answer
def answer(
self,
question: str,
history: ChatHistory,
list_files: list[QuivrKnowledge],
metadata: dict[str, str] = {},
) -> ParsedRAGResponse:
"""
Answers a question using the QuivrQA RAG synchronously.
"""
concat_list_files = format_file_list(
list_files, self.retrieval_config.max_files
)
conversational_qa_chain = self.build_chain(concat_list_files)
raw_llm_response = conversational_qa_chain.invoke(
{
"question": question,
"chat_history": history,
"custom_instructions": (self.retrieval_config.prompt),
},
config={"metadata": metadata},
)
response = parse_response(
raw_llm_response, self.retrieval_config.llm_config.model
)
return response
async def answer_astream(
self,
question: str,
history: ChatHistory,
list_files: list[QuivrKnowledge],
metadata: dict[str, str] = {},
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
"""
Answers a question using the QuivrQA RAG asynchronously.
"""
concat_list_files = format_file_list(
list_files, self.retrieval_config.max_files
)
conversational_qa_chain = self.build_chain(concat_list_files)
rolling_message = AIMessageChunk(content="")
sources = []
prev_answer = ""
chunk_id = 0
async for chunk in conversational_qa_chain.astream(
{
"question": question,
"chat_history": history,
"custom_personality": (self.retrieval_config.prompt),
},
config={"metadata": metadata},
):
# Could receive this anywhere so we need to save it for the last chunk
if "docs" in chunk:
sources = chunk["docs"] if "docs" in chunk else []
if "answer" in chunk:
rolling_message, answer_str = parse_chunk_response(
rolling_message,
chunk,
self.llm_endpoint.supports_func_calling(),
)
if len(answer_str) > 0:
if self.llm_endpoint.supports_func_calling():
diff_answer = answer_str[len(prev_answer) :]
if len(diff_answer) > 0:
parsed_chunk = ParsedRAGChunkResponse(
answer=diff_answer,
metadata=RAGResponseMetadata(),
)
prev_answer += diff_answer
logger.debug(
f"answer_astream func_calling=True question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}"
)
yield parsed_chunk
else:
parsed_chunk = ParsedRAGChunkResponse(
answer=answer_str,
metadata=RAGResponseMetadata(),
)
logger.debug(
f"answer_astream func_calling=False question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}"
)
yield parsed_chunk
chunk_id += 1
# Last chunk provides metadata
last_chunk = ParsedRAGChunkResponse(
answer="",
metadata=get_chunk_metadata(rolling_message, sources),
last_chunk=True,
)
logger.debug(
f"answer_astream last_chunk={last_chunk} question={question} rolling_msg={rolling_message} chunk_id={chunk_id}"
)
yield last_chunk