mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 17:43:03 +03:00
ef90e8e672
# Description
Major PR which, among other things, introduces the possibility of easily
customizing the retrieval workflows. Workflows are based on LangGraph,
and can be customized using a [yaml configuration
file](core/tests/test_llm_endpoint.py), and adding the implementation of
the nodes logic into
[quivr_rag_langgraph.py](1a0c98437a/backend/core/quivr_core/quivr_rag_langgraph.py
)
This is a first, simple implementation that will significantly evolve in
the coming weeks to enable more complex workflows (for instance, with
conditional nodes). We also plan to adopt a similar approach for the
ingestion part, i.e. to enable user to easily customize the ingestion
pipeline.
Closes CORE-195, CORE-203, CORE-204
## Checklist before requesting a review
Please delete options that are not relevant.
- [X] My code follows the style guidelines of this project
- [X] I have performed a self-review of my code
- [X] I have commented hard-to-understand areas
- [X] I have ideally added tests that prove my fix is effective or that
my feature works
- [X] New and existing unit tests pass locally with my changes
- [X] Any dependent changes have been merged
## Screenshots (if appropriate):
73 lines
2.5 KiB
Python
73 lines
2.5 KiB
Python
from uuid import uuid4
|
|
|
|
import pytest
|
|
from quivr_core.chat import ChatHistory
|
|
from quivr_core.config import LLMEndpointConfig, RetrievalConfig
|
|
from quivr_core.llm import LLMEndpoint
|
|
from quivr_core.models import ParsedRAGChunkResponse, RAGResponseMetadata
|
|
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def mock_chain_qa_stream(monkeypatch, chunks_stream_answer):
|
|
class MockQAChain:
|
|
async def astream_events(self, *args, **kwargs):
|
|
for c in chunks_stream_answer:
|
|
yield {
|
|
"event": "on_chat_model_stream",
|
|
"metadata": {"langgraph_node": "generate"},
|
|
"data": {"chunk": c},
|
|
}
|
|
|
|
def mock_qa_chain(*args, **kwargs):
|
|
return MockQAChain()
|
|
|
|
monkeypatch.setattr(QuivrQARAGLangGraph, "build_chain", mock_qa_chain)
|
|
|
|
|
|
@pytest.mark.base
|
|
@pytest.mark.asyncio
|
|
async def test_quivrqaraglanggraph(
|
|
mem_vector_store, full_response, mock_chain_qa_stream, openai_api_key
|
|
):
|
|
# Making sure the model
|
|
llm_config = LLMEndpointConfig(model="gpt-4o")
|
|
llm = LLMEndpoint.from_config(llm_config)
|
|
retrieval_config = RetrievalConfig(llm_config=llm_config)
|
|
chat_history = ChatHistory(uuid4(), uuid4())
|
|
rag_pipeline = QuivrQARAGLangGraph(
|
|
retrieval_config=retrieval_config, llm=llm, vector_store=mem_vector_store
|
|
)
|
|
|
|
stream_responses: list[ParsedRAGChunkResponse] = []
|
|
|
|
# Making sure that we are calling the func_calling code path
|
|
assert rag_pipeline.llm_endpoint.supports_func_calling()
|
|
async for resp in rag_pipeline.answer_astream(
|
|
"answer in bullet points. tell me something", chat_history, []
|
|
):
|
|
stream_responses.append(resp)
|
|
|
|
assert all(
|
|
not r.last_chunk for r in stream_responses[:-1]
|
|
), "Some chunks before last have last_chunk=True"
|
|
assert stream_responses[-1].last_chunk
|
|
|
|
for idx, response in enumerate(stream_responses[1:-1]):
|
|
assert (
|
|
len(response.answer) > 0
|
|
), f"Sent an empty answer {response} at index {idx+1}"
|
|
|
|
# Verify metadata
|
|
default_metadata = RAGResponseMetadata().model_dump()
|
|
assert all(
|
|
r.metadata.model_dump() == default_metadata for r in stream_responses[:-1]
|
|
)
|
|
last_response = stream_responses[-1]
|
|
# TODO(@aminediro) : test responses with sources
|
|
assert last_response.metadata.sources == []
|
|
assert last_response.metadata.citations == []
|
|
|
|
# Assert whole response makes sense
|
|
assert "".join([r.answer for r in stream_responses]) == full_response
|