feat: quivr-core ask streaming (#2828)

# Description

Added streaming response to quivr-core brain + tests.
This commit is contained in:
AmineDiro 2024-07-10 17:52:07 +02:00 committed by GitHub
parent 6056450bd6
commit 0658d4947c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 15 deletions

View File

@ -1,7 +1,7 @@
import asyncio
import logging
from pathlib import Path
from typing import Any, Callable, Dict, Mapping, Self
from typing import Any, AsyncGenerator, Callable, Dict, Mapping, Self
from uuid import UUID, uuid4
from langchain_core.documents import Document
@ -12,7 +12,7 @@ from langchain_core.vectorstores import VectorStore
from quivr_core.chat import ChatHistory
from quivr_core.config import LLMEndpointConfig, RAGConfig
from quivr_core.llm import LLMEndpoint
from quivr_core.models import ParsedRAGResponse, SearchResult
from quivr_core.models import ParsedRAGChunkResponse, ParsedRAGResponse, SearchResult
from quivr_core.processor.default_parsers import DEFAULT_PARSERS
from quivr_core.processor.processor_base import ProcessorBase
from quivr_core.quivr_rag import QuivrQARAG
@ -283,3 +283,36 @@ class Brain:
# Save answer to the chat history
return parsed_response
async def ask_streaming(
self,
question: str,
rag_config: RAGConfig | None = None,
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
llm = self.llm
# If you passed a different llm model we'll override the brain one
if rag_config:
if rag_config.llm_config != self.llm.get_config():
llm = LLMEndpoint.from_config(config=rag_config.llm_config)
else:
rag_config = RAGConfig(llm_config=self.llm.get_config())
rag_pipeline = QuivrQARAG(
rag_config=rag_config, llm=llm, vector_store=self.vector_db
)
chat_history = self.default_chat
# TODO: List of files
full_answer = ""
async for response in rag_pipeline.answer_astream(question, chat_history, []):
# Format output to be correct servicedf;j
if not response.last_chunk:
yield response
full_answer += response.answer
# TODO : add sources, metdata etc ...
chat_history.append(HumanMessage(content=question))
chat_history.append(AIMessage(content=full_answer))
yield response

View File

@ -58,19 +58,6 @@ def test_brain_from_files_success(fake_llm: LLMEndpoint, embedder, temp_data_fil
assert len(brain.storage.get_files()) == 1
@pytest.mark.asyncio
async def test_brain_ask_txt(fake_llm: LLMEndpoint, embedder, temp_data_file, answers):
brain = await Brain.afrom_files(
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
)
answer = brain.ask("question")
assert answer.answer == answers[1]
assert answer.metadata is not None
assert answer.metadata.sources is not None
assert answer.metadata.sources[0].metadata["source"] == str(temp_data_file)
@pytest.mark.asyncio
async def test_brain_from_langchain_docs(embedder):
chunk = Document("content_1", metadata={"id": uuid4()})
@ -110,3 +97,18 @@ async def test_brain_get_history(
brain.ask("question")
assert len(brain.default_chat) == 4
@pytest.mark.asyncio
async def test_brain_ask_streaming(
fake_llm: LLMEndpoint, embedder, temp_data_file, answers
):
brain = await Brain.afrom_files(
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
)
response = ""
async for chunk in brain.ask_streaming("question"):
response += chunk.answer
assert response == answers[1]