mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-26 03:15:19 +03:00
feat: add generate_answer function to support non streamed response for api brain (#1847)
Issue: https://github.com/StanGirard/quivr/issues/1841 Demo: https://github.com/StanGirard/quivr/assets/63923024/56dfa85c-6c2e-4e55-8eda-4723e58ced1d
This commit is contained in:
parent
ad44a8c18d
commit
7e6209a30b
@ -1,14 +1,11 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
import nest_asyncio
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
from llm.knowledge_brain_qa import KnowledgeBrainQA
|
|
||||||
from llm.utils.call_brain_api import call_brain_api
|
|
||||||
from llm.utils.get_api_brain_definition_as_json_schema import (
|
|
||||||
get_api_brain_definition_as_json_schema,
|
|
||||||
)
|
|
||||||
from logger import get_logger
|
from logger import get_logger
|
||||||
from modules.brain.service.brain_service import BrainService
|
from modules.brain.service.brain_service import BrainService
|
||||||
from modules.chat.dto.chats import ChatQuestion
|
from modules.chat.dto.chats import ChatQuestion
|
||||||
@ -16,6 +13,12 @@ from modules.chat.dto.inputs import CreateChatHistory
|
|||||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||||
from modules.chat.service.chat_service import ChatService
|
from modules.chat.service.chat_service import ChatService
|
||||||
|
|
||||||
|
from llm.knowledge_brain_qa import KnowledgeBrainQA
|
||||||
|
from llm.utils.call_brain_api import call_brain_api
|
||||||
|
from llm.utils.get_api_brain_definition_as_json_schema import (
|
||||||
|
get_api_brain_definition_as_json_schema,
|
||||||
|
)
|
||||||
|
|
||||||
brain_service = BrainService()
|
brain_service = BrainService()
|
||||||
chat_service = ChatService()
|
chat_service = ChatService()
|
||||||
|
|
||||||
@ -56,13 +59,15 @@ class APIBrainQA(
|
|||||||
functions,
|
functions,
|
||||||
brain_id: UUID,
|
brain_id: UUID,
|
||||||
recursive_count=0,
|
recursive_count=0,
|
||||||
|
should_log_steps=False,
|
||||||
):
|
):
|
||||||
if recursive_count > 5:
|
if recursive_count > 5:
|
||||||
yield "🧠<Deciding what to do>🧠"
|
|
||||||
yield "The assistant is having issues and took more than 5 calls to the API. Please try again later or an other instruction."
|
yield "The assistant is having issues and took more than 5 calls to the API. Please try again later or an other instruction."
|
||||||
return
|
return
|
||||||
|
|
||||||
yield "🧠<Deciding what to do>🧠"
|
if should_log_steps:
|
||||||
|
yield "🧠<Deciding what to do>🧠"
|
||||||
|
|
||||||
response = completion(
|
response = completion(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
@ -98,7 +103,9 @@ class APIBrainQA(
|
|||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
arguments = {}
|
arguments = {}
|
||||||
yield f"🧠<Calling {brain_id} with arguments {arguments}>🧠"
|
|
||||||
|
if should_log_steps:
|
||||||
|
yield f"🧠<Calling {brain_id} with arguments {arguments}>🧠"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_call_response = call_brain_api(
|
api_call_response = call_brain_api(
|
||||||
@ -125,6 +132,7 @@ class APIBrainQA(
|
|||||||
functions=functions,
|
functions=functions,
|
||||||
brain_id=brain_id,
|
brain_id=brain_id,
|
||||||
recursive_count=recursive_count + 1,
|
recursive_count=recursive_count + 1,
|
||||||
|
should_log_steps=should_log_steps,
|
||||||
):
|
):
|
||||||
yield value
|
yield value
|
||||||
|
|
||||||
@ -140,7 +148,12 @@ class APIBrainQA(
|
|||||||
yield "**...**"
|
yield "**...**"
|
||||||
break
|
break
|
||||||
|
|
||||||
async def generate_stream(self, chat_id: UUID, question: ChatQuestion):
|
async def generate_stream(
|
||||||
|
self,
|
||||||
|
chat_id: UUID,
|
||||||
|
question: ChatQuestion,
|
||||||
|
should_log_steps: Optional[bool] = True,
|
||||||
|
):
|
||||||
if not question.brain_id:
|
if not question.brain_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="No brain id provided in the question"
|
status_code=400, detail="No brain id provided in the question"
|
||||||
@ -198,6 +211,7 @@ class APIBrainQA(
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
functions=[get_api_brain_definition_as_json_schema(brain)],
|
functions=[get_api_brain_definition_as_json_schema(brain)],
|
||||||
brain_id=question.brain_id,
|
brain_id=question.brain_id,
|
||||||
|
should_log_steps=should_log_steps,
|
||||||
):
|
):
|
||||||
streamed_chat_history.assistant = value
|
streamed_chat_history.assistant = value
|
||||||
response_tokens.append(value)
|
response_tokens.append(value)
|
||||||
@ -212,3 +226,27 @@ class APIBrainQA(
|
|||||||
user_message=question.question,
|
user_message=question.question,
|
||||||
assistant="".join(response_tokens),
|
assistant="".join(response_tokens),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def generate_answer(self, chat_id: UUID, question: ChatQuestion):
|
||||||
|
async def a_generate_answer():
|
||||||
|
api_brain_question_answer: GetChatHistoryOutput = None
|
||||||
|
|
||||||
|
async for answer in self.generate_stream(
|
||||||
|
chat_id, question, should_log_steps=False
|
||||||
|
):
|
||||||
|
answer = answer.split("data: ")[1]
|
||||||
|
answer_parsed: GetChatHistoryOutput = GetChatHistoryOutput(
|
||||||
|
**json.loads(answer)
|
||||||
|
)
|
||||||
|
if api_brain_question_answer is None:
|
||||||
|
api_brain_question_answer = answer_parsed
|
||||||
|
else:
|
||||||
|
api_brain_question_answer.assistant += answer_parsed.assistant
|
||||||
|
|
||||||
|
return api_brain_question_answer
|
||||||
|
|
||||||
|
nest_asyncio.apply()
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = loop.run_until_complete(a_generate_answer())
|
||||||
|
|
||||||
|
return result
|
||||||
|
@ -5,6 +5,7 @@ litellm==1.7.7
|
|||||||
openai==1.1.1
|
openai==1.1.1
|
||||||
GitPython==3.1.36
|
GitPython==3.1.36
|
||||||
pdf2image==1.16.3
|
pdf2image==1.16.3
|
||||||
|
nest_asyncio==1.5.6
|
||||||
pypdf==3.9.0
|
pypdf==3.9.0
|
||||||
supabase==1.1.0
|
supabase==1.1.0
|
||||||
tiktoken==0.4.0
|
tiktoken==0.4.0
|
||||||
|
Loading…
Reference in New Issue
Block a user