2023-07-11 21:15:56 +03:00
import asyncio
2023-08-07 17:31:00 +03:00
import json
2023-07-11 21:15:56 +03:00
from abc import abstractmethod , abstractproperty
from typing import AsyncIterable , Awaitable
2023-08-07 17:31:00 +03:00
from uuid import UUID
from langchain import PromptTemplate
from langchain . callbacks . streaming_aiter import AsyncIteratorCallbackHandler
2023-07-11 21:15:56 +03:00
from langchain . chains import ConversationalRetrievalChain , LLMChain
from langchain . chains . question_answering import load_qa_chain
from langchain . embeddings . openai import OpenAIEmbeddings
from langchain . llms . base import BaseLLM
from logger import get_logger
from models . chat import ChatHistory
2023-08-07 17:31:00 +03:00
from repository . brain . get_brain_by_id import get_brain_by_id
2023-07-11 21:15:56 +03:00
from repository . chat . format_chat_history import format_chat_history
from repository . chat . get_chat_history import get_chat_history
from repository . chat . update_chat_history import update_chat_history
2023-08-07 17:31:00 +03:00
from repository . chat . update_message_by_id import update_message_by_id
from repository . prompt . get_prompt_by_id import get_prompt_by_id
2023-07-11 21:15:56 +03:00
from supabase . client import Client , create_client
from vectorstore . supabase import CustomSupabaseVectorStore
from . base import BaseBrainPicking
from . prompts . CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
logger = get_logger ( __name__ )
class QABaseBrainPicking ( BaseBrainPicking ) :
"""
Base class for the Brain Picking functionality using the Conversational Retrieval Chain ( QA ) from Langchain .
It is not designed to be used directly , but to be subclassed by other classes which use the QA chain .
"""
def __init__ (
self ,
model : str ,
brain_id : str ,
chat_id : str ,
streaming : bool = False ,
* * kwargs ,
) - > " QABaseBrainPicking " : # pyright: ignore reportPrivateUsage=none
"""
Initialize the QA BrainPicking class by setting embeddings , supabase client , vector store , language model and chains .
: return : QABrainPicking instance
"""
super ( ) . __init__ (
model = model ,
brain_id = brain_id ,
chat_id = chat_id ,
streaming = streaming ,
* * kwargs ,
)
@abstractproperty
def embeddings ( self ) - > OpenAIEmbeddings :
raise NotImplementedError ( " This property should be overridden in a subclass. " )
@property
def supabase_client ( self ) - > Client :
return create_client (
self . brain_settings . supabase_url , self . brain_settings . supabase_service_key
)
@property
def vector_store ( self ) - > CustomSupabaseVectorStore :
return CustomSupabaseVectorStore (
self . supabase_client ,
self . embeddings ,
table_name = " vectors " ,
brain_id = self . brain_id ,
)
2023-08-07 17:31:00 +03:00
2023-07-11 21:15:56 +03:00
@property
def question_llm ( self ) :
return self . _create_llm ( model = self . model , streaming = False )
@property
def doc_llm ( self ) :
return self . _create_llm (
2023-07-31 22:34:34 +03:00
model = self . model , streaming = True , callbacks = self . callbacks
2023-07-11 21:15:56 +03:00
)
@property
def question_generator ( self ) - > LLMChain :
2023-08-07 17:31:00 +03:00
return LLMChain (
llm = self . question_llm , prompt = CONDENSE_QUESTION_PROMPT , verbose = True
)
2023-07-11 21:15:56 +03:00
@property
def doc_chain ( self ) - > LLMChain :
2023-08-07 17:31:00 +03:00
prompt_template = """ Use the following pieces of context to answer the question at the end. If you don ' t know the answer, just say that you don ' t know, don ' t try to make up an answer.
{ context }
Question : { question }
Here is instructions on how to answer the question : { brain_prompt }
Answer : """
PROMPT = PromptTemplate (
template = prompt_template ,
input_variables = [ " context " , " question " , " brain_prompt " ] ,
)
2023-07-11 21:15:56 +03:00
return load_qa_chain (
2023-08-07 17:31:00 +03:00
llm = self . doc_llm , chain_type = " stuff " , verbose = True , prompt = PROMPT
2023-07-11 21:15:56 +03:00
) # pyright: ignore reportPrivateUsage=none
@property
def qa ( self ) - > ConversationalRetrievalChain :
return ConversationalRetrievalChain (
retriever = self . vector_store . as_retriever ( ) ,
question_generator = self . question_generator ,
combine_docs_chain = self . doc_chain , # pyright: ignore reportPrivateUsage=none
verbose = True ,
)
@abstractmethod
2023-08-07 17:31:00 +03:00
def _create_llm (
self , model , streaming = False , callbacks = None , temperature = 0.0
) - > BaseLLM :
2023-07-11 21:15:56 +03:00
"""
Determine the language model to be used .
: param model : Language model name to be used .
: param streaming : Whether to enable streaming of the model
: param callbacks : Callbacks to be used for streaming
: return : Language model instance
"""
2023-08-07 17:31:00 +03:00
def _call_chain ( self , chain , question , history , brain_prompt ) :
2023-07-11 21:15:56 +03:00
"""
Call a chain with a given question and history .
: param chain : The chain eg QA ( ConversationalRetrievalChain )
: param question : The user prompt
: param history : The chat history from DB
: return : The answer .
"""
return chain (
{
" question " : question ,
" chat_history " : history ,
2023-08-07 17:31:00 +03:00
" brain_prompt " : brain_prompt ,
2023-07-11 21:15:56 +03:00
}
)
def generate_answer ( self , question : str ) - > ChatHistory :
"""
Generate an answer to a given question by interacting with the language model .
: param question : The question
: return : The generated answer .
"""
transformed_history = [ ]
# Get the history from the database
history = get_chat_history ( self . chat_id )
# Format the chat history into a list of tuples (human, ai)
transformed_history = format_chat_history ( history )
# Generate the model response using the QA chain
2023-08-07 17:31:00 +03:00
model_response = self . _call_chain (
self . qa ,
question ,
transformed_history ,
brain_prompt = self . get_prompt ( ) ,
)
2023-07-11 21:15:56 +03:00
answer = model_response [ " answer " ]
# Update chat history
chat_answer = update_chat_history (
chat_id = self . chat_id ,
user_message = question ,
assistant = answer ,
)
return chat_answer
async def _acall_chain ( self , chain , question , history ) :
"""
Call a chain with a given question and history .
: param chain : The chain eg QA ( ConversationalRetrievalChain )
: param question : The user prompt
: param history : The chat history from DB
: return : The answer .
"""
return chain . acall (
{
" question " : question ,
" chat_history " : history ,
}
)
2023-08-07 17:31:00 +03:00
def get_prompt ( self ) - > str :
brain = get_brain_by_id ( UUID ( self . brain_id ) )
brain_prompt = " Your name is Quivr. You ' re a helpful assistant. "
if brain and brain . prompt_id :
brain_prompt_object = get_prompt_by_id ( brain . prompt_id )
if brain_prompt_object :
brain_prompt = brain_prompt_object . content
return brain_prompt
2023-07-11 21:15:56 +03:00
async def generate_stream ( self , question : str ) - > AsyncIterable :
"""
Generate a streaming answer to a given question by interacting with the language model .
: param question : The question
: return : An async iterable which generates the answer .
"""
history = get_chat_history ( self . chat_id )
2023-07-31 22:34:34 +03:00
callback = AsyncIteratorCallbackHandler ( )
self . callbacks = [ callback ]
2023-08-02 11:31:42 +03:00
# The Model used to answer the question with the context
2023-08-07 17:31:00 +03:00
answering_llm = self . _create_llm (
model = self . model ,
streaming = True ,
callbacks = self . callbacks ,
temperature = self . temperature ,
)
2023-08-02 11:31:42 +03:00
# The Model used to create the standalone Question
# Temperature = 0 means no randomness
standalone_question_llm = self . _create_llm ( model = self . model )
# The Chain that generates the standalone question
2023-08-07 17:31:00 +03:00
standalone_question_generator = LLMChain (
llm = standalone_question_llm , prompt = CONDENSE_QUESTION_PROMPT
)
prompt_template = """ Use the following pieces of context to answer the question at the end. If you don ' t know the answer, just say that you don ' t know, don ' t try to make up an answer.
{ context }
Question : { question }
Here is instructions on how to answer the question : { brain_prompt }
Answer : """
PROMPT = PromptTemplate (
template = prompt_template ,
input_variables = [ " context " , " question " , " brain_prompt " ] ,
)
2023-08-02 11:31:42 +03:00
# The Chain that generates the answer to the question
2023-08-07 17:31:00 +03:00
doc_chain = load_qa_chain ( answering_llm , chain_type = " stuff " , prompt = PROMPT )
2023-08-02 11:31:42 +03:00
# The Chain that combines the question and answer
2023-07-31 22:34:34 +03:00
qa = ConversationalRetrievalChain (
2023-08-07 17:31:00 +03:00
retriever = self . vector_store . as_retriever ( ) ,
combine_docs_chain = doc_chain ,
question_generator = standalone_question_generator ,
)
2023-07-11 21:15:56 +03:00
transformed_history = [ ]
# Format the chat history into a list of tuples (human, ai)
transformed_history = format_chat_history ( history )
# Initialize a list to hold the tokens
response_tokens = [ ]
# Wrap an awaitable with a event to signal when it's done or an exception is raised.
2023-08-07 17:31:00 +03:00
2023-07-11 21:15:56 +03:00
async def wrap_done ( fn : Awaitable , event : asyncio . Event ) :
try :
await fn
except Exception as e :
logger . error ( f " Caught exception: { e } " )
finally :
event . set ( )
2023-08-07 17:31:00 +03:00
2023-07-31 22:34:34 +03:00
# Begin a task that runs in the background.
2023-08-07 17:31:00 +03:00
run = asyncio . create_task (
wrap_done (
qa . acall (
{
" question " : question ,
" chat_history " : transformed_history ,
" brain_prompt " : self . get_prompt ( ) ,
}
) ,
callback . done ,
)
)
2023-07-11 21:15:56 +03:00
streamed_chat_history = update_chat_history (
chat_id = self . chat_id ,
user_message = question ,
assistant = " " ,
)
# Use the aiter method of the callback to stream the response with server-sent-events
async for token in callback . aiter ( ) : # pyright: ignore reportPrivateUsage=none
logger . info ( " Token: %s " , token )
# Add the token to the response_tokens list
response_tokens . append ( token )
streamed_chat_history . assistant = token
yield f " data: { json . dumps ( streamed_chat_history . to_dict ( ) ) } "
2023-07-31 22:34:34 +03:00
await run
2023-07-11 21:15:56 +03:00
# Join the tokens to create the assistant's response
assistant = " " . join ( response_tokens )
update_message_by_id (
message_id = streamed_chat_history . message_id ,
user_message = question ,
assistant = assistant ,
)