From 7e6209a30b72ea53b72ae2f4838fc2f0b846e52b Mon Sep 17 00:00:00 2001 From: Mamadou DICKO <63923024+mamadoudicko@users.noreply.github.com> Date: Thu, 7 Dec 2023 14:52:37 +0100 Subject: [PATCH] feat: add generate_answer function to support non streamed response for api brain (#1847) Issue: https://github.com/StanGirard/quivr/issues/1841 Demo: https://github.com/StanGirard/quivr/assets/63923024/56dfa85c-6c2e-4e55-8eda-4723e58ced1d --- backend/llm/api_brain_qa.py | 56 +++++++++++++++++++++++++++++++------ backend/requirements.txt | 1 + 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/backend/llm/api_brain_qa.py b/backend/llm/api_brain_qa.py index 6fb967e16..7255f8d78 100644 --- a/backend/llm/api_brain_qa.py +++ b/backend/llm/api_brain_qa.py @@ -1,14 +1,11 @@ +import asyncio import json from typing import Optional from uuid import UUID +import nest_asyncio from fastapi import HTTPException from litellm import completion -from llm.knowledge_brain_qa import KnowledgeBrainQA -from llm.utils.call_brain_api import call_brain_api -from llm.utils.get_api_brain_definition_as_json_schema import ( - get_api_brain_definition_as_json_schema, -) from logger import get_logger from modules.brain.service.brain_service import BrainService from modules.chat.dto.chats import ChatQuestion @@ -16,6 +13,12 @@ from modules.chat.dto.inputs import CreateChatHistory from modules.chat.dto.outputs import GetChatHistoryOutput from modules.chat.service.chat_service import ChatService +from llm.knowledge_brain_qa import KnowledgeBrainQA +from llm.utils.call_brain_api import call_brain_api +from llm.utils.get_api_brain_definition_as_json_schema import ( + get_api_brain_definition_as_json_schema, +) + brain_service = BrainService() chat_service = ChatService() @@ -56,13 +59,15 @@ class APIBrainQA( functions, brain_id: UUID, recursive_count=0, + should_log_steps=False, ): if recursive_count > 5: - yield "🧠🧠" yield "The assistant is having issues and took more than 5 calls to the API. Please try again later or an other instruction." return - yield "🧠🧠" + if should_log_steps: + yield "🧠🧠" + response = completion( model=self.model, temperature=self.temperature, @@ -98,7 +103,9 @@ class APIBrainQA( except Exception: arguments = {} - yield f"🧠🧠" + + if should_log_steps: + yield f"🧠🧠" try: api_call_response = call_brain_api( @@ -125,6 +132,7 @@ class APIBrainQA( functions=functions, brain_id=brain_id, recursive_count=recursive_count + 1, + should_log_steps=should_log_steps, ): yield value @@ -140,7 +148,12 @@ class APIBrainQA( yield "**...**" break - async def generate_stream(self, chat_id: UUID, question: ChatQuestion): + async def generate_stream( + self, + chat_id: UUID, + question: ChatQuestion, + should_log_steps: Optional[bool] = True, + ): if not question.brain_id: raise HTTPException( status_code=400, detail="No brain id provided in the question" @@ -198,6 +211,7 @@ class APIBrainQA( messages=messages, functions=[get_api_brain_definition_as_json_schema(brain)], brain_id=question.brain_id, + should_log_steps=should_log_steps, ): streamed_chat_history.assistant = value response_tokens.append(value) @@ -212,3 +226,27 @@ class APIBrainQA( user_message=question.question, assistant="".join(response_tokens), ) + + def generate_answer(self, chat_id: UUID, question: ChatQuestion): + async def a_generate_answer(): + api_brain_question_answer: GetChatHistoryOutput = None + + async for answer in self.generate_stream( + chat_id, question, should_log_steps=False + ): + answer = answer.split("data: ")[1] + answer_parsed: GetChatHistoryOutput = GetChatHistoryOutput( + **json.loads(answer) + ) + if api_brain_question_answer is None: + api_brain_question_answer = answer_parsed + else: + api_brain_question_answer.assistant += answer_parsed.assistant + + return api_brain_question_answer + + nest_asyncio.apply() + loop = asyncio.get_event_loop() + result = loop.run_until_complete(a_generate_answer()) + + return result diff --git a/backend/requirements.txt b/backend/requirements.txt index 6d349a5f9..3749ca103 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -5,6 +5,7 @@ litellm==1.7.7 openai==1.1.1 GitPython==3.1.36 pdf2image==1.16.3 +nest_asyncio==1.5.6 pypdf==3.9.0 supabase==1.1.0 tiktoken==0.4.0