mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 07:59:00 +03:00
feat(reranker): Add flashrank and contextual compression retriever (#2480)
This pull request adds the flashrank and contextual compression retriever to the codebase. The flashrank reranker model is used for compression, and the contextual compression retriever combines the base compressor and base retriever to improve document retrieval.
This commit is contained in:
parent
7ead787626
commit
f656dbcb42
@ -34,6 +34,8 @@ TELEMETRY_ENABLED=true
|
||||
CELERY_BROKER_URL=redis://redis:6379/0
|
||||
CELEBRY_BROKER_QUEUE_NAME=quivr-preview.fifo
|
||||
QUIVR_DOMAIN=http://localhost:3000/
|
||||
#COHERE_API_KEY=CHANGE_ME
|
||||
|
||||
|
||||
|
||||
|
||||
|
2
Pipfile
2
Pipfile
@ -60,6 +60,8 @@ datasets = "*"
|
||||
pytest-dotenv = "*"
|
||||
fpdf2 = "*"
|
||||
unidecode = "*"
|
||||
flashrank = "*"
|
||||
langchain-cohere = "*"
|
||||
|
||||
[dev-packages]
|
||||
black = "*"
|
||||
|
786
Pipfile.lock
generated
786
Pipfile.lock
generated
File diff suppressed because it is too large
Load Diff
@ -25,8 +25,8 @@ class File(BaseModel):
|
||||
vectors_ids: Optional[list] = []
|
||||
file_extension: Optional[str] = ""
|
||||
content: Optional[Any] = None
|
||||
chunk_size: int = 250
|
||||
chunk_overlap: int = 0
|
||||
chunk_size: int = 800
|
||||
chunk_overlap: int = 300
|
||||
documents: Optional[Document] = None
|
||||
|
||||
@property
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from operator import itemgetter
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
@ -5,20 +6,21 @@ from uuid import UUID
|
||||
from langchain.chains import ConversationalRetrievalChain
|
||||
from langchain.embeddings.ollama import OllamaEmbeddings
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import FlashrankRerank
|
||||
from langchain.schema import format_document
|
||||
from langchain_cohere import CohereRerank
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.messages import get_buffer_string
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
|
||||
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from modules.prompt.service.get_prompt_to_use import get_prompt_to_use
|
||||
from logger import get_logger
|
||||
from models import BrainSettings # Importing settings related to the 'brain'
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
from modules.prompt.service.get_prompt_to_use import get_prompt_to_use
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic_settings import BaseSettings
|
||||
from supabase.client import Client, create_client
|
||||
@ -28,7 +30,7 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
# First step is to create the Rephrasing Prompt
|
||||
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
|
||||
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. Keep as much details as possible from previous messages. Keep entity names and all.
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
@ -202,14 +204,20 @@ class QuivrRAG(BaseModel):
|
||||
return self.vector_store.as_retriever()
|
||||
|
||||
def get_chain(self):
|
||||
compressor = None
|
||||
if os.getenv("COHERE_API_KEY"):
|
||||
compressor = CohereRerank(top_n=5)
|
||||
else:
|
||||
compressor = FlashrankRerank(model="ms-marco-TinyBERT-L-2-v2", top_n=5)
|
||||
|
||||
retriever_doc = self.get_retriever()
|
||||
memory = ConversationBufferMemory(
|
||||
return_messages=True, output_key="answer", input_key="question"
|
||||
compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=compressor, base_retriever=retriever_doc
|
||||
)
|
||||
|
||||
loaded_memory = RunnablePassthrough.assign(
|
||||
chat_history=RunnableLambda(memory.load_memory_variables)
|
||||
| itemgetter("history"),
|
||||
chat_history=lambda x: x["chat_history"],
|
||||
question=lambda x: x["question"],
|
||||
)
|
||||
|
||||
api_base = None
|
||||
@ -219,7 +227,7 @@ class QuivrRAG(BaseModel):
|
||||
standalone_question = {
|
||||
"standalone_question": {
|
||||
"question": lambda x: x["question"],
|
||||
"chat_history": lambda x: get_buffer_string(x["chat_history"]),
|
||||
"chat_history": lambda x: x["chat_history"],
|
||||
}
|
||||
| CONDENSE_QUESTION_PROMPT
|
||||
| ChatLiteLLM(temperature=0, model=self.model, api_base=api_base)
|
||||
@ -233,7 +241,7 @@ class QuivrRAG(BaseModel):
|
||||
|
||||
# Now we retrieve the documents
|
||||
retrieved_documents = {
|
||||
"docs": itemgetter("standalone_question") | retriever_doc,
|
||||
"docs": itemgetter("standalone_question") | compression_retriever,
|
||||
"question": lambda x: x["standalone_question"],
|
||||
"custom_instructions": lambda x: prompt_to_use,
|
||||
}
|
||||
|
@ -17,8 +17,8 @@ backoff==2.2.1; python_version >= '3.7' and python_version < '4.0'
|
||||
beautifulsoup4==4.12.3; python_full_version >= '3.6.0'
|
||||
billiard==4.2.0; python_version >= '3.7'
|
||||
black==24.4.0; python_version >= '3.8'
|
||||
boto3==1.34.86; python_version >= '3.8'
|
||||
botocore==1.34.86; python_version >= '3.8'
|
||||
boto3==1.34.90; python_version >= '3.8'
|
||||
botocore==1.34.90; python_version >= '3.8'
|
||||
celery[redis,sqs]==5.4.0; python_version >= '3.8'
|
||||
certifi==2024.2.2; python_version >= '3.6'
|
||||
cffi==1.16.0; platform_python_implementation != 'PyPy'
|
||||
@ -28,6 +28,7 @@ click==8.1.7; python_version >= '3.7'
|
||||
click-didyoumean==0.3.1; python_full_version >= '3.6.2'
|
||||
click-plugins==1.1.1
|
||||
click-repl==0.3.0; python_version >= '3.6'
|
||||
cohere==5.3.3; python_version >= '3.8' and python_version < '4.0'
|
||||
coloredlogs==15.0.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
|
||||
colorlog==6.8.2; python_version >= '3.6'
|
||||
contourpy==1.2.1; python_version >= '3.9'
|
||||
@ -35,10 +36,10 @@ cryptography==42.0.5; python_version >= '3.7'
|
||||
cssselect==1.2.0; python_version >= '3.7'
|
||||
cycler==0.12.1; python_version >= '3.8'
|
||||
dataclasses-json==0.6.4; python_version >= '3.7' and python_version < '4.0'
|
||||
dataclasses-json-speakeasy==0.5.11; python_version >= '3.7' and python_version < '4.0'
|
||||
datasets==2.18.0; python_full_version >= '3.8.0'
|
||||
datasets==2.19.0; python_full_version >= '3.8.0'
|
||||
debugpy==1.8.1; python_version >= '3.8'
|
||||
decorator==5.1.1; python_version >= '3.5'
|
||||
deepdiff==7.0.1; python_version >= '3.8'
|
||||
defusedxml==0.7.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
|
||||
deprecated==1.2.14; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
|
||||
deprecation==2.1.0
|
||||
@ -50,22 +51,24 @@ docx2txt==0.8
|
||||
duckdb==0.10.2; python_full_version >= '3.7.0'
|
||||
ecdsa==0.19.0; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
|
||||
effdet==0.4.1
|
||||
emoji==2.11.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
|
||||
emoji==2.11.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3'
|
||||
et-xmlfile==1.1.0; python_version >= '3.6'
|
||||
faker==19.13.0; python_version >= '3.8'
|
||||
fastapi==0.110.1; python_version >= '3.8'
|
||||
fastapi==0.110.2; python_version >= '3.8'
|
||||
fastavro==1.9.4; python_version >= '3.8'
|
||||
feedfinder2==0.0.4
|
||||
feedparser==6.0.11; python_version >= '3.6'
|
||||
filelock==3.13.4; python_version >= '3.8'
|
||||
filetype==1.2.0
|
||||
flake8==7.0.0; python_full_version >= '3.8.1'
|
||||
flake8-black==0.3.6; python_version >= '3.7'
|
||||
flashrank==0.2.0; python_version >= '3.6'
|
||||
flatbuffers==24.3.25
|
||||
flower==2.0.1; python_version >= '3.7'
|
||||
fonttools==4.51.0; python_version >= '3.8'
|
||||
fpdf2==2.7.8; python_version >= '3.7'
|
||||
frozenlist==1.4.1; python_version >= '3.8'
|
||||
fsspec[http]==2024.2.0; python_version >= '3.8'
|
||||
fsspec[http]==2024.3.1; python_version >= '3.8'
|
||||
gitdb==4.0.11; python_version >= '3.7'
|
||||
gitpython==3.1.43; python_version >= '3.7'
|
||||
gotrue==2.4.2; python_version >= '3.8' and python_version < '4.0'
|
||||
@ -74,6 +77,7 @@ h11==0.14.0; python_version >= '3.7'
|
||||
html5lib==1.1; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
|
||||
httpcore==1.0.5; python_version >= '3.8'
|
||||
httpx==0.27.0; python_version >= '3.8'
|
||||
httpx-sse==0.4.0; python_version >= '3.8'
|
||||
huggingface-hub==0.22.2; python_full_version >= '3.8.0'
|
||||
humanfriendly==10.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
|
||||
humanize==4.9.0; python_version >= '3.8'
|
||||
@ -92,25 +96,26 @@ jsonpointer==2.4; python_version >= '2.7' and python_version not in '3.0, 3.1, 3
|
||||
kiwisolver==1.4.5; python_version >= '3.7'
|
||||
kombu[sqs]==5.3.7; python_version >= '3.8'
|
||||
langchain==0.1.16; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
langchain-community==0.0.33; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
langchain-core==0.1.44; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
langchain-cohere==0.1.3; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
langchain-community==0.0.34; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
langchain-core==0.1.45; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
langchain-openai==0.1.3; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
langchain-text-splitters==0.0.1; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
langdetect==1.0.9
|
||||
langfuse==2.26.3; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
langsmith==0.1.48; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
langfuse==2.27.1; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
langsmith==0.1.50; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
layoutparser[layoutmodels,tesseract]==0.3.4; python_version >= '3.6'
|
||||
litellm==1.35.10; python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7' and python_version >= '3.8'
|
||||
llama-index==0.10.29; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-agent-openai==0.2.2; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-cli==0.1.11; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-core==0.10.29; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-embeddings-openai==0.1.7; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
litellm==1.35.21; python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7' and python_version >= '3.8'
|
||||
llama-index==0.10.31; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-agent-openai==0.2.3; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-cli==0.1.12; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-core==0.10.31; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-embeddings-openai==0.1.8; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-indices-managed-llama-cloud==0.1.5; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-legacy==0.9.48; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-llms-openai==0.1.15; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-llms-openai==0.1.16; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-multi-modal-llms-openai==0.1.5; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-program-openai==0.1.5; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-program-openai==0.1.6; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-question-gen-openai==0.1.3; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-readers-file==0.1.19; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
llama-index-readers-llama-parse==0.1.4; python_version < '4.0' and python_full_version >= '3.8.1'
|
||||
@ -138,14 +143,15 @@ numpy==1.26.4; python_version >= '3.9'
|
||||
olefile==0.47; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
|
||||
omegaconf==2.3.0; python_version >= '3.6'
|
||||
onnx==1.16.0
|
||||
onnxruntime==1.15.1
|
||||
openai==1.21.1; python_full_version >= '3.7.1'
|
||||
onnxruntime==1.17.3
|
||||
openai==1.23.3; python_full_version >= '3.7.1'
|
||||
opencv-python==4.9.0.80; python_version >= '3.6'
|
||||
openpyxl==3.1.2
|
||||
ordered-set==4.1.0; python_version >= '3.7'
|
||||
orjson==3.10.1; python_version >= '3.8'
|
||||
packaging==23.2; python_version >= '3.7'
|
||||
pandas==1.5.3; python_version >= '3.8'
|
||||
pandasai==2.0.33; python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8' and python_version >= '3.9'
|
||||
pandasai==2.0.35; python_version not in '2.7, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8' and python_version >= '3.9'
|
||||
pathspec==0.12.1; python_version >= '3.8'
|
||||
pdf2image==1.17.0
|
||||
pdfminer.six==20231228
|
||||
@ -153,8 +159,8 @@ pdfplumber==0.11.0; python_version >= '3.8'
|
||||
pikepdf==8.15.1
|
||||
pillow==10.3.0; python_version >= '3.8'
|
||||
pillow-heif==0.16.0
|
||||
platformdirs==4.2.0; python_version >= '3.8'
|
||||
pluggy==1.4.0; python_version >= '3.8'
|
||||
platformdirs==4.2.1; python_version >= '3.8'
|
||||
pluggy==1.5.0; python_version >= '3.8'
|
||||
portalocker==2.8.2; python_version >= '3.8'
|
||||
postgrest==0.16.3; python_version >= '3.8' and python_version < '4.0'
|
||||
posthog==3.5.0
|
||||
@ -165,22 +171,22 @@ psutil==5.9.8; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2,
|
||||
psycopg2==2.9.9; python_version >= '3.7'
|
||||
psycopg2-binary==2.9.9; python_version >= '3.7'
|
||||
py==1.11.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'
|
||||
pyarrow==15.0.2; python_version >= '3.8'
|
||||
pyarrow==16.0.0; python_version >= '3.8'
|
||||
pyarrow-hotfix==0.6; python_version >= '3.5'
|
||||
pyasn1==0.6.0; python_version >= '3.8'
|
||||
pycocotools==2.0.7; python_version >= '3.5'
|
||||
pycodestyle==2.11.1; python_version >= '3.8'
|
||||
pycparser==2.22; python_version >= '3.8'
|
||||
pycurl==7.45.3
|
||||
pydantic==2.7.0; python_version >= '3.8'
|
||||
pydantic-core==2.18.1; python_version >= '3.8'
|
||||
pydantic==2.7.1; python_version >= '3.8'
|
||||
pydantic-core==2.18.2; python_version >= '3.8'
|
||||
pydantic-settings==2.2.1; python_version >= '3.8'
|
||||
pyflakes==3.2.0; python_version >= '3.8'
|
||||
pypandoc==1.13; python_version >= '3.6'
|
||||
pyparsing==3.1.2; python_full_version >= '3.6.8'
|
||||
pypdf==4.2.0; python_version >= '3.6'
|
||||
pypdfium2==4.29.0; python_version >= '3.6'
|
||||
pyright==1.1.359; python_version >= '3.7'
|
||||
pyright==1.1.360; python_version >= '3.7'
|
||||
pysbd==0.3.4; python_version >= '3'
|
||||
pytesseract==0.3.10; python_version >= '3.7'
|
||||
pytest==8.1.1; python_version >= '3.8'
|
||||
@ -201,7 +207,7 @@ pyyaml==6.0.1; python_version >= '3.6'
|
||||
ragas==0.1.7
|
||||
rapidfuzz==3.8.1; python_version >= '3.8'
|
||||
realtime==1.0.4; python_version >= '3.8' and python_version < '4.0'
|
||||
redis==5.0.3; python_version >= '3.7'
|
||||
redis==5.0.4; python_version >= '3.7'
|
||||
regex==2024.4.16; python_version >= '3.7'
|
||||
requests==2.31.0; python_version >= '3.7'
|
||||
requests-file==2.0.0
|
||||
@ -232,21 +238,22 @@ tiktoken==0.6.0; python_version >= '3.8'
|
||||
timm==0.9.16; python_version >= '3.8'
|
||||
tinysegmenter==0.3
|
||||
tldextract==5.1.2; python_version >= '3.8'
|
||||
tokenizers==0.15.2; python_version >= '3.7'
|
||||
tokenizers==0.19.1; python_version >= '3.7'
|
||||
torch==2.2.2
|
||||
torchvision==0.17.2
|
||||
tornado==6.4; python_version >= '3.8'
|
||||
tqdm==4.66.2; python_version >= '3.7'
|
||||
transformers==4.39.3; python_full_version >= '3.8.0'
|
||||
transformers==4.40.1; python_full_version >= '3.8.0'
|
||||
types-requests==2.31.0.20240406; python_version >= '3.8'
|
||||
typing-extensions==4.11.0; python_version >= '3.8'
|
||||
typing-inspect==0.9.0
|
||||
tzdata==2024.1; python_version >= '2'
|
||||
unidecode==1.3.8; python_version >= '3.5'
|
||||
unstructured[all-docs]==0.13.2; python_version < '3.12' and python_full_version >= '3.9.0'
|
||||
unstructured-client==0.18.0; python_version >= '3.8'
|
||||
unstructured-inference==0.7.25
|
||||
unstructured[all-docs]==0.13.3; python_version < '3.12' and python_full_version >= '3.9.0'
|
||||
unstructured-client==0.22.0; python_version >= '3.8'
|
||||
unstructured-inference==0.7.27
|
||||
unstructured.pytesseract==0.3.12
|
||||
urllib3==2.2.1; python_version >= '3.8'
|
||||
urllib3==2.2.1; python_version >= '3.10'
|
||||
uvicorn==0.29.0; python_version >= '3.8'
|
||||
vine==5.1.0; python_version >= '3.6'
|
||||
watchdog==4.0.0; python_version >= '3.8'
|
||||
|
@ -22,9 +22,9 @@ from modules.brain.rags.quivr_rag import QuivrRAG
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.knowledge.dto.inputs import CreateKnowledgeProperties
|
||||
from modules.knowledge.service.knowledge_service import KnowledgeService
|
||||
from modules.upload.service.upload_file import upload_file_storage
|
||||
from ragas import evaluate
|
||||
from ragas.embeddings.base import LangchainEmbeddingsWrapper
|
||||
from modules.upload.service.upload_file import upload_file_storage
|
||||
|
||||
|
||||
def main(
|
||||
@ -176,7 +176,7 @@ if __name__ == "__main__":
|
||||
"--model", type=str, default="gpt-3.5-turbo-0125", help="Model to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_size", type=int, default=4000, help="Context size for the model"
|
||||
"--context_size", type=int, default=10000, help="Context size for the model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metrics",
|
||||
|
Loading…
Reference in New Issue
Block a user