2023-07-11 21:15:56 +03:00
import asyncio
2023-08-07 17:31:00 +03:00
import json
2023-10-19 16:52:20 +03:00
from typing import AsyncIterable , Awaitable , List , Optional
2023-08-07 17:31:00 +03:00
from uuid import UUID
from langchain . callbacks . streaming_aiter import AsyncIteratorCallbackHandler
2023-12-11 18:46:45 +03:00
from langchain . chains import ConversationalRetrievalChain
2023-12-15 13:43:41 +03:00
from llm . qa_interface import QAInterface
from llm . rags . quivr_rag import QuivrRAG
from llm . rags . rag_interface import RAGInterface
from llm . utils . format_chat_history import format_chat_history
from llm . utils . get_prompt_to_use import get_prompt_to_use
from llm . utils . get_prompt_to_use_id import get_prompt_to_use_id
2023-08-10 11:25:08 +03:00
from logger import get_logger
2023-12-11 18:46:45 +03:00
from models import BrainSettings
2023-12-01 00:29:28 +03:00
from modules . brain . service . brain_service import BrainService
2024-01-26 05:56:54 +03:00
from modules . chat . dto . chats import ChatQuestion , Sources
2023-12-04 20:38:54 +03:00
from modules . chat . dto . inputs import CreateChatHistory
from modules . chat . dto . outputs import GetChatHistoryOutput
from modules . chat . service . chat_service import ChatService
2023-10-19 16:52:20 +03:00
from pydantic import BaseModel
2024-01-29 02:37:14 +03:00
from repository . files . generate_file_signed_url import generate_file_signed_url
2023-07-11 21:15:56 +03:00
logger = get_logger ( __name__ )
2023-08-22 15:23:27 +03:00
QUIVR_DEFAULT_PROMPT = " Your name is Quivr. You ' re a helpful assistant. If you don ' t know the answer, just say that you don ' t know, don ' t try to make up an answer. "
2023-07-11 21:15:56 +03:00
2023-12-01 00:29:28 +03:00
brain_service = BrainService ( )
2023-12-04 20:38:54 +03:00
chat_service = ChatService ( )
2023-12-01 00:29:28 +03:00
2024-01-20 07:34:30 +03:00
def is_valid_uuid ( uuid_to_test , version = 4 ) :
try :
uuid_obj = UUID ( uuid_to_test , version = version )
except ValueError :
return False
return str ( uuid_obj ) == uuid_to_test
2024-01-29 03:01:54 +03:00
def generate_source ( result , brain ) :
# Initialize an empty list for sources
sources_list : List [ Sources ] = [ ]
# Get source documents from the result, default to an empty list if not found
source_documents = result . get ( " source_documents " , [ ] )
# If source documents exist
if source_documents :
logger . info ( f " Source documents found: { source_documents } " )
# Iterate over each document
for doc in source_documents :
# Check if 'url' is in the document metadata
2024-01-29 07:12:28 +03:00
logger . info ( f " Metadata 1: { doc . metadata } " )
2024-01-29 03:01:54 +03:00
is_url = (
2024-01-29 07:12:28 +03:00
" original_file_name " in doc . metadata
and doc . metadata [ " original_file_name " ] is not None
and doc . metadata [ " original_file_name " ] . startswith ( " http " )
2024-01-29 03:01:54 +03:00
)
logger . info ( f " Is URL: { is_url } " )
# Determine the name based on whether it's a URL or a file
name = (
2024-01-29 07:12:28 +03:00
doc . metadata [ " original_file_name " ]
2024-01-29 03:01:54 +03:00
if is_url
2024-01-29 07:12:28 +03:00
else doc . metadata [ " file_name " ]
2024-01-29 03:01:54 +03:00
)
# Determine the type based on whether it's a URL or a file
type_ = " url " if is_url else " file "
# Determine the source URL based on whether it's a URL or a file
if is_url :
2024-01-29 07:12:28 +03:00
source_url = doc . metadata [ " original_file_name " ]
2024-01-29 03:01:54 +03:00
else :
source_url = generate_file_signed_url (
2024-01-29 07:12:28 +03:00
f " { brain . brain_id } / { doc . metadata [ ' file_name ' ] } "
2024-01-29 03:01:54 +03:00
) . get ( " signedURL " , " " )
# Append a new Sources object to the list
sources_list . append (
Sources (
name = name ,
type = type_ ,
source_url = source_url ,
original_file_name = name ,
)
)
else :
logger . info ( " No source documents found or source_documents is not a list. " )
return sources_list
2023-12-11 18:46:45 +03:00
class KnowledgeBrainQA ( BaseModel , QAInterface ) :
2023-07-11 21:15:56 +03:00
"""
2023-08-07 20:53:04 +03:00
Main class for the Brain Picking functionality .
It allows to initialize a Chat model , generate questions and retrieve answers using ConversationalRetrievalChain .
It has two main methods : ` generate_question ` and ` generate_stream ` .
One is for generating questions in a single request , the other is for generating questions in a streaming fashion .
Both are the same , except that the streaming version streams the last message as a stream .
Each have the same prompt template , which is defined in the ` prompt_template ` property .
2023-07-11 21:15:56 +03:00
"""
2023-08-10 11:25:08 +03:00
2023-10-19 16:52:20 +03:00
class Config :
""" Configuration of the Pydantic Object """
arbitrary_types_allowed = True
# Instantiate settings
brain_settings = BrainSettings ( ) # type: ignore other parameters are optional
# Default class attributes
model : str = None # pyright: ignore reportPrivateUsage=none
temperature : float = 0.1
chat_id : str = None # pyright: ignore reportPrivateUsage=none
2024-01-20 07:34:30 +03:00
brain_id : str # pyright: ignore reportPrivateUsage=none
2024-01-26 07:19:56 +03:00
max_tokens : int = 2000
2024-01-27 12:50:58 +03:00
max_input : int = 2000
2023-10-19 16:52:20 +03:00
streaming : bool = False
2023-12-11 18:46:45 +03:00
knowledge_qa : Optional [ RAGInterface ]
2024-01-21 05:39:03 +03:00
metadata : Optional [ dict ] = None
2023-10-19 16:52:20 +03:00
callbacks : List [
AsyncIteratorCallbackHandler
] = None # pyright: ignore reportPrivateUsage=none
2023-08-22 15:23:27 +03:00
prompt_id : Optional [ UUID ]
2023-07-11 21:15:56 +03:00
def __init__ (
self ,
model : str ,
brain_id : str ,
chat_id : str ,
2024-01-27 12:50:58 +03:00
max_tokens : int ,
2023-07-11 21:15:56 +03:00
streaming : bool = False ,
2023-08-22 15:23:27 +03:00
prompt_id : Optional [ UUID ] = None ,
2024-01-21 05:39:03 +03:00
metadata : Optional [ dict ] = None ,
2023-07-11 21:15:56 +03:00
* * kwargs ,
2023-08-21 13:45:32 +03:00
) :
2023-07-11 21:15:56 +03:00
super ( ) . __init__ (
model = model ,
brain_id = brain_id ,
chat_id = chat_id ,
streaming = streaming ,
* * kwargs ,
)
2023-08-22 15:23:27 +03:00
self . prompt_id = prompt_id
2023-12-11 18:46:45 +03:00
self . knowledge_qa = QuivrRAG (
model = model ,
brain_id = brain_id ,
chat_id = chat_id ,
streaming = streaming ,
* * kwargs ,
)
2024-01-21 05:39:03 +03:00
self . metadata = metadata
2024-01-27 12:50:58 +03:00
self . max_tokens = max_tokens
2023-08-22 15:23:27 +03:00
@property
def prompt_to_use ( self ) :
2024-01-20 07:34:30 +03:00
if self . brain_id and is_valid_uuid ( self . brain_id ) :
return get_prompt_to_use ( UUID ( self . brain_id ) , self . prompt_id )
else :
return None
2023-08-22 15:23:27 +03:00
@property
def prompt_to_use_id ( self ) - > Optional [ UUID ] :
2023-12-15 13:43:41 +03:00
# TODO: move to prompt service or instruction or something
2024-01-20 07:34:30 +03:00
if self . brain_id and is_valid_uuid ( self . brain_id ) :
return get_prompt_to_use_id ( UUID ( self . brain_id ) , self . prompt_id )
else :
return None
2023-07-11 21:15:56 +03:00
2023-08-10 11:25:08 +03:00
def generate_answer (
2023-12-15 13:43:41 +03:00
self , chat_id : UUID , question : ChatQuestion , save_answer : bool = True
2023-08-10 11:25:08 +03:00
) - > GetChatHistoryOutput :
2023-12-04 20:38:54 +03:00
transformed_history = format_chat_history (
chat_service . get_chat_history ( self . chat_id )
)
2023-08-08 18:15:43 +03:00
# The Chain that combines the question and answer
qa = ConversationalRetrievalChain (
2023-12-11 18:46:45 +03:00
retriever = self . knowledge_qa . get_retriever ( ) ,
combine_docs_chain = self . knowledge_qa . get_doc_chain (
streaming = False ,
2023-08-08 18:15:43 +03:00
) ,
2023-12-11 18:46:45 +03:00
question_generator = self . knowledge_qa . get_question_generation_llm ( ) ,
2023-08-18 11:18:29 +03:00
verbose = False ,
2023-09-19 13:11:03 +03:00
rephrase_question = False ,
2023-11-06 14:09:18 +03:00
return_source_documents = True ,
2023-08-08 18:15:43 +03:00
)
2023-08-22 15:23:27 +03:00
prompt_content = (
self . prompt_to_use . content if self . prompt_to_use else QUIVR_DEFAULT_PROMPT
)
2023-08-10 11:25:08 +03:00
model_response = qa (
{
2023-08-10 19:35:30 +03:00
" question " : question . question ,
2023-08-07 20:53:04 +03:00
" chat_history " : transformed_history ,
2023-08-22 15:23:27 +03:00
" custom_personality " : prompt_content ,
2023-08-10 11:25:08 +03:00
}
2023-12-11 18:46:45 +03:00
)
2023-08-10 11:25:08 +03:00
2023-07-11 21:15:56 +03:00
answer = model_response [ " answer " ]
2023-08-10 11:25:08 +03:00
2024-01-20 07:34:30 +03:00
brain = brain_service . get_brain_by_id ( self . brain_id )
2023-12-15 13:43:41 +03:00
if save_answer :
# save the answer to the database or not -> add a variable
new_chat = chat_service . update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
2024-01-20 07:34:30 +03:00
" brain_id " : brain . brain_id ,
2023-12-15 13:43:41 +03:00
" prompt_id " : self . prompt_to_use_id ,
}
)
)
return GetChatHistoryOutput (
2023-08-10 11:25:08 +03:00
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
2023-12-15 13:43:41 +03:00
" message_time " : new_chat . message_time ,
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : brain . name if brain else None ,
" message_id " : new_chat . message_id ,
2024-01-26 02:56:46 +03:00
" brain_id " : str ( brain . brain_id ) if brain else None ,
2023-08-10 11:25:08 +03:00
}
)
return GetChatHistoryOutput (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
2023-08-15 16:59:30 +03:00
" assistant " : answer ,
2023-12-15 13:43:41 +03:00
" message_time " : None ,
2023-08-22 15:23:27 +03:00
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
2023-12-15 13:43:41 +03:00
" brain_name " : None ,
" message_id " : None ,
2024-01-26 02:56:46 +03:00
" brain_id " : str ( brain . brain_id ) if brain else None ,
2023-08-10 11:25:08 +03:00
}
)
async def generate_stream (
2023-12-15 13:43:41 +03:00
self , chat_id : UUID , question : ChatQuestion , save_answer : bool = True
2023-08-10 11:25:08 +03:00
) - > AsyncIterable :
2023-12-04 20:38:54 +03:00
history = chat_service . get_chat_history ( self . chat_id )
2023-07-31 22:34:34 +03:00
callback = AsyncIteratorCallbackHandler ( )
self . callbacks = [ callback ]
2023-08-02 11:31:42 +03:00
# The Chain that combines the question and answer
2023-07-31 22:34:34 +03:00
qa = ConversationalRetrievalChain (
2023-12-11 18:46:45 +03:00
retriever = self . knowledge_qa . get_retriever ( ) ,
combine_docs_chain = self . knowledge_qa . get_doc_chain (
callbacks = self . callbacks ,
streaming = True ,
2023-08-07 20:53:04 +03:00
) ,
2023-12-11 18:46:45 +03:00
question_generator = self . knowledge_qa . get_question_generation_llm ( ) ,
2023-08-18 11:18:29 +03:00
verbose = False ,
2023-09-19 13:11:03 +03:00
rephrase_question = False ,
2023-11-06 14:09:18 +03:00
return_source_documents = True ,
2023-08-07 17:31:00 +03:00
)
2023-07-11 21:15:56 +03:00
transformed_history = format_chat_history ( history )
response_tokens = [ ]
async def wrap_done ( fn : Awaitable , event : asyncio . Event ) :
try :
2023-11-06 14:09:18 +03:00
return await fn
2023-07-11 21:15:56 +03:00
except Exception as e :
logger . error ( f " Caught exception: { e } " )
2023-11-06 14:09:18 +03:00
return None # Or some sentinel value that indicates failure
2023-07-11 21:15:56 +03:00
finally :
event . set ( )
2023-08-07 17:31:00 +03:00
2023-08-22 15:23:27 +03:00
prompt_content = self . prompt_to_use . content if self . prompt_to_use else None
2023-08-07 17:31:00 +03:00
run = asyncio . create_task (
wrap_done (
qa . acall (
{
2023-08-10 19:35:30 +03:00
" question " : question . question ,
2023-08-07 17:31:00 +03:00
" chat_history " : transformed_history ,
2023-08-22 15:23:27 +03:00
" custom_personality " : prompt_content ,
2023-08-07 17:31:00 +03:00
}
) ,
callback . done ,
)
)
2024-01-20 07:34:30 +03:00
brain = brain_service . get_brain_by_id ( self . brain_id )
2023-08-10 11:25:08 +03:00
2023-12-15 13:43:41 +03:00
if save_answer :
streamed_chat_history = chat_service . update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : " " ,
2024-01-20 07:34:30 +03:00
" brain_id " : brain . brain_id ,
2023-12-15 13:43:41 +03:00
" prompt_id " : self . prompt_to_use_id ,
}
)
)
streamed_chat_history = GetChatHistoryOutput (
2023-08-10 11:25:08 +03:00
* * {
2023-12-15 13:43:41 +03:00
" chat_id " : str ( chat_id ) ,
" message_id " : streamed_chat_history . message_id ,
" message_time " : streamed_chat_history . message_time ,
2023-08-10 11:25:08 +03:00
" user_message " : question . question ,
" assistant " : " " ,
2023-12-15 13:43:41 +03:00
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : brain . name if brain else None ,
2024-01-26 02:56:46 +03:00
" brain_id " : str ( brain . brain_id ) if brain else None ,
2024-01-21 05:39:03 +03:00
" metadata " : self . metadata ,
2023-12-15 13:43:41 +03:00
}
)
else :
streamed_chat_history = GetChatHistoryOutput (
* * {
" chat_id " : str ( chat_id ) ,
" message_id " : None ,
" message_time " : None ,
" user_message " : question . question ,
" assistant " : " " ,
" prompt_title " : self . prompt_to_use . title
if self . prompt_to_use
else None ,
" brain_name " : brain . name if brain else None ,
2024-01-26 02:56:46 +03:00
" brain_id " : str ( brain . brain_id ) if brain else None ,
2024-01-21 05:39:03 +03:00
" metadata " : self . metadata ,
2023-08-10 11:25:08 +03:00
}
)
2023-07-11 21:15:56 +03:00
2023-11-06 14:09:18 +03:00
try :
async for token in callback . aiter ( ) :
logger . debug ( " Token: %s " , token )
response_tokens . append ( token )
streamed_chat_history . assistant = token
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
except Exception as e :
logger . error ( " Error during streaming tokens: %s " , e )
try :
2024-01-29 02:37:14 +03:00
# Python
# Await the run
2023-11-06 14:09:18 +03:00
result = await run
2023-11-06 19:19:26 +03:00
2024-01-29 03:01:54 +03:00
sources_list = generate_source ( result , brain )
# Create metadata if it doesn't exist
if not streamed_chat_history . metadata :
streamed_chat_history . metadata = { }
# Serialize the sources list
serialized_sources_list = [ source . dict ( ) for source in sources_list ]
streamed_chat_history . metadata [ " sources " ] = serialized_sources_list
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
2023-11-06 14:09:18 +03:00
except Exception as e :
logger . error ( " Error processing source documents: %s " , e )
# Combine all response tokens to form the final assistant message
2023-07-11 21:15:56 +03:00
assistant = " " . join ( response_tokens )
2023-11-06 14:09:18 +03:00
try :
2023-12-15 13:43:41 +03:00
if save_answer :
chat_service . update_message_by_id (
message_id = str ( streamed_chat_history . message_id ) ,
user_message = question . question ,
assistant = assistant ,
2024-01-20 07:34:30 +03:00
metadata = streamed_chat_history . metadata ,
2023-12-15 13:43:41 +03:00
)
2023-11-06 14:09:18 +03:00
except Exception as e :
logger . error ( " Error updating message by ID: %s " , e )