2024-02-15 01:01:35 +03:00
from operator import itemgetter
from typing import Optional
from uuid import UUID
from langchain . chains import ConversationalRetrievalChain
from langchain . embeddings . ollama import OllamaEmbeddings
from langchain . llms . base import BaseLLM
from langchain . schema import format_document
from langchain_community . chat_models import ChatLiteLLM
from langchain_core . messages import get_buffer_string
from langchain_core . output_parsers import StrOutputParser
from langchain_core . prompts import ChatPromptTemplate , PromptTemplate
from langchain_core . runnables import RunnableParallel , RunnablePassthrough
from langchain_openai import ChatOpenAI , OpenAIEmbeddings
2024-04-20 12:22:05 +03:00
from utils . get_prompt_to_use import get_prompt_to_use
2024-02-15 01:01:35 +03:00
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
from modules . brain . service . brain_service import BrainService
from modules . chat . service . chat_service import ChatService
from pydantic import BaseModel , ConfigDict
from supabase . client import Client , create_client
from vectorstore . supabase import CustomSupabaseVectorStore
logger = get_logger ( __name__ )
# First step is to create the Rephrasing Prompt
_template = """ Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Chat History :
{ chat_history }
Follow Up Input : { question }
Standalone question : """
CONDENSE_QUESTION_PROMPT = PromptTemplate . from_template ( _template )
# Next is the answering prompt
template = """ Answer the question based only on the following context from files:
{ context }
Question : { question }
"""
ANSWER_PROMPT = ChatPromptTemplate . from_template ( template )
# How we format documents
DEFAULT_DOCUMENT_PROMPT = PromptTemplate . from_template (
template = " File {file_name} : {page_content} "
)
def is_valid_uuid ( uuid_to_test , version = 4 ) :
try :
uuid_obj = UUID ( uuid_to_test , version = version )
except ValueError :
return False
return str ( uuid_obj ) == uuid_to_test
brain_service = BrainService ( )
chat_service = ChatService ( )
class QuivrRAG ( BaseModel ) :
"""
Quivr implementation of the RAGInterface .
"""
model_config = ConfigDict ( arbitrary_types_allowed = True )
# Instantiate settings
brain_settings = BrainSettings ( ) # type: ignore other parameters are optional
# Default class attributes
model : str = None # pyright: ignore reportPrivateUsage=none
temperature : float = 0.1
chat_id : str = None # pyright: ignore reportPrivateUsage=none
brain_id : str = None # pyright: ignore reportPrivateUsage=none
max_tokens : int = 2000 # Output length
max_input : int = 2000
streaming : bool = False
@property
def embeddings ( self ) :
if self . brain_settings . ollama_api_base_url :
return OllamaEmbeddings (
base_url = self . brain_settings . ollama_api_base_url
) # pyright: ignore reportPrivateUsage=none
else :
return OpenAIEmbeddings ( )
@property
def prompt_to_use ( self ) :
if self . brain_id and is_valid_uuid ( self . brain_id ) :
return get_prompt_to_use ( UUID ( self . brain_id ) , self . prompt_id )
else :
return None
supabase_client : Optional [ Client ] = None
vector_store : Optional [ CustomSupabaseVectorStore ] = None
qa : Optional [ ConversationalRetrievalChain ] = None
prompt_id : Optional [ UUID ]
def __init__ (
self ,
model : str ,
brain_id : str ,
chat_id : str ,
streaming : bool = False ,
prompt_id : Optional [ UUID ] = None ,
max_tokens : int = 2000 ,
max_input : int = 2000 ,
* * kwargs ,
) :
super ( ) . __init__ (
model = model ,
brain_id = brain_id ,
chat_id = chat_id ,
streaming = streaming ,
max_tokens = max_tokens ,
max_input = max_input ,
* * kwargs ,
)
self . supabase_client = self . _create_supabase_client ( )
self . vector_store = self . _create_vector_store ( )
self . prompt_id = prompt_id
self . max_tokens = max_tokens
self . max_input = max_input
self . model = model
self . brain_id = brain_id
self . chat_id = chat_id
self . streaming = streaming
logger . info ( f " QuivrRAG initialized with model { model } and brain { brain_id } " )
logger . info ( " Max input length: " + str ( self . max_input ) )
def _create_supabase_client ( self ) - > Client :
return create_client (
self . brain_settings . supabase_url , self . brain_settings . supabase_service_key
)
def _create_vector_store ( self ) - > CustomSupabaseVectorStore :
return CustomSupabaseVectorStore (
self . supabase_client ,
self . embeddings ,
table_name = " vectors " ,
brain_id = self . brain_id ,
max_input = self . max_input ,
)
def _create_llm (
self ,
callbacks ,
model ,
streaming = False ,
temperature = 0 ,
) - > BaseLLM :
"""
Create a LLM with the given parameters
"""
if streaming and callbacks is None :
raise ValueError (
" Callbacks must be provided when using streaming language models "
)
api_base = None
if self . brain_settings . ollama_api_base_url and model . startswith ( " ollama " ) :
api_base = self . brain_settings . ollama_api_base_url
return ChatLiteLLM (
temperature = temperature ,
max_tokens = self . max_tokens ,
model = model ,
streaming = streaming ,
verbose = False ,
callbacks = callbacks ,
api_base = api_base ,
)
def _combine_documents (
docs , document_prompt = DEFAULT_DOCUMENT_PROMPT , document_separator = " \n \n "
) :
doc_strings = [ format_document ( doc , document_prompt ) for doc in docs ]
return document_separator . join ( doc_strings )
def get_retriever ( self ) :
return self . vector_store . as_retriever ( )
def get_chain ( self ) :
retriever = self . get_retriever ( )
_inputs = RunnableParallel (
standalone_question = RunnablePassthrough . assign (
chat_history = lambda x : get_buffer_string ( x [ " chat_history " ] )
)
| CONDENSE_QUESTION_PROMPT
| ChatOpenAI ( temperature = 0 )
| StrOutputParser ( ) ,
)
_context = {
" context " : itemgetter ( " standalone_question " )
| retriever
| self . _combine_documents ,
" question " : lambda x : x [ " standalone_question " ] ,
}
conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI ( )
return conversational_qa_chain