2023-08-07 17:31:00 +03:00
import json
2024-02-15 01:01:35 +03:00
from typing import AsyncIterable , List , Optional
2023-08-07 17:31:00 +03:00
from uuid import UUID
from langchain . callbacks . streaming_aiter import AsyncIteratorCallbackHandler
2023-12-15 13:43:41 +03:00
from llm . utils . format_chat_history import format_chat_history
from llm . utils . get_prompt_to_use import get_prompt_to_use
from llm . utils . get_prompt_to_use_id import get_prompt_to_use_id
2023-08-10 11:25:08 +03:00
from logger import get_logger
2023-12-11 18:46:45 +03:00
from models import BrainSettings
2024-02-06 08:02:46 +03:00
from modules . brain . qa_interface import QAInterface
from modules . brain . rags . quivr_rag import QuivrRAG
from modules . brain . rags . rag_interface import RAGInterface
2023-12-01 00:29:28 +03:00
from modules . brain . service . brain_service import BrainService
2024-01-26 05:56:54 +03:00
from modules . chat . dto . chats import ChatQuestion , Sources
2023-12-04 20:38:54 +03:00
from modules . chat . dto . inputs import CreateChatHistory
from modules . chat . dto . outputs import GetChatHistoryOutput
from modules . chat . service . chat_service import ChatService
2024-02-15 01:01:35 +03:00
from pydantic import BaseModel , ConfigDict
from pydantic_settings import BaseSettings
2024-01-29 02:37:14 +03:00
from repository . files . generate_file_signed_url import generate_file_signed_url
2023-07-11 21:15:56 +03:00
logger = get_logger ( __name__ )
2023-08-22 15:23:27 +03:00
QUIVR_DEFAULT_PROMPT = " Your name is Quivr. You ' re a helpful assistant. If you don ' t know the answer, just say that you don ' t know, don ' t try to make up an answer. "
2023-07-11 21:15:56 +03:00
2023-12-01 00:29:28 +03:00
brain_service = BrainService ( )
2023-12-04 20:38:54 +03:00
chat_service = ChatService ( )
2023-12-01 00:29:28 +03:00
2024-01-20 07:34:30 +03:00
def is_valid_uuid ( uuid_to_test , version = 4 ) :
try :
uuid_obj = UUID ( uuid_to_test , version = version )
except ValueError :
return False
return str ( uuid_obj ) == uuid_to_test
2024-02-15 07:07:53 +03:00
def generate_source ( source_documents , brain_id ) :
2024-01-29 03:01:54 +03:00
# Initialize an empty list for sources
sources_list : List [ Sources ] = [ ]
2024-02-07 07:09:01 +03:00
# Initialize a dictionary for storing generated URLs
generated_urls = { }
2024-01-29 03:01:54 +03:00
# Get source documents from the result, default to an empty list if not found
# If source documents exist
if source_documents :
logger . info ( f " Source documents found: { source_documents } " )
# Iterate over each document
for doc in source_documents :
2024-02-15 07:07:53 +03:00
doc0 = doc [ 0 ]
logger . info ( " Document: %s " , doc0 )
2024-01-29 03:01:54 +03:00
# Check if 'url' is in the document metadata
2024-02-15 07:07:53 +03:00
logger . info ( f " Metadata 1: { doc0 . metadata } " )
2024-01-29 03:01:54 +03:00
is_url = (
2024-02-15 07:07:53 +03:00
" original_file_name " in doc0 . metadata
and doc0 . metadata [ " original_file_name " ] is not None
and doc0 . metadata [ " original_file_name " ] . startswith ( " http " )
2024-01-29 03:01:54 +03:00
)
logger . info ( f " Is URL: { is_url } " )
# Determine the name based on whether it's a URL or a file
name = (
2024-02-15 07:07:53 +03:00
doc0 . metadata [ " original_file_name " ]
2024-01-29 03:01:54 +03:00
if is_url
2024-02-15 07:07:53 +03:00
else doc0 . metadata [ " file_name " ]
2024-01-29 03:01:54 +03:00
)
# Determine the type based on whether it's a URL or a file
type_ = " url " if is_url else " file "
# Determine the source URL based on whether it's a URL or a file
if is_url :
2024-02-15 07:07:53 +03:00
source_url = doc0 . metadata [ " original_file_name " ]
2024-01-29 03:01:54 +03:00
else :
2024-02-15 07:07:53 +03:00
file_path = f " { brain_id } / { doc0 . metadata [ ' file_name ' ] } "
2024-02-07 07:09:01 +03:00
# Check if the URL has already been generated
if file_path in generated_urls :
source_url = generated_urls [ file_path ]
else :
source_url = generate_file_signed_url ( file_path ) . get (
" signedURL " , " "
)
# Store the generated URL
generated_urls [ file_path ] = source_url
2024-01-29 03:01:54 +03:00
# Append a new Sources object to the list
sources_list . append (
Sources (
name = name ,
type = type_ ,
source_url = source_url ,
original_file_name = name ,
)
)
else :
logger . info ( " No source documents found or source_documents is not a list. " )
return sources_list
2023-12-11 18:46:45 +03:00
class KnowledgeBrainQA ( BaseModel , QAInterface ) :
2023-07-11 21:15:56 +03:00
"""
2023-08-07 20:53:04 +03:00
Main class for the Brain Picking functionality .
It allows to initialize a Chat model , generate questions and retrieve answers using ConversationalRetrievalChain .
It has two main methods : ` generate_question ` and ` generate_stream ` .
One is for generating questions in a single request , the other is for generating questions in a streaming fashion .
Both are the same , except that the streaming version streams the last message as a stream .
Each have the same prompt template , which is defined in the ` prompt_template ` property .
2023-07-11 21:15:56 +03:00
"""
2023-08-10 11:25:08 +03:00
2024-02-15 01:01:35 +03:00
model_config = ConfigDict ( arbitrary_types_allowed = True )
2023-10-19 16:52:20 +03:00
# Instantiate settings
2024-02-15 01:01:35 +03:00
brain_settings : BaseSettings = BrainSettings ( )
2023-10-19 16:52:20 +03:00
# Default class attributes
model : str = None # pyright: ignore reportPrivateUsage=none
temperature : float = 0.1
chat_id : str = None # pyright: ignore reportPrivateUsage=none
2024-02-15 01:01:35 +03:00
brain_id : str = None # pyright: ignore reportPrivateUsage=none
2024-01-26 07:19:56 +03:00
max_tokens : int = 2000
2024-01-27 12:50:58 +03:00
max_input : int = 2000
2023-10-19 16:52:20 +03:00
streaming : bool = False
2024-02-15 01:01:35 +03:00
knowledge_qa : Optional [ RAGInterface ] = None
2024-01-21 05:39:03 +03:00
metadata : Optional [ dict ] = None
2024-02-15 07:07:53 +03:00
user_id : str = None
2023-10-19 16:52:20 +03:00
2024-02-06 08:02:46 +03:00
callbacks : List [ AsyncIteratorCallbackHandler ] = (
None # pyright: ignore reportPrivateUsage=none
)
2023-10-19 16:52:20 +03:00
2024-02-15 01:01:35 +03:00
prompt_id : Optional [ UUID ] = None
2023-07-11 21:15:56 +03:00
def __init__ (
self ,
model : str ,
brain_id : str ,
chat_id : str ,
2024-01-27 12:50:58 +03:00
max_tokens : int ,
2023-07-11 21:15:56 +03:00
streaming : bool = False ,
2023-08-22 15:23:27 +03:00
prompt_id : Optional [ UUID ] = None ,
2024-01-21 05:39:03 +03:00
metadata : Optional [ dict ] = None ,
2024-02-15 07:07:53 +03:00
user_id : str = None ,
2023-07-11 21:15:56 +03:00
* * kwargs ,
2023-08-21 13:45:32 +03:00
) :
2023-07-11 21:15:56 +03:00
super ( ) . __init__ (
model = model ,
brain_id = brain_id ,
chat_id = chat_id ,
streaming = streaming ,
* * kwargs ,
)
2023-08-22 15:23:27 +03:00
self . prompt_id = prompt_id
2023-12-11 18:46:45 +03:00
self . knowledge_qa = QuivrRAG (
model = model ,
brain_id = brain_id ,
chat_id = chat_id ,
streaming = streaming ,
* * kwargs ,
)
2024-01-21 05:39:03 +03:00
self . metadata = metadata
2024-01-27 12:50:58 +03:00
self . max_tokens = max_tokens
2024-02-15 07:07:53 +03:00
self . user_id = user_id
2023-08-22 15:23:27 +03:00
@property
def prompt_to_use ( self ) :
2024-01-20 07:34:30 +03:00
if self . brain_id and is_valid_uuid ( self . brain_id ) :
return get_prompt_to_use ( UUID ( self . brain_id ) , self . prompt_id )
else :
return None
2023-08-22 15:23:27 +03:00
@property
def prompt_to_use_id ( self ) - > Optional [ UUID ] :
2023-12-15 13:43:41 +03:00
# TODO: move to prompt service or instruction or something
2024-01-20 07:34:30 +03:00
if self . brain_id and is_valid_uuid ( self . brain_id ) :
return get_prompt_to_use_id ( UUID ( self . brain_id ) , self . prompt_id )
else :
return None
2023-07-11 21:15:56 +03:00
2023-08-10 11:25:08 +03:00
def generate_answer (
2023-12-15 13:43:41 +03:00
self , chat_id : UUID , question : ChatQuestion , save_answer : bool = True
2023-08-10 11:25:08 +03:00
) - > GetChatHistoryOutput :
2024-02-15 10:54:52 +03:00
conversational_qa_chain = self . knowledge_qa . get_chain ( )
transformed_history , streamed_chat_history = (
self . initialize_streamed_chat_history ( chat_id , question )
2023-12-04 20:38:54 +03:00
)
2024-02-15 10:54:52 +03:00
model_response = conversational_qa_chain . invoke (
2023-08-10 11:25:08 +03:00
{
2023-08-10 19:35:30 +03:00
" question " : question . question ,
2023-08-07 20:53:04 +03:00
" chat_history " : transformed_history ,
2024-02-15 10:54:52 +03:00
" custom_personality " : (
self . prompt_to_use . content if self . prompt_to_use else None
) ,
2023-08-10 11:25:08 +03:00
}
2023-12-11 18:46:45 +03:00
)
2023-08-10 11:25:08 +03:00
2024-02-15 10:54:52 +03:00
answer = model_response [ " answer " ] . content
2023-08-10 11:25:08 +03:00
2024-01-20 07:34:30 +03:00
brain = brain_service . get_brain_by_id ( self . brain_id )
2023-12-15 13:43:41 +03:00
if save_answer :
# save the answer to the database or not -> add a variable
new_chat = chat_service . update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
2024-01-20 07:34:30 +03:00
" brain_id " : brain . brain_id ,
2023-12-15 13:43:41 +03:00
" prompt_id " : self . prompt_to_use_id ,
}
)
)
return GetChatHistoryOutput (
2023-08-10 11:25:08 +03:00
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
2023-12-15 13:43:41 +03:00
" message_time " : new_chat . message_time ,
2024-02-06 08:02:46 +03:00
" prompt_title " : (
self . prompt_to_use . title if self . prompt_to_use else None
) ,
2023-12-15 13:43:41 +03:00
" brain_name " : brain . name if brain else None ,
" message_id " : new_chat . message_id ,
2024-01-26 02:56:46 +03:00
" brain_id " : str ( brain . brain_id ) if brain else None ,
2023-08-10 11:25:08 +03:00
}
)
return GetChatHistoryOutput (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
2023-08-15 16:59:30 +03:00
" assistant " : answer ,
2023-12-15 13:43:41 +03:00
" message_time " : None ,
2024-02-06 08:02:46 +03:00
" prompt_title " : (
self . prompt_to_use . title if self . prompt_to_use else None
) ,
2023-12-15 13:43:41 +03:00
" brain_name " : None ,
" message_id " : None ,
2024-01-26 02:56:46 +03:00
" brain_id " : str ( brain . brain_id ) if brain else None ,
2023-08-10 11:25:08 +03:00
}
)
async def generate_stream (
2023-12-15 13:43:41 +03:00
self , chat_id : UUID , question : ChatQuestion , save_answer : bool = True
2023-08-10 11:25:08 +03:00
) - > AsyncIterable :
2024-02-15 01:01:35 +03:00
conversational_qa_chain = self . knowledge_qa . get_chain ( )
2024-02-15 07:07:53 +03:00
transformed_history , streamed_chat_history = (
self . initialize_streamed_chat_history ( chat_id , question )
)
2023-07-11 21:15:56 +03:00
response_tokens = [ ]
2024-02-15 07:07:53 +03:00
sources = [ ]
2023-07-11 21:15:56 +03:00
2024-02-15 07:07:53 +03:00
async for chunk in conversational_qa_chain . astream (
{
" question " : question . question ,
" chat_history " : transformed_history ,
" custom_personality " : (
self . prompt_to_use . content if self . prompt_to_use else None
) ,
}
) :
if chunk . get ( " answer " ) :
logger . info ( f " Chunk: { chunk } " )
response_tokens . append ( chunk [ " answer " ] . content )
streamed_chat_history . assistant = chunk [ " answer " ] . content
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
if chunk . get ( " docs " ) :
sources = chunk [ " docs " ]
sources_list = generate_source ( sources , self . brain_id )
if not streamed_chat_history . metadata :
streamed_chat_history . metadata = { }
# Serialize the sources list
serialized_sources_list = [ source . dict ( ) for source in sources_list ]
streamed_chat_history . metadata [ " sources " ] = serialized_sources_list
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
self . save_answer ( question , response_tokens , streamed_chat_history , save_answer )
def initialize_streamed_chat_history ( self , chat_id , question ) :
history = chat_service . get_chat_history ( self . chat_id )
transformed_history = format_chat_history ( history )
2024-01-20 07:34:30 +03:00
brain = brain_service . get_brain_by_id ( self . brain_id )
2023-08-10 11:25:08 +03:00
2024-02-15 01:01:35 +03:00
streamed_chat_history = chat_service . update_chat_history (
CreateChatHistory (
2023-08-10 11:25:08 +03:00
* * {
2024-02-15 01:01:35 +03:00
" chat_id " : chat_id ,
2023-12-15 13:43:41 +03:00
" user_message " : question . question ,
" assistant " : " " ,
2024-02-15 01:01:35 +03:00
" brain_id " : brain . brain_id ,
" prompt_id " : self . prompt_to_use_id ,
2023-08-10 11:25:08 +03:00
}
)
2024-02-15 01:01:35 +03:00
)
2023-07-11 21:15:56 +03:00
2024-02-15 01:01:35 +03:00
streamed_chat_history = GetChatHistoryOutput (
* * {
" chat_id " : str ( chat_id ) ,
" message_id " : streamed_chat_history . message_id ,
" message_time " : streamed_chat_history . message_time ,
" user_message " : question . question ,
" assistant " : " " ,
" prompt_title " : (
self . prompt_to_use . title if self . prompt_to_use else None
) ,
" brain_name " : brain . name if brain else None ,
" brain_id " : str ( brain . brain_id ) if brain else None ,
" metadata " : self . metadata ,
}
)
2024-01-29 02:37:14 +03:00
2024-02-15 07:07:53 +03:00
return transformed_history , streamed_chat_history
2023-11-06 14:09:18 +03:00
2024-02-15 07:07:53 +03:00
def save_answer (
self , question , response_tokens , streamed_chat_history , save_answer
) :
2023-07-11 21:15:56 +03:00
assistant = " " . join ( response_tokens )
2023-11-06 14:09:18 +03:00
try :
2023-12-15 13:43:41 +03:00
if save_answer :
chat_service . update_message_by_id (
message_id = str ( streamed_chat_history . message_id ) ,
user_message = question . question ,
assistant = assistant ,
2024-01-20 07:34:30 +03:00
metadata = streamed_chat_history . metadata ,
2023-12-15 13:43:41 +03:00
)
2023-11-06 14:09:18 +03:00
except Exception as e :
logger . error ( " Error updating message by ID: %s " , e )