feat(reranker): Add flashrank and contextual compression retriever (#2480)

This pull request adds the flashrank and contextual compression
retriever to the codebase. The flashrank reranker model is used for
compression, and the contextual compression retriever combines the base
compressor and base retriever to improve document retrieval.
This commit is contained in:
Stan Girard 2024-04-24 10:44:31 -07:00 committed by GitHub
parent 7ead787626
commit f656dbcb42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 501 additions and 404 deletions

View File

@ -34,6 +34,8 @@ TELEMETRY_ENABLED=true
CELERY_BROKER_URL=redis://redis:6379/0
CELEBRY_BROKER_QUEUE_NAME=quivr-preview.fifo
QUIVR_DOMAIN=http://localhost:3000/
#COHERE_API_KEY=CHANGE_ME

View File

@ -60,6 +60,8 @@ datasets = "*"
pytest-dotenv = "*"
fpdf2 = "*"
unidecode = "*"
flashrank = "*"
langchain-cohere = "*"
[dev-packages]
black = "*"

786
Pipfile.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -25,8 +25,8 @@ class File(BaseModel):
vectors_ids: Optional[list] = []
file_extension: Optional[str] = ""
content: Optional[Any] = None
chunk_size: int = 250
chunk_overlap: int = 0
chunk_size: int = 800
chunk_overlap: int = 300
documents: Optional[Document] = None
@property

View File

@ -1,3 +1,4 @@
import os
from operator import itemgetter
from typing import Optional
from uuid import UUID
@ -5,20 +6,21 @@ from uuid import UUID
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings.ollama import OllamaEmbeddings
from langchain.llms.base import BaseLLM
from langchain.memory import ConversationBufferMemory
from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain.schema import format_document
from langchain_cohere import CohereRerank
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.messages import get_buffer_string
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import OpenAIEmbeddings
from modules.prompt.service.get_prompt_to_use import get_prompt_to_use
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
from modules.brain.service.brain_service import BrainService
from modules.chat.service.chat_service import ChatService
from modules.prompt.service.get_prompt_to_use import get_prompt_to_use
from pydantic import BaseModel, ConfigDict
from pydantic_settings import BaseSettings
from supabase.client import Client, create_client
@ -28,7 +30,7 @@ logger = get_logger(__name__)
# First step is to create the Rephrasing Prompt
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. Keep as much details as possible from previous messages. Keep entity names and all.
Chat History:
{chat_history}
@ -202,14 +204,20 @@ class QuivrRAG(BaseModel):
return self.vector_store.as_retriever()
def get_chain(self):
compressor = None
if os.getenv("COHERE_API_KEY"):
compressor = CohereRerank(top_n=5)
else:
compressor = FlashrankRerank(model="ms-marco-TinyBERT-L-2-v2", top_n=5)
retriever_doc = self.get_retriever()
memory = ConversationBufferMemory(
return_messages=True, output_key="answer", input_key="question"
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever_doc
)
loaded_memory = RunnablePassthrough.assign(
chat_history=RunnableLambda(memory.load_memory_variables)
| itemgetter("history"),
chat_history=lambda x: x["chat_history"],
question=lambda x: x["question"],
)
api_base = None
@ -219,7 +227,7 @@ class QuivrRAG(BaseModel):
standalone_question = {
"standalone_question": {
"question": lambda x: x["question"],
"chat_history": lambda x: get_buffer_string(x["chat_history"]),
"chat_history": lambda x: x["chat_history"],
}
| CONDENSE_QUESTION_PROMPT
| ChatLiteLLM(temperature=0, model=self.model, api_base=api_base)
@ -233,7 +241,7 @@ class QuivrRAG(BaseModel):
# Now we retrieve the documents
retrieved_documents = {
"docs": itemgetter("standalone_question") | retriever_doc,
"docs": itemgetter("standalone_question") | compression_retriever,
"question": lambda x: x["standalone_question"],
"custom_instructions": lambda x: prompt_to_use,
}

View File

@ -17,8 +17,8 @@ backoff==2.2.1; python_version >= '3.7' and python_version < '4.0'
beautifulsoup4==4.12.3; python_full_version >= '3.6.0'
billiard==4.2.0; python_version >= '3.7'
black==24.4.0; python_version >= '3.8'
boto3==1.34.86; python_version >= '3.8'
botocore==1.34.86; python_version >= '3.8'
boto3==1.34.90; python_version >= '3.8'
botocore==1.34.90; python_version >= '3.8'
celery[redis,sqs]==5.4.0; python_version >= '3.8'
certifi==2024.2.2; python_version >= '3.6'
cffi==1.16.0; platform_python_implementation != 'PyPy'
@ -28,6 +28,7 @@ click==8.1.7; python_version >= '3.7'
click-didyoumean==0.3.1; python_full_version >= '3.6.2'
click-plugins==1.1.1
click-repl==0.3.0; python_version >= '3.6'
cohere==5.3.3; python_version >= '3.8' and python_version < '4.0'
coloredlogs==15.0.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
colorlog==6.8.2; python_version >= '3.6'
contourpy==1.2.1; python_version >= '3.9'
@ -35,10 +36,10 @@ cryptography==42.0.5; python_version >= '3.7'
cssselect==1.2.0; python_version >= '3.7'
cycler==0.12.1; python_version >= '3.8'
dataclasses-json==0.6.4; python_version >= '3.7' and python_version < '4.0'
dataclasses-json-speakeasy==0.5.11; python_version >= '3.7' and python_version < '4.0'
datasets==2.18.0; python_full_version >= '3.8.0'
datasets==2.19.0; python_full_version >= '3.8.0'
debugpy==1.8.1; python_version >= '3.8'
decorator==5.1.1; python_version >= '3.5'
deepdiff==7.0.1; python_version >= '3.8'
defusedxml==0.7.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
deprecated==1.2.14; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
deprecation==2.1.0
@ -50,22 +51,24 @@ docx2txt==0.8
duckdb==0.10.2; python_full_version >= '3.7.0'
ecdsa==0.19.0; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
effdet==0.4.1
emoji==2.11.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
emoji==2.11.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
et-xmlfile==1.1.0; python_version >= '3.6'
faker==19.13.0; python_version >= '3.8'
fastapi==0.110.1; python_version >= '3.8'
fastapi==0.110.2; python_version >= '3.8'
fastavro==1.9.4; python_version >= '3.8'
feedfinder2==0.0.4
feedparser==6.0.11; python_version >= '3.6'
filelock==3.13.4; python_version >= '3.8'
filetype==1.2.0
flake8==7.0.0; python_full_version >= '3.8.1'
flake8-black==0.3.6; python_version >= '3.7'
flashrank==0.2.0; python_version >= '3.6'
flatbuffers==24.3.25
flower==2.0.1; python_version >= '3.7'
fonttools==4.51.0; python_version >= '3.8'
fpdf2==2.7.8; python_version >= '3.7'
frozenlist==1.4.1; python_version >= '3.8'
fsspec[http]==2024.2.0; python_version >= '3.8'
fsspec[http]==2024.3.1; python_version >= '3.8'
gitdb==4.0.11; python_version >= '3.7'
gitpython==3.1.43; python_version >= '3.7'
gotrue==2.4.2; python_version >= '3.8' and python_version < '4.0'
@ -74,6 +77,7 @@ h11==0.14.0; python_version >= '3.7'
html5lib==1.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
httpcore==1.0.5; python_version >= '3.8'
httpx==0.27.0; python_version >= '3.8'
httpx-sse==0.4.0; python_version >= '3.8'
huggingface-hub==0.22.2; python_full_version >= '3.8.0'
humanfriendly==10.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
humanize==4.9.0; python_version >= '3.8'
@ -92,25 +96,26 @@ jsonpointer==2.4; python_version >= '2.7' and python_version not in '3.0, 3.1, 3
kiwisolver==1.4.5; python_version >= '3.7'
kombu[sqs]==5.3.7; python_version >= '3.8'
langchain==0.1.16; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-community==0.0.33; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-core==0.1.44; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-cohere==0.1.3; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-community==0.0.34; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-core==0.1.45; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-openai==0.1.3; python_version < '4.0' and python_full_version >= '3.8.1'
langchain-text-splitters==0.0.1; python_version < '4.0' and python_full_version >= '3.8.1'
langdetect==1.0.9
langfuse==2.26.3; python_version < '4.0' and python_full_version >= '3.8.1'
langsmith==0.1.48; python_version < '4.0' and python_full_version >= '3.8.1'
langfuse==2.27.1; python_version < '4.0' and python_full_version >= '3.8.1'
langsmith==0.1.50; python_version < '4.0' and python_full_version >= '3.8.1'
layoutparser[layoutmodels,tesseract]==0.3.4; python_version >= '3.6'
litellm==1.35.10; python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7' and python_version >= '3.8'
llama-index==0.10.29; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-agent-openai==0.2.2; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-cli==0.1.11; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-core==0.10.29; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-embeddings-openai==0.1.7; python_version < '4.0' and python_full_version >= '3.8.1'
litellm==1.35.21; python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7' and python_version >= '3.8'
llama-index==0.10.31; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-agent-openai==0.2.3; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-cli==0.1.12; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-core==0.10.31; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-embeddings-openai==0.1.8; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-indices-managed-llama-cloud==0.1.5; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-legacy==0.9.48; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-llms-openai==0.1.15; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-llms-openai==0.1.16; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-multi-modal-llms-openai==0.1.5; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-program-openai==0.1.5; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-program-openai==0.1.6; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-question-gen-openai==0.1.3; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-readers-file==0.1.19; python_version < '4.0' and python_full_version >= '3.8.1'
llama-index-readers-llama-parse==0.1.4; python_version < '4.0' and python_full_version >= '3.8.1'
@ -138,14 +143,15 @@ numpy==1.26.4; python_version >= '3.9'
olefile==0.47; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
omegaconf==2.3.0; python_version >= '3.6'
onnx==1.16.0
onnxruntime==1.15.1
openai==1.21.1; python_full_version >= '3.7.1'
onnxruntime==1.17.3
openai==1.23.3; python_full_version >= '3.7.1'
opencv-python==4.9.0.80; python_version >= '3.6'
openpyxl==3.1.2
ordered-set==4.1.0; python_version >= '3.7'
orjson==3.10.1; python_version >= '3.8'
packaging==23.2; python_version >= '3.7'
pandas==1.5.3; python_version >= '3.8'
pandasai==2.0.33; python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8' and python_version >= '3.9'
pandasai==2.0.35; python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8' and python_version >= '3.9'
pathspec==0.12.1; python_version >= '3.8'
pdf2image==1.17.0
pdfminer.six==20231228
@ -153,8 +159,8 @@ pdfplumber==0.11.0; python_version >= '3.8'
pikepdf==8.15.1
pillow==10.3.0; python_version >= '3.8'
pillow-heif==0.16.0
platformdirs==4.2.0; python_version >= '3.8'
pluggy==1.4.0; python_version >= '3.8'
platformdirs==4.2.1; python_version >= '3.8'
pluggy==1.5.0; python_version >= '3.8'
portalocker==2.8.2; python_version >= '3.8'
postgrest==0.16.3; python_version >= '3.8' and python_version < '4.0'
posthog==3.5.0
@ -165,22 +171,22 @@ psutil==5.9.8; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2,
psycopg2==2.9.9; python_version >= '3.7'
psycopg2-binary==2.9.9; python_version >= '3.7'
py==1.11.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
pyarrow==15.0.2; python_version >= '3.8'
pyarrow==16.0.0; python_version >= '3.8'
pyarrow-hotfix==0.6; python_version >= '3.5'
pyasn1==0.6.0; python_version >= '3.8'
pycocotools==2.0.7; python_version >= '3.5'
pycodestyle==2.11.1; python_version >= '3.8'
pycparser==2.22; python_version >= '3.8'
pycurl==7.45.3
pydantic==2.7.0; python_version >= '3.8'
pydantic-core==2.18.1; python_version >= '3.8'
pydantic==2.7.1; python_version >= '3.8'
pydantic-core==2.18.2; python_version >= '3.8'
pydantic-settings==2.2.1; python_version >= '3.8'
pyflakes==3.2.0; python_version >= '3.8'
pypandoc==1.13; python_version >= '3.6'
pyparsing==3.1.2; python_full_version >= '3.6.8'
pypdf==4.2.0; python_version >= '3.6'
pypdfium2==4.29.0; python_version >= '3.6'
pyright==1.1.359; python_version >= '3.7'
pyright==1.1.360; python_version >= '3.7'
pysbd==0.3.4; python_version >= '3'
pytesseract==0.3.10; python_version >= '3.7'
pytest==8.1.1; python_version >= '3.8'
@ -201,7 +207,7 @@ pyyaml==6.0.1; python_version >= '3.6'
ragas==0.1.7
rapidfuzz==3.8.1; python_version >= '3.8'
realtime==1.0.4; python_version >= '3.8' and python_version < '4.0'
redis==5.0.3; python_version >= '3.7'
redis==5.0.4; python_version >= '3.7'
regex==2024.4.16; python_version >= '3.7'
requests==2.31.0; python_version >= '3.7'
requests-file==2.0.0
@ -232,21 +238,22 @@ tiktoken==0.6.0; python_version >= '3.8'
timm==0.9.16; python_version >= '3.8'
tinysegmenter==0.3
tldextract==5.1.2; python_version >= '3.8'
tokenizers==0.15.2; python_version >= '3.7'
tokenizers==0.19.1; python_version >= '3.7'
torch==2.2.2
torchvision==0.17.2
tornado==6.4; python_version >= '3.8'
tqdm==4.66.2; python_version >= '3.7'
transformers==4.39.3; python_full_version >= '3.8.0'
transformers==4.40.1; python_full_version >= '3.8.0'
types-requests==2.31.0.20240406; python_version >= '3.8'
typing-extensions==4.11.0; python_version >= '3.8'
typing-inspect==0.9.0
tzdata==2024.1; python_version >= '2'
unidecode==1.3.8; python_version >= '3.5'
unstructured[all-docs]==0.13.2; python_version < '3.12' and python_full_version >= '3.9.0'
unstructured-client==0.18.0; python_version >= '3.8'
unstructured-inference==0.7.25
unstructured[all-docs]==0.13.3; python_version < '3.12' and python_full_version >= '3.9.0'
unstructured-client==0.22.0; python_version >= '3.8'
unstructured-inference==0.7.27
unstructured.pytesseract==0.3.12
urllib3==2.2.1; python_version >= '3.8'
urllib3==2.2.1; python_version >= '3.10'
uvicorn==0.29.0; python_version >= '3.8'
vine==5.1.0; python_version >= '3.6'
watchdog==4.0.0; python_version >= '3.8'

View File

@ -22,9 +22,9 @@ from modules.brain.rags.quivr_rag import QuivrRAG
from modules.brain.service.brain_service import BrainService
from modules.knowledge.dto.inputs import CreateKnowledgeProperties
from modules.knowledge.service.knowledge_service import KnowledgeService
from modules.upload.service.upload_file import upload_file_storage
from ragas import evaluate
from ragas.embeddings.base import LangchainEmbeddingsWrapper
from modules.upload.service.upload_file import upload_file_storage
def main(
@ -176,7 +176,7 @@ if __name__ == "__main__":
"--model", type=str, default="gpt-3.5-turbo-0125", help="Model to use"
)
parser.add_argument(
"--context_size", type=int, default=4000, help="Context size for the model"
"--context_size", type=int, default=10000, help="Context size for the model"
)
parser.add_argument(
"--metrics",