mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-26 03:15:19 +03:00
feat: quivr core brain info + processors registry + (#2877)
# Description - Created registry processor logic for automagically adding processors to quivr_core based Entrypoints - Added a langchain_community free `SimpleTxtParser` for the quivr_core base package - Added tests - Added brain_info - Enriched parsed documents metadata based on quivr_file metadata used Rich for `Brain.print_info()` to get a better output: ![image](https://github.com/user-attachments/assets/dd9f2f03-d7d7-4be0-ba6c-3fe38e11c40f)
This commit is contained in:
parent
3b68855a83
commit
3001fa1475
@ -1,19 +1,42 @@
|
||||
from langchain_core.embeddings import DeterministicFakeEmbedding
|
||||
from langchain_core.language_models import FakeListChatModel
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt
|
||||
|
||||
from quivr_core import Brain
|
||||
from quivr_core.processor.default_parsers import DEFAULT_PARSERS
|
||||
from quivr_core.processor.pdf_processor import TikaParser
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
from quivr_core.llm.llm_endpoint import LLMEndpoint
|
||||
|
||||
if __name__ == "__main__":
|
||||
pdf_paths = ["../tests/processor/data/dummy.pdf"]
|
||||
brain = Brain.from_files(
|
||||
name="test_brain",
|
||||
file_paths=[],
|
||||
llm=FakeListChatModel(responses=["good"]),
|
||||
file_paths=["tests/processor/data/dummy.pdf"],
|
||||
llm=LLMEndpoint(
|
||||
llm=FakeListChatModel(responses=["good"]),
|
||||
llm_config=LLMEndpointConfig(model="fake_model", llm_base_url="local"),
|
||||
),
|
||||
embedder=DeterministicFakeEmbedding(size=20),
|
||||
processors_mapping={
|
||||
**DEFAULT_PARSERS,
|
||||
".pdf": TikaParser(),
|
||||
},
|
||||
)
|
||||
# Check brain info
|
||||
brain.print_info()
|
||||
|
||||
console = Console()
|
||||
console.print(Panel.fit("Ask your brain !", style="bold magenta"))
|
||||
|
||||
while True:
|
||||
# Get user input
|
||||
question = Prompt.ask("[bold cyan]Question[/bold cyan]")
|
||||
|
||||
# Check if user wants to exit
|
||||
if question.lower() == "exit":
|
||||
console.print(Panel("Goodbye!", style="bold yellow"))
|
||||
break
|
||||
|
||||
answer = brain.ask(question)
|
||||
# Print the answer with typing effect
|
||||
console.print(f"[bold green]Quivr Assistant[/bold green]: {answer.answer}")
|
||||
|
||||
console.print("-" * console.width)
|
||||
|
||||
brain.print_info()
|
||||
|
@ -12,3 +12,5 @@ if __name__ == "__main__":
|
||||
answer = brain.ask("Property of gold?")
|
||||
|
||||
print("answer :", answer.answer)
|
||||
|
||||
print("brain information: ", brain)
|
||||
|
75
backend/core/poetry.lock
generated
75
backend/core/poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.0 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
@ -1189,13 +1189,13 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.2.19"
|
||||
version = "0.2.20"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "langchain_core-0.2.19-py3-none-any.whl", hash = "sha256:5b3cd34395be274c89e822c84f0e03c4da14168c177a83921c5b9414ac7a0651"},
|
||||
{file = "langchain_core-0.2.19.tar.gz", hash = "sha256:13043a83e5c9ab58b9f5ce2a56896e7e88b752e8891b2958960a98e71801471e"},
|
||||
{file = "langchain_core-0.2.20-py3-none-any.whl", hash = "sha256:16cc4da6f7ebf33accea7af45a70480733dc852ab291030fb6924865bd7caf76"},
|
||||
{file = "langchain_core-0.2.20.tar.gz", hash = "sha256:a66c439e085d8c75f822f7650a5551d17bada4003521173c763d875d949e4ed5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1241,13 +1241,13 @@ langchain-core = ">=0.2.10,<0.3.0"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.1.85"
|
||||
version = "0.1.88"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "langsmith-0.1.85-py3-none-any.whl", hash = "sha256:c1f94384f10cea96f7b4d33fd3db7ec180c03c7468877d50846f881d2017ff94"},
|
||||
{file = "langsmith-0.1.85.tar.gz", hash = "sha256:acff31f9e53efa48586cf8e32f65625a335c74d7c4fa306d1655ac18452296f6"},
|
||||
{file = "langsmith-0.1.88-py3-none-any.whl", hash = "sha256:460ebb7de440afd150fcea8f54ca8779821f2228cd59e149e5845c9dbe06db16"},
|
||||
{file = "langsmith-0.1.88.tar.gz", hash = "sha256:28a07dec19197f4808aa2628d5a3ccafcbe14cc137aef0e607bbd128e7907821"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1258,6 +1258,30 @@ pydantic = [
|
||||
]
|
||||
requests = ">=2,<3"
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "3.0.0"
|
||||
description = "Python port of markdown-it. Markdown parsing, done right!"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
|
||||
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
mdurl = ">=0.1,<1.0"
|
||||
|
||||
[package.extras]
|
||||
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
|
||||
code-style = ["pre-commit (>=3.0,<4.0)"]
|
||||
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
|
||||
linkify = ["linkify-it-py (>=1,<3)"]
|
||||
plugins = ["mdit-py-plugins"]
|
||||
profiling = ["gprof2dot"]
|
||||
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
|
||||
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
||||
|
||||
[[package]]
|
||||
name = "marshmallow"
|
||||
version = "3.21.3"
|
||||
@ -1302,6 +1326,17 @@ files = [
|
||||
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
description = "Markdown URL utilities"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
|
||||
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multidict"
|
||||
version = "6.0.5"
|
||||
@ -1527,13 +1562,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.35.13"
|
||||
version = "1.35.14"
|
||||
description = "The official Python library for the openai API"
|
||||
optional = true
|
||||
python-versions = ">=3.7.1"
|
||||
files = [
|
||||
{file = "openai-1.35.13-py3-none-any.whl", hash = "sha256:36ec3e93e0d1f243f69be85c89b9221a471c3e450dfd9df16c9829e3cdf63e60"},
|
||||
{file = "openai-1.35.13.tar.gz", hash = "sha256:c684f3945608baf7d2dcc0ef3ee6f3e27e4c66f21076df0b47be45d57e6ae6e4"},
|
||||
{file = "openai-1.35.14-py3-none-any.whl", hash = "sha256:adadf8c176e0b8c47ad782ed45dc20ef46438ee1f02c7103c4155cff79c8f68b"},
|
||||
{file = "openai-1.35.14.tar.gz", hash = "sha256:394ba1dfd12ecec1d634c50e512d24ff1858bbc2674ffcce309b822785a058de"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -2328,6 +2363,24 @@ urllib3 = ">=1.21.1,<3"
|
||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "13.7.1"
|
||||
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
{file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"},
|
||||
{file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
markdown-it-py = ">=2.2.0"
|
||||
pygments = ">=2.13.0,<3.0.0"
|
||||
|
||||
[package.extras]
|
||||
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.4.10"
|
||||
@ -2788,4 +2841,4 @@ pdf = []
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "f5ab0d3f93f3bd517382a8c4207a4a480f4a6dcb0aeaf60501c379a4a91a3ad0"
|
||||
content-hash = "cc08905f149df2f415e1d00010e5b89a371efdcf4059855d597b1b6e9973a536"
|
||||
|
@ -15,6 +15,7 @@ aiofiles = ">=23.0.0,<25.0.0"
|
||||
faiss-cpu = { version = "^1.8.0.post1", optional = true }
|
||||
langchain-community = { version = "^0.2.6", optional = true }
|
||||
langchain-openai = { version = "^0.1.14", optional = true }
|
||||
rich = "^13.7.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
base = ["langchain-community", "faiss-cpu", "langchain-openai"]
|
||||
@ -81,6 +82,10 @@ known-first-party = []
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--tb=short -ra -v"
|
||||
filterwarnings = ["ignore::DeprecationWarning"]
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"base: these tests require quivr-core with extra `base` to be installed",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
@ -1,3 +1,35 @@
|
||||
from .brain import Brain
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
__all__ = ["Brain"]
|
||||
from .brain import Brain
|
||||
from .processor.registry import register_processor, registry
|
||||
|
||||
__all__ = ["Brain", "registry", "register_processor"]
|
||||
|
||||
|
||||
def register_entries():
|
||||
if entry_points is not None:
|
||||
try:
|
||||
eps = entry_points()
|
||||
except TypeError:
|
||||
pass # importlib-metadata < 0.8
|
||||
else:
|
||||
if hasattr(eps, "select"): # Python 3.10+ / importlib_metadata >= 3.9.0
|
||||
processors = eps.select(group="quivr_core.processor")
|
||||
else:
|
||||
processors = eps.get("quivr_core.processor", [])
|
||||
registered_names = set()
|
||||
for spec in processors:
|
||||
err_msg = f"Unable to load processor from {spec}"
|
||||
name = spec.name
|
||||
if name in registered_names:
|
||||
continue
|
||||
registered_names.add(name)
|
||||
register_processor(
|
||||
name,
|
||||
spec.value.replace(":", "."),
|
||||
errtxt=err_msg,
|
||||
override=True,
|
||||
)
|
||||
|
||||
|
||||
register_entries()
|
||||
|
3
backend/core/quivr_core/brain/__init__.py
Normal file
3
backend/core/quivr_core/brain/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .brain import Brain
|
||||
|
||||
__all__ = ["Brain"]
|
@ -1,29 +1,34 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, Mapping, Self
|
||||
from pprint import PrettyPrinter
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, Self
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from quivr_core.brain.info import BrainInfo, ChatHistoryInfo
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import LLMEndpointConfig, RAGConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.models import ParsedRAGChunkResponse, ParsedRAGResponse, SearchResult
|
||||
from quivr_core.processor.default_parsers import DEFAULT_PARSERS
|
||||
from quivr_core.processor.processor_base import ProcessorBase
|
||||
from quivr_core.processor.registry import get_processor_class
|
||||
from quivr_core.quivr_rag import QuivrQARAG
|
||||
from quivr_core.storage.file import QuivrFile
|
||||
from quivr_core.storage.file import load_qfile
|
||||
from quivr_core.storage.local_storage import TransparentStorage
|
||||
from quivr_core.storage.storage_base import StorageBase
|
||||
|
||||
logger = logging.getLogger("quivr_core")
|
||||
|
||||
|
||||
async def _default_vectordb(docs: list[Document], embedder: Embeddings) -> VectorStore:
|
||||
async def _build_default_vectordb(
|
||||
docs: list[Document], embedder: Embeddings
|
||||
) -> VectorStore:
|
||||
try:
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
@ -38,7 +43,7 @@ async def _default_vectordb(docs: list[Document], embedder: Embeddings) -> Vecto
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please provide a valid vectore store or install quivr-core['base'] package for using the default one."
|
||||
"Please provide a valid vector store or install quivr-core['base'] package for using the default one."
|
||||
) from e
|
||||
|
||||
|
||||
@ -67,16 +72,15 @@ def _default_llm() -> LLMEndpoint:
|
||||
) from e
|
||||
|
||||
|
||||
async def _process_files(
|
||||
storage: StorageBase,
|
||||
skip_file_error: bool,
|
||||
processors_mapping: Mapping[str, ProcessorBase],
|
||||
async def process_files(
|
||||
storage: StorageBase, skip_file_error: bool, **processor_kwargs: dict[str, Any]
|
||||
) -> list[Document]:
|
||||
knowledge = []
|
||||
for file in storage.get_files():
|
||||
for file in await storage.get_files():
|
||||
try:
|
||||
if file.file_extension:
|
||||
processor = processors_mapping[file.file_extension]
|
||||
processor_cls = get_processor_class(file.file_extension)
|
||||
processor = processor_cls(**processor_kwargs)
|
||||
docs = await processor.process_file(file)
|
||||
knowledge.extend(docs)
|
||||
else:
|
||||
@ -118,11 +122,38 @@ class Brain:
|
||||
self.vector_db = vector_db
|
||||
self.embedder = embedder
|
||||
|
||||
def __repr__(self) -> str:
|
||||
pp = PrettyPrinter(width=80, depth=None, compact=False, sort_dicts=False)
|
||||
return pp.pformat(self.info())
|
||||
|
||||
def print_info(self):
|
||||
console = Console()
|
||||
tree = self.info().to_tree()
|
||||
panel = Panel(tree, title="Brain Info", expand=False, border_style="bold")
|
||||
console.print(panel)
|
||||
|
||||
def info(self) -> BrainInfo:
|
||||
# TODO: dim of embedding
|
||||
# "embedder": {},
|
||||
chats_info = ChatHistoryInfo(
|
||||
nb_chats=len(self._chats),
|
||||
current_default_chat=self.default_chat.id,
|
||||
current_chat_history_length=len(self.default_chat),
|
||||
)
|
||||
|
||||
return BrainInfo(
|
||||
brain_id=self.id,
|
||||
brain_name=self.name,
|
||||
files_info=self.storage.info(),
|
||||
chats_info=chats_info,
|
||||
llm_info=self.llm.info(),
|
||||
)
|
||||
|
||||
@property
|
||||
def chat_history(self):
|
||||
def chat_history(self) -> ChatHistory:
|
||||
return self.default_chat
|
||||
|
||||
def _init_chats(self):
|
||||
def _init_chats(self) -> Dict[UUID, ChatHistory]:
|
||||
chat_id = uuid4()
|
||||
default_chat = ChatHistory(chat_id=chat_id, brain_id=self.id)
|
||||
return {chat_id: default_chat}
|
||||
@ -137,7 +168,6 @@ class Brain:
|
||||
storage: StorageBase = TransparentStorage(),
|
||||
llm: LLMEndpoint | None = None,
|
||||
embedder: Embeddings | None = None,
|
||||
processors_mapping: Mapping[str, ProcessorBase] = DEFAULT_PARSERS,
|
||||
skip_file_error: bool = False,
|
||||
):
|
||||
if llm is None:
|
||||
@ -148,20 +178,20 @@ class Brain:
|
||||
|
||||
brain_id = uuid4()
|
||||
|
||||
# TODO: run in parallel using tasks
|
||||
for path in file_paths:
|
||||
file = QuivrFile.from_path(brain_id, path)
|
||||
storage.upload_file(file)
|
||||
file = await load_qfile(brain_id, path)
|
||||
await storage.upload_file(file)
|
||||
|
||||
# Parse files
|
||||
docs = await _process_files(
|
||||
docs = await process_files(
|
||||
storage=storage,
|
||||
processors_mapping=processors_mapping,
|
||||
skip_file_error=skip_file_error,
|
||||
)
|
||||
|
||||
# Building brain's vectordb
|
||||
if vector_db is None:
|
||||
vector_db = await _default_vectordb(docs, embedder)
|
||||
vector_db = await _build_default_vectordb(docs, embedder)
|
||||
else:
|
||||
await vector_db.aadd_documents(docs)
|
||||
|
||||
@ -184,10 +214,10 @@ class Brain:
|
||||
storage: StorageBase = TransparentStorage(),
|
||||
llm: LLMEndpoint | None = None,
|
||||
embedder: Embeddings | None = None,
|
||||
processors_mapping: Mapping[str, ProcessorBase] = DEFAULT_PARSERS,
|
||||
skip_file_error: bool = False,
|
||||
) -> Self:
|
||||
return asyncio.run(
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
cls.afrom_files(
|
||||
name=name,
|
||||
file_paths=file_paths,
|
||||
@ -195,7 +225,6 @@ class Brain:
|
||||
storage=storage,
|
||||
llm=llm,
|
||||
embedder=embedder,
|
||||
processors_mapping=processors_mapping,
|
||||
skip_file_error=skip_file_error,
|
||||
)
|
||||
)
|
||||
@ -221,7 +250,7 @@ class Brain:
|
||||
|
||||
# Building brain's vectordb
|
||||
if vector_db is None:
|
||||
vector_db = await _default_vectordb(langchain_documents, embedder)
|
||||
vector_db = await _build_default_vectordb(langchain_documents, embedder)
|
||||
else:
|
||||
await vector_db.aadd_documents(langchain_documents)
|
||||
|
74
backend/core/quivr_core/brain/info.py
Normal file
74
backend/core/quivr_core/brain/info.py
Normal file
@ -0,0 +1,74 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from rich.tree import Tree
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatHistoryInfo:
|
||||
nb_chats: int
|
||||
current_default_chat: UUID
|
||||
current_chat_history_length: int
|
||||
|
||||
def add_to_tree(self, chats_tree: Tree):
|
||||
chats_tree.add(f"Number of Chats: [bold]{self.nb_chats}[/bold]")
|
||||
chats_tree.add(
|
||||
f"Current Default Chat: [bold magenta]{self.current_default_chat}[/bold magenta]"
|
||||
)
|
||||
chats_tree.add(
|
||||
f"Current Chat History Length: [bold]{self.current_chat_history_length}[/bold]"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMInfo:
|
||||
model: str
|
||||
llm_base_url: str
|
||||
temperature: float
|
||||
max_tokens: int
|
||||
supports_function_calling: int
|
||||
|
||||
def add_to_tree(self, llm_tree: Tree):
|
||||
llm_tree.add(f"Model: [italic]{self.model}[/italic]")
|
||||
llm_tree.add(f"Base URL: [underline]{self.llm_base_url}[/underline]")
|
||||
llm_tree.add(f"Temperature: [bold]{self.temperature}[/bold]")
|
||||
llm_tree.add(f"Max Tokens: [bold]{self.max_tokens}[/bold]")
|
||||
|
||||
func_call_color = "green" if self.supports_function_calling else "red"
|
||||
llm_tree.add(
|
||||
f"Supports Function Calling: [bold {func_call_color}]{self.supports_function_calling}[/bold {func_call_color}]"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageInfo:
|
||||
storage_type: str
|
||||
n_files: int
|
||||
|
||||
def add_to_tree(self, files_tree: Tree):
|
||||
files_tree.add(f"Storage Type: [italic]{self.storage_type}[/italic]")
|
||||
files_tree.add(f"Number of Files: [bold]{self.n_files}[/bold]")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrainInfo:
|
||||
brain_id: UUID
|
||||
brain_name: str
|
||||
files_info: StorageInfo
|
||||
chats_info: ChatHistoryInfo
|
||||
llm_info: LLMInfo
|
||||
|
||||
def to_tree(self):
|
||||
tree = Tree("📊 Brain Information")
|
||||
tree.add(f"🆔 ID: [bold cyan]{self.brain_id}[/bold cyan]")
|
||||
tree.add(f"🧠 Brain Name: [bold green]{self.brain_name}[/bold green]")
|
||||
|
||||
files_tree = tree.add("📁 Files")
|
||||
self.files_info.add_to_tree(files_tree)
|
||||
|
||||
chats_tree = tree.add("💬 Chats")
|
||||
self.chats_info.add_to_tree(chats_tree)
|
||||
|
||||
llm_tree = tree.add("🤖 LLM")
|
||||
self.llm_info.add_to_tree(llm_tree)
|
||||
return tree
|
@ -1,6 +1,7 @@
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from quivr_core.brain.info import LLMInfo
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
from quivr_core.utils import model_supports_function_calling
|
||||
|
||||
@ -35,3 +36,14 @@ class LLMEndpoint:
|
||||
|
||||
def supports_func_calling(self) -> bool:
|
||||
return self._supports_func_calling
|
||||
|
||||
def info(self) -> LLMInfo:
|
||||
return LLMInfo(
|
||||
model=self._config.model,
|
||||
llm_base_url=(
|
||||
self._config.llm_base_url if self._config.llm_base_url else "openai"
|
||||
),
|
||||
temperature=self._config.temperature,
|
||||
max_tokens=self._config.max_tokens,
|
||||
supports_function_calling=self.supports_func_calling(),
|
||||
)
|
||||
|
@ -1,6 +0,0 @@
|
||||
from quivr_core.processor.processor_base import ProcessorBase
|
||||
from quivr_core.processor.txt_parser import TxtProcessor
|
||||
|
||||
DEFAULT_PARSERS: dict[str, ProcessorBase] = {
|
||||
".txt": TxtProcessor(),
|
||||
}
|
@ -1,27 +1,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from quivr_core.storage.file import QuivrFile
|
||||
from quivr_core.storage.file import FileExtension, QuivrFile
|
||||
|
||||
|
||||
# TODO: processors should be cached somewhere ?
|
||||
# The processor should be cached by processor type
|
||||
# The cache should use a single
|
||||
class ProcessorBase(ABC):
|
||||
supported_extensions: list[str]
|
||||
supported_extensions: list[FileExtension | str]
|
||||
|
||||
@abstractmethod
|
||||
async def process_file(self, file: QuivrFile) -> list[Document]:
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
P = TypeVar("P", bound=ProcessorBase)
|
||||
|
||||
|
||||
class ProcessorsMapping(Generic[P]):
|
||||
def __init__(self, mapping: dict[str, P]) -> None:
|
||||
# Create an empty list with items of type T
|
||||
self.ext_parser: dict[str, P] = mapping
|
||||
|
||||
def add_parser(self, extension: str, parser: P):
|
||||
# TODO: deal with existing ext keys
|
||||
self.ext_parser[extension] = parser
|
||||
def check_supported(self, file: QuivrFile):
|
||||
if file.file_extension not in self.supported_extensions:
|
||||
raise ValueError(f"can't process a file of type {file.file_extension}")
|
||||
|
107
backend/core/quivr_core/processor/registry.py
Normal file
107
backend/core/quivr_core/processor/registry.py
Normal file
@ -0,0 +1,107 @@
|
||||
import importlib
|
||||
import types
|
||||
from typing import Type, TypedDict
|
||||
|
||||
from quivr_core.storage.file import FileExtension
|
||||
|
||||
from .processor_base import ProcessorBase
|
||||
|
||||
_registry: dict[str, Type[ProcessorBase]] = {}
|
||||
|
||||
# external, read only
|
||||
registry = types.MappingProxyType(_registry)
|
||||
|
||||
|
||||
class ProcEntry(TypedDict):
|
||||
cls_mod: str
|
||||
err: str | None
|
||||
|
||||
|
||||
# Register based on mimetypes
|
||||
known_processors: dict[FileExtension | str, ProcEntry] = {
|
||||
FileExtension.txt: ProcEntry(
|
||||
cls_mod="quivr_core.processor.simple_txt_processor.SimpleTxtProcessor",
|
||||
err="Please install quivr_core[base] to use TikTokenTxtProcessor ",
|
||||
),
|
||||
FileExtension.pdf: ProcEntry(
|
||||
cls_mod="quivr_core.processor.tika_processor.TikaProcessor",
|
||||
err=None,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_processor_class(file_extension: FileExtension | str) -> Type[ProcessorBase]:
|
||||
"""Fetch processor class from registry
|
||||
|
||||
The dict ``known_processors`` maps file extensions to the locations
|
||||
of processors that could process them.
|
||||
Loading of these classes is *Lazy*. Appropriate import will happen
|
||||
the first time we try to process some file type.
|
||||
|
||||
Some processors need additional dependencies. If the import fails
|
||||
we return the "err" field of the ProcEntry in ``known_processors``.
|
||||
"""
|
||||
|
||||
if file_extension not in registry:
|
||||
if file_extension not in known_processors:
|
||||
raise ValueError(f"Extension not known: {file_extension}")
|
||||
entry = known_processors[file_extension]
|
||||
try:
|
||||
register_processor(file_extension, _import_class(entry["cls_mod"]))
|
||||
except ImportError as e:
|
||||
raise ImportError(entry["err"]) from e
|
||||
|
||||
cls = registry[file_extension]
|
||||
return cls
|
||||
|
||||
|
||||
def register_processor(
|
||||
file_type: FileExtension | str,
|
||||
proc_cls: str | Type[ProcessorBase],
|
||||
override: bool = False,
|
||||
errtxt=None,
|
||||
):
|
||||
if isinstance(proc_cls, str):
|
||||
if file_type in known_processors and override is False:
|
||||
if proc_cls != known_processors[file_type]["cls_mod"]:
|
||||
raise ValueError(
|
||||
f"Processor for ({file_type}) already in the registry and override is False"
|
||||
)
|
||||
else:
|
||||
known_processors[file_type] = ProcEntry(
|
||||
cls_mod=proc_cls,
|
||||
err=errtxt or f"{proc_cls} import failed for processor of {file_type}",
|
||||
)
|
||||
else:
|
||||
if file_type in registry and override is False:
|
||||
if _registry[file_type] is not proc_cls:
|
||||
raise ValueError(
|
||||
f"Processor for ({file_type}) already in the registry and override is False"
|
||||
)
|
||||
else:
|
||||
_registry[file_type] = proc_cls
|
||||
|
||||
|
||||
def _import_class(full_mod_path: str):
|
||||
if ":" in full_mod_path:
|
||||
mod_name, name = full_mod_path.rsplit(":", 1)
|
||||
else:
|
||||
mod_name, name = full_mod_path.rsplit(".", 1)
|
||||
|
||||
mod = importlib.import_module(mod_name)
|
||||
|
||||
for cls in name.split("."):
|
||||
mod = getattr(mod, cls)
|
||||
|
||||
if not isinstance(mod, type):
|
||||
raise TypeError(f"{full_mod_path} is not a class")
|
||||
|
||||
if not issubclass(mod, ProcessorBase):
|
||||
raise TypeError(f"{full_mod_path} is not a subclass of ProcessorBase ")
|
||||
|
||||
return mod
|
||||
|
||||
|
||||
def available_processors():
|
||||
"""Return a list of the known processors."""
|
||||
return list(known_processors)
|
62
backend/core/quivr_core/processor/simple_txt_processor.py
Normal file
62
backend/core/quivr_core/processor/simple_txt_processor.py
Normal file
@ -0,0 +1,62 @@
|
||||
from importlib.metadata import version
|
||||
from uuid import uuid4
|
||||
|
||||
import aiofiles
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from quivr_core.processor.processor_base import ProcessorBase
|
||||
from quivr_core.processor.registry import FileExtension
|
||||
from quivr_core.processor.splitter import SplitterConfig
|
||||
from quivr_core.storage.file import QuivrFile
|
||||
|
||||
|
||||
def recursive_character_splitter(
|
||||
doc: Document, chunk_size: int, chunk_overlap: int
|
||||
) -> list[Document]:
|
||||
assert chunk_overlap < chunk_size, "chunk_overlap is greater than chunk_size"
|
||||
|
||||
if len(doc.page_content) <= chunk_size:
|
||||
return [doc]
|
||||
|
||||
chunk = Document(page_content=doc.page_content[:chunk_size], metadata=doc.metadata)
|
||||
remaining = Document(
|
||||
page_content=doc.page_content[chunk_size - chunk_overlap :],
|
||||
metadata=doc.metadata,
|
||||
)
|
||||
|
||||
return [chunk] + recursive_character_splitter(remaining, chunk_size, chunk_overlap)
|
||||
|
||||
|
||||
class SimpleTxtProcessor(ProcessorBase):
|
||||
supported_extensions = [FileExtension.txt]
|
||||
|
||||
def __init__(
|
||||
self, splitter_config: SplitterConfig = SplitterConfig(), **kwargs
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.splitter_config = splitter_config
|
||||
|
||||
async def process_file(self, file: QuivrFile) -> list[Document]:
|
||||
self.check_supported(file)
|
||||
file_metadata = file.metadata
|
||||
|
||||
async with aiofiles.open(file.path, mode="r") as f:
|
||||
content = await f.read()
|
||||
|
||||
doc = Document(
|
||||
page_content=content,
|
||||
metadata={
|
||||
"id": uuid4(),
|
||||
"chunk_size": len(content),
|
||||
"chunk_overlap": self.splitter_config.chunk_overlap,
|
||||
"parser_name": self.__class__.__name__,
|
||||
"quivr_core_version": version("quivr-core"),
|
||||
**file_metadata,
|
||||
},
|
||||
)
|
||||
|
||||
docs = recursive_character_splitter(
|
||||
doc, self.splitter_config.chunk_size, self.splitter_config.chunk_overlap
|
||||
)
|
||||
|
||||
return docs
|
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from importlib.metadata import version
|
||||
from typing import AsyncIterable
|
||||
|
||||
import httpx
|
||||
@ -6,14 +7,15 @@ from langchain_core.documents import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from quivr_core.processor.processor_base import ProcessorBase
|
||||
from quivr_core.processor.registry import FileExtension
|
||||
from quivr_core.processor.splitter import SplitterConfig
|
||||
from quivr_core.storage.file import QuivrFile
|
||||
|
||||
logger = logging.getLogger("quivr_core")
|
||||
|
||||
|
||||
class TikaParser(ProcessorBase):
|
||||
supported_extensions = [".pdf"]
|
||||
class TikaProcessor(ProcessorBase):
|
||||
supported_extensions = [FileExtension.pdf]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -51,12 +53,22 @@ class TikaParser(ProcessorBase):
|
||||
raise RuntimeError("can't send parse request to tika server")
|
||||
|
||||
async def process_file(self, file: QuivrFile) -> list[Document]:
|
||||
assert file.file_extension in self.supported_extensions
|
||||
self.check_supported(file)
|
||||
|
||||
async with file.open() as f:
|
||||
txt = await self._send_parse_tika(f)
|
||||
document = Document(page_content=txt)
|
||||
|
||||
# Use the default splitter
|
||||
docs = self.text_splitter.split_documents([document])
|
||||
file_metadata = file.metadata
|
||||
|
||||
for doc in docs:
|
||||
doc.metadata = {
|
||||
"chunk_size": len(doc.page_content),
|
||||
"chunk_overlap": self.splitter_config.chunk_overlap,
|
||||
"parser_name": self.__class__.__name__,
|
||||
"quivr_core_version": version("quivr-core"),
|
||||
**file_metadata,
|
||||
}
|
||||
|
||||
return docs
|
@ -1,19 +1,24 @@
|
||||
from importlib.metadata import version
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_community.document_loaders.text import TextLoader
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
|
||||
from quivr_core.processor.processor_base import ProcessorBase
|
||||
from quivr_core.processor.registry import FileExtension
|
||||
from quivr_core.processor.splitter import SplitterConfig
|
||||
from quivr_core.storage.file import QuivrFile
|
||||
|
||||
|
||||
class TxtProcessor(ProcessorBase):
|
||||
class TikTokenTxtProcessor(ProcessorBase):
|
||||
supported_extensions = [FileExtension.txt]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
splitter: TextSplitter | None = None,
|
||||
splitter_config: SplitterConfig = SplitterConfig(),
|
||||
) -> None:
|
||||
self.supported_extensions = [".txt"]
|
||||
self.loader_cls = TextLoader
|
||||
|
||||
self.splitter_config = splitter_config
|
||||
@ -27,10 +32,22 @@ class TxtProcessor(ProcessorBase):
|
||||
)
|
||||
|
||||
async def process_file(self, file: QuivrFile) -> list[Document]:
|
||||
if file.file_extension not in self.supported_extensions:
|
||||
raise Exception(f"can't process a file of type {file.file_extension}")
|
||||
self.check_supported(file)
|
||||
|
||||
loader = self.loader_cls(file.path)
|
||||
documents = await loader.aload()
|
||||
docs = self.text_splitter.split_documents(documents)
|
||||
|
||||
file_metadata = file.metadata
|
||||
|
||||
for doc in docs:
|
||||
doc.metadata = {
|
||||
"id": uuid4(),
|
||||
"chunk_size": len(doc.page_content),
|
||||
"chunk_overlap": self.splitter_config.chunk_overlap,
|
||||
"parser_name": self.__class__.__name__,
|
||||
"quivr_core_version": version("quivr-core"),
|
||||
**file_metadata,
|
||||
}
|
||||
|
||||
return docs
|
@ -1,21 +1,87 @@
|
||||
import hashlib
|
||||
import mimetypes
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, AsyncIterable
|
||||
from typing import Any, AsyncGenerator, AsyncIterable
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import aiofiles
|
||||
|
||||
|
||||
class FileExtension(str, Enum):
|
||||
txt = ".txt"
|
||||
pdf = ".pdf"
|
||||
docx = ".docx"
|
||||
|
||||
|
||||
def get_file_extension(file_path: Path) -> FileExtension | str:
|
||||
try:
|
||||
mime_type, _ = mimetypes.guess_type(file_path.name)
|
||||
if mime_type:
|
||||
mime_ext = mimetypes.guess_extension(mime_type)
|
||||
if mime_ext:
|
||||
return FileExtension(mime_ext)
|
||||
return FileExtension(file_path.suffix)
|
||||
except ValueError:
|
||||
warnings.warn(
|
||||
f"File {file_path.name} extension isn't recognized. Make sure you have registered a parser for {file_path.suffix}",
|
||||
stacklevel=2,
|
||||
)
|
||||
return file_path.suffix
|
||||
|
||||
|
||||
async def load_qfile(brain_id: UUID, path: str | Path):
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
raise FileExistsError(f"file {path} doesn't exist")
|
||||
|
||||
file_size = os.stat(path).st_size
|
||||
|
||||
async with aiofiles.open(path, mode="rb") as f:
|
||||
file_md5 = hashlib.md5(await f.read()).hexdigest()
|
||||
|
||||
try:
|
||||
# NOTE: when loading from existing storage, file name will be uuid
|
||||
id = UUID(path.name)
|
||||
except ValueError:
|
||||
id = uuid4()
|
||||
|
||||
return QuivrFile(
|
||||
id=id,
|
||||
brain_id=brain_id,
|
||||
path=path,
|
||||
original_filename=path.name,
|
||||
file_extension=get_file_extension(path),
|
||||
file_size=file_size,
|
||||
file_md5=file_md5,
|
||||
)
|
||||
|
||||
|
||||
class QuivrFile:
|
||||
__slots__ = [
|
||||
"id",
|
||||
"brain_id",
|
||||
"path",
|
||||
"original_filename",
|
||||
"file_size",
|
||||
"file_extension",
|
||||
"file_md5",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: UUID,
|
||||
original_filename: str,
|
||||
path: Path,
|
||||
brain_id: UUID,
|
||||
file_md5: str,
|
||||
file_extension: FileExtension | str,
|
||||
file_size: int | None = None,
|
||||
file_extension: str | None = None,
|
||||
) -> None:
|
||||
self.id = id
|
||||
self.brain_id = brain_id
|
||||
@ -23,31 +89,7 @@ class QuivrFile:
|
||||
self.original_filename = original_filename
|
||||
self.file_size = file_size
|
||||
self.file_extension = file_extension
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, brain_id: UUID, path: str | Path):
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
raise FileExistsError(f"file {path} doesn't exist")
|
||||
|
||||
file_size = os.stat(path).st_size
|
||||
|
||||
try:
|
||||
# NOTE: when loading from existing storage, file name will be uuid
|
||||
id = UUID(path.name)
|
||||
except ValueError:
|
||||
id = uuid4()
|
||||
|
||||
return cls(
|
||||
id=id,
|
||||
brain_id=brain_id,
|
||||
path=path,
|
||||
original_filename=path.name,
|
||||
file_size=file_size,
|
||||
file_extension=path.suffix,
|
||||
)
|
||||
self.file_md5 = file_md5
|
||||
|
||||
@asynccontextmanager
|
||||
async def open(self) -> AsyncGenerator[AsyncIterable[bytes], None]:
|
||||
@ -57,3 +99,13 @@ class QuivrFile:
|
||||
yield f
|
||||
finally:
|
||||
await f.close()
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
return {
|
||||
"qfile_id": self.id,
|
||||
"qfile_path": self.path,
|
||||
"original_file_name": self.original_filename,
|
||||
"file_md4": self.file_md5,
|
||||
"file_size": self.file_size,
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_core.storage.file import QuivrFile
|
||||
@ -8,8 +9,11 @@ from quivr_core.storage.storage_base import StorageBase
|
||||
|
||||
|
||||
class LocalStorage(StorageBase):
|
||||
name: str = "local_storage"
|
||||
|
||||
def __init__(self, dir_path: Path | None = None, copy_flag: bool = True):
|
||||
self.files: list[QuivrFile] = []
|
||||
self.hashes: Set[str] = set()
|
||||
self.copy_flag = copy_flag
|
||||
|
||||
if dir_path is None:
|
||||
@ -24,14 +28,19 @@ class LocalStorage(StorageBase):
|
||||
# TODO(@aminediro): load existing files
|
||||
pass
|
||||
|
||||
def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
|
||||
def nb_files(self) -> int:
|
||||
return len(self.files)
|
||||
|
||||
def info(self):
|
||||
return {"directory_path": self.dir_path, **super().info()}
|
||||
|
||||
async def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
|
||||
dst_path = os.path.join(
|
||||
self.dir_path, str(file.brain_id), f"{file.id}{file.file_extension}"
|
||||
)
|
||||
|
||||
# TODO(@aminediro): Check hash of file not file path
|
||||
if os.path.exists(dst_path) and not exists_ok:
|
||||
raise FileExistsError("file already exists")
|
||||
if file.file_md5 in self.hashes and not exists_ok:
|
||||
raise FileExistsError(f"file {file.original_filename} already uploaded")
|
||||
|
||||
if self.copy_flag:
|
||||
shutil.copy2(file.path, dst_path)
|
||||
@ -40,28 +49,31 @@ class LocalStorage(StorageBase):
|
||||
|
||||
file.path = Path(dst_path)
|
||||
self.files.append(file)
|
||||
self.hashes.add(file.file_md5)
|
||||
|
||||
def get_files(self) -> list[QuivrFile]:
|
||||
async def get_files(self) -> list[QuivrFile]:
|
||||
return self.files
|
||||
|
||||
def remove_file(self, file_id: UUID) -> None:
|
||||
async def remove_file(self, file_id: UUID) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransparentStorage(StorageBase):
|
||||
"""Transparent Storage.
|
||||
uses default
|
||||
"""Transparent Storage."""
|
||||
|
||||
"""
|
||||
name: str = "transparent_storage"
|
||||
|
||||
def __init__(self):
|
||||
self.files = []
|
||||
self.id_files = {}
|
||||
|
||||
def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
|
||||
self.files.append(file)
|
||||
async def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
|
||||
self.id_files[file.id] = file
|
||||
|
||||
def remove_file(self, file_id: UUID) -> None:
|
||||
def nb_files(self) -> int:
|
||||
return len(self.id_files)
|
||||
|
||||
async def remove_file(self, file_id: UUID) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_files(self) -> list[QuivrFile]:
|
||||
return self.files
|
||||
async def get_files(self) -> list[QuivrFile]:
|
||||
return list(self.id_files.values())
|
||||
|
@ -1,18 +1,39 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_core.brain.info import StorageInfo
|
||||
from quivr_core.storage.local_storage import QuivrFile
|
||||
|
||||
|
||||
class StorageBase(ABC):
|
||||
name: str
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
for required in ("name",):
|
||||
if not getattr(cls, required):
|
||||
raise TypeError(
|
||||
f"Can't instantiate abstract class {cls.__name__} without {required} attribute defined"
|
||||
)
|
||||
return super().__init_subclass__(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def get_files(self) -> list[QuivrFile]:
|
||||
def nb_files(self) -> int:
|
||||
raise Exception("Unimplemented nb_files method")
|
||||
|
||||
@abstractmethod
|
||||
async def get_files(self) -> list[QuivrFile]:
|
||||
raise Exception("Unimplemented get_files method")
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
|
||||
async def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
|
||||
raise Exception("Unimplemented upload_file method")
|
||||
|
||||
@abstractmethod
|
||||
def remove_file(self, file_id: UUID) -> None:
|
||||
async def remove_file(self, file_id: UUID) -> None:
|
||||
raise Exception("Unimplemented remove_file method")
|
||||
|
||||
def info(self) -> StorageInfo:
|
||||
return StorageInfo(
|
||||
storage_type=self.name,
|
||||
n_files=self.nb_files(),
|
||||
)
|
||||
|
@ -1,5 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from langchain_core.embeddings import DeterministicFakeEmbedding
|
||||
@ -10,6 +12,39 @@ from langchain_core.vectorstores import InMemoryVectorStore
|
||||
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.storage.file import FileExtension, QuivrFile
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def temp_data_file(tmp_path):
|
||||
data = "This is some test data."
|
||||
temp_file = tmp_path / "data.txt"
|
||||
temp_file.write_text(data)
|
||||
return temp_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def quivr_txt(temp_data_file):
|
||||
return QuivrFile(
|
||||
id=uuid4(),
|
||||
brain_id=uuid4(),
|
||||
original_filename=temp_data_file.name,
|
||||
path=temp_data_file,
|
||||
file_extension=FileExtension.txt,
|
||||
file_md5="123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def quivr_pdf():
|
||||
return QuivrFile(
|
||||
id=uuid4(),
|
||||
brain_id=uuid4(),
|
||||
original_filename="dummy.pdf",
|
||||
path=Path("./tests/processor/data/dummy.pdf"),
|
||||
file_extension=FileExtension.pdf,
|
||||
file_md5="13bh234jh234",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -36,14 +71,6 @@ def openai_api_key():
|
||||
os.environ["OPENAI_API_KEY"] = "abcd"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def temp_data_file(tmp_path):
|
||||
data = "This is some test data."
|
||||
temp_file = tmp_path / "data.txt"
|
||||
temp_file.write_text(data)
|
||||
return temp_file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def answers():
|
||||
return [f"answer_{i}" for i in range(10)]
|
||||
|
67
backend/core/tests/processor/test_registry.py
Normal file
67
backend/core/tests/processor/test_registry.py
Normal file
@ -0,0 +1,67 @@
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from quivr_core import registry
|
||||
from quivr_core.processor.processor_base import ProcessorBase
|
||||
from quivr_core.processor.registry import (
|
||||
_import_class,
|
||||
get_processor_class,
|
||||
register_processor,
|
||||
)
|
||||
from quivr_core.processor.simple_txt_processor import SimpleTxtProcessor
|
||||
from quivr_core.processor.tika_processor import TikaProcessor
|
||||
from quivr_core.storage.file import FileExtension, QuivrFile
|
||||
|
||||
|
||||
def test_get_processor_cls():
|
||||
cls = get_processor_class(FileExtension.txt)
|
||||
assert cls == SimpleTxtProcessor
|
||||
cls = get_processor_class(FileExtension.pdf)
|
||||
assert cls == TikaProcessor
|
||||
|
||||
|
||||
def test__import_class():
|
||||
mod_path = "quivr_core.processor.tika_processor.TikaProcessor"
|
||||
mod = _import_class(mod_path)
|
||||
assert mod == TikaProcessor
|
||||
|
||||
with pytest.raises(TypeError, match=r".* is not a class"):
|
||||
mod_path = "quivr_core.processor"
|
||||
_import_class(mod_path)
|
||||
|
||||
with pytest.raises(TypeError, match=r".* ProcessorBase"):
|
||||
mod_path = "quivr_core.Brain"
|
||||
_import_class(mod_path)
|
||||
|
||||
|
||||
def test_get_processor_cls_error():
|
||||
with pytest.raises(ValueError):
|
||||
get_processor_class(".docx")
|
||||
|
||||
|
||||
def test_register_new_proc():
|
||||
nprocs = len(registry)
|
||||
|
||||
class TestProcessor(ProcessorBase):
|
||||
supported_extensions = [".test"]
|
||||
|
||||
async def process_file(self, file: QuivrFile) -> list[Document]:
|
||||
return []
|
||||
|
||||
register_processor(".test", TestProcessor)
|
||||
assert len(registry) == nprocs + 1
|
||||
|
||||
cls = get_processor_class(".test")
|
||||
assert cls == TestProcessor
|
||||
|
||||
|
||||
def test_register_override_proc():
|
||||
class TestProcessor(ProcessorBase):
|
||||
supported_extensions = [".pdf"]
|
||||
|
||||
async def process_file(self, file: QuivrFile) -> list[Document]:
|
||||
return []
|
||||
|
||||
register_processor(".pdf", TestProcessor, override=True)
|
||||
cls = get_processor_class(FileExtension.pdf)
|
||||
assert cls == TestProcessor
|
34
backend/core/tests/processor/test_simple_txt_processor.py
Normal file
34
backend/core/tests/processor/test_simple_txt_processor.py
Normal file
@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from quivr_core.processor.simple_txt_processor import (
|
||||
SimpleTxtProcessor,
|
||||
recursive_character_splitter,
|
||||
)
|
||||
from quivr_core.processor.splitter import SplitterConfig
|
||||
from quivr_core.storage.file import FileExtension
|
||||
|
||||
|
||||
def test_recursive_character_splitter():
|
||||
doc = Document(page_content="abcdefgh", metadata={"key": "value"})
|
||||
|
||||
docs = recursive_character_splitter(doc, chunk_size=2, chunk_overlap=1)
|
||||
|
||||
assert [d.page_content for d in docs] == ["ab", "bc", "cd", "de", "ef", "fg", "gh"]
|
||||
assert [d.metadata for d in docs] == [doc.metadata] * len(docs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_processor(quivr_pdf, quivr_txt):
|
||||
proc = SimpleTxtProcessor(
|
||||
splitter_config=SplitterConfig(chunk_size=100, chunk_overlap=20)
|
||||
)
|
||||
assert proc.supported_extensions == [FileExtension.txt]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await proc.process_file(quivr_pdf)
|
||||
|
||||
docs = await proc.process_file(quivr_txt)
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "This is some test data."
|
@ -1,38 +1,23 @@
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from quivr_core.processor.pdf_processor import TikaParser
|
||||
from quivr_core.storage.file import QuivrFile
|
||||
from quivr_core.processor.tika_processor import TikaProcessor
|
||||
|
||||
# TODO: TIKA server should be set
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pdf():
|
||||
return QuivrFile(
|
||||
id=uuid4(),
|
||||
brain_id=uuid4(),
|
||||
original_filename="dummy.pdf",
|
||||
path=Path("./tests/processor/data/dummy.pdf"),
|
||||
file_extension=".pdf",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_file(pdf):
|
||||
tparser = TikaParser()
|
||||
doc = await tparser.process_file(pdf)
|
||||
async def test_process_file(quivr_pdf):
|
||||
tparser = TikaProcessor()
|
||||
doc = await tparser.process_file(quivr_pdf)
|
||||
assert len(doc) > 0
|
||||
assert doc[0].page_content.strip("\n") == "Dummy PDF download"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_parse_tika_exception(pdf):
|
||||
async def test_send_parse_tika_exception(quivr_pdf):
|
||||
# TODO: Mock correct tika for retries
|
||||
tparser = TikaParser(tika_url="test.test")
|
||||
tparser = TikaProcessor(tika_url="test.test")
|
||||
with pytest.raises(RuntimeError):
|
||||
doc = await tparser.process_file(pdf)
|
||||
doc = await tparser.process_file(quivr_pdf)
|
||||
assert len(doc) > 0
|
||||
assert doc[0].page_content.strip("\n") == "Dummy PDF download"
|
||||
|
45
backend/core/tests/processor/test_txt_processor.py
Normal file
45
backend/core/tests/processor/test_txt_processor.py
Normal file
@ -0,0 +1,45 @@
|
||||
from importlib.metadata import version
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from quivr_core.processor.splitter import SplitterConfig
|
||||
from quivr_core.processor.txt_processor import TikTokenTxtProcessor
|
||||
from quivr_core.storage.file import FileExtension, QuivrFile
|
||||
|
||||
# TODO: TIKA server should be set
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def txt_qfile(temp_data_file):
|
||||
return QuivrFile(
|
||||
id=uuid4(),
|
||||
brain_id=uuid4(),
|
||||
original_filename="data.txt",
|
||||
path=temp_data_file,
|
||||
file_extension=FileExtension.txt,
|
||||
file_md5="hash",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.base
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_txt(txt_qfile):
|
||||
tparser = TikTokenTxtProcessor(
|
||||
splitter_config=SplitterConfig(chunk_size=20, chunk_overlap=0)
|
||||
)
|
||||
doc = await tparser.process_file(txt_qfile)
|
||||
assert len(doc) > 0
|
||||
assert doc[0].page_content == "This is some test data."
|
||||
# assert dict1.items() <= dict2.items()
|
||||
|
||||
assert (
|
||||
doc[0].metadata.items()
|
||||
>= {
|
||||
"chunk_size": len(doc[0].page_content),
|
||||
"chunk_overlap": 0,
|
||||
"parser_name": tparser.__class__.__name__,
|
||||
"quivr_core_version": version("quivr-core"),
|
||||
**txt_qfile.metadata,
|
||||
}.items()
|
||||
)
|
@ -1,3 +1,4 @@
|
||||
from dataclasses import asdict
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@ -16,8 +17,11 @@ def test_brain_empty_files():
|
||||
Brain.from_files(name="test_brain", file_paths=[])
|
||||
|
||||
|
||||
def test_brain_from_files_success(fake_llm: LLMEndpoint, embedder, temp_data_file):
|
||||
brain = Brain.from_files(
|
||||
@pytest.mark.asyncio
|
||||
async def test_brain_from_files_success(
|
||||
fake_llm: LLMEndpoint, embedder, temp_data_file
|
||||
):
|
||||
brain = await Brain.afrom_files(
|
||||
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
|
||||
)
|
||||
assert brain.name == "test_brain"
|
||||
@ -29,7 +33,7 @@ def test_brain_from_files_success(fake_llm: LLMEndpoint, embedder, temp_data_fil
|
||||
|
||||
# storage
|
||||
assert isinstance(brain.storage, TransparentStorage)
|
||||
assert len(brain.storage.get_files()) == 1
|
||||
assert len(await brain.storage.get_files()) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -39,7 +43,7 @@ async def test_brain_from_langchain_docs(embedder):
|
||||
name="test", langchain_documents=[chunk], embedder=embedder
|
||||
)
|
||||
# No appended files
|
||||
assert len(brain.storage.get_files()) == 0
|
||||
assert len(await brain.storage.get_files()) == 0
|
||||
assert len(brain.chat_history) == 0
|
||||
|
||||
|
||||
@ -90,3 +94,28 @@ async def test_brain_ask_streaming(
|
||||
response += chunk.answer
|
||||
|
||||
assert response == answers[1]
|
||||
|
||||
|
||||
def test_brain_info_empty(fake_llm: LLMEndpoint, embedder, mem_vector_store):
|
||||
storage = TransparentStorage()
|
||||
id = uuid4()
|
||||
brain = Brain(
|
||||
name="test",
|
||||
id=id,
|
||||
llm=fake_llm,
|
||||
embedder=embedder,
|
||||
storage=storage,
|
||||
vector_db=mem_vector_store,
|
||||
)
|
||||
|
||||
assert asdict(brain.info()) == {
|
||||
"brain_id": id,
|
||||
"brain_name": "test",
|
||||
"files_info": asdict(storage.info()),
|
||||
"chats_info": {
|
||||
"nb_chats": 1, # start with a default chat
|
||||
"current_default_chat": brain.default_chat.id,
|
||||
"current_chat_history_length": 0,
|
||||
},
|
||||
"llm_info": asdict(fake_llm.info()),
|
||||
}
|
||||
|
0
backend/core/tests/test_quivr_file.py
Normal file
0
backend/core/tests/test_quivr_file.py
Normal file
Loading…
Reference in New Issue
Block a user