2023-11-08 18:07:21 +03:00
import json
from typing import Optional
from uuid import UUID
2024-01-26 02:56:46 +03:00
import jq
import requests
2023-11-09 18:58:51 +03:00
from fastapi import HTTPException
2023-11-08 18:07:21 +03:00
from litellm import completion
2023-12-07 16:52:37 +03:00
from llm . utils . call_brain_api import call_brain_api
from llm . utils . get_api_brain_definition_as_json_schema import (
get_api_brain_definition_as_json_schema ,
)
2023-12-15 13:43:41 +03:00
from logger import get_logger
2024-02-06 08:02:46 +03:00
from modules . brain . knowledge_brain_qa import KnowledgeBrainQA
from modules . brain . qa_interface import QAInterface
2023-12-15 13:43:41 +03:00
from modules . brain . service . brain_service import BrainService
from modules . chat . dto . chats import ChatQuestion
from modules . chat . dto . inputs import CreateChatHistory
from modules . chat . dto . outputs import GetChatHistoryOutput
from modules . chat . service . chat_service import ChatService
2023-12-07 16:52:37 +03:00
2023-12-01 00:29:28 +03:00
brain_service = BrainService ( )
2023-12-04 20:38:54 +03:00
chat_service = ChatService ( )
2023-11-08 18:07:21 +03:00
2023-11-29 21:17:16 +03:00
logger = get_logger ( __name__ )
2023-11-08 18:07:21 +03:00
2023-12-04 20:38:54 +03:00
2024-01-26 02:56:46 +03:00
class UUIDEncoder ( json . JSONEncoder ) :
def default ( self , obj ) :
2024-02-06 08:02:46 +03:00
if isinstance ( obj , UUID ) :
2024-01-26 02:56:46 +03:00
# if the object is uuid, we simply return the value of uuid
return str ( obj )
return super ( ) . default ( obj )
2023-12-11 18:46:45 +03:00
class APIBrainQA ( KnowledgeBrainQA , QAInterface ) :
2023-11-08 18:07:21 +03:00
user_id : UUID
2024-01-26 02:56:46 +03:00
raw : bool = False
jq_instructions : Optional [ str ] = None
2023-11-08 18:07:21 +03:00
def __init__ (
self ,
model : str ,
brain_id : str ,
chat_id : str ,
streaming : bool = False ,
prompt_id : Optional [ UUID ] = None ,
2024-01-26 02:56:46 +03:00
raw : bool = False ,
jq_instructions : Optional [ str ] = None ,
2023-11-08 18:07:21 +03:00
* * kwargs ,
) :
2023-11-09 18:58:51 +03:00
user_id = kwargs . get ( " user_id " )
if not user_id :
raise HTTPException ( status_code = 400 , detail = " Cannot find user id " )
2023-11-08 18:07:21 +03:00
super ( ) . __init__ (
model = model ,
brain_id = brain_id ,
chat_id = chat_id ,
streaming = streaming ,
prompt_id = prompt_id ,
* * kwargs ,
)
self . user_id = user_id
2024-01-26 02:56:46 +03:00
self . raw = raw
self . jq_instructions = jq_instructions
def get_api_call_response_as_text (
self , method , api_url , params , search_params , secrets
) - > str :
headers = { }
api_url_with_search_params = api_url
if search_params :
api_url_with_search_params + = " ? "
for search_param in search_params :
api_url_with_search_params + = (
f " { search_param } = { search_params [ search_param ] } & "
)
for secret in secrets :
headers [ secret ] = secrets [ secret ]
try :
if method in [ " GET " , " DELETE " ] :
response = requests . request (
method ,
url = api_url_with_search_params ,
params = params or None ,
headers = headers or None ,
)
elif method in [ " POST " , " PUT " , " PATCH " ] :
response = requests . request (
method ,
url = api_url_with_search_params ,
json = params or None ,
headers = headers or None ,
)
else :
raise ValueError ( f " Invalid method: { method } " )
return response . text
except Exception as e :
logger . error ( f " Error calling API: { e } " )
return None
def log_steps ( self , message : str , type : str ) :
if " api " not in self . metadata :
self . metadata [ " api " ] = { }
if " steps " not in self . metadata [ " api " ] :
self . metadata [ " api " ] [ " steps " ] = [ ]
self . metadata [ " api " ] [ " steps " ] . append (
{
" number " : len ( self . metadata [ " api " ] [ " steps " ] ) ,
" type " : type ,
" message " : message ,
}
)
2023-11-08 18:07:21 +03:00
2023-11-09 18:58:51 +03:00
async def make_completion (
self ,
messages ,
functions ,
brain_id : UUID ,
2023-12-04 20:38:54 +03:00
recursive_count = 0 ,
2024-01-26 02:56:46 +03:00
should_log_steps = True ,
) - > str | None :
2023-11-30 01:54:39 +03:00
if recursive_count > 5 :
2024-01-26 02:56:46 +03:00
self . log_steps (
" The assistant is having issues and took more than 5 calls to the API. Please try again later or an other instruction. " ,
" error " ,
)
2023-11-30 01:54:39 +03:00
return
2023-12-04 20:38:54 +03:00
2024-01-26 02:56:46 +03:00
if " api " not in self . metadata :
self . metadata [ " api " ] = { }
if " raw " not in self . metadata [ " api " ] :
self . metadata [ " api " ] [ " raw_enabled " ] = self . raw
2023-12-07 16:52:37 +03:00
2023-11-08 18:07:21 +03:00
response = completion (
model = self . model ,
temperature = self . temperature ,
2023-11-23 19:36:11 +03:00
max_tokens = self . max_tokens ,
2023-11-08 18:07:21 +03:00
messages = messages ,
2023-11-09 18:58:51 +03:00
functions = functions ,
2023-11-08 18:07:21 +03:00
stream = True ,
2023-11-09 18:58:51 +03:00
function_call = " auto " ,
2023-11-08 18:07:21 +03:00
)
2023-11-09 18:58:51 +03:00
function_call = {
" name " : None ,
" arguments " : " " ,
}
for chunk in response :
finish_reason = chunk . choices [ 0 ] . finish_reason
if finish_reason == " stop " :
2024-01-26 02:56:46 +03:00
self . log_steps ( " Quivr has finished " , " info " )
2023-11-09 18:58:51 +03:00
break
2023-12-04 20:38:54 +03:00
if (
" function_call " in chunk . choices [ 0 ] . delta
and chunk . choices [ 0 ] . delta [ " function_call " ]
) :
2023-11-30 01:54:39 +03:00
if chunk . choices [ 0 ] . delta [ " function_call " ] . name :
function_call [ " name " ] = chunk . choices [ 0 ] . delta [ " function_call " ] . name
if chunk . choices [ 0 ] . delta [ " function_call " ] . arguments :
2023-12-04 20:38:54 +03:00
function_call [ " arguments " ] + = (
chunk . choices [ 0 ] . delta [ " function_call " ] . arguments
)
2023-11-09 18:58:51 +03:00
elif finish_reason == " function_call " :
try :
arguments = json . loads ( function_call [ " arguments " ] )
2023-12-04 20:38:54 +03:00
2023-11-09 18:58:51 +03:00
except Exception :
2024-01-26 02:56:46 +03:00
self . log_steps ( f " Issues with { arguments } " , " error " )
2023-11-09 18:58:51 +03:00
arguments = { }
2023-12-07 16:52:37 +03:00
2024-01-26 02:56:46 +03:00
self . log_steps ( f " Calling { brain_id } with arguments { arguments } " , " info " )
2023-11-09 18:58:51 +03:00
2023-11-24 16:58:33 +03:00
try :
api_call_response = call_brain_api (
brain_id = brain_id ,
user_id = self . user_id ,
arguments = arguments ,
)
except Exception as e :
2024-01-23 04:33:48 +03:00
logger . info ( f " Error while calling API: { e } " )
api_call_response = f " Error while calling API: { e } "
2023-11-30 01:54:39 +03:00
function_name = function_call [ " name " ]
2024-01-26 02:56:46 +03:00
self . log_steps ( " Quivr has called the API " , " info " )
2023-11-09 18:58:51 +03:00
messages . append (
{
" role " : " function " ,
2023-11-30 01:54:39 +03:00
" name " : function_call [ " name " ] ,
2024-01-23 04:33:48 +03:00
" content " : f " The function { function_name } was called and gave The following answer:(data from function) { api_call_response } (end of data from function). Don ' t call this function again unless there was an error or extremely necessary and asked specifically by the user. If an error, display it to the user in raw. " ,
2023-11-09 18:58:51 +03:00
}
)
2024-01-23 04:33:48 +03:00
2024-01-26 02:56:46 +03:00
self . metadata [ " api " ] [ " raw_response " ] = json . loads ( api_call_response )
if self . raw :
# Yield the raw response in a format that can then be catched by the generate_stream function
response_to_yield = f " ````raw_response: { api_call_response } ```` "
yield response_to_yield
return
2023-11-09 18:58:51 +03:00
async for value in self . make_completion (
messages = messages ,
functions = functions ,
brain_id = brain_id ,
2023-11-30 01:54:39 +03:00
recursive_count = recursive_count + 1 ,
2023-12-07 16:52:37 +03:00
should_log_steps = should_log_steps ,
2023-11-09 18:58:51 +03:00
) :
yield value
else :
2023-11-23 19:36:11 +03:00
if (
hasattr ( chunk . choices [ 0 ] , " delta " )
and chunk . choices [ 0 ] . delta
and hasattr ( chunk . choices [ 0 ] . delta , " content " )
) :
2023-11-21 17:34:12 +03:00
content = chunk . choices [ 0 ] . delta . content
yield content
2023-11-23 19:36:11 +03:00
else : # pragma: no cover
2023-11-22 10:47:51 +03:00
yield " **...** "
2023-11-21 17:34:12 +03:00
break
2023-11-08 18:07:21 +03:00
2023-12-07 16:52:37 +03:00
async def generate_stream (
self ,
chat_id : UUID ,
question : ChatQuestion ,
2023-12-15 13:43:41 +03:00
save_answer : bool = True ,
2023-12-07 16:52:37 +03:00
should_log_steps : Optional [ bool ] = True ,
) :
2024-01-22 06:44:03 +03:00
brain = brain_service . get_brain_by_id ( self . brain_id )
2023-11-09 18:58:51 +03:00
if not brain :
raise HTTPException ( status_code = 404 , detail = " Brain not found " )
2023-11-30 01:54:39 +03:00
prompt_content = " You are a helpful assistant that can access functions to help answer questions. If there are information missing in the question, you can ask follow up questions to get more information to the user. Once all the information is available, you can call the function to get the answer. "
2023-11-09 18:58:51 +03:00
if self . prompt_to_use :
prompt_content + = self . prompt_to_use . content
messages = [ { " role " : " system " , " content " : prompt_content } ]
2023-12-04 20:38:54 +03:00
history = chat_service . get_chat_history ( self . chat_id )
2023-11-09 18:58:51 +03:00
for message in history :
formatted_message = [
{ " role " : " user " , " content " : message . user_message } ,
{ " role " : " assistant " , " content " : message . assistant } ,
]
messages . extend ( formatted_message )
messages . append ( { " role " : " user " , " content " : question . question } )
2023-12-15 13:43:41 +03:00
if save_answer :
streamed_chat_history = chat_service . update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : " " ,
2024-01-22 06:44:03 +03:00
" brain_id " : self . brain_id ,
2023-12-15 13:43:41 +03:00
" prompt_id " : self . prompt_to_use_id ,
}
)
)
streamed_chat_history = GetChatHistoryOutput (
2023-11-08 18:07:21 +03:00
* * {
2023-12-15 13:43:41 +03:00
" chat_id " : str ( chat_id ) ,
" message_id " : streamed_chat_history . message_id ,
" message_time " : streamed_chat_history . message_time ,
2023-11-08 18:07:21 +03:00
" user_message " : question . question ,
" assistant " : " " ,
2023-12-15 13:43:41 +03:00
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : brain . name if brain else None ,
2024-01-26 02:56:46 +03:00
" brain_id " : str ( self . brain_id ) ,
" metadata " : self . metadata ,
2023-12-15 13:43:41 +03:00
}
)
else :
streamed_chat_history = GetChatHistoryOutput (
* * {
" chat_id " : str ( chat_id ) ,
" message_id " : None ,
" message_time " : None ,
" user_message " : question . question ,
" assistant " : " " ,
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : brain . name if brain else None ,
2024-01-26 02:56:46 +03:00
" brain_id " : str ( self . brain_id ) ,
" metadata " : self . metadata ,
2023-11-08 18:07:21 +03:00
}
)
response_tokens = [ ]
2023-11-09 18:58:51 +03:00
async for value in self . make_completion (
messages = messages ,
functions = [ get_api_brain_definition_as_json_schema ( brain ) ] ,
2024-01-22 06:44:03 +03:00
brain_id = self . brain_id ,
2023-12-07 16:52:37 +03:00
should_log_steps = should_log_steps ,
2023-11-09 18:58:51 +03:00
) :
2024-01-26 02:56:46 +03:00
# Look if the value is a raw response
if value . startswith ( " ````raw_response: " ) :
raw_value_cleaned = value . replace ( " ````raw_response: " , " " ) . replace (
" ```` " , " "
)
logger . info ( f " Raw response: { raw_value_cleaned } " )
if self . jq_instructions :
json_raw_value_cleaned = json . loads ( raw_value_cleaned )
raw_value_cleaned = (
jq . compile ( self . jq_instructions )
. input_value ( json_raw_value_cleaned )
. first ( )
)
streamed_chat_history . assistant = raw_value_cleaned
response_tokens . append ( raw_value_cleaned )
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
else :
streamed_chat_history . assistant = value
response_tokens . append ( value )
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
2023-12-15 13:43:41 +03:00
if save_answer :
chat_service . update_message_by_id (
message_id = str ( streamed_chat_history . message_id ) ,
user_message = question . question ,
2024-01-26 04:10:38 +03:00
assistant = " " . join ( str ( token ) for token in response_tokens ) ,
2024-01-26 02:56:46 +03:00
metadata = self . metadata ,
2023-12-15 13:43:41 +03:00
)
def make_completion_without_streaming (
self ,
messages ,
functions ,
brain_id : UUID ,
recursive_count = 0 ,
should_log_steps = False ,
) :
if recursive_count > 5 :
print (
" The assistant is having issues and took more than 5 calls to the API. Please try again later or an other instruction. "
)
return
if should_log_steps :
print ( " 🧠<Deciding what to do>🧠 " )
response = completion (
model = self . model ,
temperature = self . temperature ,
max_tokens = self . max_tokens ,
messages = messages ,
functions = functions ,
stream = False ,
function_call = " auto " ,
2023-11-08 18:07:21 +03:00
)
2023-12-07 16:52:37 +03:00
2023-12-15 13:43:41 +03:00
response_message = response . choices [ 0 ] . message
finish_reason = response . choices [ 0 ] . finish_reason
2023-12-07 16:52:37 +03:00
2023-12-15 13:43:41 +03:00
if finish_reason == " function_call " :
function_call = response_message . function_call
try :
arguments = json . loads ( function_call . arguments )
except Exception :
arguments = { }
if should_log_steps :
2024-01-26 02:56:46 +03:00
self . log_steps ( f " Calling { brain_id } with arguments { arguments } " , " info " )
2023-12-15 13:43:41 +03:00
try :
api_call_response = call_brain_api (
brain_id = brain_id ,
user_id = self . user_id ,
arguments = arguments ,
)
except Exception as e :
raise HTTPException (
status_code = 400 ,
detail = f " Error while calling API: { e } " ,
2023-12-07 16:52:37 +03:00
)
2023-12-15 13:43:41 +03:00
function_name = function_call . name
messages . append (
{
" role " : " function " ,
" name " : function_call . name ,
" content " : f " The function { function_name } was called and gave The following answer:(data from function) { api_call_response } (end of data from function). Don ' t call this function again unless there was an error or extremely necessary and asked specifically by the user. " ,
}
)
return self . make_completion_without_streaming (
messages = messages ,
functions = functions ,
brain_id = brain_id ,
recursive_count = recursive_count + 1 ,
should_log_steps = should_log_steps ,
)
if finish_reason == " stop " :
return response_message
else :
print ( " Never ending completion " )
def generate_answer (
self ,
chat_id : UUID ,
question : ChatQuestion ,
save_answer : bool = True ,
2024-01-26 02:56:46 +03:00
raw : bool = True ,
2023-12-15 13:43:41 +03:00
) :
2024-01-22 06:44:03 +03:00
if not self . brain_id :
2023-12-15 13:43:41 +03:00
raise HTTPException (
status_code = 400 , detail = " No brain id provided in the question "
)
2024-01-22 06:44:03 +03:00
brain = brain_service . get_brain_by_id ( self . brain_id )
2023-12-15 13:43:41 +03:00
if not brain :
raise HTTPException ( status_code = 404 , detail = " Brain not found " )
prompt_content = " You are a helpful assistant that can access functions to help answer questions. If there are information missing in the question, you can ask follow up questions to get more information to the user. Once all the information is available, you can call the function to get the answer. "
if self . prompt_to_use :
prompt_content + = self . prompt_to_use . content
2023-12-07 16:52:37 +03:00
2023-12-15 13:43:41 +03:00
messages = [ { " role " : " system " , " content " : prompt_content } ]
history = chat_service . get_chat_history ( self . chat_id )
2023-12-07 16:52:37 +03:00
2023-12-15 13:43:41 +03:00
for message in history :
formatted_message = [
{ " role " : " user " , " content " : message . user_message } ,
{ " role " : " assistant " , " content " : message . assistant } ,
]
messages . extend ( formatted_message )
messages . append ( { " role " : " user " , " content " : question . question } )
response = self . make_completion_without_streaming (
messages = messages ,
functions = [ get_api_brain_definition_as_json_schema ( brain ) ] ,
2024-01-22 06:44:03 +03:00
brain_id = self . brain_id ,
2023-12-15 13:43:41 +03:00
should_log_steps = False ,
2024-01-26 02:56:46 +03:00
raw = raw ,
2023-12-15 13:43:41 +03:00
)
answer = response . content
if save_answer :
new_chat = chat_service . update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
2024-01-22 06:44:03 +03:00
" brain_id " : self . brain_id ,
2023-12-15 13:43:41 +03:00
" prompt_id " : self . prompt_to_use_id ,
}
)
)
return GetChatHistoryOutput (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
" message_time " : new_chat . message_time ,
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : brain . name if brain else None ,
" message_id " : new_chat . message_id ,
2024-01-26 02:56:46 +03:00
" metadata " : self . metadata ,
" brain_id " : str ( self . brain_id ) ,
2023-12-15 13:43:41 +03:00
}
)
return GetChatHistoryOutput (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
" message_time " : " 123 " ,
" prompt_title " : None ,
" brain_name " : brain . name ,
" message_id " : None ,
2024-01-26 02:56:46 +03:00
" metadata " : self . metadata ,
" brain_id " : str ( self . brain_id ) ,
2023-12-15 13:43:41 +03:00
}
)