2023-07-11 21:15:56 +03:00
import asyncio
2023-08-07 17:31:00 +03:00
import json
2023-10-19 16:52:20 +03:00
from typing import AsyncIterable , Awaitable , List , Optional
2023-08-07 17:31:00 +03:00
from uuid import UUID
from langchain . callbacks . streaming_aiter import AsyncIteratorCallbackHandler
2023-12-11 18:46:45 +03:00
from langchain . chains import ConversationalRetrievalChain
2023-12-15 13:43:41 +03:00
from llm . qa_interface import QAInterface
from llm . rags . quivr_rag import QuivrRAG
from llm . rags . rag_interface import RAGInterface
from llm . utils . format_chat_history import format_chat_history
from llm . utils . get_prompt_to_use import get_prompt_to_use
from llm . utils . get_prompt_to_use_id import get_prompt_to_use_id
2023-08-10 11:25:08 +03:00
from logger import get_logger
2023-12-11 18:46:45 +03:00
from models import BrainSettings
2023-12-01 00:29:28 +03:00
from modules . brain . service . brain_service import BrainService
2023-12-04 20:38:54 +03:00
from modules . chat . dto . chats import ChatQuestion
from modules . chat . dto . inputs import CreateChatHistory
from modules . chat . dto . outputs import GetChatHistoryOutput
from modules . chat . service . chat_service import ChatService
2023-10-19 16:52:20 +03:00
from pydantic import BaseModel
2023-07-11 21:15:56 +03:00
logger = get_logger ( __name__ )
2023-08-22 15:23:27 +03:00
QUIVR_DEFAULT_PROMPT = " Your name is Quivr. You ' re a helpful assistant. If you don ' t know the answer, just say that you don ' t know, don ' t try to make up an answer. "
2023-07-11 21:15:56 +03:00
2023-12-01 00:29:28 +03:00
brain_service = BrainService ( )
2023-12-04 20:38:54 +03:00
chat_service = ChatService ( )
2023-12-01 00:29:28 +03:00
2023-12-11 18:46:45 +03:00
class KnowledgeBrainQA ( BaseModel , QAInterface ) :
2023-07-11 21:15:56 +03:00
"""
2023-08-07 20:53:04 +03:00
Main class for the Brain Picking functionality .
It allows to initialize a Chat model , generate questions and retrieve answers using ConversationalRetrievalChain .
It has two main methods : ` generate_question ` and ` generate_stream ` .
One is for generating questions in a single request , the other is for generating questions in a streaming fashion .
Both are the same , except that the streaming version streams the last message as a stream .
Each have the same prompt template , which is defined in the ` prompt_template ` property .
2023-07-11 21:15:56 +03:00
"""
2023-08-10 11:25:08 +03:00
2023-10-19 16:52:20 +03:00
class Config :
""" Configuration of the Pydantic Object """
arbitrary_types_allowed = True
# Instantiate settings
brain_settings = BrainSettings ( ) # type: ignore other parameters are optional
# Default class attributes
model : str = None # pyright: ignore reportPrivateUsage=none
temperature : float = 0.1
chat_id : str = None # pyright: ignore reportPrivateUsage=none
brain_id : str = None # pyright: ignore reportPrivateUsage=none
max_tokens : int = 256
streaming : bool = False
2023-12-11 18:46:45 +03:00
knowledge_qa : Optional [ RAGInterface ]
2023-10-19 16:52:20 +03:00
callbacks : List [
AsyncIteratorCallbackHandler
] = None # pyright: ignore reportPrivateUsage=none
2023-08-22 15:23:27 +03:00
prompt_id : Optional [ UUID ]
2023-07-11 21:15:56 +03:00
def __init__ (
self ,
model : str ,
brain_id : str ,
chat_id : str ,
streaming : bool = False ,
2023-08-22 15:23:27 +03:00
prompt_id : Optional [ UUID ] = None ,
2023-07-11 21:15:56 +03:00
* * kwargs ,
2023-08-21 13:45:32 +03:00
) :
2023-07-11 21:15:56 +03:00
super ( ) . __init__ (
model = model ,
brain_id = brain_id ,
chat_id = chat_id ,
streaming = streaming ,
* * kwargs ,
)
2023-08-22 15:23:27 +03:00
self . prompt_id = prompt_id
2023-12-11 18:46:45 +03:00
self . knowledge_qa = QuivrRAG (
model = model ,
brain_id = brain_id ,
chat_id = chat_id ,
streaming = streaming ,
* * kwargs ,
)
2023-08-22 15:23:27 +03:00
@property
def prompt_to_use ( self ) :
2023-12-15 13:43:41 +03:00
# TODO: move to prompt service or instruction or something
2023-08-22 15:23:27 +03:00
return get_prompt_to_use ( UUID ( self . brain_id ) , self . prompt_id )
@property
def prompt_to_use_id ( self ) - > Optional [ UUID ] :
2023-12-15 13:43:41 +03:00
# TODO: move to prompt service or instruction or something
2023-08-22 15:23:27 +03:00
return get_prompt_to_use_id ( UUID ( self . brain_id ) , self . prompt_id )
2023-07-11 21:15:56 +03:00
2023-08-10 11:25:08 +03:00
def generate_answer (
2023-12-15 13:43:41 +03:00
self , chat_id : UUID , question : ChatQuestion , save_answer : bool = True
2023-08-10 11:25:08 +03:00
) - > GetChatHistoryOutput :
2023-12-04 20:38:54 +03:00
transformed_history = format_chat_history (
chat_service . get_chat_history ( self . chat_id )
)
2023-08-08 18:15:43 +03:00
# The Chain that combines the question and answer
qa = ConversationalRetrievalChain (
2023-12-11 18:46:45 +03:00
retriever = self . knowledge_qa . get_retriever ( ) ,
combine_docs_chain = self . knowledge_qa . get_doc_chain (
streaming = False ,
2023-08-08 18:15:43 +03:00
) ,
2023-12-11 18:46:45 +03:00
question_generator = self . knowledge_qa . get_question_generation_llm ( ) ,
2023-08-18 11:18:29 +03:00
verbose = False ,
2023-09-19 13:11:03 +03:00
rephrase_question = False ,
2023-11-06 14:09:18 +03:00
return_source_documents = True ,
2023-08-08 18:15:43 +03:00
)
2023-08-22 15:23:27 +03:00
prompt_content = (
self . prompt_to_use . content if self . prompt_to_use else QUIVR_DEFAULT_PROMPT
)
2023-08-10 11:25:08 +03:00
model_response = qa (
{
2023-08-10 19:35:30 +03:00
" question " : question . question ,
2023-08-07 20:53:04 +03:00
" chat_history " : transformed_history ,
2023-08-22 15:23:27 +03:00
" custom_personality " : prompt_content ,
2023-08-10 11:25:08 +03:00
}
2023-12-11 18:46:45 +03:00
)
2023-08-10 11:25:08 +03:00
2023-07-11 21:15:56 +03:00
answer = model_response [ " answer " ]
2023-08-10 11:25:08 +03:00
2023-12-15 13:43:41 +03:00
brain = None
if question . brain_id :
brain = brain_service . get_brain_by_id ( question . brain_id )
if save_answer :
# save the answer to the database or not -> add a variable
new_chat = chat_service . update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
" brain_id " : question . brain_id ,
" prompt_id " : self . prompt_to_use_id ,
}
)
)
return GetChatHistoryOutput (
2023-08-10 11:25:08 +03:00
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
2023-12-15 13:43:41 +03:00
" message_time " : new_chat . message_time ,
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : brain . name if brain else None ,
" message_id " : new_chat . message_id ,
2023-08-10 11:25:08 +03:00
}
)
return GetChatHistoryOutput (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
2023-08-15 16:59:30 +03:00
" assistant " : answer ,
2023-12-15 13:43:41 +03:00
" message_time " : None ,
2023-08-22 15:23:27 +03:00
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
2023-12-15 13:43:41 +03:00
" brain_name " : None ,
" message_id " : None ,
2023-08-10 11:25:08 +03:00
}
)
async def generate_stream (
2023-12-15 13:43:41 +03:00
self , chat_id : UUID , question : ChatQuestion , save_answer : bool = True
2023-08-10 11:25:08 +03:00
) - > AsyncIterable :
2023-12-04 20:38:54 +03:00
history = chat_service . get_chat_history ( self . chat_id )
2023-07-31 22:34:34 +03:00
callback = AsyncIteratorCallbackHandler ( )
self . callbacks = [ callback ]
2023-08-02 11:31:42 +03:00
# The Chain that combines the question and answer
2023-07-31 22:34:34 +03:00
qa = ConversationalRetrievalChain (
2023-12-11 18:46:45 +03:00
retriever = self . knowledge_qa . get_retriever ( ) ,
combine_docs_chain = self . knowledge_qa . get_doc_chain (
callbacks = self . callbacks ,
streaming = True ,
2023-08-07 20:53:04 +03:00
) ,
2023-12-11 18:46:45 +03:00
question_generator = self . knowledge_qa . get_question_generation_llm ( ) ,
2023-08-18 11:18:29 +03:00
verbose = False ,
2023-09-19 13:11:03 +03:00
rephrase_question = False ,
2023-11-06 14:09:18 +03:00
return_source_documents = True ,
2023-08-07 17:31:00 +03:00
)
2023-07-11 21:15:56 +03:00
transformed_history = format_chat_history ( history )
response_tokens = [ ]
async def wrap_done ( fn : Awaitable , event : asyncio . Event ) :
try :
2023-11-06 14:09:18 +03:00
return await fn
2023-07-11 21:15:56 +03:00
except Exception as e :
logger . error ( f " Caught exception: { e } " )
2023-11-06 14:09:18 +03:00
return None # Or some sentinel value that indicates failure
2023-07-11 21:15:56 +03:00
finally :
event . set ( )
2023-08-07 17:31:00 +03:00
2023-08-22 15:23:27 +03:00
prompt_content = self . prompt_to_use . content if self . prompt_to_use else None
2023-08-07 17:31:00 +03:00
run = asyncio . create_task (
wrap_done (
qa . acall (
{
2023-08-10 19:35:30 +03:00
" question " : question . question ,
2023-08-07 17:31:00 +03:00
" chat_history " : transformed_history ,
2023-08-22 15:23:27 +03:00
" custom_personality " : prompt_content ,
2023-08-07 17:31:00 +03:00
}
) ,
callback . done ,
)
)
2023-08-10 11:25:08 +03:00
brain = None
if question . brain_id :
2023-12-01 00:29:28 +03:00
brain = brain_service . get_brain_by_id ( question . brain_id )
2023-08-10 11:25:08 +03:00
2023-12-15 13:43:41 +03:00
if save_answer :
streamed_chat_history = chat_service . update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : " " ,
" brain_id " : question . brain_id ,
" prompt_id " : self . prompt_to_use_id ,
}
)
)
streamed_chat_history = GetChatHistoryOutput (
2023-08-10 11:25:08 +03:00
* * {
2023-12-15 13:43:41 +03:00
" chat_id " : str ( chat_id ) ,
" message_id " : streamed_chat_history . message_id ,
" message_time " : streamed_chat_history . message_time ,
2023-08-10 11:25:08 +03:00
" user_message " : question . question ,
" assistant " : " " ,
2023-12-15 13:43:41 +03:00
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : brain . name if brain else None ,
}
)
else :
streamed_chat_history = GetChatHistoryOutput (
* * {
" chat_id " : str ( chat_id ) ,
" message_id " : None ,
" message_time " : None ,
" user_message " : question . question ,
" assistant " : " " ,
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : brain . name if brain else None ,
2023-08-10 11:25:08 +03:00
}
)
2023-07-11 21:15:56 +03:00
2023-11-06 14:09:18 +03:00
try :
async for token in callback . aiter ( ) :
logger . debug ( " Token: %s " , token )
response_tokens . append ( token )
streamed_chat_history . assistant = token
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
except Exception as e :
logger . error ( " Error during streaming tokens: %s " , e )
sources_string = " "
try :
result = await run
source_documents = result . get ( " source_documents " , [ ] )
2023-12-11 18:46:45 +03:00
# Deduplicate source documents
2023-11-06 19:19:26 +03:00
source_documents = list (
{ doc . metadata [ " file_name " ] : doc for doc in source_documents } . values ( )
)
2023-11-06 14:09:18 +03:00
if source_documents :
# Formatting the source documents using Markdown without new lines for each source
sources_string = " \n \n **Sources:** " + " , " . join (
f " { doc . metadata . get ( ' file_name ' , ' Unnamed Document ' ) } "
for doc in source_documents
)
streamed_chat_history . assistant + = sources_string
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
else :
logger . info (
" No source documents found or source_documents is not a list. "
)
except Exception as e :
logger . error ( " Error processing source documents: %s " , e )
# Combine all response tokens to form the final assistant message
2023-07-11 21:15:56 +03:00
assistant = " " . join ( response_tokens )
2023-11-06 14:09:18 +03:00
assistant + = sources_string
2023-07-11 21:15:56 +03:00
2023-11-06 14:09:18 +03:00
try :
2023-12-15 13:43:41 +03:00
if save_answer :
chat_service . update_message_by_id (
message_id = str ( streamed_chat_history . message_id ) ,
user_message = question . question ,
assistant = assistant ,
)
2023-11-06 14:09:18 +03:00
except Exception as e :
logger . error ( " Error updating message by ID: %s " , e )