2023-12-07 16:52:37 +03:00
import asyncio
2023-11-08 18:07:21 +03:00
import json
from typing import Optional
from uuid import UUID
2023-12-07 16:52:37 +03:00
import nest_asyncio
2023-11-09 18:58:51 +03:00
from fastapi import HTTPException
2023-11-08 18:07:21 +03:00
from litellm import completion
2023-12-04 20:38:54 +03:00
from logger import get_logger
2023-12-01 00:29:28 +03:00
from modules . brain . service . brain_service import BrainService
2023-12-04 20:38:54 +03:00
from modules . chat . dto . chats import ChatQuestion
from modules . chat . dto . inputs import CreateChatHistory
from modules . chat . dto . outputs import GetChatHistoryOutput
from modules . chat . service . chat_service import ChatService
2023-11-08 18:07:21 +03:00
2023-12-07 16:52:37 +03:00
from llm . knowledge_brain_qa import KnowledgeBrainQA
from llm . utils . call_brain_api import call_brain_api
from llm . utils . get_api_brain_definition_as_json_schema import (
get_api_brain_definition_as_json_schema ,
)
2023-12-01 00:29:28 +03:00
brain_service = BrainService ( )
2023-12-04 20:38:54 +03:00
chat_service = ChatService ( )
2023-11-08 18:07:21 +03:00
2023-11-29 21:17:16 +03:00
logger = get_logger ( __name__ )
2023-11-08 18:07:21 +03:00
2023-12-04 20:38:54 +03:00
2023-11-09 18:58:51 +03:00
class APIBrainQA (
2023-12-06 17:44:36 +03:00
KnowledgeBrainQA ,
2023-11-09 18:58:51 +03:00
) :
2023-11-08 18:07:21 +03:00
user_id : UUID
def __init__ (
self ,
model : str ,
brain_id : str ,
chat_id : str ,
streaming : bool = False ,
prompt_id : Optional [ UUID ] = None ,
* * kwargs ,
) :
2023-11-09 18:58:51 +03:00
user_id = kwargs . get ( " user_id " )
if not user_id :
raise HTTPException ( status_code = 400 , detail = " Cannot find user id " )
2023-11-08 18:07:21 +03:00
super ( ) . __init__ (
model = model ,
brain_id = brain_id ,
chat_id = chat_id ,
streaming = streaming ,
prompt_id = prompt_id ,
* * kwargs ,
)
self . user_id = user_id
2023-11-09 18:58:51 +03:00
async def make_completion (
self ,
messages ,
functions ,
brain_id : UUID ,
2023-12-04 20:38:54 +03:00
recursive_count = 0 ,
2023-12-07 16:52:37 +03:00
should_log_steps = False ,
2023-12-04 20:38:54 +03:00
) :
2023-11-30 01:54:39 +03:00
if recursive_count > 5 :
yield " The assistant is having issues and took more than 5 calls to the API. Please try again later or an other instruction. "
return
2023-12-04 20:38:54 +03:00
2023-12-07 16:52:37 +03:00
if should_log_steps :
yield " 🧠<Deciding what to do>🧠 "
2023-11-08 18:07:21 +03:00
response = completion (
model = self . model ,
temperature = self . temperature ,
2023-11-23 19:36:11 +03:00
max_tokens = self . max_tokens ,
2023-11-08 18:07:21 +03:00
messages = messages ,
2023-11-09 18:58:51 +03:00
functions = functions ,
2023-11-08 18:07:21 +03:00
stream = True ,
2023-11-09 18:58:51 +03:00
function_call = " auto " ,
2023-11-08 18:07:21 +03:00
)
2023-11-09 18:58:51 +03:00
function_call = {
" name " : None ,
" arguments " : " " ,
}
for chunk in response :
finish_reason = chunk . choices [ 0 ] . finish_reason
if finish_reason == " stop " :
break
2023-12-04 20:38:54 +03:00
if (
" function_call " in chunk . choices [ 0 ] . delta
and chunk . choices [ 0 ] . delta [ " function_call " ]
) :
2023-11-30 01:54:39 +03:00
if chunk . choices [ 0 ] . delta [ " function_call " ] . name :
function_call [ " name " ] = chunk . choices [ 0 ] . delta [ " function_call " ] . name
if chunk . choices [ 0 ] . delta [ " function_call " ] . arguments :
2023-12-04 20:38:54 +03:00
function_call [ " arguments " ] + = (
chunk . choices [ 0 ] . delta [ " function_call " ] . arguments
)
2023-11-09 18:58:51 +03:00
elif finish_reason == " function_call " :
try :
arguments = json . loads ( function_call [ " arguments " ] )
2023-12-04 20:38:54 +03:00
2023-11-09 18:58:51 +03:00
except Exception :
arguments = { }
2023-12-07 16:52:37 +03:00
if should_log_steps :
yield f " 🧠<Calling { brain_id } with arguments { arguments } >🧠 "
2023-11-09 18:58:51 +03:00
2023-11-24 16:58:33 +03:00
try :
api_call_response = call_brain_api (
brain_id = brain_id ,
user_id = self . user_id ,
arguments = arguments ,
)
except Exception as e :
raise HTTPException (
status_code = 400 ,
detail = f " Error while calling API: { e } " ,
)
2023-12-04 20:38:54 +03:00
2023-11-30 01:54:39 +03:00
function_name = function_call [ " name " ]
2023-11-09 18:58:51 +03:00
messages . append (
{
" role " : " function " ,
2023-11-30 01:54:39 +03:00
" name " : function_call [ " name " ] ,
" content " : f " The function { function_name } was called and gave The following answer:(data from function) { api_call_response } (end of data from function). Don ' t call this function again unless there was an error or extremely necessary and asked specifically by the user. " ,
2023-11-09 18:58:51 +03:00
}
)
async for value in self . make_completion (
messages = messages ,
functions = functions ,
brain_id = brain_id ,
2023-11-30 01:54:39 +03:00
recursive_count = recursive_count + 1 ,
2023-12-07 16:52:37 +03:00
should_log_steps = should_log_steps ,
2023-11-09 18:58:51 +03:00
) :
yield value
else :
2023-11-23 19:36:11 +03:00
if (
hasattr ( chunk . choices [ 0 ] , " delta " )
and chunk . choices [ 0 ] . delta
and hasattr ( chunk . choices [ 0 ] . delta , " content " )
) :
2023-11-21 17:34:12 +03:00
content = chunk . choices [ 0 ] . delta . content
yield content
2023-11-23 19:36:11 +03:00
else : # pragma: no cover
2023-11-22 10:47:51 +03:00
yield " **...** "
2023-11-21 17:34:12 +03:00
break
2023-11-08 18:07:21 +03:00
2023-12-07 16:52:37 +03:00
async def generate_stream (
self ,
chat_id : UUID ,
question : ChatQuestion ,
should_log_steps : Optional [ bool ] = True ,
) :
2023-11-09 18:58:51 +03:00
if not question . brain_id :
raise HTTPException (
status_code = 400 , detail = " No brain id provided in the question "
2023-11-08 18:07:21 +03:00
)
2023-12-01 00:29:28 +03:00
brain = brain_service . get_brain_by_id ( question . brain_id )
2023-11-09 18:58:51 +03:00
if not brain :
raise HTTPException ( status_code = 404 , detail = " Brain not found " )
2023-11-30 01:54:39 +03:00
prompt_content = " You are a helpful assistant that can access functions to help answer questions. If there are information missing in the question, you can ask follow up questions to get more information to the user. Once all the information is available, you can call the function to get the answer. "
2023-11-09 18:58:51 +03:00
if self . prompt_to_use :
prompt_content + = self . prompt_to_use . content
messages = [ { " role " : " system " , " content " : prompt_content } ]
2023-12-04 20:38:54 +03:00
history = chat_service . get_chat_history ( self . chat_id )
2023-11-09 18:58:51 +03:00
for message in history :
formatted_message = [
{ " role " : " user " , " content " : message . user_message } ,
{ " role " : " assistant " , " content " : message . assistant } ,
]
messages . extend ( formatted_message )
messages . append ( { " role " : " user " , " content " : question . question } )
2023-12-04 20:38:54 +03:00
streamed_chat_history = chat_service . update_chat_history (
2023-11-08 18:07:21 +03:00
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : " " ,
" brain_id " : question . brain_id ,
" prompt_id " : self . prompt_to_use_id ,
}
)
)
2023-11-09 18:58:51 +03:00
streamed_chat_history = GetChatHistoryOutput (
2023-11-08 18:07:21 +03:00
* * {
" chat_id " : str ( chat_id ) ,
" message_id " : streamed_chat_history . message_id ,
" message_time " : streamed_chat_history . message_time ,
" user_message " : question . question ,
" assistant " : " " ,
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : brain . name if brain else None ,
}
)
response_tokens = [ ]
2023-11-09 18:58:51 +03:00
async for value in self . make_completion (
messages = messages ,
functions = [ get_api_brain_definition_as_json_schema ( brain ) ] ,
brain_id = question . brain_id ,
2023-12-07 16:52:37 +03:00
should_log_steps = should_log_steps ,
2023-11-09 18:58:51 +03:00
) :
streamed_chat_history . assistant = value
response_tokens . append ( value )
2023-11-08 18:07:21 +03:00
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
2023-11-24 16:58:33 +03:00
response_tokens = [
token
for token in response_tokens
if not token . startswith ( " 🧠< " ) and not token . endswith ( " >🧠 " )
]
2023-12-04 20:38:54 +03:00
chat_service . update_message_by_id (
2023-11-08 18:07:21 +03:00
message_id = str ( streamed_chat_history . message_id ) ,
user_message = question . question ,
assistant = " " . join ( response_tokens ) ,
)
2023-12-07 16:52:37 +03:00
def generate_answer ( self , chat_id : UUID , question : ChatQuestion ) :
async def a_generate_answer ( ) :
api_brain_question_answer : GetChatHistoryOutput = None
async for answer in self . generate_stream (
chat_id , question , should_log_steps = False
) :
answer = answer . split ( " data: " ) [ 1 ]
answer_parsed : GetChatHistoryOutput = GetChatHistoryOutput (
* * json . loads ( answer )
)
if api_brain_question_answer is None :
api_brain_question_answer = answer_parsed
else :
api_brain_question_answer . assistant + = answer_parsed . assistant
return api_brain_question_answer
nest_asyncio . apply ( )
loop = asyncio . get_event_loop ( )
result = loop . run_until_complete ( a_generate_answer ( ) )
return result