2024-07-19 10:47:39 +03:00
|
|
|
from dataclasses import asdict
|
2024-07-09 18:55:14 +03:00
|
|
|
from uuid import uuid4
|
|
|
|
|
2024-07-09 16:22:16 +03:00
|
|
|
import pytest
|
2024-07-09 18:55:14 +03:00
|
|
|
from langchain_core.documents import Document
|
2024-07-12 16:07:39 +03:00
|
|
|
from langchain_core.embeddings import Embeddings
|
2024-07-09 16:22:16 +03:00
|
|
|
|
|
|
|
from quivr_core.brain import Brain
|
2024-07-10 16:22:59 +03:00
|
|
|
from quivr_core.chat import ChatHistory
|
2024-07-09 18:55:14 +03:00
|
|
|
from quivr_core.llm import LLMEndpoint
|
|
|
|
from quivr_core.storage.local_storage import TransparentStorage
|
2024-07-09 16:22:16 +03:00
|
|
|
|
|
|
|
|
2024-07-30 19:49:12 +03:00
|
|
|
@pytest.mark.base
|
|
|
|
def test_brain_empty_files_no_vectordb(fake_llm, embedder):
|
2024-07-09 16:22:16 +03:00
|
|
|
# Testing no files
|
|
|
|
with pytest.raises(ValueError):
|
2024-07-30 19:49:12 +03:00
|
|
|
Brain.from_files(
|
|
|
|
name="test_brain",
|
|
|
|
file_paths=[],
|
|
|
|
llm=fake_llm,
|
|
|
|
embedder=embedder,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def test_brain_empty_files(fake_llm, embedder, mem_vector_store):
|
|
|
|
brain = Brain.from_files(
|
|
|
|
name="test_brain",
|
|
|
|
file_paths=[],
|
|
|
|
llm=fake_llm,
|
|
|
|
embedder=embedder,
|
|
|
|
vector_db=mem_vector_store,
|
|
|
|
)
|
|
|
|
assert brain
|
2024-07-09 16:22:16 +03:00
|
|
|
|
|
|
|
|
2024-07-19 10:47:39 +03:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_brain_from_files_success(
|
2024-07-30 19:49:12 +03:00
|
|
|
fake_llm: LLMEndpoint, embedder, temp_data_file, mem_vector_store
|
2024-07-19 10:47:39 +03:00
|
|
|
):
|
|
|
|
brain = await Brain.afrom_files(
|
2024-07-30 19:49:12 +03:00
|
|
|
name="test_brain",
|
|
|
|
file_paths=[temp_data_file],
|
|
|
|
embedder=embedder,
|
|
|
|
llm=fake_llm,
|
|
|
|
vector_db=mem_vector_store,
|
2024-07-09 16:22:16 +03:00
|
|
|
)
|
2024-07-09 18:55:14 +03:00
|
|
|
assert brain.name == "test_brain"
|
2024-07-10 16:22:59 +03:00
|
|
|
assert len(brain.chat_history) == 0
|
|
|
|
assert brain.llm == fake_llm
|
2024-07-09 16:22:16 +03:00
|
|
|
assert brain.vector_db.embeddings == embedder
|
2024-07-10 16:22:59 +03:00
|
|
|
assert isinstance(brain.default_chat, ChatHistory)
|
|
|
|
assert len(brain.default_chat) == 0
|
|
|
|
|
2024-07-09 18:55:14 +03:00
|
|
|
# storage
|
|
|
|
assert isinstance(brain.storage, TransparentStorage)
|
2024-07-19 10:47:39 +03:00
|
|
|
assert len(await brain.storage.get_files()) == 1
|
2024-07-09 16:22:16 +03:00
|
|
|
|
2024-07-09 18:55:14 +03:00
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2024-07-30 19:49:12 +03:00
|
|
|
async def test_brain_from_langchain_docs(embedder, fake_llm, mem_vector_store):
|
2024-07-09 18:55:14 +03:00
|
|
|
chunk = Document("content_1", metadata={"id": uuid4()})
|
|
|
|
brain = await Brain.afrom_langchain_documents(
|
2024-07-30 19:49:12 +03:00
|
|
|
name="test",
|
|
|
|
llm=fake_llm,
|
|
|
|
langchain_documents=[chunk],
|
|
|
|
embedder=embedder,
|
|
|
|
vector_db=mem_vector_store,
|
2024-07-09 18:55:14 +03:00
|
|
|
)
|
|
|
|
# No appended files
|
2024-07-19 10:47:39 +03:00
|
|
|
assert len(await brain.storage.get_files()) == 0
|
2024-07-09 18:55:14 +03:00
|
|
|
assert len(brain.chat_history) == 0
|
|
|
|
|
|
|
|
|
2024-07-30 19:49:12 +03:00
|
|
|
@pytest.mark.base
|
2024-07-09 18:55:14 +03:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_brain_search(
|
|
|
|
embedder: Embeddings,
|
|
|
|
):
|
2024-07-12 16:07:39 +03:00
|
|
|
chunk1 = Document("content_1", metadata={"id": uuid4()})
|
|
|
|
chunk2 = Document("content_2", metadata={"id": uuid4()})
|
2024-07-09 18:55:14 +03:00
|
|
|
brain = await Brain.afrom_langchain_documents(
|
2024-07-12 16:07:39 +03:00
|
|
|
name="test", langchain_documents=[chunk1, chunk2], embedder=embedder
|
2024-07-09 18:55:14 +03:00
|
|
|
)
|
|
|
|
|
2024-07-12 16:07:39 +03:00
|
|
|
k = 2
|
|
|
|
result = await brain.asearch("content_1", n_results=k)
|
2024-07-09 18:55:14 +03:00
|
|
|
|
2024-07-12 16:07:39 +03:00
|
|
|
assert len(result) == k
|
|
|
|
assert result[0].chunk == chunk1
|
|
|
|
assert result[1].chunk == chunk2
|
|
|
|
assert result[0].distance == 0
|
|
|
|
assert result[1].distance > result[0].distance
|
2024-07-10 16:22:59 +03:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_brain_get_history(
|
2024-07-30 19:49:12 +03:00
|
|
|
fake_llm: LLMEndpoint, embedder, temp_data_file, mem_vector_store
|
2024-07-10 16:22:59 +03:00
|
|
|
):
|
|
|
|
brain = await Brain.afrom_files(
|
2024-07-30 19:49:12 +03:00
|
|
|
name="test_brain",
|
|
|
|
file_paths=[temp_data_file],
|
|
|
|
embedder=embedder,
|
|
|
|
llm=fake_llm,
|
|
|
|
vector_db=mem_vector_store,
|
2024-07-10 16:22:59 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
brain.ask("question")
|
|
|
|
brain.ask("question")
|
|
|
|
|
|
|
|
assert len(brain.default_chat) == 4
|
2024-07-10 18:52:07 +03:00
|
|
|
|
|
|
|
|
2024-07-30 19:49:12 +03:00
|
|
|
@pytest.mark.base
|
2024-07-10 18:52:07 +03:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_brain_ask_streaming(
|
|
|
|
fake_llm: LLMEndpoint, embedder, temp_data_file, answers
|
|
|
|
):
|
|
|
|
brain = await Brain.afrom_files(
|
|
|
|
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
|
|
|
|
)
|
|
|
|
|
|
|
|
response = ""
|
|
|
|
async for chunk in brain.ask_streaming("question"):
|
|
|
|
response += chunk.answer
|
|
|
|
|
|
|
|
assert response == answers[1]
|
2024-07-19 10:47:39 +03:00
|
|
|
|
|
|
|
|
|
|
|
def test_brain_info_empty(fake_llm: LLMEndpoint, embedder, mem_vector_store):
|
|
|
|
storage = TransparentStorage()
|
|
|
|
id = uuid4()
|
|
|
|
brain = Brain(
|
|
|
|
name="test",
|
|
|
|
id=id,
|
|
|
|
llm=fake_llm,
|
|
|
|
embedder=embedder,
|
|
|
|
storage=storage,
|
|
|
|
vector_db=mem_vector_store,
|
|
|
|
)
|
|
|
|
|
|
|
|
assert asdict(brain.info()) == {
|
|
|
|
"brain_id": id,
|
|
|
|
"brain_name": "test",
|
|
|
|
"files_info": asdict(storage.info()),
|
|
|
|
"chats_info": {
|
|
|
|
"nb_chats": 1, # start with a default chat
|
|
|
|
"current_default_chat": brain.default_chat.id,
|
|
|
|
"current_chat_history_length": 0,
|
|
|
|
},
|
|
|
|
"llm_info": asdict(fake_llm.info()),
|
|
|
|
}
|