mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-13 11:12:23 +03:00
feat: add generate_answer function to support non streamed response for api brain (#1847)
Issue: https://github.com/StanGirard/quivr/issues/1841 Demo: https://github.com/StanGirard/quivr/assets/63923024/56dfa85c-6c2e-4e55-8eda-4723e58ced1d
This commit is contained in:
parent
ad44a8c18d
commit
7e6209a30b
@ -1,14 +1,11 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import nest_asyncio
|
||||
from fastapi import HTTPException
|
||||
from litellm import completion
|
||||
from llm.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from llm.utils.call_brain_api import call_brain_api
|
||||
from llm.utils.get_api_brain_definition_as_json_schema import (
|
||||
get_api_brain_definition_as_json_schema,
|
||||
)
|
||||
from logger import get_logger
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.dto.chats import ChatQuestion
|
||||
@ -16,6 +13,12 @@ from modules.chat.dto.inputs import CreateChatHistory
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
|
||||
from llm.knowledge_brain_qa import KnowledgeBrainQA
|
||||
from llm.utils.call_brain_api import call_brain_api
|
||||
from llm.utils.get_api_brain_definition_as_json_schema import (
|
||||
get_api_brain_definition_as_json_schema,
|
||||
)
|
||||
|
||||
brain_service = BrainService()
|
||||
chat_service = ChatService()
|
||||
|
||||
@ -56,13 +59,15 @@ class APIBrainQA(
|
||||
functions,
|
||||
brain_id: UUID,
|
||||
recursive_count=0,
|
||||
should_log_steps=False,
|
||||
):
|
||||
if recursive_count > 5:
|
||||
yield "🧠<Deciding what to do>🧠"
|
||||
yield "The assistant is having issues and took more than 5 calls to the API. Please try again later or an other instruction."
|
||||
return
|
||||
|
||||
yield "🧠<Deciding what to do>🧠"
|
||||
if should_log_steps:
|
||||
yield "🧠<Deciding what to do>🧠"
|
||||
|
||||
response = completion(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
@ -98,7 +103,9 @@ class APIBrainQA(
|
||||
|
||||
except Exception:
|
||||
arguments = {}
|
||||
yield f"🧠<Calling {brain_id} with arguments {arguments}>🧠"
|
||||
|
||||
if should_log_steps:
|
||||
yield f"🧠<Calling {brain_id} with arguments {arguments}>🧠"
|
||||
|
||||
try:
|
||||
api_call_response = call_brain_api(
|
||||
@ -125,6 +132,7 @@ class APIBrainQA(
|
||||
functions=functions,
|
||||
brain_id=brain_id,
|
||||
recursive_count=recursive_count + 1,
|
||||
should_log_steps=should_log_steps,
|
||||
):
|
||||
yield value
|
||||
|
||||
@ -140,7 +148,12 @@ class APIBrainQA(
|
||||
yield "**...**"
|
||||
break
|
||||
|
||||
async def generate_stream(self, chat_id: UUID, question: ChatQuestion):
|
||||
async def generate_stream(
|
||||
self,
|
||||
chat_id: UUID,
|
||||
question: ChatQuestion,
|
||||
should_log_steps: Optional[bool] = True,
|
||||
):
|
||||
if not question.brain_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No brain id provided in the question"
|
||||
@ -198,6 +211,7 @@ class APIBrainQA(
|
||||
messages=messages,
|
||||
functions=[get_api_brain_definition_as_json_schema(brain)],
|
||||
brain_id=question.brain_id,
|
||||
should_log_steps=should_log_steps,
|
||||
):
|
||||
streamed_chat_history.assistant = value
|
||||
response_tokens.append(value)
|
||||
@ -212,3 +226,27 @@ class APIBrainQA(
|
||||
user_message=question.question,
|
||||
assistant="".join(response_tokens),
|
||||
)
|
||||
|
||||
def generate_answer(self, chat_id: UUID, question: ChatQuestion):
|
||||
async def a_generate_answer():
|
||||
api_brain_question_answer: GetChatHistoryOutput = None
|
||||
|
||||
async for answer in self.generate_stream(
|
||||
chat_id, question, should_log_steps=False
|
||||
):
|
||||
answer = answer.split("data: ")[1]
|
||||
answer_parsed: GetChatHistoryOutput = GetChatHistoryOutput(
|
||||
**json.loads(answer)
|
||||
)
|
||||
if api_brain_question_answer is None:
|
||||
api_brain_question_answer = answer_parsed
|
||||
else:
|
||||
api_brain_question_answer.assistant += answer_parsed.assistant
|
||||
|
||||
return api_brain_question_answer
|
||||
|
||||
nest_asyncio.apply()
|
||||
loop = asyncio.get_event_loop()
|
||||
result = loop.run_until_complete(a_generate_answer())
|
||||
|
||||
return result
|
||||
|
@ -5,6 +5,7 @@ litellm==1.7.7
|
||||
openai==1.1.1
|
||||
GitPython==3.1.36
|
||||
pdf2image==1.16.3
|
||||
nest_asyncio==1.5.6
|
||||
pypdf==3.9.0
|
||||
supabase==1.1.0
|
||||
tiktoken==0.4.0
|
||||
|
Loading…
Reference in New Issue
Block a user