feat(citations): system added (#2498)

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
This commit is contained in:
Stan Girard 2024-04-26 08:11:01 -07:00 committed by GitHub
parent 19365c4bb5
commit b7ff2e77af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 169 additions and 90 deletions

View File

@ -1,4 +1,5 @@
import json
import logging
from typing import AsyncIterable, List, Optional
from uuid import UUID
@ -26,7 +27,7 @@ from modules.user.service.user_usage import UserUsage
from pydantic import BaseModel, ConfigDict
from pydantic_settings import BaseSettings
logger = get_logger(__name__)
logger = get_logger(__name__, log_level=logging.INFO)
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
@ -43,7 +44,11 @@ def is_valid_uuid(uuid_to_test, version=4):
return str(uuid_obj) == uuid_to_test
def generate_source(source_documents, brain_id):
def generate_source(source_documents, brain_id, citations: List[int] = None):
"""
Generate the sources list for the answer
It takes in a list of sources documents and citations that points to the docs index that was used in the answer
"""
# Initialize an empty list for sources
sources_list: List[Sources] = []
@ -51,26 +56,26 @@ def generate_source(source_documents, brain_id):
generated_urls = {}
# remove duplicate sources with same name and create a list of unique sources
source_documents = list(
{v.metadata["file_name"]: v for v in source_documents}.values()
)
sources_url_cache = {}
# Get source documents from the result, default to an empty list if not found
# If source documents exist
if source_documents:
logger.info(f"Source documents found: {source_documents}")
logger.info(f"Citations {citations}")
# Iterate over each document
for doc in source_documents:
logger.info("Document: %s", doc)
for doc, index in zip(source_documents, range(len(source_documents))):
logger.info(f"Processing source document {doc.metadata['file_name']}")
if citations is not None:
if index not in citations:
logger.info(f"Skipping source document {doc.metadata['file_name']}")
continue
# Check if 'url' is in the document metadata
logger.info(f"Metadata 1: {doc.metadata}")
is_url = (
"original_file_name" in doc.metadata
and doc.metadata["original_file_name"] is not None
and doc.metadata["original_file_name"].startswith("http")
)
logger.info(f"Is URL: {is_url}")
# Determine the name based on whether it's a URL or a file
name = (
@ -90,6 +95,10 @@ def generate_source(source_documents, brain_id):
# Check if the URL has already been generated
if file_path in generated_urls:
source_url = generated_urls[file_path]
else:
# Generate the URL
if file_path in sources_url_cache:
source_url = sources_url_cache[file_path]
else:
generated_url = generate_file_signed_url(file_path)
if generated_url is not None:
@ -106,6 +115,7 @@ def generate_source(source_documents, brain_id):
type=type_,
source_url=source_url,
original_file_name=name,
citation=doc.page_content,
)
)
else:
@ -219,10 +229,6 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
def calculate_pricing(self):
logger.info("Calculating pricing")
logger.info(f"Model: {self.model}")
logger.info(f"User settings: {self.user_settings}")
logger.info(f"Models settings: {self.models_settings}")
model_to_use = find_model_and_generate_metadata(
self.chat_id,
self.brain.model,
@ -248,6 +254,8 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
self.initialize_streamed_chat_history(chat_id, question)
)
metadata = self.metadata or {}
citations = None
answer = ""
model_response = conversational_qa_chain.invoke(
{
"question": question.question,
@ -258,57 +266,23 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
}
)
if self.model_compatible_with_function_calling(model=self.model):
if model_response["answer"].tool_calls:
citations = model_response["answer"].tool_calls[-1]["args"]["citations"]
if citations:
citations = citations
answer = model_response["answer"].tool_calls[-1]["args"]["answer"]
metadata["citations"] = citations
else:
answer = model_response["answer"].content
sources = model_response["docs"] or []
if len(sources) > 0:
sources_list = generate_source(sources, self.brain_id)
metadata["sources"] = sources_list
sources_list = generate_source(sources, self.brain_id, citations=citations)
serialized_sources_list = [source.dict() for source in sources_list]
metadata["sources"] = serialized_sources_list
answer = model_response["answer"].content
if save_answer:
# save the answer to the database or not -> add a variable
new_chat = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"brain_id": self.brain.brain_id,
"prompt_id": self.prompt_to_use_id,
}
)
)
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"message_time": new_chat.message_time,
"prompt_title": (
self.prompt_to_use.title if self.prompt_to_use else None
),
"brain_name": self.brain.name if self.brain else None,
"message_id": new_chat.message_id,
"brain_id": str(self.brain.brain_id) if self.brain else None,
"metadata": metadata,
}
)
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"message_time": None,
"prompt_title": (
self.prompt_to_use.title if self.prompt_to_use else None
),
"brain_name": None,
"message_id": None,
"brain_id": str(self.brain.brain_id) if self.brain else None,
"metadata": metadata,
}
return self.save_non_streaming_answer(
chat_id=chat_id, question=question, answer=answer, metadata=metadata
)
async def generate_stream(
@ -318,9 +292,10 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question)
)
response_tokens = []
response_tokens = ""
sources = []
citations = []
first = True
async for chunk in conversational_qa_chain.astream(
{
"question": question.question,
@ -330,17 +305,46 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
),
}
):
if not streamed_chat_history.metadata:
streamed_chat_history.metadata = {}
if self.model_compatible_with_function_calling(model=self.model):
if chunk.get("answer"):
logger.info(f"Chunk: {chunk}")
response_tokens.append(chunk["answer"].content)
if first:
gathered = chunk["answer"]
first = False
else:
gathered = gathered + chunk["answer"]
if (
gathered.tool_calls
and gathered.tool_calls[-1].get("args")
and "answer" in gathered.tool_calls[-1]["args"]
):
# Only send the difference between answer and response_tokens which was the previous answer
answer = gathered.tool_calls[-1]["args"]["answer"]
difference = answer[len(response_tokens) :]
streamed_chat_history.assistant = difference
response_tokens = answer
yield f"data: {json.dumps(streamed_chat_history.dict())}"
if (
gathered.tool_calls
and gathered.tool_calls[-1].get("args")
and "citations" in gathered.tool_calls[-1]["args"]
):
citations = gathered.tool_calls[-1]["args"]["citations"]
else:
if chunk.get("answer"):
response_tokens += chunk["answer"].content
streamed_chat_history.assistant = chunk["answer"].content
yield f"data: {json.dumps(streamed_chat_history.dict())}"
if chunk.get("docs"):
sources = chunk["docs"]
sources_list = generate_source(sources, self.brain_id)
if not streamed_chat_history.metadata:
streamed_chat_history.metadata = {}
sources_list = generate_source(sources, self.brain_id, citations)
streamed_chat_history.metadata["citations"] = citations
# Serialize the sources list
serialized_sources_list = [source.dict() for source in sources_list]
streamed_chat_history.metadata["sources"] = serialized_sources_list
@ -398,7 +402,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
except Exception as e:
logger.error("Error updating message by ID: %s", e)
def save_non_streaming_answer(self, chat_id, question, answer):
def save_non_streaming_answer(self, chat_id, question, answer, metadata):
new_chat = chat_service.update_chat_history(
CreateChatHistory(
**{
@ -407,6 +411,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
"assistant": answer,
"brain_id": self.brain.brain_id,
"prompt_id": self.prompt_to_use_id,
"metadata": metadata,
}
)
)
@ -423,5 +428,6 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
"brain_name": self.brain.name if self.brain else None,
"message_id": new_chat.message_id,
"brain_id": str(self.brain.brain_id) if self.brain else None,
"metadata": metadata,
}
)

View File

@ -39,3 +39,20 @@ class QAInterface(ABC):
raise NotImplementedError(
"generate_stream is an abstract method and must be implemented"
)
def model_compatible_with_function_calling(self, model: str):
if model in [
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-0613",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
]:
return True
return False

View File

@ -1,6 +1,7 @@
import logging
import os
from operator import itemgetter
from typing import Optional
from typing import List, Optional
from uuid import UUID
from langchain.chains import ConversationalRetrievalChain
@ -14,8 +15,10 @@ from langchain_cohere import CohereRerank
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
from langchain_core.pydantic_v1 import Field as FieldV1
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain'
from modules.brain.service.brain_service import BrainService
@ -26,7 +29,20 @@ from pydantic_settings import BaseSettings
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
logger = get_logger(__name__)
logger = get_logger(__name__, log_level=logging.INFO)
class cited_answer(BaseModelV1):
"""Answer the user question based only on the given sources, and cite the sources used."""
answer: str = FieldV1(
...,
description="The answer to the user question, which is based only on the given sources.",
)
citations: List[int] = FieldV1(
...,
description="The integer IDs of the SPECIFIC sources which justify the answer.",
)
# First step is to create the Rephrasing Prompt
@ -66,7 +82,9 @@ ANSWER_PROMPT = ChatPromptTemplate.from_messages(
# How we format documents
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
template="Source: {index} \n {page_content}"
)
def is_valid_uuid(uuid_to_test, version=4):
@ -116,6 +134,23 @@ class QuivrRAG(BaseModel):
else:
return None
def model_compatible_with_function_calling(self):
if self.model in [
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-0613",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
]:
return True
return False
supabase_client: Optional[Client] = None
vector_store: Optional[CustomSupabaseVectorStore] = None
qa: Optional[ConversationalRetrievalChain] = None
@ -197,6 +232,9 @@ class QuivrRAG(BaseModel):
def _combine_documents(
self, docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
):
# for each docs, add an index in the metadata to be able to cite the sources
for doc, index in zip(docs, range(len(docs))):
doc.metadata["index"] = index
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
@ -287,14 +325,27 @@ class QuivrRAG(BaseModel):
"question": itemgetter("question"),
"custom_instructions": itemgetter("custom_instructions"),
}
llm = ChatLiteLLM(
max_tokens=self.max_tokens,
model=self.model,
temperature=self.temperature,
api_base=api_base,
)
if self.model_compatible_with_function_calling():
# And finally, we do the part that returns the answers
llm_function = ChatOpenAI(
max_tokens=self.max_tokens,
model=self.model,
temperature=self.temperature,
)
llm = llm_function.bind_tools(
[cited_answer],
tool_choice="cited_answer",
)
answer = {
"answer": final_inputs
| ANSWER_PROMPT
| ChatLiteLLM(
max_tokens=self.max_tokens, model=self.model, api_base=api_base
),
"answer": final_inputs | ANSWER_PROMPT | llm,
"docs": itemgetter("docs"),
}

View File

@ -4,7 +4,7 @@ from uuid import UUID
from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.notification.entity.notification import Notification
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel
class ChatMessage(BaseModel):
@ -33,6 +33,7 @@ class Sources(BaseModel):
source_url: str
type: str
original_file_name: str
citation: str
class ChatItemType(Enum):

View File

@ -11,6 +11,7 @@ class CreateChatHistory(BaseModel):
assistant: str
prompt_id: Optional[UUID] = None
brain_id: Optional[UUID] = None
metadata: Optional[dict] = {}
class QuestionAndAnswer(BaseModel):

View File

@ -74,6 +74,7 @@ class Chats(ChatsInterface):
"brain_id": (
str(chat_history.brain_id) if chat_history.brain_id else None
),
"metadata": chat_history.metadata if chat_history.metadata else {},
}
)
.execute()
@ -104,7 +105,9 @@ class Chats(ChatsInterface):
def delete_chat_history(self, chat_id):
self.db.table("chat_history").delete().match({"chat_id": chat_id}).execute()
def update_chat_message(self, chat_id, message_id, chat_message_properties: ChatMessageProperties ):
def update_chat_message(
self, chat_id, message_id, chat_message_properties: ChatMessageProperties
):
response = (
self.db.table("chat_history")
.update(chat_message_properties)