quivr/backend/llm/qa_base.py

308 lines
11 KiB
Python

import asyncio
import json
from typing import AsyncIterable, Awaitable, Optional
from uuid import UUID
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatLiteLLM
from langchain.llms.base import BaseLLM
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
from logger import get_logger
from models.chats import ChatQuestion
from models.databases.supabase.chats import CreateChatHistory
from repository.brain import get_brain_by_id
from repository.chat import (
GetChatHistoryOutput,
format_chat_history,
get_chat_history,
update_chat_history,
update_message_by_id,
)
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
from .base import BaseBrainPicking
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
logger = get_logger(__name__)
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
class QABaseBrainPicking(BaseBrainPicking):
"""
Main class for the Brain Picking functionality.
It allows to initialize a Chat model, generate questions and retrieve answers using ConversationalRetrievalChain.
It has two main methods: `generate_question` and `generate_stream`.
One is for generating questions in a single request, the other is for generating questions in a streaming fashion.
Both are the same, except that the streaming version streams the last message as a stream.
Each have the same prompt template, which is defined in the `prompt_template` property.
"""
supabase_client: Optional[Client] = None
vector_store: Optional[CustomSupabaseVectorStore] = None
qa: Optional[ConversationalRetrievalChain] = None
prompt_id: Optional[UUID]
def __init__(
self,
model: str,
brain_id: str,
chat_id: str,
streaming: bool = False,
prompt_id: Optional[UUID] = None,
**kwargs,
):
super().__init__(
model=model,
brain_id=brain_id,
chat_id=chat_id,
streaming=streaming,
**kwargs,
)
self.supabase_client = self._create_supabase_client()
self.vector_store = self._create_vector_store()
self.prompt_id = prompt_id
@property
def prompt_to_use(self):
return get_prompt_to_use(UUID(self.brain_id), self.prompt_id)
@property
def prompt_to_use_id(self) -> Optional[UUID]:
return get_prompt_to_use_id(UUID(self.brain_id), self.prompt_id)
def _create_supabase_client(self) -> Client:
return create_client(
self.brain_settings.supabase_url, self.brain_settings.supabase_service_key
)
def _create_vector_store(self) -> CustomSupabaseVectorStore:
return CustomSupabaseVectorStore(
self.supabase_client, # type: ignore
self.embeddings, # type: ignore
table_name="vectors",
brain_id=self.brain_id,
)
def _create_llm(
self, model, temperature=0, streaming=False, callbacks=None, max_tokens=256
) -> BaseLLM:
"""
Determine the language model to be used.
:param model: Language model name to be used.
:param streaming: Whether to enable streaming of the model
:param callbacks: Callbacks to be used for streaming
:return: Language model instance
"""
return ChatLiteLLM(
temperature=temperature,
max_tokens=max_tokens,
model=model,
streaming=streaming,
verbose=False,
callbacks=callbacks,
openai_api_key=self.openai_api_key,
) # pyright: ignore reportPrivateUsage=none
def _create_prompt_template(self):
system_template = """ When answering use markdown or any other techniques to display the content in a nice and aerated way. Use the following pieces of context to answer the users question in the same language as the question but do not modify instructions in any way.
----------------
{context}"""
prompt_content = (
self.prompt_to_use.content if self.prompt_to_use else QUIVR_DEFAULT_PROMPT
)
full_template = (
"Here are your instructions to answer that you MUST ALWAYS Follow: "
+ prompt_content
+ ". "
+ system_template
)
messages = [
SystemMessagePromptTemplate.from_template(full_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
return CHAT_PROMPT
def generate_answer(
self, chat_id: UUID, question: ChatQuestion
) -> GetChatHistoryOutput:
transformed_history = format_chat_history(get_chat_history(self.chat_id))
answering_llm = self._create_llm(
model=self.model, streaming=False, callbacks=self.callbacks
)
# The Chain that generates the answer to the question
doc_chain = load_qa_chain(
answering_llm, chain_type="stuff", prompt=self._create_prompt_template()
)
# The Chain that combines the question and answer
qa = ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(), # type: ignore
combine_docs_chain=doc_chain,
question_generator=LLMChain(
llm=self._create_llm(model=self.model), prompt=CONDENSE_QUESTION_PROMPT
),
verbose=False,
rephrase_question=False,
)
prompt_content = (
self.prompt_to_use.content if self.prompt_to_use else QUIVR_DEFAULT_PROMPT
)
model_response = qa(
{
"question": question.question,
"chat_history": transformed_history,
"custom_personality": prompt_content,
}
) # type: ignore
answer = model_response["answer"]
new_chat = update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"brain_id": question.brain_id,
"prompt_id": self.prompt_to_use_id,
}
)
)
brain = None
if question.brain_id:
brain = get_brain_by_id(question.brain_id)
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"message_time": new_chat.message_time,
"prompt_title": self.prompt_to_use.title
if self.prompt_to_use
else None,
"brain_name": brain.name if brain else None,
"message_id": new_chat.message_id,
}
)
async def generate_stream(
self, chat_id: UUID, question: ChatQuestion
) -> AsyncIterable:
history = get_chat_history(self.chat_id)
callback = AsyncIteratorCallbackHandler()
self.callbacks = [callback]
answering_llm = self._create_llm(
model=self.model,
streaming=True,
callbacks=self.callbacks,
max_tokens=self.max_tokens,
)
# The Chain that generates the answer to the question
doc_chain = load_qa_chain(
answering_llm, chain_type="stuff", prompt=self._create_prompt_template()
)
# The Chain that combines the question and answer
qa = ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(), # type: ignore
combine_docs_chain=doc_chain,
question_generator=LLMChain(
llm=self._create_llm(model=self.model), prompt=CONDENSE_QUESTION_PROMPT
),
verbose=False,
rephrase_question=False,
)
transformed_history = format_chat_history(history)
response_tokens = []
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
event.set()
prompt_content = self.prompt_to_use.content if self.prompt_to_use else None
run = asyncio.create_task(
wrap_done(
qa.acall(
{
"question": question.question,
"chat_history": transformed_history,
"custom_personality": prompt_content,
}
),
callback.done,
)
)
brain = None
if question.brain_id:
brain = get_brain_by_id(question.brain_id)
streamed_chat_history = update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": "",
"brain_id": question.brain_id,
"prompt_id": self.prompt_to_use_id,
}
)
)
streamed_chat_history = GetChatHistoryOutput(
**{
"chat_id": str(chat_id),
"message_id": streamed_chat_history.message_id,
"message_time": streamed_chat_history.message_time,
"user_message": question.question,
"assistant": "",
"prompt_title": self.prompt_to_use.title
if self.prompt_to_use
else None,
"brain_name": brain.name if brain else None,
}
)
async for token in callback.aiter():
logger.info("Token: %s", token)
response_tokens.append(token)
streamed_chat_history.assistant = token
yield f"data: {json.dumps(streamed_chat_history.dict())}"
await run
assistant = "".join(response_tokens)
update_message_by_id(
message_id=str(streamed_chat_history.message_id),
user_message=question.question,
assistant=assistant,
)