2023-08-07 17:31:00 +03:00
import json
2024-02-15 01:01:35 +03:00
from typing import AsyncIterable , List , Optional
2023-08-07 17:31:00 +03:00
from uuid import UUID
from langchain . callbacks . streaming_aiter import AsyncIteratorCallbackHandler
2023-08-10 11:25:08 +03:00
from logger import get_logger
2023-12-11 18:46:45 +03:00
from models import BrainSettings
2024-02-20 04:29:45 +03:00
from modules . brain . entity . brain_entity import BrainEntity
2024-02-06 08:02:46 +03:00
from modules . brain . qa_interface import QAInterface
from modules . brain . rags . quivr_rag import QuivrRAG
from modules . brain . rags . rag_interface import RAGInterface
2023-12-01 00:29:28 +03:00
from modules . brain . service . brain_service import BrainService
2024-04-21 14:09:52 +03:00
from modules . brain . service . utils . format_chat_history import format_chat_history
from modules . brain . service . utils . get_prompt_to_use_id import get_prompt_to_use_id
2024-02-20 04:29:45 +03:00
from modules . chat . controller . chat . utils import (
find_model_and_generate_metadata ,
update_user_usage ,
)
2024-01-26 05:56:54 +03:00
from modules . chat . dto . chats import ChatQuestion , Sources
2023-12-04 20:38:54 +03:00
from modules . chat . dto . inputs import CreateChatHistory
from modules . chat . dto . outputs import GetChatHistoryOutput
from modules . chat . service . chat_service import ChatService
2024-04-21 14:09:52 +03:00
from modules . prompt . service . get_prompt_to_use import get_prompt_to_use
from modules . upload . service . generate_file_signed_url import generate_file_signed_url
from modules . user . service . user_usage import UserUsage
2024-02-15 01:01:35 +03:00
from pydantic import BaseModel , ConfigDict
from pydantic_settings import BaseSettings
2023-07-11 21:15:56 +03:00
logger = get_logger ( __name__ )
2023-08-22 15:23:27 +03:00
QUIVR_DEFAULT_PROMPT = " Your name is Quivr. You ' re a helpful assistant. If you don ' t know the answer, just say that you don ' t know, don ' t try to make up an answer. "
2023-07-11 21:15:56 +03:00
2023-12-01 00:29:28 +03:00
brain_service = BrainService ( )
2023-12-04 20:38:54 +03:00
chat_service = ChatService ( )
2023-12-01 00:29:28 +03:00
2024-01-20 07:34:30 +03:00
def is_valid_uuid ( uuid_to_test , version = 4 ) :
try :
uuid_obj = UUID ( uuid_to_test , version = version )
except ValueError :
return False
return str ( uuid_obj ) == uuid_to_test
2024-02-15 07:07:53 +03:00
def generate_source ( source_documents , brain_id ) :
2024-01-29 03:01:54 +03:00
# Initialize an empty list for sources
sources_list : List [ Sources ] = [ ]
2024-02-07 07:09:01 +03:00
# Initialize a dictionary for storing generated URLs
generated_urls = { }
2024-04-22 00:50:44 +03:00
# remove duplicate sources with same name and create a list of unique sources
source_documents = list (
{ v . metadata [ " file_name " ] : v for v in source_documents } . values ( )
)
2024-01-29 03:01:54 +03:00
# Get source documents from the result, default to an empty list if not found
# If source documents exist
if source_documents :
logger . info ( f " Source documents found: { source_documents } " )
# Iterate over each document
for doc in source_documents :
2024-02-17 04:14:30 +03:00
logger . info ( " Document: %s " , doc )
2024-01-29 03:01:54 +03:00
# Check if 'url' is in the document metadata
2024-02-17 04:14:30 +03:00
logger . info ( f " Metadata 1: { doc . metadata } " )
2024-01-29 03:01:54 +03:00
is_url = (
2024-02-17 04:14:30 +03:00
" original_file_name " in doc . metadata
and doc . metadata [ " original_file_name " ] is not None
and doc . metadata [ " original_file_name " ] . startswith ( " http " )
2024-01-29 03:01:54 +03:00
)
logger . info ( f " Is URL: { is_url } " )
# Determine the name based on whether it's a URL or a file
name = (
2024-02-17 04:14:30 +03:00
doc . metadata [ " original_file_name " ]
2024-01-29 03:01:54 +03:00
if is_url
2024-02-17 04:14:30 +03:00
else doc . metadata [ " file_name " ]
2024-01-29 03:01:54 +03:00
)
# Determine the type based on whether it's a URL or a file
type_ = " url " if is_url else " file "
# Determine the source URL based on whether it's a URL or a file
if is_url :
2024-02-17 04:14:30 +03:00
source_url = doc . metadata [ " original_file_name " ]
2024-01-29 03:01:54 +03:00
else :
2024-02-17 04:14:30 +03:00
file_path = f " { brain_id } / { doc . metadata [ ' file_name ' ] } "
2024-02-07 07:09:01 +03:00
# Check if the URL has already been generated
if file_path in generated_urls :
source_url = generated_urls [ file_path ]
else :
2024-02-28 08:30:25 +03:00
generated_url = generate_file_signed_url ( file_path )
if generated_url is not None :
source_url = generated_url . get ( " signedURL " , " " )
else :
source_url = " "
2024-02-07 07:09:01 +03:00
# Store the generated URL
generated_urls [ file_path ] = source_url
2024-01-29 03:01:54 +03:00
# Append a new Sources object to the list
sources_list . append (
Sources (
name = name ,
type = type_ ,
source_url = source_url ,
original_file_name = name ,
)
)
else :
logger . info ( " No source documents found or source_documents is not a list. " )
return sources_list
2023-12-11 18:46:45 +03:00
class KnowledgeBrainQA ( BaseModel , QAInterface ) :
2023-07-11 21:15:56 +03:00
"""
2023-08-07 20:53:04 +03:00
Main class for the Brain Picking functionality .
It allows to initialize a Chat model , generate questions and retrieve answers using ConversationalRetrievalChain .
It has two main methods : ` generate_question ` and ` generate_stream ` .
One is for generating questions in a single request , the other is for generating questions in a streaming fashion .
Both are the same , except that the streaming version streams the last message as a stream .
Each have the same prompt template , which is defined in the ` prompt_template ` property .
2023-07-11 21:15:56 +03:00
"""
2023-08-10 11:25:08 +03:00
2024-02-15 01:01:35 +03:00
model_config = ConfigDict ( arbitrary_types_allowed = True )
2023-10-19 16:52:20 +03:00
# Instantiate settings
2024-02-15 01:01:35 +03:00
brain_settings : BaseSettings = BrainSettings ( )
2023-10-19 16:52:20 +03:00
# Default class attributes
2024-02-20 04:29:45 +03:00
model : str = " gpt-3.5-turbo-0125 " # pyright: ignore reportPrivateUsage=none
2023-10-19 16:52:20 +03:00
temperature : float = 0.1
chat_id : str = None # pyright: ignore reportPrivateUsage=none
2024-02-15 01:01:35 +03:00
brain_id : str = None # pyright: ignore reportPrivateUsage=none
2024-01-26 07:19:56 +03:00
max_tokens : int = 2000
2024-01-27 12:50:58 +03:00
max_input : int = 2000
2023-10-19 16:52:20 +03:00
streaming : bool = False
2024-02-15 01:01:35 +03:00
knowledge_qa : Optional [ RAGInterface ] = None
2024-02-20 04:29:45 +03:00
brain : Optional [ BrainEntity ] = None
2024-02-19 19:12:33 +03:00
user_id : str = None
2024-02-20 04:29:45 +03:00
user_email : str = None
user_usage : Optional [ UserUsage ] = None
user_settings : Optional [ dict ] = None
models_settings : Optional [ List [ dict ] ] = None
metadata : Optional [ dict ] = None
2023-10-19 16:52:20 +03:00
2024-02-06 08:02:46 +03:00
callbacks : List [ AsyncIteratorCallbackHandler ] = (
None # pyright: ignore reportPrivateUsage=none
)
2023-10-19 16:52:20 +03:00
2024-02-15 01:01:35 +03:00
prompt_id : Optional [ UUID ] = None
2023-07-11 21:15:56 +03:00
def __init__ (
self ,
brain_id : str ,
chat_id : str ,
streaming : bool = False ,
2023-08-22 15:23:27 +03:00
prompt_id : Optional [ UUID ] = None ,
2024-01-21 05:39:03 +03:00
metadata : Optional [ dict ] = None ,
2024-02-15 07:07:53 +03:00
user_id : str = None ,
2024-02-20 04:29:45 +03:00
user_email : str = None ,
cost : int = 100 ,
2023-07-11 21:15:56 +03:00
* * kwargs ,
2023-08-21 13:45:32 +03:00
) :
2023-07-11 21:15:56 +03:00
super ( ) . __init__ (
brain_id = brain_id ,
chat_id = chat_id ,
streaming = streaming ,
* * kwargs ,
)
2023-08-22 15:23:27 +03:00
self . prompt_id = prompt_id
2024-02-20 04:29:45 +03:00
self . user_id = user_id
self . user_email = user_email
self . user_usage = UserUsage (
id = user_id ,
email = user_email ,
)
self . brain = brain_service . get_brain_by_id ( brain_id )
self . user_settings = self . user_usage . get_user_settings ( )
# Get Model settings for the user
self . models_settings = self . user_usage . get_model_settings ( )
self . increase_usage_user ( )
2023-12-11 18:46:45 +03:00
self . knowledge_qa = QuivrRAG (
2024-02-20 13:14:02 +03:00
model = self . brain . model if self . brain . model else self . model ,
2023-12-11 18:46:45 +03:00
brain_id = brain_id ,
chat_id = chat_id ,
streaming = streaming ,
2024-02-23 02:38:25 +03:00
max_input = self . max_input ,
max_tokens = self . max_tokens ,
2023-12-11 18:46:45 +03:00
* * kwargs ,
)
2023-08-22 15:23:27 +03:00
@property
def prompt_to_use ( self ) :
2024-01-20 07:34:30 +03:00
if self . brain_id and is_valid_uuid ( self . brain_id ) :
return get_prompt_to_use ( UUID ( self . brain_id ) , self . prompt_id )
else :
return None
2023-08-22 15:23:27 +03:00
@property
def prompt_to_use_id ( self ) - > Optional [ UUID ] :
2023-12-15 13:43:41 +03:00
# TODO: move to prompt service or instruction or something
2024-01-20 07:34:30 +03:00
if self . brain_id and is_valid_uuid ( self . brain_id ) :
return get_prompt_to_use_id ( UUID ( self . brain_id ) , self . prompt_id )
else :
return None
2023-07-11 21:15:56 +03:00
2024-02-20 04:29:45 +03:00
def increase_usage_user ( self ) :
# Raises an error if the user has consumed all of of his credits
update_user_usage (
usage = self . user_usage ,
user_settings = self . user_settings ,
cost = self . calculate_pricing ( ) ,
)
def calculate_pricing ( self ) :
logger . info ( " Calculating pricing " )
logger . info ( f " Model: { self . model } " )
logger . info ( f " User settings: { self . user_settings } " )
logger . info ( f " Models settings: { self . models_settings } " )
model_to_use = find_model_and_generate_metadata (
self . chat_id ,
self . brain . model ,
self . user_settings ,
self . models_settings ,
)
self . model = model_to_use . name
self . max_input = model_to_use . max_input
self . max_tokens = model_to_use . max_output
user_choosen_model_price = 1000
for model_setting in self . models_settings :
if model_setting [ " name " ] == self . model :
user_choosen_model_price = model_setting [ " price " ]
return user_choosen_model_price
2023-08-10 11:25:08 +03:00
def generate_answer (
2023-12-15 13:43:41 +03:00
self , chat_id : UUID , question : ChatQuestion , save_answer : bool = True
2023-08-10 11:25:08 +03:00
) - > GetChatHistoryOutput :
2024-02-15 10:54:52 +03:00
conversational_qa_chain = self . knowledge_qa . get_chain ( )
transformed_history , streamed_chat_history = (
self . initialize_streamed_chat_history ( chat_id , question )
2023-12-04 20:38:54 +03:00
)
2024-04-22 00:50:44 +03:00
metadata = self . metadata or { }
2024-02-15 10:54:52 +03:00
model_response = conversational_qa_chain . invoke (
2023-08-10 11:25:08 +03:00
{
2023-08-10 19:35:30 +03:00
" question " : question . question ,
2023-08-07 20:53:04 +03:00
" chat_history " : transformed_history ,
2024-02-15 10:54:52 +03:00
" custom_personality " : (
self . prompt_to_use . content if self . prompt_to_use else None
) ,
2023-08-10 11:25:08 +03:00
}
2023-12-11 18:46:45 +03:00
)
2023-08-10 11:25:08 +03:00
2024-04-22 00:50:44 +03:00
sources = model_response [ " docs " ] or [ ]
if len ( sources ) > 0 :
sources_list = generate_source ( sources , self . brain_id )
metadata [ " sources " ] = sources_list
2024-02-15 10:54:52 +03:00
answer = model_response [ " answer " ] . content
2023-08-10 11:25:08 +03:00
2023-12-15 13:43:41 +03:00
if save_answer :
# save the answer to the database or not -> add a variable
new_chat = chat_service . update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
2024-02-20 04:29:45 +03:00
" brain_id " : self . brain . brain_id ,
2023-12-15 13:43:41 +03:00
" prompt_id " : self . prompt_to_use_id ,
}
)
)
return GetChatHistoryOutput (
2023-08-10 11:25:08 +03:00
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
2023-12-15 13:43:41 +03:00
" message_time " : new_chat . message_time ,
2024-02-06 08:02:46 +03:00
" prompt_title " : (
self . prompt_to_use . title if self . prompt_to_use else None
) ,
2024-02-20 04:29:45 +03:00
" brain_name " : self . brain . name if self . brain else None ,
2023-12-15 13:43:41 +03:00
" message_id " : new_chat . message_id ,
2024-02-20 04:29:45 +03:00
" brain_id " : str ( self . brain . brain_id ) if self . brain else None ,
2024-04-22 00:50:44 +03:00
" metadata " : metadata ,
2023-08-10 11:25:08 +03:00
}
)
return GetChatHistoryOutput (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
2023-08-15 16:59:30 +03:00
" assistant " : answer ,
2023-12-15 13:43:41 +03:00
" message_time " : None ,
2024-02-06 08:02:46 +03:00
" prompt_title " : (
self . prompt_to_use . title if self . prompt_to_use else None
) ,
2023-12-15 13:43:41 +03:00
" brain_name " : None ,
" message_id " : None ,
2024-02-20 04:29:45 +03:00
" brain_id " : str ( self . brain . brain_id ) if self . brain else None ,
2024-04-22 00:50:44 +03:00
" metadata " : metadata ,
2023-08-10 11:25:08 +03:00
}
)
async def generate_stream (
2023-12-15 13:43:41 +03:00
self , chat_id : UUID , question : ChatQuestion , save_answer : bool = True
2023-08-10 11:25:08 +03:00
) - > AsyncIterable :
2024-02-15 01:01:35 +03:00
conversational_qa_chain = self . knowledge_qa . get_chain ( )
2024-02-15 07:07:53 +03:00
transformed_history , streamed_chat_history = (
self . initialize_streamed_chat_history ( chat_id , question )
)
2023-07-11 21:15:56 +03:00
response_tokens = [ ]
2024-02-15 07:07:53 +03:00
sources = [ ]
2023-07-11 21:15:56 +03:00
2024-02-15 07:07:53 +03:00
async for chunk in conversational_qa_chain . astream (
{
" question " : question . question ,
" chat_history " : transformed_history ,
" custom_personality " : (
self . prompt_to_use . content if self . prompt_to_use else None
) ,
}
) :
if chunk . get ( " answer " ) :
logger . info ( f " Chunk: { chunk } " )
response_tokens . append ( chunk [ " answer " ] . content )
streamed_chat_history . assistant = chunk [ " answer " ] . content
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
if chunk . get ( " docs " ) :
sources = chunk [ " docs " ]
sources_list = generate_source ( sources , self . brain_id )
if not streamed_chat_history . metadata :
streamed_chat_history . metadata = { }
# Serialize the sources list
serialized_sources_list = [ source . dict ( ) for source in sources_list ]
streamed_chat_history . metadata [ " sources " ] = serialized_sources_list
yield f " data: { json . dumps ( streamed_chat_history . dict ( ) ) } "
self . save_answer ( question , response_tokens , streamed_chat_history , save_answer )
def initialize_streamed_chat_history ( self , chat_id , question ) :
history = chat_service . get_chat_history ( self . chat_id )
transformed_history = format_chat_history ( history )
2024-01-20 07:34:30 +03:00
brain = brain_service . get_brain_by_id ( self . brain_id )
2023-08-10 11:25:08 +03:00
2024-02-15 01:01:35 +03:00
streamed_chat_history = chat_service . update_chat_history (
CreateChatHistory (
2023-08-10 11:25:08 +03:00
* * {
2024-02-15 01:01:35 +03:00
" chat_id " : chat_id ,
2023-12-15 13:43:41 +03:00
" user_message " : question . question ,
" assistant " : " " ,
2024-02-15 01:01:35 +03:00
" brain_id " : brain . brain_id ,
" prompt_id " : self . prompt_to_use_id ,
2023-08-10 11:25:08 +03:00
}
)
2024-02-15 01:01:35 +03:00
)
2023-07-11 21:15:56 +03:00
2024-02-15 01:01:35 +03:00
streamed_chat_history = GetChatHistoryOutput (
* * {
" chat_id " : str ( chat_id ) ,
" message_id " : streamed_chat_history . message_id ,
" message_time " : streamed_chat_history . message_time ,
" user_message " : question . question ,
" assistant " : " " ,
" prompt_title " : (
self . prompt_to_use . title if self . prompt_to_use else None
) ,
" brain_name " : brain . name if brain else None ,
" brain_id " : str ( brain . brain_id ) if brain else None ,
" metadata " : self . metadata ,
}
)
2024-01-29 02:37:14 +03:00
2024-02-15 07:07:53 +03:00
return transformed_history , streamed_chat_history
2023-11-06 14:09:18 +03:00
2024-02-15 07:07:53 +03:00
def save_answer (
self , question , response_tokens , streamed_chat_history , save_answer
) :
2023-07-11 21:15:56 +03:00
assistant = " " . join ( response_tokens )
2023-11-06 14:09:18 +03:00
try :
2023-12-15 13:43:41 +03:00
if save_answer :
chat_service . update_message_by_id (
message_id = str ( streamed_chat_history . message_id ) ,
user_message = question . question ,
assistant = assistant ,
2024-01-20 07:34:30 +03:00
metadata = streamed_chat_history . metadata ,
2023-12-15 13:43:41 +03:00
)
2023-11-06 14:09:18 +03:00
except Exception as e :
logger . error ( " Error updating message by ID: %s " , e )
2024-04-21 14:09:52 +03:00
def save_non_streaming_answer ( self , chat_id , question , answer ) :
new_chat = chat_service . update_chat_history (
CreateChatHistory (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
" brain_id " : self . brain . brain_id ,
" prompt_id " : self . prompt_to_use_id ,
}
)
)
return GetChatHistoryOutput (
* * {
" chat_id " : chat_id ,
" user_message " : question . question ,
" assistant " : answer ,
" message_time " : new_chat . message_time ,
" prompt_title " : (
self . prompt_to_use . title if self . prompt_to_use else None
) ,
" brain_name " : self . brain . name if self . brain else None ,
" message_id " : new_chat . message_id ,
" brain_id " : str ( self . brain . brain_id ) if self . brain else None ,
}
)