mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-26 12:55:01 +03:00
feat: quivr core brain info + processors registry + (#2877)
# Description - Created registry processor logic for automagically adding processors to quivr_core based Entrypoints - Added a langchain_community free `SimpleTxtParser` for the quivr_core base package - Added tests - Added brain_info - Enriched parsed documents metadata based on quivr_file metadata used Rich for `Brain.print_info()` to get a better output: ![image](https://github.com/user-attachments/assets/dd9f2f03-d7d7-4be0-ba6c-3fe38e11c40f)
This commit is contained in:
parent
3b68855a83
commit
3001fa1475
@ -1,19 +1,42 @@
|
|||||||
from langchain_core.embeddings import DeterministicFakeEmbedding
|
from langchain_core.embeddings import DeterministicFakeEmbedding
|
||||||
from langchain_core.language_models import FakeListChatModel
|
from langchain_core.language_models import FakeListChatModel
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.prompt import Prompt
|
||||||
|
|
||||||
from quivr_core import Brain
|
from quivr_core import Brain
|
||||||
from quivr_core.processor.default_parsers import DEFAULT_PARSERS
|
from quivr_core.config import LLMEndpointConfig
|
||||||
from quivr_core.processor.pdf_processor import TikaParser
|
from quivr_core.llm.llm_endpoint import LLMEndpoint
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pdf_paths = ["../tests/processor/data/dummy.pdf"]
|
|
||||||
brain = Brain.from_files(
|
brain = Brain.from_files(
|
||||||
name="test_brain",
|
name="test_brain",
|
||||||
file_paths=[],
|
file_paths=["tests/processor/data/dummy.pdf"],
|
||||||
llm=FakeListChatModel(responses=["good"]),
|
llm=LLMEndpoint(
|
||||||
|
llm=FakeListChatModel(responses=["good"]),
|
||||||
|
llm_config=LLMEndpointConfig(model="fake_model", llm_base_url="local"),
|
||||||
|
),
|
||||||
embedder=DeterministicFakeEmbedding(size=20),
|
embedder=DeterministicFakeEmbedding(size=20),
|
||||||
processors_mapping={
|
|
||||||
**DEFAULT_PARSERS,
|
|
||||||
".pdf": TikaParser(),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
# Check brain info
|
||||||
|
brain.print_info()
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
console.print(Panel.fit("Ask your brain !", style="bold magenta"))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Get user input
|
||||||
|
question = Prompt.ask("[bold cyan]Question[/bold cyan]")
|
||||||
|
|
||||||
|
# Check if user wants to exit
|
||||||
|
if question.lower() == "exit":
|
||||||
|
console.print(Panel("Goodbye!", style="bold yellow"))
|
||||||
|
break
|
||||||
|
|
||||||
|
answer = brain.ask(question)
|
||||||
|
# Print the answer with typing effect
|
||||||
|
console.print(f"[bold green]Quivr Assistant[/bold green]: {answer.answer}")
|
||||||
|
|
||||||
|
console.print("-" * console.width)
|
||||||
|
|
||||||
|
brain.print_info()
|
||||||
|
@ -12,3 +12,5 @@ if __name__ == "__main__":
|
|||||||
answer = brain.ask("Property of gold?")
|
answer = brain.ask("Property of gold?")
|
||||||
|
|
||||||
print("answer :", answer.answer)
|
print("answer :", answer.answer)
|
||||||
|
|
||||||
|
print("brain information: ", brain)
|
||||||
|
75
backend/core/poetry.lock
generated
75
backend/core/poetry.lock
generated
@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.8.0 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiofiles"
|
name = "aiofiles"
|
||||||
@ -1189,13 +1189,13 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.2.19"
|
version = "0.2.20"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "<4.0,>=3.8.1"
|
python-versions = "<4.0,>=3.8.1"
|
||||||
files = [
|
files = [
|
||||||
{file = "langchain_core-0.2.19-py3-none-any.whl", hash = "sha256:5b3cd34395be274c89e822c84f0e03c4da14168c177a83921c5b9414ac7a0651"},
|
{file = "langchain_core-0.2.20-py3-none-any.whl", hash = "sha256:16cc4da6f7ebf33accea7af45a70480733dc852ab291030fb6924865bd7caf76"},
|
||||||
{file = "langchain_core-0.2.19.tar.gz", hash = "sha256:13043a83e5c9ab58b9f5ce2a56896e7e88b752e8891b2958960a98e71801471e"},
|
{file = "langchain_core-0.2.20.tar.gz", hash = "sha256:a66c439e085d8c75f822f7650a5551d17bada4003521173c763d875d949e4ed5"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -1241,13 +1241,13 @@ langchain-core = ">=0.2.10,<0.3.0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langsmith"
|
name = "langsmith"
|
||||||
version = "0.1.85"
|
version = "0.1.88"
|
||||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "<4.0,>=3.8.1"
|
python-versions = "<4.0,>=3.8.1"
|
||||||
files = [
|
files = [
|
||||||
{file = "langsmith-0.1.85-py3-none-any.whl", hash = "sha256:c1f94384f10cea96f7b4d33fd3db7ec180c03c7468877d50846f881d2017ff94"},
|
{file = "langsmith-0.1.88-py3-none-any.whl", hash = "sha256:460ebb7de440afd150fcea8f54ca8779821f2228cd59e149e5845c9dbe06db16"},
|
||||||
{file = "langsmith-0.1.85.tar.gz", hash = "sha256:acff31f9e53efa48586cf8e32f65625a335c74d7c4fa306d1655ac18452296f6"},
|
{file = "langsmith-0.1.88.tar.gz", hash = "sha256:28a07dec19197f4808aa2628d5a3ccafcbe14cc137aef0e607bbd128e7907821"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -1258,6 +1258,30 @@ pydantic = [
|
|||||||
]
|
]
|
||||||
requests = ">=2,<3"
|
requests = ">=2,<3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "markdown-it-py"
|
||||||
|
version = "3.0.0"
|
||||||
|
description = "Python port of markdown-it. Markdown parsing, done right!"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"},
|
||||||
|
{file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
mdurl = ">=0.1,<1.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
|
||||||
|
code-style = ["pre-commit (>=3.0,<4.0)"]
|
||||||
|
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
|
||||||
|
linkify = ["linkify-it-py (>=1,<3)"]
|
||||||
|
plugins = ["mdit-py-plugins"]
|
||||||
|
profiling = ["gprof2dot"]
|
||||||
|
rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
|
||||||
|
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "marshmallow"
|
name = "marshmallow"
|
||||||
version = "3.21.3"
|
version = "3.21.3"
|
||||||
@ -1302,6 +1326,17 @@ files = [
|
|||||||
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mdurl"
|
||||||
|
version = "0.1.2"
|
||||||
|
description = "Markdown URL utilities"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
|
||||||
|
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "multidict"
|
name = "multidict"
|
||||||
version = "6.0.5"
|
version = "6.0.5"
|
||||||
@ -1527,13 +1562,13 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openai"
|
name = "openai"
|
||||||
version = "1.35.13"
|
version = "1.35.14"
|
||||||
description = "The official Python library for the openai API"
|
description = "The official Python library for the openai API"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7.1"
|
python-versions = ">=3.7.1"
|
||||||
files = [
|
files = [
|
||||||
{file = "openai-1.35.13-py3-none-any.whl", hash = "sha256:36ec3e93e0d1f243f69be85c89b9221a471c3e450dfd9df16c9829e3cdf63e60"},
|
{file = "openai-1.35.14-py3-none-any.whl", hash = "sha256:adadf8c176e0b8c47ad782ed45dc20ef46438ee1f02c7103c4155cff79c8f68b"},
|
||||||
{file = "openai-1.35.13.tar.gz", hash = "sha256:c684f3945608baf7d2dcc0ef3ee6f3e27e4c66f21076df0b47be45d57e6ae6e4"},
|
{file = "openai-1.35.14.tar.gz", hash = "sha256:394ba1dfd12ecec1d634c50e512d24ff1858bbc2674ffcce309b822785a058de"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -2328,6 +2363,24 @@ urllib3 = ">=1.21.1,<3"
|
|||||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rich"
|
||||||
|
version = "13.7.1"
|
||||||
|
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7.0"
|
||||||
|
files = [
|
||||||
|
{file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"},
|
||||||
|
{file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
markdown-it-py = ">=2.2.0"
|
||||||
|
pygments = ">=2.13.0,<3.0.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruff"
|
name = "ruff"
|
||||||
version = "0.4.10"
|
version = "0.4.10"
|
||||||
@ -2788,4 +2841,4 @@ pdf = []
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "f5ab0d3f93f3bd517382a8c4207a4a480f4a6dcb0aeaf60501c379a4a91a3ad0"
|
content-hash = "cc08905f149df2f415e1d00010e5b89a371efdcf4059855d597b1b6e9973a536"
|
||||||
|
@ -15,6 +15,7 @@ aiofiles = ">=23.0.0,<25.0.0"
|
|||||||
faiss-cpu = { version = "^1.8.0.post1", optional = true }
|
faiss-cpu = { version = "^1.8.0.post1", optional = true }
|
||||||
langchain-community = { version = "^0.2.6", optional = true }
|
langchain-community = { version = "^0.2.6", optional = true }
|
||||||
langchain-openai = { version = "^0.1.14", optional = true }
|
langchain-openai = { version = "^0.1.14", optional = true }
|
||||||
|
rich = "^13.7.1"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
base = ["langchain-community", "faiss-cpu", "langchain-openai"]
|
base = ["langchain-community", "faiss-cpu", "langchain-openai"]
|
||||||
@ -81,6 +82,10 @@ known-first-party = []
|
|||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "--tb=short -ra -v"
|
addopts = "--tb=short -ra -v"
|
||||||
filterwarnings = ["ignore::DeprecationWarning"]
|
filterwarnings = ["ignore::DeprecationWarning"]
|
||||||
|
markers = [
|
||||||
|
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||||
|
"base: these tests require quivr-core with extra `base` to be installed",
|
||||||
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
@ -1,3 +1,35 @@
|
|||||||
from .brain import Brain
|
from importlib.metadata import entry_points
|
||||||
|
|
||||||
__all__ = ["Brain"]
|
from .brain import Brain
|
||||||
|
from .processor.registry import register_processor, registry
|
||||||
|
|
||||||
|
__all__ = ["Brain", "registry", "register_processor"]
|
||||||
|
|
||||||
|
|
||||||
|
def register_entries():
|
||||||
|
if entry_points is not None:
|
||||||
|
try:
|
||||||
|
eps = entry_points()
|
||||||
|
except TypeError:
|
||||||
|
pass # importlib-metadata < 0.8
|
||||||
|
else:
|
||||||
|
if hasattr(eps, "select"): # Python 3.10+ / importlib_metadata >= 3.9.0
|
||||||
|
processors = eps.select(group="quivr_core.processor")
|
||||||
|
else:
|
||||||
|
processors = eps.get("quivr_core.processor", [])
|
||||||
|
registered_names = set()
|
||||||
|
for spec in processors:
|
||||||
|
err_msg = f"Unable to load processor from {spec}"
|
||||||
|
name = spec.name
|
||||||
|
if name in registered_names:
|
||||||
|
continue
|
||||||
|
registered_names.add(name)
|
||||||
|
register_processor(
|
||||||
|
name,
|
||||||
|
spec.value.replace(":", "."),
|
||||||
|
errtxt=err_msg,
|
||||||
|
override=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_entries()
|
||||||
|
3
backend/core/quivr_core/brain/__init__.py
Normal file
3
backend/core/quivr_core/brain/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .brain import Brain
|
||||||
|
|
||||||
|
__all__ = ["Brain"]
|
@ -1,29 +1,34 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, AsyncGenerator, Callable, Dict, Mapping, Self
|
from pprint import PrettyPrinter
|
||||||
|
from typing import Any, AsyncGenerator, Callable, Dict, Self
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
from quivr_core.brain.info import BrainInfo, ChatHistoryInfo
|
||||||
from quivr_core.chat import ChatHistory
|
from quivr_core.chat import ChatHistory
|
||||||
from quivr_core.config import LLMEndpointConfig, RAGConfig
|
from quivr_core.config import LLMEndpointConfig, RAGConfig
|
||||||
from quivr_core.llm import LLMEndpoint
|
from quivr_core.llm import LLMEndpoint
|
||||||
from quivr_core.models import ParsedRAGChunkResponse, ParsedRAGResponse, SearchResult
|
from quivr_core.models import ParsedRAGChunkResponse, ParsedRAGResponse, SearchResult
|
||||||
from quivr_core.processor.default_parsers import DEFAULT_PARSERS
|
from quivr_core.processor.registry import get_processor_class
|
||||||
from quivr_core.processor.processor_base import ProcessorBase
|
|
||||||
from quivr_core.quivr_rag import QuivrQARAG
|
from quivr_core.quivr_rag import QuivrQARAG
|
||||||
from quivr_core.storage.file import QuivrFile
|
from quivr_core.storage.file import load_qfile
|
||||||
from quivr_core.storage.local_storage import TransparentStorage
|
from quivr_core.storage.local_storage import TransparentStorage
|
||||||
from quivr_core.storage.storage_base import StorageBase
|
from quivr_core.storage.storage_base import StorageBase
|
||||||
|
|
||||||
logger = logging.getLogger("quivr_core")
|
logger = logging.getLogger("quivr_core")
|
||||||
|
|
||||||
|
|
||||||
async def _default_vectordb(docs: list[Document], embedder: Embeddings) -> VectorStore:
|
async def _build_default_vectordb(
|
||||||
|
docs: list[Document], embedder: Embeddings
|
||||||
|
) -> VectorStore:
|
||||||
try:
|
try:
|
||||||
from langchain_community.vectorstores import FAISS
|
from langchain_community.vectorstores import FAISS
|
||||||
|
|
||||||
@ -38,7 +43,7 @@ async def _default_vectordb(docs: list[Document], embedder: Embeddings) -> Vecto
|
|||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Please provide a valid vectore store or install quivr-core['base'] package for using the default one."
|
"Please provide a valid vector store or install quivr-core['base'] package for using the default one."
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@ -67,16 +72,15 @@ def _default_llm() -> LLMEndpoint:
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
async def _process_files(
|
async def process_files(
|
||||||
storage: StorageBase,
|
storage: StorageBase, skip_file_error: bool, **processor_kwargs: dict[str, Any]
|
||||||
skip_file_error: bool,
|
|
||||||
processors_mapping: Mapping[str, ProcessorBase],
|
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
knowledge = []
|
knowledge = []
|
||||||
for file in storage.get_files():
|
for file in await storage.get_files():
|
||||||
try:
|
try:
|
||||||
if file.file_extension:
|
if file.file_extension:
|
||||||
processor = processors_mapping[file.file_extension]
|
processor_cls = get_processor_class(file.file_extension)
|
||||||
|
processor = processor_cls(**processor_kwargs)
|
||||||
docs = await processor.process_file(file)
|
docs = await processor.process_file(file)
|
||||||
knowledge.extend(docs)
|
knowledge.extend(docs)
|
||||||
else:
|
else:
|
||||||
@ -118,11 +122,38 @@ class Brain:
|
|||||||
self.vector_db = vector_db
|
self.vector_db = vector_db
|
||||||
self.embedder = embedder
|
self.embedder = embedder
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
pp = PrettyPrinter(width=80, depth=None, compact=False, sort_dicts=False)
|
||||||
|
return pp.pformat(self.info())
|
||||||
|
|
||||||
|
def print_info(self):
|
||||||
|
console = Console()
|
||||||
|
tree = self.info().to_tree()
|
||||||
|
panel = Panel(tree, title="Brain Info", expand=False, border_style="bold")
|
||||||
|
console.print(panel)
|
||||||
|
|
||||||
|
def info(self) -> BrainInfo:
|
||||||
|
# TODO: dim of embedding
|
||||||
|
# "embedder": {},
|
||||||
|
chats_info = ChatHistoryInfo(
|
||||||
|
nb_chats=len(self._chats),
|
||||||
|
current_default_chat=self.default_chat.id,
|
||||||
|
current_chat_history_length=len(self.default_chat),
|
||||||
|
)
|
||||||
|
|
||||||
|
return BrainInfo(
|
||||||
|
brain_id=self.id,
|
||||||
|
brain_name=self.name,
|
||||||
|
files_info=self.storage.info(),
|
||||||
|
chats_info=chats_info,
|
||||||
|
llm_info=self.llm.info(),
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_history(self):
|
def chat_history(self) -> ChatHistory:
|
||||||
return self.default_chat
|
return self.default_chat
|
||||||
|
|
||||||
def _init_chats(self):
|
def _init_chats(self) -> Dict[UUID, ChatHistory]:
|
||||||
chat_id = uuid4()
|
chat_id = uuid4()
|
||||||
default_chat = ChatHistory(chat_id=chat_id, brain_id=self.id)
|
default_chat = ChatHistory(chat_id=chat_id, brain_id=self.id)
|
||||||
return {chat_id: default_chat}
|
return {chat_id: default_chat}
|
||||||
@ -137,7 +168,6 @@ class Brain:
|
|||||||
storage: StorageBase = TransparentStorage(),
|
storage: StorageBase = TransparentStorage(),
|
||||||
llm: LLMEndpoint | None = None,
|
llm: LLMEndpoint | None = None,
|
||||||
embedder: Embeddings | None = None,
|
embedder: Embeddings | None = None,
|
||||||
processors_mapping: Mapping[str, ProcessorBase] = DEFAULT_PARSERS,
|
|
||||||
skip_file_error: bool = False,
|
skip_file_error: bool = False,
|
||||||
):
|
):
|
||||||
if llm is None:
|
if llm is None:
|
||||||
@ -148,20 +178,20 @@ class Brain:
|
|||||||
|
|
||||||
brain_id = uuid4()
|
brain_id = uuid4()
|
||||||
|
|
||||||
|
# TODO: run in parallel using tasks
|
||||||
for path in file_paths:
|
for path in file_paths:
|
||||||
file = QuivrFile.from_path(brain_id, path)
|
file = await load_qfile(brain_id, path)
|
||||||
storage.upload_file(file)
|
await storage.upload_file(file)
|
||||||
|
|
||||||
# Parse files
|
# Parse files
|
||||||
docs = await _process_files(
|
docs = await process_files(
|
||||||
storage=storage,
|
storage=storage,
|
||||||
processors_mapping=processors_mapping,
|
|
||||||
skip_file_error=skip_file_error,
|
skip_file_error=skip_file_error,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Building brain's vectordb
|
# Building brain's vectordb
|
||||||
if vector_db is None:
|
if vector_db is None:
|
||||||
vector_db = await _default_vectordb(docs, embedder)
|
vector_db = await _build_default_vectordb(docs, embedder)
|
||||||
else:
|
else:
|
||||||
await vector_db.aadd_documents(docs)
|
await vector_db.aadd_documents(docs)
|
||||||
|
|
||||||
@ -184,10 +214,10 @@ class Brain:
|
|||||||
storage: StorageBase = TransparentStorage(),
|
storage: StorageBase = TransparentStorage(),
|
||||||
llm: LLMEndpoint | None = None,
|
llm: LLMEndpoint | None = None,
|
||||||
embedder: Embeddings | None = None,
|
embedder: Embeddings | None = None,
|
||||||
processors_mapping: Mapping[str, ProcessorBase] = DEFAULT_PARSERS,
|
|
||||||
skip_file_error: bool = False,
|
skip_file_error: bool = False,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
return asyncio.run(
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
cls.afrom_files(
|
cls.afrom_files(
|
||||||
name=name,
|
name=name,
|
||||||
file_paths=file_paths,
|
file_paths=file_paths,
|
||||||
@ -195,7 +225,6 @@ class Brain:
|
|||||||
storage=storage,
|
storage=storage,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
embedder=embedder,
|
embedder=embedder,
|
||||||
processors_mapping=processors_mapping,
|
|
||||||
skip_file_error=skip_file_error,
|
skip_file_error=skip_file_error,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -221,7 +250,7 @@ class Brain:
|
|||||||
|
|
||||||
# Building brain's vectordb
|
# Building brain's vectordb
|
||||||
if vector_db is None:
|
if vector_db is None:
|
||||||
vector_db = await _default_vectordb(langchain_documents, embedder)
|
vector_db = await _build_default_vectordb(langchain_documents, embedder)
|
||||||
else:
|
else:
|
||||||
await vector_db.aadd_documents(langchain_documents)
|
await vector_db.aadd_documents(langchain_documents)
|
||||||
|
|
74
backend/core/quivr_core/brain/info.py
Normal file
74
backend/core/quivr_core/brain/info.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from rich.tree import Tree
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatHistoryInfo:
|
||||||
|
nb_chats: int
|
||||||
|
current_default_chat: UUID
|
||||||
|
current_chat_history_length: int
|
||||||
|
|
||||||
|
def add_to_tree(self, chats_tree: Tree):
|
||||||
|
chats_tree.add(f"Number of Chats: [bold]{self.nb_chats}[/bold]")
|
||||||
|
chats_tree.add(
|
||||||
|
f"Current Default Chat: [bold magenta]{self.current_default_chat}[/bold magenta]"
|
||||||
|
)
|
||||||
|
chats_tree.add(
|
||||||
|
f"Current Chat History Length: [bold]{self.current_chat_history_length}[/bold]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMInfo:
|
||||||
|
model: str
|
||||||
|
llm_base_url: str
|
||||||
|
temperature: float
|
||||||
|
max_tokens: int
|
||||||
|
supports_function_calling: int
|
||||||
|
|
||||||
|
def add_to_tree(self, llm_tree: Tree):
|
||||||
|
llm_tree.add(f"Model: [italic]{self.model}[/italic]")
|
||||||
|
llm_tree.add(f"Base URL: [underline]{self.llm_base_url}[/underline]")
|
||||||
|
llm_tree.add(f"Temperature: [bold]{self.temperature}[/bold]")
|
||||||
|
llm_tree.add(f"Max Tokens: [bold]{self.max_tokens}[/bold]")
|
||||||
|
|
||||||
|
func_call_color = "green" if self.supports_function_calling else "red"
|
||||||
|
llm_tree.add(
|
||||||
|
f"Supports Function Calling: [bold {func_call_color}]{self.supports_function_calling}[/bold {func_call_color}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StorageInfo:
|
||||||
|
storage_type: str
|
||||||
|
n_files: int
|
||||||
|
|
||||||
|
def add_to_tree(self, files_tree: Tree):
|
||||||
|
files_tree.add(f"Storage Type: [italic]{self.storage_type}[/italic]")
|
||||||
|
files_tree.add(f"Number of Files: [bold]{self.n_files}[/bold]")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BrainInfo:
|
||||||
|
brain_id: UUID
|
||||||
|
brain_name: str
|
||||||
|
files_info: StorageInfo
|
||||||
|
chats_info: ChatHistoryInfo
|
||||||
|
llm_info: LLMInfo
|
||||||
|
|
||||||
|
def to_tree(self):
|
||||||
|
tree = Tree("📊 Brain Information")
|
||||||
|
tree.add(f"🆔 ID: [bold cyan]{self.brain_id}[/bold cyan]")
|
||||||
|
tree.add(f"🧠 Brain Name: [bold green]{self.brain_name}[/bold green]")
|
||||||
|
|
||||||
|
files_tree = tree.add("📁 Files")
|
||||||
|
self.files_info.add_to_tree(files_tree)
|
||||||
|
|
||||||
|
chats_tree = tree.add("💬 Chats")
|
||||||
|
self.chats_info.add_to_tree(chats_tree)
|
||||||
|
|
||||||
|
llm_tree = tree.add("🤖 LLM")
|
||||||
|
self.llm_info.add_to_tree(llm_tree)
|
||||||
|
return tree
|
@ -1,6 +1,7 @@
|
|||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from pydantic.v1 import SecretStr
|
from pydantic.v1 import SecretStr
|
||||||
|
|
||||||
|
from quivr_core.brain.info import LLMInfo
|
||||||
from quivr_core.config import LLMEndpointConfig
|
from quivr_core.config import LLMEndpointConfig
|
||||||
from quivr_core.utils import model_supports_function_calling
|
from quivr_core.utils import model_supports_function_calling
|
||||||
|
|
||||||
@ -35,3 +36,14 @@ class LLMEndpoint:
|
|||||||
|
|
||||||
def supports_func_calling(self) -> bool:
|
def supports_func_calling(self) -> bool:
|
||||||
return self._supports_func_calling
|
return self._supports_func_calling
|
||||||
|
|
||||||
|
def info(self) -> LLMInfo:
|
||||||
|
return LLMInfo(
|
||||||
|
model=self._config.model,
|
||||||
|
llm_base_url=(
|
||||||
|
self._config.llm_base_url if self._config.llm_base_url else "openai"
|
||||||
|
),
|
||||||
|
temperature=self._config.temperature,
|
||||||
|
max_tokens=self._config.max_tokens,
|
||||||
|
supports_function_calling=self.supports_func_calling(),
|
||||||
|
)
|
||||||
|
@ -1,6 +0,0 @@
|
|||||||
from quivr_core.processor.processor_base import ProcessorBase
|
|
||||||
from quivr_core.processor.txt_parser import TxtProcessor
|
|
||||||
|
|
||||||
DEFAULT_PARSERS: dict[str, ProcessorBase] = {
|
|
||||||
".txt": TxtProcessor(),
|
|
||||||
}
|
|
@ -1,27 +1,20 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Generic, TypeVar
|
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from quivr_core.storage.file import QuivrFile
|
from quivr_core.storage.file import FileExtension, QuivrFile
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: processors should be cached somewhere ?
|
||||||
|
# The processor should be cached by processor type
|
||||||
|
# The cache should use a single
|
||||||
class ProcessorBase(ABC):
|
class ProcessorBase(ABC):
|
||||||
supported_extensions: list[str]
|
supported_extensions: list[FileExtension | str]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def process_file(self, file: QuivrFile) -> list[Document]:
|
async def process_file(self, file: QuivrFile) -> list[Document]:
|
||||||
pass
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def check_supported(self, file: QuivrFile):
|
||||||
P = TypeVar("P", bound=ProcessorBase)
|
if file.file_extension not in self.supported_extensions:
|
||||||
|
raise ValueError(f"can't process a file of type {file.file_extension}")
|
||||||
|
|
||||||
class ProcessorsMapping(Generic[P]):
|
|
||||||
def __init__(self, mapping: dict[str, P]) -> None:
|
|
||||||
# Create an empty list with items of type T
|
|
||||||
self.ext_parser: dict[str, P] = mapping
|
|
||||||
|
|
||||||
def add_parser(self, extension: str, parser: P):
|
|
||||||
# TODO: deal with existing ext keys
|
|
||||||
self.ext_parser[extension] = parser
|
|
||||||
|
107
backend/core/quivr_core/processor/registry.py
Normal file
107
backend/core/quivr_core/processor/registry.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import importlib
|
||||||
|
import types
|
||||||
|
from typing import Type, TypedDict
|
||||||
|
|
||||||
|
from quivr_core.storage.file import FileExtension
|
||||||
|
|
||||||
|
from .processor_base import ProcessorBase
|
||||||
|
|
||||||
|
_registry: dict[str, Type[ProcessorBase]] = {}
|
||||||
|
|
||||||
|
# external, read only
|
||||||
|
registry = types.MappingProxyType(_registry)
|
||||||
|
|
||||||
|
|
||||||
|
class ProcEntry(TypedDict):
|
||||||
|
cls_mod: str
|
||||||
|
err: str | None
|
||||||
|
|
||||||
|
|
||||||
|
# Register based on mimetypes
|
||||||
|
known_processors: dict[FileExtension | str, ProcEntry] = {
|
||||||
|
FileExtension.txt: ProcEntry(
|
||||||
|
cls_mod="quivr_core.processor.simple_txt_processor.SimpleTxtProcessor",
|
||||||
|
err="Please install quivr_core[base] to use TikTokenTxtProcessor ",
|
||||||
|
),
|
||||||
|
FileExtension.pdf: ProcEntry(
|
||||||
|
cls_mod="quivr_core.processor.tika_processor.TikaProcessor",
|
||||||
|
err=None,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_processor_class(file_extension: FileExtension | str) -> Type[ProcessorBase]:
|
||||||
|
"""Fetch processor class from registry
|
||||||
|
|
||||||
|
The dict ``known_processors`` maps file extensions to the locations
|
||||||
|
of processors that could process them.
|
||||||
|
Loading of these classes is *Lazy*. Appropriate import will happen
|
||||||
|
the first time we try to process some file type.
|
||||||
|
|
||||||
|
Some processors need additional dependencies. If the import fails
|
||||||
|
we return the "err" field of the ProcEntry in ``known_processors``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if file_extension not in registry:
|
||||||
|
if file_extension not in known_processors:
|
||||||
|
raise ValueError(f"Extension not known: {file_extension}")
|
||||||
|
entry = known_processors[file_extension]
|
||||||
|
try:
|
||||||
|
register_processor(file_extension, _import_class(entry["cls_mod"]))
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(entry["err"]) from e
|
||||||
|
|
||||||
|
cls = registry[file_extension]
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
def register_processor(
|
||||||
|
file_type: FileExtension | str,
|
||||||
|
proc_cls: str | Type[ProcessorBase],
|
||||||
|
override: bool = False,
|
||||||
|
errtxt=None,
|
||||||
|
):
|
||||||
|
if isinstance(proc_cls, str):
|
||||||
|
if file_type in known_processors and override is False:
|
||||||
|
if proc_cls != known_processors[file_type]["cls_mod"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Processor for ({file_type}) already in the registry and override is False"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
known_processors[file_type] = ProcEntry(
|
||||||
|
cls_mod=proc_cls,
|
||||||
|
err=errtxt or f"{proc_cls} import failed for processor of {file_type}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if file_type in registry and override is False:
|
||||||
|
if _registry[file_type] is not proc_cls:
|
||||||
|
raise ValueError(
|
||||||
|
f"Processor for ({file_type}) already in the registry and override is False"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_registry[file_type] = proc_cls
|
||||||
|
|
||||||
|
|
||||||
|
def _import_class(full_mod_path: str):
|
||||||
|
if ":" in full_mod_path:
|
||||||
|
mod_name, name = full_mod_path.rsplit(":", 1)
|
||||||
|
else:
|
||||||
|
mod_name, name = full_mod_path.rsplit(".", 1)
|
||||||
|
|
||||||
|
mod = importlib.import_module(mod_name)
|
||||||
|
|
||||||
|
for cls in name.split("."):
|
||||||
|
mod = getattr(mod, cls)
|
||||||
|
|
||||||
|
if not isinstance(mod, type):
|
||||||
|
raise TypeError(f"{full_mod_path} is not a class")
|
||||||
|
|
||||||
|
if not issubclass(mod, ProcessorBase):
|
||||||
|
raise TypeError(f"{full_mod_path} is not a subclass of ProcessorBase ")
|
||||||
|
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
def available_processors():
|
||||||
|
"""Return a list of the known processors."""
|
||||||
|
return list(known_processors)
|
62
backend/core/quivr_core/processor/simple_txt_processor.py
Normal file
62
backend/core/quivr_core/processor/simple_txt_processor.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from importlib.metadata import version
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
from quivr_core.processor.processor_base import ProcessorBase
|
||||||
|
from quivr_core.processor.registry import FileExtension
|
||||||
|
from quivr_core.processor.splitter import SplitterConfig
|
||||||
|
from quivr_core.storage.file import QuivrFile
|
||||||
|
|
||||||
|
|
||||||
|
def recursive_character_splitter(
|
||||||
|
doc: Document, chunk_size: int, chunk_overlap: int
|
||||||
|
) -> list[Document]:
|
||||||
|
assert chunk_overlap < chunk_size, "chunk_overlap is greater than chunk_size"
|
||||||
|
|
||||||
|
if len(doc.page_content) <= chunk_size:
|
||||||
|
return [doc]
|
||||||
|
|
||||||
|
chunk = Document(page_content=doc.page_content[:chunk_size], metadata=doc.metadata)
|
||||||
|
remaining = Document(
|
||||||
|
page_content=doc.page_content[chunk_size - chunk_overlap :],
|
||||||
|
metadata=doc.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [chunk] + recursive_character_splitter(remaining, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleTxtProcessor(ProcessorBase):
|
||||||
|
supported_extensions = [FileExtension.txt]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, splitter_config: SplitterConfig = SplitterConfig(), **kwargs
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.splitter_config = splitter_config
|
||||||
|
|
||||||
|
async def process_file(self, file: QuivrFile) -> list[Document]:
|
||||||
|
self.check_supported(file)
|
||||||
|
file_metadata = file.metadata
|
||||||
|
|
||||||
|
async with aiofiles.open(file.path, mode="r") as f:
|
||||||
|
content = await f.read()
|
||||||
|
|
||||||
|
doc = Document(
|
||||||
|
page_content=content,
|
||||||
|
metadata={
|
||||||
|
"id": uuid4(),
|
||||||
|
"chunk_size": len(content),
|
||||||
|
"chunk_overlap": self.splitter_config.chunk_overlap,
|
||||||
|
"parser_name": self.__class__.__name__,
|
||||||
|
"quivr_core_version": version("quivr-core"),
|
||||||
|
**file_metadata,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = recursive_character_splitter(
|
||||||
|
doc, self.splitter_config.chunk_size, self.splitter_config.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
return docs
|
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from importlib.metadata import version
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -6,14 +7,15 @@ from langchain_core.documents import Document
|
|||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
|
|
||||||
from quivr_core.processor.processor_base import ProcessorBase
|
from quivr_core.processor.processor_base import ProcessorBase
|
||||||
|
from quivr_core.processor.registry import FileExtension
|
||||||
from quivr_core.processor.splitter import SplitterConfig
|
from quivr_core.processor.splitter import SplitterConfig
|
||||||
from quivr_core.storage.file import QuivrFile
|
from quivr_core.storage.file import QuivrFile
|
||||||
|
|
||||||
logger = logging.getLogger("quivr_core")
|
logger = logging.getLogger("quivr_core")
|
||||||
|
|
||||||
|
|
||||||
class TikaParser(ProcessorBase):
|
class TikaProcessor(ProcessorBase):
|
||||||
supported_extensions = [".pdf"]
|
supported_extensions = [FileExtension.pdf]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -51,12 +53,22 @@ class TikaParser(ProcessorBase):
|
|||||||
raise RuntimeError("can't send parse request to tika server")
|
raise RuntimeError("can't send parse request to tika server")
|
||||||
|
|
||||||
async def process_file(self, file: QuivrFile) -> list[Document]:
|
async def process_file(self, file: QuivrFile) -> list[Document]:
|
||||||
assert file.file_extension in self.supported_extensions
|
self.check_supported(file)
|
||||||
|
|
||||||
async with file.open() as f:
|
async with file.open() as f:
|
||||||
txt = await self._send_parse_tika(f)
|
txt = await self._send_parse_tika(f)
|
||||||
document = Document(page_content=txt)
|
document = Document(page_content=txt)
|
||||||
|
|
||||||
# Use the default splitter
|
# Use the default splitter
|
||||||
docs = self.text_splitter.split_documents([document])
|
docs = self.text_splitter.split_documents([document])
|
||||||
|
file_metadata = file.metadata
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
doc.metadata = {
|
||||||
|
"chunk_size": len(doc.page_content),
|
||||||
|
"chunk_overlap": self.splitter_config.chunk_overlap,
|
||||||
|
"parser_name": self.__class__.__name__,
|
||||||
|
"quivr_core_version": version("quivr-core"),
|
||||||
|
**file_metadata,
|
||||||
|
}
|
||||||
|
|
||||||
return docs
|
return docs
|
@ -1,19 +1,24 @@
|
|||||||
|
from importlib.metadata import version
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from langchain_community.document_loaders.text import TextLoader
|
from langchain_community.document_loaders.text import TextLoader
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
|
|
||||||
from quivr_core.processor.processor_base import ProcessorBase
|
from quivr_core.processor.processor_base import ProcessorBase
|
||||||
|
from quivr_core.processor.registry import FileExtension
|
||||||
from quivr_core.processor.splitter import SplitterConfig
|
from quivr_core.processor.splitter import SplitterConfig
|
||||||
from quivr_core.storage.file import QuivrFile
|
from quivr_core.storage.file import QuivrFile
|
||||||
|
|
||||||
|
|
||||||
class TxtProcessor(ProcessorBase):
|
class TikTokenTxtProcessor(ProcessorBase):
|
||||||
|
supported_extensions = [FileExtension.txt]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
splitter: TextSplitter | None = None,
|
splitter: TextSplitter | None = None,
|
||||||
splitter_config: SplitterConfig = SplitterConfig(),
|
splitter_config: SplitterConfig = SplitterConfig(),
|
||||||
) -> None:
|
) -> None:
|
||||||
self.supported_extensions = [".txt"]
|
|
||||||
self.loader_cls = TextLoader
|
self.loader_cls = TextLoader
|
||||||
|
|
||||||
self.splitter_config = splitter_config
|
self.splitter_config = splitter_config
|
||||||
@ -27,10 +32,22 @@ class TxtProcessor(ProcessorBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def process_file(self, file: QuivrFile) -> list[Document]:
|
async def process_file(self, file: QuivrFile) -> list[Document]:
|
||||||
if file.file_extension not in self.supported_extensions:
|
self.check_supported(file)
|
||||||
raise Exception(f"can't process a file of type {file.file_extension}")
|
|
||||||
|
|
||||||
loader = self.loader_cls(file.path)
|
loader = self.loader_cls(file.path)
|
||||||
documents = await loader.aload()
|
documents = await loader.aload()
|
||||||
docs = self.text_splitter.split_documents(documents)
|
docs = self.text_splitter.split_documents(documents)
|
||||||
|
|
||||||
|
file_metadata = file.metadata
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
doc.metadata = {
|
||||||
|
"id": uuid4(),
|
||||||
|
"chunk_size": len(doc.page_content),
|
||||||
|
"chunk_overlap": self.splitter_config.chunk_overlap,
|
||||||
|
"parser_name": self.__class__.__name__,
|
||||||
|
"quivr_core_version": version("quivr-core"),
|
||||||
|
**file_metadata,
|
||||||
|
}
|
||||||
|
|
||||||
return docs
|
return docs
|
@ -1,21 +1,87 @@
|
|||||||
|
import hashlib
|
||||||
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import AsyncGenerator, AsyncIterable
|
from typing import Any, AsyncGenerator, AsyncIterable
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
|
|
||||||
|
|
||||||
|
class FileExtension(str, Enum):
|
||||||
|
txt = ".txt"
|
||||||
|
pdf = ".pdf"
|
||||||
|
docx = ".docx"
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_extension(file_path: Path) -> FileExtension | str:
|
||||||
|
try:
|
||||||
|
mime_type, _ = mimetypes.guess_type(file_path.name)
|
||||||
|
if mime_type:
|
||||||
|
mime_ext = mimetypes.guess_extension(mime_type)
|
||||||
|
if mime_ext:
|
||||||
|
return FileExtension(mime_ext)
|
||||||
|
return FileExtension(file_path.suffix)
|
||||||
|
except ValueError:
|
||||||
|
warnings.warn(
|
||||||
|
f"File {file_path.name} extension isn't recognized. Make sure you have registered a parser for {file_path.suffix}",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return file_path.suffix
|
||||||
|
|
||||||
|
|
||||||
|
async def load_qfile(brain_id: UUID, path: str | Path):
|
||||||
|
if not isinstance(path, Path):
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
raise FileExistsError(f"file {path} doesn't exist")
|
||||||
|
|
||||||
|
file_size = os.stat(path).st_size
|
||||||
|
|
||||||
|
async with aiofiles.open(path, mode="rb") as f:
|
||||||
|
file_md5 = hashlib.md5(await f.read()).hexdigest()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# NOTE: when loading from existing storage, file name will be uuid
|
||||||
|
id = UUID(path.name)
|
||||||
|
except ValueError:
|
||||||
|
id = uuid4()
|
||||||
|
|
||||||
|
return QuivrFile(
|
||||||
|
id=id,
|
||||||
|
brain_id=brain_id,
|
||||||
|
path=path,
|
||||||
|
original_filename=path.name,
|
||||||
|
file_extension=get_file_extension(path),
|
||||||
|
file_size=file_size,
|
||||||
|
file_md5=file_md5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QuivrFile:
|
class QuivrFile:
|
||||||
|
__slots__ = [
|
||||||
|
"id",
|
||||||
|
"brain_id",
|
||||||
|
"path",
|
||||||
|
"original_filename",
|
||||||
|
"file_size",
|
||||||
|
"file_extension",
|
||||||
|
"file_md5",
|
||||||
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
id: UUID,
|
id: UUID,
|
||||||
original_filename: str,
|
original_filename: str,
|
||||||
path: Path,
|
path: Path,
|
||||||
brain_id: UUID,
|
brain_id: UUID,
|
||||||
|
file_md5: str,
|
||||||
|
file_extension: FileExtension | str,
|
||||||
file_size: int | None = None,
|
file_size: int | None = None,
|
||||||
file_extension: str | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.id = id
|
self.id = id
|
||||||
self.brain_id = brain_id
|
self.brain_id = brain_id
|
||||||
@ -23,31 +89,7 @@ class QuivrFile:
|
|||||||
self.original_filename = original_filename
|
self.original_filename = original_filename
|
||||||
self.file_size = file_size
|
self.file_size = file_size
|
||||||
self.file_extension = file_extension
|
self.file_extension = file_extension
|
||||||
|
self.file_md5 = file_md5
|
||||||
@classmethod
|
|
||||||
def from_path(cls, brain_id: UUID, path: str | Path):
|
|
||||||
if not isinstance(path, Path):
|
|
||||||
path = Path(path)
|
|
||||||
|
|
||||||
if not path.exists():
|
|
||||||
raise FileExistsError(f"file {path} doesn't exist")
|
|
||||||
|
|
||||||
file_size = os.stat(path).st_size
|
|
||||||
|
|
||||||
try:
|
|
||||||
# NOTE: when loading from existing storage, file name will be uuid
|
|
||||||
id = UUID(path.name)
|
|
||||||
except ValueError:
|
|
||||||
id = uuid4()
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
id=id,
|
|
||||||
brain_id=brain_id,
|
|
||||||
path=path,
|
|
||||||
original_filename=path.name,
|
|
||||||
file_size=file_size,
|
|
||||||
file_extension=path.suffix,
|
|
||||||
)
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def open(self) -> AsyncGenerator[AsyncIterable[bytes], None]:
|
async def open(self) -> AsyncGenerator[AsyncIterable[bytes], None]:
|
||||||
@ -57,3 +99,13 @@ class QuivrFile:
|
|||||||
yield f
|
yield f
|
||||||
finally:
|
finally:
|
||||||
await f.close()
|
await f.close()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metadata(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"qfile_id": self.id,
|
||||||
|
"qfile_path": self.path,
|
||||||
|
"original_file_name": self.original_filename,
|
||||||
|
"file_md4": self.file_md5,
|
||||||
|
"file_size": self.file_size,
|
||||||
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Set
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from quivr_core.storage.file import QuivrFile
|
from quivr_core.storage.file import QuivrFile
|
||||||
@ -8,8 +9,11 @@ from quivr_core.storage.storage_base import StorageBase
|
|||||||
|
|
||||||
|
|
||||||
class LocalStorage(StorageBase):
|
class LocalStorage(StorageBase):
|
||||||
|
name: str = "local_storage"
|
||||||
|
|
||||||
def __init__(self, dir_path: Path | None = None, copy_flag: bool = True):
|
def __init__(self, dir_path: Path | None = None, copy_flag: bool = True):
|
||||||
self.files: list[QuivrFile] = []
|
self.files: list[QuivrFile] = []
|
||||||
|
self.hashes: Set[str] = set()
|
||||||
self.copy_flag = copy_flag
|
self.copy_flag = copy_flag
|
||||||
|
|
||||||
if dir_path is None:
|
if dir_path is None:
|
||||||
@ -24,14 +28,19 @@ class LocalStorage(StorageBase):
|
|||||||
# TODO(@aminediro): load existing files
|
# TODO(@aminediro): load existing files
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
|
def nb_files(self) -> int:
|
||||||
|
return len(self.files)
|
||||||
|
|
||||||
|
def info(self):
|
||||||
|
return {"directory_path": self.dir_path, **super().info()}
|
||||||
|
|
||||||
|
async def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
|
||||||
dst_path = os.path.join(
|
dst_path = os.path.join(
|
||||||
self.dir_path, str(file.brain_id), f"{file.id}{file.file_extension}"
|
self.dir_path, str(file.brain_id), f"{file.id}{file.file_extension}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(@aminediro): Check hash of file not file path
|
if file.file_md5 in self.hashes and not exists_ok:
|
||||||
if os.path.exists(dst_path) and not exists_ok:
|
raise FileExistsError(f"file {file.original_filename} already uploaded")
|
||||||
raise FileExistsError("file already exists")
|
|
||||||
|
|
||||||
if self.copy_flag:
|
if self.copy_flag:
|
||||||
shutil.copy2(file.path, dst_path)
|
shutil.copy2(file.path, dst_path)
|
||||||
@ -40,28 +49,31 @@ class LocalStorage(StorageBase):
|
|||||||
|
|
||||||
file.path = Path(dst_path)
|
file.path = Path(dst_path)
|
||||||
self.files.append(file)
|
self.files.append(file)
|
||||||
|
self.hashes.add(file.file_md5)
|
||||||
|
|
||||||
def get_files(self) -> list[QuivrFile]:
|
async def get_files(self) -> list[QuivrFile]:
|
||||||
return self.files
|
return self.files
|
||||||
|
|
||||||
def remove_file(self, file_id: UUID) -> None:
|
async def remove_file(self, file_id: UUID) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class TransparentStorage(StorageBase):
|
class TransparentStorage(StorageBase):
|
||||||
"""Transparent Storage.
|
"""Transparent Storage."""
|
||||||
uses default
|
|
||||||
|
|
||||||
"""
|
name: str = "transparent_storage"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.files = []
|
self.id_files = {}
|
||||||
|
|
||||||
def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
|
async def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
|
||||||
self.files.append(file)
|
self.id_files[file.id] = file
|
||||||
|
|
||||||
def remove_file(self, file_id: UUID) -> None:
|
def nb_files(self) -> int:
|
||||||
|
return len(self.id_files)
|
||||||
|
|
||||||
|
async def remove_file(self, file_id: UUID) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_files(self) -> list[QuivrFile]:
|
async def get_files(self) -> list[QuivrFile]:
|
||||||
return self.files
|
return list(self.id_files.values())
|
||||||
|
@ -1,18 +1,39 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from quivr_core.brain.info import StorageInfo
|
||||||
from quivr_core.storage.local_storage import QuivrFile
|
from quivr_core.storage.local_storage import QuivrFile
|
||||||
|
|
||||||
|
|
||||||
class StorageBase(ABC):
|
class StorageBase(ABC):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
for required in ("name",):
|
||||||
|
if not getattr(cls, required):
|
||||||
|
raise TypeError(
|
||||||
|
f"Can't instantiate abstract class {cls.__name__} without {required} attribute defined"
|
||||||
|
)
|
||||||
|
return super().__init_subclass__(**kwargs)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_files(self) -> list[QuivrFile]:
|
def nb_files(self) -> int:
|
||||||
|
raise Exception("Unimplemented nb_files method")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_files(self) -> list[QuivrFile]:
|
||||||
raise Exception("Unimplemented get_files method")
|
raise Exception("Unimplemented get_files method")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
|
async def upload_file(self, file: QuivrFile, exists_ok: bool = False) -> None:
|
||||||
raise Exception("Unimplemented upload_file method")
|
raise Exception("Unimplemented upload_file method")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remove_file(self, file_id: UUID) -> None:
|
async def remove_file(self, file_id: UUID) -> None:
|
||||||
raise Exception("Unimplemented remove_file method")
|
raise Exception("Unimplemented remove_file method")
|
||||||
|
|
||||||
|
def info(self) -> StorageInfo:
|
||||||
|
return StorageInfo(
|
||||||
|
storage_type=self.name,
|
||||||
|
n_files=self.nb_files(),
|
||||||
|
)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.embeddings import DeterministicFakeEmbedding
|
from langchain_core.embeddings import DeterministicFakeEmbedding
|
||||||
@ -10,6 +12,39 @@ from langchain_core.vectorstores import InMemoryVectorStore
|
|||||||
|
|
||||||
from quivr_core.config import LLMEndpointConfig
|
from quivr_core.config import LLMEndpointConfig
|
||||||
from quivr_core.llm import LLMEndpoint
|
from quivr_core.llm import LLMEndpoint
|
||||||
|
from quivr_core.storage.file import FileExtension, QuivrFile
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def temp_data_file(tmp_path):
|
||||||
|
data = "This is some test data."
|
||||||
|
temp_file = tmp_path / "data.txt"
|
||||||
|
temp_file.write_text(data)
|
||||||
|
return temp_file
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def quivr_txt(temp_data_file):
|
||||||
|
return QuivrFile(
|
||||||
|
id=uuid4(),
|
||||||
|
brain_id=uuid4(),
|
||||||
|
original_filename=temp_data_file.name,
|
||||||
|
path=temp_data_file,
|
||||||
|
file_extension=FileExtension.txt,
|
||||||
|
file_md5="123",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def quivr_pdf():
|
||||||
|
return QuivrFile(
|
||||||
|
id=uuid4(),
|
||||||
|
brain_id=uuid4(),
|
||||||
|
original_filename="dummy.pdf",
|
||||||
|
path=Path("./tests/processor/data/dummy.pdf"),
|
||||||
|
file_extension=FileExtension.pdf,
|
||||||
|
file_md5="13bh234jh234",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -36,14 +71,6 @@ def openai_api_key():
|
|||||||
os.environ["OPENAI_API_KEY"] = "abcd"
|
os.environ["OPENAI_API_KEY"] = "abcd"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def temp_data_file(tmp_path):
|
|
||||||
data = "This is some test data."
|
|
||||||
temp_file = tmp_path / "data.txt"
|
|
||||||
temp_file.write_text(data)
|
|
||||||
return temp_file
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def answers():
|
def answers():
|
||||||
return [f"answer_{i}" for i in range(10)]
|
return [f"answer_{i}" for i in range(10)]
|
||||||
|
67
backend/core/tests/processor/test_registry.py
Normal file
67
backend/core/tests/processor/test_registry.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import pytest
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
from quivr_core import registry
|
||||||
|
from quivr_core.processor.processor_base import ProcessorBase
|
||||||
|
from quivr_core.processor.registry import (
|
||||||
|
_import_class,
|
||||||
|
get_processor_class,
|
||||||
|
register_processor,
|
||||||
|
)
|
||||||
|
from quivr_core.processor.simple_txt_processor import SimpleTxtProcessor
|
||||||
|
from quivr_core.processor.tika_processor import TikaProcessor
|
||||||
|
from quivr_core.storage.file import FileExtension, QuivrFile
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_processor_cls():
|
||||||
|
cls = get_processor_class(FileExtension.txt)
|
||||||
|
assert cls == SimpleTxtProcessor
|
||||||
|
cls = get_processor_class(FileExtension.pdf)
|
||||||
|
assert cls == TikaProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def test__import_class():
|
||||||
|
mod_path = "quivr_core.processor.tika_processor.TikaProcessor"
|
||||||
|
mod = _import_class(mod_path)
|
||||||
|
assert mod == TikaProcessor
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match=r".* is not a class"):
|
||||||
|
mod_path = "quivr_core.processor"
|
||||||
|
_import_class(mod_path)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match=r".* ProcessorBase"):
|
||||||
|
mod_path = "quivr_core.Brain"
|
||||||
|
_import_class(mod_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_processor_cls_error():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
get_processor_class(".docx")
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_new_proc():
|
||||||
|
nprocs = len(registry)
|
||||||
|
|
||||||
|
class TestProcessor(ProcessorBase):
|
||||||
|
supported_extensions = [".test"]
|
||||||
|
|
||||||
|
async def process_file(self, file: QuivrFile) -> list[Document]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
register_processor(".test", TestProcessor)
|
||||||
|
assert len(registry) == nprocs + 1
|
||||||
|
|
||||||
|
cls = get_processor_class(".test")
|
||||||
|
assert cls == TestProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_override_proc():
|
||||||
|
class TestProcessor(ProcessorBase):
|
||||||
|
supported_extensions = [".pdf"]
|
||||||
|
|
||||||
|
async def process_file(self, file: QuivrFile) -> list[Document]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
register_processor(".pdf", TestProcessor, override=True)
|
||||||
|
cls = get_processor_class(FileExtension.pdf)
|
||||||
|
assert cls == TestProcessor
|
34
backend/core/tests/processor/test_simple_txt_processor.py
Normal file
34
backend/core/tests/processor/test_simple_txt_processor.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import pytest
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
from quivr_core.processor.simple_txt_processor import (
|
||||||
|
SimpleTxtProcessor,
|
||||||
|
recursive_character_splitter,
|
||||||
|
)
|
||||||
|
from quivr_core.processor.splitter import SplitterConfig
|
||||||
|
from quivr_core.storage.file import FileExtension
|
||||||
|
|
||||||
|
|
||||||
|
def test_recursive_character_splitter():
|
||||||
|
doc = Document(page_content="abcdefgh", metadata={"key": "value"})
|
||||||
|
|
||||||
|
docs = recursive_character_splitter(doc, chunk_size=2, chunk_overlap=1)
|
||||||
|
|
||||||
|
assert [d.page_content for d in docs] == ["ab", "bc", "cd", "de", "ef", "fg", "gh"]
|
||||||
|
assert [d.metadata for d in docs] == [doc.metadata] * len(docs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_simple_processor(quivr_pdf, quivr_txt):
|
||||||
|
proc = SimpleTxtProcessor(
|
||||||
|
splitter_config=SplitterConfig(chunk_size=100, chunk_overlap=20)
|
||||||
|
)
|
||||||
|
assert proc.supported_extensions == [FileExtension.txt]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await proc.process_file(quivr_pdf)
|
||||||
|
|
||||||
|
docs = await proc.process_file(quivr_txt)
|
||||||
|
|
||||||
|
assert len(docs) == 1
|
||||||
|
assert docs[0].page_content == "This is some test data."
|
@ -1,38 +1,23 @@
|
|||||||
from pathlib import Path
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from quivr_core.processor.pdf_processor import TikaParser
|
from quivr_core.processor.tika_processor import TikaProcessor
|
||||||
from quivr_core.storage.file import QuivrFile
|
|
||||||
|
|
||||||
# TODO: TIKA server should be set
|
# TODO: TIKA server should be set
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def pdf():
|
|
||||||
return QuivrFile(
|
|
||||||
id=uuid4(),
|
|
||||||
brain_id=uuid4(),
|
|
||||||
original_filename="dummy.pdf",
|
|
||||||
path=Path("./tests/processor/data/dummy.pdf"),
|
|
||||||
file_extension=".pdf",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_file(pdf):
|
async def test_process_file(quivr_pdf):
|
||||||
tparser = TikaParser()
|
tparser = TikaProcessor()
|
||||||
doc = await tparser.process_file(pdf)
|
doc = await tparser.process_file(quivr_pdf)
|
||||||
assert len(doc) > 0
|
assert len(doc) > 0
|
||||||
assert doc[0].page_content.strip("\n") == "Dummy PDF download"
|
assert doc[0].page_content.strip("\n") == "Dummy PDF download"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_parse_tika_exception(pdf):
|
async def test_send_parse_tika_exception(quivr_pdf):
|
||||||
# TODO: Mock correct tika for retries
|
# TODO: Mock correct tika for retries
|
||||||
tparser = TikaParser(tika_url="test.test")
|
tparser = TikaProcessor(tika_url="test.test")
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
doc = await tparser.process_file(pdf)
|
doc = await tparser.process_file(quivr_pdf)
|
||||||
assert len(doc) > 0
|
assert len(doc) > 0
|
||||||
assert doc[0].page_content.strip("\n") == "Dummy PDF download"
|
assert doc[0].page_content.strip("\n") == "Dummy PDF download"
|
||||||
|
45
backend/core/tests/processor/test_txt_processor.py
Normal file
45
backend/core/tests/processor/test_txt_processor.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from importlib.metadata import version
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from quivr_core.processor.splitter import SplitterConfig
|
||||||
|
from quivr_core.processor.txt_processor import TikTokenTxtProcessor
|
||||||
|
from quivr_core.storage.file import FileExtension, QuivrFile
|
||||||
|
|
||||||
|
# TODO: TIKA server should be set
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def txt_qfile(temp_data_file):
|
||||||
|
return QuivrFile(
|
||||||
|
id=uuid4(),
|
||||||
|
brain_id=uuid4(),
|
||||||
|
original_filename="data.txt",
|
||||||
|
path=temp_data_file,
|
||||||
|
file_extension=FileExtension.txt,
|
||||||
|
file_md5="hash",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.base
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_txt(txt_qfile):
|
||||||
|
tparser = TikTokenTxtProcessor(
|
||||||
|
splitter_config=SplitterConfig(chunk_size=20, chunk_overlap=0)
|
||||||
|
)
|
||||||
|
doc = await tparser.process_file(txt_qfile)
|
||||||
|
assert len(doc) > 0
|
||||||
|
assert doc[0].page_content == "This is some test data."
|
||||||
|
# assert dict1.items() <= dict2.items()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
doc[0].metadata.items()
|
||||||
|
>= {
|
||||||
|
"chunk_size": len(doc[0].page_content),
|
||||||
|
"chunk_overlap": 0,
|
||||||
|
"parser_name": tparser.__class__.__name__,
|
||||||
|
"quivr_core_version": version("quivr-core"),
|
||||||
|
**txt_qfile.metadata,
|
||||||
|
}.items()
|
||||||
|
)
|
@ -1,3 +1,4 @@
|
|||||||
|
from dataclasses import asdict
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -16,8 +17,11 @@ def test_brain_empty_files():
|
|||||||
Brain.from_files(name="test_brain", file_paths=[])
|
Brain.from_files(name="test_brain", file_paths=[])
|
||||||
|
|
||||||
|
|
||||||
def test_brain_from_files_success(fake_llm: LLMEndpoint, embedder, temp_data_file):
|
@pytest.mark.asyncio
|
||||||
brain = Brain.from_files(
|
async def test_brain_from_files_success(
|
||||||
|
fake_llm: LLMEndpoint, embedder, temp_data_file
|
||||||
|
):
|
||||||
|
brain = await Brain.afrom_files(
|
||||||
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
|
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
|
||||||
)
|
)
|
||||||
assert brain.name == "test_brain"
|
assert brain.name == "test_brain"
|
||||||
@ -29,7 +33,7 @@ def test_brain_from_files_success(fake_llm: LLMEndpoint, embedder, temp_data_fil
|
|||||||
|
|
||||||
# storage
|
# storage
|
||||||
assert isinstance(brain.storage, TransparentStorage)
|
assert isinstance(brain.storage, TransparentStorage)
|
||||||
assert len(brain.storage.get_files()) == 1
|
assert len(await brain.storage.get_files()) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -39,7 +43,7 @@ async def test_brain_from_langchain_docs(embedder):
|
|||||||
name="test", langchain_documents=[chunk], embedder=embedder
|
name="test", langchain_documents=[chunk], embedder=embedder
|
||||||
)
|
)
|
||||||
# No appended files
|
# No appended files
|
||||||
assert len(brain.storage.get_files()) == 0
|
assert len(await brain.storage.get_files()) == 0
|
||||||
assert len(brain.chat_history) == 0
|
assert len(brain.chat_history) == 0
|
||||||
|
|
||||||
|
|
||||||
@ -90,3 +94,28 @@ async def test_brain_ask_streaming(
|
|||||||
response += chunk.answer
|
response += chunk.answer
|
||||||
|
|
||||||
assert response == answers[1]
|
assert response == answers[1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_brain_info_empty(fake_llm: LLMEndpoint, embedder, mem_vector_store):
|
||||||
|
storage = TransparentStorage()
|
||||||
|
id = uuid4()
|
||||||
|
brain = Brain(
|
||||||
|
name="test",
|
||||||
|
id=id,
|
||||||
|
llm=fake_llm,
|
||||||
|
embedder=embedder,
|
||||||
|
storage=storage,
|
||||||
|
vector_db=mem_vector_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert asdict(brain.info()) == {
|
||||||
|
"brain_id": id,
|
||||||
|
"brain_name": "test",
|
||||||
|
"files_info": asdict(storage.info()),
|
||||||
|
"chats_info": {
|
||||||
|
"nb_chats": 1, # start with a default chat
|
||||||
|
"current_default_chat": brain.default_chat.id,
|
||||||
|
"current_chat_history_length": 0,
|
||||||
|
},
|
||||||
|
"llm_info": asdict(fake_llm.info()),
|
||||||
|
}
|
||||||
|
0
backend/core/tests/test_quivr_file.py
Normal file
0
backend/core/tests/test_quivr_file.py
Normal file
Loading…
Reference in New Issue
Block a user