mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 01:21:48 +03:00
feat: quivr-core ask streaming (#2828)
# Description Added streaming response to quivr-core brain + tests.
This commit is contained in:
parent
6056450bd6
commit
0658d4947c
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Mapping, Self
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, Mapping, Self
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.documents import Document
|
||||
@ -12,7 +12,7 @@ from langchain_core.vectorstores import VectorStore
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import LLMEndpointConfig, RAGConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.models import ParsedRAGResponse, SearchResult
|
||||
from quivr_core.models import ParsedRAGChunkResponse, ParsedRAGResponse, SearchResult
|
||||
from quivr_core.processor.default_parsers import DEFAULT_PARSERS
|
||||
from quivr_core.processor.processor_base import ProcessorBase
|
||||
from quivr_core.quivr_rag import QuivrQARAG
|
||||
@ -283,3 +283,36 @@ class Brain:
|
||||
|
||||
# Save answer to the chat history
|
||||
return parsed_response
|
||||
|
||||
async def ask_streaming(
|
||||
self,
|
||||
question: str,
|
||||
rag_config: RAGConfig | None = None,
|
||||
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
|
||||
llm = self.llm
|
||||
|
||||
# If you passed a different llm model we'll override the brain one
|
||||
if rag_config:
|
||||
if rag_config.llm_config != self.llm.get_config():
|
||||
llm = LLMEndpoint.from_config(config=rag_config.llm_config)
|
||||
else:
|
||||
rag_config = RAGConfig(llm_config=self.llm.get_config())
|
||||
|
||||
rag_pipeline = QuivrQARAG(
|
||||
rag_config=rag_config, llm=llm, vector_store=self.vector_db
|
||||
)
|
||||
|
||||
chat_history = self.default_chat
|
||||
|
||||
# TODO: List of files
|
||||
full_answer = ""
|
||||
async for response in rag_pipeline.answer_astream(question, chat_history, []):
|
||||
# Format output to be correct servicedf;j
|
||||
if not response.last_chunk:
|
||||
yield response
|
||||
full_answer += response.answer
|
||||
|
||||
# TODO : add sources, metdata etc ...
|
||||
chat_history.append(HumanMessage(content=question))
|
||||
chat_history.append(AIMessage(content=full_answer))
|
||||
yield response
|
||||
|
@ -58,19 +58,6 @@ def test_brain_from_files_success(fake_llm: LLMEndpoint, embedder, temp_data_fil
|
||||
assert len(brain.storage.get_files()) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brain_ask_txt(fake_llm: LLMEndpoint, embedder, temp_data_file, answers):
|
||||
brain = await Brain.afrom_files(
|
||||
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
|
||||
)
|
||||
answer = brain.ask("question")
|
||||
|
||||
assert answer.answer == answers[1]
|
||||
assert answer.metadata is not None
|
||||
assert answer.metadata.sources is not None
|
||||
assert answer.metadata.sources[0].metadata["source"] == str(temp_data_file)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brain_from_langchain_docs(embedder):
|
||||
chunk = Document("content_1", metadata={"id": uuid4()})
|
||||
@ -110,3 +97,18 @@ async def test_brain_get_history(
|
||||
brain.ask("question")
|
||||
|
||||
assert len(brain.default_chat) == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brain_ask_streaming(
|
||||
fake_llm: LLMEndpoint, embedder, temp_data_file, answers
|
||||
):
|
||||
brain = await Brain.afrom_files(
|
||||
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
|
||||
)
|
||||
|
||||
response = ""
|
||||
async for chunk in brain.ask_streaming("question"):
|
||||
response += chunk.answer
|
||||
|
||||
assert response == answers[1]
|
||||
|
Loading…
Reference in New Issue
Block a user