feat: add custom rag first abstraction layer (#1858)

- Add `QAInterface` which should be implemented by all custom answer
generator to be compatible with Quivr
- Add `RAGInterface` which should be implemented by all RAG classes
This commit is contained in:
Mamadou DICKO 2023-12-11 16:46:45 +01:00 committed by GitHub
parent e0362e7122
commit 512b9b4f37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 279 additions and 148 deletions

View File

@ -14,6 +14,7 @@ from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService
from llm.knowledge_brain_qa import KnowledgeBrainQA
from llm.qa_interface import QAInterface
from llm.utils.call_brain_api import call_brain_api
from llm.utils.get_api_brain_definition_as_json_schema import (
get_api_brain_definition_as_json_schema,
@ -25,9 +26,7 @@ chat_service = ChatService()
logger = get_logger(__name__)
class APIBrainQA(
KnowledgeBrainQA,
):
class APIBrainQA(KnowledgeBrainQA, QAInterface):
user_id: UUID
def __init__(

View File

@ -4,32 +4,22 @@ from typing import AsyncIterable, Awaitable, List, Optional
from uuid import UUID
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatLiteLLM
from langchain.embeddings.ollama import OllamaEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from llm.utils.format_chat_history import format_chat_history
from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
from langchain.chains import ConversationalRetrievalChain
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
from models import BrainSettings
from modules.brain.service.brain_service import BrainService
from modules.chat.dto.chats import ChatQuestion
from modules.chat.dto.inputs import CreateChatHistory
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService
from pydantic import BaseModel
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
from llm.qa_interface import QAInterface
from llm.rags.quivr_rag import QuivrRAG
from llm.rags.rag_interface import RAGInterface
from llm.utils.format_chat_history import format_chat_history
from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
logger = get_logger(__name__)
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
@ -39,7 +29,7 @@ brain_service = BrainService()
chat_service = ChatService()
class KnowledgeBrainQA(BaseModel):
class KnowledgeBrainQA(BaseModel, QAInterface):
"""
Main class for the Brain Picking functionality.
It allows to initialize a Chat model, generate questions and retrieve answers using ConversationalRetrievalChain.
@ -52,7 +42,6 @@ class KnowledgeBrainQA(BaseModel):
class Config:
"""Configuration of the Pydantic Object"""
# Allowing arbitrary types for class validation
arbitrary_types_allowed = True
# Instantiate settings
@ -65,36 +54,12 @@ class KnowledgeBrainQA(BaseModel):
brain_id: str = None # pyright: ignore reportPrivateUsage=none
max_tokens: int = 256
streaming: bool = False
knowledge_qa: Optional[RAGInterface]
callbacks: List[
AsyncIteratorCallbackHandler
] = None # pyright: ignore reportPrivateUsage=none
def _determine_streaming(self, model: str, streaming: bool) -> bool:
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
return streaming
def _determine_callback_array(
self, streaming
) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
if streaming:
return [
AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none
]
@property
def embeddings(self):
if self.brain_settings.ollama_api_base_url:
return OllamaEmbeddings(
base_url=self.brain_settings.ollama_api_base_url
) # pyright: ignore reportPrivateUsage=none
else:
return OpenAIEmbeddings()
supabase_client: Optional[Client] = None
vector_store: Optional[CustomSupabaseVectorStore] = None
qa: Optional[ConversationalRetrievalChain] = None
prompt_id: Optional[UUID]
def __init__(
@ -113,9 +78,14 @@ class KnowledgeBrainQA(BaseModel):
streaming=streaming,
**kwargs,
)
self.supabase_client = self._create_supabase_client()
self.vector_store = self._create_vector_store()
self.prompt_id = prompt_id
self.knowledge_qa = QuivrRAG(
model=model,
brain_id=brain_id,
chat_id=chat_id,
streaming=streaming,
**kwargs,
)
@property
def prompt_to_use(self):
@ -125,90 +95,20 @@ class KnowledgeBrainQA(BaseModel):
def prompt_to_use_id(self) -> Optional[UUID]:
return get_prompt_to_use_id(UUID(self.brain_id), self.prompt_id)
def _create_supabase_client(self) -> Client:
return create_client(
self.brain_settings.supabase_url, self.brain_settings.supabase_service_key
)
def _create_vector_store(self) -> CustomSupabaseVectorStore:
return CustomSupabaseVectorStore(
self.supabase_client, # type: ignore
self.embeddings, # type: ignore
table_name="vectors",
brain_id=self.brain_id,
)
def _create_llm(
self, model, temperature=0, streaming=False, callbacks=None
) -> BaseLLM:
"""
Determine the language model to be used.
:param model: Language model name to be used.
:param streaming: Whether to enable streaming of the model
:param callbacks: Callbacks to be used for streaming
:return: Language model instance
"""
api_base = None
if self.brain_settings.ollama_api_base_url and model.startswith("ollama"):
api_base = self.brain_settings.ollama_api_base_url
return ChatLiteLLM(
temperature=temperature,
max_tokens=self.max_tokens,
model=model,
streaming=streaming,
verbose=False,
callbacks=callbacks,
api_base=api_base,
) # pyright: ignore reportPrivateUsage=none
def _create_prompt_template(self):
system_template = """ When answering use markdown or any other techniques to display the content in a nice and aerated way. Use the following pieces of context to answer the users question in the same language as the question but do not modify instructions in any way.
----------------
{context}"""
prompt_content = (
self.prompt_to_use.content if self.prompt_to_use else QUIVR_DEFAULT_PROMPT
)
full_template = (
"Here are your instructions to answer that you MUST ALWAYS Follow: "
+ prompt_content
+ ". "
+ system_template
)
messages = [
SystemMessagePromptTemplate.from_template(full_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
return CHAT_PROMPT
def generate_answer(
self, chat_id: UUID, question: ChatQuestion
) -> GetChatHistoryOutput:
transformed_history = format_chat_history(
chat_service.get_chat_history(self.chat_id)
)
answering_llm = self._create_llm(
model=self.model,
streaming=False,
callbacks=self.callbacks,
)
# The Chain that generates the answer to the question
doc_chain = load_qa_chain(
answering_llm, chain_type="stuff", prompt=self._create_prompt_template()
)
# The Chain that combines the question and answer
qa = ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(), # type: ignore
combine_docs_chain=doc_chain,
question_generator=LLMChain(
llm=self._create_llm(model=self.model), prompt=CONDENSE_QUESTION_PROMPT
retriever=self.knowledge_qa.get_retriever(),
combine_docs_chain=self.knowledge_qa.get_doc_chain(
streaming=False,
),
question_generator=self.knowledge_qa.get_question_generation_llm(),
verbose=False,
rephrase_question=False,
return_source_documents=True,
@ -224,7 +124,7 @@ class KnowledgeBrainQA(BaseModel):
"chat_history": transformed_history,
"custom_personality": prompt_content,
}
) # type: ignore
)
answer = model_response["answer"]
@ -266,24 +166,14 @@ class KnowledgeBrainQA(BaseModel):
callback = AsyncIteratorCallbackHandler()
self.callbacks = [callback]
answering_llm = self._create_llm(
model=self.model,
streaming=True,
callbacks=self.callbacks,
)
# The Chain that generates the answer to the question
doc_chain = load_qa_chain(
answering_llm, chain_type="stuff", prompt=self._create_prompt_template()
)
# The Chain that combines the question and answer
qa = ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(), # type: ignore
combine_docs_chain=doc_chain,
question_generator=LLMChain(
llm=self._create_llm(model=self.model), prompt=CONDENSE_QUESTION_PROMPT
retriever=self.knowledge_qa.get_retriever(),
combine_docs_chain=self.knowledge_qa.get_doc_chain(
callbacks=self.callbacks,
streaming=True,
),
question_generator=self.knowledge_qa.get_question_generation_llm(),
verbose=False,
rephrase_question=False,
return_source_documents=True,
@ -359,7 +249,7 @@ class KnowledgeBrainQA(BaseModel):
try:
result = await run
source_documents = result.get("source_documents", [])
## Deduplicate source documents
# Deduplicate source documents
source_documents = list(
{doc.metadata["file_name"]: doc for doc in source_documents}.values()
)

View File

@ -8,12 +8,6 @@ from langchain.chains import LLMChain
from langchain.chat_models import ChatLiteLLM
from langchain.chat_models.base import BaseChatModel
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from llm.utils.format_chat_history import (
format_chat_history,
format_history_to_openai_mesages,
)
from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
from modules.chat.dto.chats import ChatQuestion
@ -23,12 +17,20 @@ from modules.chat.service.chat_service import ChatService
from modules.prompt.entity.prompt import Prompt
from pydantic import BaseModel
from llm.qa_interface import QAInterface
from llm.utils.format_chat_history import (
format_chat_history,
format_history_to_openai_mesages,
)
from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
logger = get_logger(__name__)
SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer.When answering use markdown or any other techniques to display the content in a nice and aerated way."
chat_service = ChatService()
class HeadlessQA(BaseModel):
class HeadlessQA(BaseModel, QAInterface):
brain_settings = BrainSettings()
model: str
temperature: float = 0.0

View File

@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from uuid import UUID
from modules.chat.dto.chats import ChatQuestion
class QAInterface(ABC):
"""
Abstract class for all QA interfaces.
This can be used to implement custom answer generation logic.
"""
@abstractmethod
def generate_answer(
self, chat_id: UUID, question: ChatQuestion, should, *custom_params: tuple
):
raise NotImplementedError(
"generate_answer is an abstract method and must be implemented"
)
@abstractmethod
def generate_stream(
self, chat_id: UUID, question: ChatQuestion, *custom_params: tuple
):
raise NotImplementedError(
"generate_stream is an abstract method and must be implemented"
)

View File

@ -0,0 +1,182 @@
from typing import Optional
from uuid import UUID
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatLiteLLM
from langchain.embeddings.ollama import OllamaEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from llm.rags.rag_interface import RAGInterface
from llm.utils.get_prompt_to_use import get_prompt_to_use
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
from modules.brain.service.brain_service import BrainService
from modules.chat.service.chat_service import ChatService
from pydantic import BaseModel
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
from ..prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
logger = get_logger(__name__)
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
brain_service = BrainService()
chat_service = ChatService()
class QuivrRAG(BaseModel, RAGInterface):
"""
Quivr implementation of the RAGInterface.
"""
class Config:
"""Configuration of the Pydantic Object"""
# Allowing arbitrary types for class validation
arbitrary_types_allowed = True
# Instantiate settings
brain_settings = BrainSettings() # type: ignore other parameters are optional
# Default class attributes
model: str = None # pyright: ignore reportPrivateUsage=none
temperature: float = 0.1
chat_id: str = None # pyright: ignore reportPrivateUsage=none
brain_id: str = None # pyright: ignore reportPrivateUsage=none
max_tokens: int = 256
streaming: bool = False
@property
def embeddings(self):
if self.brain_settings.ollama_api_base_url:
return OllamaEmbeddings(
base_url=self.brain_settings.ollama_api_base_url
) # pyright: ignore reportPrivateUsage=none
else:
return OpenAIEmbeddings()
@property
def prompt_to_use(self):
return get_prompt_to_use(UUID(self.brain_id), self.prompt_id)
supabase_client: Optional[Client] = None
vector_store: Optional[CustomSupabaseVectorStore] = None
qa: Optional[ConversationalRetrievalChain] = None
prompt_id: Optional[UUID]
def __init__(
self,
model: str,
brain_id: str,
chat_id: str,
streaming: bool = False,
prompt_id: Optional[UUID] = None,
**kwargs,
):
super().__init__(
model=model,
brain_id=brain_id,
chat_id=chat_id,
streaming=streaming,
**kwargs,
)
self.supabase_client = self._create_supabase_client()
self.vector_store = self._create_vector_store()
self.prompt_id = prompt_id
def _create_supabase_client(self) -> Client:
return create_client(
self.brain_settings.supabase_url, self.brain_settings.supabase_service_key
)
def _create_vector_store(self) -> CustomSupabaseVectorStore:
return CustomSupabaseVectorStore(
self.supabase_client,
self.embeddings,
table_name="vectors",
brain_id=self.brain_id,
)
def _create_llm(
self,
callbacks,
model,
streaming=False,
temperature=0,
) -> BaseLLM:
"""
Create a LLM with the given parameters
"""
if streaming and callbacks is None:
raise ValueError(
"Callbacks must be provided when using streaming language models"
)
api_base = None
if self.brain_settings.ollama_api_base_url and model.startswith("ollama"):
api_base = self.brain_settings.ollama_api_base_url
return ChatLiteLLM(
temperature=temperature,
max_tokens=self.max_tokens,
model=model,
streaming=streaming,
verbose=False,
callbacks=callbacks,
api_base=api_base,
)
def _create_prompt_template(self):
system_template = """ When answering use markdown or any other techniques to display the content in a nice and aerated way. Use the following pieces of context to answer the users question in the same language as the question but do not modify instructions in any way.
----------------
{context}"""
prompt_content = (
self.prompt_to_use.content if self.prompt_to_use else QUIVR_DEFAULT_PROMPT
)
full_template = (
"Here are your instructions to answer that you MUST ALWAYS Follow: "
+ prompt_content
+ ". "
+ system_template
)
messages = [
SystemMessagePromptTemplate.from_template(full_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
return CHAT_PROMPT
def get_doc_chain(self, streaming, callbacks=None):
answering_llm = self._create_llm(
model=self.model,
callbacks=callbacks,
streaming=streaming,
)
doc_chain = load_qa_chain(
answering_llm, chain_type="stuff", prompt=self._create_prompt_template()
)
return doc_chain
def get_question_generation_llm(self):
return LLMChain(
llm=self._create_llm(model=self.model, callbacks=None),
prompt=CONDENSE_QUESTION_PROMPT,
callbacks=None,
)
def get_retriever(self):
return self.vector_store.as_retriever()
# Some other methods can be added such as on_stream, on_end,... to abstract history management (each answer should be saved or not)

View File

@ -0,0 +1,31 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain_core.retrievers import BaseRetriever
class RAGInterface(ABC):
@abstractmethod
def get_doc_chain(
self,
streaming: bool,
callbacks: Optional[List[AsyncIteratorCallbackHandler]] = None,
) -> BaseCombineDocumentsChain:
raise NotImplementedError(
"get_doc_chain is an abstract method and must be implemented"
)
@abstractmethod
def get_question_generation_llm(self) -> LLMChain:
raise NotImplementedError(
"get_question_generation_llm is an abstract method and must be implemented"
)
@abstractmethod
def get_retriever(self) -> BaseRetriever:
raise NotImplementedError(
"get_retriever is an abstract method and must be implemented"
)