mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 17:03:29 +03:00
feat(citations): system added (#2498)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
19365c4bb5
commit
b7ff2e77af
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncIterable, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
@ -26,7 +27,7 @@ from modules.user.service.user_usage import UserUsage
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = get_logger(__name__, log_level=logging.INFO)
|
||||
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
||||
|
||||
|
||||
@ -43,7 +44,11 @@ def is_valid_uuid(uuid_to_test, version=4):
|
||||
return str(uuid_obj) == uuid_to_test
|
||||
|
||||
|
||||
def generate_source(source_documents, brain_id):
|
||||
def generate_source(source_documents, brain_id, citations: List[int] = None):
|
||||
"""
|
||||
Generate the sources list for the answer
|
||||
It takes in a list of sources documents and citations that points to the docs index that was used in the answer
|
||||
"""
|
||||
# Initialize an empty list for sources
|
||||
sources_list: List[Sources] = []
|
||||
|
||||
@ -51,26 +56,26 @@ def generate_source(source_documents, brain_id):
|
||||
generated_urls = {}
|
||||
|
||||
# remove duplicate sources with same name and create a list of unique sources
|
||||
source_documents = list(
|
||||
{v.metadata["file_name"]: v for v in source_documents}.values()
|
||||
)
|
||||
sources_url_cache = {}
|
||||
|
||||
# Get source documents from the result, default to an empty list if not found
|
||||
|
||||
# If source documents exist
|
||||
if source_documents:
|
||||
logger.info(f"Source documents found: {source_documents}")
|
||||
logger.info(f"Citations {citations}")
|
||||
# Iterate over each document
|
||||
for doc in source_documents:
|
||||
logger.info("Document: %s", doc)
|
||||
for doc, index in zip(source_documents, range(len(source_documents))):
|
||||
logger.info(f"Processing source document {doc.metadata['file_name']}")
|
||||
if citations is not None:
|
||||
if index not in citations:
|
||||
logger.info(f"Skipping source document {doc.metadata['file_name']}")
|
||||
continue
|
||||
# Check if 'url' is in the document metadata
|
||||
logger.info(f"Metadata 1: {doc.metadata}")
|
||||
is_url = (
|
||||
"original_file_name" in doc.metadata
|
||||
and doc.metadata["original_file_name"] is not None
|
||||
and doc.metadata["original_file_name"].startswith("http")
|
||||
)
|
||||
logger.info(f"Is URL: {is_url}")
|
||||
|
||||
# Determine the name based on whether it's a URL or a file
|
||||
name = (
|
||||
@ -91,11 +96,15 @@ def generate_source(source_documents, brain_id):
|
||||
if file_path in generated_urls:
|
||||
source_url = generated_urls[file_path]
|
||||
else:
|
||||
generated_url = generate_file_signed_url(file_path)
|
||||
if generated_url is not None:
|
||||
source_url = generated_url.get("signedURL", "")
|
||||
# Generate the URL
|
||||
if file_path in sources_url_cache:
|
||||
source_url = sources_url_cache[file_path]
|
||||
else:
|
||||
source_url = ""
|
||||
generated_url = generate_file_signed_url(file_path)
|
||||
if generated_url is not None:
|
||||
source_url = generated_url.get("signedURL", "")
|
||||
else:
|
||||
source_url = ""
|
||||
# Store the generated URL
|
||||
generated_urls[file_path] = source_url
|
||||
|
||||
@ -106,6 +115,7 @@ def generate_source(source_documents, brain_id):
|
||||
type=type_,
|
||||
source_url=source_url,
|
||||
original_file_name=name,
|
||||
citation=doc.page_content,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -219,10 +229,6 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
|
||||
def calculate_pricing(self):
|
||||
|
||||
logger.info("Calculating pricing")
|
||||
logger.info(f"Model: {self.model}")
|
||||
logger.info(f"User settings: {self.user_settings}")
|
||||
logger.info(f"Models settings: {self.models_settings}")
|
||||
model_to_use = find_model_and_generate_metadata(
|
||||
self.chat_id,
|
||||
self.brain.model,
|
||||
@ -248,6 +254,8 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
self.initialize_streamed_chat_history(chat_id, question)
|
||||
)
|
||||
metadata = self.metadata or {}
|
||||
citations = None
|
||||
answer = ""
|
||||
model_response = conversational_qa_chain.invoke(
|
||||
{
|
||||
"question": question.question,
|
||||
@ -258,57 +266,23 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
}
|
||||
)
|
||||
|
||||
if self.model_compatible_with_function_calling(model=self.model):
|
||||
if model_response["answer"].tool_calls:
|
||||
citations = model_response["answer"].tool_calls[-1]["args"]["citations"]
|
||||
if citations:
|
||||
citations = citations
|
||||
answer = model_response["answer"].tool_calls[-1]["args"]["answer"]
|
||||
metadata["citations"] = citations
|
||||
else:
|
||||
answer = model_response["answer"].content
|
||||
sources = model_response["docs"] or []
|
||||
if len(sources) > 0:
|
||||
sources_list = generate_source(sources, self.brain_id)
|
||||
metadata["sources"] = sources_list
|
||||
sources_list = generate_source(sources, self.brain_id, citations=citations)
|
||||
serialized_sources_list = [source.dict() for source in sources_list]
|
||||
metadata["sources"] = serialized_sources_list
|
||||
|
||||
answer = model_response["answer"].content
|
||||
|
||||
if save_answer:
|
||||
# save the answer to the database or not -> add a variable
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": self.brain.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": new_chat.message_time,
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": self.brain.name if self.brain else None,
|
||||
"message_id": new_chat.message_id,
|
||||
"brain_id": str(self.brain.brain_id) if self.brain else None,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
return GetChatHistoryOutput(
|
||||
**{
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"message_time": None,
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": None,
|
||||
"message_id": None,
|
||||
"brain_id": str(self.brain.brain_id) if self.brain else None,
|
||||
"metadata": metadata,
|
||||
}
|
||||
return self.save_non_streaming_answer(
|
||||
chat_id=chat_id, question=question, answer=answer, metadata=metadata
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
@ -318,9 +292,10 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
transformed_history, streamed_chat_history = (
|
||||
self.initialize_streamed_chat_history(chat_id, question)
|
||||
)
|
||||
response_tokens = []
|
||||
response_tokens = ""
|
||||
sources = []
|
||||
|
||||
citations = []
|
||||
first = True
|
||||
async for chunk in conversational_qa_chain.astream(
|
||||
{
|
||||
"question": question.question,
|
||||
@ -330,18 +305,47 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
),
|
||||
}
|
||||
):
|
||||
if chunk.get("answer"):
|
||||
logger.info(f"Chunk: {chunk}")
|
||||
response_tokens.append(chunk["answer"].content)
|
||||
streamed_chat_history.assistant = chunk["answer"].content
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
if not streamed_chat_history.metadata:
|
||||
streamed_chat_history.metadata = {}
|
||||
if self.model_compatible_with_function_calling(model=self.model):
|
||||
if chunk.get("answer"):
|
||||
if first:
|
||||
gathered = chunk["answer"]
|
||||
first = False
|
||||
else:
|
||||
gathered = gathered + chunk["answer"]
|
||||
if (
|
||||
gathered.tool_calls
|
||||
and gathered.tool_calls[-1].get("args")
|
||||
and "answer" in gathered.tool_calls[-1]["args"]
|
||||
):
|
||||
# Only send the difference between answer and response_tokens which was the previous answer
|
||||
answer = gathered.tool_calls[-1]["args"]["answer"]
|
||||
difference = answer[len(response_tokens) :]
|
||||
streamed_chat_history.assistant = difference
|
||||
response_tokens = answer
|
||||
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
if (
|
||||
gathered.tool_calls
|
||||
and gathered.tool_calls[-1].get("args")
|
||||
and "citations" in gathered.tool_calls[-1]["args"]
|
||||
):
|
||||
citations = gathered.tool_calls[-1]["args"]["citations"]
|
||||
else:
|
||||
if chunk.get("answer"):
|
||||
response_tokens += chunk["answer"].content
|
||||
streamed_chat_history.assistant = chunk["answer"].content
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
|
||||
if chunk.get("docs"):
|
||||
sources = chunk["docs"]
|
||||
|
||||
sources_list = generate_source(sources, self.brain_id)
|
||||
if not streamed_chat_history.metadata:
|
||||
streamed_chat_history.metadata = {}
|
||||
# Serialize the sources list
|
||||
sources_list = generate_source(sources, self.brain_id, citations)
|
||||
|
||||
streamed_chat_history.metadata["citations"] = citations
|
||||
|
||||
# Serialize the sources list
|
||||
serialized_sources_list = [source.dict() for source in sources_list]
|
||||
streamed_chat_history.metadata["sources"] = serialized_sources_list
|
||||
yield f"data: {json.dumps(streamed_chat_history.dict())}"
|
||||
@ -398,7 +402,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
except Exception as e:
|
||||
logger.error("Error updating message by ID: %s", e)
|
||||
|
||||
def save_non_streaming_answer(self, chat_id, question, answer):
|
||||
def save_non_streaming_answer(self, chat_id, question, answer, metadata):
|
||||
new_chat = chat_service.update_chat_history(
|
||||
CreateChatHistory(
|
||||
**{
|
||||
@ -407,6 +411,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
"assistant": answer,
|
||||
"brain_id": self.brain.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
)
|
||||
@ -423,5 +428,6 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
"brain_name": self.brain.name if self.brain else None,
|
||||
"message_id": new_chat.message_id,
|
||||
"brain_id": str(self.brain.brain_id) if self.brain else None,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
@ -39,3 +39,20 @@ class QAInterface(ABC):
|
||||
raise NotImplementedError(
|
||||
"generate_stream is an abstract method and must be implemented"
|
||||
)
|
||||
|
||||
def model_compatible_with_function_calling(self, model: str):
|
||||
if model in [
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-2024-04-09",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0613",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo-0613",
|
||||
]:
|
||||
return True
|
||||
return False
|
||||
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from operator import itemgetter
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.chains import ConversationalRetrievalChain
|
||||
@ -14,8 +15,10 @@ from langchain_cohere import CohereRerank
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
|
||||
from langchain_core.pydantic_v1 import Field as FieldV1
|
||||
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
from logger import get_logger
|
||||
from models import BrainSettings # Importing settings related to the 'brain'
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
@ -26,7 +29,20 @@ from pydantic_settings import BaseSettings
|
||||
from supabase.client import Client, create_client
|
||||
from vectorstore.supabase import CustomSupabaseVectorStore
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = get_logger(__name__, log_level=logging.INFO)
|
||||
|
||||
|
||||
class cited_answer(BaseModelV1):
|
||||
"""Answer the user question based only on the given sources, and cite the sources used."""
|
||||
|
||||
answer: str = FieldV1(
|
||||
...,
|
||||
description="The answer to the user question, which is based only on the given sources.",
|
||||
)
|
||||
citations: List[int] = FieldV1(
|
||||
...,
|
||||
description="The integer IDs of the SPECIFIC sources which justify the answer.",
|
||||
)
|
||||
|
||||
|
||||
# First step is to create the Rephrasing Prompt
|
||||
@ -66,7 +82,9 @@ ANSWER_PROMPT = ChatPromptTemplate.from_messages(
|
||||
|
||||
# How we format documents
|
||||
|
||||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
||||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
|
||||
template="Source: {index} \n {page_content}"
|
||||
)
|
||||
|
||||
|
||||
def is_valid_uuid(uuid_to_test, version=4):
|
||||
@ -116,6 +134,23 @@ class QuivrRAG(BaseModel):
|
||||
else:
|
||||
return None
|
||||
|
||||
def model_compatible_with_function_calling(self):
|
||||
if self.model in [
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-2024-04-09",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0613",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo-0613",
|
||||
]:
|
||||
return True
|
||||
return False
|
||||
|
||||
supabase_client: Optional[Client] = None
|
||||
vector_store: Optional[CustomSupabaseVectorStore] = None
|
||||
qa: Optional[ConversationalRetrievalChain] = None
|
||||
@ -197,6 +232,9 @@ class QuivrRAG(BaseModel):
|
||||
def _combine_documents(
|
||||
self, docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
|
||||
):
|
||||
# for each docs, add an index in the metadata to be able to cite the sources
|
||||
for doc, index in zip(docs, range(len(docs))):
|
||||
doc.metadata["index"] = index
|
||||
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
||||
return document_separator.join(doc_strings)
|
||||
|
||||
@ -287,14 +325,27 @@ class QuivrRAG(BaseModel):
|
||||
"question": itemgetter("question"),
|
||||
"custom_instructions": itemgetter("custom_instructions"),
|
||||
}
|
||||
llm = ChatLiteLLM(
|
||||
max_tokens=self.max_tokens,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
api_base=api_base,
|
||||
)
|
||||
if self.model_compatible_with_function_calling():
|
||||
|
||||
# And finally, we do the part that returns the answers
|
||||
llm_function = ChatOpenAI(
|
||||
max_tokens=self.max_tokens,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
llm = llm_function.bind_tools(
|
||||
[cited_answer],
|
||||
tool_choice="cited_answer",
|
||||
)
|
||||
|
||||
# And finally, we do the part that returns the answers
|
||||
answer = {
|
||||
"answer": final_inputs
|
||||
| ANSWER_PROMPT
|
||||
| ChatLiteLLM(
|
||||
max_tokens=self.max_tokens, model=self.model, api_base=api_base
|
||||
),
|
||||
"answer": final_inputs | ANSWER_PROMPT | llm,
|
||||
"docs": itemgetter("docs"),
|
||||
}
|
||||
|
||||
|
@ -4,7 +4,7 @@ from uuid import UUID
|
||||
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from modules.notification.entity.notification import Notification
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@ -33,6 +33,7 @@ class Sources(BaseModel):
|
||||
source_url: str
|
||||
type: str
|
||||
original_file_name: str
|
||||
citation: str
|
||||
|
||||
|
||||
class ChatItemType(Enum):
|
||||
|
@ -11,6 +11,7 @@ class CreateChatHistory(BaseModel):
|
||||
assistant: str
|
||||
prompt_id: Optional[UUID] = None
|
||||
brain_id: Optional[UUID] = None
|
||||
metadata: Optional[dict] = {}
|
||||
|
||||
|
||||
class QuestionAndAnswer(BaseModel):
|
||||
|
@ -74,6 +74,7 @@ class Chats(ChatsInterface):
|
||||
"brain_id": (
|
||||
str(chat_history.brain_id) if chat_history.brain_id else None
|
||||
),
|
||||
"metadata": chat_history.metadata if chat_history.metadata else {},
|
||||
}
|
||||
)
|
||||
.execute()
|
||||
@ -104,7 +105,9 @@ class Chats(ChatsInterface):
|
||||
def delete_chat_history(self, chat_id):
|
||||
self.db.table("chat_history").delete().match({"chat_id": chat_id}).execute()
|
||||
|
||||
def update_chat_message(self, chat_id, message_id, chat_message_properties: ChatMessageProperties ):
|
||||
def update_chat_message(
|
||||
self, chat_id, message_id, chat_message_properties: ChatMessageProperties
|
||||
):
|
||||
response = (
|
||||
self.db.table("chat_history")
|
||||
.update(chat_message_properties)
|
||||
|
Loading…
Reference in New Issue
Block a user