quivr/backend/modules/brain/rags/new_quivr_rag.py
Stan Girard 08e015af6c
feat(lcel): migrated to lcel and pydantic (#2185)
# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):

---------

Co-authored-by: Antoine Dewez <44063631+Zewed@users.noreply.github.com>
2024-02-14 14:01:35 -08:00

211 lines
6.6 KiB
Python

from operator import itemgetter
from typing import Optional
from uuid import UUID
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings.ollama import OllamaEmbeddings
from langchain.llms.base import BaseLLM
from langchain.schema import format_document
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.messages import get_buffer_string
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from llm.utils.get_prompt_to_use import get_prompt_to_use
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
from modules.brain.service.brain_service import BrainService
from modules.chat.service.chat_service import ChatService
from pydantic import BaseModel, ConfigDict
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
logger = get_logger(__name__)
# First step is to create the Rephrasing Prompt
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
# Next is the answering prompt
template = """Answer the question based only on the following context from files:
{context}
Question: {question}
"""
ANSWER_PROMPT = ChatPromptTemplate.from_template(template)
# How we format documents
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
template="File {file_name}: {page_content}"
)
def is_valid_uuid(uuid_to_test, version=4):
try:
uuid_obj = UUID(uuid_to_test, version=version)
except ValueError:
return False
return str(uuid_obj) == uuid_to_test
brain_service = BrainService()
chat_service = ChatService()
class QuivrRAG(BaseModel):
"""
Quivr implementation of the RAGInterface.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
# Instantiate settings
brain_settings = BrainSettings() # type: ignore other parameters are optional
# Default class attributes
model: str = None # pyright: ignore reportPrivateUsage=none
temperature: float = 0.1
chat_id: str = None # pyright: ignore reportPrivateUsage=none
brain_id: str = None # pyright: ignore reportPrivateUsage=none
max_tokens: int = 2000 # Output length
max_input: int = 2000
streaming: bool = False
@property
def embeddings(self):
if self.brain_settings.ollama_api_base_url:
return OllamaEmbeddings(
base_url=self.brain_settings.ollama_api_base_url
) # pyright: ignore reportPrivateUsage=none
else:
return OpenAIEmbeddings()
@property
def prompt_to_use(self):
if self.brain_id and is_valid_uuid(self.brain_id):
return get_prompt_to_use(UUID(self.brain_id), self.prompt_id)
else:
return None
supabase_client: Optional[Client] = None
vector_store: Optional[CustomSupabaseVectorStore] = None
qa: Optional[ConversationalRetrievalChain] = None
prompt_id: Optional[UUID]
def __init__(
self,
model: str,
brain_id: str,
chat_id: str,
streaming: bool = False,
prompt_id: Optional[UUID] = None,
max_tokens: int = 2000,
max_input: int = 2000,
**kwargs,
):
super().__init__(
model=model,
brain_id=brain_id,
chat_id=chat_id,
streaming=streaming,
max_tokens=max_tokens,
max_input=max_input,
**kwargs,
)
self.supabase_client = self._create_supabase_client()
self.vector_store = self._create_vector_store()
self.prompt_id = prompt_id
self.max_tokens = max_tokens
self.max_input = max_input
self.model = model
self.brain_id = brain_id
self.chat_id = chat_id
self.streaming = streaming
logger.info(f"QuivrRAG initialized with model {model} and brain {brain_id}")
logger.info("Max input length: " + str(self.max_input))
def _create_supabase_client(self) -> Client:
return create_client(
self.brain_settings.supabase_url, self.brain_settings.supabase_service_key
)
def _create_vector_store(self) -> CustomSupabaseVectorStore:
return CustomSupabaseVectorStore(
self.supabase_client,
self.embeddings,
table_name="vectors",
brain_id=self.brain_id,
max_input=self.max_input,
)
def _create_llm(
self,
callbacks,
model,
streaming=False,
temperature=0,
) -> BaseLLM:
"""
Create a LLM with the given parameters
"""
if streaming and callbacks is None:
raise ValueError(
"Callbacks must be provided when using streaming language models"
)
api_base = None
if self.brain_settings.ollama_api_base_url and model.startswith("ollama"):
api_base = self.brain_settings.ollama_api_base_url
return ChatLiteLLM(
temperature=temperature,
max_tokens=self.max_tokens,
model=model,
streaming=streaming,
verbose=False,
callbacks=callbacks,
api_base=api_base,
)
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
def get_retriever(self):
return self.vector_store.as_retriever()
def get_chain(self):
retriever = self.get_retriever()
_inputs = RunnableParallel(
standalone_question=RunnablePassthrough.assign(
chat_history=lambda x: get_buffer_string(x["chat_history"])
)
| CONDENSE_QUESTION_PROMPT
| ChatOpenAI(temperature=0)
| StrOutputParser(),
)
_context = {
"context": itemgetter("standalone_question")
| retriever
| self._combine_documents,
"question": lambda x: x["standalone_question"],
}
conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI()
return conversational_qa_chain