From 285fe5b96065a19c74f0314557e5840d8722099e Mon Sep 17 00:00:00 2001 From: Jacopo Chevallard Date: Thu, 31 Oct 2024 17:57:54 +0100 Subject: [PATCH] feat: websearch, tool use, user intent, dynamic retrieval, multiple questions (#3424) # Description This PR includes far too many new features: - detection of user intent (closes CORE-211) - treating multiple questions in parallel (closes CORE-212) - using the chat history when answering a question (closes CORE-213) - filtering of retrieved chunks by relevance threshold (closes CORE-217) - dynamic retrieval of chunks (closes CORE-218) - enabling web search via Tavily (closes CORE-220) - enabling agent / assistant to activate tools when relevant to complete the user task (closes CORE-224) Also closes CORE-205 ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --------- Co-authored-by: Stan Girard --- .github/workflows/backend-core-tests.yml | 2 - core/pyproject.toml | 3 +- core/quivr_core/brain/brain.py | 165 ++-- core/quivr_core/brain/brain_defaults.py | 6 +- core/quivr_core/brain/serialization.py | 4 +- core/quivr_core/config.py | 493 ---------- core/quivr_core/llm/llm_endpoint.py | 12 +- core/quivr_core/llm_tools/__init__.py | 0 core/quivr_core/llm_tools/entity.py | 36 + core/quivr_core/llm_tools/llm_tools.py | 33 + core/quivr_core/llm_tools/other_tools.py | 24 + core/quivr_core/llm_tools/web_search_tools.py | 73 ++ core/quivr_core/prompts.py | 119 --- core/quivr_core/quivr_rag_langgraph.py | 488 ---------- core/quivr_core/rag/__init__.py | 0 core/quivr_core/rag/entities/__init__.py | 0 core/quivr_core/{ => rag/entities}/chat.py | 19 +- core/quivr_core/rag/entities/config.py | 452 +++++++++ core/quivr_core/{ => rag/entities}/models.py | 8 +- core/quivr_core/rag/prompts.py | 265 ++++++ core/quivr_core/{ => rag}/quivr_rag.py | 10 +- core/quivr_core/rag/quivr_rag_langgraph.py | 891 ++++++++++++++++++ core/quivr_core/rag/utils.py | 188 ++++ core/quivr_core/utils.py | 167 ---- core/requirements-dev.lock | 6 +- core/requirements.lock | 4 +- core/tests/conftest.py | 2 +- core/tests/fixture_chunks.py | 6 +- core/tests/rag_config.yaml | 4 +- core/tests/rag_config_workflow.yaml | 4 +- core/tests/test_brain.py | 6 +- core/tests/test_chat_history.py | 2 +- core/tests/test_config.py | 29 +- core/tests/test_llm_endpoint.py | 2 +- core/tests/test_quivr_rag.py | 42 +- core/tests/test_utils.py | 15 +- examples/pdf_document_from_yaml.py | 2 +- examples/pdf_parsing_tika.py | 2 +- examples/simple_question/pyproject.toml | 1 + .../simple_question/requirements-dev.lock | 10 +- examples/simple_question/requirements.lock | 10 +- examples/simple_question/simple_question.py | 10 +- .../simple_question_streaming.py | 2 +- 43 files changed, 2165 insertions(+), 1452 deletions(-) create mode 100644 core/quivr_core/llm_tools/__init__.py create mode 100644 core/quivr_core/llm_tools/entity.py create mode 100644 core/quivr_core/llm_tools/llm_tools.py create mode 100644 core/quivr_core/llm_tools/other_tools.py create mode 100644 core/quivr_core/llm_tools/web_search_tools.py delete mode 100644 core/quivr_core/prompts.py delete mode 100644 core/quivr_core/quivr_rag_langgraph.py create mode 100644 core/quivr_core/rag/__init__.py create mode 100644 core/quivr_core/rag/entities/__init__.py rename core/quivr_core/{ => rag/entities}/chat.py (86%) create mode 100644 core/quivr_core/rag/entities/config.py rename core/quivr_core/{ => rag/entities}/models.py (93%) create mode 100644 core/quivr_core/rag/prompts.py rename core/quivr_core/{ => rag}/quivr_rag.py (97%) create mode 100644 core/quivr_core/rag/quivr_rag_langgraph.py create mode 100644 core/quivr_core/rag/utils.py delete mode 100644 core/quivr_core/utils.py rename examples/{ => simple_question}/simple_question_streaming.py (93%) diff --git a/.github/workflows/backend-core-tests.yml b/.github/workflows/backend-core-tests.yml index 013eb7c10..ffb7021bc 100644 --- a/.github/workflows/backend-core-tests.yml +++ b/.github/workflows/backend-core-tests.yml @@ -41,6 +41,4 @@ jobs: sudo apt-get update sudo apt-get install -y libmagic-dev poppler-utils libreoffice tesseract-ocr pandoc cd core - rye run python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()" - rye run python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')" rye test -p quivr-core diff --git a/core/pyproject.toml b/core/pyproject.toml index cd77db6bc..6f3725555 100644 --- a/core/pyproject.toml +++ b/core/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "pydantic>=2.8.2", "langchain-core>=0.2.38", "langchain>=0.2.14,<0.3.0", - "langgraph>=0.2.14", + "langgraph>=0.2.38", "httpx>=0.27.0", "rich>=13.7.1", "tiktoken>=0.7.0", @@ -21,6 +21,7 @@ dependencies = [ "types-pyyaml>=6.0.12.20240808", "transformers[sentencepiece]>=4.44.2", "faiss-cpu>=1.8.0.post1", + "rapidfuzz>=3.10.1", ] readme = "README.md" requires-python = ">= 3.11" diff --git a/core/quivr_core/brain/brain.py b/core/quivr_core/brain/brain.py index 228814cde..0c6220a8e 100644 --- a/core/quivr_core/brain/brain.py +++ b/core/quivr_core/brain/brain.py @@ -10,7 +10,9 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.messages import AIMessage, HumanMessage from langchain_core.vectorstores import VectorStore +from quivr_core.rag.entities.models import ParsedRAGResponse from langchain_openai import OpenAIEmbeddings +from quivr_core.rag.quivr_rag import QuivrQARAG from rich.console import Console from rich.panel import Panel @@ -22,19 +24,17 @@ from quivr_core.brain.serialization import ( LocalStorageConfig, TransparentStorageConfig, ) -from quivr_core.chat import ChatHistory -from quivr_core.config import RetrievalConfig +from quivr_core.rag.entities.chat import ChatHistory +from quivr_core.rag.entities.config import RetrievalConfig from quivr_core.files.file import load_qfile from quivr_core.llm import LLMEndpoint -from quivr_core.models import ( +from quivr_core.rag.entities.models import ( ParsedRAGChunkResponse, - ParsedRAGResponse, QuivrKnowledge, SearchResult, ) from quivr_core.processor.registry import get_processor_class -from quivr_core.quivr_rag import QuivrQARAG -from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph +from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph from quivr_core.storage.local_storage import LocalStorage, TransparentStorage from quivr_core.storage.storage_base import StorageBase @@ -49,19 +49,15 @@ async def process_files( """ Process files in storage. This function takes a StorageBase and return a list of langchain documents. - Args: storage (StorageBase): The storage containing the files to process. skip_file_error (bool): Whether to skip files that cannot be processed. processor_kwargs (dict[str, Any]): Additional arguments for the processor. - Returns: list[Document]: List of processed documents in the Langchain Document format. - Raises: ValueError: If a file cannot be processed and skip_file_error is False. Exception: If no processor is found for a file of a specific type and skip_file_error is False. - """ knowledge = [] @@ -91,23 +87,17 @@ async def process_files( class Brain: """ A class representing a Brain. - This class allows for the creation of a Brain, which is a collection of knowledge one wants to retrieve information from. - A Brain is set to: - * Store files in the storage of your choice (local, S3, etc.) * Process the files in the storage to extract text and metadata in a wide range of format. * Store the processed files in the vector store of your choice (FAISS, PGVector, etc.) - default to FAISS. * Create an index of the processed files. * Use the *Quivr* workflow for the retrieval augmented generation. - A Brain is able to: - * Search for information in the vector store. * Answer questions about the knowledges in the Brain. * Stream the answer to the question. - Attributes: name (str): The name of the brain. id (UUID): The unique identifier of the brain. @@ -115,16 +105,14 @@ class Brain: llm (LLMEndpoint): The language model used to generate the answer. vector_db (VectorStore): The vector store used to store the processed files. embedder (Embeddings): The embeddings used to create the index of the processed files. - - """ def __init__( self, *, name: str, - id: UUID, llm: LLMEndpoint, + id: UUID | None = None, vector_db: VectorStore | None = None, embedder: Embeddings | None = None, storage: StorageBase | None = None, @@ -156,19 +144,15 @@ class Brain: def load(cls, folder_path: str | Path) -> Self: """ Load a brain from a folder path. - Args: folder_path (str | Path): The path to the folder containing the brain. - Returns: Brain: The brain loaded from the folder path. - Example: ```python brain_loaded = Brain.load("path/to/brain") brain_loaded.print_info() ``` - """ if isinstance(folder_path, str): folder_path = Path(folder_path) @@ -217,16 +201,13 @@ class Brain: vector_db=vector_db, ) - async def save(self, folder_path: str | Path) -> str: + async def save(self, folder_path: str | Path): """ Save the brain to a folder path. - Args: folder_path (str | Path): The path to the folder where the brain will be saved. - Returns: str: The path to the folder where the brain was saved. - Example: ```python await brain.save("path/to/brain") @@ -324,10 +305,9 @@ class Brain: embedder: Embeddings | None = None, skip_file_error: bool = False, processor_kwargs: dict[str, Any] | None = None, - ) -> Self: + ): """ Create a brain from a list of file paths. - Args: name (str): The name of the brain. file_paths (list[str | Path]): The list of file paths to add to the brain. @@ -337,10 +317,8 @@ class Brain: embedder (Embeddings | None): The embeddings used to create the index of the processed files. skip_file_error (bool): Whether to skip files that cannot be processed. processor_kwargs (dict[str, Any] | None): Additional arguments for the processor. - Returns: Brain: The brain created from the file paths. - Example: ```python brain = await Brain.afrom_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"]) @@ -429,7 +407,6 @@ class Brain: ) -> Self: """ Create a brain from a list of langchain documents. - Args: name (str): The name of the brain. langchain_documents (list[Document]): The list of langchain documents to add to the brain. @@ -437,10 +414,8 @@ class Brain: storage (StorageBase): The storage used to store the files. llm (LLMEndpoint | None): The language model used to generate the answer. embedder (Embeddings | None): The embeddings used to create the index of the processed files. - Returns: Brain: The brain created from the langchain documents. - Example: ```python from langchain_core.documents import Document @@ -449,6 +424,7 @@ class Brain: brain.print_info() ``` """ + if llm is None: llm = default_llm() @@ -481,16 +457,13 @@ class Brain: ) -> list[SearchResult]: """ Search for relevant documents in the brain based on a query. - Args: query (str | Document): The query to search for. n_results (int): The number of results to return. filter (Callable | Dict[str, Any] | None): The filter to apply to the search. fetch_n_neighbors (int): The number of neighbors to fetch. - Returns: list[SearchResult]: The list of retrieved chunks. - Example: ```python brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"]) @@ -517,57 +490,6 @@ class Brain: # add it to vectorstore raise NotImplementedError - def ask( - self, - question: str, - retrieval_config: RetrievalConfig | None = None, - rag_pipeline: Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None = None, - list_files: list[QuivrKnowledge] | None = None, - chat_history: ChatHistory | None = None, - ) -> ParsedRAGResponse: - """ - Ask a question to the brain and get a generated answer. - - Args: - question (str): The question to ask. - retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs). - rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use. - list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline. - chat_history (ChatHistory | None): The chat history to use. - - Returns: - ParsedRAGResponse: The generated answer. - - Example: - ```python - brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"]) - answer = brain.ask("What is the meaning of life?") - print(answer.answer) - ``` - """ - async def collect_streamed_response(): - full_answer = "" - async for response in self.ask_streaming( - question=question, - retrieval_config=retrieval_config, - rag_pipeline=rag_pipeline, - list_files=list_files, - chat_history=chat_history - ): - full_answer += response.answer - return full_answer - - # Run the async function in the event loop - loop = asyncio.get_event_loop() - full_answer = loop.run_until_complete(collect_streamed_response()) - - chat_history = self.default_chat if chat_history is None else chat_history - chat_history.append(HumanMessage(content=question)) - chat_history.append(AIMessage(content=full_answer)) - - # Return the final response - return ParsedRAGResponse(answer=full_answer) - async def ask_streaming( self, question: str, @@ -578,24 +500,20 @@ class Brain: ) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]: """ Ask a question to the brain and get a streamed generated answer. - Args: question (str): The question to ask. retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs). rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use. - list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline. + list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline. chat_history (ChatHistory | None): The chat history to use. - Returns: AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]: The streamed generated answer. - Example: ```python brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"]) async for chunk in brain.ask_streaming("What is the meaning of life?"): print(chunk.answer) ``` - """ llm = self.llm @@ -630,3 +548,64 @@ class Brain: chat_history.append(AIMessage(content=full_answer)) yield response + async def aask( + self, + question: str, + retrieval_config: RetrievalConfig | None = None, + rag_pipeline: Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None = None, + list_files: list[QuivrKnowledge] | None = None, + chat_history: ChatHistory | None = None, + ) -> ParsedRAGResponse: + """ + Synchronous version that asks a question to the brain and gets a generated answer. + Args: + question (str): The question to ask. + retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs). + rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use. + list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline. + chat_history (ChatHistory | None): The chat history to use. + Returns: + ParsedRAGResponse: The generated answer. + """ + full_answer = "" + + async for response in self.ask_streaming( + question=question, + retrieval_config=retrieval_config, + rag_pipeline=rag_pipeline, + list_files=list_files, + chat_history=chat_history, + ): + full_answer += response.answer + + return ParsedRAGResponse(answer=full_answer) + + def ask( + self, + question: str, + retrieval_config: RetrievalConfig | None = None, + rag_pipeline: Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None = None, + list_files: list[QuivrKnowledge] | None = None, + chat_history: ChatHistory | None = None, + ) -> ParsedRAGResponse: + """ + Fully synchronous version that asks a question to the brain and gets a generated answer. + Args: + question (str): The question to ask. + retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs). + rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use. + list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline. + chat_history (ChatHistory | None): The chat history to use. + Returns: + ParsedRAGResponse: The generated answer. + """ + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.aask( + question=question, + retrieval_config=retrieval_config, + rag_pipeline=rag_pipeline, + list_files=list_files, + chat_history=chat_history, + ) + ) diff --git a/core/quivr_core/brain/brain_defaults.py b/core/quivr_core/brain/brain_defaults.py index a0cf71cde..5e613447f 100644 --- a/core/quivr_core/brain/brain_defaults.py +++ b/core/quivr_core/brain/brain_defaults.py @@ -4,7 +4,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore -from quivr_core.config import LLMEndpointConfig +from quivr_core.rag.entities.config import DefaultModelSuppliers, LLMEndpointConfig from quivr_core.llm import LLMEndpoint logger = logging.getLogger("quivr_core") @@ -46,7 +46,9 @@ def default_embedder() -> Embeddings: def default_llm() -> LLMEndpoint: try: logger.debug("Loaded ChatOpenAI as default LLM for brain") - llm = LLMEndpoint.from_config(LLMEndpointConfig()) + llm = LLMEndpoint.from_config( + LLMEndpointConfig(supplier=DefaultModelSuppliers.OPENAI, model="gpt-4o") + ) return llm except ImportError as e: diff --git a/core/quivr_core/brain/serialization.py b/core/quivr_core/brain/serialization.py index 7b2764a1f..25af26c2f 100644 --- a/core/quivr_core/brain/serialization.py +++ b/core/quivr_core/brain/serialization.py @@ -4,9 +4,9 @@ from uuid import UUID from pydantic import BaseModel, Field, SecretStr -from quivr_core.config import LLMEndpointConfig +from quivr_core.rag.entities.config import LLMEndpointConfig +from quivr_core.rag.entities.models import ChatMessage from quivr_core.files.file import QuivrFileSerialized -from quivr_core.models import ChatMessage class EmbedderConfig(BaseModel): diff --git a/core/quivr_core/config.py b/core/quivr_core/config.py index 47b7f9ecd..2f001c443 100644 --- a/core/quivr_core/config.py +++ b/core/quivr_core/config.py @@ -1,15 +1,8 @@ -import os from enum import Enum -from typing import Dict, List, Optional -from uuid import UUID import yaml from pydantic import BaseModel -from quivr_core.base_config import QuivrBaseConfig -from quivr_core.processor.splitter import SplitterConfig -from quivr_core.prompts import CustomPromptsModel - class PdfParser(str, Enum): LLAMA_PARSE = "llama_parse" @@ -32,489 +25,3 @@ class MegaparseConfig(MegaparseBaseConfig): strategy: str = "fast" llama_parse_api_key: str | None = None pdf_parser: PdfParser = PdfParser.UNSTRUCTURED - - -class BrainConfig(QuivrBaseConfig): - brain_id: UUID | None = None - name: str - - @property - def id(self) -> UUID | None: - return self.brain_id - - -class DefaultRerankers(str, Enum): - """ - Enum representing the default API-based reranker suppliers supported by the application. - - This enum defines the various reranker providers that can be used in the system. - Each enum value corresponds to a specific supplier's identifier and has an - associated default model. - - Attributes: - COHERE (str): Represents Cohere AI as a reranker supplier. - JINA (str): Represents Jina AI as a reranker supplier. - - Methods: - default_model (property): Returns the default model for the selected supplier. - """ - - COHERE = "cohere" - JINA = "jina" - - @property - def default_model(self) -> str: - """ - Get the default model for the selected reranker supplier. - - This property method returns the default model associated with the current - reranker supplier (COHERE or JINA). - - Returns: - str: The name of the default model for the selected supplier. - - Raises: - KeyError: If the current enum value doesn't have a corresponding default model. - """ - # Mapping of suppliers to their default models - return { - self.COHERE: "rerank-multilingual-v3.0", - self.JINA: "jina-reranker-v2-base-multilingual", - }[self] - - -class DefaultModelSuppliers(str, Enum): - """ - Enum representing the default model suppliers supported by the application. - - This enum defines the various AI model providers that can be used as sources - for LLMs in the system. Each enum value corresponds to a specific - supplier's identifier. - - Attributes: - OPENAI (str): Represents OpenAI as a model supplier. - AZURE (str): Represents Azure (Microsoft) as a model supplier. - ANTHROPIC (str): Represents Anthropic as a model supplier. - META (str): Represents Meta as a model supplier. - MISTRAL (str): Represents Mistral AI as a model supplier. - GROQ (str): Represents Groq as a model supplier. - """ - - OPENAI = "openai" - AZURE = "azure" - ANTHROPIC = "anthropic" - META = "meta" - MISTRAL = "mistral" - GROQ = "groq" - - -class LLMConfig(QuivrBaseConfig): - context: int | None = None - tokenizer_hub: str | None = None - - -class LLMModelConfig: - _model_defaults: Dict[DefaultModelSuppliers, Dict[str, LLMConfig]] = { - DefaultModelSuppliers.OPENAI: { - "gpt-4o": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"), - "gpt-4o-mini": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"), - "gpt-4-turbo": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4"), - "gpt-4": LLMConfig(context=8192, tokenizer_hub="Xenova/gpt-4"), - "gpt-3.5-turbo": LLMConfig( - context=16385, tokenizer_hub="Xenova/gpt-3.5-turbo" - ), - "text-embedding-3-large": LLMConfig( - context=8191, tokenizer_hub="Xenova/text-embedding-ada-002" - ), - "text-embedding-3-small": LLMConfig( - context=8191, tokenizer_hub="Xenova/text-embedding-ada-002" - ), - "text-embedding-ada-002": LLMConfig( - context=8191, tokenizer_hub="Xenova/text-embedding-ada-002" - ), - }, - DefaultModelSuppliers.ANTHROPIC: { - "claude-3-5-sonnet": LLMConfig( - context=200000, tokenizer_hub="Xenova/claude-tokenizer" - ), - "claude-3-opus": LLMConfig( - context=200000, tokenizer_hub="Xenova/claude-tokenizer" - ), - "claude-3-sonnet": LLMConfig( - context=200000, tokenizer_hub="Xenova/claude-tokenizer" - ), - "claude-3-haiku": LLMConfig( - context=200000, tokenizer_hub="Xenova/claude-tokenizer" - ), - "claude-2-1": LLMConfig( - context=200000, tokenizer_hub="Xenova/claude-tokenizer" - ), - "claude-2-0": LLMConfig( - context=100000, tokenizer_hub="Xenova/claude-tokenizer" - ), - "claude-instant-1-2": LLMConfig( - context=100000, tokenizer_hub="Xenova/claude-tokenizer" - ), - }, - DefaultModelSuppliers.META: { - "llama-3.1": LLMConfig( - context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer" - ), - "llama-3": LLMConfig( - context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new" - ), - "llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"), - "code-llama": LLMConfig( - context=16384, tokenizer_hub="Xenova/llama-code-tokenizer" - ), - }, - DefaultModelSuppliers.GROQ: { - "llama-3.1": LLMConfig( - context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer" - ), - "llama-3": LLMConfig( - context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new" - ), - "llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"), - "code-llama": LLMConfig( - context=16384, tokenizer_hub="Xenova/llama-code-tokenizer" - ), - }, - DefaultModelSuppliers.MISTRAL: { - "mistral-large": LLMConfig( - context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3" - ), - "mistral-small": LLMConfig( - context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3" - ), - "mistral-nemo": LLMConfig( - context=128000, tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer" - ), - "codestral": LLMConfig( - context=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3" - ), - }, - } - - @classmethod - def get_supplier_by_model_name(cls, model: str) -> DefaultModelSuppliers | None: - # Iterate over the suppliers and their models - for supplier, models in cls._model_defaults.items(): - # Check if the model name or a base part of the model name is in the supplier's models - for base_model_name in models: - if model.startswith(base_model_name): - return supplier - # Return None if no supplier matches the model name - return None - - @classmethod - def get_llm_model_config( - cls, supplier: DefaultModelSuppliers, model_name: str - ) -> Optional[LLMConfig]: - """Retrieve the LLMConfig (context and tokenizer_hub) for a given supplier and model.""" - supplier_defaults = cls._model_defaults.get(supplier) - if not supplier_defaults: - return None - - # Use startswith logic for matching model names - for key, config in supplier_defaults.items(): - if model_name.startswith(key): - return config - - return None - - -class LLMEndpointConfig(QuivrBaseConfig): - """ - Configuration class for Large Language Models (LLM) endpoints. - - This class defines the settings and parameters for interacting with various LLM providers. - It includes configuration for the model, API keys, token limits, and other relevant settings. - - Attributes: - supplier (DefaultModelSuppliers): The LLM provider (default: OPENAI). - model (str): The specific model to use (default: "gpt-4o"). - context_length (int | None): The maximum context length for the model. - tokenizer_hub (str | None): The tokenizer to use for this model. - llm_base_url (str | None): Base URL for the LLM API. - env_variable_name (str): Name of the environment variable for the API key. - llm_api_key (str | None): The API key for the LLM provider. - max_input_tokens (int): Maximum number of input tokens sent to the LLM (default: 2000). - max_output_tokens (int): Maximum number of output tokens returned by the LLM (default: 2000). - temperature (float): Temperature setting for text generation (default: 0.7). - streaming (bool): Whether to use streaming for responses (default: True). - prompt (CustomPromptsModel | None): Custom prompt configuration. - """ - - supplier: DefaultModelSuppliers = DefaultModelSuppliers.OPENAI - model: str = "gpt-4o" - context_length: int | None = None - tokenizer_hub: str | None = None - llm_base_url: str | None = None - env_variable_name: str = f"{supplier.upper()}_API_KEY" - llm_api_key: str | None = None - max_input_tokens: int = 2000 - max_output_tokens: int = 2000 - temperature: float = 0.7 - streaming: bool = True - prompt: CustomPromptsModel | None = None - - _FALLBACK_TOKENIZER = "cl100k_base" - - @property - def fallback_tokenizer(self) -> str: - """ - Get the fallback tokenizer. - - Returns: - str: The name of the fallback tokenizer. - """ - return self._FALLBACK_TOKENIZER - - def __init__(self, **data): - """ - Initialize the LLMEndpointConfig. - - This method sets up the initial configuration, including setting the LLM model - config and API key. - - """ - super().__init__(**data) - self.set_llm_model_config() - self.set_api_key() - - def set_api_key(self, force_reset: bool = False): - """ - Set the API key for the LLM provider. - - This method attempts to set the API key from the environment variable. - If the key is not found, it raises a ValueError. - - Args: - force_reset (bool): If True, forces a reset of the API key even if already set. - - Raises: - ValueError: If the API key is not set in the environment. - """ - if not self.llm_api_key or force_reset: - self.llm_api_key = os.getenv(self.env_variable_name) - - if not self.llm_api_key: - raise ValueError( - f"The API key for supplier '{self.supplier}' is not set. " - f"Please set the environment variable: {self.env_variable_name}" - ) - - def set_llm_model_config(self): - """ - Set the LLM model configuration. - - This method automatically sets the context_length and tokenizer_hub - based on the current supplier and model. - """ - llm_model_config = LLMModelConfig.get_llm_model_config( - self.supplier, self.model - ) - if llm_model_config: - self.context_length = llm_model_config.context - self.tokenizer_hub = llm_model_config.tokenizer_hub - - def set_llm_model(self, model: str): - """ - Set the LLM model and update related configurations. - - This method updates the supplier and model based on the provided model name, - then updates the model config and API key accordingly. - - Args: - model (str): The name of the model to set. - - Raises: - ValueError: If no corresponding supplier is found for the given model. - """ - supplier = LLMModelConfig.get_supplier_by_model_name(model) - if supplier is None: - raise ValueError( - f"Cannot find the corresponding supplier for model {model}" - ) - self.supplier = supplier - self.model = model - - self.set_llm_model_config() - self.set_api_key(force_reset=True) - - def set_from_sqlmodel(self, sqlmodel: BaseModel, mapping: Dict[str, str]): - """ - Set attributes in LLMEndpointConfig from SQLModel attributes using a field mapping. - - This method allows for dynamic setting of LLMEndpointConfig attributes based on - a provided SQLModel instance and a mapping dictionary. - - Args: - sqlmodel (SQLModel): An instance of the SQLModel class. - mapping (Dict[str, str]): A dictionary that maps SQLModel fields to LLMEndpointConfig fields. - Example: {"max_input": "max_input_tokens", "env_variable_name": "env_variable_name"} - - Raises: - AttributeError: If any field in the mapping doesn't exist in either the SQLModel or LLMEndpointConfig. - """ - for model_field, llm_field in mapping.items(): - if hasattr(sqlmodel, model_field) and hasattr(self, llm_field): - setattr(self, llm_field, getattr(sqlmodel, model_field)) - else: - raise AttributeError( - f"Invalid mapping: {model_field} or {llm_field} does not exist." - ) - - -# Cannot use Pydantic v2 field_validator because of conflicts with pydantic v1 still in use in LangChain -class RerankerConfig(QuivrBaseConfig): - """ - Configuration class for reranker models. - - This class defines the settings for reranker models used in the application, - including the supplier, model, and API key information. - - Attributes: - supplier (DefaultRerankers | None): The reranker supplier (e.g., COHERE). - model (str | None): The specific reranker model to use. - top_n (int): The number of top chunks returned by the reranker (default: 5). - api_key (str | None): The API key for the reranker service. - """ - - supplier: DefaultRerankers | None = None - model: str | None = None - top_n: int = 5 - api_key: str | None = None - - def __init__(self, **data): - """ - Initialize the RerankerConfig. - """ - super().__init__(**data) - self.validate_model() - - def validate_model(self): - """ - Validate and set up the reranker model configuration. - - This method ensures that a model is set (using the default if not provided) - and that the necessary API key is available in the environment. - - Raises: - ValueError: If the required API key is not set in the environment. - """ - if self.model is None and self.supplier is not None: - self.model = self.supplier.default_model - - if self.supplier: - api_key_var = f"{self.supplier.upper()}_API_KEY" - self.api_key = os.getenv(api_key_var) - - if self.api_key is None: - raise ValueError( - f"The API key for supplier '{self.supplier}' is not set. " - f"Please set the environment variable: {api_key_var}" - ) - - -class NodeConfig(QuivrBaseConfig): - """ - Configuration class for a node in an AI assistant workflow. - - This class represents a single node in a workflow configuration, - defining its name and connections to other nodes. - - Attributes: - name (str): The name of the node. - edges (List[str]): List of names of other nodes this node links to. - """ - - name: str - edges: List[str] - - -class WorkflowConfig(QuivrBaseConfig): - """ - Configuration class for an AI assistant workflow. - - This class represents the entire workflow configuration, - consisting of multiple interconnected nodes. - - Attributes: - name (str): The name of the workflow. - nodes (List[NodeConfig]): List of nodes in the workflow. - """ - - name: str - nodes: List[NodeConfig] - - -class RetrievalConfig(QuivrBaseConfig): - """ - Configuration class for the retrieval phase of a RAG assistant. - - This class defines the settings for the retrieval process, - including reranker and LLM configurations, as well as various limits and prompts. - - Attributes: - workflow_config (WorkflowConfig | None): Configuration for the workflow. - reranker_config (RerankerConfig): Configuration for the reranker. - llm_config (LLMEndpointConfig): Configuration for the LLM endpoint. - max_history (int): Maximum number of past conversation turns to pass to the LLM as context (default: 10). - max_files (int): Maximum number of files to process (default: 20). - prompt (str | None): Custom prompt for the retrieval process. - """ - - workflow_config: WorkflowConfig | None = None - reranker_config: RerankerConfig = RerankerConfig() - llm_config: LLMEndpointConfig = LLMEndpointConfig() - max_history: int = 10 - max_files: int = 20 - prompt: str | None = None - - -class ParserConfig(QuivrBaseConfig): - """ - Configuration class for the parser. - - This class defines the settings for the parsing process, - including configurations for the text splitter and Megaparse. - - Attributes: - splitter_config (SplitterConfig): Configuration for the text splitter. - megaparse_config (MegaparseConfig): Configuration for Megaparse. - """ - - splitter_config: SplitterConfig = SplitterConfig() - megaparse_config: MegaparseConfig = MegaparseConfig() - - -class IngestionConfig(QuivrBaseConfig): - """ - Configuration class for the data ingestion process. - - This class defines the settings for the data ingestion process, - including the parser configuration. - - Attributes: - parser_config (ParserConfig): Configuration for the parser. - """ - - parser_config: ParserConfig = ParserConfig() - - -class AssistantConfig(QuivrBaseConfig): - """ - Configuration class for an AI assistant. - - This class defines the overall configuration for an AI assistant, - including settings for retrieval and ingestion processes. - - Attributes: - retrieval_config (RetrievalConfig): Configuration for the retrieval process. - ingestion_config (IngestionConfig): Configuration for the ingestion process. - """ - - retrieval_config: RetrievalConfig = RetrievalConfig() - ingestion_config: IngestionConfig = IngestionConfig() diff --git a/core/quivr_core/llm/llm_endpoint.py b/core/quivr_core/llm/llm_endpoint.py index e26c0e6bf..d4ffcf87a 100644 --- a/core/quivr_core/llm/llm_endpoint.py +++ b/core/quivr_core/llm/llm_endpoint.py @@ -10,8 +10,8 @@ from langchain_openai import AzureChatOpenAI, ChatOpenAI from pydantic.v1 import SecretStr from quivr_core.brain.info import LLMInfo -from quivr_core.config import DefaultModelSuppliers, LLMEndpointConfig -from quivr_core.utils import model_supports_function_calling +from quivr_core.rag.entities.config import DefaultModelSuppliers, LLMEndpointConfig +from quivr_core.rag.utils import model_supports_function_calling logger = logging.getLogger("quivr_core") @@ -70,6 +70,7 @@ class LLMEndpoint: else None, azure_endpoint=azure_endpoint, max_tokens=config.max_output_tokens, + temperature=config.temperature, ) elif config.supplier == DefaultModelSuppliers.ANTHROPIC: _llm = ChatAnthropic( @@ -79,6 +80,7 @@ class LLMEndpoint: else None, base_url=config.llm_base_url, max_tokens=config.max_output_tokens, + temperature=config.temperature, ) elif config.supplier == DefaultModelSuppliers.OPENAI: _llm = ChatOpenAI( @@ -88,6 +90,7 @@ class LLMEndpoint: else None, base_url=config.llm_base_url, max_tokens=config.max_output_tokens, + temperature=config.temperature, ) else: _llm = ChatOpenAI( @@ -97,6 +100,7 @@ class LLMEndpoint: else None, base_url=config.llm_base_url, max_tokens=config.max_output_tokens, + temperature=config.temperature, ) return cls(llm=_llm, llm_config=config) @@ -118,3 +122,7 @@ class LLMEndpoint: max_tokens=self._config.max_output_tokens, supports_function_calling=self.supports_func_calling(), ) + + def clone_llm(self): + """Create a new instance of the LLM with the same configuration.""" + return self._llm.__class__(**self._llm.__dict__) diff --git a/core/quivr_core/llm_tools/__init__.py b/core/quivr_core/llm_tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/core/quivr_core/llm_tools/entity.py b/core/quivr_core/llm_tools/entity.py new file mode 100644 index 000000000..449377a3a --- /dev/null +++ b/core/quivr_core/llm_tools/entity.py @@ -0,0 +1,36 @@ +from quivr_core.base_config import QuivrBaseConfig +from typing import Callable +from langchain_core.tools import BaseTool +from typing import Dict, Any + + +class ToolsCategory(QuivrBaseConfig): + name: str + description: str + tools: list + default_tool: str | None = None + create_tool: Callable + + def __init__(self, **data): + super().__init__(**data) + self.name = self.name.lower() + + +class ToolWrapper: + def __init__(self, tool: BaseTool, format_input: Callable, format_output: Callable): + self.tool = tool + self.format_input = format_input + self.format_output = format_output + + +class ToolRegistry: + def __init__(self): + self._registry = {} + + def register_tool(self, tool_name: str, create_func: Callable): + self._registry[tool_name] = create_func + + def create_tool(self, tool_name: str, config: Dict[str, Any]) -> ToolWrapper: + if tool_name not in self._registry: + raise ValueError(f"Tool {tool_name} is not supported.") + return self._registry[tool_name](config) diff --git a/core/quivr_core/llm_tools/llm_tools.py b/core/quivr_core/llm_tools/llm_tools.py new file mode 100644 index 000000000..6e35bdcdc --- /dev/null +++ b/core/quivr_core/llm_tools/llm_tools.py @@ -0,0 +1,33 @@ +from typing import Dict, Any, Type, Union + +from quivr_core.llm_tools.entity import ToolWrapper + +from quivr_core.llm_tools.web_search_tools import ( + WebSearchTools, +) + +from quivr_core.llm_tools.other_tools import ( + OtherTools, +) + +TOOLS_CATEGORIES = { + WebSearchTools.name: WebSearchTools, + OtherTools.name: OtherTools, +} + +# Register all ToolsList enums +TOOLS_LISTS = { + **{tool.value: tool for tool in WebSearchTools.tools}, + **{tool.value: tool for tool in OtherTools.tools}, +} + + +class LLMToolFactory: + @staticmethod + def create_tool(tool_name: str, config: Dict[str, Any]) -> Union[ToolWrapper, Type]: + for category, tools_class in TOOLS_CATEGORIES.items(): + if tool_name in tools_class.tools: + return tools_class.create_tool(tool_name, config) + elif tool_name.lower() == category and tools_class.default_tool: + return tools_class.create_tool(tools_class.default_tool, config) + raise ValueError(f"Tool {tool_name} is not supported.") diff --git a/core/quivr_core/llm_tools/other_tools.py b/core/quivr_core/llm_tools/other_tools.py new file mode 100644 index 000000000..3efa8b5b9 --- /dev/null +++ b/core/quivr_core/llm_tools/other_tools.py @@ -0,0 +1,24 @@ +from enum import Enum +from typing import Dict, Any, Type, Union +from langchain_core.tools import BaseTool +from quivr_core.llm_tools.entity import ToolsCategory +from quivr_core.rag.entities.models import cited_answer + + +class OtherToolsList(str, Enum): + CITED_ANSWER = "cited_answer" + + +def create_other_tool(tool_name: str, config: Dict[str, Any]) -> Union[BaseTool, Type]: + if tool_name == OtherToolsList.CITED_ANSWER: + return cited_answer + else: + raise ValueError(f"Tool {tool_name} is not supported.") + + +OtherTools = ToolsCategory( + name="Other", + description="Other tools", + tools=[OtherToolsList.CITED_ANSWER], + create_tool=create_other_tool, +) diff --git a/core/quivr_core/llm_tools/web_search_tools.py b/core/quivr_core/llm_tools/web_search_tools.py new file mode 100644 index 000000000..fed56ce38 --- /dev/null +++ b/core/quivr_core/llm_tools/web_search_tools.py @@ -0,0 +1,73 @@ +from enum import Enum +from typing import Dict, List, Any +from langchain_community.tools import TavilySearchResults +from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper +from quivr_core.llm_tools.entity import ToolsCategory +import os +from pydantic.v1 import SecretStr as SecretStrV1 # Ensure correct import +from quivr_core.llm_tools.entity import ToolWrapper, ToolRegistry +from langchain_core.documents import Document + + +class WebSearchToolsList(str, Enum): + TAVILY = "tavily" + + +def create_tavily_tool(config: Dict[str, Any]) -> ToolWrapper: + api_key = ( + config.pop("api_key") if "api_key" in config else os.getenv("TAVILY_API_KEY") + ) + if not api_key: + raise ValueError( + "Missing required config key 'api_key' or environment variable 'TAVILY_API_KEY'" + ) + + tavily_api_wrapper = TavilySearchAPIWrapper( + tavily_api_key=SecretStrV1(api_key), + ) + tool = TavilySearchResults( + api_wrapper=tavily_api_wrapper, + max_results=config.pop("max_results", 5), + search_depth=config.pop("search_depth", "advanced"), + include_answer=config.pop("include_answer", True), + **config, + ) + + tool.name = WebSearchToolsList.TAVILY.value + + def format_input(task: str) -> Dict[str, Any]: + return {"query": task} + + def format_output(response: Any) -> List[Document]: + metadata = {"integration": "", "integration_link": ""} + return [ + Document( + page_content=d["content"], + metadata={ + **metadata, + "file_name": d["url"], + "original_file_name": d["url"], + }, + ) + for d in response + ] + + return ToolWrapper(tool, format_input, format_output) + + +# Initialize the registry and register tools +web_search_tool_registry = ToolRegistry() +web_search_tool_registry.register_tool(WebSearchToolsList.TAVILY, create_tavily_tool) + + +def create_web_search_tool(tool_name: str, config: Dict[str, Any]) -> ToolWrapper: + return web_search_tool_registry.create_tool(tool_name, config) + + +WebSearchTools = ToolsCategory( + name="Web Search", + description="Tools for web searching", + tools=[WebSearchToolsList.TAVILY], + default_tool=WebSearchToolsList.TAVILY, + create_tool=create_web_search_tool, +) diff --git a/core/quivr_core/prompts.py b/core/quivr_core/prompts.py deleted file mode 100644 index 48ec90a05..000000000 --- a/core/quivr_core/prompts.py +++ /dev/null @@ -1,119 +0,0 @@ -import datetime - -from langchain_core.prompts import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - MessagesPlaceholder, - PromptTemplate, - SystemMessagePromptTemplate, -) -from langchain_core.prompts.base import BasePromptTemplate -from pydantic import ConfigDict, create_model - - -class CustomPromptsDict(dict): - def __init__(self, type, *args, **kwargs): - super().__init__(*args, **kwargs) - self._type = type - - def __setitem__(self, key, value): - # Automatically convert the value into a tuple (my_type, value) - super().__setitem__(key, (self._type, value)) - - -def _define_custom_prompts() -> CustomPromptsDict: - custom_prompts: CustomPromptsDict = CustomPromptsDict(type=BasePromptTemplate) - - today_date = datetime.datetime.now().strftime("%B %d, %Y") - - # --------------------------------------------------------------------------- - # Prompt for question rephrasing - # --------------------------------------------------------------------------- - _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. Keep as much details as possible from previous messages. Keep entity names and all. - - Chat History: - {chat_history} - Follow Up Input: {question} - Standalone question:""" - - CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) - custom_prompts["CONDENSE_QUESTION_PROMPT"] = CONDENSE_QUESTION_PROMPT - - # --------------------------------------------------------------------------- - # Prompt for RAG - # --------------------------------------------------------------------------- - system_message_template = ( - f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}." - ) - - system_message_template += """ - When answering use markdown. - Use markdown code blocks for code snippets. - Answer in a concise and clear manner. - Use the following pieces of context from files provided by the user to answer the users. - Answer in the same language as the user question. - If you don't know the answer with the context provided from the files, just say that you don't know, don't try to make up an answer. - Don't cite the source id in the answer objects, but you can use the source to answer the question. - You have access to the files to answer the user question (limited to first 20 files): - {files} - - If not None, User instruction to follow to answer: {custom_instructions} - Don't cite the source id in the answer objects, but you can use the source to answer the question. - """ - - template_answer = """ - Context: - {context} - - User Question: {question} - Answer: - """ - - RAG_ANSWER_PROMPT = ChatPromptTemplate.from_messages( - [ - SystemMessagePromptTemplate.from_template(system_message_template), - HumanMessagePromptTemplate.from_template(template_answer), - ] - ) - custom_prompts["RAG_ANSWER_PROMPT"] = RAG_ANSWER_PROMPT - - # --------------------------------------------------------------------------- - # Prompt for formatting documents - # --------------------------------------------------------------------------- - DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template( - template="Source: {index} \n {page_content}" - ) - custom_prompts["DEFAULT_DOCUMENT_PROMPT"] = DEFAULT_DOCUMENT_PROMPT - - # --------------------------------------------------------------------------- - # Prompt for chatting directly with LLMs, without any document retrieval stage - # --------------------------------------------------------------------------- - system_message_template = ( - f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}." - ) - system_message_template += """ - If not None, also follow these user instructions when answering: {custom_instructions} - """ - - template_answer = """ - User Question: {question} - Answer: - """ - CHAT_LLM_PROMPT = ChatPromptTemplate.from_messages( - [ - SystemMessagePromptTemplate.from_template(system_message_template), - MessagesPlaceholder(variable_name="chat_history"), - HumanMessagePromptTemplate.from_template(template_answer), - ] - ) - custom_prompts["CHAT_LLM_PROMPT"] = CHAT_LLM_PROMPT - - return custom_prompts - - -_custom_prompts = _define_custom_prompts() -CustomPromptsModel = create_model( - "CustomPromptsModel", **_custom_prompts, __config__=ConfigDict(extra="forbid") -) - -custom_prompts = CustomPromptsModel() diff --git a/core/quivr_core/quivr_rag_langgraph.py b/core/quivr_core/quivr_rag_langgraph.py deleted file mode 100644 index 12d0bea45..000000000 --- a/core/quivr_core/quivr_rag_langgraph.py +++ /dev/null @@ -1,488 +0,0 @@ -import logging -from enum import Enum -from typing import Annotated, AsyncGenerator, Optional, Sequence, TypedDict -from uuid import uuid4 - -# TODO(@aminediro): this is the only dependency to langchain package, we should remove it -from langchain.retrievers import ContextualCompressionRetriever -from langchain_cohere import CohereRerank -from langchain_community.document_compressors import JinaRerank -from langchain_core.callbacks import Callbacks -from langchain_core.documents import BaseDocumentCompressor, Document -from langchain_core.messages import BaseMessage -from langchain_core.messages.ai import AIMessageChunk -from langchain_core.vectorstores import VectorStore -from langgraph.graph import END, START, StateGraph -from langgraph.graph.message import add_messages - -from quivr_core.chat import ChatHistory -from quivr_core.config import DefaultRerankers, RetrievalConfig -from quivr_core.llm import LLMEndpoint -from quivr_core.models import ( - ParsedRAGChunkResponse, - ParsedRAGResponse, - QuivrKnowledge, - RAGResponseMetadata, - cited_answer, -) -from quivr_core.prompts import custom_prompts -from quivr_core.utils import ( - combine_documents, - format_file_list, - get_chunk_metadata, - parse_chunk_response, - parse_response, -) - -logger = logging.getLogger("quivr_core") - - -class SpecialEdges(str, Enum): - START = "START" - END = "END" - - -class AgentState(TypedDict): - # The add_messages function defines how an update should be processed - # Default is to replace. add_messages says "append" - messages: Annotated[Sequence[BaseMessage], add_messages] - chat_history: ChatHistory - docs: list[Document] - files: str - final_response: dict - - -class IdempotentCompressor(BaseDocumentCompressor): - def compress_documents( - self, - documents: Sequence[Document], - query: str, - callbacks: Optional[Callbacks] = None, - ) -> Sequence[Document]: - """ - A no-op document compressor that simply returns the documents it is given. - - This is a placeholder until a more sophisticated document compression - algorithm is implemented. - """ - return documents - - -class QuivrQARAGLangGraph: - def __init__( - self, - *, - retrieval_config: RetrievalConfig, - llm: LLMEndpoint, - vector_store: VectorStore | None = None, - reranker: BaseDocumentCompressor | None = None, - ): - """ - Construct a QuivrQARAGLangGraph object. - - Args: - retrieval_config (RetrievalConfig): The configuration for the RAG model. - llm (LLMEndpoint): The LLM to use for generating text. - vector_store (VectorStore): The vector store to use for storing and retrieving documents. - reranker (BaseDocumentCompressor | None): The document compressor to use for re-ranking documents. Defaults to IdempotentCompressor if not provided. - """ - self.retrieval_config = retrieval_config - self.vector_store = vector_store - self.llm_endpoint = llm - - self.graph = None - - if reranker is not None: - self.reranker = reranker - elif self.retrieval_config.reranker_config.supplier == DefaultRerankers.COHERE: - self.reranker = CohereRerank( - model=self.retrieval_config.reranker_config.model, - top_n=self.retrieval_config.reranker_config.top_n, - cohere_api_key=self.retrieval_config.reranker_config.api_key, - ) - elif self.retrieval_config.reranker_config.supplier == DefaultRerankers.JINA: - self.reranker = JinaRerank( - model=self.retrieval_config.reranker_config.model, - top_n=self.retrieval_config.reranker_config.top_n, - jina_api_key=self.retrieval_config.reranker_config.api_key, - ) - else: - self.reranker = IdempotentCompressor() - - if self.vector_store: - self.compression_retriever = ContextualCompressionRetriever( - base_compressor=self.reranker, base_retriever=self.retriever - ) - - @property - def retriever(self): - """ - Returns a retriever that can retrieve documents from the vector store. - - Returns: - VectorStoreRetriever: The retriever. - """ - if self.vector_store: - return self.vector_store.as_retriever() - else: - raise ValueError("No vector store provided") - - def filter_history(self, state: AgentState) -> dict: - """ - Filter out the chat history to only include the messages that are relevant to the current question - - Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '), - AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer, - sous la direction de son supérieur hiérarchique, Stanislas Girard."), - HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''), - HumanMessage(content='Dis moi en plus sur elle'), - AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")] - Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair - a token is 4 characters - """ - chat_history = state["chat_history"] - total_tokens = 0 - total_pairs = 0 - _chat_id = uuid4() - _chat_history = ChatHistory(chat_id=_chat_id, brain_id=chat_history.brain_id) - for human_message, ai_message in reversed(list(chat_history.iter_pairs())): - # TODO: replace with tiktoken - message_tokens = self.llm_endpoint.count_tokens( - human_message.content - ) + self.llm_endpoint.count_tokens(ai_message.content) - if ( - total_tokens + message_tokens - > self.retrieval_config.llm_config.max_output_tokens - or total_pairs >= self.retrieval_config.max_history - ): - break - _chat_history.append(human_message) - _chat_history.append(ai_message) - total_tokens += message_tokens - total_pairs += 1 - - return {"chat_history": _chat_history} - - ### Nodes - def rewrite(self, state): - """ - Transform the query to produce a better question. - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - - # Grader - msg = custom_prompts.CONDENSE_QUESTION_PROMPT.format( - chat_history=state["chat_history"], - question=state["messages"][0].content, - ) - - model = self.llm_endpoint._llm - response = model.invoke(msg) - return {"messages": [response]} - - def retrieve(self, state): - """ - Retrieve relevent chunks - - Args: - state (messages): The current state - - Returns: - dict: The retrieved chunks - """ - question = state["messages"][-1].content - docs = self.compression_retriever.invoke(question) - return {"docs": docs} - - def generate_rag(self, state): - """ - Generate answer - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - messages = state["messages"] - user_question = messages[0].content - files = state["files"] - - docs = state["docs"] - - # Prompt - prompt = self.retrieval_config.prompt - - final_inputs = {} - final_inputs["context"] = combine_documents(docs) if docs else "None" - final_inputs["question"] = user_question - final_inputs["custom_instructions"] = prompt if prompt else "None" - final_inputs["files"] = files if files else "None" - - # LLM - llm = self.llm_endpoint._llm - if self.llm_endpoint.supports_func_calling(): - llm = self.llm_endpoint._llm.bind_tools( - [cited_answer], - tool_choice="any", - ) - - # Chain - rag_chain = custom_prompts.RAG_ANSWER_PROMPT | llm - - # Run - response = rag_chain.invoke(final_inputs) - formatted_response = { - "answer": response, # Assuming the last message contains the final answer - "docs": docs, - } - return {"messages": [response], "final_response": formatted_response} - - def generate_chat_llm(self, state): - """ - Generate answer - - Args: - state (messages): The current state - - Returns: - dict: The updated state with re-phrased question - """ - messages = state["messages"] - user_question = messages[0].content - - # Prompt - prompt = self.retrieval_config.prompt - - final_inputs = {} - final_inputs["question"] = user_question - final_inputs["custom_instructions"] = prompt if prompt else "None" - final_inputs["chat_history"] = state["chat_history"].to_list() - - # LLM - llm = self.llm_endpoint._llm - - # Chain - rag_chain = custom_prompts.CHAT_LLM_PROMPT | llm - - # Run - response = rag_chain.invoke(final_inputs) - formatted_response = { - "answer": response, # Assuming the last message contains the final answer - } - return {"messages": [response], "final_response": formatted_response} - - def build_chain(self): - """ - Builds the langchain chain for the given configuration. - - Returns: - Callable[[Dict], Dict]: The langchain chain. - """ - if not self.graph: - self.graph = self.create_graph() - - return self.graph - - def create_graph(self): - """ - Builds the langchain chain for the given configuration. - - This function creates a state machine which takes a chat history and a question - and produces an answer. The state machine consists of the following states: - - - filter_history: Filter the chat history (i.e., remove the last message) - - rewrite: Re-write the question using the filtered history - - retrieve: Retrieve documents related to the re-written question - - generate: Generate an answer using the retrieved documents - - The state machine starts in the filter_history state and transitions as follows: - filter_history -> rewrite -> retrieve -> generate -> END - - The final answer is returned as a dictionary with the answer and the list of documents - used to generate the answer. - - Returns: - Callable[[Dict], Dict]: The langchain chain. - """ - workflow = StateGraph(AgentState) - - if self.retrieval_config.workflow_config: - if SpecialEdges.START not in [ - node.name for node in self.retrieval_config.workflow_config.nodes - ]: - raise ValueError("The workflow should contain a 'START' node") - for node in self.retrieval_config.workflow_config.nodes: - if node.name not in SpecialEdges._value2member_map_: - workflow.add_node(node.name, getattr(self, node.name)) - - for node in self.retrieval_config.workflow_config.nodes: - for edge in node.edges: - if node.name == SpecialEdges.START: - workflow.add_edge(START, edge) - elif edge == SpecialEdges.END: - workflow.add_edge(node.name, END) - else: - workflow.add_edge(node.name, edge) - else: - # Define the nodes we will cycle between - workflow.add_node("filter_history", self.filter_history) - workflow.add_node("rewrite", self.rewrite) # Re-writing the question - workflow.add_node("retrieve", self.retrieve) # retrieval - workflow.add_node("generate", self.generate_rag) - - # Add node for filtering history - - workflow.set_entry_point("filter_history") - workflow.add_edge("filter_history", "rewrite") - workflow.add_edge("rewrite", "retrieve") - workflow.add_edge("retrieve", "generate") - workflow.add_edge( - "generate", END - ) # Add edge from generate to format_response - - # Compile - graph = workflow.compile() - return graph - - def answer( - self, - question: str, - history: ChatHistory, - list_files: list[QuivrKnowledge], - metadata: dict[str, str] = {}, - ) -> ParsedRAGResponse: - """ - Answer a question using the langgraph chain. - - Args: - question (str): The question to answer. - history (ChatHistory): The chat history to use for context. - list_files (list[QuivrKnowledge]): The list of files to use for retrieval. - metadata (dict[str, str], optional): The metadata to pass to the langchain invocation. Defaults to {}. - - Returns: - ParsedRAGResponse: The answer to the question. - """ - concat_list_files = format_file_list( - list_files, self.retrieval_config.max_files - ) - conversational_qa_chain = self.build_chain() - inputs = { - "messages": [ - ("user", question), - ], - "chat_history": history, - "files": concat_list_files, - } - raw_llm_response = conversational_qa_chain.invoke( - inputs, - config={"metadata": metadata}, - ) - response = parse_response( - raw_llm_response["final_response"], self.retrieval_config.llm_config.model - ) - return response - - async def answer_astream( - self, - question: str, - history: ChatHistory, - list_files: list[QuivrKnowledge], - metadata: dict[str, str] = {}, - ) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]: - """ - Answer a question using the langgraph chain and yield each chunk of the answer separately. - - Args: - question (str): The question to answer. - history (ChatHistory): The chat history to use for context. - list_files (list[QuivrKnowledge]): The list of files to use for retrieval. - metadata (dict[str, str], optional): The metadata to pass to the langchain invocation. Defaults to {}. - - Yields: - ParsedRAGChunkResponse: Each chunk of the answer. - """ - concat_list_files = format_file_list( - list_files, self.retrieval_config.max_files - ) - conversational_qa_chain = self.build_chain() - - rolling_message = AIMessageChunk(content="") - sources: list[Document] | None = None - prev_answer = "" - chunk_id = 0 - - async for event in conversational_qa_chain.astream_events( - { - "messages": [ - ("user", question), - ], - "chat_history": history, - "files": concat_list_files, - }, - version="v1", - config={"metadata": metadata}, - ): - kind = event["event"] - if ( - not sources - and "output" in event["data"] - and "docs" in event["data"]["output"] - ): - sources = event["data"]["output"]["docs"] - - if ( - kind == "on_chat_model_stream" - and "generate" in event["metadata"]["langgraph_node"] - ): - chunk = event["data"]["chunk"] - rolling_message, answer_str = parse_chunk_response( - rolling_message, - chunk, - self.llm_endpoint.supports_func_calling(), - ) - if len(answer_str) > 0: - if ( - self.llm_endpoint.supports_func_calling() - and rolling_message.tool_calls - ): - diff_answer = answer_str[len(prev_answer) :] - if len(diff_answer) > 0: - parsed_chunk = ParsedRAGChunkResponse( - answer=diff_answer, - metadata=RAGResponseMetadata(), - ) - prev_answer += diff_answer - - logger.debug( - f"answer_astream func_calling=True question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}" - ) - yield parsed_chunk - else: - parsed_chunk = ParsedRAGChunkResponse( - answer=answer_str, - metadata=RAGResponseMetadata(), - ) - logger.debug( - f"answer_astream func_calling=False question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}" - ) - yield parsed_chunk - - chunk_id += 1 - - # Last chunk provides metadata - last_chunk = ParsedRAGChunkResponse( - answer="", - metadata=get_chunk_metadata(rolling_message, sources), - last_chunk=True, - ) - logger.debug( - f"answer_astream last_chunk={last_chunk} question={question} rolling_msg={rolling_message} chunk_id={chunk_id}" - ) - yield last_chunk diff --git a/core/quivr_core/rag/__init__.py b/core/quivr_core/rag/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/core/quivr_core/rag/entities/__init__.py b/core/quivr_core/rag/entities/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/core/quivr_core/chat.py b/core/quivr_core/rag/entities/chat.py similarity index 86% rename from core/quivr_core/chat.py rename to core/quivr_core/rag/entities/chat.py index f5ea0692c..ee3e10767 100644 --- a/core/quivr_core/chat.py +++ b/core/quivr_core/rag/entities/chat.py @@ -1,11 +1,10 @@ -from copy import deepcopy from datetime import datetime -from typing import Any, Generator, List, Tuple +from typing import Any, Generator, Tuple, List from uuid import UUID, uuid4 from langchain_core.messages import AIMessage, HumanMessage -from quivr_core.models import ChatMessage +from quivr_core.rag.entities.models import ChatMessage class ChatHistory: @@ -98,17 +97,3 @@ class ChatHistory: """ return [_msg.msg for _msg in self._msgs] - - def __deepcopy__(self, memo): - """ - Support for deepcopy of ChatHistory. - This method ensures that mutable objects (like lists) are copied deeply. - """ - # Create a new instance of ChatHistory - new_copy = ChatHistory(self.id, deepcopy(self.brain_id, memo)) - - # Perform a deepcopy of the _msgs list - new_copy._msgs = deepcopy(self._msgs, memo) - - # Return the deep copied instance - return new_copy diff --git a/core/quivr_core/rag/entities/config.py b/core/quivr_core/rag/entities/config.py new file mode 100644 index 000000000..cb455ab42 --- /dev/null +++ b/core/quivr_core/rag/entities/config.py @@ -0,0 +1,452 @@ +import os +import re +import logging +from enum import Enum +from typing import Dict, Hashable, List, Optional, Union, Any, Type +from uuid import UUID +from pydantic import BaseModel +from langgraph.graph import START, END +from langchain_core.tools import BaseTool +from quivr_core.config import MegaparseConfig +from rapidfuzz import process, fuzz + + +from quivr_core.base_config import QuivrBaseConfig +from quivr_core.processor.splitter import SplitterConfig +from quivr_core.rag.prompts import CustomPromptsModel +from quivr_core.llm_tools.llm_tools import LLMToolFactory, TOOLS_CATEGORIES, TOOLS_LISTS + +logger = logging.getLogger("quivr_core") + + +def normalize_to_env_variable_name(name: str) -> str: + # Replace any character that is not a letter, digit, or underscore with an underscore + env_variable_name = re.sub(r"[^A-Za-z0-9_]", "_", name).upper() + + # Check if the normalized name starts with a digit + if env_variable_name[0].isdigit(): + raise ValueError( + f"Invalid environment variable name '{env_variable_name}': Cannot start with a digit." + ) + + return env_variable_name + + +class SpecialEdges(str, Enum): + start = "START" + end = "END" + + +class BrainConfig(QuivrBaseConfig): + brain_id: UUID | None = None + name: str + + @property + def id(self) -> UUID | None: + return self.brain_id + + +class DefaultWebSearchTool(str, Enum): + TAVILY = "tavily" + + +class DefaultRerankers(str, Enum): + COHERE = "cohere" + JINA = "jina" + # MIXEDBREAD = "mixedbread-ai" + + @property + def default_model(self) -> str: + # Mapping of suppliers to their default models + return { + self.COHERE: "rerank-multilingual-v3.0", + self.JINA: "jina-reranker-v2-base-multilingual", + # self.MIXEDBREAD: "rmxbai-rerank-large-v1", + }[self] + + +class DefaultModelSuppliers(str, Enum): + OPENAI = "openai" + AZURE = "azure" + ANTHROPIC = "anthropic" + META = "meta" + MISTRAL = "mistral" + GROQ = "groq" + + +class LLMConfig(QuivrBaseConfig): + context: int | None = None + tokenizer_hub: str | None = None + + +class LLMModelConfig: + _model_defaults: Dict[DefaultModelSuppliers, Dict[str, LLMConfig]] = { + DefaultModelSuppliers.OPENAI: { + "gpt-4o": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"), + "gpt-4o-mini": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"), + "gpt-4-turbo": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4"), + "gpt-4": LLMConfig(context=8192, tokenizer_hub="Xenova/gpt-4"), + "gpt-3.5-turbo": LLMConfig( + context=16385, tokenizer_hub="Xenova/gpt-3.5-turbo" + ), + "text-embedding-3-large": LLMConfig( + context=8191, tokenizer_hub="Xenova/text-embedding-ada-002" + ), + "text-embedding-3-small": LLMConfig( + context=8191, tokenizer_hub="Xenova/text-embedding-ada-002" + ), + "text-embedding-ada-002": LLMConfig( + context=8191, tokenizer_hub="Xenova/text-embedding-ada-002" + ), + }, + DefaultModelSuppliers.ANTHROPIC: { + "claude-3-5-sonnet": LLMConfig( + context=200000, tokenizer_hub="Xenova/claude-tokenizer" + ), + "claude-3-opus": LLMConfig( + context=200000, tokenizer_hub="Xenova/claude-tokenizer" + ), + "claude-3-sonnet": LLMConfig( + context=200000, tokenizer_hub="Xenova/claude-tokenizer" + ), + "claude-3-haiku": LLMConfig( + context=200000, tokenizer_hub="Xenova/claude-tokenizer" + ), + "claude-2-1": LLMConfig( + context=200000, tokenizer_hub="Xenova/claude-tokenizer" + ), + "claude-2-0": LLMConfig( + context=100000, tokenizer_hub="Xenova/claude-tokenizer" + ), + "claude-instant-1-2": LLMConfig( + context=100000, tokenizer_hub="Xenova/claude-tokenizer" + ), + }, + DefaultModelSuppliers.META: { + "llama-3.1": LLMConfig( + context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer" + ), + "llama-3": LLMConfig( + context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new" + ), + "llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"), + "code-llama": LLMConfig( + context=16384, tokenizer_hub="Xenova/llama-code-tokenizer" + ), + }, + DefaultModelSuppliers.GROQ: { + "llama-3.1": LLMConfig( + context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer" + ), + "llama-3": LLMConfig( + context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new" + ), + "llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"), + "code-llama": LLMConfig( + context=16384, tokenizer_hub="Xenova/llama-code-tokenizer" + ), + }, + DefaultModelSuppliers.MISTRAL: { + "mistral-large": LLMConfig( + context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3" + ), + "mistral-small": LLMConfig( + context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3" + ), + "mistral-nemo": LLMConfig( + context=128000, tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer" + ), + "codestral": LLMConfig( + context=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3" + ), + }, + } + + @classmethod + def get_supplier_by_model_name(cls, model: str) -> DefaultModelSuppliers | None: + # Iterate over the suppliers and their models + for supplier, models in cls._model_defaults.items(): + # Check if the model name or a base part of the model name is in the supplier's models + for base_model_name in models: + if model.startswith(base_model_name): + return supplier + # Return None if no supplier matches the model name + return None + + @classmethod + def get_llm_model_config( + cls, supplier: DefaultModelSuppliers, model_name: str + ) -> Optional[LLMConfig]: + """Retrieve the LLMConfig (context and tokenizer_hub) for a given supplier and model.""" + supplier_defaults = cls._model_defaults.get(supplier) + if not supplier_defaults: + return None + + # Use startswith logic for matching model names + for key, config in supplier_defaults.items(): + if model_name.startswith(key): + return config + + return None + + +class LLMEndpointConfig(QuivrBaseConfig): + supplier: DefaultModelSuppliers = DefaultModelSuppliers.OPENAI + model: str = "gpt-4o" + context_length: int | None = None + tokenizer_hub: str | None = None + llm_base_url: str | None = None + env_variable_name: str | None = None + llm_api_key: str | None = None + max_context_tokens: int = 2000 + max_output_tokens: int = 2000 + temperature: float = 0.7 + streaming: bool = True + prompt: CustomPromptsModel | None = None + + _FALLBACK_TOKENIZER = "cl100k_base" + + @property + def fallback_tokenizer(self) -> str: + return self._FALLBACK_TOKENIZER + + def __init__(self, **data): + super().__init__(**data) + self.set_llm_model_config() + self.set_api_key() + + def set_api_key(self, force_reset: bool = False): + if not self.supplier: + return + + # Check if the corresponding API key environment variable is set + if not self.env_variable_name: + self.env_variable_name = ( + f"{normalize_to_env_variable_name(self.supplier)}_API_KEY" + ) + + if not self.llm_api_key or force_reset: + self.llm_api_key = os.getenv(self.env_variable_name) + + if not self.llm_api_key: + logger.warning(f"The API key for supplier '{self.supplier}' is not set. ") + + def set_llm_model_config(self): + # Automatically set context_length and tokenizer_hub based on the supplier and model + llm_model_config = LLMModelConfig.get_llm_model_config( + self.supplier, self.model + ) + if llm_model_config: + self.context_length = llm_model_config.context + self.tokenizer_hub = llm_model_config.tokenizer_hub + + def set_llm_model(self, model: str): + supplier = LLMModelConfig.get_supplier_by_model_name(model) + if supplier is None: + raise ValueError( + f"Cannot find the corresponding supplier for model {model}" + ) + self.supplier = supplier + self.model = model + + self.set_llm_model_config() + self.set_api_key(force_reset=True) + + def set_from_sqlmodel(self, sqlmodel: BaseModel, mapping: Dict[str, str]): + """ + Set attributes in LLMEndpointConfig from Model attributes using a field mapping. + + :param model_instance: An instance of the Model class. + :param mapping: A dictionary that maps Model fields to LLMEndpointConfig fields. + Example: {"max_input": "max_input_tokens", "env_variable_name": "env_variable_name"} + """ + for model_field, llm_field in mapping.items(): + if hasattr(sqlmodel, model_field) and hasattr(self, llm_field): + setattr(self, llm_field, getattr(sqlmodel, model_field)) + else: + raise AttributeError( + f"Invalid mapping: {model_field} or {llm_field} does not exist." + ) + + +# Cannot use Pydantic v2 field_validator because of conflicts with pydantic v1 still in use in LangChain +class RerankerConfig(QuivrBaseConfig): + supplier: DefaultRerankers | None = None + model: str | None = None + top_n: int = 5 # Number of chunks returned by the re-ranker + api_key: str | None = None + relevance_score_threshold: float | None = None + relevance_score_key: str = "relevance_score" + + def __init__(self, **data): + super().__init__(**data) # Call Pydantic's BaseModel init + self.validate_model() # Automatically call external validation + + def validate_model(self): + # If model is not provided, get default model based on supplier + if self.model is None and self.supplier is not None: + self.model = self.supplier.default_model + + # Check if the corresponding API key environment variable is set + if self.supplier: + api_key_var = f"{normalize_to_env_variable_name(self.supplier)}_API_KEY" + self.api_key = os.getenv(api_key_var) + + if self.api_key is None: + raise ValueError( + f"The API key for supplier '{self.supplier}' is not set. " + f"Please set the environment variable: {api_key_var}" + ) + + +class ConditionalEdgeConfig(QuivrBaseConfig): + routing_function: str + conditions: Union[list, Dict[Hashable, str]] + + def __init__(self, **data): + super().__init__(**data) + self.resolve_special_edges() + + def resolve_special_edges(self): + """Replace SpecialEdges enum values with their corresponding langgraph values.""" + + if isinstance(self.conditions, dict): + # If conditions is a dictionary, iterate through the key-value pairs + for key, value in self.conditions.items(): + if value == SpecialEdges.end: + self.conditions[key] = END + elif value == SpecialEdges.start: + self.conditions[key] = START + elif isinstance(self.conditions, list): + # If conditions is a list, iterate through the values + for index, value in enumerate(self.conditions): + if value == SpecialEdges.end: + self.conditions[index] = END + elif value == SpecialEdges.start: + self.conditions[index] = START + + +class NodeConfig(QuivrBaseConfig): + name: str + edges: List[str] | None = None + conditional_edge: ConditionalEdgeConfig | None = None + tools: List[Dict[str, Any]] | None = None + instantiated_tools: List[BaseTool | Type] | None = None + + def __init__(self, **data): + super().__init__(**data) + self._instantiate_tools() + self.resolve_special_edges_in_name_and_edges() + + def resolve_special_edges_in_name_and_edges(self): + """Replace SpecialEdges enum values in name and edges with corresponding langgraph values.""" + if self.name == SpecialEdges.start: + self.name = START + elif self.name == SpecialEdges.end: + self.name = END + + if self.edges: + for i, edge in enumerate(self.edges): + if edge == SpecialEdges.start: + self.edges[i] = START + elif edge == SpecialEdges.end: + self.edges[i] = END + + def _instantiate_tools(self): + """Instantiate tools based on the configuration.""" + if self.tools: + self.instantiated_tools = [ + LLMToolFactory.create_tool(tool_config.pop("name"), tool_config) + for tool_config in self.tools + ] + + +class DefaultWorkflow(str, Enum): + RAG = "rag" + + @property + def nodes(self) -> List[NodeConfig]: + # Mapping of workflow types to their default node configurations + workflows = { + self.RAG: [ + NodeConfig(name=START, edges=["filter_history"]), + NodeConfig(name="filter_history", edges=["rewrite"]), + NodeConfig(name="rewrite", edges=["retrieve"]), + NodeConfig(name="retrieve", edges=["generate_rag"]), + NodeConfig(name="generate_rag", edges=[END]), + ] + } + return workflows[self] + + +class WorkflowConfig(QuivrBaseConfig): + name: str | None = None + nodes: List[NodeConfig] = [] + available_tools: List[str] | None = None + validated_tools: List[BaseTool | Type] = [] + activated_tools: List[BaseTool | Type] = [] + + def __init__(self, **data): + super().__init__(**data) + self.check_first_node_is_start() + self.validate_available_tools() + + def check_first_node_is_start(self): + if self.nodes and self.nodes[0].name != START: + raise ValueError(f"The first node should be a {SpecialEdges.start} node") + + def get_node_tools(self, node_name: str) -> List[Any]: + """Get tools for a specific node.""" + for node in self.nodes: + if node.name == node_name and node.instantiated_tools: + return node.instantiated_tools + return [] + + def validate_available_tools(self): + if self.available_tools: + valid_tools = list(TOOLS_CATEGORIES.keys()) + list(TOOLS_LISTS.keys()) + for tool in self.available_tools: + if tool.lower() in valid_tools: + self.validated_tools.append( + LLMToolFactory.create_tool(tool, {}).tool + ) + else: + matches = process.extractOne( + tool.lower(), valid_tools, scorer=fuzz.WRatio + ) + if matches: + raise ValueError( + f"Tool {tool} is not a valid ToolsCategory or ToolsList. Did you mean {matches[0]}?" + ) + else: + raise ValueError( + f"Tool {tool} is not a valid ToolsCategory or ToolsList" + ) + + +class RetrievalConfig(QuivrBaseConfig): + reranker_config: RerankerConfig = RerankerConfig() + llm_config: LLMEndpointConfig = LLMEndpointConfig() + max_history: int = 10 + max_files: int = 20 + k: int = 40 # Number of chunks returned by the retriever + prompt: str | None = None + workflow_config: WorkflowConfig = WorkflowConfig(nodes=DefaultWorkflow.RAG.nodes) + + def __init__(self, **data): + super().__init__(**data) + self.llm_config.set_api_key(force_reset=True) + + +class ParserConfig(QuivrBaseConfig): + splitter_config: SplitterConfig = SplitterConfig() + megaparse_config: MegaparseConfig = MegaparseConfig() + + +class IngestionConfig(QuivrBaseConfig): + parser_config: ParserConfig = ParserConfig() + + +class AssistantConfig(QuivrBaseConfig): + retrieval_config: RetrievalConfig = RetrievalConfig() + ingestion_config: IngestionConfig = IngestionConfig() diff --git a/core/quivr_core/models.py b/core/quivr_core/rag/entities/models.py similarity index 93% rename from core/quivr_core/models.py rename to core/quivr_core/rag/entities/models.py index 0dc304c67..f87f49b66 100644 --- a/core/quivr_core/models.py +++ b/core/quivr_core/rag/entities/models.py @@ -7,7 +7,7 @@ from langchain_core.documents import Document from langchain_core.messages import AIMessage, HumanMessage from langchain_core.pydantic_v1 import BaseModel as BaseModelV1 from langchain_core.pydantic_v1 import Field as FieldV1 -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import TypedDict @@ -73,9 +73,9 @@ class ChatLLMMetadata(BaseModel): class RAGResponseMetadata(BaseModel): - citations: list[int] | None = None - followup_questions: list[str] | None = None - sources: list[Any] | None = None + citations: list[int] = Field(default_factory=list) + followup_questions: list[str] = Field(default_factory=list) + sources: list[Any] = Field(default_factory=list) metadata_model: ChatLLMMetadata | None = None diff --git a/core/quivr_core/rag/prompts.py b/core/quivr_core/rag/prompts.py new file mode 100644 index 000000000..256735568 --- /dev/null +++ b/core/quivr_core/rag/prompts.py @@ -0,0 +1,265 @@ +import datetime +from pydantic import ConfigDict, create_model + +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.prompts import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + PromptTemplate, + SystemMessagePromptTemplate, + MessagesPlaceholder, +) + + +class CustomPromptsDict(dict): + def __init__(self, type, *args, **kwargs): + super().__init__(*args, **kwargs) + self._type = type + + def __setitem__(self, key, value): + # Automatically convert the value into a tuple (my_type, value) + super().__setitem__(key, (self._type, value)) + + +def _define_custom_prompts() -> CustomPromptsDict: + custom_prompts: CustomPromptsDict = CustomPromptsDict(type=BasePromptTemplate) + + today_date = datetime.datetime.now().strftime("%B %d, %Y") + + # --------------------------------------------------------------------------- + # Prompt for question rephrasing + # --------------------------------------------------------------------------- + system_message_template = ( + "Given a chat history and the latest user question " + "which might reference context in the chat history, " + "formulate a standalone question which can be understood " + "without the chat history. Do NOT answer the question, " + "just reformulate it if needed and otherwise return it as is. " + "Do not output your reasoning, just the question." + ) + + template_answer = "User question: {question}\n Standalone question:" + + CONDENSE_QUESTION_PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_template), + MessagesPlaceholder(variable_name="chat_history"), + HumanMessagePromptTemplate.from_template(template_answer), + ] + ) + + custom_prompts["CONDENSE_QUESTION_PROMPT"] = CONDENSE_QUESTION_PROMPT + + # --------------------------------------------------------------------------- + # Prompt for RAG + # --------------------------------------------------------------------------- + system_message_template = f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}. " + + system_message_template += ( + "- When answering use markdown. Use markdown code blocks for code snippets.\n" + "- Answer in a concise and clear manner.\n" + "- If no preferred language is provided, answer in the same language as the language used by the user.\n" + "- You must use ONLY the provided context to answer the question. " + "Do not use any prior knowledge or external information, even if you are certain of the answer.\n" + "- If you cannot provide an answer using ONLY the context provided, do not attempt to answer from your own knowledge." + "Instead, inform the user that the answer isn't available in the context and suggest using the available tools {tools}.\n" + "- Do not apologize when providing an answer.\n" + "- Don't cite the source id in the answer objects, but you can use the source to answer the question.\n\n" + ) + + context_template = ( + "\n" + "- You have access to the following internal reasoning to provide an answer: {reasoning}\n" + "- You have access to the following files to answer the user question (limited to first 20 files): {files}\n" + "- You have access to the following context to answer the user question: {context}\n" + "- Follow these user instruction when crafting the answer: {custom_instructions}\n" + "- These user instructions shall take priority over any other previous instruction.\n" + "- Remember: if you cannot provide an answer using ONLY the provided context and CITING the sources, " + "inform the user that you don't have the answer and consider if any of the tools can help answer the question.\n" + "- Explain your reasoning about the potentiel tool usage in the answer.\n" + "- Only use binded tools to answer the question.\n" + # "OFFER the user the possibility to ACTIVATE a relevant tool among " + # "the tools which can be activated." + # "Tools which can be activated: {tools}. If any of these tools can help in providing an answer " + # "to the user question, you should offer the user the possibility to activate it. " + # "Remember, you shall NOT use the above tools, ONLY offer the user the possibility to activate them.\n" + ) + + template_answer = ( + "Original task: {question}\n" + "Rephrased and contextualized task: {rephrased_task}\n" + ) + + RAG_ANSWER_PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_template), + MessagesPlaceholder(variable_name="chat_history"), + SystemMessagePromptTemplate.from_template(context_template), + HumanMessagePromptTemplate.from_template(template_answer), + ] + ) + custom_prompts["RAG_ANSWER_PROMPT"] = RAG_ANSWER_PROMPT + + # --------------------------------------------------------------------------- + # Prompt for formatting documents + # --------------------------------------------------------------------------- + DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template( + template="Filename: {original_file_name}\nSource: {index} \n {page_content}" + ) + custom_prompts["DEFAULT_DOCUMENT_PROMPT"] = DEFAULT_DOCUMENT_PROMPT + + # --------------------------------------------------------------------------- + # Prompt for chatting directly with LLMs, without any document retrieval stage + # --------------------------------------------------------------------------- + system_message_template = ( + f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}." + ) + system_message_template += """ + If not None, also follow these user instructions when answering: {custom_instructions} + """ + + template_answer = """ + User Question: {question} + Answer: + """ + CHAT_LLM_PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_template), + MessagesPlaceholder(variable_name="chat_history"), + HumanMessagePromptTemplate.from_template(template_answer), + ] + ) + custom_prompts["CHAT_LLM_PROMPT"] = CHAT_LLM_PROMPT + + # --------------------------------------------------------------------------- + # Prompt to understand the user intent + # --------------------------------------------------------------------------- + system_message_template = ( + "Given the following user input, determine the user intent, in particular " + "whether the user is providing instructions to the system or is asking the system to " + "execute a task:\n" + " - if the user is providing direct instructions to modify the system behaviour (for instance, " + "'Can you reply in French?' or 'Answer in French' or 'You are an expert legal assistant' " + "or 'You will behave as...'), the user intent is 'prompt';\n" + " - in all other cases (asking questions, asking for summarising a text, asking for translating a text, ...), " + "the intent is 'task'.\n" + ) + + template_answer = "User input: {question}" + + USER_INTENT_PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_template), + HumanMessagePromptTemplate.from_template(template_answer), + ] + ) + custom_prompts["USER_INTENT_PROMPT"] = USER_INTENT_PROMPT + + # --------------------------------------------------------------------------- + # Prompt to create a system prompt from user instructions + # --------------------------------------------------------------------------- + system_message_template = ( + "- Given the following user instruction, current system prompt, list of available tools " + "and list of activated tools, update the prompt to include the instruction and decide which tools to activate.\n" + "- The prompt shall only contain generic instructions which can be applied to any user task or question.\n" + "- The prompt shall be concise and clear.\n" + "- If the system prompt already contains the instruction, do not add it again.\n" + "- If the system prompt contradicts ther user instruction, remove the contradictory " + "statement or statements in the system prompt.\n" + "- You shall return separately the updated system prompt and the reasoning that led to the update.\n" + "- If the system prompt refers to a tool, you shall add the tool to the list of activated tools.\n" + "- If no tool activation is needed, return empty lists.\n" + "- You shall also return the reasoning that led to the tool activation.\n" + "- Current system prompt: {system_prompt}\n" + "- List of available tools: {available_tools}\n" + "- List of activated tools: {activated_tools}\n\n" + ) + + template_answer = "User instructions: {instruction}\n" + + UPDATE_PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_template), + HumanMessagePromptTemplate.from_template(template_answer), + ] + ) + custom_prompts["UPDATE_PROMPT"] = UPDATE_PROMPT + + # --------------------------------------------------------------------------- + # Prompt to split the user input into multiple questions / instructions + # --------------------------------------------------------------------------- + system_message_template = ( + "Given a chat history and the user input, split and rephrase the input into instructions and tasks.\n" + "- Instructions direct the system to behave in a certain way or to use specific tools: examples of instructions are " + "'Can you reply in French?', 'Answer in French', 'You are an expert legal assistant', " + "'You will behave as...', 'Use web search').\n" + "- You shall collect and condense all the instructions into a single string.\n" + "- The instructions shall be standalone and self-contained, so that they can be understood " + "without the chat history. If no instructions are found, return an empty string.\n" + "- Instructions to be understood may require considering the chat history.\n" + "- Tasks are often questions, but they can also be summarisation tasks, translation tasks, content generation tasks, etc.\n" + "- Tasks to be understood may require considering the chat history.\n" + "- If the user input contains different tasks, you shall split the input into multiple tasks.\n" + "- Each splitted task shall be a standalone, self-contained task which can be understood " + "without the chat history. You shall rephrase the tasks if needed.\n" + "- If no explicit task is present, you shall infer the tasks from the user input and the chat history.\n" + "- Do NOT try to solve the tasks or answer the questions, " + "just reformulate them if needed and otherwise return them as is.\n" + "- Remember, you shall NOT suggest or generate new tasks.\n" + "- As an example, the user input 'What is Apple? Who is its CEO? When was it founded?' " + "shall be split into the questions 'What is Apple?', 'Who is the CEO of Apple?' and 'When was Apple founded?'.\n" + "- If no tasks are found, return the user input as is in the task.\n" + ) + + template_answer = "User input: {user_input}" + + SPLIT_PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_template), + MessagesPlaceholder(variable_name="chat_history"), + HumanMessagePromptTemplate.from_template(template_answer), + ] + ) + custom_prompts["SPLIT_PROMPT"] = SPLIT_PROMPT + + # --------------------------------------------------------------------------- + # Prompt to grade the relevance of an answer and decide whather to perform a web search + # --------------------------------------------------------------------------- + system_message_template = ( + "Given the following tasks you shall determine whether all tasks can be " + "completed fully and in the best possible way using the provided context and chat history. " + "You shall:\n" + "- Consider each task separately,\n" + "- Determine whether the context and chat history contain " + "all the information necessary to complete the task.\n" + "- If the context and chat history do not contain all the information necessary to complete the task, " + "consider ONLY the list of tools below and select the tool most appropriate to complete the task.\n" + "- If no tools are listed, return the tasks as is and no tool.\n" + "- If no relevant tool can be selected, return the tasks as is and no tool.\n" + "- Do not propose to use a tool if that tool is not listed among the available tools.\n" + ) + + context_template = "Context: {context}\n {activated_tools}\n" + + template_answer = "Tasks: {tasks}\n" + + TOOL_ROUTING_PROMPT = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate.from_template(system_message_template), + MessagesPlaceholder(variable_name="chat_history"), + SystemMessagePromptTemplate.from_template(context_template), + HumanMessagePromptTemplate.from_template(template_answer), + ] + ) + + custom_prompts["TOOL_ROUTING_PROMPT"] = TOOL_ROUTING_PROMPT + + return custom_prompts + + +_custom_prompts = _define_custom_prompts() +CustomPromptsModel = create_model( + "CustomPromptsModel", **_custom_prompts, __config__=ConfigDict(extra="forbid") +) + +custom_prompts = CustomPromptsModel() diff --git a/core/quivr_core/quivr_rag.py b/core/quivr_core/rag/quivr_rag.py similarity index 97% rename from core/quivr_core/quivr_rag.py rename to core/quivr_core/rag/quivr_rag.py index a11b98bfc..38502eb4d 100644 --- a/core/quivr_core/quivr_rag.py +++ b/core/quivr_core/rag/quivr_rag.py @@ -12,18 +12,18 @@ from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableLambda, RunnablePassthrough from langchain_core.vectorstores import VectorStore -from quivr_core.chat import ChatHistory -from quivr_core.config import RetrievalConfig +from quivr_core.rag.entities.chat import ChatHistory +from quivr_core.rag.entities.config import RetrievalConfig from quivr_core.llm import LLMEndpoint -from quivr_core.models import ( +from quivr_core.rag.entities.models import ( ParsedRAGChunkResponse, ParsedRAGResponse, QuivrKnowledge, RAGResponseMetadata, cited_answer, ) -from quivr_core.prompts import custom_prompts -from quivr_core.utils import ( +from quivr_core.rag.prompts import custom_prompts +from quivr_core.rag.utils import ( combine_documents, format_file_list, get_chunk_metadata, diff --git a/core/quivr_core/rag/quivr_rag_langgraph.py b/core/quivr_core/rag/quivr_rag_langgraph.py new file mode 100644 index 000000000..49ba6d282 --- /dev/null +++ b/core/quivr_core/rag/quivr_rag_langgraph.py @@ -0,0 +1,891 @@ +import logging +from typing import ( + Annotated, + AsyncGenerator, + List, + Optional, + Sequence, + Tuple, + TypedDict, + Dict, + Any, + Type, +) +from uuid import uuid4 +import asyncio + +# TODO(@aminediro): this is the only dependency to langchain package, we should remove it +from langchain.retrievers import ContextualCompressionRetriever +from langchain_cohere import CohereRerank +from langchain_community.document_compressors import JinaRerank +from langchain_core.callbacks import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from langchain_core.messages import BaseMessage +from langchain_core.messages.ai import AIMessageChunk +from langchain_core.vectorstores import VectorStore +from langchain_core.prompts.base import BasePromptTemplate +from langgraph.graph import START, END, StateGraph +from langgraph.graph.message import add_messages +from langgraph.types import Send + + +from pydantic import BaseModel, Field +import openai + +from quivr_core.rag.entities.chat import ChatHistory +from quivr_core.rag.entities.config import DefaultRerankers, NodeConfig, RetrievalConfig +from quivr_core.llm import LLMEndpoint +from quivr_core.llm_tools.llm_tools import LLMToolFactory +from quivr_core.rag.entities.models import ( + ParsedRAGChunkResponse, + QuivrKnowledge, +) +from quivr_core.rag.prompts import custom_prompts +from quivr_core.rag.utils import ( + collect_tools, + combine_documents, + format_file_list, + get_chunk_metadata, + parse_chunk_response, +) + +logger = logging.getLogger("quivr_core") + + +class SplittedInput(BaseModel): + instructions_reasoning: Optional[str] = Field( + default=None, + description="The reasoning that leads to identifying the user instructions to the system", + ) + instructions: Optional[str] = Field( + default=None, description="The instructions to the system" + ) + + tasks_reasoning: Optional[str] = Field( + default=None, + description="The reasoning that leads to identifying the explicit or implicit user tasks and questions", + ) + tasks: Optional[List[str]] = Field( + default_factory=lambda: ["No explicit or implicit tasks found"], + description="The list of standalone, self-contained tasks or questions.", + ) + + +class TasksCompletion(BaseModel): + completable_tasks_reasoning: Optional[str] = Field( + default=None, + description="The reasoning that leads to identifying the user tasks or questions that can be completed using the provided context and chat history.", + ) + completable_tasks: Optional[List[str]] = Field( + default_factory=list, + description="The user tasks or questions that can be completed using the provided context and chat history.", + ) + + non_completable_tasks_reasoning: Optional[str] = Field( + default=None, + description="The reasoning that leads to identifying the user tasks or questions that cannot be completed using the provided context and chat history.", + ) + non_completable_tasks: Optional[List[str]] = Field( + default_factory=list, + description="The user tasks or questions that need a tool to be completed.", + ) + + tool_reasoning: Optional[str] = Field( + default=None, + description="The reasoning that leads to identifying the tool that shall be used to complete the tasks.", + ) + tool: Optional[str] = Field( + default_factory=list, + description="The tool that shall be used to complete the tasks.", + ) + + +class FinalAnswer(BaseModel): + reasoning_answer: str = Field( + description="The step-by-step reasoning that led to the final answer" + ) + answer: str = Field(description="The final answer to the user tasks/questions") + + all_tasks_completed: bool = Field( + description="Whether all tasks/questions have been successfully answered/completed or not. " + " If the final answer to the user is 'I don't know' or 'I don't have enough information' or 'I'm not sure', " + " this variable should be 'false'" + ) + + +class UpdatedPromptAndTools(BaseModel): + prompt_reasoning: Optional[str] = Field( + default=None, + description="The step-by-step reasoning that leads to the updated system prompt", + ) + prompt: Optional[str] = Field(default=None, description="The updated system prompt") + + tools_reasoning: Optional[str] = Field( + default=None, + description="The reasoning that leads to activating and deactivating the tools", + ) + tools_to_activate: Optional[List[str]] = Field( + default_factory=list, description="The list of tools to activate" + ) + tools_to_deactivate: Optional[List[str]] = Field( + default_factory=list, description="The list of tools to deactivate" + ) + + +class AgentState(TypedDict): + # The add_messages function defines how an update should be processed + # Default is to replace. add_messages says "append" + messages: Annotated[Sequence[BaseMessage], add_messages] + reasoning: List[str] + chat_history: ChatHistory + docs: list[Document] + files: str + tasks: List[str] + instructions: str + tool: str + + +class IdempotentCompressor(BaseDocumentCompressor): + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Optional[Callbacks] = None, + ) -> Sequence[Document]: + """ + A no-op document compressor that simply returns the documents it is given. + + This is a placeholder until a more sophisticated document compression + algorithm is implemented. + """ + return documents + + +class QuivrQARAGLangGraph: + def __init__( + self, + *, + retrieval_config: RetrievalConfig, + llm: LLMEndpoint, + vector_store: VectorStore | None = None, + ): + """ + Construct a QuivrQARAGLangGraph object. + + Args: + retrieval_config (RetrievalConfig): The configuration for the RAG model. + llm (LLMEndpoint): The LLM to use for generating text. + vector_store (VectorStore): The vector store to use for storing and retrieving documents. + reranker (BaseDocumentCompressor | None): The document compressor to use for re-ranking documents. Defaults to IdempotentCompressor if not provided. + """ + self.retrieval_config = retrieval_config + self.vector_store = vector_store + self.llm_endpoint = llm + + self.graph = None + + def get_reranker(self, **kwargs): + # Extract the reranker configuration from self + config = self.retrieval_config.reranker_config + + # Allow kwargs to override specific config values + supplier = kwargs.pop("supplier", config.supplier) + model = kwargs.pop("model", config.model) + top_n = kwargs.pop("top_n", config.top_n) + api_key = kwargs.pop("api_key", config.api_key) + + if supplier == DefaultRerankers.COHERE: + reranker = CohereRerank( + model=model, top_n=top_n, cohere_api_key=api_key, **kwargs + ) + elif supplier == DefaultRerankers.JINA: + reranker = JinaRerank( + model=model, top_n=top_n, jina_api_key=api_key, **kwargs + ) + else: + reranker = IdempotentCompressor() + + return reranker + + def get_retriever(self, **kwargs): + """ + Returns a retriever that can retrieve documents from the vector store. + + Returns: + VectorStoreRetriever: The retriever. + """ + if self.vector_store: + retriever = self.vector_store.as_retriever(**kwargs) + else: + raise ValueError("No vector store provided") + + return retriever + + def routing(self, state: AgentState) -> List[Send]: + """ + The routing function for the RAG model. + + Args: + state (AgentState): The current state of the agent. + + Returns: + dict: The next state of the agent. + """ + + msg = custom_prompts.SPLIT_PROMPT.format( + user_input=state["messages"][0].content, + ) + + response: SplittedInput + + try: + structured_llm = self.llm_endpoint._llm.with_structured_output( + SplittedInput, method="json_schema" + ) + response = structured_llm.invoke(msg) + + except openai.BadRequestError: + structured_llm = self.llm_endpoint._llm.with_structured_output( + SplittedInput + ) + response = structured_llm.invoke(msg) + + send_list: List[Send] = [] + + instructions = ( + response.instructions + if response.instructions + else self.retrieval_config.prompt + ) + + if instructions: + send_list.append(Send("edit_system_prompt", {"instructions": instructions})) + elif response.tasks: + chat_history = state["chat_history"] + send_list.append( + Send( + "filter_history", + {"chat_history": chat_history, "tasks": response.tasks}, + ) + ) + + return send_list + + def routing_split(self, state: AgentState): + response = self.invoke_structured_output( + custom_prompts.SPLIT_PROMPT.format( + chat_history=state["chat_history"].to_list(), + user_input=state["messages"][0].content, + ), + SplittedInput, + ) + + instructions = response.instructions or self.retrieval_config.prompt + tasks = response.tasks or [] + + if instructions: + return [ + Send( + "edit_system_prompt", + {**state, "instructions": instructions, "tasks": tasks}, + ) + ] + elif tasks: + return [Send("filter_history", {**state, "tasks": tasks})] + + return [] + + def update_active_tools(self, updated_prompt_and_tools: UpdatedPromptAndTools): + if updated_prompt_and_tools.tools_to_activate: + for tool in updated_prompt_and_tools.tools_to_activate: + for ( + validated_tool + ) in self.retrieval_config.workflow_config.validated_tools: + if tool == validated_tool.name: + self.retrieval_config.workflow_config.activated_tools.append( + validated_tool + ) + + if updated_prompt_and_tools.tools_to_deactivate: + for tool in updated_prompt_and_tools.tools_to_deactivate: + for ( + activated_tool + ) in self.retrieval_config.workflow_config.activated_tools: + if tool == activated_tool.name: + self.retrieval_config.workflow_config.activated_tools.remove( + activated_tool + ) + + def edit_system_prompt(self, state: AgentState) -> AgentState: + user_instruction = state["instructions"] + prompt = self.retrieval_config.prompt + available_tools, activated_tools = collect_tools( + self.retrieval_config.workflow_config + ) + inputs = { + "instruction": user_instruction, + "system_prompt": prompt if prompt else "", + "available_tools": available_tools, + "activated_tools": activated_tools, + } + + msg = custom_prompts.UPDATE_PROMPT.format(**inputs) + + response: UpdatedPromptAndTools = self.invoke_structured_output( + msg, UpdatedPromptAndTools + ) + + self.update_active_tools(response) + self.retrieval_config.prompt = response.prompt + + reasoning = [response.prompt_reasoning] if response.prompt_reasoning else [] + reasoning += [response.tools_reasoning] if response.tools_reasoning else [] + + return {**state, "messages": [], "reasoning": reasoning} + + def filter_history(self, state: AgentState) -> AgentState: + """ + Filter out the chat history to only include the messages that are relevant to the current question + + Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '), + AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer, + sous la direction de son supérieur hiérarchique, Stanislas Girard."), + HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''), + HumanMessage(content='Dis moi en plus sur elle'), + AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")] + Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair + a token is 4 characters + """ + + chat_history = state["chat_history"] + total_tokens = 0 + total_pairs = 0 + _chat_id = uuid4() + _chat_history = ChatHistory(chat_id=_chat_id, brain_id=chat_history.brain_id) + for human_message, ai_message in reversed(list(chat_history.iter_pairs())): + # TODO: replace with tiktoken + message_tokens = self.llm_endpoint.count_tokens( + human_message.content + ) + self.llm_endpoint.count_tokens(ai_message.content) + if ( + total_tokens + message_tokens + > self.retrieval_config.llm_config.max_context_tokens + or total_pairs >= self.retrieval_config.max_history + ): + break + _chat_history.append(human_message) + _chat_history.append(ai_message) + total_tokens += message_tokens + total_pairs += 1 + + return {**state, "chat_history": _chat_history} + + ### Nodes + async def rewrite(self, state: AgentState) -> AgentState: + """ + Transform the query to produce a better question. + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + + tasks = ( + state["tasks"] + if "tasks" in state and state["tasks"] + else [state["messages"][0].content] + ) + + # Prepare the async tasks for all user tsks + async_tasks = [] + for task in tasks: + msg = custom_prompts.CONDENSE_QUESTION_PROMPT.format( + chat_history=state["chat_history"].to_list(), + question=task, + ) + + model = self.llm_endpoint._llm + # Asynchronously invoke the model for each question + async_tasks.append(model.ainvoke(msg)) + + # Gather all the responses asynchronously + responses = await asyncio.gather(*async_tasks) if async_tasks else [] + + # Replace each question with its condensed version + condensed_questions = [] + for response in responses: + condensed_questions.append(response.content) + + return {**state, "tasks": condensed_questions} + + def filter_chunks_by_relevance(self, chunks: List[Document], **kwargs): + config = self.retrieval_config.reranker_config + relevance_score_threshold = kwargs.get( + "relevance_score_threshold", config.relevance_score_threshold + ) + + if relevance_score_threshold is None: + return chunks + + filtered_chunks = [] + for chunk in chunks: + if config.relevance_score_key not in chunk.metadata: + logger.warning( + f"Relevance score key {config.relevance_score_key} not found in metadata, cannot filter chunks by relevance" + ) + filtered_chunks.append(chunk) + elif ( + chunk.metadata[config.relevance_score_key] >= relevance_score_threshold + ): + filtered_chunks.append(chunk) + + return filtered_chunks + + def tool_routing(self, state: AgentState): + tasks = state["tasks"] + if not tasks: + return [Send("generate_rag", state)] + + docs = state["docs"] + + _, activated_tools = collect_tools(self.retrieval_config.workflow_config) + + input = { + "chat_history": state["chat_history"].to_list(), + "tasks": state["tasks"], + "context": docs, + "activated_tools": activated_tools, + } + + input, _ = self.reduce_rag_context( + inputs=input, + prompt=custom_prompts.TOOL_ROUTING_PROMPT, + docs=docs, + # max_context_tokens=2000, + ) + + msg = custom_prompts.TOOL_ROUTING_PROMPT.format(**input) + + response: TasksCompletion = self.invoke_structured_output(msg, TasksCompletion) + + send_list: List[Send] = [] + + if response.non_completable_tasks and response.tool: + payload = { + **state, + "tasks": response.non_completable_tasks, + "tool": response.tool, + } + send_list.append(Send("run_tool", payload)) + else: + send_list.append(Send("generate_rag", state)) + + return send_list + + async def run_tool(self, state: AgentState) -> AgentState: + tool = state["tool"] + if tool not in [ + t.name for t in self.retrieval_config.workflow_config.activated_tools + ]: + raise ValueError(f"Tool {tool} not activated") + + tasks = state["tasks"] + + tool_wrapper = LLMToolFactory.create_tool(tool, {}) + + # Prepare the async tasks for all questions + async_tasks = [] + for task in tasks: + formatted_input = tool_wrapper.format_input(task) + # Asynchronously invoke the model for each question + async_tasks.append(tool_wrapper.tool.ainvoke(formatted_input)) + + # Gather all the responses asynchronously + responses = await asyncio.gather(*async_tasks) if async_tasks else [] + + docs = [] + for response in responses: + _docs = tool_wrapper.format_output(response) + _docs = self.filter_chunks_by_relevance(_docs) + docs += _docs + + return {**state, "docs": state["docs"] + docs} + + async def retrieve(self, state: AgentState) -> AgentState: + """ + Retrieve relevent chunks + + Args: + state (messages): The current state + + Returns: + dict: The retrieved chunks + """ + + tasks = state["tasks"] + if not tasks: + return {**state, "docs": []} + + kwargs = { + "search_kwargs": { + "k": self.retrieval_config.k, + } + } # type: ignore + base_retriever = self.get_retriever(**kwargs) + + kwargs = {"top_n": self.retrieval_config.reranker_config.top_n} # type: ignore + reranker = self.get_reranker(**kwargs) + + compression_retriever = ContextualCompressionRetriever( + base_compressor=reranker, base_retriever=base_retriever + ) + + # Prepare the async tasks for all questions + async_tasks = [] + for task in tasks: + # Asynchronously invoke the model for each question + async_tasks.append(compression_retriever.ainvoke(task)) + + # Gather all the responses asynchronously + responses = await asyncio.gather(*async_tasks) if async_tasks else [] + + docs = [] + for response in responses: + _docs = self.filter_chunks_by_relevance(response) + docs += _docs + + return {**state, "docs": docs} + + async def dynamic_retrieve(self, state: AgentState) -> AgentState: + """ + Retrieve relevent chunks + + Args: + state (messages): The current state + + Returns: + dict: The retrieved chunks + """ + + tasks = state["tasks"] + if not tasks: + return {**state, "docs": []} + + k = self.retrieval_config.k + top_n = self.retrieval_config.reranker_config.top_n + number_of_relevant_chunks = top_n + i = 1 + + while number_of_relevant_chunks == top_n: + top_n = self.retrieval_config.reranker_config.top_n * i + kwargs = {"top_n": top_n} + reranker = self.get_reranker(**kwargs) + + k = max([top_n * 2, self.retrieval_config.k]) + kwargs = {"search_kwargs": {"k": k}} # type: ignore + base_retriever = self.get_retriever(**kwargs) + + if i > 1: + logging.info( + f"Increasing top_n to {top_n} and k to {k} to retrieve more relevant chunks" + ) + + compression_retriever = ContextualCompressionRetriever( + base_compressor=reranker, base_retriever=base_retriever + ) + + # Prepare the async tasks for all questions + async_tasks = [] + for task in tasks: + # Asynchronously invoke the model for each question + async_tasks.append(compression_retriever.ainvoke(task)) + + # Gather all the responses asynchronously + responses = await asyncio.gather(*async_tasks) if async_tasks else [] + + docs = [] + _n = [] + for response in responses: + _docs = self.filter_chunks_by_relevance(response) + _n.append(len(_docs)) + docs += _docs + + if not docs: + break + + context_length = self.get_rag_context_length(state, docs) + if context_length >= self.retrieval_config.llm_config.max_context_tokens: + logging.warning( + f"The context length is {context_length} which is greater than " + f"the max context tokens of {self.retrieval_config.llm_config.max_context_tokens}" + ) + break + + number_of_relevant_chunks = max(_n) + i += 1 + + return {**state, "docs": docs} + + def get_rag_context_length(self, state: AgentState, docs: List[Document]) -> int: + final_inputs = self._build_rag_prompt_inputs(state, docs) + msg = custom_prompts.RAG_ANSWER_PROMPT.format(**final_inputs) + return self.llm_endpoint.count_tokens(msg) + + def reduce_rag_context( + self, + inputs: Dict[str, Any], + prompt: BasePromptTemplate, + docs: List[Document] | None = None, + max_context_tokens: int | None = None, + ) -> Tuple[Dict[str, Any], List[Document] | None]: + MAX_ITERATION = 100 + SECURITY_FACTOR = 0.85 + iteration = 0 + + msg = prompt.format(**inputs) + n = self.llm_endpoint.count_tokens(msg) + + max_context_tokens = ( + max_context_tokens + if max_context_tokens + else self.retrieval_config.llm_config.max_context_tokens + ) + + while n > max_context_tokens * SECURITY_FACTOR: + chat_history = inputs["chat_history"] if "chat_history" in inputs else [] + + if len(chat_history) > 0: + inputs["chat_history"] = chat_history[2:] + elif docs and len(docs) > 1: + docs = docs[:-1] + else: + logging.warning( + f"Not enough context to reduce. The context length is {n} " + f"which is greater than the max context tokens of {max_context_tokens}" + ) + break + + if docs and "context" in inputs: + inputs["context"] = combine_documents(docs) + + msg = prompt.format(**inputs) + n = self.llm_endpoint.count_tokens(msg) + + iteration += 1 + if iteration > MAX_ITERATION: + logging.warning( + f"Attained the maximum number of iterations ({MAX_ITERATION})" + ) + break + + return inputs, docs + + def bind_tools_to_llm(self, node_name: str): + if self.llm_endpoint.supports_func_calling(): + tools = self.retrieval_config.workflow_config.get_node_tools(node_name) + if tools: # Only bind tools if there are any available + return self.llm_endpoint._llm.bind_tools(tools, tool_choice="any") + return self.llm_endpoint._llm + + def generate_rag(self, state: AgentState) -> AgentState: + docs: List[Document] | None = state["docs"] + final_inputs = self._build_rag_prompt_inputs(state, docs) + + reduced_inputs, docs = self.reduce_rag_context( + final_inputs, custom_prompts.RAG_ANSWER_PROMPT, docs + ) + + msg = custom_prompts.RAG_ANSWER_PROMPT.format(**reduced_inputs) + llm = self.bind_tools_to_llm(self.generate_rag.__name__) + response = llm.invoke(msg) + + return {**state, "messages": [response], "docs": docs if docs else []} + + def generate_chat_llm(self, state: AgentState) -> AgentState: + """ + Generate answer + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + messages = state["messages"] + user_question = messages[0].content + + # Prompt + prompt = self.retrieval_config.prompt + + final_inputs = {} + final_inputs["question"] = user_question + final_inputs["custom_instructions"] = prompt if prompt else "None" + final_inputs["chat_history"] = state["chat_history"].to_list() + + # LLM + llm = self.llm_endpoint._llm + + reduced_inputs, _ = self.reduce_rag_context( + final_inputs, custom_prompts.CHAT_LLM_PROMPT, None + ) + + msg = custom_prompts.CHAT_LLM_PROMPT.format(**reduced_inputs) + + # Run + response = llm.invoke(msg) + return {**state, "messages": [response]} + + def build_chain(self): + """ + Builds the langchain chain for the given configuration. + + Returns: + Callable[[Dict], Dict]: The langchain chain. + """ + if not self.graph: + self.graph = self.create_graph() + + return self.graph + + def create_graph(self): + workflow = StateGraph(AgentState) + self.final_nodes = [] + + self._build_workflow(workflow) + + return workflow.compile() + + def _build_workflow(self, workflow: StateGraph): + for node in self.retrieval_config.workflow_config.nodes: + if node.name not in [START, END]: + workflow.add_node(node.name, getattr(self, node.name)) + + for node in self.retrieval_config.workflow_config.nodes: + self._add_node_edges(workflow, node) + + def _add_node_edges(self, workflow: StateGraph, node: NodeConfig): + if node.edges: + for edge in node.edges: + workflow.add_edge(node.name, edge) + if edge == END: + self.final_nodes.append(node.name) + elif node.conditional_edge: + routing_function = getattr(self, node.conditional_edge.routing_function) + workflow.add_conditional_edges( + node.name, routing_function, node.conditional_edge.conditions + ) + if END in node.conditional_edge.conditions: + self.final_nodes.append(node.name) + else: + raise ValueError("Node should have at least one edge or conditional_edge") + + async def answer_astream( + self, + question: str, + history: ChatHistory, + list_files: list[QuivrKnowledge], + metadata: dict[str, str] = {}, + ) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]: + """ + Answer a question using the langgraph chain and yield each chunk of the answer separately. + """ + concat_list_files = format_file_list( + list_files, self.retrieval_config.max_files + ) + conversational_qa_chain = self.build_chain() + + rolling_message = AIMessageChunk(content="") + docs: list[Document] | None = None + previous_content = "" + + async for event in conversational_qa_chain.astream_events( + { + "messages": [("user", question)], + "chat_history": history, + "files": concat_list_files, + }, + version="v1", + config={"metadata": metadata}, + ): + if self._is_final_node_with_docs(event): + docs = event["data"]["output"]["docs"] + + if self._is_final_node_and_chat_model_stream(event): + chunk = event["data"]["chunk"] + rolling_message, new_content, previous_content = parse_chunk_response( + rolling_message, + chunk, + self.llm_endpoint.supports_func_calling(), + previous_content, + ) + + if new_content: + chunk_metadata = get_chunk_metadata(rolling_message, docs) + yield ParsedRAGChunkResponse( + answer=new_content, metadata=chunk_metadata + ) + + # Yield final metadata chunk + yield ParsedRAGChunkResponse( + answer="", + metadata=get_chunk_metadata(rolling_message, docs), + last_chunk=True, + ) + + def _is_final_node_with_docs(self, event: dict) -> bool: + return ( + "output" in event["data"] + and event["data"]["output"] is not None + and "docs" in event["data"]["output"] + and event["metadata"]["langgraph_node"] in self.final_nodes + ) + + def _is_final_node_and_chat_model_stream(self, event: dict) -> bool: + return ( + event["event"] == "on_chat_model_stream" + and "langgraph_node" in event["metadata"] + and event["metadata"]["langgraph_node"] in self.final_nodes + ) + + def invoke_structured_output( + self, prompt: str, output_class: Type[BaseModel] + ) -> Any: + try: + structured_llm = self.llm_endpoint._llm.with_structured_output( + output_class, method="json_schema" + ) + return structured_llm.invoke(prompt) + except openai.BadRequestError: + structured_llm = self.llm_endpoint._llm.with_structured_output(output_class) + return structured_llm.invoke(prompt) + + def _build_rag_prompt_inputs( + self, state: AgentState, docs: List[Document] | None + ) -> Dict[str, Any]: + """Build the input dictionary for RAG_ANSWER_PROMPT. + + Args: + state: Current agent state + docs: List of documents or None + + Returns: + Dictionary containing all inputs needed for RAG_ANSWER_PROMPT + """ + messages = state["messages"] + user_question = messages[0].content + files = state["files"] + prompt = self.retrieval_config.prompt + available_tools, _ = collect_tools(self.retrieval_config.workflow_config) + + return { + "context": combine_documents(docs) if docs else "None", + "question": user_question, + "rephrased_task": state["tasks"], + "custom_instructions": prompt if prompt else "None", + "files": files if files else "None", + "chat_history": state["chat_history"].to_list(), + "reasoning": state["reasoning"] if "reasoning" in state else "None", + "tools": available_tools, + } diff --git a/core/quivr_core/rag/utils.py b/core/quivr_core/rag/utils.py new file mode 100644 index 000000000..edf66a4a7 --- /dev/null +++ b/core/quivr_core/rag/utils.py @@ -0,0 +1,188 @@ +import logging +from typing import Any, List, Tuple, no_type_check + +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.messages.ai import AIMessageChunk +from langchain_core.prompts import format_document + +from quivr_core.rag.entities.config import WorkflowConfig +from quivr_core.rag.entities.models import ( + ChatLLMMetadata, + ParsedRAGResponse, + QuivrKnowledge, + RAGResponseMetadata, + RawRAGResponse, +) +from quivr_core.rag.prompts import custom_prompts + +# TODO(@aminediro): define a types packages where we clearly define IO types +# This should be used for serialization/deseriallization later + + +logger = logging.getLogger("quivr_core") + + +def model_supports_function_calling(model_name: str): + models_not_supporting_function_calls: list[str] = ["llama2", "test", "ollama3"] + + return model_name not in models_not_supporting_function_calls + + +def format_history_to_openai_mesages( + tuple_history: List[Tuple[str, str]], system_message: str, question: str +) -> List[BaseMessage]: + """Format the chat history into a list of Base Messages""" + messages = [] + messages.append(SystemMessage(content=system_message)) + for human, ai in tuple_history: + messages.append(HumanMessage(content=human)) + messages.append(AIMessage(content=ai)) + messages.append(HumanMessage(content=question)) + return messages + + +def cited_answer_filter(tool): + return tool["name"] == "cited_answer" + + +def get_chunk_metadata( + msg: AIMessageChunk, sources: list[Any] | None = None +) -> RAGResponseMetadata: + metadata = {"sources": sources or []} + + if not msg.tool_calls: + return RAGResponseMetadata(**metadata, metadata_model=None) + + all_citations = [] + all_followup_questions = [] + + for tool_call in msg.tool_calls: + if tool_call.get("name") == "cited_answer" and "args" in tool_call: + args = tool_call["args"] + all_citations.extend(args.get("citations", [])) + all_followup_questions.extend(args.get("followup_questions", [])) + + metadata["citations"] = all_citations + metadata["followup_questions"] = all_followup_questions[:3] # Limit to 3 + + return RAGResponseMetadata(**metadata, metadata_model=None) + + +def get_prev_message_str(msg: AIMessageChunk) -> str: + if msg.tool_calls: + cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x)) + if "args" in cited_answer and "answer" in cited_answer["args"]: + return cited_answer["args"]["answer"] + return "" + + +# TODO: CONVOLUTED LOGIC ! +# TODO(@aminediro): redo this +@no_type_check +def parse_chunk_response( + rolling_msg: AIMessageChunk, + raw_chunk: AIMessageChunk, + supports_func_calling: bool, + previous_content: str = "", +) -> Tuple[AIMessageChunk, str, str]: + """Parse a chunk response + Args: + rolling_msg: The accumulated message so far + raw_chunk: The new chunk to add + supports_func_calling: Whether function calling is supported + previous_content: The previous content string + Returns: + Tuple of (updated rolling message, new content only, full content) + """ + rolling_msg += raw_chunk + + if not supports_func_calling or not rolling_msg.tool_calls: + new_content = raw_chunk.content # Just the new chunk's content + full_content = rolling_msg.content # The full accumulated content + return rolling_msg, new_content, full_content + + current_answers = get_answers_from_tool_calls(rolling_msg.tool_calls) + full_answer = "\n\n".join(current_answers) + new_content = full_answer[len(previous_content) :] + + return rolling_msg, new_content, full_answer + + +def get_answers_from_tool_calls(tool_calls): + answers = [] + for tool_call in tool_calls: + if tool_call.get("name") == "cited_answer" and "args" in tool_call: + answers.append(tool_call["args"].get("answer", "")) + return answers + + +@no_type_check +def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse: + answers = [] + sources = raw_response["docs"] if "docs" in raw_response else [] + + metadata = RAGResponseMetadata( + sources=sources, metadata_model=ChatLLMMetadata(name=model_name) + ) + + if ( + model_supports_function_calling(model_name) + and "tool_calls" in raw_response["answer"] + and raw_response["answer"].tool_calls + ): + all_citations = [] + all_followup_questions = [] + for tool_call in raw_response["answer"].tool_calls: + if "args" in tool_call: + args = tool_call["args"] + if "citations" in args: + all_citations.extend(args["citations"]) + if "followup_questions" in args: + all_followup_questions.extend(args["followup_questions"]) + if "answer" in args: + answers.append(args["answer"]) + metadata.citations = all_citations + metadata.followup_questions = all_followup_questions + else: + answers.append(raw_response["answer"].content) + + answer_str = "\n".join(answers) + parsed_response = ParsedRAGResponse(answer=answer_str, metadata=metadata) + return parsed_response + + +def combine_documents( + docs, + document_prompt=custom_prompts.DEFAULT_DOCUMENT_PROMPT, + document_separator="\n\n", +): + # for each docs, add an index in the metadata to be able to cite the sources + for doc, index in zip(docs, range(len(docs)), strict=False): + doc.metadata["index"] = index + doc_strings = [format_document(doc, document_prompt) for doc in docs] + return document_separator.join(doc_strings) + + +def format_file_list( + list_files_array: list[QuivrKnowledge], max_files: int = 20 +) -> str: + list_files = [file.file_name or file.url for file in list_files_array] + files: list[str] = list(filter(lambda n: n is not None, list_files)) # type: ignore + files = files[:max_files] + + files_str = "\n".join(files) if list_files_array else "None" + return files_str + + +def collect_tools(workflow_config: WorkflowConfig): + validated_tools = "Available tools which can be activated:\n" + for i, tool in enumerate(workflow_config.validated_tools): + validated_tools += f"Tool {i+1} name: {tool.name}\n" + validated_tools += f"Tool {i+1} description: {tool.description}\n\n" + + activated_tools = "Activated tools which can be deactivated:\n" + for i, tool in enumerate(workflow_config.activated_tools): + activated_tools += f"Tool {i+1} name: {tool.name}\n" + activated_tools += f"Tool {i+1} description: {tool.description}\n\n" + + return validated_tools, activated_tools diff --git a/core/quivr_core/utils.py b/core/quivr_core/utils.py deleted file mode 100644 index 3239db104..000000000 --- a/core/quivr_core/utils.py +++ /dev/null @@ -1,167 +0,0 @@ -import logging -from typing import Any, List, Tuple, no_type_check - -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage -from langchain_core.messages.ai import AIMessageChunk -from langchain_core.prompts import format_document - -from quivr_core.models import ( - ChatLLMMetadata, - ParsedRAGResponse, - QuivrKnowledge, - RAGResponseMetadata, - RawRAGResponse, -) -from quivr_core.prompts import custom_prompts - -# TODO(@aminediro): define a types packages where we clearly define IO types -# This should be used for serialization/deseriallization later - - -logger = logging.getLogger("quivr_core") - - -def model_supports_function_calling(model_name: str): - models_supporting_function_calls = [ - "gpt-4", - "gpt-4-1106-preview", - "gpt-4-0613", - "gpt-4o", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-0613", - "gpt-4-0125-preview", - "gpt-3.5-turbo", - "gpt-4-turbo", - "gpt-4o", - "gpt-4o-mini", - ] - return model_name in models_supporting_function_calls - - -def format_history_to_openai_mesages( - tuple_history: List[Tuple[str, str]], system_message: str, question: str -) -> List[BaseMessage]: - """Format the chat history into a list of Base Messages""" - messages = [] - messages.append(SystemMessage(content=system_message)) - for human, ai in tuple_history: - messages.append(HumanMessage(content=human)) - messages.append(AIMessage(content=ai)) - messages.append(HumanMessage(content=question)) - return messages - - -def cited_answer_filter(tool): - return tool["name"] == "cited_answer" - - -def get_chunk_metadata( - msg: AIMessageChunk, sources: list[Any] | None = None -) -> RAGResponseMetadata: - # Initiate the source - metadata = {"sources": sources} if sources else {"sources": []} - if msg.tool_calls: - cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x)) - - if "args" in cited_answer: - gathered_args = cited_answer["args"] - if "citations" in gathered_args: - citations = gathered_args["citations"] - metadata["citations"] = citations - - if "followup_questions" in gathered_args: - followup_questions = gathered_args["followup_questions"] - metadata["followup_questions"] = followup_questions - - return RAGResponseMetadata(**metadata, metadata_model=None) - - -def get_prev_message_str(msg: AIMessageChunk) -> str: - if msg.tool_calls: - cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x)) - if "args" in cited_answer and "answer" in cited_answer["args"]: - return cited_answer["args"]["answer"] - return "" - - -# TODO: CONVOLUTED LOGIC ! -# TODO(@aminediro): redo this -@no_type_check -def parse_chunk_response( - rolling_msg: AIMessageChunk, - raw_chunk: dict[str, Any], - supports_func_calling: bool, -) -> Tuple[AIMessageChunk, str]: - # Init with sources - answer_str = "" - - if "answer" in raw_chunk: - answer = raw_chunk["answer"] - else: - answer = raw_chunk - - rolling_msg += answer - if supports_func_calling and rolling_msg.tool_calls: - cited_answer = next(x for x in rolling_msg.tool_calls if cited_answer_filter(x)) - if "args" in cited_answer and "answer" in cited_answer["args"]: - gathered_args = cited_answer["args"] - # Only send the difference between answer and response_tokens which was the previous answer - answer_str = gathered_args["answer"] - return rolling_msg, answer_str - - return rolling_msg, answer.content - - -@no_type_check -def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse: - answer = "" - sources = raw_response["docs"] if "docs" in raw_response else [] - - metadata = RAGResponseMetadata( - sources=sources, metadata_model=ChatLLMMetadata(name=model_name) - ) - - if ( - model_supports_function_calling(model_name) - and "tool_calls" in raw_response["answer"] - and raw_response["answer"].tool_calls - ): - if "citations" in raw_response["answer"].tool_calls[-1]["args"]: - citations = raw_response["answer"].tool_calls[-1]["args"]["citations"] - metadata.citations = citations - followup_questions = raw_response["answer"].tool_calls[-1]["args"][ - "followup_questions" - ] - if followup_questions: - metadata.followup_questions = followup_questions - answer = raw_response["answer"].tool_calls[-1]["args"]["answer"] - else: - answer = raw_response["answer"].tool_calls[-1]["args"]["answer"] - else: - answer = raw_response["answer"].content - - parsed_response = ParsedRAGResponse(answer=answer, metadata=metadata) - return parsed_response - - -def combine_documents( - docs, - document_prompt=custom_prompts.DEFAULT_DOCUMENT_PROMPT, - document_separator="\n\n", -): - # for each docs, add an index in the metadata to be able to cite the sources - for doc, index in zip(docs, range(len(docs)), strict=False): - doc.metadata["index"] = index - doc_strings = [format_document(doc, document_prompt) for doc in docs] - return document_separator.join(doc_strings) - - -def format_file_list( - list_files_array: list[QuivrKnowledge], max_files: int = 20 -) -> str: - list_files = [file.file_name or file.url for file in list_files_array] - files: list[str] = list(filter(lambda n: n is not None, list_files)) # type: ignore - files = files[:max_files] - - files_str = "\n".join(files) if list_files_array else "None" - return files_str diff --git a/core/requirements-dev.lock b/core/requirements-dev.lock index a481a489c..0492b8e7d 100644 --- a/core/requirements-dev.lock +++ b/core/requirements-dev.lock @@ -27,6 +27,8 @@ anyio==4.6.2.post1 # via anthropic # via httpx # via openai +appnope==0.1.4 + # via ipykernel asttokens==2.4.1 # via stack-data attrs==24.2.0 @@ -80,8 +82,6 @@ frozenlist==1.4.1 # via aiosignal fsspec==2024.9.0 # via huggingface-hub -greenlet==3.1.1 - # via sqlalchemy h11==0.14.0 # via httpcore httpcore==1.0.6 @@ -278,6 +278,8 @@ pyyaml==6.0.2 pyzmq==26.2.0 # via ipykernel # via jupyter-client +rapidfuzz==3.10.1 + # via quivr-core regex==2024.9.11 # via tiktoken # via transformers diff --git a/core/requirements.lock b/core/requirements.lock index e82754528..c0cc9af4a 100644 --- a/core/requirements.lock +++ b/core/requirements.lock @@ -56,8 +56,6 @@ frozenlist==1.4.1 # via aiosignal fsspec==2024.9.0 # via huggingface-hub -greenlet==3.1.1 - # via sqlalchemy h11==0.14.0 # via httpcore httpcore==1.0.6 @@ -185,6 +183,8 @@ pyyaml==6.0.2 # via langchain-community # via langchain-core # via transformers +rapidfuzz==3.10.1 + # via quivr-core regex==2024.9.11 # via tiktoken # via transformers diff --git a/core/tests/conftest.py b/core/tests/conftest.py index a6e262e77..59d09f2ff 100644 --- a/core/tests/conftest.py +++ b/core/tests/conftest.py @@ -9,7 +9,7 @@ from langchain_core.language_models import FakeListChatModel from langchain_core.messages.ai import AIMessageChunk from langchain_core.runnables.utils import AddableDict from langchain_core.vectorstores import InMemoryVectorStore -from quivr_core.config import LLMEndpointConfig +from quivr_core.rag.entities.config import LLMEndpointConfig from quivr_core.files.file import FileExtension, QuivrFile from quivr_core.llm import LLMEndpoint diff --git a/core/tests/fixture_chunks.py b/core/tests/fixture_chunks.py index ae521f6ee..06e036669 100644 --- a/core/tests/fixture_chunks.py +++ b/core/tests/fixture_chunks.py @@ -5,10 +5,10 @@ from uuid import uuid4 from langchain_core.embeddings import DeterministicFakeEmbedding from langchain_core.messages.ai import AIMessageChunk from langchain_core.vectorstores import InMemoryVectorStore -from quivr_core.chat import ChatHistory -from quivr_core.config import LLMEndpointConfig, RetrievalConfig +from quivr_core.rag.entities.chat import ChatHistory +from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig from quivr_core.llm import LLMEndpoint -from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph +from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph async def main(): diff --git a/core/tests/rag_config.yaml b/core/tests/rag_config.yaml index 3c43d0cba..9cc68a7ea 100644 --- a/core/tests/rag_config.yaml +++ b/core/tests/rag_config.yaml @@ -27,9 +27,9 @@ retrieval_config: supplier: "openai" # The model to use for the LLM for the given supplier - model: "gpt-4o" + model: "gpt-3.5-turbo-0125" - max_input_tokens: 2000 + max_context_tokens: 2000 # Maximum number of tokens to pass to the LLM # as a context to generate the answer diff --git a/core/tests/rag_config_workflow.yaml b/core/tests/rag_config_workflow.yaml index f566299cb..4a346bc45 100644 --- a/core/tests/rag_config_workflow.yaml +++ b/core/tests/rag_config_workflow.yaml @@ -40,9 +40,9 @@ retrieval_config: supplier: "openai" # The model to use for the LLM for the given supplier - model: "gpt-4o" + model: "gpt-3.5-turbo-0125" - max_input_tokens: 2000 + max_context_tokens: 2000 # Maximum number of tokens to pass to the LLM # as a context to generate the answer diff --git a/core/tests/test_brain.py b/core/tests/test_brain.py index 367df9e07..8421eefc9 100644 --- a/core/tests/test_brain.py +++ b/core/tests/test_brain.py @@ -5,7 +5,7 @@ import pytest from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from quivr_core.brain import Brain -from quivr_core.chat import ChatHistory +from quivr_core.rag.entities.chat import ChatHistory from quivr_core.llm import LLMEndpoint from quivr_core.storage.local_storage import TransparentStorage @@ -104,8 +104,8 @@ async def test_brain_get_history( vector_db=mem_vector_store, ) - brain.ask("question") - brain.ask("question") + await brain.aask("question") + await brain.aask("question") assert len(brain.default_chat) == 4 diff --git a/core/tests/test_chat_history.py b/core/tests/test_chat_history.py index b5af198a6..1df884e3c 100644 --- a/core/tests/test_chat_history.py +++ b/core/tests/test_chat_history.py @@ -3,7 +3,7 @@ from uuid import uuid4 import pytest from langchain_core.messages import AIMessage, HumanMessage -from quivr_core.chat import ChatHistory +from quivr_core.rag.entities.chat import ChatHistory @pytest.fixture diff --git a/core/tests/test_config.py b/core/tests/test_config.py index c7d2320cb..15ca48536 100644 --- a/core/tests/test_config.py +++ b/core/tests/test_config.py @@ -1,18 +1,21 @@ -from quivr_core.config import LLMEndpointConfig, RetrievalConfig +from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig def test_default_llm_config(): config = LLMEndpointConfig() - assert config.model_dump(exclude={"llm_api_key"}) == LLMEndpointConfig( - model="gpt-4o", - llm_base_url=None, - llm_api_key=None, - max_input_tokens=2000, - max_output_tokens=2000, - temperature=0.7, - streaming=True, - ).model_dump(exclude={"llm_api_key"}) + assert ( + config.model_dump() + == LLMEndpointConfig( + model="gpt-4o", + llm_base_url=None, + llm_api_key=None, + max_context_tokens=2000, + max_output_tokens=2000, + temperature=0.7, + streaming=True, + ).model_dump() + ) def test_default_retrievalconfig(): @@ -20,6 +23,6 @@ def test_default_retrievalconfig(): assert config.max_files == 20 assert config.prompt is None - assert config.llm_config.model_dump( - exclude={"llm_api_key"} - ) == LLMEndpointConfig().model_dump(exclude={"llm_api_key"}) + print("\n\n", config.llm_config, "\n\n") + print("\n\n", LLMEndpointConfig(), "\n\n") + assert config.llm_config == LLMEndpointConfig() diff --git a/core/tests/test_llm_endpoint.py b/core/tests/test_llm_endpoint.py index 04c5556aa..b43fc2a03 100644 --- a/core/tests/test_llm_endpoint.py +++ b/core/tests/test_llm_endpoint.py @@ -3,7 +3,7 @@ import os import pytest from langchain_core.language_models import FakeListChatModel from pydantic.v1.error_wrappers import ValidationError -from quivr_core.config import LLMEndpointConfig +from quivr_core.rag.entities.config import LLMEndpointConfig from quivr_core.llm import LLMEndpoint diff --git a/core/tests/test_quivr_rag.py b/core/tests/test_quivr_rag.py index 629d808b8..f6184bf16 100644 --- a/core/tests/test_quivr_rag.py +++ b/core/tests/test_quivr_rag.py @@ -1,25 +1,51 @@ from uuid import uuid4 import pytest -from quivr_core.chat import ChatHistory -from quivr_core.config import LLMEndpointConfig, RetrievalConfig +from quivr_core.rag.entities.chat import ChatHistory +from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig from quivr_core.llm import LLMEndpoint -from quivr_core.models import ParsedRAGChunkResponse, RAGResponseMetadata -from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph +from quivr_core.rag.entities.models import ParsedRAGChunkResponse, RAGResponseMetadata +from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph @pytest.fixture(scope="function") def mock_chain_qa_stream(monkeypatch, chunks_stream_answer): class MockQAChain: async def astream_events(self, *args, **kwargs): - for c in chunks_stream_answer: + default_metadata = { + "langgraph_node": "generate", + "is_final_node": False, + "citations": None, + "followup_questions": None, + "sources": None, + "metadata_model": None, + } + + # Send all chunks except the last one + for chunk in chunks_stream_answer[:-1]: yield { "event": "on_chat_model_stream", - "metadata": {"langgraph_node": "generate"}, - "data": {"chunk": c}, + "metadata": default_metadata, + "data": {"chunk": chunk["answer"]}, } + # Send the last chunk + yield { + "event": "end", + "metadata": { + "langgraph_node": "generate", + "is_final_node": True, + "citations": [], + "followup_questions": None, + "sources": [], + "metadata_model": None, + }, + "data": {"chunk": chunks_stream_answer[-1]["answer"]}, + } + def mock_qa_chain(*args, **kwargs): + self = args[0] + self.final_nodes = ["generate"] return MockQAChain() monkeypatch.setattr(QuivrQARAGLangGraph, "build_chain", mock_qa_chain) @@ -48,11 +74,13 @@ async def test_quivrqaraglanggraph( ): stream_responses.append(resp) + # This assertion passed assert all( not r.last_chunk for r in stream_responses[:-1] ), "Some chunks before last have last_chunk=True" assert stream_responses[-1].last_chunk + # Let's check this assertion for idx, response in enumerate(stream_responses[1:-1]): assert ( len(response.answer) > 0 diff --git a/core/tests/test_utils.py b/core/tests/test_utils.py index 7847f94e1..0fb384c72 100644 --- a/core/tests/test_utils.py +++ b/core/tests/test_utils.py @@ -3,7 +3,7 @@ from uuid import uuid4 import pytest from langchain_core.messages.ai import AIMessageChunk from langchain_core.messages.tool import ToolCall -from quivr_core.utils import ( +from quivr_core.rag.utils import ( get_prev_message_str, model_supports_function_calling, parse_chunk_response, @@ -43,13 +43,9 @@ def test_get_prev_message_str(): def test_parse_chunk_response_nofunc_calling(): rolling_msg = AIMessageChunk(content="") - chunk = { - "answer": AIMessageChunk( - content="next ", - ) - } + chunk = AIMessageChunk(content="next ") for i in range(10): - rolling_msg, parsed_chunk = parse_chunk_response(rolling_msg, chunk, False) + rolling_msg, parsed_chunk, _ = parse_chunk_response(rolling_msg, chunk, False) assert rolling_msg.content == "next " * (i + 1) assert parsed_chunk == "next " @@ -70,8 +66,9 @@ def test_parse_chunk_response_func_calling(chunks_stream_answer): answer_str_history: list[str] = [] for chunk in chunks_stream_answer: - # This is done - rolling_msg, answer_str = parse_chunk_response(rolling_msg, chunk, True) + # Extract the AIMessageChunk from the chunk dictionary + chunk_msg = chunk["answer"] # Get the AIMessageChunk from the dict + rolling_msg, answer_str, _ = parse_chunk_response(rolling_msg, chunk_msg, True) rolling_msgs_history.append(rolling_msg) answer_str_history.append(answer_str) diff --git a/examples/pdf_document_from_yaml.py b/examples/pdf_document_from_yaml.py index 02406931f..2960c8293 100644 --- a/examples/pdf_document_from_yaml.py +++ b/examples/pdf_document_from_yaml.py @@ -5,7 +5,7 @@ from pathlib import Path import dotenv from quivr_core import Brain -from quivr_core.config import AssistantConfig +from quivr_core.rag.entities.config import AssistantConfig from rich.traceback import install as rich_install ConsoleOutputHandler = logging.StreamHandler() diff --git a/examples/pdf_parsing_tika.py b/examples/pdf_parsing_tika.py index b86a232a2..b3f576ced 100644 --- a/examples/pdf_parsing_tika.py +++ b/examples/pdf_parsing_tika.py @@ -1,7 +1,7 @@ from langchain_core.embeddings import DeterministicFakeEmbedding from langchain_core.language_models import FakeListChatModel from quivr_core import Brain -from quivr_core.config import LLMEndpointConfig +from quivr_core.rag.entities.config import LLMEndpointConfig from quivr_core.llm.llm_endpoint import LLMEndpoint from rich.console import Console from rich.panel import Panel diff --git a/examples/simple_question/pyproject.toml b/examples/simple_question/pyproject.toml index ce7161456..1b3287a8f 100644 --- a/examples/simple_question/pyproject.toml +++ b/examples/simple_question/pyproject.toml @@ -7,6 +7,7 @@ authors = [ ] dependencies = [ "quivr-core @ file:///${PROJECT_ROOT}/../../core", + "python-dotenv>=1.0.1", ] readme = "README.md" requires-python = ">= 3.11" diff --git a/examples/simple_question/requirements-dev.lock b/examples/simple_question/requirements-dev.lock index 2083a1da6..07a7448f1 100644 --- a/examples/simple_question/requirements-dev.lock +++ b/examples/simple_question/requirements-dev.lock @@ -55,8 +55,6 @@ frozenlist==1.4.1 # via aiosignal fsspec==2024.10.0 # via huggingface-hub -greenlet==3.1.1 - # via sqlalchemy h11==0.14.0 # via httpcore httpcore==1.0.6 @@ -169,6 +167,7 @@ pydantic==2.9.2 # via langsmith # via openai # via quivr-core + # via sqlmodel pydantic-core==2.23.4 # via cohere # via pydantic @@ -176,6 +175,8 @@ pygments==2.18.0 # via rich python-dateutil==2.9.0.post0 # via pandas +python-dotenv==1.0.1 + # via quivr-core pytz==2024.2 # via pandas pyyaml==6.0.2 @@ -185,6 +186,8 @@ pyyaml==6.0.2 # via langchain-core # via transformers quivr-core @ file:///${PROJECT_ROOT}/../../core +rapidfuzz==3.10.1 + # via quivr-core regex==2024.9.11 # via tiktoken # via transformers @@ -215,6 +218,9 @@ sniffio==1.3.1 sqlalchemy==2.0.36 # via langchain # via langchain-community + # via sqlmodel +sqlmodel==0.0.22 + # via quivr-core tabulate==0.9.0 # via langchain-cohere tenacity==8.5.0 diff --git a/examples/simple_question/requirements.lock b/examples/simple_question/requirements.lock index 2083a1da6..07a7448f1 100644 --- a/examples/simple_question/requirements.lock +++ b/examples/simple_question/requirements.lock @@ -55,8 +55,6 @@ frozenlist==1.4.1 # via aiosignal fsspec==2024.10.0 # via huggingface-hub -greenlet==3.1.1 - # via sqlalchemy h11==0.14.0 # via httpcore httpcore==1.0.6 @@ -169,6 +167,7 @@ pydantic==2.9.2 # via langsmith # via openai # via quivr-core + # via sqlmodel pydantic-core==2.23.4 # via cohere # via pydantic @@ -176,6 +175,8 @@ pygments==2.18.0 # via rich python-dateutil==2.9.0.post0 # via pandas +python-dotenv==1.0.1 + # via quivr-core pytz==2024.2 # via pandas pyyaml==6.0.2 @@ -185,6 +186,8 @@ pyyaml==6.0.2 # via langchain-core # via transformers quivr-core @ file:///${PROJECT_ROOT}/../../core +rapidfuzz==3.10.1 + # via quivr-core regex==2024.9.11 # via tiktoken # via transformers @@ -215,6 +218,9 @@ sniffio==1.3.1 sqlalchemy==2.0.36 # via langchain # via langchain-community + # via sqlmodel +sqlmodel==0.0.22 + # via quivr-core tabulate==0.9.0 # via langchain-cohere tenacity==8.5.0 diff --git a/examples/simple_question/simple_question.py b/examples/simple_question/simple_question.py index 4067635e9..a6a570d82 100644 --- a/examples/simple_question/simple_question.py +++ b/examples/simple_question/simple_question.py @@ -2,6 +2,10 @@ import tempfile from quivr_core import Brain +import dotenv + +dotenv.load_dotenv() + if __name__ == "__main__": with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file: temp_file.write("Gold is a liquid of blue-like colour.") @@ -12,7 +16,5 @@ if __name__ == "__main__": file_paths=[temp_file.name], ) - answer = brain.ask( - "what is gold? asnwer in french" - ) - print("answer QuivrQARAGLangGraph :", answer.answer) + answer = brain.ask("what is gold? answer in french") + print("answer QuivrQARAGLangGraph :", answer) diff --git a/examples/simple_question_streaming.py b/examples/simple_question/simple_question_streaming.py similarity index 93% rename from examples/simple_question_streaming.py rename to examples/simple_question/simple_question_streaming.py index acd75880c..15e37b42e 100644 --- a/examples/simple_question_streaming.py +++ b/examples/simple_question/simple_question_streaming.py @@ -4,7 +4,7 @@ import tempfile from dotenv import load_dotenv from quivr_core import Brain from quivr_core.quivr_rag import QuivrQARAG -from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph +from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph async def main():