quivr/backend/modules/brain/knowledge_brain_qa.py

434 lines
17 KiB
Python
Raw Normal View History

import json
import logging
from typing import AsyncIterable, List, Optional
from uuid import UUID
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from logger import get_logger
from models import BrainSettings
from modules.brain.entity.brain_entity import BrainEntity
from modules.brain.qa_interface import QAInterface
from modules.brain.rags.quivr_rag import QuivrRAG
from modules.brain.rags.rag_interface import RAGInterface
from modules.brain.service.brain_service import BrainService
from modules.brain.service.utils.format_chat_history import format_chat_history
from modules.brain.service.utils.get_prompt_to_use_id import get_prompt_to_use_id
from modules.chat.controller.chat.utils import (
find_model_and_generate_metadata,
update_user_usage,
)
from modules.chat.dto.chats import ChatQuestion, Sources
from modules.chat.dto.inputs import CreateChatHistory
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.chat.service.chat_service import ChatService
from modules.prompt.service.get_prompt_to_use import get_prompt_to_use
from modules.upload.service.generate_file_signed_url import generate_file_signed_url
from modules.user.service.user_usage import UserUsage
from pydantic import BaseModel, ConfigDict
from pydantic_settings import BaseSettings
logger = get_logger(__name__, log_level=logging.INFO)
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
brain_service = BrainService()
chat_service = ChatService()
def is_valid_uuid(uuid_to_test, version=4):
try:
uuid_obj = UUID(uuid_to_test, version=version)
except ValueError:
return False
return str(uuid_obj) == uuid_to_test
def generate_source(source_documents, brain_id, citations: List[int] = None):
"""
Generate the sources list for the answer
It takes in a list of sources documents and citations that points to the docs index that was used in the answer
"""
# Initialize an empty list for sources
sources_list: List[Sources] = []
# Initialize a dictionary for storing generated URLs
generated_urls = {}
# remove duplicate sources with same name and create a list of unique sources
sources_url_cache = {}
# Get source documents from the result, default to an empty list if not found
# If source documents exist
if source_documents:
logger.info(f"Citations {citations}")
# Iterate over each document
for doc, index in zip(source_documents, range(len(source_documents))):
logger.info(f"Processing source document {doc.metadata['file_name']}")
if citations is not None:
if index not in citations:
logger.info(f"Skipping source document {doc.metadata['file_name']}")
continue
# Check if 'url' is in the document metadata
is_url = (
"original_file_name" in doc.metadata
and doc.metadata["original_file_name"] is not None
and doc.metadata["original_file_name"].startswith("http")
)
# Determine the name based on whether it's a URL or a file
name = (
doc.metadata["original_file_name"]
if is_url
else doc.metadata["file_name"]
)
# Determine the type based on whether it's a URL or a file
type_ = "url" if is_url else "file"
# Determine the source URL based on whether it's a URL or a file
if is_url:
source_url = doc.metadata["original_file_name"]
else:
file_path = f"{brain_id}/{doc.metadata['file_name']}"
# Check if the URL has already been generated
if file_path in generated_urls:
source_url = generated_urls[file_path]
else:
# Generate the URL
if file_path in sources_url_cache:
source_url = sources_url_cache[file_path]
else:
generated_url = generate_file_signed_url(file_path)
if generated_url is not None:
source_url = generated_url.get("signedURL", "")
else:
source_url = ""
# Store the generated URL
generated_urls[file_path] = source_url
# Append a new Sources object to the list
sources_list.append(
Sources(
name=name,
type=type_,
source_url=source_url,
original_file_name=name,
citation=doc.page_content,
)
)
else:
logger.info("No source documents found or source_documents is not a list.")
return sources_list
class KnowledgeBrainQA(BaseModel, QAInterface):
"""
Main class for the Brain Picking functionality.
It allows to initialize a Chat model, generate questions and retrieve answers using ConversationalRetrievalChain.
It has two main methods: `generate_question` and `generate_stream`.
One is for generating questions in a single request, the other is for generating questions in a streaming fashion.
Both are the same, except that the streaming version streams the last message as a stream.
Each have the same prompt template, which is defined in the `prompt_template` property.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
# Instantiate settings
brain_settings: BaseSettings = BrainSettings()
# Default class attributes
model: str = "gpt-3.5-turbo-0125" # pyright: ignore reportPrivateUsage=none
temperature: float = 0.1
chat_id: str = None # pyright: ignore reportPrivateUsage=none
brain_id: str = None # pyright: ignore reportPrivateUsage=none
max_tokens: int = 2000
max_input: int = 2000
streaming: bool = False
knowledge_qa: Optional[RAGInterface] = None
brain: Optional[BrainEntity] = None
user_id: str = None
user_email: str = None
user_usage: Optional[UserUsage] = None
user_settings: Optional[dict] = None
models_settings: Optional[List[dict]] = None
metadata: Optional[dict] = None
callbacks: List[AsyncIteratorCallbackHandler] = (
None # pyright: ignore reportPrivateUsage=none
)
prompt_id: Optional[UUID] = None
def __init__(
self,
brain_id: str,
chat_id: str,
streaming: bool = False,
prompt_id: Optional[UUID] = None,
metadata: Optional[dict] = None,
user_id: str = None,
user_email: str = None,
cost: int = 100,
**kwargs,
):
super().__init__(
brain_id=brain_id,
chat_id=chat_id,
streaming=streaming,
**kwargs,
)
self.prompt_id = prompt_id
self.user_id = user_id
self.user_email = user_email
self.user_usage = UserUsage(
id=user_id,
email=user_email,
)
self.brain = brain_service.get_brain_by_id(brain_id)
self.user_settings = self.user_usage.get_user_settings()
# Get Model settings for the user
self.models_settings = self.user_usage.get_model_settings()
self.increase_usage_user()
self.knowledge_qa = QuivrRAG(
model=self.brain.model if self.brain.model else self.model,
brain_id=brain_id,
chat_id=chat_id,
streaming=streaming,
max_input=self.max_input,
max_tokens=self.max_tokens,
**kwargs,
)
@property
def prompt_to_use(self):
if self.brain_id and is_valid_uuid(self.brain_id):
return get_prompt_to_use(UUID(self.brain_id), self.prompt_id)
else:
return None
@property
def prompt_to_use_id(self) -> Optional[UUID]:
# TODO: move to prompt service or instruction or something
if self.brain_id and is_valid_uuid(self.brain_id):
return get_prompt_to_use_id(UUID(self.brain_id), self.prompt_id)
else:
return None
def increase_usage_user(self):
# Raises an error if the user has consumed all of of his credits
update_user_usage(
usage=self.user_usage,
user_settings=self.user_settings,
cost=self.calculate_pricing(),
)
def calculate_pricing(self):
model_to_use = find_model_and_generate_metadata(
self.chat_id,
self.brain.model,
self.user_settings,
self.models_settings,
)
self.model = model_to_use.name
self.max_input = model_to_use.max_input
self.max_tokens = model_to_use.max_output
user_choosen_model_price = 1000
for model_setting in self.models_settings:
if model_setting["name"] == self.model:
user_choosen_model_price = model_setting["price"]
return user_choosen_model_price
def generate_answer(
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> GetChatHistoryOutput:
conversational_qa_chain = self.knowledge_qa.get_chain()
transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question)
)
metadata = self.metadata or {}
citations = None
answer = ""
model_response = conversational_qa_chain.invoke(
{
2023-08-10 19:35:30 +03:00
"question": question.question,
"chat_history": transformed_history,
"custom_personality": (
self.prompt_to_use.content if self.prompt_to_use else None
),
}
)
if self.model_compatible_with_function_calling(model=self.model):
if model_response["answer"].tool_calls:
citations = model_response["answer"].tool_calls[-1]["args"]["citations"]
if citations:
citations = citations
answer = model_response["answer"].tool_calls[-1]["args"]["answer"]
metadata["citations"] = citations
else:
answer = model_response["answer"].content
sources = model_response["docs"] or []
if len(sources) > 0:
sources_list = generate_source(sources, self.brain_id, citations=citations)
serialized_sources_list = [source.dict() for source in sources_list]
metadata["sources"] = serialized_sources_list
return self.save_non_streaming_answer(
chat_id=chat_id, question=question, answer=answer, metadata=metadata
)
async def generate_stream(
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> AsyncIterable:
conversational_qa_chain = self.knowledge_qa.get_chain()
transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question)
)
response_tokens = ""
sources = []
citations = []
first = True
async for chunk in conversational_qa_chain.astream(
{
"question": question.question,
"chat_history": transformed_history,
"custom_personality": (
self.prompt_to_use.content if self.prompt_to_use else None
),
}
):
if not streamed_chat_history.metadata:
streamed_chat_history.metadata = {}
if self.model_compatible_with_function_calling(model=self.model):
if chunk.get("answer"):
if first:
gathered = chunk["answer"]
first = False
else:
gathered = gathered + chunk["answer"]
if (
gathered.tool_calls
and gathered.tool_calls[-1].get("args")
and "answer" in gathered.tool_calls[-1]["args"]
):
# Only send the difference between answer and response_tokens which was the previous answer
answer = gathered.tool_calls[-1]["args"]["answer"]
difference = answer[len(response_tokens) :]
streamed_chat_history.assistant = difference
response_tokens = answer
yield f"data: {json.dumps(streamed_chat_history.dict())}"
if (
gathered.tool_calls
and gathered.tool_calls[-1].get("args")
and "citations" in gathered.tool_calls[-1]["args"]
):
citations = gathered.tool_calls[-1]["args"]["citations"]
else:
if chunk.get("answer"):
response_tokens += chunk["answer"].content
streamed_chat_history.assistant = chunk["answer"].content
yield f"data: {json.dumps(streamed_chat_history.dict())}"
if chunk.get("docs"):
sources = chunk["docs"]
sources_list = generate_source(sources, self.brain_id, citations)
streamed_chat_history.metadata["citations"] = citations
# Serialize the sources list
serialized_sources_list = [source.dict() for source in sources_list]
streamed_chat_history.metadata["sources"] = serialized_sources_list
yield f"data: {json.dumps(streamed_chat_history.dict())}"
self.save_answer(question, response_tokens, streamed_chat_history, save_answer)
def initialize_streamed_chat_history(self, chat_id, question):
history = chat_service.get_chat_history(self.chat_id)
transformed_history = format_chat_history(history)
brain = brain_service.get_brain_by_id(self.brain_id)
streamed_chat_history = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": "",
"brain_id": brain.brain_id,
"prompt_id": self.prompt_to_use_id,
}
)
)
streamed_chat_history = GetChatHistoryOutput(
**{
"chat_id": str(chat_id),
"message_id": streamed_chat_history.message_id,
"message_time": streamed_chat_history.message_time,
"user_message": question.question,
"assistant": "",
"prompt_title": (
self.prompt_to_use.title if self.prompt_to_use else None
),
"brain_name": brain.name if brain else None,
"brain_id": str(brain.brain_id) if brain else None,
"metadata": self.metadata,
}
)
2024-01-29 02:37:14 +03:00
return transformed_history, streamed_chat_history
def save_answer(
self, question, response_tokens, streamed_chat_history, save_answer
):
assistant = "".join(response_tokens)
try:
if save_answer:
chat_service.update_message_by_id(
message_id=str(streamed_chat_history.message_id),
user_message=question.question,
assistant=assistant,
metadata=streamed_chat_history.metadata,
)
except Exception as e:
logger.error("Error updating message by ID: %s", e)
def save_non_streaming_answer(self, chat_id, question, answer, metadata):
new_chat = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"brain_id": self.brain.brain_id,
"prompt_id": self.prompt_to_use_id,
"metadata": metadata,
}
)
)
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"message_time": new_chat.message_time,
"prompt_title": (
self.prompt_to_use.title if self.prompt_to_use else None
),
"brain_name": self.brain.name if self.brain else None,
"message_id": new_chat.message_id,
"brain_id": str(self.brain.brain_id) if self.brain else None,
"metadata": metadata,
}
)