feat(citations): system added (#2498)

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
This commit is contained in:
Stan Girard 2024-04-26 08:11:01 -07:00 committed by GitHub
parent 19365c4bb5
commit b7ff2e77af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 169 additions and 90 deletions

View File

@ -1,4 +1,5 @@
import json import json
import logging
from typing import AsyncIterable, List, Optional from typing import AsyncIterable, List, Optional
from uuid import UUID from uuid import UUID
@ -26,7 +27,7 @@ from modules.user.service.user_usage import UserUsage
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
logger = get_logger(__name__) logger = get_logger(__name__, log_level=logging.INFO)
QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer." QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."
@ -43,7 +44,11 @@ def is_valid_uuid(uuid_to_test, version=4):
return str(uuid_obj) == uuid_to_test return str(uuid_obj) == uuid_to_test
def generate_source(source_documents, brain_id): def generate_source(source_documents, brain_id, citations: List[int] = None):
"""
Generate the sources list for the answer
It takes in a list of sources documents and citations that points to the docs index that was used in the answer
"""
# Initialize an empty list for sources # Initialize an empty list for sources
sources_list: List[Sources] = [] sources_list: List[Sources] = []
@ -51,26 +56,26 @@ def generate_source(source_documents, brain_id):
generated_urls = {} generated_urls = {}
# remove duplicate sources with same name and create a list of unique sources # remove duplicate sources with same name and create a list of unique sources
source_documents = list( sources_url_cache = {}
{v.metadata["file_name"]: v for v in source_documents}.values()
)
# Get source documents from the result, default to an empty list if not found # Get source documents from the result, default to an empty list if not found
# If source documents exist # If source documents exist
if source_documents: if source_documents:
logger.info(f"Source documents found: {source_documents}") logger.info(f"Citations {citations}")
# Iterate over each document # Iterate over each document
for doc in source_documents: for doc, index in zip(source_documents, range(len(source_documents))):
logger.info("Document: %s", doc) logger.info(f"Processing source document {doc.metadata['file_name']}")
if citations is not None:
if index not in citations:
logger.info(f"Skipping source document {doc.metadata['file_name']}")
continue
# Check if 'url' is in the document metadata # Check if 'url' is in the document metadata
logger.info(f"Metadata 1: {doc.metadata}")
is_url = ( is_url = (
"original_file_name" in doc.metadata "original_file_name" in doc.metadata
and doc.metadata["original_file_name"] is not None and doc.metadata["original_file_name"] is not None
and doc.metadata["original_file_name"].startswith("http") and doc.metadata["original_file_name"].startswith("http")
) )
logger.info(f"Is URL: {is_url}")
# Determine the name based on whether it's a URL or a file # Determine the name based on whether it's a URL or a file
name = ( name = (
@ -90,6 +95,10 @@ def generate_source(source_documents, brain_id):
# Check if the URL has already been generated # Check if the URL has already been generated
if file_path in generated_urls: if file_path in generated_urls:
source_url = generated_urls[file_path] source_url = generated_urls[file_path]
else:
# Generate the URL
if file_path in sources_url_cache:
source_url = sources_url_cache[file_path]
else: else:
generated_url = generate_file_signed_url(file_path) generated_url = generate_file_signed_url(file_path)
if generated_url is not None: if generated_url is not None:
@ -106,6 +115,7 @@ def generate_source(source_documents, brain_id):
type=type_, type=type_,
source_url=source_url, source_url=source_url,
original_file_name=name, original_file_name=name,
citation=doc.page_content,
) )
) )
else: else:
@ -219,10 +229,6 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
def calculate_pricing(self): def calculate_pricing(self):
logger.info("Calculating pricing")
logger.info(f"Model: {self.model}")
logger.info(f"User settings: {self.user_settings}")
logger.info(f"Models settings: {self.models_settings}")
model_to_use = find_model_and_generate_metadata( model_to_use = find_model_and_generate_metadata(
self.chat_id, self.chat_id,
self.brain.model, self.brain.model,
@ -248,6 +254,8 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
self.initialize_streamed_chat_history(chat_id, question) self.initialize_streamed_chat_history(chat_id, question)
) )
metadata = self.metadata or {} metadata = self.metadata or {}
citations = None
answer = ""
model_response = conversational_qa_chain.invoke( model_response = conversational_qa_chain.invoke(
{ {
"question": question.question, "question": question.question,
@ -258,57 +266,23 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
} }
) )
if self.model_compatible_with_function_calling(model=self.model):
if model_response["answer"].tool_calls:
citations = model_response["answer"].tool_calls[-1]["args"]["citations"]
if citations:
citations = citations
answer = model_response["answer"].tool_calls[-1]["args"]["answer"]
metadata["citations"] = citations
else:
answer = model_response["answer"].content
sources = model_response["docs"] or [] sources = model_response["docs"] or []
if len(sources) > 0: if len(sources) > 0:
sources_list = generate_source(sources, self.brain_id) sources_list = generate_source(sources, self.brain_id, citations=citations)
metadata["sources"] = sources_list serialized_sources_list = [source.dict() for source in sources_list]
metadata["sources"] = serialized_sources_list
answer = model_response["answer"].content return self.save_non_streaming_answer(
chat_id=chat_id, question=question, answer=answer, metadata=metadata
if save_answer:
# save the answer to the database or not -> add a variable
new_chat = chat_service.update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"brain_id": self.brain.brain_id,
"prompt_id": self.prompt_to_use_id,
}
)
)
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"message_time": new_chat.message_time,
"prompt_title": (
self.prompt_to_use.title if self.prompt_to_use else None
),
"brain_name": self.brain.name if self.brain else None,
"message_id": new_chat.message_id,
"brain_id": str(self.brain.brain_id) if self.brain else None,
"metadata": metadata,
}
)
return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"message_time": None,
"prompt_title": (
self.prompt_to_use.title if self.prompt_to_use else None
),
"brain_name": None,
"message_id": None,
"brain_id": str(self.brain.brain_id) if self.brain else None,
"metadata": metadata,
}
) )
async def generate_stream( async def generate_stream(
@ -318,9 +292,10 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
transformed_history, streamed_chat_history = ( transformed_history, streamed_chat_history = (
self.initialize_streamed_chat_history(chat_id, question) self.initialize_streamed_chat_history(chat_id, question)
) )
response_tokens = [] response_tokens = ""
sources = [] sources = []
citations = []
first = True
async for chunk in conversational_qa_chain.astream( async for chunk in conversational_qa_chain.astream(
{ {
"question": question.question, "question": question.question,
@ -330,17 +305,46 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
), ),
} }
): ):
if not streamed_chat_history.metadata:
streamed_chat_history.metadata = {}
if self.model_compatible_with_function_calling(model=self.model):
if chunk.get("answer"): if chunk.get("answer"):
logger.info(f"Chunk: {chunk}") if first:
response_tokens.append(chunk["answer"].content) gathered = chunk["answer"]
first = False
else:
gathered = gathered + chunk["answer"]
if (
gathered.tool_calls
and gathered.tool_calls[-1].get("args")
and "answer" in gathered.tool_calls[-1]["args"]
):
# Only send the difference between answer and response_tokens which was the previous answer
answer = gathered.tool_calls[-1]["args"]["answer"]
difference = answer[len(response_tokens) :]
streamed_chat_history.assistant = difference
response_tokens = answer
yield f"data: {json.dumps(streamed_chat_history.dict())}"
if (
gathered.tool_calls
and gathered.tool_calls[-1].get("args")
and "citations" in gathered.tool_calls[-1]["args"]
):
citations = gathered.tool_calls[-1]["args"]["citations"]
else:
if chunk.get("answer"):
response_tokens += chunk["answer"].content
streamed_chat_history.assistant = chunk["answer"].content streamed_chat_history.assistant = chunk["answer"].content
yield f"data: {json.dumps(streamed_chat_history.dict())}" yield f"data: {json.dumps(streamed_chat_history.dict())}"
if chunk.get("docs"): if chunk.get("docs"):
sources = chunk["docs"] sources = chunk["docs"]
sources_list = generate_source(sources, self.brain_id) sources_list = generate_source(sources, self.brain_id, citations)
if not streamed_chat_history.metadata:
streamed_chat_history.metadata = {} streamed_chat_history.metadata["citations"] = citations
# Serialize the sources list # Serialize the sources list
serialized_sources_list = [source.dict() for source in sources_list] serialized_sources_list = [source.dict() for source in sources_list]
streamed_chat_history.metadata["sources"] = serialized_sources_list streamed_chat_history.metadata["sources"] = serialized_sources_list
@ -398,7 +402,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
except Exception as e: except Exception as e:
logger.error("Error updating message by ID: %s", e) logger.error("Error updating message by ID: %s", e)
def save_non_streaming_answer(self, chat_id, question, answer): def save_non_streaming_answer(self, chat_id, question, answer, metadata):
new_chat = chat_service.update_chat_history( new_chat = chat_service.update_chat_history(
CreateChatHistory( CreateChatHistory(
**{ **{
@ -407,6 +411,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
"assistant": answer, "assistant": answer,
"brain_id": self.brain.brain_id, "brain_id": self.brain.brain_id,
"prompt_id": self.prompt_to_use_id, "prompt_id": self.prompt_to_use_id,
"metadata": metadata,
} }
) )
) )
@ -423,5 +428,6 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
"brain_name": self.brain.name if self.brain else None, "brain_name": self.brain.name if self.brain else None,
"message_id": new_chat.message_id, "message_id": new_chat.message_id,
"brain_id": str(self.brain.brain_id) if self.brain else None, "brain_id": str(self.brain.brain_id) if self.brain else None,
"metadata": metadata,
} }
) )

View File

@ -39,3 +39,20 @@ class QAInterface(ABC):
raise NotImplementedError( raise NotImplementedError(
"generate_stream is an abstract method and must be implemented" "generate_stream is an abstract method and must be implemented"
) )
def model_compatible_with_function_calling(self, model: str):
if model in [
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-0613",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
]:
return True
return False

View File

@ -1,6 +1,7 @@
import logging
import os import os
from operator import itemgetter from operator import itemgetter
from typing import Optional from typing import List, Optional
from uuid import UUID from uuid import UUID
from langchain.chains import ConversationalRetrievalChain from langchain.chains import ConversationalRetrievalChain
@ -14,8 +15,10 @@ from langchain_cohere import CohereRerank
from langchain_community.chat_models import ChatLiteLLM from langchain_community.chat_models import ChatLiteLLM
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
from langchain_core.pydantic_v1 import Field as FieldV1
from langchain_core.runnables import RunnableLambda, RunnablePassthrough from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_openai import OpenAIEmbeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from logger import get_logger from logger import get_logger
from models import BrainSettings # Importing settings related to the 'brain' from models import BrainSettings # Importing settings related to the 'brain'
from modules.brain.service.brain_service import BrainService from modules.brain.service.brain_service import BrainService
@ -26,7 +29,20 @@ from pydantic_settings import BaseSettings
from supabase.client import Client, create_client from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore from vectorstore.supabase import CustomSupabaseVectorStore
logger = get_logger(__name__) logger = get_logger(__name__, log_level=logging.INFO)
class cited_answer(BaseModelV1):
"""Answer the user question based only on the given sources, and cite the sources used."""
answer: str = FieldV1(
...,
description="The answer to the user question, which is based only on the given sources.",
)
citations: List[int] = FieldV1(
...,
description="The integer IDs of the SPECIFIC sources which justify the answer.",
)
# First step is to create the Rephrasing Prompt # First step is to create the Rephrasing Prompt
@ -66,7 +82,9 @@ ANSWER_PROMPT = ChatPromptTemplate.from_messages(
# How we format documents # How we format documents
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
template="Source: {index} \n {page_content}"
)
def is_valid_uuid(uuid_to_test, version=4): def is_valid_uuid(uuid_to_test, version=4):
@ -116,6 +134,23 @@ class QuivrRAG(BaseModel):
else: else:
return None return None
def model_compatible_with_function_calling(self):
if self.model in [
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-0613",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
]:
return True
return False
supabase_client: Optional[Client] = None supabase_client: Optional[Client] = None
vector_store: Optional[CustomSupabaseVectorStore] = None vector_store: Optional[CustomSupabaseVectorStore] = None
qa: Optional[ConversationalRetrievalChain] = None qa: Optional[ConversationalRetrievalChain] = None
@ -197,6 +232,9 @@ class QuivrRAG(BaseModel):
def _combine_documents( def _combine_documents(
self, docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" self, docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
): ):
# for each docs, add an index in the metadata to be able to cite the sources
for doc, index in zip(docs, range(len(docs))):
doc.metadata["index"] = index
doc_strings = [format_document(doc, document_prompt) for doc in docs] doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings) return document_separator.join(doc_strings)
@ -287,14 +325,27 @@ class QuivrRAG(BaseModel):
"question": itemgetter("question"), "question": itemgetter("question"),
"custom_instructions": itemgetter("custom_instructions"), "custom_instructions": itemgetter("custom_instructions"),
} }
llm = ChatLiteLLM(
max_tokens=self.max_tokens,
model=self.model,
temperature=self.temperature,
api_base=api_base,
)
if self.model_compatible_with_function_calling():
# And finally, we do the part that returns the answers # And finally, we do the part that returns the answers
llm_function = ChatOpenAI(
max_tokens=self.max_tokens,
model=self.model,
temperature=self.temperature,
)
llm = llm_function.bind_tools(
[cited_answer],
tool_choice="cited_answer",
)
answer = { answer = {
"answer": final_inputs "answer": final_inputs | ANSWER_PROMPT | llm,
| ANSWER_PROMPT
| ChatLiteLLM(
max_tokens=self.max_tokens, model=self.model, api_base=api_base
),
"docs": itemgetter("docs"), "docs": itemgetter("docs"),
} }

View File

@ -4,7 +4,7 @@ from uuid import UUID
from modules.chat.dto.outputs import GetChatHistoryOutput from modules.chat.dto.outputs import GetChatHistoryOutput
from modules.notification.entity.notification import Notification from modules.notification.entity.notification import Notification
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
@ -33,6 +33,7 @@ class Sources(BaseModel):
source_url: str source_url: str
type: str type: str
original_file_name: str original_file_name: str
citation: str
class ChatItemType(Enum): class ChatItemType(Enum):

View File

@ -11,6 +11,7 @@ class CreateChatHistory(BaseModel):
assistant: str assistant: str
prompt_id: Optional[UUID] = None prompt_id: Optional[UUID] = None
brain_id: Optional[UUID] = None brain_id: Optional[UUID] = None
metadata: Optional[dict] = {}
class QuestionAndAnswer(BaseModel): class QuestionAndAnswer(BaseModel):

View File

@ -74,6 +74,7 @@ class Chats(ChatsInterface):
"brain_id": ( "brain_id": (
str(chat_history.brain_id) if chat_history.brain_id else None str(chat_history.brain_id) if chat_history.brain_id else None
), ),
"metadata": chat_history.metadata if chat_history.metadata else {},
} }
) )
.execute() .execute()
@ -104,7 +105,9 @@ class Chats(ChatsInterface):
def delete_chat_history(self, chat_id): def delete_chat_history(self, chat_id):
self.db.table("chat_history").delete().match({"chat_id": chat_id}).execute() self.db.table("chat_history").delete().match({"chat_id": chat_id}).execute()
def update_chat_message(self, chat_id, message_id, chat_message_properties: ChatMessageProperties ): def update_chat_message(
self, chat_id, message_id, chat_message_properties: ChatMessageProperties
):
response = ( response = (
self.db.table("chat_history") self.db.table("chat_history")
.update(chat_message_properties) .update(chat_message_properties)