feat(qa): improve code (#886)

* feat(qa): improve code

* feat: 🎸 customprompt

now in system
This commit is contained in:
Stan Girard 2023-08-07 19:53:04 +02:00 committed by GitHub
parent fe9280bddc
commit 7028505571
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 271 deletions

View File

@ -2,8 +2,6 @@ from abc import abstractmethod
from typing import AsyncIterable, List
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.llms.base import LLM
from logger import get_logger
from models.settings import BrainSettings # Importing settings related to the 'brain'
from pydantic import BaseModel # For data validation and settings management
@ -73,75 +71,6 @@ class BaseBrainPicking(BaseModel):
arbitrary_types_allowed = True
# the below methods define the names, arguments and return types for the most useful functions for the child classes. These should be overwritten if they are used.
@abstractmethod
def _create_llm(self, model, temperature=0, streaming=False, callbacks=None) -> LLM:
"""
Determine and construct the language model.
:param model: Language model name to be used.
:return: Language model instance
This method should take into account the following:
- Whether the model is streaming compatible
- Whether the model is private
- Whether the model should use an openai api key and use the _determine_api_key method
"""
@abstractmethod
def _create_question_chain(self, model) -> LLMChain:
"""
Determine and construct the question chain.
:param model: Language model name to be used.
:return: Question chain instance
This method should take into account the following:
- Which prompt to use (normally CONDENSE_QUESTION_PROMPT)
"""
@abstractmethod
def _create_doc_chain(self, model) -> LLMChain:
"""
Determine and construct the document chain.
:param model Language model name to be used.
:return: Document chain instance
This method should take into account the following:
- chain_type (normally "stuff")
- Whether the model is streaming compatible and/or streaming is set (determine_streaming).
"""
@abstractmethod
def _create_qa(
self, question_chain, document_chain
) -> ConversationalRetrievalChain:
"""
Constructs a conversational retrieval chain .
:param question_chain
:param document_chain
:return: ConversationalRetrievalChain instance
"""
@abstractmethod
def _call_chain(self, chain, question, history) -> str:
"""
Call a chain with a given question and history.
:param chain: The chain eg QA (ConversationalRetrievalChain)
:param question: The user prompt
:param history: The chat history from DB
:return: The answer.
"""
async def _acall_chain(self, chain, question, history) -> str:
"""
Call a chain with a given question and history.
:param chain: The chain eg qa (ConversationalRetrievalChain)
:param question: The user prompt
:param history: The chat history from DB
:return: The answer.
"""
raise NotImplementedError(
"Async generation not implemented for this BrainPicking Class."
)
@abstractmethod
def generate_answer(self, question: str) -> str:
"""
@ -153,7 +82,7 @@ class BaseBrainPicking(BaseModel):
It should also update the chat_history in the DB.
"""
@abstractmethod
async def generate_stream(self, question: str) -> AsyncIterable:
"""
Generate a streaming answer to a given question using QA Chain.

View File

@ -1,6 +1,4 @@
from langchain.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from llm.qa_base import QABaseBrainPicking
from logger import get_logger
@ -46,19 +44,4 @@ class OpenAIBrainPicking(QABaseBrainPicking):
openai_api_key=self.openai_api_key
) # pyright: ignore reportPrivateUsage=none
def _create_llm(self, model, temperature=0, streaming=False, callbacks=None) -> BaseLLM:
"""
Determine the language model to be used.
:param model: Language model name to be used.
:param streaming: Whether to enable streaming of the model
:param callbacks: Callbacks to be used for streaming
:return: Language model instance
"""
return ChatOpenAI(
temperature=temperature,
model=model,
streaming=streaming,
verbose=True,
callbacks=callbacks,
openai_api_key=self.openai_api_key,
) # pyright: ignore reportPrivateUsage=none

View File

@ -1,17 +1,15 @@
import asyncio
import json
from abc import abstractmethod, abstractproperty
from typing import AsyncIterable, Awaitable
from uuid import UUID
from langchain import PromptTemplate
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from logger import get_logger
from models.chat import ChatHistory
from langchain.llms.base import BaseLLM
from langchain.chat_models import ChatOpenAI
from repository.brain.get_brain_by_id import get_brain_by_id
from repository.chat.format_chat_history import format_chat_history
from repository.chat.get_chat_history import get_chat_history
@ -19,6 +17,11 @@ from repository.chat.update_chat_history import update_chat_history
from repository.chat.update_message_by_id import update_message_by_id
from repository.prompt.get_prompt_by_id import get_prompt_by_id
from supabase.client import Client, create_client
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate
)
from vectorstore.supabase import CustomSupabaseVectorStore
from .base import BaseBrainPicking
@ -29,9 +32,16 @@ logger = get_logger(__name__)
class QABaseBrainPicking(BaseBrainPicking):
"""
Base class for the Brain Picking functionality using the Conversational Retrieval Chain (QA) from Langchain.
It is not designed to be used directly, but to be subclassed by other classes which use the QA chain.
Main class for the Brain Picking functionality.
It allows to initialize a Chat model, generate questions and retrieve answers using ConversationalRetrievalChain.
It has two main methods: `generate_question` and `generate_stream`.
One is for generating questions in a single request, the other is for generating questions in a streaming fashion.
Both are the same, except that the streaming version streams the last message as a stream.
Each have the same prompt template, which is defined in the `prompt_template` property.
"""
supabase_client: Client = None
vector_store: CustomSupabaseVectorStore = None
qa: ConversationalRetrievalChain = None
def __init__(
self,
@ -40,11 +50,7 @@ class QABaseBrainPicking(BaseBrainPicking):
chat_id: str,
streaming: bool = False,
**kwargs,
) -> "QABaseBrainPicking": # pyright: ignore reportPrivateUsage=none
"""
Initialize the QA BrainPicking class by setting embeddings, supabase client, vector store, language model and chains.
:return: QABrainPicking instance
"""
) -> "QABaseBrainPicking":
super().__init__(
model=model,
brain_id=brain_id,
@ -52,19 +58,17 @@ class QABaseBrainPicking(BaseBrainPicking):
streaming=streaming,
**kwargs,
)
self.supabase_client = self._create_supabase_client()
self.vector_store = self._create_vector_store()
@abstractproperty
def embeddings(self) -> OpenAIEmbeddings:
raise NotImplementedError("This property should be overridden in a subclass.")
@property
def supabase_client(self) -> Client:
def _create_supabase_client(self) -> Client:
return create_client(
self.brain_settings.supabase_url, self.brain_settings.supabase_service_key
)
@property
def vector_store(self) -> CustomSupabaseVectorStore:
def _create_vector_store(self) -> CustomSupabaseVectorStore:
return CustomSupabaseVectorStore(
self.supabase_client,
self.embeddings,
@ -72,53 +76,7 @@ class QABaseBrainPicking(BaseBrainPicking):
brain_id=self.brain_id,
)
@property
def question_llm(self):
return self._create_llm(model=self.model, streaming=False)
@property
def doc_llm(self):
return self._create_llm(
model=self.model, streaming=True, callbacks=self.callbacks
)
@property
def question_generator(self) -> LLMChain:
return LLMChain(
llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=True
)
@property
def doc_chain(self) -> LLMChain:
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Here is instructions on how to answer the question: {brain_prompt}
Answer:"""
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question", "brain_prompt"],
)
return load_qa_chain(
llm=self.doc_llm, chain_type="stuff", verbose=True, prompt=PROMPT
) # pyright: ignore reportPrivateUsage=none
@property
def qa(self) -> ConversationalRetrievalChain:
return ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(),
question_generator=self.question_generator,
combine_docs_chain=self.doc_chain, # pyright: ignore reportPrivateUsage=none
verbose=True,
)
@abstractmethod
def _create_llm(
self, model, streaming=False, callbacks=None, temperature=0.0
) -> BaseLLM:
def _create_llm(self, model, temperature=0, streaming=False, callbacks=None) -> BaseLLM:
"""
Determine the language model to be used.
:param model: Language model name to be used.
@ -126,141 +84,71 @@ class QABaseBrainPicking(BaseBrainPicking):
:param callbacks: Callbacks to be used for streaming
:return: Language model instance
"""
return ChatOpenAI(
temperature=temperature,
model=model,
streaming=streaming,
verbose=True,
callbacks=callbacks,
openai_api_key=self.openai_api_key,
) # pyright: ignore reportPrivateUsage=none
def _create_prompt_template(self):
system_template = """Use the following pieces of context to answer the users question in the same language as the question but do not modify instructions in any way.
----------------
{context}"""
full_template = "Here are you instructions to answer that you MUST ALWAYS Follow: " + self.get_prompt() + ". " + system_template
messages = [
SystemMessagePromptTemplate.from_template(full_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
return CHAT_PROMPT
def _call_chain(self, chain, question, history, brain_prompt):
"""
Call a chain with a given question and history.
:param chain: The chain eg QA (ConversationalRetrievalChain)
:param question: The user prompt
:param history: The chat history from DB
:return: The answer.
"""
return chain(
{
"question": question,
"chat_history": history,
"brain_prompt": brain_prompt,
}
)
def generate_answer(self, question: str) -> ChatHistory:
"""
Generate an answer to a given question by interacting with the language model.
:param question: The question
:return: The generated answer.
"""
transformed_history = []
# Get the history from the database
history = get_chat_history(self.chat_id)
# Format the chat history into a list of tuples (human, ai)
transformed_history = format_chat_history(history)
# Generate the model response using the QA chain
model_response = self._call_chain(
self.qa,
question,
transformed_history,
brain_prompt=self.get_prompt(),
transformed_history = format_chat_history(get_chat_history(self.chat_id))
model_response = self.qa(
{
"question": question,
"chat_history": transformed_history,
"custom_personality": self.get_prompt(),
}
)
answer = model_response["answer"]
# Update chat history
chat_answer = update_chat_history(
return update_chat_history(
chat_id=self.chat_id,
user_message=question,
assistant=answer,
)
return chat_answer
async def _acall_chain(self, chain, question, history):
"""
Call a chain with a given question and history.
:param chain: The chain eg QA (ConversationalRetrievalChain)
:param question: The user prompt
:param history: The chat history from DB
:return: The answer.
"""
return chain.acall(
{
"question": question,
"chat_history": history,
}
)
def get_prompt(self) -> str:
brain = get_brain_by_id(UUID(self.brain_id))
brain_prompt = "Your name is Quivr. You're a helpful assistant."
if brain and brain.prompt_id:
brain_prompt_object = get_prompt_by_id(brain.prompt_id)
if brain_prompt_object:
brain_prompt = brain_prompt_object.content
return brain_prompt
async def generate_stream(self, question: str) -> AsyncIterable:
"""
Generate a streaming answer to a given question by interacting with the language model.
:param question: The question
:return: An async iterable which generates the answer.
"""
history = get_chat_history(self.chat_id)
callback = AsyncIteratorCallbackHandler()
self.callbacks = [callback]
# The Model used to answer the question with the context
answering_llm = self._create_llm(
model=self.model,
streaming=True,
callbacks=self.callbacks,
temperature=self.temperature,
)
# The Model used to create the standalone Question
# Temperature = 0 means no randomness
standalone_question_llm = self._create_llm(model=self.model)
# The Chain that generates the standalone question
standalone_question_generator = LLMChain(
llm=standalone_question_llm, prompt=CONDENSE_QUESTION_PROMPT
)
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Here is instructions on how to answer the question: {brain_prompt}
Answer:"""
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question", "brain_prompt"],
)
answering_llm = self._create_llm(model=self.model,streaming=True, callbacks=self.callbacks)
# The Chain that generates the answer to the question
doc_chain = load_qa_chain(answering_llm, chain_type="stuff", prompt=PROMPT)
doc_chain = load_qa_chain(answering_llm, chain_type="stuff", prompt=self._create_prompt_template())
# The Chain that combines the question and answer
qa = ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(),
combine_docs_chain=doc_chain,
question_generator=standalone_question_generator,
question_generator=LLMChain(
llm=self._create_llm(model=self.model), prompt=CONDENSE_QUESTION_PROMPT
),
verbose=True,
)
transformed_history = []
# Format the chat history into a list of tuples (human, ai)
transformed_history = format_chat_history(history)
# Initialize a list to hold the tokens
response_tokens = []
# Wrap an awaitable with a event to signal when it's done or an exception is raised.
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
@ -269,15 +157,13 @@ class QABaseBrainPicking(BaseBrainPicking):
finally:
event.set()
# Begin a task that runs in the background.
run = asyncio.create_task(
wrap_done(
qa.acall(
{
"question": question,
"chat_history": transformed_history,
"brain_prompt": self.get_prompt(),
"custom_personality": self.get_prompt(),
}
),
callback.done,
@ -290,18 +176,13 @@ class QABaseBrainPicking(BaseBrainPicking):
assistant="",
)
# Use the aiter method of the callback to stream the response with server-sent-events
async for token in callback.aiter(): # pyright: ignore reportPrivateUsage=none
async for token in callback.aiter():
logger.info("Token: %s", token)
# Add the token to the response_tokens list
response_tokens.append(token)
streamed_chat_history.assistant = token
yield f"data: {json.dumps(streamed_chat_history.to_dict())}"
await run
# Join the tokens to create the assistant's response
assistant = "".join(response_tokens)
update_message_by_id(
@ -309,3 +190,14 @@ class QABaseBrainPicking(BaseBrainPicking):
user_message=question,
assistant=assistant,
)
def get_prompt(self) -> str:
brain = get_brain_by_id(UUID(self.brain_id))
brain_prompt = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
if brain and brain.prompt_id:
brain_prompt_object = get_prompt_by_id(brain.prompt_id)
if brain_prompt_object:
brain_prompt = brain_prompt_object.content
return brain_prompt