2023-11-08 18:07:21 +03:00
import json
from typing import Optional
from uuid import UUID
2023-11-09 18:58:51 +03:00
from fastapi import HTTPException
2023-11-29 02:07:04 +03:00
from logger import get_logger
2023-11-08 18:07:21 +03:00
from litellm import completion
from models . chats import ChatQuestion
from models . databases . supabase . chats import CreateChatHistory
from repository . brain . get_brain_by_id import get_brain_by_id
2023-11-09 18:58:51 +03:00
from repository . chat . get_chat_history import GetChatHistoryOutput , get_chat_history
2023-11-08 18:07:21 +03:00
from repository . chat . update_chat_history import update_chat_history
from repository . chat . update_message_by_id import update_message_by_id
from llm . qa_base import QABaseBrainPicking
from llm . utils . call_brain_api import call_brain_api
from llm . utils . get_api_brain_definition_as_json_schema import (
get_api_brain_definition_as_json_schema ,
)
2023-11-29 02:07:04 +03:00
logger = get_logger ( __name__ )
2023-11-08 18:07:21 +03:00
2023-11-09 18:58:51 +03:00
class APIBrainQA (
QABaseBrainPicking ,
) :
2023-11-08 18:07:21 +03:00
user_id : UUID
def __init__ (
self ,
model : str ,
brain_id : str ,
chat_id : str ,
streaming : bool = False ,
prompt_id : Optional [ UUID ] = None ,
* * kwargs ,
) :
2023-11-09 18:58:51 +03:00
user_id = kwargs . get ( " user_id " )
if not user_id :
raise HTTPException ( status_code = 400 , detail = " Cannot find user id " )
2023-11-08 18:07:21 +03:00
super ( ) . __init__ (
model = model ,
brain_id = brain_id ,
chat_id = chat_id ,
streaming = streaming ,
prompt_id = prompt_id ,
* * kwargs ,
)
self . user_id = user_id
2023-11-09 18:58:51 +03:00
async def make_completion (
self ,
messages ,
functions ,
brain_id : UUID ,
) :
2023-11-24 16:58:33 +03:00
yield " 🧠<Deciding what to do>🧠 "
2023-11-08 18:07:21 +03:00
response = completion (
model = self . model ,
temperature = self . temperature ,
2023-11-23 19:36:11 +03:00
max_tokens = self . max_tokens ,
2023-11-08 18:07:21 +03:00
messages = messages ,
2023-11-09 18:58:51 +03:00
functions = functions ,
2023-11-08 18:07:21 +03:00
stream = True ,
2023-11-09 18:58:51 +03:00
function_call = " auto " ,
2023-11-08 18:07:21 +03:00
)
2023-11-09 18:58:51 +03:00
function_call = {
" name " : None ,
" arguments " : " " ,
}
for chunk in response :
finish_reason = chunk . choices [ 0 ] . finish_reason
if finish_reason == " stop " :
break
2023-11-29 02:07:04 +03:00
if " function_call " in chunk . choices [ 0 ] . delta and chunk . choices [ 0 ] . delta [ " function_call " ] :
2023-11-09 18:58:51 +03:00
if " name " in chunk . choices [ 0 ] . delta [ " function_call " ] :
function_call [ " name " ] = chunk . choices [ 0 ] . delta [ " function_call " ] [
" name "
]
if " arguments " in chunk . choices [ 0 ] . delta [ " function_call " ] :
function_call [ " arguments " ] + = chunk . choices [ 0 ] . delta [
" function_call "
] [ " arguments " ]
elif finish_reason == " function_call " :
try :
2023-11-29 02:07:04 +03:00
logger . info ( f " Function call: { function_call } " )
2023-11-09 18:58:51 +03:00
arguments = json . loads ( function_call [ " arguments " ] )
2023-11-29 02:07:04 +03:00
2023-11-09 18:58:51 +03:00
except Exception :
arguments = { }
2023-11-29 02:07:04 +03:00
yield f " 🧠<Calling { brain_id } with arguments { arguments } >🧠 "
2023-11-09 18:58:51 +03:00
2023-11-24 16:58:33 +03:00
try :
api_call_response = call_brain_api (
brain_id = brain_id ,
user_id = self . user_id ,
arguments = arguments ,
)
except Exception as e :
raise HTTPException (
status_code = 400 ,
detail = f " Error while calling API: { e } " ,
)
2023-11-09 18:58:51 +03:00
messages . append (
{
" role " : " function " ,
2023-11-29 02:07:04 +03:00
" name " : str ( brain_id ) ,
2023-11-09 18:58:51 +03:00
" content " : api_call_response ,
}
)
async for value in self . make_completion (
messages = messages ,
functions = functions ,
brain_id = brain_id ,
) :
yield value
else :
2023-11-23 19:36:11 +03:00
if (
hasattr ( chunk . choices [ 0 ] , " delta " )
and chunk . choices [ 0 ] . delta
and hasattr ( chunk . choices [ 0 ] . delta , " content " )
) :
2023-11-21 17:34:12 +03:00
content = chunk . choices [ 0 ] . delta . content
yield content
2023-11-23 19:36:11 +03:00
else : # pragma: no cover
2023-11-22 10:47:51 +03:00
yield " **...** "
2023-11-21 17:34:12 +03:00
break
2023-11-08 18:07:21 +03:00
2023-11-09 18:58:51 +03:00
async def generate_stream ( self , chat_id : UUID , question : ChatQuestion ) :
if not question . brain_id :
raise HTTPException (
status_code = 400 , detail = " No brain id provided in the question "
2023-11-08 18:07:21 +03:00
)
2023-11-09 18:58:51 +03:00
brain = get_brain_by_id ( question . brain_id )
if not brain :
raise HTTPException ( status_code = 404 , detail = " Brain not found " )
prompt_content = " You ' are a helpful assistant which can call APIs. Feel free to call the API when you need to. Don ' t force APIs call, do it when necessary. If it seems like you should call the API and there are missing parameters, ask user for them. "
if self . prompt_to_use :
prompt_content + = self . prompt_to_use . content
messages = [ { " role " : " system " , " content " : prompt_content } ]
history = get_chat_history ( self . chat_id )
for message in history :
formatted_message = [
{ " role " : " user " , " content " : message . user_message } ,
{ " role " : " assistant " , " content " : message . assistant } ,
]
messages . extend ( formatted_message )
messages . append ( { " role " : " user " , " content " : question . question } )
2023-11-08 18:07:21 +03:00
streamed_chat_history = update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : " " ,
" brain_id " : question . brain_id ,
" prompt_id " : self . prompt_to_use_id ,
}
)
)
2023-11-09 18:58:51 +03:00
streamed_chat_history = GetChatHistoryOutput (
2023-11-08 18:07:21 +03:00
* * {
" chat_id " : str ( chat_id ) ,
" message_id " : streamed_chat_history . message_id ,
" message_time " : streamed_chat_history . message_time ,
" user_message " : question . question ,
" assistant " : " " ,
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : brain . name if brain else None ,
}
)
response_tokens = [ ]
2023-11-09 18:58:51 +03:00
async for value in self . make_completion (
messages = messages ,
functions = [ get_api_brain_definition_as_json_schema ( brain ) ] ,
brain_id = question . brain_id ,
) :
streamed_chat_history . assistant = value
response_tokens . append ( value )
2023-11-08 18:07:21 +03:00
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
2023-11-24 16:58:33 +03:00
response_tokens = [
token
for token in response_tokens
if not token . startswith ( " 🧠< " ) and not token . endswith ( " >🧠 " )
]
2023-11-08 18:07:21 +03:00
update_message_by_id (
message_id = str ( streamed_chat_history . message_id ) ,
user_message = question . question ,
assistant = " " . join ( response_tokens ) ,
)