2023-07-03 12:16:36 +03:00
from abc import abstractmethod
from typing import AsyncIterable , List
2023-07-31 22:34:34 +03:00
from langchain . callbacks . streaming_aiter import AsyncIteratorCallbackHandler
2023-07-03 12:16:36 +03:00
from langchain . chains import ConversationalRetrievalChain , LLMChain
2023-07-04 18:56:54 +03:00
from langchain . llms . base import LLM
2023-07-03 12:16:36 +03:00
from logger import get_logger
from models . settings import BrainSettings # Importing settings related to the 'brain'
from pydantic import BaseModel # For data validation and settings management
logger = get_logger ( __name__ )
class BaseBrainPicking ( BaseModel ) :
Base Class for BrainPicking . Allows you to interact with LLMs ( large language models )
Use this class to define abstract methods and methods and properties common to all classes .
# Instantiate settings
2023-07-10 15:27:49 +03:00
brain_settings = BrainSettings ( ) # type: ignore other parameters are optional
2023-07-03 12:16:36 +03:00
# Default class attributes
2023-07-10 15:27:49 +03:00
model : str = None # pyright: ignore reportPrivateUsage=none
2023-07-03 12:16:36 +03:00
temperature : float = 0.0
2023-07-10 15:27:49 +03:00
chat_id : str = None # pyright: ignore reportPrivateUsage=none
brain_id : str = None # pyright: ignore reportPrivateUsage=none
2023-07-03 12:16:36 +03:00
max_tokens : int = 256
2023-07-10 15:27:49 +03:00
user_openai_api_key : str = None # pyright: ignore reportPrivateUsage=none
2023-07-04 18:56:54 +03:00
streaming : bool = False
2023-07-10 15:27:49 +03:00
openai_api_key : str = None # pyright: ignore reportPrivateUsage=none
callbacks : List [
2023-07-31 22:34:34 +03:00
2023-07-10 15:27:49 +03:00
] = None # pyright: ignore reportPrivateUsage=none
2023-07-03 12:16:36 +03:00
def _determine_api_key ( self , openai_api_key , user_openai_api_key ) :
""" If user provided an API key, use it. """
if user_openai_api_key is not None :
return user_openai_api_key
else :
return openai_api_key
2023-07-04 18:56:54 +03:00
def _determine_streaming ( self , model : str , streaming : bool ) - > bool :
2023-07-03 12:16:36 +03:00
""" If the model name allows for streaming and streaming is declared, set streaming to True. """
2023-07-31 22:34:34 +03:00
return streaming
2023-07-04 18:56:54 +03:00
def _determine_callback_array (
self , streaming
2023-07-10 15:27:49 +03:00
) - > List [ AsyncIteratorCallbackHandler ] : # pyright: ignore reportPrivateUsage=none
2023-07-03 12:16:36 +03:00
""" If streaming is set, set the AsyncIteratorCallbackHandler as the only callback. """
if streaming :
2023-07-10 15:27:49 +03:00
return [
2023-07-31 22:34:34 +03:00
AsyncIteratorCallbackHandler ( ) # pyright: ignore reportPrivateUsage=none
2023-07-10 15:27:49 +03:00
2023-07-03 12:16:36 +03:00
def __init__ ( self , * * data ) :
super ( ) . __init__ ( * * data )
self . openai_api_key = self . _determine_api_key (
self . brain_settings . openai_api_key , self . user_openai_api_key
2023-07-10 15:27:49 +03:00
self . streaming = self . _determine_streaming (
self . model , self . streaming
) # pyright: ignore reportPrivateUsage=none
self . callbacks = self . _determine_callback_array (
self . streaming
) # pyright: ignore reportPrivateUsage=none
2023-07-03 12:16:36 +03:00
class Config :
""" Configuration of the Pydantic Object """
# Allowing arbitrary types for class validation
arbitrary_types_allowed = True
# the below methods define the names, arguments and return types for the most useful functions for the child classes. These should be overwritten if they are used.
2023-08-02 11:31:42 +03:00
def _create_llm ( self , model , temperature = 0 , streaming = False , callbacks = None ) - > LLM :
2023-07-03 12:16:36 +03:00
Determine and construct the language model .
2023-07-04 18:56:54 +03:00
: param model : Language model name to be used .
2023-07-03 12:16:36 +03:00
: return : Language model instance
This method should take into account the following :
- Whether the model is streaming compatible
- Whether the model is private
- Whether the model should use an openai api key and use the _determine_api_key method
2023-07-04 18:56:54 +03:00
def _create_question_chain ( self , model ) - > LLMChain :
2023-07-03 12:16:36 +03:00
Determine and construct the question chain .
2023-07-04 18:56:54 +03:00
: param model : Language model name to be used .
2023-07-03 12:16:36 +03:00
: return : Question chain instance
This method should take into account the following :
- Which prompt to use ( normally CONDENSE_QUESTION_PROMPT )
2023-07-04 18:56:54 +03:00
def _create_doc_chain ( self , model ) - > LLMChain :
2023-07-03 12:16:36 +03:00
Determine and construct the document chain .
2023-07-04 18:56:54 +03:00
: param model Language model name to be used .
2023-07-03 12:16:36 +03:00
: return : Document chain instance
This method should take into account the following :
- chain_type ( normally " stuff " )
- Whether the model is streaming compatible and / or streaming is set ( determine_streaming ) .
def _create_qa (
self , question_chain , document_chain
) - > ConversationalRetrievalChain :
Constructs a conversational retrieval chain .
: param question_chain
: param document_chain
: return : ConversationalRetrievalChain instance
def _call_chain ( self , chain , question , history ) - > str :
Call a chain with a given question and history .
: param chain : The chain eg QA ( ConversationalRetrievalChain )
: param question : The user prompt
: param history : The chat history from DB
: return : The answer .
async def _acall_chain ( self , chain , question , history ) - > str :
Call a chain with a given question and history .
: param chain : The chain eg qa ( ConversationalRetrievalChain )
: param question : The user prompt
: param history : The chat history from DB
: return : The answer .
raise NotImplementedError (
" Async generation not implemented for this BrainPicking Class. "
def generate_answer ( self , question : str ) - > str :
Generate an answer to a given question using QA Chain .
: param question : The question
: return : The generated answer .
This function should also call : _create_qa , get_chat_history and format_chat_history .
It should also update the chat_history in the DB .
2023-08-01 14:46:53 +03:00
2023-07-03 12:16:36 +03:00
async def generate_stream ( self , question : str ) - > AsyncIterable :
Generate a streaming answer to a given question using QA Chain .
: param question : The question
: return : An async iterable which generates the answer .
This function has to do some other things :
- Update the chat history in the DB with the chat details ( chat_id , question ) - > Return a message_id and timestamp
- Use the _acall_chain method inside create_task from asyncio to run the process on a child thread .
- Append each token to the chat_history object from the db and yield it from the function
- Append each token from the callback to an answer string - > Used to update chat history in DB ( update_message_by_id )
raise NotImplementedError (
" Async generation not implemented for this BrainPicking Class. "