2023-08-18 11:32:22 +03:00
import asyncio
import json
2023-08-22 15:23:27 +03:00
from typing import AsyncIterable , Awaitable , List , Optional
2023-08-18 11:32:22 +03:00
from uuid import UUID
from langchain . callbacks . streaming_aiter import AsyncIteratorCallbackHandler
from langchain . chains import LLMChain
2023-09-05 18:38:19 +03:00
from langchain . chat_models import ChatLiteLLM
2023-08-25 15:03:57 +03:00
from langchain . chat_models . base import BaseChatModel
2023-09-18 01:22:57 +03:00
from langchain . prompts . chat import ChatPromptTemplate , HumanMessagePromptTemplate
2023-12-15 13:43:41 +03:00
from llm . qa_interface import QAInterface
from llm . utils . format_chat_history import (
format_chat_history ,
format_history_to_openai_mesages ,
)
from llm . utils . get_prompt_to_use import get_prompt_to_use
from llm . utils . get_prompt_to_use_id import get_prompt_to_use_id
2023-12-04 20:38:54 +03:00
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
from modules . chat . dto . chats import ChatQuestion
from modules . chat . dto . inputs import CreateChatHistory
from modules . chat . dto . outputs import GetChatHistoryOutput
from modules . chat . service . chat_service import ChatService
from modules . prompt . entity . prompt import Prompt
from pydantic import BaseModel
2023-11-27 13:21:26 +03:00
2023-08-18 11:32:22 +03:00
logger = get_logger ( __name__ )
2023-09-18 01:22:57 +03:00
SYSTEM_MESSAGE = " Your name is Quivr. You ' re a helpful assistant. If you don ' t know the answer, just say that you don ' t know, don ' t try to make up an answer.When answering use markdown or any other techniques to display the content in a nice and aerated way. "
2023-12-04 20:38:54 +03:00
chat_service = ChatService ( )
2023-08-18 11:32:22 +03:00
2023-12-11 18:46:45 +03:00
class HeadlessQA ( BaseModel , QAInterface ) :
2023-12-04 20:38:54 +03:00
brain_settings = BrainSettings ( )
2023-08-25 15:03:57 +03:00
model : str
2023-08-18 11:32:22 +03:00
temperature : float = 0.0
2023-11-01 10:52:49 +03:00
max_tokens : int = 2000
2023-08-18 11:32:22 +03:00
streaming : bool = False
2023-08-25 15:03:57 +03:00
chat_id : str
callbacks : Optional [ List [ AsyncIteratorCallbackHandler ] ] = None
prompt_id : Optional [ UUID ] = None
2023-08-18 11:32:22 +03:00
2023-08-25 15:03:57 +03:00
def _determine_streaming ( self , streaming : bool ) - > bool :
2023-08-18 11:32:22 +03:00
""" If the model name allows for streaming and streaming is declared, set streaming to True. """
return streaming
def _determine_callback_array (
self , streaming
2023-08-25 15:03:57 +03:00
) - > List [ AsyncIteratorCallbackHandler ] :
2023-08-18 11:32:22 +03:00
""" If streaming is set, set the AsyncIteratorCallbackHandler as the only callback. """
if streaming :
2023-08-25 15:03:57 +03:00
return [ AsyncIteratorCallbackHandler ( ) ]
else :
return [ ]
2023-08-18 11:32:22 +03:00
def __init__ ( self , * * data ) :
super ( ) . __init__ ( * * data )
2023-08-25 15:03:57 +03:00
self . streaming = self . _determine_streaming ( self . streaming )
self . callbacks = self . _determine_callback_array ( self . streaming )
2023-08-18 11:32:22 +03:00
2023-08-22 15:23:27 +03:00
@property
def prompt_to_use ( self ) - > Optional [ Prompt ] :
return get_prompt_to_use ( None , self . prompt_id )
@property
def prompt_to_use_id ( self ) - > Optional [ UUID ] :
return get_prompt_to_use_id ( None , self . prompt_id )
2023-08-18 11:32:22 +03:00
def _create_llm (
2023-11-27 13:21:26 +03:00
self ,
model ,
temperature = 0 ,
streaming = False ,
callbacks = None ,
2023-08-25 15:03:57 +03:00
) - > BaseChatModel :
2023-08-18 11:32:22 +03:00
"""
Determine the language model to be used .
: param model : Language model name to be used .
: param streaming : Whether to enable streaming of the model
: param callbacks : Callbacks to be used for streaming
: return : Language model instance
"""
2023-11-29 21:17:16 +03:00
api_base = None
if self . brain_settings . ollama_api_base_url and model . startswith ( " ollama " ) :
api_base = self . brain_settings . ollama_api_base_url
2023-09-05 18:38:19 +03:00
return ChatLiteLLM (
2023-11-29 21:17:16 +03:00
temperature = temperature ,
2023-08-18 11:32:22 +03:00
model = model ,
streaming = streaming ,
verbose = True ,
callbacks = callbacks ,
2023-11-27 13:21:26 +03:00
max_tokens = self . max_tokens ,
2023-11-29 21:17:16 +03:00
api_base = api_base ,
2023-08-25 15:03:57 +03:00
)
2023-08-18 11:32:22 +03:00
def _create_prompt_template ( self ) :
messages = [
HumanMessagePromptTemplate . from_template ( " {question} " ) ,
]
CHAT_PROMPT = ChatPromptTemplate . from_messages ( messages )
return CHAT_PROMPT
def generate_answer (
2023-12-15 13:43:41 +03:00
self , chat_id : UUID , question : ChatQuestion , save_answer : bool = True
2023-08-18 11:32:22 +03:00
) - > GetChatHistoryOutput :
2023-12-04 20:38:54 +03:00
# Move format_chat_history to chat service ?
transformed_history = format_chat_history (
chat_service . get_chat_history ( self . chat_id )
)
2023-08-22 15:23:27 +03:00
prompt_content = (
self . prompt_to_use . content if self . prompt_to_use else SYSTEM_MESSAGE
)
messages = format_history_to_openai_mesages (
transformed_history , prompt_content , question . question
)
2023-08-18 11:32:22 +03:00
answering_llm = self . _create_llm (
2023-11-27 13:21:26 +03:00
model = self . model ,
streaming = False ,
callbacks = self . callbacks ,
2023-08-18 11:32:22 +03:00
)
2023-08-25 15:03:57 +03:00
model_prediction = answering_llm . predict_messages ( messages )
2023-08-18 11:32:22 +03:00
answer = model_prediction . content
2023-12-15 13:43:41 +03:00
if save_answer :
new_chat = chat_service . update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
" brain_id " : None ,
" prompt_id " : self . prompt_to_use_id ,
}
)
)
2023-08-18 11:32:22 +03:00
2023-12-15 13:43:41 +03:00
return GetChatHistoryOutput (
2023-08-18 11:32:22 +03:00
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
2023-12-15 13:43:41 +03:00
" message_time " : new_chat . message_time ,
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : None ,
" message_id " : new_chat . message_id ,
}
)
else :
return GetChatHistoryOutput (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
" message_time " : None ,
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : None ,
" message_id " : None ,
2023-08-18 11:32:22 +03:00
}
)
async def generate_stream (
2023-12-15 13:43:41 +03:00
self , chat_id : UUID , question : ChatQuestion , save_answer : bool = True
2023-08-18 11:32:22 +03:00
) - > AsyncIterable :
callback = AsyncIteratorCallbackHandler ( )
self . callbacks = [ callback ]
2023-12-04 20:38:54 +03:00
transformed_history = format_chat_history (
chat_service . get_chat_history ( self . chat_id )
)
2023-08-22 15:23:27 +03:00
prompt_content = (
self . prompt_to_use . content if self . prompt_to_use else SYSTEM_MESSAGE
)
messages = format_history_to_openai_mesages (
transformed_history , prompt_content , question . question
)
2023-08-18 11:32:22 +03:00
answering_llm = self . _create_llm (
2023-09-30 23:32:53 +03:00
model = self . model ,
streaming = True ,
callbacks = self . callbacks ,
2023-08-18 11:32:22 +03:00
)
CHAT_PROMPT = ChatPromptTemplate . from_messages ( messages )
headlessChain = LLMChain ( llm = answering_llm , prompt = CHAT_PROMPT )
response_tokens = [ ]
async def wrap_done ( fn : Awaitable , event : asyncio . Event ) :
try :
await fn
except Exception as e :
logger . error ( f " Caught exception: { e } " )
finally :
event . set ( )
2023-08-22 15:23:27 +03:00
2023-08-18 11:32:22 +03:00
run = asyncio . create_task (
wrap_done (
headlessChain . acall ( { } ) ,
callback . done ,
) ,
)
2023-12-15 13:43:41 +03:00
if save_answer :
streamed_chat_history = chat_service . update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : " " ,
" brain_id " : None ,
" prompt_id " : self . prompt_to_use_id ,
}
)
)
streamed_chat_history = GetChatHistoryOutput (
2023-08-18 11:32:22 +03:00
* * {
2023-12-15 13:43:41 +03:00
" chat_id " : str ( chat_id ) ,
" message_id " : streamed_chat_history . message_id ,
" message_time " : streamed_chat_history . message_time ,
2023-08-18 11:32:22 +03:00
" user_message " : question . question ,
" assistant " : " " ,
2023-12-15 13:43:41 +03:00
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : None ,
}
)
else :
streamed_chat_history = GetChatHistoryOutput (
* * {
" chat_id " : str ( chat_id ) ,
" message_id " : None ,
" message_time " : None ,
" user_message " : question . question ,
" assistant " : " " ,
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : None ,
2023-08-18 11:32:22 +03:00
}
)
async for token in callback . aiter ( ) :
2023-08-22 15:23:27 +03:00
response_tokens . append ( token )
streamed_chat_history . assistant = token
2023-08-18 11:32:22 +03:00
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
await run
assistant = " " . join ( response_tokens )
2023-12-15 13:43:41 +03:00
if save_answer :
chat_service . update_message_by_id (
message_id = str ( streamed_chat_history . message_id ) ,
user_message = question . question ,
assistant = assistant ,
)
2023-08-18 11:32:22 +03:00
class Config :
arbitrary_types_allowed = True