From b7ff2e77af7c3bcc783622b273589b5668938e11 Mon Sep 17 00:00:00 2001 From: Stan Girard Date: Fri, 26 Apr 2024 08:11:01 -0700 Subject: [PATCH] feat(citations): system added (#2498) # Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --- backend/modules/brain/knowledge_brain_qa.py | 162 ++++++++++---------- backend/modules/brain/qa_interface.py | 17 ++ backend/modules/brain/rags/quivr_rag.py | 71 +++++++-- backend/modules/chat/dto/chats.py | 3 +- backend/modules/chat/dto/inputs.py | 1 + backend/modules/chat/repository/chats.py | 5 +- 6 files changed, 169 insertions(+), 90 deletions(-) diff --git a/backend/modules/brain/knowledge_brain_qa.py b/backend/modules/brain/knowledge_brain_qa.py index e0eae4d01..92dc4f120 100644 --- a/backend/modules/brain/knowledge_brain_qa.py +++ b/backend/modules/brain/knowledge_brain_qa.py @@ -1,4 +1,5 @@ import json +import logging from typing import AsyncIterable, List, Optional from uuid import UUID @@ -26,7 +27,7 @@ from modules.user.service.user_usage import UserUsage from pydantic import BaseModel, ConfigDict from pydantic_settings import BaseSettings -logger = get_logger(__name__) +logger = get_logger(__name__, log_level=logging.INFO) QUIVR_DEFAULT_PROMPT = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer." @@ -43,7 +44,11 @@ def is_valid_uuid(uuid_to_test, version=4): return str(uuid_obj) == uuid_to_test -def generate_source(source_documents, brain_id): +def generate_source(source_documents, brain_id, citations: List[int] = None): + """ + Generate the sources list for the answer + It takes in a list of sources documents and citations that points to the docs index that was used in the answer + """ # Initialize an empty list for sources sources_list: List[Sources] = [] @@ -51,26 +56,26 @@ def generate_source(source_documents, brain_id): generated_urls = {} # remove duplicate sources with same name and create a list of unique sources - source_documents = list( - {v.metadata["file_name"]: v for v in source_documents}.values() - ) + sources_url_cache = {} # Get source documents from the result, default to an empty list if not found # If source documents exist if source_documents: - logger.info(f"Source documents found: {source_documents}") + logger.info(f"Citations {citations}") # Iterate over each document - for doc in source_documents: - logger.info("Document: %s", doc) + for doc, index in zip(source_documents, range(len(source_documents))): + logger.info(f"Processing source document {doc.metadata['file_name']}") + if citations is not None: + if index not in citations: + logger.info(f"Skipping source document {doc.metadata['file_name']}") + continue # Check if 'url' is in the document metadata - logger.info(f"Metadata 1: {doc.metadata}") is_url = ( "original_file_name" in doc.metadata and doc.metadata["original_file_name"] is not None and doc.metadata["original_file_name"].startswith("http") ) - logger.info(f"Is URL: {is_url}") # Determine the name based on whether it's a URL or a file name = ( @@ -91,11 +96,15 @@ def generate_source(source_documents, brain_id): if file_path in generated_urls: source_url = generated_urls[file_path] else: - generated_url = generate_file_signed_url(file_path) - if generated_url is not None: - source_url = generated_url.get("signedURL", "") + # Generate the URL + if file_path in sources_url_cache: + source_url = sources_url_cache[file_path] else: - source_url = "" + generated_url = generate_file_signed_url(file_path) + if generated_url is not None: + source_url = generated_url.get("signedURL", "") + else: + source_url = "" # Store the generated URL generated_urls[file_path] = source_url @@ -106,6 +115,7 @@ def generate_source(source_documents, brain_id): type=type_, source_url=source_url, original_file_name=name, + citation=doc.page_content, ) ) else: @@ -219,10 +229,6 @@ class KnowledgeBrainQA(BaseModel, QAInterface): def calculate_pricing(self): - logger.info("Calculating pricing") - logger.info(f"Model: {self.model}") - logger.info(f"User settings: {self.user_settings}") - logger.info(f"Models settings: {self.models_settings}") model_to_use = find_model_and_generate_metadata( self.chat_id, self.brain.model, @@ -248,6 +254,8 @@ class KnowledgeBrainQA(BaseModel, QAInterface): self.initialize_streamed_chat_history(chat_id, question) ) metadata = self.metadata or {} + citations = None + answer = "" model_response = conversational_qa_chain.invoke( { "question": question.question, @@ -258,57 +266,23 @@ class KnowledgeBrainQA(BaseModel, QAInterface): } ) + if self.model_compatible_with_function_calling(model=self.model): + if model_response["answer"].tool_calls: + citations = model_response["answer"].tool_calls[-1]["args"]["citations"] + if citations: + citations = citations + answer = model_response["answer"].tool_calls[-1]["args"]["answer"] + metadata["citations"] = citations + else: + answer = model_response["answer"].content sources = model_response["docs"] or [] if len(sources) > 0: - sources_list = generate_source(sources, self.brain_id) - metadata["sources"] = sources_list + sources_list = generate_source(sources, self.brain_id, citations=citations) + serialized_sources_list = [source.dict() for source in sources_list] + metadata["sources"] = serialized_sources_list - answer = model_response["answer"].content - - if save_answer: - # save the answer to the database or not -> add a variable - new_chat = chat_service.update_chat_history( - CreateChatHistory( - **{ - "chat_id": chat_id, - "user_message": question.question, - "assistant": answer, - "brain_id": self.brain.brain_id, - "prompt_id": self.prompt_to_use_id, - } - ) - ) - - return GetChatHistoryOutput( - **{ - "chat_id": chat_id, - "user_message": question.question, - "assistant": answer, - "message_time": new_chat.message_time, - "prompt_title": ( - self.prompt_to_use.title if self.prompt_to_use else None - ), - "brain_name": self.brain.name if self.brain else None, - "message_id": new_chat.message_id, - "brain_id": str(self.brain.brain_id) if self.brain else None, - "metadata": metadata, - } - ) - - return GetChatHistoryOutput( - **{ - "chat_id": chat_id, - "user_message": question.question, - "assistant": answer, - "message_time": None, - "prompt_title": ( - self.prompt_to_use.title if self.prompt_to_use else None - ), - "brain_name": None, - "message_id": None, - "brain_id": str(self.brain.brain_id) if self.brain else None, - "metadata": metadata, - } + return self.save_non_streaming_answer( + chat_id=chat_id, question=question, answer=answer, metadata=metadata ) async def generate_stream( @@ -318,9 +292,10 @@ class KnowledgeBrainQA(BaseModel, QAInterface): transformed_history, streamed_chat_history = ( self.initialize_streamed_chat_history(chat_id, question) ) - response_tokens = [] + response_tokens = "" sources = [] - + citations = [] + first = True async for chunk in conversational_qa_chain.astream( { "question": question.question, @@ -330,18 +305,47 @@ class KnowledgeBrainQA(BaseModel, QAInterface): ), } ): - if chunk.get("answer"): - logger.info(f"Chunk: {chunk}") - response_tokens.append(chunk["answer"].content) - streamed_chat_history.assistant = chunk["answer"].content - yield f"data: {json.dumps(streamed_chat_history.dict())}" + if not streamed_chat_history.metadata: + streamed_chat_history.metadata = {} + if self.model_compatible_with_function_calling(model=self.model): + if chunk.get("answer"): + if first: + gathered = chunk["answer"] + first = False + else: + gathered = gathered + chunk["answer"] + if ( + gathered.tool_calls + and gathered.tool_calls[-1].get("args") + and "answer" in gathered.tool_calls[-1]["args"] + ): + # Only send the difference between answer and response_tokens which was the previous answer + answer = gathered.tool_calls[-1]["args"]["answer"] + difference = answer[len(response_tokens) :] + streamed_chat_history.assistant = difference + response_tokens = answer + + yield f"data: {json.dumps(streamed_chat_history.dict())}" + if ( + gathered.tool_calls + and gathered.tool_calls[-1].get("args") + and "citations" in gathered.tool_calls[-1]["args"] + ): + citations = gathered.tool_calls[-1]["args"]["citations"] + else: + if chunk.get("answer"): + response_tokens += chunk["answer"].content + streamed_chat_history.assistant = chunk["answer"].content + yield f"data: {json.dumps(streamed_chat_history.dict())}" + if chunk.get("docs"): sources = chunk["docs"] - sources_list = generate_source(sources, self.brain_id) - if not streamed_chat_history.metadata: - streamed_chat_history.metadata = {} - # Serialize the sources list + sources_list = generate_source(sources, self.brain_id, citations) + + streamed_chat_history.metadata["citations"] = citations + + # Serialize the sources list serialized_sources_list = [source.dict() for source in sources_list] streamed_chat_history.metadata["sources"] = serialized_sources_list yield f"data: {json.dumps(streamed_chat_history.dict())}" @@ -398,7 +402,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface): except Exception as e: logger.error("Error updating message by ID: %s", e) - def save_non_streaming_answer(self, chat_id, question, answer): + def save_non_streaming_answer(self, chat_id, question, answer, metadata): new_chat = chat_service.update_chat_history( CreateChatHistory( **{ @@ -407,6 +411,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface): "assistant": answer, "brain_id": self.brain.brain_id, "prompt_id": self.prompt_to_use_id, + "metadata": metadata, } ) ) @@ -423,5 +428,6 @@ class KnowledgeBrainQA(BaseModel, QAInterface): "brain_name": self.brain.name if self.brain else None, "message_id": new_chat.message_id, "brain_id": str(self.brain.brain_id) if self.brain else None, + "metadata": metadata, } ) diff --git a/backend/modules/brain/qa_interface.py b/backend/modules/brain/qa_interface.py index 69b9f2c91..9a5a131a2 100644 --- a/backend/modules/brain/qa_interface.py +++ b/backend/modules/brain/qa_interface.py @@ -39,3 +39,20 @@ class QAInterface(ABC): raise NotImplementedError( "generate_stream is an abstract method and must be implemented" ) + + def model_compatible_with_function_calling(self, model: str): + if model in [ + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-turbo-preview", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-4", + "gpt-4-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0613", + ]: + return True + return False diff --git a/backend/modules/brain/rags/quivr_rag.py b/backend/modules/brain/rags/quivr_rag.py index 8487bc33b..9de3d4667 100644 --- a/backend/modules/brain/rags/quivr_rag.py +++ b/backend/modules/brain/rags/quivr_rag.py @@ -1,6 +1,7 @@ +import logging import os from operator import itemgetter -from typing import Optional +from typing import List, Optional from uuid import UUID from langchain.chains import ConversationalRetrievalChain @@ -14,8 +15,10 @@ from langchain_cohere import CohereRerank from langchain_community.chat_models import ChatLiteLLM from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate, PromptTemplate +from langchain_core.pydantic_v1 import BaseModel as BaseModelV1 +from langchain_core.pydantic_v1 import Field as FieldV1 from langchain_core.runnables import RunnableLambda, RunnablePassthrough -from langchain_openai import OpenAIEmbeddings +from langchain_openai import ChatOpenAI, OpenAIEmbeddings from logger import get_logger from models import BrainSettings # Importing settings related to the 'brain' from modules.brain.service.brain_service import BrainService @@ -26,7 +29,20 @@ from pydantic_settings import BaseSettings from supabase.client import Client, create_client from vectorstore.supabase import CustomSupabaseVectorStore -logger = get_logger(__name__) +logger = get_logger(__name__, log_level=logging.INFO) + + +class cited_answer(BaseModelV1): + """Answer the user question based only on the given sources, and cite the sources used.""" + + answer: str = FieldV1( + ..., + description="The answer to the user question, which is based only on the given sources.", + ) + citations: List[int] = FieldV1( + ..., + description="The integer IDs of the SPECIFIC sources which justify the answer.", + ) # First step is to create the Rephrasing Prompt @@ -66,7 +82,9 @@ ANSWER_PROMPT = ChatPromptTemplate.from_messages( # How we format documents -DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") +DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template( + template="Source: {index} \n {page_content}" +) def is_valid_uuid(uuid_to_test, version=4): @@ -116,6 +134,23 @@ class QuivrRAG(BaseModel): else: return None + def model_compatible_with_function_calling(self): + if self.model in [ + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-turbo-preview", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-4", + "gpt-4-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0613", + ]: + return True + return False + supabase_client: Optional[Client] = None vector_store: Optional[CustomSupabaseVectorStore] = None qa: Optional[ConversationalRetrievalChain] = None @@ -197,6 +232,9 @@ class QuivrRAG(BaseModel): def _combine_documents( self, docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" ): + # for each docs, add an index in the metadata to be able to cite the sources + for doc, index in zip(docs, range(len(docs))): + doc.metadata["index"] = index doc_strings = [format_document(doc, document_prompt) for doc in docs] return document_separator.join(doc_strings) @@ -287,14 +325,27 @@ class QuivrRAG(BaseModel): "question": itemgetter("question"), "custom_instructions": itemgetter("custom_instructions"), } + llm = ChatLiteLLM( + max_tokens=self.max_tokens, + model=self.model, + temperature=self.temperature, + api_base=api_base, + ) + if self.model_compatible_with_function_calling(): + + # And finally, we do the part that returns the answers + llm_function = ChatOpenAI( + max_tokens=self.max_tokens, + model=self.model, + temperature=self.temperature, + ) + llm = llm_function.bind_tools( + [cited_answer], + tool_choice="cited_answer", + ) - # And finally, we do the part that returns the answers answer = { - "answer": final_inputs - | ANSWER_PROMPT - | ChatLiteLLM( - max_tokens=self.max_tokens, model=self.model, api_base=api_base - ), + "answer": final_inputs | ANSWER_PROMPT | llm, "docs": itemgetter("docs"), } diff --git a/backend/modules/chat/dto/chats.py b/backend/modules/chat/dto/chats.py index a3eaa0f69..ad03a510f 100644 --- a/backend/modules/chat/dto/chats.py +++ b/backend/modules/chat/dto/chats.py @@ -4,7 +4,7 @@ from uuid import UUID from modules.chat.dto.outputs import GetChatHistoryOutput from modules.notification.entity.notification import Notification -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel class ChatMessage(BaseModel): @@ -33,6 +33,7 @@ class Sources(BaseModel): source_url: str type: str original_file_name: str + citation: str class ChatItemType(Enum): diff --git a/backend/modules/chat/dto/inputs.py b/backend/modules/chat/dto/inputs.py index 47966e58e..d3acab722 100644 --- a/backend/modules/chat/dto/inputs.py +++ b/backend/modules/chat/dto/inputs.py @@ -11,6 +11,7 @@ class CreateChatHistory(BaseModel): assistant: str prompt_id: Optional[UUID] = None brain_id: Optional[UUID] = None + metadata: Optional[dict] = {} class QuestionAndAnswer(BaseModel): diff --git a/backend/modules/chat/repository/chats.py b/backend/modules/chat/repository/chats.py index fbd7ff215..3e32772d5 100644 --- a/backend/modules/chat/repository/chats.py +++ b/backend/modules/chat/repository/chats.py @@ -74,6 +74,7 @@ class Chats(ChatsInterface): "brain_id": ( str(chat_history.brain_id) if chat_history.brain_id else None ), + "metadata": chat_history.metadata if chat_history.metadata else {}, } ) .execute() @@ -104,7 +105,9 @@ class Chats(ChatsInterface): def delete_chat_history(self, chat_id): self.db.table("chat_history").delete().match({"chat_id": chat_id}).execute() - def update_chat_message(self, chat_id, message_id, chat_message_properties: ChatMessageProperties ): + def update_chat_message( + self, chat_id, message_id, chat_message_properties: ChatMessageProperties + ): response = ( self.db.table("chat_history") .update(chat_message_properties)