feat: add generate_answer function to support non streamed response for api brain (#1847)

Issue: https://github.com/StanGirard/quivr/issues/1841

Demo:


https://github.com/StanGirard/quivr/assets/63923024/56dfa85c-6c2e-4e55-8eda-4723e58ced1d
This commit is contained in:
Mamadou DICKO 2023-12-07 14:52:37 +01:00 committed by GitHub
parent ad44a8c18d
commit 7e6209a30b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 9 deletions

View File

@ -1,14 +1,11 @@
import asyncio
import json
from typing import Optional
from uuid import UUID
import nest_asyncio
from fastapi import HTTPException
from litellm import completion
from llm.knowledge_brain_qa import KnowledgeBrainQA
from llm.utils.call_brain_api import call_brain_api
from llm.utils.get_api_brain_definition_as_json_schema import (
get_api_brain_definition_as_json_schema,
)
from logger import get_logger
from modules.brain.service.brain_service import BrainService
from modules.chat.dto.chats import ChatQuestion
@ -16,6 +13,12 @@ from modules.chat.dto.inputs import CreateChatHistory
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService
from llm.knowledge_brain_qa import KnowledgeBrainQA
from llm.utils.call_brain_api import call_brain_api
from llm.utils.get_api_brain_definition_as_json_schema import (
get_api_brain_definition_as_json_schema,
)
brain_service = BrainService()
chat_service = ChatService()
@ -56,13 +59,15 @@ class APIBrainQA(
functions,
brain_id: UUID,
recursive_count=0,
should_log_steps=False,
):
if recursive_count > 5:
yield "🧠<Deciding what to do>🧠"
yield "The assistant is having issues and took more than 5 calls to the API. Please try again later or an other instruction."
return
if should_log_steps:
yield "🧠<Deciding what to do>🧠"
response = completion(
model=self.model,
temperature=self.temperature,
@ -98,6 +103,8 @@ class APIBrainQA(
except Exception:
arguments = {}
if should_log_steps:
yield f"🧠<Calling {brain_id} with arguments {arguments}>🧠"
try:
@ -125,6 +132,7 @@ class APIBrainQA(
functions=functions,
brain_id=brain_id,
recursive_count=recursive_count + 1,
should_log_steps=should_log_steps,
):
yield value
@ -140,7 +148,12 @@ class APIBrainQA(
yield "**...**"
break
async def generate_stream(self, chat_id: UUID, question: ChatQuestion):
async def generate_stream(
self,
chat_id: UUID,
question: ChatQuestion,
should_log_steps: Optional[bool] = True,
):
if not question.brain_id:
raise HTTPException(
status_code=400, detail="No brain id provided in the question"
@ -198,6 +211,7 @@ class APIBrainQA(
messages=messages,
functions=[get_api_brain_definition_as_json_schema(brain)],
brain_id=question.brain_id,
should_log_steps=should_log_steps,
):
streamed_chat_history.assistant = value
response_tokens.append(value)
@ -212,3 +226,27 @@ class APIBrainQA(
user_message=question.question,
assistant="".join(response_tokens),
)
def generate_answer(self, chat_id: UUID, question: ChatQuestion):
async def a_generate_answer():
api_brain_question_answer: GetChatHistoryOutput = None
async for answer in self.generate_stream(
chat_id, question, should_log_steps=False
):
answer = answer.split("data: ")[1]
answer_parsed: GetChatHistoryOutput = GetChatHistoryOutput(
**json.loads(answer)
)
if api_brain_question_answer is None:
api_brain_question_answer = answer_parsed
else:
api_brain_question_answer.assistant += answer_parsed.assistant
return api_brain_question_answer
nest_asyncio.apply()
loop = asyncio.get_event_loop()
result = loop.run_until_complete(a_generate_answer())
return result

View File

@ -5,6 +5,7 @@ litellm==1.7.7
openai==1.1.1
GitPython==3.1.36
pdf2image==1.16.3
nest_asyncio==1.5.6
pypdf==3.9.0
supabase==1.1.0
tiktoken==0.4.0