feat: quivr core minimal chat (#2803)

# Description

Minimal working example of `quivr-core` rag with minimal dependencies.

---------

Co-authored-by: aminediro <aminedirhoussi@gmail.com>
This commit is contained in:
AmineDiro 2024-07-09 15:22:16 +02:00 committed by GitHub
parent 296e9fd1b9
commit 1dc6d88f9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 2412 additions and 189 deletions

2
.gitignore vendored
View File

@ -87,3 +87,5 @@ backend/modules/sync/controller/credentials.json
backend/.env.test
**/*.egg-info
.coverage

View File

@ -0,0 +1,19 @@
from langchain_core.embeddings import DeterministicFakeEmbedding
from langchain_core.language_models import FakeListChatModel
from quivr_core import Brain
from quivr_core.processor.default_parsers import DEFAULT_PARSERS
from quivr_core.processor.pdf_processor import TikaParser
if __name__ == "__main__":
pdf_paths = ["../tests/processor/data/dummy.pdf"]
brain = Brain.from_files(
name="test_brain",
file_paths=[],
llm=FakeListChatModel(responses=["good"]),
embedder=DeterministicFakeEmbedding(size=20),
processors_mapping={
**DEFAULT_PARSERS,
".pdf": TikaParser(),
},
)

View File

@ -0,0 +1,14 @@
import tempfile
from quivr_core import Brain
if __name__ == "__main__":
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
temp_file.write("Gold is metal.")
temp_file.flush()
brain = Brain.from_files(name="test_brain", file_paths=[temp_file.name])
answer = brain.ask("Property of gold?")
print("answer :", answer.answer)

1906
backend/core/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -10,7 +10,37 @@ repository = "https://github.com/langchain-ai/langchain"
python = "^3.11"
pydantic = "^2.7.4"
langchain-core = "^0.2.10"
langchain = "^0.2.6"
httpx = "^0.27.0"
faiss-cpu = { version = "^1.8.0.post1", optional = true }
langchain-community = { version = "^0.2.6", optional = true }
langchain-openai = { version = "^0.1.14", optional = true }
aiofiles = "^24.1.0"
[tool.poetry.extras]
base = ["langchain-community", "faiss-cpu", "langchain-openai"]
pdf = []
[tool.poetry.group.dev.dependencies]
mypy = "^1.10.0"
pre-commit = "^3.7.1"
ipykernel = "*"
ruff = "^0.4.8"
flake8 = "*"
flake8-black = "*"
pytest-cov = "^5.0.0"
[tool.poetry.group.test.dependencies]
pytest-asyncio = "^0.23.7"
pytest = "^8.2.2"
[tool.mypy]
disallow_untyped_defs = true
# Remove venv skip when integrated with pre-commit
exclude = ["_static", "build", "examples", "notebooks", "venv", ".venv"]
ignore_missing_imports = true
python_version = "3.11"
[tool.ruff]
line-length = 88

View File

@ -0,0 +1,3 @@
from .brain import Brain
__all__ = ["Brain"]

View File

@ -0,0 +1,198 @@
import asyncio
import logging
from pathlib import Path
from typing import Mapping, Self
from uuid import UUID, uuid4
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseChatModel
from langchain_core.vectorstores import VectorStore
from quivr_core.config import RAGConfig
from quivr_core.models import ParsedRAGResponse
from quivr_core.processor.default_parsers import DEFAULT_PARSERS
from quivr_core.processor.processor_base import ProcessorBase
from quivr_core.quivr_rag import QuivrQARAG
from quivr_core.storage.file import QuivrFile
from quivr_core.storage.local_storage import TransparentStorage
from quivr_core.storage.storage_base import StorageBase
logger = logging.getLogger(__name__)
async def _process_files(
storage: StorageBase,
skip_file_error: bool,
processors_mapping: Mapping[str, ProcessorBase],
) -> list[Document]:
knowledge = []
for file in storage.get_files():
try:
if file.file_extension:
processor = processors_mapping[file.file_extension]
docs = await processor.process_file(file)
knowledge.extend(docs)
else:
logger.error(f"can't find processor for {file}")
if skip_file_error:
continue
else:
raise ValueError(f"can't parse {file}. can't find file extension")
except KeyError as e:
if skip_file_error:
continue
else:
raise Exception(f"Can't parse {file}. No available processor") from e
return knowledge
class Brain:
def __init__(
self,
*,
name: str,
id: UUID,
vector_db: VectorStore,
llm: BaseChatModel,
embedder: Embeddings,
storage: StorageBase,
):
self.id = id
self.name = name
self.storage = storage
# Chat history
self.chat_history: list[str] = []
# RAG dependencies:
self.llm = llm
self.vector_db = vector_db
self.embedder = embedder
@classmethod
async def afrom_files(
cls,
*,
name: str,
file_paths: list[str | Path],
vector_db: VectorStore | None = None,
storage: StorageBase = TransparentStorage(),
llm: BaseChatModel | None = None,
embedder: Embeddings | None = None,
processors_mapping: Mapping[str, ProcessorBase] = DEFAULT_PARSERS,
skip_file_error: bool = False,
):
if llm is None:
try:
from langchain_openai import ChatOpenAI
logger.debug("Loaded ChatOpenAI as default LLM for brain")
llm = ChatOpenAI()
except ImportError as e:
raise ImportError(
"Please provide a valid BaseLLM or install quivr-core['base'] package"
) from e
if embedder is None:
try:
from langchain_openai import OpenAIEmbeddings
logger.debug("Loaded OpenAIEmbeddings as default LLM for brain")
embedder = OpenAIEmbeddings()
except ImportError as e:
raise ImportError(
"Please provide a valid Embedder or install quivr-core['base'] package for using the defaultone."
) from e
brain_id = uuid4()
for path in file_paths:
file = QuivrFile.from_path(brain_id, path)
storage.upload_file(file)
# Parse files
docs = await _process_files(
storage=storage,
processors_mapping=processors_mapping,
skip_file_error=skip_file_error,
)
# Building brain's vectordb
if vector_db is None:
try:
from langchain_community.vectorstores import FAISS
logger.debug("Using Faiss-CPU as vector store.")
# TODO(@aminediro) : embedding call is not concurrent for all documents but waits
# We can actually wait on all processing
if len(docs) > 0:
vector_db = await FAISS.afrom_documents(
documents=docs, embedding=embedder
)
else:
raise ValueError("can't initialize brain without documents")
except ImportError as e:
raise ImportError(
"Please provide a valid vectore store or install quivr-core['base'] package for using the default one."
) from e
else:
vector_db.add_documents(docs)
return cls(
id=brain_id,
name=name,
storage=storage,
llm=llm,
embedder=embedder,
vector_db=vector_db,
)
@classmethod
def from_files(
cls,
*,
name: str,
file_paths: list[str | Path],
vector_db: VectorStore | None = None,
storage: StorageBase = TransparentStorage(),
llm: BaseChatModel | None = None,
embedder: Embeddings | None = None,
processors_mapping: Mapping[str, ProcessorBase] = DEFAULT_PARSERS,
skip_file_error: bool = False,
) -> Self:
return asyncio.run(
cls.afrom_files(
name=name,
file_paths=file_paths,
vector_db=vector_db,
storage=storage,
llm=llm,
embedder=embedder,
processors_mapping=processors_mapping,
skip_file_error=skip_file_error,
)
)
# TODO(@aminediro)
def add_file(self) -> None:
# add it to storage
# add it to vectorstore
raise NotImplementedError
def ask(
self, question: str, rag_config: RAGConfig = RAGConfig()
) -> ParsedRAGResponse:
rag_pipeline = QuivrQARAG(
rag_config=rag_config, llm=self.llm, vector_store=self.vector_db
)
# transformed_history = format_chat_history(history)
parsed_response = rag_pipeline.answer(question, [], [])
# Save answer to the chat history
return parsed_response

View File

@ -3,9 +3,9 @@ from pydantic import BaseModel, field_validator
class RAGConfig(BaseModel):
model: str = "gpt-3.5-turbo-0125" # pyright: ignore reportPrivateUsage=none
temperature: float | None = 0.1
max_tokens: int | None = 2000
temperature: float = 0.7
max_input: int = 2000
max_tokens: int | None = 2000
streaming: bool = False
max_files: int = 20
prompt: str | None = None

View File

@ -0,0 +1,6 @@
from quivr_core.processor.processor_base import ProcessorBase
from quivr_core.processor.txt_parser import TxtProcessor
DEFAULT_PARSERS: dict[str, ProcessorBase] = {
".txt": TxtProcessor(),
}

View File

@ -0,0 +1,62 @@
import logging
from typing import AsyncIterable
import httpx
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from quivr_core.processor.processor_base import ProcessorBase
from quivr_core.processor.splitter import SplitterConfig
from quivr_core.storage.file import QuivrFile
logger = logging.getLogger(__name__)
class TikaParser(ProcessorBase):
supported_extensions = [".pdf"]
def __init__(
self,
tika_url: str = "http://localhost:9998/tika",
splitter: TextSplitter | None = None,
splitter_config: SplitterConfig = SplitterConfig(),
timeout: float = 5.0,
max_retries: int = 3,
) -> None:
self.tika_url = tika_url
self.max_retries = max_retries
self._client = httpx.AsyncClient(timeout=timeout)
self.splitter_config = splitter_config
if splitter:
self.text_splitter = splitter
else:
self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=splitter_config.chunk_size,
chunk_overlap=splitter_config.chunk_overlap,
)
async def _send_parse_tika(self, f: AsyncIterable[bytes]) -> str:
retry = 0
headers = {"Accept": "text/plain"}
while retry < self.max_retries:
try:
resp = await self._client.put(self.tika_url, headers=headers, content=f)
resp.raise_for_status()
return resp.content.decode("utf-8")
except Exception as e:
retry += 1
logger.debug(f"tika url error :{e}. retrying for the {retry} time...")
raise RuntimeError("can't send parse request to tika server")
async def process_file(self, file: QuivrFile) -> list[Document]:
assert file.file_extension in self.supported_extensions
async with file.open() as f:
txt = await self._send_parse_tika(f)
document = Document(page_content=txt)
# Use the default splitter
docs = self.text_splitter.split_documents([document])
return docs

View File

@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from langchain_core.documents import Document
from quivr_core.storage.file import QuivrFile
class ProcessorBase(ABC):
supported_extensions: list[str]
@abstractmethod
async def process_file(self, file: QuivrFile) -> list[Document]:
pass
P = TypeVar("P", bound=ProcessorBase)
class ProcessorsMapping(Generic[P]):
def __init__(self, mapping: dict[str, P]) -> None:
# Create an empty list with items of type T
self.ext_parser: dict[str, P] = mapping
def add_parser(self, extension: str, parser: P):
# TODO: deal with existing ext keys
self.ext_parser[extension] = parser

View File

@ -0,0 +1,6 @@
from pydantic import BaseModel
class SplitterConfig(BaseModel):
chunk_size: int = 400
chunk_overlap: int = 100

View File

@ -0,0 +1,36 @@
from langchain_community.document_loaders.text import TextLoader
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from quivr_core.processor.processor_base import ProcessorBase
from quivr_core.processor.splitter import SplitterConfig
from quivr_core.storage.file import QuivrFile
class TxtProcessor(ProcessorBase):
def __init__(
self,
splitter: TextSplitter | None = None,
splitter_config: SplitterConfig = SplitterConfig(),
) -> None:
self.supported_extensions = [".txt"]
self.loader_cls = TextLoader
self.splitter_config = splitter_config
if splitter:
self.text_splitter = splitter
else:
self.text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=splitter_config.chunk_size,
chunk_overlap=splitter_config.chunk_overlap,
)
async def process_file(self, file: QuivrFile) -> list[Document]:
if file.file_extension not in self.supported_extensions:
raise Exception(f"can't process a file of type {file.file_extension}")
loader = self.loader_cls(file.path)
documents = await loader.aload()
docs = self.text_splitter.split_documents(documents)
return docs

View File

@ -1,15 +1,17 @@
import logging
from operator import itemgetter
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional, Sequence
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.language_models import BaseChatModel
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.vectorstores import VectorStore
from langchain_openai import ChatOpenAI
from quivr_core.config import RAGConfig
from quivr_core.models import (
ParsedRAGChunkResponse,
@ -30,18 +32,31 @@ from quivr_core.utils import (
logger = logging.getLogger(__name__)
class IdempotentCompressor(BaseDocumentCompressor):
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
return documents
class QuivrQARAG:
def __init__(
self,
*,
rag_config: RAGConfig,
llm: ChatLiteLLM,
llm: BaseChatModel,
vector_store: VectorStore,
reranker: BaseDocumentCompressor | None = None,
):
self.rag_config = rag_config
self.vector_store = vector_store
self.llm = llm
self.reranker = self._create_reranker()
self.reranker = reranker if reranker is not None else IdempotentCompressor()
self.supports_func_calling = model_supports_function_calling(
self.rag_config.model
)
@ -50,20 +65,6 @@ class QuivrQARAG:
def retriever(self):
return self.vector_store.as_retriever()
def _create_reranker(self):
# TODO: reranker config
# if os.getenv("COHERE_API_KEY"):
compressor = CohereRerank(top_n=20)
# else:
# ranker_model_name = "ms-marco-TinyBERT-L-2-v2"
# flashrank_client = Ranker(model_name=ranker_model_name)
# compressor = FlashrankRerank(
# client=flashrank_client, model=ranker_model_name, top_n=20
# )
# TODO @stangirard fix
return compressor
# TODO : refactor and simplify
def filter_history(
self, chat_history, max_history: int = 10, max_tokens: int = 2000
):
@ -136,18 +137,13 @@ class QuivrQARAG:
# Override llm if we have a OpenAI model
llm = self.llm
if self.supports_func_calling:
if self.rag_config.temperature:
llm_function = ChatOpenAI(
max_tokens=self.rag_config.max_tokens,
model=self.rag_config.model,
temperature=self.rag_config.temperature,
)
else:
llm_function = ChatOpenAI(
max_tokens=self.rag_config.max_tokens,
model=self.rag_config.model,
)
llm_function = ChatOpenAI(
model=self.rag_config.model,
max_tokens=self.rag_config.max_tokens,
temperature=self.rag_config.temperature,
)
llm = llm_function.bind_tools(
[cited_answer],

View File

@ -0,0 +1,59 @@
import os
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncGenerator, AsyncIterable
from uuid import UUID, uuid4
import aiofiles
class QuivrFile:
def __init__(
self,
id: UUID,
original_filename: str,
path: Path,
brain_id: UUID,
file_size: int | None = None,
file_extension: str | None = None,
) -> None:
self.id = id
self.brain_id = brain_id
self.path = path
self.original_filename = original_filename
self.file_size = file_size
self.file_extension = file_extension
@classmethod
def from_path(cls, brain_id: UUID, path: str | Path):
if not isinstance(path, Path):
path = Path(path)
if not path.exists():
raise FileExistsError(f"file {path} doesn't exist")
file_size = os.stat(path).st_size
try:
# NOTE: when loading from existing storage, file name will be uuid
id = UUID(path.name)
except ValueError:
id = uuid4()
return cls(
id=id,
brain_id=brain_id,
path=path,
original_filename=path.name,
file_size=file_size,
file_extension=path.suffix,
)
@asynccontextmanager
async def open(self) -> AsyncGenerator[AsyncIterable[bytes], None]:
# TODO(@aminediro) : match on path type
f = await aiofiles.open(self.path, mode="rb")
try:
yield f
finally:
await f.close()

View File

@ -0,0 +1,67 @@
import os
import shutil
from pathlib import Path
from uuid import UUID
from quivr_core.storage.file import QuivrFile
from quivr_core.storage.storage_base import StorageBase
class LocalStorage(StorageBase):
def __init__(self, dir_path: Path | None = None, copy_flag: bool = True):
self.files: list[QuivrFile] = []
self.copy_flag = copy_flag
if dir_path is None:
self.dir_path = Path(
os.getenv("QUIVR_LOCAL_STORAGE", "~/.cache/quivr/files")
)
else:
self.dir_path = dir_path
os.makedirs(self.dir_path, exist_ok=True)
def _load_files(self) -> None:
# TODO(@aminediro): load existing files
pass
def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
dst_path = os.path.join(
self.dir_path, str(file.brain_id), f"{file.id}{file.file_extension}"
)
# TODO(@aminediro): Check hash of file not file path
if os.path.exists(dst_path) and not exists_ok:
raise FileExistsError("file already exists")
if self.copy_flag:
shutil.copy2(file.path, dst_path)
else:
os.symlink(file.path, dst_path)
file.path = Path(dst_path)
self.files.append(file)
def get_files(self) -> list[QuivrFile]:
return self.files
def remove_file(self, file_id: UUID) -> None:
raise NotImplementedError
class TransparentStorage(StorageBase):
"""Transparent Storage.
uses default
"""
def __init__(self):
self.files = []
def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
self.files.append(file)
def remove_file(self, file_id: UUID) -> None:
raise NotImplementedError
def get_files(self) -> list[QuivrFile]:
return self.files

View File

@ -0,0 +1,18 @@
from abc import ABC, abstractmethod
from uuid import UUID
from quivr_core.storage.local_storage import QuivrFile
class StorageBase(ABC):
@abstractmethod
def get_files(self) -> list[QuivrFile]:
raise Exception("Unimplemented get_files method")
@abstractmethod
def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
raise Exception("Unimplemented upload_file method")
@abstractmethod
def remove_file(self, file_id: UUID) -> None:
raise Exception("Unimplemented remove_file method")

View File

View File

Binary file not shown.

View File

@ -0,0 +1,38 @@
from pathlib import Path
from uuid import uuid4
import pytest
from quivr_core.processor.pdf_processor import TikaParser
from quivr_core.storage.file import QuivrFile
# TODO: TIKA server should be set
@pytest.fixture
def pdf():
return QuivrFile(
id=uuid4(),
brain_id=uuid4(),
original_filename="dummy.pdf",
path=Path("./tests/processor/data/dummy.pdf"),
file_extension=".pdf",
)
@pytest.mark.asyncio
async def test_process_file(pdf):
tparser = TikaParser()
doc = await tparser.process_file(pdf)
assert len(doc) > 0
assert doc[0].page_content.strip("\n") == "Dummy PDF download"
@pytest.mark.asyncio
async def test_send_parse_tika_exception(pdf):
# TODO: Mock correct tika for retries
tparser = TikaParser(tika_url="test.test")
with pytest.raises(RuntimeError):
doc = await tparser.process_file(pdf)
assert len(doc) > 0
assert doc[0].page_content.strip("\n") == "Dummy PDF download"

View File

@ -0,0 +1,48 @@
import pytest
from langchain_core.embeddings import DeterministicFakeEmbedding
from langchain_core.language_models import FakeListChatModel
from quivr_core.brain import Brain
@pytest.fixture
def temp_data_file(tmpdir):
data = "This is some test data."
temp_file = tmpdir.join("data.txt")
temp_file.write(data)
return temp_file
@pytest.fixture
def answers():
return [f"answer_{i}" for i in range(10)]
@pytest.fixture(scope="function")
def llm(answers: list[str]):
return FakeListChatModel(responses=answers)
@pytest.fixture(scope="function")
def embedder():
return DeterministicFakeEmbedding(size=20)
def test_brain_from_files_exception():
# Testing no files
with pytest.raises(ValueError):
Brain.from_files(name="test_brain", file_paths=[])
def test_brain_ask_txt(llm, embedder, temp_data_file, answers):
brain = Brain.from_files(
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=llm
)
assert brain.llm == llm
assert brain.vector_db.embeddings == embedder
answer = brain.ask("question")
assert answer.answer == answers[0]
assert answer.metadata == answers[0]