2023-08-18 11:32:22 +03:00
import asyncio
import json
2023-08-22 15:23:27 +03:00
from typing import AsyncIterable , Awaitable , List , Optional
2023-08-18 11:32:22 +03:00
from uuid import UUID
from langchain . callbacks . streaming_aiter import AsyncIteratorCallbackHandler
from langchain . chains import LLMChain
2023-08-22 15:23:27 +03:00
from langchain . chat_models import ChatOpenAI
2023-08-25 15:03:57 +03:00
from langchain . chat_models . base import BaseChatModel
2023-08-18 11:32:22 +03:00
from langchain . prompts . chat import (
ChatPromptTemplate ,
HumanMessagePromptTemplate ,
)
2023-08-22 15:23:27 +03:00
from logger import get_logger
from models . chats import ChatQuestion
2023-08-18 11:32:22 +03:00
from models . databases . supabase . chats import CreateChatHistory
2023-08-22 15:23:27 +03:00
from models . prompt import Prompt
from pydantic import BaseModel
2023-08-21 13:25:16 +03:00
from repository . chat import (
2023-08-22 15:23:27 +03:00
GetChatHistoryOutput ,
2023-08-21 13:25:16 +03:00
format_chat_history ,
2023-08-22 15:23:27 +03:00
format_history_to_openai_mesages ,
2023-08-21 13:25:16 +03:00
get_chat_history ,
update_chat_history ,
2023-08-22 15:23:27 +03:00
update_message_by_id ,
2023-08-21 13:25:16 +03:00
)
2023-08-18 11:32:22 +03:00
2023-08-22 15:23:27 +03:00
from llm . utils . get_prompt_to_use import get_prompt_to_use
from llm . utils . get_prompt_to_use_id import get_prompt_to_use_id
2023-08-18 11:32:22 +03:00
logger = get_logger ( __name__ )
2023-08-22 15:23:27 +03:00
SYSTEM_MESSAGE = " Your name is Quivr. You ' re a helpful assistant. If you don ' t know the answer, just say that you don ' t know, don ' t try to make up an answer. "
2023-08-18 11:32:22 +03:00
class HeadlessQA ( BaseModel ) :
2023-08-25 15:03:57 +03:00
model : str
2023-08-18 11:32:22 +03:00
temperature : float = 0.0
max_tokens : int = 256
2023-08-25 15:03:57 +03:00
user_openai_api_key : Optional [ str ] = None
openai_api_key : Optional [ str ] = None
2023-08-18 11:32:22 +03:00
streaming : bool = False
2023-08-25 15:03:57 +03:00
chat_id : str
callbacks : Optional [ List [ AsyncIteratorCallbackHandler ] ] = None
prompt_id : Optional [ UUID ] = None
2023-08-18 11:32:22 +03:00
def _determine_api_key ( self , openai_api_key , user_openai_api_key ) :
""" If user provided an API key, use it. """
if user_openai_api_key is not None :
return user_openai_api_key
else :
return openai_api_key
2023-08-25 15:03:57 +03:00
def _determine_streaming ( self , streaming : bool ) - > bool :
2023-08-18 11:32:22 +03:00
""" If the model name allows for streaming and streaming is declared, set streaming to True. """
return streaming
def _determine_callback_array (
self , streaming
2023-08-25 15:03:57 +03:00
) - > List [ AsyncIteratorCallbackHandler ] :
2023-08-18 11:32:22 +03:00
""" If streaming is set, set the AsyncIteratorCallbackHandler as the only callback. """
if streaming :
2023-08-25 15:03:57 +03:00
return [ AsyncIteratorCallbackHandler ( ) ]
else :
return [ ]
2023-08-18 11:32:22 +03:00
def __init__ ( self , * * data ) :
super ( ) . __init__ ( * * data )
2023-08-25 15:03:57 +03:00
print ( " in HeadlessQA " )
2023-08-18 11:32:22 +03:00
self . openai_api_key = self . _determine_api_key (
self . openai_api_key , self . user_openai_api_key
)
2023-08-25 15:03:57 +03:00
self . streaming = self . _determine_streaming ( self . streaming )
self . callbacks = self . _determine_callback_array ( self . streaming )
2023-08-18 11:32:22 +03:00
2023-08-22 15:23:27 +03:00
@property
def prompt_to_use ( self ) - > Optional [ Prompt ] :
return get_prompt_to_use ( None , self . prompt_id )
@property
def prompt_to_use_id ( self ) - > Optional [ UUID ] :
return get_prompt_to_use_id ( None , self . prompt_id )
2023-08-18 11:32:22 +03:00
def _create_llm (
self , model , temperature = 0 , streaming = False , callbacks = None
2023-08-25 15:03:57 +03:00
) - > BaseChatModel :
2023-08-18 11:32:22 +03:00
"""
Determine the language model to be used .
: param model : Language model name to be used .
: param streaming : Whether to enable streaming of the model
: param callbacks : Callbacks to be used for streaming
: return : Language model instance
"""
return ChatOpenAI (
temperature = temperature ,
model = model ,
streaming = streaming ,
verbose = True ,
callbacks = callbacks ,
openai_api_key = self . openai_api_key ,
2023-08-25 15:03:57 +03:00
)
2023-08-18 11:32:22 +03:00
def _create_prompt_template ( self ) :
messages = [
HumanMessagePromptTemplate . from_template ( " {question} " ) ,
]
CHAT_PROMPT = ChatPromptTemplate . from_messages ( messages )
return CHAT_PROMPT
def generate_answer (
self , chat_id : UUID , question : ChatQuestion
) - > GetChatHistoryOutput :
transformed_history = format_chat_history ( get_chat_history ( self . chat_id ) )
2023-08-22 15:23:27 +03:00
prompt_content = (
self . prompt_to_use . content if self . prompt_to_use else SYSTEM_MESSAGE
)
messages = format_history_to_openai_mesages (
transformed_history , prompt_content , question . question
)
2023-08-18 11:32:22 +03:00
answering_llm = self . _create_llm (
model = self . model , streaming = False , callbacks = self . callbacks
)
2023-08-25 15:03:57 +03:00
model_prediction = answering_llm . predict_messages ( messages )
2023-08-18 11:32:22 +03:00
answer = model_prediction . content
new_chat = update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
" brain_id " : None ,
2023-08-22 15:23:27 +03:00
" prompt_id " : self . prompt_to_use_id ,
2023-08-18 11:32:22 +03:00
}
)
)
return GetChatHistoryOutput (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
" message_time " : new_chat . message_time ,
2023-08-22 15:23:27 +03:00
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
2023-08-18 11:32:22 +03:00
" brain_name " : None ,
" message_id " : new_chat . message_id ,
}
)
async def generate_stream (
self , chat_id : UUID , question : ChatQuestion
) - > AsyncIterable :
callback = AsyncIteratorCallbackHandler ( )
self . callbacks = [ callback ]
transformed_history = format_chat_history ( get_chat_history ( self . chat_id ) )
2023-08-22 15:23:27 +03:00
prompt_content = (
self . prompt_to_use . content if self . prompt_to_use else SYSTEM_MESSAGE
)
messages = format_history_to_openai_mesages (
transformed_history , prompt_content , question . question
)
2023-08-18 11:32:22 +03:00
answering_llm = self . _create_llm (
model = self . model , streaming = True , callbacks = self . callbacks
)
CHAT_PROMPT = ChatPromptTemplate . from_messages ( messages )
headlessChain = LLMChain ( llm = answering_llm , prompt = CHAT_PROMPT )
response_tokens = [ ]
async def wrap_done ( fn : Awaitable , event : asyncio . Event ) :
try :
await fn
except Exception as e :
logger . error ( f " Caught exception: { e } " )
finally :
event . set ( )
2023-08-22 15:23:27 +03:00
2023-08-18 11:32:22 +03:00
run = asyncio . create_task (
wrap_done (
headlessChain . acall ( { } ) ,
callback . done ,
) ,
)
streamed_chat_history = update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : " " ,
" brain_id " : None ,
2023-08-22 15:23:27 +03:00
" prompt_id " : self . prompt_to_use_id ,
2023-08-18 11:32:22 +03:00
}
)
)
streamed_chat_history = GetChatHistoryOutput (
* * {
" chat_id " : str ( chat_id ) ,
" message_id " : streamed_chat_history . message_id ,
" message_time " : streamed_chat_history . message_time ,
" user_message " : question . question ,
" assistant " : " " ,
2023-08-22 15:23:27 +03:00
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
2023-08-18 11:32:22 +03:00
" brain_name " : None ,
}
)
async for token in callback . aiter ( ) :
2023-08-22 15:23:27 +03:00
logger . info ( " Token: %s " , token )
response_tokens . append ( token )
streamed_chat_history . assistant = token
2023-08-18 11:32:22 +03:00
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
await run
assistant = " " . join ( response_tokens )
update_message_by_id (
message_id = str ( streamed_chat_history . message_id ) ,
user_message = question . question ,
assistant = assistant ,
)
class Config :
arbitrary_types_allowed = True