feat: quivr core brain info + processors registry + (#2877)

# Description

- Created registry processor logic for automagically adding processors
to quivr_core based Entrypoints
- Added a langchain_community free `SimpleTxtParser` for the quivr_core
base package
- Added tests
- Added brain_info 
- Enriched parsed documents metadata based on quivr_file metadata

used Rich for `Brain.print_info()` to get a better output: 

![image](https://github.com/user-attachments/assets/dd9f2f03-d7d7-4be0-ba6c-3fe38e11c40f)
This commit is contained in:
AmineDiro 2024-07-19 09:47:39 +02:00 committed by GitHub
parent 3b68855a83
commit 3001fa1475
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 845 additions and 155 deletions

View File

@ -1,19 +1,42 @@
from langchain_core.embeddings import DeterministicFakeEmbedding from langchain_core.embeddings import DeterministicFakeEmbedding
from langchain_core.language_models import FakeListChatModel from langchain_core.language_models import FakeListChatModel
from rich.console import Console
from rich.panel import Panel
from rich.prompt import Prompt
from quivr_core import Brain from quivr_core import Brain
from quivr_core.processor.default_parsers import DEFAULT_PARSERS from quivr_core.config import LLMEndpointConfig
from quivr_core.processor.pdf_processor import TikaParser from quivr_core.llm.llm_endpoint import LLMEndpoint
if __name__ == "__main__": if __name__ == "__main__":
pdf_paths = ["../tests/processor/data/dummy.pdf"]
brain = Brain.from_files( brain = Brain.from_files(
name="test_brain", name="test_brain",
file_paths=[], file_paths=["tests/processor/data/dummy.pdf"],
llm=FakeListChatModel(responses=["good"]), llm=LLMEndpoint(
llm=FakeListChatModel(responses=["good"]),
llm_config=LLMEndpointConfig(model="fake_model", llm_base_url="local"),
),
embedder=DeterministicFakeEmbedding(size=20), embedder=DeterministicFakeEmbedding(size=20),
processors_mapping={
**DEFAULT_PARSERS,
".pdf": TikaParser(),
},
) )
# Check brain info
brain.print_info()
console = Console()
console.print(Panel.fit("Ask your brain !", style="bold magenta"))
while True:
# Get user input
question = Prompt.ask("[bold cyan]Question[/bold cyan]")
# Check if user wants to exit
if question.lower() == "exit":
console.print(Panel("Goodbye!", style="bold yellow"))
break
answer = brain.ask(question)
# Print the answer with typing effect
console.print(f"[bold green]Quivr Assistant[/bold green]: {answer.answer}")
console.print("-" * console.width)
brain.print_info()

View File

@ -12,3 +12,5 @@ if __name__ == "__main__":
answer = brain.ask("Property of gold?") answer = brain.ask("Property of gold?")
print("answer :", answer.answer) print("answer :", answer.answer)
print("brain information: ", brain)

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.0 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]] [[package]]
name = "aiofiles" name = "aiofiles"
@ -1189,13 +1189,13 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0"
[[package]] [[package]]
name = "langchain-core" name = "langchain-core"
version = "0.2.19" version = "0.2.20"
description = "Building applications with LLMs through composability" description = "Building applications with LLMs through composability"
optional = false optional = false
python-versions = "<4.0,>=3.8.1" python-versions = "<4.0,>=3.8.1"
files = [ files = [
{file = "langchain_core-0.2.19-py3-none-any.whl", hash = "sha256:5b3cd34395be274c89e822c84f0e03c4da14168c177a83921c5b9414ac7a0651"}, {file = "langchain_core-0.2.20-py3-none-any.whl", hash = "sha256:16cc4da6f7ebf33accea7af45a70480733dc852ab291030fb6924865bd7caf76"},
{file = "langchain_core-0.2.19.tar.gz", hash = "sha256:13043a83e5c9ab58b9f5ce2a56896e7e88b752e8891b2958960a98e71801471e"}, {file = "langchain_core-0.2.20.tar.gz", hash = "sha256:a66c439e085d8c75f822f7650a5551d17bada4003521173c763d875d949e4ed5"},
] ]
[package.dependencies] [package.dependencies]
@ -1241,13 +1241,13 @@ langchain-core = ">=0.2.10,<0.3.0"
[[package]] [[package]]
name = "langsmith" name = "langsmith"
version = "0.1.85" version = "0.1.88"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false optional = false
python-versions = "<4.0,>=3.8.1" python-versions = "<4.0,>=3.8.1"
files = [ files = [
{file = "langsmith-0.1.85-py3-none-any.whl", hash = "sha256:c1f94384f10cea96f7b4d33fd3db7ec180c03c7468877d50846f881d2017ff94"}, {file = "langsmith-0.1.88-py3-none-any.whl", hash = "sha256:460ebb7de440afd150fcea8f54ca8779821f2228cd59e149e5845c9dbe06db16"},
{file = "langsmith-0.1.85.tar.gz", hash = "sha256:acff31f9e53efa48586cf8e32f65625a335c74d7c4fa306d1655ac18452296f6"}, {file = "langsmith-0.1.88.tar.gz", hash = "sha256:28a07dec19197f4808aa2628d5a3ccafcbe14cc137aef0e607bbd128e7907821"},
] ]
[package.dependencies] [package.dependencies]
@ -1258,6 +1258,30 @@ pydantic = [
] ]
requests = ">=2,<3" requests = ">=2,<3"
[[package]]
name = "markdown-it-py"
version = "3.0.0"
description = "Python port of markdown-it. Markdown parsing, done right!"
optional = false
python-versions = ">=3.8"
files = [
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
]
[package.dependencies]
mdurl = ">=0.1,<1.0"
[package.extras]
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
code-style = ["pre-commit (>=3.0,<4.0)"]
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
linkify = ["linkify-it-py (>=1,<3)"]
plugins = ["mdit-py-plugins"]
profiling = ["gprof2dot"]
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
[[package]] [[package]]
name = "marshmallow" name = "marshmallow"
version = "3.21.3" version = "3.21.3"
@ -1302,6 +1326,17 @@ files = [
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
] ]
[[package]]
name = "mdurl"
version = "0.1.2"
description = "Markdown URL utilities"
optional = false
python-versions = ">=3.7"
files = [
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
]
[[package]] [[package]]
name = "multidict" name = "multidict"
version = "6.0.5" version = "6.0.5"
@ -1527,13 +1562,13 @@ files = [
[[package]] [[package]]
name = "openai" name = "openai"
version = "1.35.13" version = "1.35.14"
description = "The official Python library for the openai API" description = "The official Python library for the openai API"
optional = true optional = true
python-versions = ">=3.7.1" python-versions = ">=3.7.1"
files = [ files = [
{file = "openai-1.35.13-py3-none-any.whl", hash = "sha256:36ec3e93e0d1f243f69be85c89b9221a471c3e450dfd9df16c9829e3cdf63e60"}, {file = "openai-1.35.14-py3-none-any.whl", hash = "sha256:adadf8c176e0b8c47ad782ed45dc20ef46438ee1f02c7103c4155cff79c8f68b"},
{file = "openai-1.35.13.tar.gz", hash = "sha256:c684f3945608baf7d2dcc0ef3ee6f3e27e4c66f21076df0b47be45d57e6ae6e4"}, {file = "openai-1.35.14.tar.gz", hash = "sha256:394ba1dfd12ecec1d634c50e512d24ff1858bbc2674ffcce309b822785a058de"},
] ]
[package.dependencies] [package.dependencies]
@ -2328,6 +2363,24 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"] socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rich"
version = "13.7.1"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"},
{file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"},
]
[package.dependencies]
markdown-it-py = ">=2.2.0"
pygments = ">=2.13.0,<3.0.0"
[package.extras]
jupyter = ["ipywidgets (>=7.5.1,<9)"]
[[package]] [[package]]
name = "ruff" name = "ruff"
version = "0.4.10" version = "0.4.10"
@ -2788,4 +2841,4 @@ pdf = []
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "f5ab0d3f93f3bd517382a8c4207a4a480f4a6dcb0aeaf60501c379a4a91a3ad0" content-hash = "cc08905f149df2f415e1d00010e5b89a371efdcf4059855d597b1b6e9973a536"

View File

@ -15,6 +15,7 @@ aiofiles = ">=23.0.0,<25.0.0"
faiss-cpu = { version = "^1.8.0.post1", optional = true } faiss-cpu = { version = "^1.8.0.post1", optional = true }
langchain-community = { version = "^0.2.6", optional = true } langchain-community = { version = "^0.2.6", optional = true }
langchain-openai = { version = "^0.1.14", optional = true } langchain-openai = { version = "^0.1.14", optional = true }
rich = "^13.7.1"
[tool.poetry.extras] [tool.poetry.extras]
base = ["langchain-community", "faiss-cpu", "langchain-openai"] base = ["langchain-community", "faiss-cpu", "langchain-openai"]
@ -81,6 +82,10 @@ known-first-party = []
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "--tb=short -ra -v" addopts = "--tb=short -ra -v"
filterwarnings = ["ignore::DeprecationWarning"] filterwarnings = ["ignore::DeprecationWarning"]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"base: these tests require quivr-core with extra `base` to be installed",
]
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]

View File

@ -1,3 +1,35 @@
from .brain import Brain from importlib.metadata import entry_points
__all__ = ["Brain"] from .brain import Brain
from .processor.registry import register_processor, registry
__all__ = ["Brain", "registry", "register_processor"]
def register_entries():
if entry_points is not None:
try:
eps = entry_points()
except TypeError:
pass # importlib-metadata < 0.8
else:
if hasattr(eps, "select"): # Python 3.10+ / importlib_metadata >= 3.9.0
processors = eps.select(group="quivr_core.processor")
else:
processors = eps.get("quivr_core.processor", [])
registered_names = set()
for spec in processors:
err_msg = f"Unable to load processor from {spec}"
name = spec.name
if name in registered_names:
continue
registered_names.add(name)
register_processor(
name,
spec.value.replace(":", "."),
errtxt=err_msg,
override=True,
)
register_entries()

View File

@ -0,0 +1,3 @@
from .brain import Brain
__all__ = ["Brain"]

View File

@ -1,29 +1,34 @@
import asyncio import asyncio
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, AsyncGenerator, Callable, Dict, Mapping, Self from pprint import PrettyPrinter
from typing import Any, AsyncGenerator, Callable, Dict, Self
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from rich.console import Console
from rich.panel import Panel
from quivr_core.brain.info import BrainInfo, ChatHistoryInfo
from quivr_core.chat import ChatHistory from quivr_core.chat import ChatHistory
from quivr_core.config import LLMEndpointConfig, RAGConfig from quivr_core.config import LLMEndpointConfig, RAGConfig
from quivr_core.llm import LLMEndpoint from quivr_core.llm import LLMEndpoint
from quivr_core.models import ParsedRAGChunkResponse, ParsedRAGResponse, SearchResult from quivr_core.models import ParsedRAGChunkResponse, ParsedRAGResponse, SearchResult
from quivr_core.processor.default_parsers import DEFAULT_PARSERS from quivr_core.processor.registry import get_processor_class
from quivr_core.processor.processor_base import ProcessorBase
from quivr_core.quivr_rag import QuivrQARAG from quivr_core.quivr_rag import QuivrQARAG
from quivr_core.storage.file import QuivrFile from quivr_core.storage.file import load_qfile
from quivr_core.storage.local_storage import TransparentStorage from quivr_core.storage.local_storage import TransparentStorage
from quivr_core.storage.storage_base import StorageBase from quivr_core.storage.storage_base import StorageBase
logger = logging.getLogger("quivr_core") logger = logging.getLogger("quivr_core")
async def _default_vectordb(docs: list[Document], embedder: Embeddings) -> VectorStore: async def _build_default_vectordb(
docs: list[Document], embedder: Embeddings
) -> VectorStore:
try: try:
from langchain_community.vectorstores import FAISS from langchain_community.vectorstores import FAISS
@ -38,7 +43,7 @@ async def _default_vectordb(docs: list[Document], embedder: Embeddings) -> Vecto
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Please provide a valid vectore store or install quivr-core['base'] package for using the default one." "Please provide a valid vector store or install quivr-core['base'] package for using the default one."
) from e ) from e
@ -67,16 +72,15 @@ def _default_llm() -> LLMEndpoint:
) from e ) from e
async def _process_files( async def process_files(
storage: StorageBase, storage: StorageBase, skip_file_error: bool, **processor_kwargs: dict[str, Any]
skip_file_error: bool,
processors_mapping: Mapping[str, ProcessorBase],
) -> list[Document]: ) -> list[Document]:
knowledge = [] knowledge = []
for file in storage.get_files(): for file in await storage.get_files():
try: try:
if file.file_extension: if file.file_extension:
processor = processors_mapping[file.file_extension] processor_cls = get_processor_class(file.file_extension)
processor = processor_cls(**processor_kwargs)
docs = await processor.process_file(file) docs = await processor.process_file(file)
knowledge.extend(docs) knowledge.extend(docs)
else: else:
@ -118,11 +122,38 @@ class Brain:
self.vector_db = vector_db self.vector_db = vector_db
self.embedder = embedder self.embedder = embedder
def __repr__(self) -> str:
pp = PrettyPrinter(width=80, depth=None, compact=False, sort_dicts=False)
return pp.pformat(self.info())
def print_info(self):
console = Console()
tree = self.info().to_tree()
panel = Panel(tree, title="Brain Info", expand=False, border_style="bold")
console.print(panel)
def info(self) -> BrainInfo:
# TODO: dim of embedding
# "embedder": {},
chats_info = ChatHistoryInfo(
nb_chats=len(self._chats),
current_default_chat=self.default_chat.id,
current_chat_history_length=len(self.default_chat),
)
return BrainInfo(
brain_id=self.id,
brain_name=self.name,
files_info=self.storage.info(),
chats_info=chats_info,
llm_info=self.llm.info(),
)
@property @property
def chat_history(self): def chat_history(self) -> ChatHistory:
return self.default_chat return self.default_chat
def _init_chats(self): def _init_chats(self) -> Dict[UUID, ChatHistory]:
chat_id = uuid4() chat_id = uuid4()
default_chat = ChatHistory(chat_id=chat_id, brain_id=self.id) default_chat = ChatHistory(chat_id=chat_id, brain_id=self.id)
return {chat_id: default_chat} return {chat_id: default_chat}
@ -137,7 +168,6 @@ class Brain:
storage: StorageBase = TransparentStorage(), storage: StorageBase = TransparentStorage(),
llm: LLMEndpoint | None = None, llm: LLMEndpoint | None = None,
embedder: Embeddings | None = None, embedder: Embeddings | None = None,
processors_mapping: Mapping[str, ProcessorBase] = DEFAULT_PARSERS,
skip_file_error: bool = False, skip_file_error: bool = False,
): ):
if llm is None: if llm is None:
@ -148,20 +178,20 @@ class Brain:
brain_id = uuid4() brain_id = uuid4()
# TODO: run in parallel using tasks
for path in file_paths: for path in file_paths:
file = QuivrFile.from_path(brain_id, path) file = await load_qfile(brain_id, path)
storage.upload_file(file) await storage.upload_file(file)
# Parse files # Parse files
docs = await _process_files( docs = await process_files(
storage=storage, storage=storage,
processors_mapping=processors_mapping,
skip_file_error=skip_file_error, skip_file_error=skip_file_error,
) )
# Building brain's vectordb # Building brain's vectordb
if vector_db is None: if vector_db is None:
vector_db = await _default_vectordb(docs, embedder) vector_db = await _build_default_vectordb(docs, embedder)
else: else:
await vector_db.aadd_documents(docs) await vector_db.aadd_documents(docs)
@ -184,10 +214,10 @@ class Brain:
storage: StorageBase = TransparentStorage(), storage: StorageBase = TransparentStorage(),
llm: LLMEndpoint | None = None, llm: LLMEndpoint | None = None,
embedder: Embeddings | None = None, embedder: Embeddings | None = None,
processors_mapping: Mapping[str, ProcessorBase] = DEFAULT_PARSERS,
skip_file_error: bool = False, skip_file_error: bool = False,
) -> Self: ) -> Self:
return asyncio.run( loop = asyncio.get_event_loop()
return loop.run_until_complete(
cls.afrom_files( cls.afrom_files(
name=name, name=name,
file_paths=file_paths, file_paths=file_paths,
@ -195,7 +225,6 @@ class Brain:
storage=storage, storage=storage,
llm=llm, llm=llm,
embedder=embedder, embedder=embedder,
processors_mapping=processors_mapping,
skip_file_error=skip_file_error, skip_file_error=skip_file_error,
) )
) )
@ -221,7 +250,7 @@ class Brain:
# Building brain's vectordb # Building brain's vectordb
if vector_db is None: if vector_db is None:
vector_db = await _default_vectordb(langchain_documents, embedder) vector_db = await _build_default_vectordb(langchain_documents, embedder)
else: else:
await vector_db.aadd_documents(langchain_documents) await vector_db.aadd_documents(langchain_documents)

View File

@ -0,0 +1,74 @@
from dataclasses import dataclass
from uuid import UUID
from rich.tree import Tree
@dataclass
class ChatHistoryInfo:
nb_chats: int
current_default_chat: UUID
current_chat_history_length: int
def add_to_tree(self, chats_tree: Tree):
chats_tree.add(f"Number of Chats: [bold]{self.nb_chats}[/bold]")
chats_tree.add(
f"Current Default Chat: [bold magenta]{self.current_default_chat}[/bold magenta]"
)
chats_tree.add(
f"Current Chat History Length: [bold]{self.current_chat_history_length}[/bold]"
)
@dataclass
class LLMInfo:
model: str
llm_base_url: str
temperature: float
max_tokens: int
supports_function_calling: int
def add_to_tree(self, llm_tree: Tree):
llm_tree.add(f"Model: [italic]{self.model}[/italic]")
llm_tree.add(f"Base URL: [underline]{self.llm_base_url}[/underline]")
llm_tree.add(f"Temperature: [bold]{self.temperature}[/bold]")
llm_tree.add(f"Max Tokens: [bold]{self.max_tokens}[/bold]")
func_call_color = "green" if self.supports_function_calling else "red"
llm_tree.add(
f"Supports Function Calling: [bold {func_call_color}]{self.supports_function_calling}[/bold {func_call_color}]"
)
@dataclass
class StorageInfo:
storage_type: str
n_files: int
def add_to_tree(self, files_tree: Tree):
files_tree.add(f"Storage Type: [italic]{self.storage_type}[/italic]")
files_tree.add(f"Number of Files: [bold]{self.n_files}[/bold]")
@dataclass
class BrainInfo:
brain_id: UUID
brain_name: str
files_info: StorageInfo
chats_info: ChatHistoryInfo
llm_info: LLMInfo
def to_tree(self):
tree = Tree("📊 Brain Information")
tree.add(f"🆔 ID: [bold cyan]{self.brain_id}[/bold cyan]")
tree.add(f"🧠 Brain Name: [bold green]{self.brain_name}[/bold green]")
files_tree = tree.add("📁 Files")
self.files_info.add_to_tree(files_tree)
chats_tree = tree.add("💬 Chats")
self.chats_info.add_to_tree(chats_tree)
llm_tree = tree.add("🤖 LLM")
self.llm_info.add_to_tree(llm_tree)
return tree

View File

@ -1,6 +1,7 @@
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from pydantic.v1 import SecretStr from pydantic.v1 import SecretStr
from quivr_core.brain.info import LLMInfo
from quivr_core.config import LLMEndpointConfig from quivr_core.config import LLMEndpointConfig
from quivr_core.utils import model_supports_function_calling from quivr_core.utils import model_supports_function_calling
@ -35,3 +36,14 @@ class LLMEndpoint:
def supports_func_calling(self) -> bool: def supports_func_calling(self) -> bool:
return self._supports_func_calling return self._supports_func_calling
def info(self) -> LLMInfo:
return LLMInfo(
model=self._config.model,
llm_base_url=(
self._config.llm_base_url if self._config.llm_base_url else "openai"
),
temperature=self._config.temperature,
max_tokens=self._config.max_tokens,
supports_function_calling=self.supports_func_calling(),
)

View File

@ -1,6 +0,0 @@
from quivr_core.processor.processor_base import ProcessorBase
from quivr_core.processor.txt_parser import TxtProcessor
DEFAULT_PARSERS: dict[str, ProcessorBase] = {
".txt": TxtProcessor(),
}

View File

@ -1,27 +1,20 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from langchain_core.documents import Document from langchain_core.documents import Document
from quivr_core.storage.file import QuivrFile from quivr_core.storage.file import FileExtension, QuivrFile
# TODO: processors should be cached somewhere ?
# The processor should be cached by processor type
# The cache should use a single
class ProcessorBase(ABC): class ProcessorBase(ABC):
supported_extensions: list[str] supported_extensions: list[FileExtension | str]
@abstractmethod @abstractmethod
async def process_file(self, file: QuivrFile) -> list[Document]: async def process_file(self, file: QuivrFile) -> list[Document]:
pass raise NotImplementedError
def check_supported(self, file: QuivrFile):
P = TypeVar("P", bound=ProcessorBase) if file.file_extension not in self.supported_extensions:
raise ValueError(f"can't process a file of type {file.file_extension}")
class ProcessorsMapping(Generic[P]):
def __init__(self, mapping: dict[str, P]) -> None:
# Create an empty list with items of type T
self.ext_parser: dict[str, P] = mapping
def add_parser(self, extension: str, parser: P):
# TODO: deal with existing ext keys
self.ext_parser[extension] = parser

View File

@ -0,0 +1,107 @@
import importlib
import types
from typing import Type, TypedDict
from quivr_core.storage.file import FileExtension
from .processor_base import ProcessorBase
_registry: dict[str, Type[ProcessorBase]] = {}
# external, read only
registry = types.MappingProxyType(_registry)
class ProcEntry(TypedDict):
cls_mod: str
err: str | None
# Register based on mimetypes
known_processors: dict[FileExtension | str, ProcEntry] = {
FileExtension.txt: ProcEntry(
cls_mod="quivr_core.processor.simple_txt_processor.SimpleTxtProcessor",
err="Please install quivr_core[base] to use TikTokenTxtProcessor ",
),
FileExtension.pdf: ProcEntry(
cls_mod="quivr_core.processor.tika_processor.TikaProcessor",
err=None,
),
}
def get_processor_class(file_extension: FileExtension | str) -> Type[ProcessorBase]:
"""Fetch processor class from registry
The dict ``known_processors`` maps file extensions to the locations
of processors that could process them.
Loading of these classes is *Lazy*. Appropriate import will happen
the first time we try to process some file type.
Some processors need additional dependencies. If the import fails
we return the "err" field of the ProcEntry in ``known_processors``.
"""
if file_extension not in registry:
if file_extension not in known_processors:
raise ValueError(f"Extension not known: {file_extension}")
entry = known_processors[file_extension]
try:
register_processor(file_extension, _import_class(entry["cls_mod"]))
except ImportError as e:
raise ImportError(entry["err"]) from e
cls = registry[file_extension]
return cls
def register_processor(
file_type: FileExtension | str,
proc_cls: str | Type[ProcessorBase],
override: bool = False,
errtxt=None,
):
if isinstance(proc_cls, str):
if file_type in known_processors and override is False:
if proc_cls != known_processors[file_type]["cls_mod"]:
raise ValueError(
f"Processor for ({file_type}) already in the registry and override is False"
)
else:
known_processors[file_type] = ProcEntry(
cls_mod=proc_cls,
err=errtxt or f"{proc_cls} import failed for processor of {file_type}",
)
else:
if file_type in registry and override is False:
if _registry[file_type] is not proc_cls:
raise ValueError(
f"Processor for ({file_type}) already in the registry and override is False"
)
else:
_registry[file_type] = proc_cls
def _import_class(full_mod_path: str):
if ":" in full_mod_path:
mod_name, name = full_mod_path.rsplit(":", 1)
else:
mod_name, name = full_mod_path.rsplit(".", 1)
mod = importlib.import_module(mod_name)
for cls in name.split("."):
mod = getattr(mod, cls)
if not isinstance(mod, type):
raise TypeError(f"{full_mod_path} is not a class")
if not issubclass(mod, ProcessorBase):
raise TypeError(f"{full_mod_path} is not a subclass of ProcessorBase ")
return mod
def available_processors():
"""Return a list of the known processors."""
return list(known_processors)

View File

@ -0,0 +1,62 @@
from importlib.metadata import version
from uuid import uuid4
import aiofiles
from langchain_core.documents import Document
from quivr_core.processor.processor_base import ProcessorBase
from quivr_core.processor.registry import FileExtension
from quivr_core.processor.splitter import SplitterConfig
from quivr_core.storage.file import QuivrFile
def recursive_character_splitter(
doc: Document, chunk_size: int, chunk_overlap: int
) -> list[Document]:
assert chunk_overlap < chunk_size, "chunk_overlap is greater than chunk_size"
if len(doc.page_content) <= chunk_size:
return [doc]
chunk = Document(page_content=doc.page_content[:chunk_size], metadata=doc.metadata)
remaining = Document(
page_content=doc.page_content[chunk_size - chunk_overlap :],
metadata=doc.metadata,
)
return [chunk] + recursive_character_splitter(remaining, chunk_size, chunk_overlap)
class SimpleTxtProcessor(ProcessorBase):
supported_extensions = [FileExtension.txt]
def __init__(
self, splitter_config: SplitterConfig = SplitterConfig(), **kwargs
) -> None:
super().__init__(**kwargs)
self.splitter_config = splitter_config
async def process_file(self, file: QuivrFile) -> list[Document]:
self.check_supported(file)
file_metadata = file.metadata
async with aiofiles.open(file.path, mode="r") as f:
content = await f.read()
doc = Document(
page_content=content,
metadata={
"id": uuid4(),
"chunk_size": len(content),
"chunk_overlap": self.splitter_config.chunk_overlap,
"parser_name": self.__class__.__name__,
"quivr_core_version": version("quivr-core"),
**file_metadata,
},
)
docs = recursive_character_splitter(
doc, self.splitter_config.chunk_size, self.splitter_config.chunk_overlap
)
return docs

View File

@ -1,4 +1,5 @@
import logging import logging
from importlib.metadata import version
from typing import AsyncIterable from typing import AsyncIterable
import httpx import httpx
@ -6,14 +7,15 @@ from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from quivr_core.processor.processor_base import ProcessorBase from quivr_core.processor.processor_base import ProcessorBase
from quivr_core.processor.registry import FileExtension
from quivr_core.processor.splitter import SplitterConfig from quivr_core.processor.splitter import SplitterConfig
from quivr_core.storage.file import QuivrFile from quivr_core.storage.file import QuivrFile
logger = logging.getLogger("quivr_core") logger = logging.getLogger("quivr_core")
class TikaParser(ProcessorBase): class TikaProcessor(ProcessorBase):
supported_extensions = [".pdf"] supported_extensions = [FileExtension.pdf]
def __init__( def __init__(
self, self,
@ -51,12 +53,22 @@ class TikaParser(ProcessorBase):
raise RuntimeError("can't send parse request to tika server") raise RuntimeError("can't send parse request to tika server")
async def process_file(self, file: QuivrFile) -> list[Document]: async def process_file(self, file: QuivrFile) -> list[Document]:
assert file.file_extension in self.supported_extensions self.check_supported(file)
async with file.open() as f: async with file.open() as f:
txt = await self._send_parse_tika(f) txt = await self._send_parse_tika(f)
document = Document(page_content=txt) document = Document(page_content=txt)
# Use the default splitter # Use the default splitter
docs = self.text_splitter.split_documents([document]) docs = self.text_splitter.split_documents([document])
file_metadata = file.metadata
for doc in docs:
doc.metadata = {
"chunk_size": len(doc.page_content),
"chunk_overlap": self.splitter_config.chunk_overlap,
"parser_name": self.__class__.__name__,
"quivr_core_version": version("quivr-core"),
**file_metadata,
}
return docs return docs

View File

@ -1,19 +1,24 @@
from importlib.metadata import version
from uuid import uuid4
from langchain_community.document_loaders.text import TextLoader from langchain_community.document_loaders.text import TextLoader
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from quivr_core.processor.processor_base import ProcessorBase from quivr_core.processor.processor_base import ProcessorBase
from quivr_core.processor.registry import FileExtension
from quivr_core.processor.splitter import SplitterConfig from quivr_core.processor.splitter import SplitterConfig
from quivr_core.storage.file import QuivrFile from quivr_core.storage.file import QuivrFile
class TxtProcessor(ProcessorBase): class TikTokenTxtProcessor(ProcessorBase):
supported_extensions = [FileExtension.txt]
def __init__( def __init__(
self, self,
splitter: TextSplitter | None = None, splitter: TextSplitter | None = None,
splitter_config: SplitterConfig = SplitterConfig(), splitter_config: SplitterConfig = SplitterConfig(),
) -> None: ) -> None:
self.supported_extensions = [".txt"]
self.loader_cls = TextLoader self.loader_cls = TextLoader
self.splitter_config = splitter_config self.splitter_config = splitter_config
@ -27,10 +32,22 @@ class TxtProcessor(ProcessorBase):
) )
async def process_file(self, file: QuivrFile) -> list[Document]: async def process_file(self, file: QuivrFile) -> list[Document]:
if file.file_extension not in self.supported_extensions: self.check_supported(file)
raise Exception(f"can't process a file of type {file.file_extension}")
loader = self.loader_cls(file.path) loader = self.loader_cls(file.path)
documents = await loader.aload() documents = await loader.aload()
docs = self.text_splitter.split_documents(documents) docs = self.text_splitter.split_documents(documents)
file_metadata = file.metadata
for doc in docs:
doc.metadata = {
"id": uuid4(),
"chunk_size": len(doc.page_content),
"chunk_overlap": self.splitter_config.chunk_overlap,
"parser_name": self.__class__.__name__,
"quivr_core_version": version("quivr-core"),
**file_metadata,
}
return docs return docs

View File

@ -1,21 +1,87 @@
import hashlib
import mimetypes
import os import os
import warnings
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import AsyncGenerator, AsyncIterable from typing import Any, AsyncGenerator, AsyncIterable
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import aiofiles import aiofiles
class FileExtension(str, Enum):
txt = ".txt"
pdf = ".pdf"
docx = ".docx"
def get_file_extension(file_path: Path) -> FileExtension | str:
try:
mime_type, _ = mimetypes.guess_type(file_path.name)
if mime_type:
mime_ext = mimetypes.guess_extension(mime_type)
if mime_ext:
return FileExtension(mime_ext)
return FileExtension(file_path.suffix)
except ValueError:
warnings.warn(
f"File {file_path.name} extension isn't recognized. Make sure you have registered a parser for {file_path.suffix}",
stacklevel=2,
)
return file_path.suffix
async def load_qfile(brain_id: UUID, path: str | Path):
if not isinstance(path, Path):
path = Path(path)
if not path.exists():
raise FileExistsError(f"file {path} doesn't exist")
file_size = os.stat(path).st_size
async with aiofiles.open(path, mode="rb") as f:
file_md5 = hashlib.md5(await f.read()).hexdigest()
try:
# NOTE: when loading from existing storage, file name will be uuid
id = UUID(path.name)
except ValueError:
id = uuid4()
return QuivrFile(
id=id,
brain_id=brain_id,
path=path,
original_filename=path.name,
file_extension=get_file_extension(path),
file_size=file_size,
file_md5=file_md5,
)
class QuivrFile: class QuivrFile:
__slots__ = [
"id",
"brain_id",
"path",
"original_filename",
"file_size",
"file_extension",
"file_md5",
]
def __init__( def __init__(
self, self,
id: UUID, id: UUID,
original_filename: str, original_filename: str,
path: Path, path: Path,
brain_id: UUID, brain_id: UUID,
file_md5: str,
file_extension: FileExtension | str,
file_size: int | None = None, file_size: int | None = None,
file_extension: str | None = None,
) -> None: ) -> None:
self.id = id self.id = id
self.brain_id = brain_id self.brain_id = brain_id
@ -23,31 +89,7 @@ class QuivrFile:
self.original_filename = original_filename self.original_filename = original_filename
self.file_size = file_size self.file_size = file_size
self.file_extension = file_extension self.file_extension = file_extension
self.file_md5 = file_md5
@classmethod
def from_path(cls, brain_id: UUID, path: str | Path):
if not isinstance(path, Path):
path = Path(path)
if not path.exists():
raise FileExistsError(f"file {path} doesn't exist")
file_size = os.stat(path).st_size
try:
# NOTE: when loading from existing storage, file name will be uuid
id = UUID(path.name)
except ValueError:
id = uuid4()
return cls(
id=id,
brain_id=brain_id,
path=path,
original_filename=path.name,
file_size=file_size,
file_extension=path.suffix,
)
@asynccontextmanager @asynccontextmanager
async def open(self) -> AsyncGenerator[AsyncIterable[bytes], None]: async def open(self) -> AsyncGenerator[AsyncIterable[bytes], None]:
@ -57,3 +99,13 @@ class QuivrFile:
yield f yield f
finally: finally:
await f.close() await f.close()
@property
def metadata(self) -> dict[str, Any]:
return {
"qfile_id": self.id,
"qfile_path": self.path,
"original_file_name": self.original_filename,
"file_md4": self.file_md5,
"file_size": self.file_size,
}

View File

@ -1,6 +1,7 @@
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Set
from uuid import UUID from uuid import UUID
from quivr_core.storage.file import QuivrFile from quivr_core.storage.file import QuivrFile
@ -8,8 +9,11 @@ from quivr_core.storage.storage_base import StorageBase
class LocalStorage(StorageBase): class LocalStorage(StorageBase):
name: str = "local_storage"
def __init__(self, dir_path: Path | None = None, copy_flag: bool = True): def __init__(self, dir_path: Path | None = None, copy_flag: bool = True):
self.files: list[QuivrFile] = [] self.files: list[QuivrFile] = []
self.hashes: Set[str] = set()
self.copy_flag = copy_flag self.copy_flag = copy_flag
if dir_path is None: if dir_path is None:
@ -24,14 +28,19 @@ class LocalStorage(StorageBase):
# TODO(@aminediro): load existing files # TODO(@aminediro): load existing files
pass pass
def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None: def nb_files(self) -> int:
return len(self.files)
def info(self):
return {"directory_path": self.dir_path, **super().info()}
async def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
dst_path = os.path.join( dst_path = os.path.join(
self.dir_path, str(file.brain_id), f"{file.id}{file.file_extension}" self.dir_path, str(file.brain_id), f"{file.id}{file.file_extension}"
) )
# TODO(@aminediro): Check hash of file not file path if file.file_md5 in self.hashes and not exists_ok:
if os.path.exists(dst_path) and not exists_ok: raise FileExistsError(f"file {file.original_filename} already uploaded")
raise FileExistsError("file already exists")
if self.copy_flag: if self.copy_flag:
shutil.copy2(file.path, dst_path) shutil.copy2(file.path, dst_path)
@ -40,28 +49,31 @@ class LocalStorage(StorageBase):
file.path = Path(dst_path) file.path = Path(dst_path)
self.files.append(file) self.files.append(file)
self.hashes.add(file.file_md5)
def get_files(self) -> list[QuivrFile]: async def get_files(self) -> list[QuivrFile]:
return self.files return self.files
def remove_file(self, file_id: UUID) -> None: async def remove_file(self, file_id: UUID) -> None:
raise NotImplementedError raise NotImplementedError
class TransparentStorage(StorageBase): class TransparentStorage(StorageBase):
"""Transparent Storage. """Transparent Storage."""
uses default
""" name: str = "transparent_storage"
def __init__(self): def __init__(self):
self.files = [] self.id_files = {}
def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None: async def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
self.files.append(file) self.id_files[file.id] = file
def remove_file(self, file_id: UUID) -> None: def nb_files(self) -> int:
return len(self.id_files)
async def remove_file(self, file_id: UUID) -> None:
raise NotImplementedError raise NotImplementedError
def get_files(self) -> list[QuivrFile]: async def get_files(self) -> list[QuivrFile]:
return self.files return list(self.id_files.values())

View File

@ -1,18 +1,39 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from uuid import UUID from uuid import UUID
from quivr_core.brain.info import StorageInfo
from quivr_core.storage.local_storage import QuivrFile from quivr_core.storage.local_storage import QuivrFile
class StorageBase(ABC): class StorageBase(ABC):
name: str
def __init_subclass__(cls, **kwargs):
for required in ("name",):
if not getattr(cls, required):
raise TypeError(
f"Can't instantiate abstract class {cls.__name__} without {required} attribute defined"
)
return super().__init_subclass__(**kwargs)
@abstractmethod @abstractmethod
def get_files(self) -> list[QuivrFile]: def nb_files(self) -> int:
raise Exception("Unimplemented nb_files method")
@abstractmethod
async def get_files(self) -> list[QuivrFile]:
raise Exception("Unimplemented get_files method") raise Exception("Unimplemented get_files method")
@abstractmethod @abstractmethod
def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None: async def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
raise Exception("Unimplemented upload_file method") raise Exception("Unimplemented upload_file method")
@abstractmethod @abstractmethod
def remove_file(self, file_id: UUID) -> None: async def remove_file(self, file_id: UUID) -> None:
raise Exception("Unimplemented remove_file method") raise Exception("Unimplemented remove_file method")
def info(self) -> StorageInfo:
return StorageInfo(
storage_type=self.name,
n_files=self.nb_files(),
)

View File

@ -1,5 +1,7 @@
import json import json
import os import os
from pathlib import Path
from uuid import uuid4
import pytest import pytest
from langchain_core.embeddings import DeterministicFakeEmbedding from langchain_core.embeddings import DeterministicFakeEmbedding
@ -10,6 +12,39 @@ from langchain_core.vectorstores import InMemoryVectorStore
from quivr_core.config import LLMEndpointConfig from quivr_core.config import LLMEndpointConfig
from quivr_core.llm import LLMEndpoint from quivr_core.llm import LLMEndpoint
from quivr_core.storage.file import FileExtension, QuivrFile
@pytest.fixture(scope="function")
def temp_data_file(tmp_path):
data = "This is some test data."
temp_file = tmp_path / "data.txt"
temp_file.write_text(data)
return temp_file
@pytest.fixture(scope="function")
def quivr_txt(temp_data_file):
return QuivrFile(
id=uuid4(),
brain_id=uuid4(),
original_filename=temp_data_file.name,
path=temp_data_file,
file_extension=FileExtension.txt,
file_md5="123",
)
@pytest.fixture
def quivr_pdf():
return QuivrFile(
id=uuid4(),
brain_id=uuid4(),
original_filename="dummy.pdf",
path=Path("./tests/processor/data/dummy.pdf"),
file_extension=FileExtension.pdf,
file_md5="13bh234jh234",
)
@pytest.fixture @pytest.fixture
@ -36,14 +71,6 @@ def openai_api_key():
os.environ["OPENAI_API_KEY"] = "abcd" os.environ["OPENAI_API_KEY"] = "abcd"
@pytest.fixture(scope="function")
def temp_data_file(tmp_path):
data = "This is some test data."
temp_file = tmp_path / "data.txt"
temp_file.write_text(data)
return temp_file
@pytest.fixture @pytest.fixture
def answers(): def answers():
return [f"answer_{i}" for i in range(10)] return [f"answer_{i}" for i in range(10)]

View File

@ -0,0 +1,67 @@
import pytest
from langchain_core.documents import Document
from quivr_core import registry
from quivr_core.processor.processor_base import ProcessorBase
from quivr_core.processor.registry import (
_import_class,
get_processor_class,
register_processor,
)
from quivr_core.processor.simple_txt_processor import SimpleTxtProcessor
from quivr_core.processor.tika_processor import TikaProcessor
from quivr_core.storage.file import FileExtension, QuivrFile
def test_get_processor_cls():
cls = get_processor_class(FileExtension.txt)
assert cls == SimpleTxtProcessor
cls = get_processor_class(FileExtension.pdf)
assert cls == TikaProcessor
def test__import_class():
mod_path = "quivr_core.processor.tika_processor.TikaProcessor"
mod = _import_class(mod_path)
assert mod == TikaProcessor
with pytest.raises(TypeError, match=r".* is not a class"):
mod_path = "quivr_core.processor"
_import_class(mod_path)
with pytest.raises(TypeError, match=r".* ProcessorBase"):
mod_path = "quivr_core.Brain"
_import_class(mod_path)
def test_get_processor_cls_error():
with pytest.raises(ValueError):
get_processor_class(".docx")
def test_register_new_proc():
nprocs = len(registry)
class TestProcessor(ProcessorBase):
supported_extensions = [".test"]
async def process_file(self, file: QuivrFile) -> list[Document]:
return []
register_processor(".test", TestProcessor)
assert len(registry) == nprocs + 1
cls = get_processor_class(".test")
assert cls == TestProcessor
def test_register_override_proc():
class TestProcessor(ProcessorBase):
supported_extensions = [".pdf"]
async def process_file(self, file: QuivrFile) -> list[Document]:
return []
register_processor(".pdf", TestProcessor, override=True)
cls = get_processor_class(FileExtension.pdf)
assert cls == TestProcessor

View File

@ -0,0 +1,34 @@
import pytest
from langchain_core.documents import Document
from quivr_core.processor.simple_txt_processor import (
SimpleTxtProcessor,
recursive_character_splitter,
)
from quivr_core.processor.splitter import SplitterConfig
from quivr_core.storage.file import FileExtension
def test_recursive_character_splitter():
doc = Document(page_content="abcdefgh", metadata={"key": "value"})
docs = recursive_character_splitter(doc, chunk_size=2, chunk_overlap=1)
assert [d.page_content for d in docs] == ["ab", "bc", "cd", "de", "ef", "fg", "gh"]
assert [d.metadata for d in docs] == [doc.metadata] * len(docs)
@pytest.mark.asyncio
async def test_simple_processor(quivr_pdf, quivr_txt):
proc = SimpleTxtProcessor(
splitter_config=SplitterConfig(chunk_size=100, chunk_overlap=20)
)
assert proc.supported_extensions == [FileExtension.txt]
with pytest.raises(ValueError):
await proc.process_file(quivr_pdf)
docs = await proc.process_file(quivr_txt)
assert len(docs) == 1
assert docs[0].page_content == "This is some test data."

View File

@ -1,38 +1,23 @@
from pathlib import Path
from uuid import uuid4
import pytest import pytest
from quivr_core.processor.pdf_processor import TikaParser from quivr_core.processor.tika_processor import TikaProcessor
from quivr_core.storage.file import QuivrFile
# TODO: TIKA server should be set # TODO: TIKA server should be set
@pytest.fixture
def pdf():
return QuivrFile(
id=uuid4(),
brain_id=uuid4(),
original_filename="dummy.pdf",
path=Path("./tests/processor/data/dummy.pdf"),
file_extension=".pdf",
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_file(pdf): async def test_process_file(quivr_pdf):
tparser = TikaParser() tparser = TikaProcessor()
doc = await tparser.process_file(pdf) doc = await tparser.process_file(quivr_pdf)
assert len(doc) > 0 assert len(doc) > 0
assert doc[0].page_content.strip("\n") == "Dummy PDF download" assert doc[0].page_content.strip("\n") == "Dummy PDF download"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_parse_tika_exception(pdf): async def test_send_parse_tika_exception(quivr_pdf):
# TODO: Mock correct tika for retries # TODO: Mock correct tika for retries
tparser = TikaParser(tika_url="test.test") tparser = TikaProcessor(tika_url="test.test")
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
doc = await tparser.process_file(pdf) doc = await tparser.process_file(quivr_pdf)
assert len(doc) > 0 assert len(doc) > 0
assert doc[0].page_content.strip("\n") == "Dummy PDF download" assert doc[0].page_content.strip("\n") == "Dummy PDF download"

View File

@ -0,0 +1,45 @@
from importlib.metadata import version
from uuid import uuid4
import pytest
from quivr_core.processor.splitter import SplitterConfig
from quivr_core.processor.txt_processor import TikTokenTxtProcessor
from quivr_core.storage.file import FileExtension, QuivrFile
# TODO: TIKA server should be set
@pytest.fixture
def txt_qfile(temp_data_file):
return QuivrFile(
id=uuid4(),
brain_id=uuid4(),
original_filename="data.txt",
path=temp_data_file,
file_extension=FileExtension.txt,
file_md5="hash",
)
@pytest.mark.base
@pytest.mark.asyncio
async def test_process_txt(txt_qfile):
tparser = TikTokenTxtProcessor(
splitter_config=SplitterConfig(chunk_size=20, chunk_overlap=0)
)
doc = await tparser.process_file(txt_qfile)
assert len(doc) > 0
assert doc[0].page_content == "This is some test data."
# assert dict1.items() <= dict2.items()
assert (
doc[0].metadata.items()
>= {
"chunk_size": len(doc[0].page_content),
"chunk_overlap": 0,
"parser_name": tparser.__class__.__name__,
"quivr_core_version": version("quivr-core"),
**txt_qfile.metadata,
}.items()
)

View File

@ -1,3 +1,4 @@
from dataclasses import asdict
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
@ -16,8 +17,11 @@ def test_brain_empty_files():
Brain.from_files(name="test_brain", file_paths=[]) Brain.from_files(name="test_brain", file_paths=[])
def test_brain_from_files_success(fake_llm: LLMEndpoint, embedder, temp_data_file): @pytest.mark.asyncio
brain = Brain.from_files( async def test_brain_from_files_success(
fake_llm: LLMEndpoint, embedder, temp_data_file
):
brain = await Brain.afrom_files(
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
) )
assert brain.name == "test_brain" assert brain.name == "test_brain"
@ -29,7 +33,7 @@ def test_brain_from_files_success(fake_llm: LLMEndpoint, embedder, temp_data_fil
# storage # storage
assert isinstance(brain.storage, TransparentStorage) assert isinstance(brain.storage, TransparentStorage)
assert len(brain.storage.get_files()) == 1 assert len(await brain.storage.get_files()) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
@ -39,7 +43,7 @@ async def test_brain_from_langchain_docs(embedder):
name="test", langchain_documents=[chunk], embedder=embedder name="test", langchain_documents=[chunk], embedder=embedder
) )
# No appended files # No appended files
assert len(brain.storage.get_files()) == 0 assert len(await brain.storage.get_files()) == 0
assert len(brain.chat_history) == 0 assert len(brain.chat_history) == 0
@ -90,3 +94,28 @@ async def test_brain_ask_streaming(
response += chunk.answer response += chunk.answer
assert response == answers[1] assert response == answers[1]
def test_brain_info_empty(fake_llm: LLMEndpoint, embedder, mem_vector_store):
storage = TransparentStorage()
id = uuid4()
brain = Brain(
name="test",
id=id,
llm=fake_llm,
embedder=embedder,
storage=storage,
vector_db=mem_vector_store,
)
assert asdict(brain.info()) == {
"brain_id": id,
"brain_name": "test",
"files_info": asdict(storage.info()),
"chats_info": {
"nb_chats": 1, # start with a default chat
"current_default_chat": brain.default_chat.id,
"current_chat_history_length": 0,
},
"llm_info": asdict(fake_llm.info()),
}

View File