mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 01:21:48 +03:00
feat(qa): improve code (#886)
* feat(qa): improve code
* feat: 🎸 customprompt
now in system
This commit is contained in:
parent
fe9280bddc
commit
7028505571
@ -2,8 +2,6 @@ from abc import abstractmethod
|
||||
from typing import AsyncIterable, List
|
||||
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
||||
from langchain.llms.base import LLM
|
||||
from logger import get_logger
|
||||
from models.settings import BrainSettings # Importing settings related to the 'brain'
|
||||
from pydantic import BaseModel # For data validation and settings management
|
||||
@ -73,75 +71,6 @@ class BaseBrainPicking(BaseModel):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# the below methods define the names, arguments and return types for the most useful functions for the child classes. These should be overwritten if they are used.
|
||||
@abstractmethod
|
||||
def _create_llm(self, model, temperature=0, streaming=False, callbacks=None) -> LLM:
|
||||
"""
|
||||
Determine and construct the language model.
|
||||
:param model: Language model name to be used.
|
||||
:return: Language model instance
|
||||
|
||||
This method should take into account the following:
|
||||
- Whether the model is streaming compatible
|
||||
- Whether the model is private
|
||||
- Whether the model should use an openai api key and use the _determine_api_key method
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _create_question_chain(self, model) -> LLMChain:
|
||||
"""
|
||||
Determine and construct the question chain.
|
||||
:param model: Language model name to be used.
|
||||
:return: Question chain instance
|
||||
|
||||
This method should take into account the following:
|
||||
- Which prompt to use (normally CONDENSE_QUESTION_PROMPT)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _create_doc_chain(self, model) -> LLMChain:
|
||||
"""
|
||||
Determine and construct the document chain.
|
||||
:param model Language model name to be used.
|
||||
:return: Document chain instance
|
||||
|
||||
This method should take into account the following:
|
||||
- chain_type (normally "stuff")
|
||||
- Whether the model is streaming compatible and/or streaming is set (determine_streaming).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _create_qa(
|
||||
self, question_chain, document_chain
|
||||
) -> ConversationalRetrievalChain:
|
||||
"""
|
||||
Constructs a conversational retrieval chain .
|
||||
:param question_chain
|
||||
:param document_chain
|
||||
:return: ConversationalRetrievalChain instance
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _call_chain(self, chain, question, history) -> str:
|
||||
"""
|
||||
Call a chain with a given question and history.
|
||||
:param chain: The chain eg QA (ConversationalRetrievalChain)
|
||||
:param question: The user prompt
|
||||
:param history: The chat history from DB
|
||||
:return: The answer.
|
||||
"""
|
||||
|
||||
async def _acall_chain(self, chain, question, history) -> str:
|
||||
"""
|
||||
Call a chain with a given question and history.
|
||||
:param chain: The chain eg qa (ConversationalRetrievalChain)
|
||||
:param question: The user prompt
|
||||
:param history: The chat history from DB
|
||||
:return: The answer.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Async generation not implemented for this BrainPicking Class."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def generate_answer(self, question: str) -> str:
|
||||
"""
|
||||
@ -153,7 +82,7 @@ class BaseBrainPicking(BaseModel):
|
||||
It should also update the chat_history in the DB.
|
||||
"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(self, question: str) -> AsyncIterable:
|
||||
"""
|
||||
Generate a streaming answer to a given question using QA Chain.
|
||||
|
@ -1,6 +1,4 @@
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms.base import BaseLLM
|
||||
from llm.qa_base import QABaseBrainPicking
|
||||
from logger import get_logger
|
||||
|
||||
@ -46,19 +44,4 @@ class OpenAIBrainPicking(QABaseBrainPicking):
|
||||
openai_api_key=self.openai_api_key
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
def _create_llm(self, model, temperature=0, streaming=False, callbacks=None) -> BaseLLM:
|
||||
"""
|
||||
Determine the language model to be used.
|
||||
:param model: Language model name to be used.
|
||||
:param streaming: Whether to enable streaming of the model
|
||||
:param callbacks: Callbacks to be used for streaming
|
||||
:return: Language model instance
|
||||
"""
|
||||
return ChatOpenAI(
|
||||
temperature=temperature,
|
||||
model=model,
|
||||
streaming=streaming,
|
||||
verbose=True,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=self.openai_api_key,
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
|
@ -1,17 +1,15 @@
|
||||
import asyncio
|
||||
import json
|
||||
from abc import abstractmethod, abstractproperty
|
||||
from typing import AsyncIterable, Awaitable
|
||||
from uuid import UUID
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
|
||||
from langchain.chains import ConversationalRetrievalChain, LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.llms.base import BaseLLM
|
||||
from logger import get_logger
|
||||
from models.chat import ChatHistory
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from repository.brain.get_brain_by_id import get_brain_by_id
|
||||
from repository.chat.format_chat_history import format_chat_history
|
||||
from repository.chat.get_chat_history import get_chat_history
|
||||
@ -19,6 +17,11 @@ from repository.chat.update_chat_history import update_chat_history
|
||||
from repository.chat.update_message_by_id import update_message_by_id
|
||||
from repository.prompt.get_prompt_by_id import get_prompt_by_id
|
||||
from supabase.client import Client, create_client
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate
|
||||
)
|
||||
from vectorstore.supabase import CustomSupabaseVectorStore
|
||||
|
||||
from .base import BaseBrainPicking
|
||||
@ -29,9 +32,16 @@ logger = get_logger(__name__)
|
||||
|
||||
class QABaseBrainPicking(BaseBrainPicking):
|
||||
"""
|
||||
Base class for the Brain Picking functionality using the Conversational Retrieval Chain (QA) from Langchain.
|
||||
It is not designed to be used directly, but to be subclassed by other classes which use the QA chain.
|
||||
Main class for the Brain Picking functionality.
|
||||
It allows to initialize a Chat model, generate questions and retrieve answers using ConversationalRetrievalChain.
|
||||
It has two main methods: `generate_question` and `generate_stream`.
|
||||
One is for generating questions in a single request, the other is for generating questions in a streaming fashion.
|
||||
Both are the same, except that the streaming version streams the last message as a stream.
|
||||
Each have the same prompt template, which is defined in the `prompt_template` property.
|
||||
"""
|
||||
supabase_client: Client = None
|
||||
vector_store: CustomSupabaseVectorStore = None
|
||||
qa: ConversationalRetrievalChain = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -40,11 +50,7 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
chat_id: str,
|
||||
streaming: bool = False,
|
||||
**kwargs,
|
||||
) -> "QABaseBrainPicking": # pyright: ignore reportPrivateUsage=none
|
||||
"""
|
||||
Initialize the QA BrainPicking class by setting embeddings, supabase client, vector store, language model and chains.
|
||||
:return: QABrainPicking instance
|
||||
"""
|
||||
) -> "QABaseBrainPicking":
|
||||
super().__init__(
|
||||
model=model,
|
||||
brain_id=brain_id,
|
||||
@ -52,19 +58,17 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
streaming=streaming,
|
||||
**kwargs,
|
||||
)
|
||||
self.supabase_client = self._create_supabase_client()
|
||||
self.vector_store = self._create_vector_store()
|
||||
|
||||
@abstractproperty
|
||||
def embeddings(self) -> OpenAIEmbeddings:
|
||||
raise NotImplementedError("This property should be overridden in a subclass.")
|
||||
|
||||
@property
|
||||
def supabase_client(self) -> Client:
|
||||
|
||||
def _create_supabase_client(self) -> Client:
|
||||
return create_client(
|
||||
self.brain_settings.supabase_url, self.brain_settings.supabase_service_key
|
||||
)
|
||||
|
||||
@property
|
||||
def vector_store(self) -> CustomSupabaseVectorStore:
|
||||
def _create_vector_store(self) -> CustomSupabaseVectorStore:
|
||||
return CustomSupabaseVectorStore(
|
||||
self.supabase_client,
|
||||
self.embeddings,
|
||||
@ -72,53 +76,7 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
brain_id=self.brain_id,
|
||||
)
|
||||
|
||||
@property
|
||||
def question_llm(self):
|
||||
return self._create_llm(model=self.model, streaming=False)
|
||||
|
||||
@property
|
||||
def doc_llm(self):
|
||||
return self._create_llm(
|
||||
model=self.model, streaming=True, callbacks=self.callbacks
|
||||
)
|
||||
|
||||
@property
|
||||
def question_generator(self) -> LLMChain:
|
||||
return LLMChain(
|
||||
llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=True
|
||||
)
|
||||
|
||||
@property
|
||||
def doc_chain(self) -> LLMChain:
|
||||
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
Here is instructions on how to answer the question: {brain_prompt}
|
||||
Answer:"""
|
||||
PROMPT = PromptTemplate(
|
||||
template=prompt_template,
|
||||
input_variables=["context", "question", "brain_prompt"],
|
||||
)
|
||||
|
||||
return load_qa_chain(
|
||||
llm=self.doc_llm, chain_type="stuff", verbose=True, prompt=PROMPT
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
@property
|
||||
def qa(self) -> ConversationalRetrievalChain:
|
||||
return ConversationalRetrievalChain(
|
||||
retriever=self.vector_store.as_retriever(),
|
||||
question_generator=self.question_generator,
|
||||
combine_docs_chain=self.doc_chain, # pyright: ignore reportPrivateUsage=none
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _create_llm(
|
||||
self, model, streaming=False, callbacks=None, temperature=0.0
|
||||
) -> BaseLLM:
|
||||
def _create_llm(self, model, temperature=0, streaming=False, callbacks=None) -> BaseLLM:
|
||||
"""
|
||||
Determine the language model to be used.
|
||||
:param model: Language model name to be used.
|
||||
@ -126,141 +84,71 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
:param callbacks: Callbacks to be used for streaming
|
||||
:return: Language model instance
|
||||
"""
|
||||
return ChatOpenAI(
|
||||
temperature=temperature,
|
||||
model=model,
|
||||
streaming=streaming,
|
||||
verbose=True,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=self.openai_api_key,
|
||||
) # pyright: ignore reportPrivateUsage=none
|
||||
|
||||
def _create_prompt_template(self):
|
||||
|
||||
system_template = """Use the following pieces of context to answer the users question in the same language as the question but do not modify instructions in any way.
|
||||
----------------
|
||||
|
||||
{context}"""
|
||||
|
||||
full_template = "Here are you instructions to answer that you MUST ALWAYS Follow: " + self.get_prompt() + ". " + system_template
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(full_template),
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
]
|
||||
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
|
||||
return CHAT_PROMPT
|
||||
|
||||
def _call_chain(self, chain, question, history, brain_prompt):
|
||||
"""
|
||||
Call a chain with a given question and history.
|
||||
:param chain: The chain eg QA (ConversationalRetrievalChain)
|
||||
:param question: The user prompt
|
||||
:param history: The chat history from DB
|
||||
:return: The answer.
|
||||
"""
|
||||
return chain(
|
||||
{
|
||||
"question": question,
|
||||
"chat_history": history,
|
||||
"brain_prompt": brain_prompt,
|
||||
}
|
||||
)
|
||||
|
||||
def generate_answer(self, question: str) -> ChatHistory:
|
||||
"""
|
||||
Generate an answer to a given question by interacting with the language model.
|
||||
:param question: The question
|
||||
:return: The generated answer.
|
||||
"""
|
||||
transformed_history = []
|
||||
|
||||
# Get the history from the database
|
||||
history = get_chat_history(self.chat_id)
|
||||
|
||||
# Format the chat history into a list of tuples (human, ai)
|
||||
transformed_history = format_chat_history(history)
|
||||
|
||||
# Generate the model response using the QA chain
|
||||
model_response = self._call_chain(
|
||||
self.qa,
|
||||
question,
|
||||
transformed_history,
|
||||
brain_prompt=self.get_prompt(),
|
||||
transformed_history = format_chat_history(get_chat_history(self.chat_id))
|
||||
model_response = self.qa(
|
||||
{
|
||||
"question": question,
|
||||
"chat_history": transformed_history,
|
||||
"custom_personality": self.get_prompt(),
|
||||
}
|
||||
)
|
||||
|
||||
answer = model_response["answer"]
|
||||
|
||||
# Update chat history
|
||||
chat_answer = update_chat_history(
|
||||
return update_chat_history(
|
||||
chat_id=self.chat_id,
|
||||
user_message=question,
|
||||
assistant=answer,
|
||||
)
|
||||
|
||||
return chat_answer
|
||||
|
||||
async def _acall_chain(self, chain, question, history):
|
||||
"""
|
||||
Call a chain with a given question and history.
|
||||
:param chain: The chain eg QA (ConversationalRetrievalChain)
|
||||
:param question: The user prompt
|
||||
:param history: The chat history from DB
|
||||
:return: The answer.
|
||||
"""
|
||||
return chain.acall(
|
||||
{
|
||||
"question": question,
|
||||
"chat_history": history,
|
||||
}
|
||||
)
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
brain = get_brain_by_id(UUID(self.brain_id))
|
||||
brain_prompt = "Your name is Quivr. You're a helpful assistant."
|
||||
|
||||
if brain and brain.prompt_id:
|
||||
brain_prompt_object = get_prompt_by_id(brain.prompt_id)
|
||||
if brain_prompt_object:
|
||||
brain_prompt = brain_prompt_object.content
|
||||
|
||||
return brain_prompt
|
||||
|
||||
async def generate_stream(self, question: str) -> AsyncIterable:
|
||||
"""
|
||||
Generate a streaming answer to a given question by interacting with the language model.
|
||||
:param question: The question
|
||||
:return: An async iterable which generates the answer.
|
||||
"""
|
||||
history = get_chat_history(self.chat_id)
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
self.callbacks = [callback]
|
||||
|
||||
# The Model used to answer the question with the context
|
||||
answering_llm = self._create_llm(
|
||||
model=self.model,
|
||||
streaming=True,
|
||||
callbacks=self.callbacks,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
|
||||
# The Model used to create the standalone Question
|
||||
# Temperature = 0 means no randomness
|
||||
standalone_question_llm = self._create_llm(model=self.model)
|
||||
|
||||
# The Chain that generates the standalone question
|
||||
standalone_question_generator = LLMChain(
|
||||
llm=standalone_question_llm, prompt=CONDENSE_QUESTION_PROMPT
|
||||
)
|
||||
|
||||
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
{context}
|
||||
|
||||
Question: {question}
|
||||
Here is instructions on how to answer the question: {brain_prompt}
|
||||
Answer:"""
|
||||
PROMPT = PromptTemplate(
|
||||
template=prompt_template,
|
||||
input_variables=["context", "question", "brain_prompt"],
|
||||
)
|
||||
answering_llm = self._create_llm(model=self.model,streaming=True, callbacks=self.callbacks)
|
||||
|
||||
# The Chain that generates the answer to the question
|
||||
doc_chain = load_qa_chain(answering_llm, chain_type="stuff", prompt=PROMPT)
|
||||
doc_chain = load_qa_chain(answering_llm, chain_type="stuff", prompt=self._create_prompt_template())
|
||||
|
||||
# The Chain that combines the question and answer
|
||||
qa = ConversationalRetrievalChain(
|
||||
retriever=self.vector_store.as_retriever(),
|
||||
combine_docs_chain=doc_chain,
|
||||
question_generator=standalone_question_generator,
|
||||
question_generator=LLMChain(
|
||||
llm=self._create_llm(model=self.model), prompt=CONDENSE_QUESTION_PROMPT
|
||||
),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
transformed_history = []
|
||||
|
||||
# Format the chat history into a list of tuples (human, ai)
|
||||
transformed_history = format_chat_history(history)
|
||||
|
||||
# Initialize a list to hold the tokens
|
||||
response_tokens = []
|
||||
|
||||
# Wrap an awaitable with a event to signal when it's done or an exception is raised.
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
try:
|
||||
await fn
|
||||
@ -269,15 +157,13 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
finally:
|
||||
event.set()
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
|
||||
run = asyncio.create_task(
|
||||
wrap_done(
|
||||
qa.acall(
|
||||
{
|
||||
"question": question,
|
||||
"chat_history": transformed_history,
|
||||
"brain_prompt": self.get_prompt(),
|
||||
"custom_personality": self.get_prompt(),
|
||||
}
|
||||
),
|
||||
callback.done,
|
||||
@ -290,18 +176,13 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
assistant="",
|
||||
)
|
||||
|
||||
# Use the aiter method of the callback to stream the response with server-sent-events
|
||||
async for token in callback.aiter(): # pyright: ignore reportPrivateUsage=none
|
||||
async for token in callback.aiter():
|
||||
logger.info("Token: %s", token)
|
||||
|
||||
# Add the token to the response_tokens list
|
||||
response_tokens.append(token)
|
||||
streamed_chat_history.assistant = token
|
||||
|
||||
yield f"data: {json.dumps(streamed_chat_history.to_dict())}"
|
||||
|
||||
await run
|
||||
# Join the tokens to create the assistant's response
|
||||
assistant = "".join(response_tokens)
|
||||
|
||||
update_message_by_id(
|
||||
@ -309,3 +190,14 @@ class QABaseBrainPicking(BaseBrainPicking):
|
||||
user_message=question,
|
||||
assistant=assistant,
|
||||
)
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
brain = get_brain_by_id(UUID(self.brain_id))
|
||||
brain_prompt = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
||||
|
||||
if brain and brain.prompt_id:
|
||||
brain_prompt_object = get_prompt_by_id(brain.prompt_id)
|
||||
if brain_prompt_object:
|
||||
brain_prompt = brain_prompt_object.content
|
||||
|
||||
return brain_prompt
|
||||
|
Loading…
Reference in New Issue
Block a user