quivr/backend/core/tests/test_quivr_rag.py
Jacopo Chevallard ef90e8e672
feat: introducing configurable retrieval workflows (#3227)
# Description

Major PR which, among other things, introduces the possibility of easily
customizing the retrieval workflows. Workflows are based on LangGraph,
and can be customized using a [yaml configuration
file](core/tests/test_llm_endpoint.py), and adding the implementation of
the nodes logic into
[quivr_rag_langgraph.py](1a0c98437a/backend/core/quivr_core/quivr_rag_langgraph.py)

This is a first, simple implementation that will significantly evolve in
the coming weeks to enable more complex workflows (for instance, with
conditional nodes). We also plan to adopt a similar approach for the
ingestion part, i.e. to enable user to easily customize the ingestion
pipeline.

Closes CORE-195, CORE-203, CORE-204

## Checklist before requesting a review

Please delete options that are not relevant.

- [X] My code follows the style guidelines of this project
- [X] I have performed a self-review of my code
- [X] I have commented hard-to-understand areas
- [X] I have ideally added tests that prove my fix is effective or that
my feature works
- [X] New and existing unit tests pass locally with my changes
- [X] Any dependent changes have been merged

## Screenshots (if appropriate):
2024-09-23 09:11:06 -07:00

73 lines
2.5 KiB
Python

from uuid import uuid4
import pytest
from quivr_core.chat import ChatHistory
from quivr_core.config import LLMEndpointConfig, RetrievalConfig
from quivr_core.llm import LLMEndpoint
from quivr_core.models import ParsedRAGChunkResponse, RAGResponseMetadata
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
@pytest.fixture(scope="function")
def mock_chain_qa_stream(monkeypatch, chunks_stream_answer):
class MockQAChain:
async def astream_events(self, *args, **kwargs):
for c in chunks_stream_answer:
yield {
"event": "on_chat_model_stream",
"metadata": {"langgraph_node": "generate"},
"data": {"chunk": c},
}
def mock_qa_chain(*args, **kwargs):
return MockQAChain()
monkeypatch.setattr(QuivrQARAGLangGraph, "build_chain", mock_qa_chain)
@pytest.mark.base
@pytest.mark.asyncio
async def test_quivrqaraglanggraph(
mem_vector_store, full_response, mock_chain_qa_stream, openai_api_key
):
# Making sure the model
llm_config = LLMEndpointConfig(model="gpt-4o")
llm = LLMEndpoint.from_config(llm_config)
retrieval_config = RetrievalConfig(llm_config=llm_config)
chat_history = ChatHistory(uuid4(), uuid4())
rag_pipeline = QuivrQARAGLangGraph(
retrieval_config=retrieval_config, llm=llm, vector_store=mem_vector_store
)
stream_responses: list[ParsedRAGChunkResponse] = []
# Making sure that we are calling the func_calling code path
assert rag_pipeline.llm_endpoint.supports_func_calling()
async for resp in rag_pipeline.answer_astream(
"answer in bullet points. tell me something", chat_history, []
):
stream_responses.append(resp)
assert all(
not r.last_chunk for r in stream_responses[:-1]
), "Some chunks before last have last_chunk=True"
assert stream_responses[-1].last_chunk
for idx, response in enumerate(stream_responses[1:-1]):
assert (
len(response.answer) > 0
), f"Sent an empty answer {response} at index {idx+1}"
# Verify metadata
default_metadata = RAGResponseMetadata().model_dump()
assert all(
r.metadata.model_dump() == default_metadata for r in stream_responses[:-1]
)
last_response = stream_responses[-1]
# TODO(@aminediro) : test responses with sources
assert last_response.metadata.sources == []
assert last_response.metadata.citations == []
# Assert whole response makes sense
assert "".join([r.answer for r in stream_responses]) == full_response