feat: websearch, tool use, user intent, dynamic retrieval, multiple questions (#3424)

# Description

This PR includes far too many new features:

- detection of user intent (closes CORE-211)
- treating multiple questions in parallel (closes CORE-212)
- using the chat history when answering a question (closes CORE-213)
- filtering of retrieved chunks by relevance threshold (closes CORE-217)
- dynamic retrieval of chunks (closes CORE-218)
- enabling web search via Tavily (closes CORE-220)
- enabling agent / assistant to activate tools when relevant to complete
the user task (closes CORE-224)

Also closes CORE-205

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):

---------

Co-authored-by: Stan Girard <stan@quivr.app>
This commit is contained in:
Jacopo Chevallard 2024-10-31 17:57:54 +01:00 committed by GitHub
parent 5401c01ee2
commit 285fe5b960
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 2165 additions and 1452 deletions

View File

@ -41,6 +41,4 @@ jobs:
sudo apt-get update sudo apt-get update
sudo apt-get install -y libmagic-dev poppler-utils libreoffice tesseract-ocr pandoc sudo apt-get install -y libmagic-dev poppler-utils libreoffice tesseract-ocr pandoc
cd core cd core
rye run python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()"
rye run python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
rye test -p quivr-core rye test -p quivr-core

View File

@ -9,7 +9,7 @@ dependencies = [
"pydantic>=2.8.2", "pydantic>=2.8.2",
"langchain-core>=0.2.38", "langchain-core>=0.2.38",
"langchain>=0.2.14,<0.3.0", "langchain>=0.2.14,<0.3.0",
"langgraph>=0.2.14", "langgraph>=0.2.38",
"httpx>=0.27.0", "httpx>=0.27.0",
"rich>=13.7.1", "rich>=13.7.1",
"tiktoken>=0.7.0", "tiktoken>=0.7.0",
@ -21,6 +21,7 @@ dependencies = [
"types-pyyaml>=6.0.12.20240808", "types-pyyaml>=6.0.12.20240808",
"transformers[sentencepiece]>=4.44.2", "transformers[sentencepiece]>=4.44.2",
"faiss-cpu>=1.8.0.post1", "faiss-cpu>=1.8.0.post1",
"rapidfuzz>=3.10.1",
] ]
readme = "README.md" readme = "README.md"
requires-python = ">= 3.11" requires-python = ">= 3.11"

View File

@ -10,7 +10,9 @@ from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from quivr_core.rag.entities.models import ParsedRAGResponse
from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAIEmbeddings
from quivr_core.rag.quivr_rag import QuivrQARAG
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
@ -22,19 +24,17 @@ from quivr_core.brain.serialization import (
LocalStorageConfig, LocalStorageConfig,
TransparentStorageConfig, TransparentStorageConfig,
) )
from quivr_core.chat import ChatHistory from quivr_core.rag.entities.chat import ChatHistory
from quivr_core.config import RetrievalConfig from quivr_core.rag.entities.config import RetrievalConfig
from quivr_core.files.file import load_qfile from quivr_core.files.file import load_qfile
from quivr_core.llm import LLMEndpoint from quivr_core.llm import LLMEndpoint
from quivr_core.models import ( from quivr_core.rag.entities.models import (
ParsedRAGChunkResponse, ParsedRAGChunkResponse,
ParsedRAGResponse,
QuivrKnowledge, QuivrKnowledge,
SearchResult, SearchResult,
) )
from quivr_core.processor.registry import get_processor_class from quivr_core.processor.registry import get_processor_class
from quivr_core.quivr_rag import QuivrQARAG from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
from quivr_core.storage.local_storage import LocalStorage, TransparentStorage from quivr_core.storage.local_storage import LocalStorage, TransparentStorage
from quivr_core.storage.storage_base import StorageBase from quivr_core.storage.storage_base import StorageBase
@ -49,19 +49,15 @@ async def process_files(
""" """
Process files in storage. Process files in storage.
This function takes a StorageBase and return a list of langchain documents. This function takes a StorageBase and return a list of langchain documents.
Args: Args:
storage (StorageBase): The storage containing the files to process. storage (StorageBase): The storage containing the files to process.
skip_file_error (bool): Whether to skip files that cannot be processed. skip_file_error (bool): Whether to skip files that cannot be processed.
processor_kwargs (dict[str, Any]): Additional arguments for the processor. processor_kwargs (dict[str, Any]): Additional arguments for the processor.
Returns: Returns:
list[Document]: List of processed documents in the Langchain Document format. list[Document]: List of processed documents in the Langchain Document format.
Raises: Raises:
ValueError: If a file cannot be processed and skip_file_error is False. ValueError: If a file cannot be processed and skip_file_error is False.
Exception: If no processor is found for a file of a specific type and skip_file_error is False. Exception: If no processor is found for a file of a specific type and skip_file_error is False.
""" """
knowledge = [] knowledge = []
@ -91,23 +87,17 @@ async def process_files(
class Brain: class Brain:
""" """
A class representing a Brain. A class representing a Brain.
This class allows for the creation of a Brain, which is a collection of knowledge one wants to retrieve information from. This class allows for the creation of a Brain, which is a collection of knowledge one wants to retrieve information from.
A Brain is set to: A Brain is set to:
* Store files in the storage of your choice (local, S3, etc.) * Store files in the storage of your choice (local, S3, etc.)
* Process the files in the storage to extract text and metadata in a wide range of format. * Process the files in the storage to extract text and metadata in a wide range of format.
* Store the processed files in the vector store of your choice (FAISS, PGVector, etc.) - default to FAISS. * Store the processed files in the vector store of your choice (FAISS, PGVector, etc.) - default to FAISS.
* Create an index of the processed files. * Create an index of the processed files.
* Use the *Quivr* workflow for the retrieval augmented generation. * Use the *Quivr* workflow for the retrieval augmented generation.
A Brain is able to: A Brain is able to:
* Search for information in the vector store. * Search for information in the vector store.
* Answer questions about the knowledges in the Brain. * Answer questions about the knowledges in the Brain.
* Stream the answer to the question. * Stream the answer to the question.
Attributes: Attributes:
name (str): The name of the brain. name (str): The name of the brain.
id (UUID): The unique identifier of the brain. id (UUID): The unique identifier of the brain.
@ -115,16 +105,14 @@ class Brain:
llm (LLMEndpoint): The language model used to generate the answer. llm (LLMEndpoint): The language model used to generate the answer.
vector_db (VectorStore): The vector store used to store the processed files. vector_db (VectorStore): The vector store used to store the processed files.
embedder (Embeddings): The embeddings used to create the index of the processed files. embedder (Embeddings): The embeddings used to create the index of the processed files.
""" """
def __init__( def __init__(
self, self,
*, *,
name: str, name: str,
id: UUID,
llm: LLMEndpoint, llm: LLMEndpoint,
id: UUID | None = None,
vector_db: VectorStore | None = None, vector_db: VectorStore | None = None,
embedder: Embeddings | None = None, embedder: Embeddings | None = None,
storage: StorageBase | None = None, storage: StorageBase | None = None,
@ -156,19 +144,15 @@ class Brain:
def load(cls, folder_path: str | Path) -> Self: def load(cls, folder_path: str | Path) -> Self:
""" """
Load a brain from a folder path. Load a brain from a folder path.
Args: Args:
folder_path (str | Path): The path to the folder containing the brain. folder_path (str | Path): The path to the folder containing the brain.
Returns: Returns:
Brain: The brain loaded from the folder path. Brain: The brain loaded from the folder path.
Example: Example:
```python ```python
brain_loaded = Brain.load("path/to/brain") brain_loaded = Brain.load("path/to/brain")
brain_loaded.print_info() brain_loaded.print_info()
``` ```
""" """
if isinstance(folder_path, str): if isinstance(folder_path, str):
folder_path = Path(folder_path) folder_path = Path(folder_path)
@ -217,16 +201,13 @@ class Brain:
vector_db=vector_db, vector_db=vector_db,
) )
async def save(self, folder_path: str | Path) -> str: async def save(self, folder_path: str | Path):
""" """
Save the brain to a folder path. Save the brain to a folder path.
Args: Args:
folder_path (str | Path): The path to the folder where the brain will be saved. folder_path (str | Path): The path to the folder where the brain will be saved.
Returns: Returns:
str: The path to the folder where the brain was saved. str: The path to the folder where the brain was saved.
Example: Example:
```python ```python
await brain.save("path/to/brain") await brain.save("path/to/brain")
@ -324,10 +305,9 @@ class Brain:
embedder: Embeddings | None = None, embedder: Embeddings | None = None,
skip_file_error: bool = False, skip_file_error: bool = False,
processor_kwargs: dict[str, Any] | None = None, processor_kwargs: dict[str, Any] | None = None,
) -> Self: ):
""" """
Create a brain from a list of file paths. Create a brain from a list of file paths.
Args: Args:
name (str): The name of the brain. name (str): The name of the brain.
file_paths (list[str | Path]): The list of file paths to add to the brain. file_paths (list[str | Path]): The list of file paths to add to the brain.
@ -337,10 +317,8 @@ class Brain:
embedder (Embeddings | None): The embeddings used to create the index of the processed files. embedder (Embeddings | None): The embeddings used to create the index of the processed files.
skip_file_error (bool): Whether to skip files that cannot be processed. skip_file_error (bool): Whether to skip files that cannot be processed.
processor_kwargs (dict[str, Any] | None): Additional arguments for the processor. processor_kwargs (dict[str, Any] | None): Additional arguments for the processor.
Returns: Returns:
Brain: The brain created from the file paths. Brain: The brain created from the file paths.
Example: Example:
```python ```python
brain = await Brain.afrom_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"]) brain = await Brain.afrom_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
@ -429,7 +407,6 @@ class Brain:
) -> Self: ) -> Self:
""" """
Create a brain from a list of langchain documents. Create a brain from a list of langchain documents.
Args: Args:
name (str): The name of the brain. name (str): The name of the brain.
langchain_documents (list[Document]): The list of langchain documents to add to the brain. langchain_documents (list[Document]): The list of langchain documents to add to the brain.
@ -437,10 +414,8 @@ class Brain:
storage (StorageBase): The storage used to store the files. storage (StorageBase): The storage used to store the files.
llm (LLMEndpoint | None): The language model used to generate the answer. llm (LLMEndpoint | None): The language model used to generate the answer.
embedder (Embeddings | None): The embeddings used to create the index of the processed files. embedder (Embeddings | None): The embeddings used to create the index of the processed files.
Returns: Returns:
Brain: The brain created from the langchain documents. Brain: The brain created from the langchain documents.
Example: Example:
```python ```python
from langchain_core.documents import Document from langchain_core.documents import Document
@ -449,6 +424,7 @@ class Brain:
brain.print_info() brain.print_info()
``` ```
""" """
if llm is None: if llm is None:
llm = default_llm() llm = default_llm()
@ -481,16 +457,13 @@ class Brain:
) -> list[SearchResult]: ) -> list[SearchResult]:
""" """
Search for relevant documents in the brain based on a query. Search for relevant documents in the brain based on a query.
Args: Args:
query (str | Document): The query to search for. query (str | Document): The query to search for.
n_results (int): The number of results to return. n_results (int): The number of results to return.
filter (Callable | Dict[str, Any] | None): The filter to apply to the search. filter (Callable | Dict[str, Any] | None): The filter to apply to the search.
fetch_n_neighbors (int): The number of neighbors to fetch. fetch_n_neighbors (int): The number of neighbors to fetch.
Returns: Returns:
list[SearchResult]: The list of retrieved chunks. list[SearchResult]: The list of retrieved chunks.
Example: Example:
```python ```python
brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"]) brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
@ -517,57 +490,6 @@ class Brain:
# add it to vectorstore # add it to vectorstore
raise NotImplementedError raise NotImplementedError
def ask(
self,
question: str,
retrieval_config: RetrievalConfig | None = None,
rag_pipeline: Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None = None,
list_files: list[QuivrKnowledge] | None = None,
chat_history: ChatHistory | None = None,
) -> ParsedRAGResponse:
"""
Ask a question to the brain and get a generated answer.
Args:
question (str): The question to ask.
retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs).
rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use.
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
chat_history (ChatHistory | None): The chat history to use.
Returns:
ParsedRAGResponse: The generated answer.
Example:
```python
brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
answer = brain.ask("What is the meaning of life?")
print(answer.answer)
```
"""
async def collect_streamed_response():
full_answer = ""
async for response in self.ask_streaming(
question=question,
retrieval_config=retrieval_config,
rag_pipeline=rag_pipeline,
list_files=list_files,
chat_history=chat_history
):
full_answer += response.answer
return full_answer
# Run the async function in the event loop
loop = asyncio.get_event_loop()
full_answer = loop.run_until_complete(collect_streamed_response())
chat_history = self.default_chat if chat_history is None else chat_history
chat_history.append(HumanMessage(content=question))
chat_history.append(AIMessage(content=full_answer))
# Return the final response
return ParsedRAGResponse(answer=full_answer)
async def ask_streaming( async def ask_streaming(
self, self,
question: str, question: str,
@ -578,24 +500,20 @@ class Brain:
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]: ) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
""" """
Ask a question to the brain and get a streamed generated answer. Ask a question to the brain and get a streamed generated answer.
Args: Args:
question (str): The question to ask. question (str): The question to ask.
retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs). retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs).
rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use. rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use.
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline. list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
chat_history (ChatHistory | None): The chat history to use. chat_history (ChatHistory | None): The chat history to use.
Returns: Returns:
AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]: The streamed generated answer. AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]: The streamed generated answer.
Example: Example:
```python ```python
brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"]) brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
async for chunk in brain.ask_streaming("What is the meaning of life?"): async for chunk in brain.ask_streaming("What is the meaning of life?"):
print(chunk.answer) print(chunk.answer)
``` ```
""" """
llm = self.llm llm = self.llm
@ -630,3 +548,64 @@ class Brain:
chat_history.append(AIMessage(content=full_answer)) chat_history.append(AIMessage(content=full_answer))
yield response yield response
async def aask(
self,
question: str,
retrieval_config: RetrievalConfig | None = None,
rag_pipeline: Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None = None,
list_files: list[QuivrKnowledge] | None = None,
chat_history: ChatHistory | None = None,
) -> ParsedRAGResponse:
"""
Synchronous version that asks a question to the brain and gets a generated answer.
Args:
question (str): The question to ask.
retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs).
rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use.
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
chat_history (ChatHistory | None): The chat history to use.
Returns:
ParsedRAGResponse: The generated answer.
"""
full_answer = ""
async for response in self.ask_streaming(
question=question,
retrieval_config=retrieval_config,
rag_pipeline=rag_pipeline,
list_files=list_files,
chat_history=chat_history,
):
full_answer += response.answer
return ParsedRAGResponse(answer=full_answer)
def ask(
self,
question: str,
retrieval_config: RetrievalConfig | None = None,
rag_pipeline: Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None = None,
list_files: list[QuivrKnowledge] | None = None,
chat_history: ChatHistory | None = None,
) -> ParsedRAGResponse:
"""
Fully synchronous version that asks a question to the brain and gets a generated answer.
Args:
question (str): The question to ask.
retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs).
rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use.
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
chat_history (ChatHistory | None): The chat history to use.
Returns:
ParsedRAGResponse: The generated answer.
"""
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.aask(
question=question,
retrieval_config=retrieval_config,
rag_pipeline=rag_pipeline,
list_files=list_files,
chat_history=chat_history,
)
)

View File

@ -4,7 +4,7 @@ from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from quivr_core.config import LLMEndpointConfig from quivr_core.rag.entities.config import DefaultModelSuppliers, LLMEndpointConfig
from quivr_core.llm import LLMEndpoint from quivr_core.llm import LLMEndpoint
logger = logging.getLogger("quivr_core") logger = logging.getLogger("quivr_core")
@ -46,7 +46,9 @@ def default_embedder() -> Embeddings:
def default_llm() -> LLMEndpoint: def default_llm() -> LLMEndpoint:
try: try:
logger.debug("Loaded ChatOpenAI as default LLM for brain") logger.debug("Loaded ChatOpenAI as default LLM for brain")
llm = LLMEndpoint.from_config(LLMEndpointConfig()) llm = LLMEndpoint.from_config(
LLMEndpointConfig(supplier=DefaultModelSuppliers.OPENAI, model="gpt-4o")
)
return llm return llm
except ImportError as e: except ImportError as e:

View File

@ -4,9 +4,9 @@ from uuid import UUID
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, SecretStr
from quivr_core.config import LLMEndpointConfig from quivr_core.rag.entities.config import LLMEndpointConfig
from quivr_core.rag.entities.models import ChatMessage
from quivr_core.files.file import QuivrFileSerialized from quivr_core.files.file import QuivrFileSerialized
from quivr_core.models import ChatMessage
class EmbedderConfig(BaseModel): class EmbedderConfig(BaseModel):

View File

@ -1,15 +1,8 @@
import os
from enum import Enum from enum import Enum
from typing import Dict, List, Optional
from uuid import UUID
import yaml import yaml
from pydantic import BaseModel from pydantic import BaseModel
from quivr_core.base_config import QuivrBaseConfig
from quivr_core.processor.splitter import SplitterConfig
from quivr_core.prompts import CustomPromptsModel
class PdfParser(str, Enum): class PdfParser(str, Enum):
LLAMA_PARSE = "llama_parse" LLAMA_PARSE = "llama_parse"
@ -32,489 +25,3 @@ class MegaparseConfig(MegaparseBaseConfig):
strategy: str = "fast" strategy: str = "fast"
llama_parse_api_key: str | None = None llama_parse_api_key: str | None = None
pdf_parser: PdfParser = PdfParser.UNSTRUCTURED pdf_parser: PdfParser = PdfParser.UNSTRUCTURED
class BrainConfig(QuivrBaseConfig):
brain_id: UUID | None = None
name: str
@property
def id(self) -> UUID | None:
return self.brain_id
class DefaultRerankers(str, Enum):
"""
Enum representing the default API-based reranker suppliers supported by the application.
This enum defines the various reranker providers that can be used in the system.
Each enum value corresponds to a specific supplier's identifier and has an
associated default model.
Attributes:
COHERE (str): Represents Cohere AI as a reranker supplier.
JINA (str): Represents Jina AI as a reranker supplier.
Methods:
default_model (property): Returns the default model for the selected supplier.
"""
COHERE = "cohere"
JINA = "jina"
@property
def default_model(self) -> str:
"""
Get the default model for the selected reranker supplier.
This property method returns the default model associated with the current
reranker supplier (COHERE or JINA).
Returns:
str: The name of the default model for the selected supplier.
Raises:
KeyError: If the current enum value doesn't have a corresponding default model.
"""
# Mapping of suppliers to their default models
return {
self.COHERE: "rerank-multilingual-v3.0",
self.JINA: "jina-reranker-v2-base-multilingual",
}[self]
class DefaultModelSuppliers(str, Enum):
"""
Enum representing the default model suppliers supported by the application.
This enum defines the various AI model providers that can be used as sources
for LLMs in the system. Each enum value corresponds to a specific
supplier's identifier.
Attributes:
OPENAI (str): Represents OpenAI as a model supplier.
AZURE (str): Represents Azure (Microsoft) as a model supplier.
ANTHROPIC (str): Represents Anthropic as a model supplier.
META (str): Represents Meta as a model supplier.
MISTRAL (str): Represents Mistral AI as a model supplier.
GROQ (str): Represents Groq as a model supplier.
"""
OPENAI = "openai"
AZURE = "azure"
ANTHROPIC = "anthropic"
META = "meta"
MISTRAL = "mistral"
GROQ = "groq"
class LLMConfig(QuivrBaseConfig):
context: int | None = None
tokenizer_hub: str | None = None
class LLMModelConfig:
_model_defaults: Dict[DefaultModelSuppliers, Dict[str, LLMConfig]] = {
DefaultModelSuppliers.OPENAI: {
"gpt-4o": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
"gpt-4o-mini": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
"gpt-4-turbo": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4"),
"gpt-4": LLMConfig(context=8192, tokenizer_hub="Xenova/gpt-4"),
"gpt-3.5-turbo": LLMConfig(
context=16385, tokenizer_hub="Xenova/gpt-3.5-turbo"
),
"text-embedding-3-large": LLMConfig(
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
),
"text-embedding-3-small": LLMConfig(
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
),
"text-embedding-ada-002": LLMConfig(
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
),
},
DefaultModelSuppliers.ANTHROPIC: {
"claude-3-5-sonnet": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
),
"claude-3-opus": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
),
"claude-3-sonnet": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
),
"claude-3-haiku": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
),
"claude-2-1": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
),
"claude-2-0": LLMConfig(
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
),
"claude-instant-1-2": LLMConfig(
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
),
},
DefaultModelSuppliers.META: {
"llama-3.1": LLMConfig(
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
),
"llama-3": LLMConfig(
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
),
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
"code-llama": LLMConfig(
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
),
},
DefaultModelSuppliers.GROQ: {
"llama-3.1": LLMConfig(
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
),
"llama-3": LLMConfig(
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
),
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
"code-llama": LLMConfig(
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
),
},
DefaultModelSuppliers.MISTRAL: {
"mistral-large": LLMConfig(
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
),
"mistral-small": LLMConfig(
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
),
"mistral-nemo": LLMConfig(
context=128000, tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer"
),
"codestral": LLMConfig(
context=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
),
},
}
@classmethod
def get_supplier_by_model_name(cls, model: str) -> DefaultModelSuppliers | None:
# Iterate over the suppliers and their models
for supplier, models in cls._model_defaults.items():
# Check if the model name or a base part of the model name is in the supplier's models
for base_model_name in models:
if model.startswith(base_model_name):
return supplier
# Return None if no supplier matches the model name
return None
@classmethod
def get_llm_model_config(
cls, supplier: DefaultModelSuppliers, model_name: str
) -> Optional[LLMConfig]:
"""Retrieve the LLMConfig (context and tokenizer_hub) for a given supplier and model."""
supplier_defaults = cls._model_defaults.get(supplier)
if not supplier_defaults:
return None
# Use startswith logic for matching model names
for key, config in supplier_defaults.items():
if model_name.startswith(key):
return config
return None
class LLMEndpointConfig(QuivrBaseConfig):
"""
Configuration class for Large Language Models (LLM) endpoints.
This class defines the settings and parameters for interacting with various LLM providers.
It includes configuration for the model, API keys, token limits, and other relevant settings.
Attributes:
supplier (DefaultModelSuppliers): The LLM provider (default: OPENAI).
model (str): The specific model to use (default: "gpt-4o").
context_length (int | None): The maximum context length for the model.
tokenizer_hub (str | None): The tokenizer to use for this model.
llm_base_url (str | None): Base URL for the LLM API.
env_variable_name (str): Name of the environment variable for the API key.
llm_api_key (str | None): The API key for the LLM provider.
max_input_tokens (int): Maximum number of input tokens sent to the LLM (default: 2000).
max_output_tokens (int): Maximum number of output tokens returned by the LLM (default: 2000).
temperature (float): Temperature setting for text generation (default: 0.7).
streaming (bool): Whether to use streaming for responses (default: True).
prompt (CustomPromptsModel | None): Custom prompt configuration.
"""
supplier: DefaultModelSuppliers = DefaultModelSuppliers.OPENAI
model: str = "gpt-4o"
context_length: int | None = None
tokenizer_hub: str | None = None
llm_base_url: str | None = None
env_variable_name: str = f"{supplier.upper()}_API_KEY"
llm_api_key: str | None = None
max_input_tokens: int = 2000
max_output_tokens: int = 2000
temperature: float = 0.7
streaming: bool = True
prompt: CustomPromptsModel | None = None
_FALLBACK_TOKENIZER = "cl100k_base"
@property
def fallback_tokenizer(self) -> str:
"""
Get the fallback tokenizer.
Returns:
str: The name of the fallback tokenizer.
"""
return self._FALLBACK_TOKENIZER
def __init__(self, **data):
"""
Initialize the LLMEndpointConfig.
This method sets up the initial configuration, including setting the LLM model
config and API key.
"""
super().__init__(**data)
self.set_llm_model_config()
self.set_api_key()
def set_api_key(self, force_reset: bool = False):
"""
Set the API key for the LLM provider.
This method attempts to set the API key from the environment variable.
If the key is not found, it raises a ValueError.
Args:
force_reset (bool): If True, forces a reset of the API key even if already set.
Raises:
ValueError: If the API key is not set in the environment.
"""
if not self.llm_api_key or force_reset:
self.llm_api_key = os.getenv(self.env_variable_name)
if not self.llm_api_key:
raise ValueError(
f"The API key for supplier '{self.supplier}' is not set. "
f"Please set the environment variable: {self.env_variable_name}"
)
def set_llm_model_config(self):
"""
Set the LLM model configuration.
This method automatically sets the context_length and tokenizer_hub
based on the current supplier and model.
"""
llm_model_config = LLMModelConfig.get_llm_model_config(
self.supplier, self.model
)
if llm_model_config:
self.context_length = llm_model_config.context
self.tokenizer_hub = llm_model_config.tokenizer_hub
def set_llm_model(self, model: str):
"""
Set the LLM model and update related configurations.
This method updates the supplier and model based on the provided model name,
then updates the model config and API key accordingly.
Args:
model (str): The name of the model to set.
Raises:
ValueError: If no corresponding supplier is found for the given model.
"""
supplier = LLMModelConfig.get_supplier_by_model_name(model)
if supplier is None:
raise ValueError(
f"Cannot find the corresponding supplier for model {model}"
)
self.supplier = supplier
self.model = model
self.set_llm_model_config()
self.set_api_key(force_reset=True)
def set_from_sqlmodel(self, sqlmodel: BaseModel, mapping: Dict[str, str]):
"""
Set attributes in LLMEndpointConfig from SQLModel attributes using a field mapping.
This method allows for dynamic setting of LLMEndpointConfig attributes based on
a provided SQLModel instance and a mapping dictionary.
Args:
sqlmodel (SQLModel): An instance of the SQLModel class.
mapping (Dict[str, str]): A dictionary that maps SQLModel fields to LLMEndpointConfig fields.
Example: {"max_input": "max_input_tokens", "env_variable_name": "env_variable_name"}
Raises:
AttributeError: If any field in the mapping doesn't exist in either the SQLModel or LLMEndpointConfig.
"""
for model_field, llm_field in mapping.items():
if hasattr(sqlmodel, model_field) and hasattr(self, llm_field):
setattr(self, llm_field, getattr(sqlmodel, model_field))
else:
raise AttributeError(
f"Invalid mapping: {model_field} or {llm_field} does not exist."
)
# Cannot use Pydantic v2 field_validator because of conflicts with pydantic v1 still in use in LangChain
class RerankerConfig(QuivrBaseConfig):
"""
Configuration class for reranker models.
This class defines the settings for reranker models used in the application,
including the supplier, model, and API key information.
Attributes:
supplier (DefaultRerankers | None): The reranker supplier (e.g., COHERE).
model (str | None): The specific reranker model to use.
top_n (int): The number of top chunks returned by the reranker (default: 5).
api_key (str | None): The API key for the reranker service.
"""
supplier: DefaultRerankers | None = None
model: str | None = None
top_n: int = 5
api_key: str | None = None
def __init__(self, **data):
"""
Initialize the RerankerConfig.
"""
super().__init__(**data)
self.validate_model()
def validate_model(self):
"""
Validate and set up the reranker model configuration.
This method ensures that a model is set (using the default if not provided)
and that the necessary API key is available in the environment.
Raises:
ValueError: If the required API key is not set in the environment.
"""
if self.model is None and self.supplier is not None:
self.model = self.supplier.default_model
if self.supplier:
api_key_var = f"{self.supplier.upper()}_API_KEY"
self.api_key = os.getenv(api_key_var)
if self.api_key is None:
raise ValueError(
f"The API key for supplier '{self.supplier}' is not set. "
f"Please set the environment variable: {api_key_var}"
)
class NodeConfig(QuivrBaseConfig):
"""
Configuration class for a node in an AI assistant workflow.
This class represents a single node in a workflow configuration,
defining its name and connections to other nodes.
Attributes:
name (str): The name of the node.
edges (List[str]): List of names of other nodes this node links to.
"""
name: str
edges: List[str]
class WorkflowConfig(QuivrBaseConfig):
"""
Configuration class for an AI assistant workflow.
This class represents the entire workflow configuration,
consisting of multiple interconnected nodes.
Attributes:
name (str): The name of the workflow.
nodes (List[NodeConfig]): List of nodes in the workflow.
"""
name: str
nodes: List[NodeConfig]
class RetrievalConfig(QuivrBaseConfig):
"""
Configuration class for the retrieval phase of a RAG assistant.
This class defines the settings for the retrieval process,
including reranker and LLM configurations, as well as various limits and prompts.
Attributes:
workflow_config (WorkflowConfig | None): Configuration for the workflow.
reranker_config (RerankerConfig): Configuration for the reranker.
llm_config (LLMEndpointConfig): Configuration for the LLM endpoint.
max_history (int): Maximum number of past conversation turns to pass to the LLM as context (default: 10).
max_files (int): Maximum number of files to process (default: 20).
prompt (str | None): Custom prompt for the retrieval process.
"""
workflow_config: WorkflowConfig | None = None
reranker_config: RerankerConfig = RerankerConfig()
llm_config: LLMEndpointConfig = LLMEndpointConfig()
max_history: int = 10
max_files: int = 20
prompt: str | None = None
class ParserConfig(QuivrBaseConfig):
"""
Configuration class for the parser.
This class defines the settings for the parsing process,
including configurations for the text splitter and Megaparse.
Attributes:
splitter_config (SplitterConfig): Configuration for the text splitter.
megaparse_config (MegaparseConfig): Configuration for Megaparse.
"""
splitter_config: SplitterConfig = SplitterConfig()
megaparse_config: MegaparseConfig = MegaparseConfig()
class IngestionConfig(QuivrBaseConfig):
"""
Configuration class for the data ingestion process.
This class defines the settings for the data ingestion process,
including the parser configuration.
Attributes:
parser_config (ParserConfig): Configuration for the parser.
"""
parser_config: ParserConfig = ParserConfig()
class AssistantConfig(QuivrBaseConfig):
"""
Configuration class for an AI assistant.
This class defines the overall configuration for an AI assistant,
including settings for retrieval and ingestion processes.
Attributes:
retrieval_config (RetrievalConfig): Configuration for the retrieval process.
ingestion_config (IngestionConfig): Configuration for the ingestion process.
"""
retrieval_config: RetrievalConfig = RetrievalConfig()
ingestion_config: IngestionConfig = IngestionConfig()

View File

@ -10,8 +10,8 @@ from langchain_openai import AzureChatOpenAI, ChatOpenAI
from pydantic.v1 import SecretStr from pydantic.v1 import SecretStr
from quivr_core.brain.info import LLMInfo from quivr_core.brain.info import LLMInfo
from quivr_core.config import DefaultModelSuppliers, LLMEndpointConfig from quivr_core.rag.entities.config import DefaultModelSuppliers, LLMEndpointConfig
from quivr_core.utils import model_supports_function_calling from quivr_core.rag.utils import model_supports_function_calling
logger = logging.getLogger("quivr_core") logger = logging.getLogger("quivr_core")
@ -70,6 +70,7 @@ class LLMEndpoint:
else None, else None,
azure_endpoint=azure_endpoint, azure_endpoint=azure_endpoint,
max_tokens=config.max_output_tokens, max_tokens=config.max_output_tokens,
temperature=config.temperature,
) )
elif config.supplier == DefaultModelSuppliers.ANTHROPIC: elif config.supplier == DefaultModelSuppliers.ANTHROPIC:
_llm = ChatAnthropic( _llm = ChatAnthropic(
@ -79,6 +80,7 @@ class LLMEndpoint:
else None, else None,
base_url=config.llm_base_url, base_url=config.llm_base_url,
max_tokens=config.max_output_tokens, max_tokens=config.max_output_tokens,
temperature=config.temperature,
) )
elif config.supplier == DefaultModelSuppliers.OPENAI: elif config.supplier == DefaultModelSuppliers.OPENAI:
_llm = ChatOpenAI( _llm = ChatOpenAI(
@ -88,6 +90,7 @@ class LLMEndpoint:
else None, else None,
base_url=config.llm_base_url, base_url=config.llm_base_url,
max_tokens=config.max_output_tokens, max_tokens=config.max_output_tokens,
temperature=config.temperature,
) )
else: else:
_llm = ChatOpenAI( _llm = ChatOpenAI(
@ -97,6 +100,7 @@ class LLMEndpoint:
else None, else None,
base_url=config.llm_base_url, base_url=config.llm_base_url,
max_tokens=config.max_output_tokens, max_tokens=config.max_output_tokens,
temperature=config.temperature,
) )
return cls(llm=_llm, llm_config=config) return cls(llm=_llm, llm_config=config)
@ -118,3 +122,7 @@ class LLMEndpoint:
max_tokens=self._config.max_output_tokens, max_tokens=self._config.max_output_tokens,
supports_function_calling=self.supports_func_calling(), supports_function_calling=self.supports_func_calling(),
) )
def clone_llm(self):
"""Create a new instance of the LLM with the same configuration."""
return self._llm.__class__(**self._llm.__dict__)

View File

View File

@ -0,0 +1,36 @@
from quivr_core.base_config import QuivrBaseConfig
from typing import Callable
from langchain_core.tools import BaseTool
from typing import Dict, Any
class ToolsCategory(QuivrBaseConfig):
name: str
description: str
tools: list
default_tool: str | None = None
create_tool: Callable
def __init__(self, **data):
super().__init__(**data)
self.name = self.name.lower()
class ToolWrapper:
def __init__(self, tool: BaseTool, format_input: Callable, format_output: Callable):
self.tool = tool
self.format_input = format_input
self.format_output = format_output
class ToolRegistry:
def __init__(self):
self._registry = {}
def register_tool(self, tool_name: str, create_func: Callable):
self._registry[tool_name] = create_func
def create_tool(self, tool_name: str, config: Dict[str, Any]) -> ToolWrapper:
if tool_name not in self._registry:
raise ValueError(f"Tool {tool_name} is not supported.")
return self._registry[tool_name](config)

View File

@ -0,0 +1,33 @@
from typing import Dict, Any, Type, Union
from quivr_core.llm_tools.entity import ToolWrapper
from quivr_core.llm_tools.web_search_tools import (
WebSearchTools,
)
from quivr_core.llm_tools.other_tools import (
OtherTools,
)
TOOLS_CATEGORIES = {
WebSearchTools.name: WebSearchTools,
OtherTools.name: OtherTools,
}
# Register all ToolsList enums
TOOLS_LISTS = {
**{tool.value: tool for tool in WebSearchTools.tools},
**{tool.value: tool for tool in OtherTools.tools},
}
class LLMToolFactory:
@staticmethod
def create_tool(tool_name: str, config: Dict[str, Any]) -> Union[ToolWrapper, Type]:
for category, tools_class in TOOLS_CATEGORIES.items():
if tool_name in tools_class.tools:
return tools_class.create_tool(tool_name, config)
elif tool_name.lower() == category and tools_class.default_tool:
return tools_class.create_tool(tools_class.default_tool, config)
raise ValueError(f"Tool {tool_name} is not supported.")

View File

@ -0,0 +1,24 @@
from enum import Enum
from typing import Dict, Any, Type, Union
from langchain_core.tools import BaseTool
from quivr_core.llm_tools.entity import ToolsCategory
from quivr_core.rag.entities.models import cited_answer
class OtherToolsList(str, Enum):
CITED_ANSWER = "cited_answer"
def create_other_tool(tool_name: str, config: Dict[str, Any]) -> Union[BaseTool, Type]:
if tool_name == OtherToolsList.CITED_ANSWER:
return cited_answer
else:
raise ValueError(f"Tool {tool_name} is not supported.")
OtherTools = ToolsCategory(
name="Other",
description="Other tools",
tools=[OtherToolsList.CITED_ANSWER],
create_tool=create_other_tool,
)

View File

@ -0,0 +1,73 @@
from enum import Enum
from typing import Dict, List, Any
from langchain_community.tools import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from quivr_core.llm_tools.entity import ToolsCategory
import os
from pydantic.v1 import SecretStr as SecretStrV1 # Ensure correct import
from quivr_core.llm_tools.entity import ToolWrapper, ToolRegistry
from langchain_core.documents import Document
class WebSearchToolsList(str, Enum):
TAVILY = "tavily"
def create_tavily_tool(config: Dict[str, Any]) -> ToolWrapper:
api_key = (
config.pop("api_key") if "api_key" in config else os.getenv("TAVILY_API_KEY")
)
if not api_key:
raise ValueError(
"Missing required config key 'api_key' or environment variable 'TAVILY_API_KEY'"
)
tavily_api_wrapper = TavilySearchAPIWrapper(
tavily_api_key=SecretStrV1(api_key),
)
tool = TavilySearchResults(
api_wrapper=tavily_api_wrapper,
max_results=config.pop("max_results", 5),
search_depth=config.pop("search_depth", "advanced"),
include_answer=config.pop("include_answer", True),
**config,
)
tool.name = WebSearchToolsList.TAVILY.value
def format_input(task: str) -> Dict[str, Any]:
return {"query": task}
def format_output(response: Any) -> List[Document]:
metadata = {"integration": "", "integration_link": ""}
return [
Document(
page_content=d["content"],
metadata={
**metadata,
"file_name": d["url"],
"original_file_name": d["url"],
},
)
for d in response
]
return ToolWrapper(tool, format_input, format_output)
# Initialize the registry and register tools
web_search_tool_registry = ToolRegistry()
web_search_tool_registry.register_tool(WebSearchToolsList.TAVILY, create_tavily_tool)
def create_web_search_tool(tool_name: str, config: Dict[str, Any]) -> ToolWrapper:
return web_search_tool_registry.create_tool(tool_name, config)
WebSearchTools = ToolsCategory(
name="Web Search",
description="Tools for web searching",
tools=[WebSearchToolsList.TAVILY],
default_tool=WebSearchToolsList.TAVILY,
create_tool=create_web_search_tool,
)

View File

@ -1,119 +0,0 @@
import datetime
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.prompts.base import BasePromptTemplate
from pydantic import ConfigDict, create_model
class CustomPromptsDict(dict):
def __init__(self, type, *args, **kwargs):
super().__init__(*args, **kwargs)
self._type = type
def __setitem__(self, key, value):
# Automatically convert the value into a tuple (my_type, value)
super().__setitem__(key, (self._type, value))
def _define_custom_prompts() -> CustomPromptsDict:
custom_prompts: CustomPromptsDict = CustomPromptsDict(type=BasePromptTemplate)
today_date = datetime.datetime.now().strftime("%B %d, %Y")
# ---------------------------------------------------------------------------
# Prompt for question rephrasing
# ---------------------------------------------------------------------------
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. Keep as much details as possible from previous messages. Keep entity names and all.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
custom_prompts["CONDENSE_QUESTION_PROMPT"] = CONDENSE_QUESTION_PROMPT
# ---------------------------------------------------------------------------
# Prompt for RAG
# ---------------------------------------------------------------------------
system_message_template = (
f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}."
)
system_message_template += """
When answering use markdown.
Use markdown code blocks for code snippets.
Answer in a concise and clear manner.
Use the following pieces of context from files provided by the user to answer the users.
Answer in the same language as the user question.
If you don't know the answer with the context provided from the files, just say that you don't know, don't try to make up an answer.
Don't cite the source id in the answer objects, but you can use the source to answer the question.
You have access to the files to answer the user question (limited to first 20 files):
{files}
If not None, User instruction to follow to answer: {custom_instructions}
Don't cite the source id in the answer objects, but you can use the source to answer the question.
"""
template_answer = """
Context:
{context}
User Question: {question}
Answer:
"""
RAG_ANSWER_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_template),
HumanMessagePromptTemplate.from_template(template_answer),
]
)
custom_prompts["RAG_ANSWER_PROMPT"] = RAG_ANSWER_PROMPT
# ---------------------------------------------------------------------------
# Prompt for formatting documents
# ---------------------------------------------------------------------------
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
template="Source: {index} \n {page_content}"
)
custom_prompts["DEFAULT_DOCUMENT_PROMPT"] = DEFAULT_DOCUMENT_PROMPT
# ---------------------------------------------------------------------------
# Prompt for chatting directly with LLMs, without any document retrieval stage
# ---------------------------------------------------------------------------
system_message_template = (
f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}."
)
system_message_template += """
If not None, also follow these user instructions when answering: {custom_instructions}
"""
template_answer = """
User Question: {question}
Answer:
"""
CHAT_LLM_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_template),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template(template_answer),
]
)
custom_prompts["CHAT_LLM_PROMPT"] = CHAT_LLM_PROMPT
return custom_prompts
_custom_prompts = _define_custom_prompts()
CustomPromptsModel = create_model(
"CustomPromptsModel", **_custom_prompts, __config__=ConfigDict(extra="forbid")
)
custom_prompts = CustomPromptsModel()

View File

@ -1,488 +0,0 @@
import logging
from enum import Enum
from typing import Annotated, AsyncGenerator, Optional, Sequence, TypedDict
from uuid import uuid4
# TODO(@aminediro): this is the only dependency to langchain package, we should remove it
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
from langchain_community.document_compressors import JinaRerank
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.messages import BaseMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.vectorstores import VectorStore
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from quivr_core.chat import ChatHistory
from quivr_core.config import DefaultRerankers, RetrievalConfig
from quivr_core.llm import LLMEndpoint
from quivr_core.models import (
ParsedRAGChunkResponse,
ParsedRAGResponse,
QuivrKnowledge,
RAGResponseMetadata,
cited_answer,
)
from quivr_core.prompts import custom_prompts
from quivr_core.utils import (
combine_documents,
format_file_list,
get_chunk_metadata,
parse_chunk_response,
parse_response,
)
logger = logging.getLogger("quivr_core")
class SpecialEdges(str, Enum):
START = "START"
END = "END"
class AgentState(TypedDict):
# The add_messages function defines how an update should be processed
# Default is to replace. add_messages says "append"
messages: Annotated[Sequence[BaseMessage], add_messages]
chat_history: ChatHistory
docs: list[Document]
files: str
final_response: dict
class IdempotentCompressor(BaseDocumentCompressor):
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
A no-op document compressor that simply returns the documents it is given.
This is a placeholder until a more sophisticated document compression
algorithm is implemented.
"""
return documents
class QuivrQARAGLangGraph:
def __init__(
self,
*,
retrieval_config: RetrievalConfig,
llm: LLMEndpoint,
vector_store: VectorStore | None = None,
reranker: BaseDocumentCompressor | None = None,
):
"""
Construct a QuivrQARAGLangGraph object.
Args:
retrieval_config (RetrievalConfig): The configuration for the RAG model.
llm (LLMEndpoint): The LLM to use for generating text.
vector_store (VectorStore): The vector store to use for storing and retrieving documents.
reranker (BaseDocumentCompressor | None): The document compressor to use for re-ranking documents. Defaults to IdempotentCompressor if not provided.
"""
self.retrieval_config = retrieval_config
self.vector_store = vector_store
self.llm_endpoint = llm
self.graph = None
if reranker is not None:
self.reranker = reranker
elif self.retrieval_config.reranker_config.supplier == DefaultRerankers.COHERE:
self.reranker = CohereRerank(
model=self.retrieval_config.reranker_config.model,
top_n=self.retrieval_config.reranker_config.top_n,
cohere_api_key=self.retrieval_config.reranker_config.api_key,
)
elif self.retrieval_config.reranker_config.supplier == DefaultRerankers.JINA:
self.reranker = JinaRerank(
model=self.retrieval_config.reranker_config.model,
top_n=self.retrieval_config.reranker_config.top_n,
jina_api_key=self.retrieval_config.reranker_config.api_key,
)
else:
self.reranker = IdempotentCompressor()
if self.vector_store:
self.compression_retriever = ContextualCompressionRetriever(
base_compressor=self.reranker, base_retriever=self.retriever
)
@property
def retriever(self):
"""
Returns a retriever that can retrieve documents from the vector store.
Returns:
VectorStoreRetriever: The retriever.
"""
if self.vector_store:
return self.vector_store.as_retriever()
else:
raise ValueError("No vector store provided")
def filter_history(self, state: AgentState) -> dict:
"""
Filter out the chat history to only include the messages that are relevant to the current question
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '),
AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer,
sous la direction de son supérieur hiérarchique, Stanislas Girard."),
HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''),
HumanMessage(content='Dis moi en plus sur elle'),
AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
a token is 4 characters
"""
chat_history = state["chat_history"]
total_tokens = 0
total_pairs = 0
_chat_id = uuid4()
_chat_history = ChatHistory(chat_id=_chat_id, brain_id=chat_history.brain_id)
for human_message, ai_message in reversed(list(chat_history.iter_pairs())):
# TODO: replace with tiktoken
message_tokens = self.llm_endpoint.count_tokens(
human_message.content
) + self.llm_endpoint.count_tokens(ai_message.content)
if (
total_tokens + message_tokens
> self.retrieval_config.llm_config.max_output_tokens
or total_pairs >= self.retrieval_config.max_history
):
break
_chat_history.append(human_message)
_chat_history.append(ai_message)
total_tokens += message_tokens
total_pairs += 1
return {"chat_history": _chat_history}
### Nodes
def rewrite(self, state):
"""
Transform the query to produce a better question.
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
# Grader
msg = custom_prompts.CONDENSE_QUESTION_PROMPT.format(
chat_history=state["chat_history"],
question=state["messages"][0].content,
)
model = self.llm_endpoint._llm
response = model.invoke(msg)
return {"messages": [response]}
def retrieve(self, state):
"""
Retrieve relevent chunks
Args:
state (messages): The current state
Returns:
dict: The retrieved chunks
"""
question = state["messages"][-1].content
docs = self.compression_retriever.invoke(question)
return {"docs": docs}
def generate_rag(self, state):
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
messages = state["messages"]
user_question = messages[0].content
files = state["files"]
docs = state["docs"]
# Prompt
prompt = self.retrieval_config.prompt
final_inputs = {}
final_inputs["context"] = combine_documents(docs) if docs else "None"
final_inputs["question"] = user_question
final_inputs["custom_instructions"] = prompt if prompt else "None"
final_inputs["files"] = files if files else "None"
# LLM
llm = self.llm_endpoint._llm
if self.llm_endpoint.supports_func_calling():
llm = self.llm_endpoint._llm.bind_tools(
[cited_answer],
tool_choice="any",
)
# Chain
rag_chain = custom_prompts.RAG_ANSWER_PROMPT | llm
# Run
response = rag_chain.invoke(final_inputs)
formatted_response = {
"answer": response, # Assuming the last message contains the final answer
"docs": docs,
}
return {"messages": [response], "final_response": formatted_response}
def generate_chat_llm(self, state):
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
messages = state["messages"]
user_question = messages[0].content
# Prompt
prompt = self.retrieval_config.prompt
final_inputs = {}
final_inputs["question"] = user_question
final_inputs["custom_instructions"] = prompt if prompt else "None"
final_inputs["chat_history"] = state["chat_history"].to_list()
# LLM
llm = self.llm_endpoint._llm
# Chain
rag_chain = custom_prompts.CHAT_LLM_PROMPT | llm
# Run
response = rag_chain.invoke(final_inputs)
formatted_response = {
"answer": response, # Assuming the last message contains the final answer
}
return {"messages": [response], "final_response": formatted_response}
def build_chain(self):
"""
Builds the langchain chain for the given configuration.
Returns:
Callable[[Dict], Dict]: The langchain chain.
"""
if not self.graph:
self.graph = self.create_graph()
return self.graph
def create_graph(self):
"""
Builds the langchain chain for the given configuration.
This function creates a state machine which takes a chat history and a question
and produces an answer. The state machine consists of the following states:
- filter_history: Filter the chat history (i.e., remove the last message)
- rewrite: Re-write the question using the filtered history
- retrieve: Retrieve documents related to the re-written question
- generate: Generate an answer using the retrieved documents
The state machine starts in the filter_history state and transitions as follows:
filter_history -> rewrite -> retrieve -> generate -> END
The final answer is returned as a dictionary with the answer and the list of documents
used to generate the answer.
Returns:
Callable[[Dict], Dict]: The langchain chain.
"""
workflow = StateGraph(AgentState)
if self.retrieval_config.workflow_config:
if SpecialEdges.START not in [
node.name for node in self.retrieval_config.workflow_config.nodes
]:
raise ValueError("The workflow should contain a 'START' node")
for node in self.retrieval_config.workflow_config.nodes:
if node.name not in SpecialEdges._value2member_map_:
workflow.add_node(node.name, getattr(self, node.name))
for node in self.retrieval_config.workflow_config.nodes:
for edge in node.edges:
if node.name == SpecialEdges.START:
workflow.add_edge(START, edge)
elif edge == SpecialEdges.END:
workflow.add_edge(node.name, END)
else:
workflow.add_edge(node.name, edge)
else:
# Define the nodes we will cycle between
workflow.add_node("filter_history", self.filter_history)
workflow.add_node("rewrite", self.rewrite) # Re-writing the question
workflow.add_node("retrieve", self.retrieve) # retrieval
workflow.add_node("generate", self.generate_rag)
# Add node for filtering history
workflow.set_entry_point("filter_history")
workflow.add_edge("filter_history", "rewrite")
workflow.add_edge("rewrite", "retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge(
"generate", END
) # Add edge from generate to format_response
# Compile
graph = workflow.compile()
return graph
def answer(
self,
question: str,
history: ChatHistory,
list_files: list[QuivrKnowledge],
metadata: dict[str, str] = {},
) -> ParsedRAGResponse:
"""
Answer a question using the langgraph chain.
Args:
question (str): The question to answer.
history (ChatHistory): The chat history to use for context.
list_files (list[QuivrKnowledge]): The list of files to use for retrieval.
metadata (dict[str, str], optional): The metadata to pass to the langchain invocation. Defaults to {}.
Returns:
ParsedRAGResponse: The answer to the question.
"""
concat_list_files = format_file_list(
list_files, self.retrieval_config.max_files
)
conversational_qa_chain = self.build_chain()
inputs = {
"messages": [
("user", question),
],
"chat_history": history,
"files": concat_list_files,
}
raw_llm_response = conversational_qa_chain.invoke(
inputs,
config={"metadata": metadata},
)
response = parse_response(
raw_llm_response["final_response"], self.retrieval_config.llm_config.model
)
return response
async def answer_astream(
self,
question: str,
history: ChatHistory,
list_files: list[QuivrKnowledge],
metadata: dict[str, str] = {},
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
"""
Answer a question using the langgraph chain and yield each chunk of the answer separately.
Args:
question (str): The question to answer.
history (ChatHistory): The chat history to use for context.
list_files (list[QuivrKnowledge]): The list of files to use for retrieval.
metadata (dict[str, str], optional): The metadata to pass to the langchain invocation. Defaults to {}.
Yields:
ParsedRAGChunkResponse: Each chunk of the answer.
"""
concat_list_files = format_file_list(
list_files, self.retrieval_config.max_files
)
conversational_qa_chain = self.build_chain()
rolling_message = AIMessageChunk(content="")
sources: list[Document] | None = None
prev_answer = ""
chunk_id = 0
async for event in conversational_qa_chain.astream_events(
{
"messages": [
("user", question),
],
"chat_history": history,
"files": concat_list_files,
},
version="v1",
config={"metadata": metadata},
):
kind = event["event"]
if (
not sources
and "output" in event["data"]
and "docs" in event["data"]["output"]
):
sources = event["data"]["output"]["docs"]
if (
kind == "on_chat_model_stream"
and "generate" in event["metadata"]["langgraph_node"]
):
chunk = event["data"]["chunk"]
rolling_message, answer_str = parse_chunk_response(
rolling_message,
chunk,
self.llm_endpoint.supports_func_calling(),
)
if len(answer_str) > 0:
if (
self.llm_endpoint.supports_func_calling()
and rolling_message.tool_calls
):
diff_answer = answer_str[len(prev_answer) :]
if len(diff_answer) > 0:
parsed_chunk = ParsedRAGChunkResponse(
answer=diff_answer,
metadata=RAGResponseMetadata(),
)
prev_answer += diff_answer
logger.debug(
f"answer_astream func_calling=True question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}"
)
yield parsed_chunk
else:
parsed_chunk = ParsedRAGChunkResponse(
answer=answer_str,
metadata=RAGResponseMetadata(),
)
logger.debug(
f"answer_astream func_calling=False question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}"
)
yield parsed_chunk
chunk_id += 1
# Last chunk provides metadata
last_chunk = ParsedRAGChunkResponse(
answer="",
metadata=get_chunk_metadata(rolling_message, sources),
last_chunk=True,
)
logger.debug(
f"answer_astream last_chunk={last_chunk} question={question} rolling_msg={rolling_message} chunk_id={chunk_id}"
)
yield last_chunk

View File

View File

View File

@ -1,11 +1,10 @@
from copy import deepcopy
from datetime import datetime from datetime import datetime
from typing import Any, Generator, List, Tuple from typing import Any, Generator, Tuple, List
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from quivr_core.models import ChatMessage from quivr_core.rag.entities.models import ChatMessage
class ChatHistory: class ChatHistory:
@ -98,17 +97,3 @@ class ChatHistory:
""" """
return [_msg.msg for _msg in self._msgs] return [_msg.msg for _msg in self._msgs]
def __deepcopy__(self, memo):
"""
Support for deepcopy of ChatHistory.
This method ensures that mutable objects (like lists) are copied deeply.
"""
# Create a new instance of ChatHistory
new_copy = ChatHistory(self.id, deepcopy(self.brain_id, memo))
# Perform a deepcopy of the _msgs list
new_copy._msgs = deepcopy(self._msgs, memo)
# Return the deep copied instance
return new_copy

View File

@ -0,0 +1,452 @@
import os
import re
import logging
from enum import Enum
from typing import Dict, Hashable, List, Optional, Union, Any, Type
from uuid import UUID
from pydantic import BaseModel
from langgraph.graph import START, END
from langchain_core.tools import BaseTool
from quivr_core.config import MegaparseConfig
from rapidfuzz import process, fuzz
from quivr_core.base_config import QuivrBaseConfig
from quivr_core.processor.splitter import SplitterConfig
from quivr_core.rag.prompts import CustomPromptsModel
from quivr_core.llm_tools.llm_tools import LLMToolFactory, TOOLS_CATEGORIES, TOOLS_LISTS
logger = logging.getLogger("quivr_core")
def normalize_to_env_variable_name(name: str) -> str:
# Replace any character that is not a letter, digit, or underscore with an underscore
env_variable_name = re.sub(r"[^A-Za-z0-9_]", "_", name).upper()
# Check if the normalized name starts with a digit
if env_variable_name[0].isdigit():
raise ValueError(
f"Invalid environment variable name '{env_variable_name}': Cannot start with a digit."
)
return env_variable_name
class SpecialEdges(str, Enum):
start = "START"
end = "END"
class BrainConfig(QuivrBaseConfig):
brain_id: UUID | None = None
name: str
@property
def id(self) -> UUID | None:
return self.brain_id
class DefaultWebSearchTool(str, Enum):
TAVILY = "tavily"
class DefaultRerankers(str, Enum):
COHERE = "cohere"
JINA = "jina"
# MIXEDBREAD = "mixedbread-ai"
@property
def default_model(self) -> str:
# Mapping of suppliers to their default models
return {
self.COHERE: "rerank-multilingual-v3.0",
self.JINA: "jina-reranker-v2-base-multilingual",
# self.MIXEDBREAD: "rmxbai-rerank-large-v1",
}[self]
class DefaultModelSuppliers(str, Enum):
OPENAI = "openai"
AZURE = "azure"
ANTHROPIC = "anthropic"
META = "meta"
MISTRAL = "mistral"
GROQ = "groq"
class LLMConfig(QuivrBaseConfig):
context: int | None = None
tokenizer_hub: str | None = None
class LLMModelConfig:
_model_defaults: Dict[DefaultModelSuppliers, Dict[str, LLMConfig]] = {
DefaultModelSuppliers.OPENAI: {
"gpt-4o": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
"gpt-4o-mini": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
"gpt-4-turbo": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4"),
"gpt-4": LLMConfig(context=8192, tokenizer_hub="Xenova/gpt-4"),
"gpt-3.5-turbo": LLMConfig(
context=16385, tokenizer_hub="Xenova/gpt-3.5-turbo"
),
"text-embedding-3-large": LLMConfig(
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
),
"text-embedding-3-small": LLMConfig(
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
),
"text-embedding-ada-002": LLMConfig(
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
),
},
DefaultModelSuppliers.ANTHROPIC: {
"claude-3-5-sonnet": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
),
"claude-3-opus": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
),
"claude-3-sonnet": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
),
"claude-3-haiku": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
),
"claude-2-1": LLMConfig(
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
),
"claude-2-0": LLMConfig(
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
),
"claude-instant-1-2": LLMConfig(
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
),
},
DefaultModelSuppliers.META: {
"llama-3.1": LLMConfig(
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
),
"llama-3": LLMConfig(
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
),
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
"code-llama": LLMConfig(
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
),
},
DefaultModelSuppliers.GROQ: {
"llama-3.1": LLMConfig(
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
),
"llama-3": LLMConfig(
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
),
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
"code-llama": LLMConfig(
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
),
},
DefaultModelSuppliers.MISTRAL: {
"mistral-large": LLMConfig(
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
),
"mistral-small": LLMConfig(
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
),
"mistral-nemo": LLMConfig(
context=128000, tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer"
),
"codestral": LLMConfig(
context=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
),
},
}
@classmethod
def get_supplier_by_model_name(cls, model: str) -> DefaultModelSuppliers | None:
# Iterate over the suppliers and their models
for supplier, models in cls._model_defaults.items():
# Check if the model name or a base part of the model name is in the supplier's models
for base_model_name in models:
if model.startswith(base_model_name):
return supplier
# Return None if no supplier matches the model name
return None
@classmethod
def get_llm_model_config(
cls, supplier: DefaultModelSuppliers, model_name: str
) -> Optional[LLMConfig]:
"""Retrieve the LLMConfig (context and tokenizer_hub) for a given supplier and model."""
supplier_defaults = cls._model_defaults.get(supplier)
if not supplier_defaults:
return None
# Use startswith logic for matching model names
for key, config in supplier_defaults.items():
if model_name.startswith(key):
return config
return None
class LLMEndpointConfig(QuivrBaseConfig):
supplier: DefaultModelSuppliers = DefaultModelSuppliers.OPENAI
model: str = "gpt-4o"
context_length: int | None = None
tokenizer_hub: str | None = None
llm_base_url: str | None = None
env_variable_name: str | None = None
llm_api_key: str | None = None
max_context_tokens: int = 2000
max_output_tokens: int = 2000
temperature: float = 0.7
streaming: bool = True
prompt: CustomPromptsModel | None = None
_FALLBACK_TOKENIZER = "cl100k_base"
@property
def fallback_tokenizer(self) -> str:
return self._FALLBACK_TOKENIZER
def __init__(self, **data):
super().__init__(**data)
self.set_llm_model_config()
self.set_api_key()
def set_api_key(self, force_reset: bool = False):
if not self.supplier:
return
# Check if the corresponding API key environment variable is set
if not self.env_variable_name:
self.env_variable_name = (
f"{normalize_to_env_variable_name(self.supplier)}_API_KEY"
)
if not self.llm_api_key or force_reset:
self.llm_api_key = os.getenv(self.env_variable_name)
if not self.llm_api_key:
logger.warning(f"The API key for supplier '{self.supplier}' is not set. ")
def set_llm_model_config(self):
# Automatically set context_length and tokenizer_hub based on the supplier and model
llm_model_config = LLMModelConfig.get_llm_model_config(
self.supplier, self.model
)
if llm_model_config:
self.context_length = llm_model_config.context
self.tokenizer_hub = llm_model_config.tokenizer_hub
def set_llm_model(self, model: str):
supplier = LLMModelConfig.get_supplier_by_model_name(model)
if supplier is None:
raise ValueError(
f"Cannot find the corresponding supplier for model {model}"
)
self.supplier = supplier
self.model = model
self.set_llm_model_config()
self.set_api_key(force_reset=True)
def set_from_sqlmodel(self, sqlmodel: BaseModel, mapping: Dict[str, str]):
"""
Set attributes in LLMEndpointConfig from Model attributes using a field mapping.
:param model_instance: An instance of the Model class.
:param mapping: A dictionary that maps Model fields to LLMEndpointConfig fields.
Example: {"max_input": "max_input_tokens", "env_variable_name": "env_variable_name"}
"""
for model_field, llm_field in mapping.items():
if hasattr(sqlmodel, model_field) and hasattr(self, llm_field):
setattr(self, llm_field, getattr(sqlmodel, model_field))
else:
raise AttributeError(
f"Invalid mapping: {model_field} or {llm_field} does not exist."
)
# Cannot use Pydantic v2 field_validator because of conflicts with pydantic v1 still in use in LangChain
class RerankerConfig(QuivrBaseConfig):
supplier: DefaultRerankers | None = None
model: str | None = None
top_n: int = 5 # Number of chunks returned by the re-ranker
api_key: str | None = None
relevance_score_threshold: float | None = None
relevance_score_key: str = "relevance_score"
def __init__(self, **data):
super().__init__(**data) # Call Pydantic's BaseModel init
self.validate_model() # Automatically call external validation
def validate_model(self):
# If model is not provided, get default model based on supplier
if self.model is None and self.supplier is not None:
self.model = self.supplier.default_model
# Check if the corresponding API key environment variable is set
if self.supplier:
api_key_var = f"{normalize_to_env_variable_name(self.supplier)}_API_KEY"
self.api_key = os.getenv(api_key_var)
if self.api_key is None:
raise ValueError(
f"The API key for supplier '{self.supplier}' is not set. "
f"Please set the environment variable: {api_key_var}"
)
class ConditionalEdgeConfig(QuivrBaseConfig):
routing_function: str
conditions: Union[list, Dict[Hashable, str]]
def __init__(self, **data):
super().__init__(**data)
self.resolve_special_edges()
def resolve_special_edges(self):
"""Replace SpecialEdges enum values with their corresponding langgraph values."""
if isinstance(self.conditions, dict):
# If conditions is a dictionary, iterate through the key-value pairs
for key, value in self.conditions.items():
if value == SpecialEdges.end:
self.conditions[key] = END
elif value == SpecialEdges.start:
self.conditions[key] = START
elif isinstance(self.conditions, list):
# If conditions is a list, iterate through the values
for index, value in enumerate(self.conditions):
if value == SpecialEdges.end:
self.conditions[index] = END
elif value == SpecialEdges.start:
self.conditions[index] = START
class NodeConfig(QuivrBaseConfig):
name: str
edges: List[str] | None = None
conditional_edge: ConditionalEdgeConfig | None = None
tools: List[Dict[str, Any]] | None = None
instantiated_tools: List[BaseTool | Type] | None = None
def __init__(self, **data):
super().__init__(**data)
self._instantiate_tools()
self.resolve_special_edges_in_name_and_edges()
def resolve_special_edges_in_name_and_edges(self):
"""Replace SpecialEdges enum values in name and edges with corresponding langgraph values."""
if self.name == SpecialEdges.start:
self.name = START
elif self.name == SpecialEdges.end:
self.name = END
if self.edges:
for i, edge in enumerate(self.edges):
if edge == SpecialEdges.start:
self.edges[i] = START
elif edge == SpecialEdges.end:
self.edges[i] = END
def _instantiate_tools(self):
"""Instantiate tools based on the configuration."""
if self.tools:
self.instantiated_tools = [
LLMToolFactory.create_tool(tool_config.pop("name"), tool_config)
for tool_config in self.tools
]
class DefaultWorkflow(str, Enum):
RAG = "rag"
@property
def nodes(self) -> List[NodeConfig]:
# Mapping of workflow types to their default node configurations
workflows = {
self.RAG: [
NodeConfig(name=START, edges=["filter_history"]),
NodeConfig(name="filter_history", edges=["rewrite"]),
NodeConfig(name="rewrite", edges=["retrieve"]),
NodeConfig(name="retrieve", edges=["generate_rag"]),
NodeConfig(name="generate_rag", edges=[END]),
]
}
return workflows[self]
class WorkflowConfig(QuivrBaseConfig):
name: str | None = None
nodes: List[NodeConfig] = []
available_tools: List[str] | None = None
validated_tools: List[BaseTool | Type] = []
activated_tools: List[BaseTool | Type] = []
def __init__(self, **data):
super().__init__(**data)
self.check_first_node_is_start()
self.validate_available_tools()
def check_first_node_is_start(self):
if self.nodes and self.nodes[0].name != START:
raise ValueError(f"The first node should be a {SpecialEdges.start} node")
def get_node_tools(self, node_name: str) -> List[Any]:
"""Get tools for a specific node."""
for node in self.nodes:
if node.name == node_name and node.instantiated_tools:
return node.instantiated_tools
return []
def validate_available_tools(self):
if self.available_tools:
valid_tools = list(TOOLS_CATEGORIES.keys()) + list(TOOLS_LISTS.keys())
for tool in self.available_tools:
if tool.lower() in valid_tools:
self.validated_tools.append(
LLMToolFactory.create_tool(tool, {}).tool
)
else:
matches = process.extractOne(
tool.lower(), valid_tools, scorer=fuzz.WRatio
)
if matches:
raise ValueError(
f"Tool {tool} is not a valid ToolsCategory or ToolsList. Did you mean {matches[0]}?"
)
else:
raise ValueError(
f"Tool {tool} is not a valid ToolsCategory or ToolsList"
)
class RetrievalConfig(QuivrBaseConfig):
reranker_config: RerankerConfig = RerankerConfig()
llm_config: LLMEndpointConfig = LLMEndpointConfig()
max_history: int = 10
max_files: int = 20
k: int = 40 # Number of chunks returned by the retriever
prompt: str | None = None
workflow_config: WorkflowConfig = WorkflowConfig(nodes=DefaultWorkflow.RAG.nodes)
def __init__(self, **data):
super().__init__(**data)
self.llm_config.set_api_key(force_reset=True)
class ParserConfig(QuivrBaseConfig):
splitter_config: SplitterConfig = SplitterConfig()
megaparse_config: MegaparseConfig = MegaparseConfig()
class IngestionConfig(QuivrBaseConfig):
parser_config: ParserConfig = ParserConfig()
class AssistantConfig(QuivrBaseConfig):
retrieval_config: RetrievalConfig = RetrievalConfig()
ingestion_config: IngestionConfig = IngestionConfig()

View File

@ -7,7 +7,7 @@ from langchain_core.documents import Document
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.pydantic_v1 import BaseModel as BaseModelV1 from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
from langchain_core.pydantic_v1 import Field as FieldV1 from langchain_core.pydantic_v1 import Field as FieldV1
from pydantic import BaseModel from pydantic import BaseModel, Field
from typing_extensions import TypedDict from typing_extensions import TypedDict
@ -73,9 +73,9 @@ class ChatLLMMetadata(BaseModel):
class RAGResponseMetadata(BaseModel): class RAGResponseMetadata(BaseModel):
citations: list[int] | None = None citations: list[int] = Field(default_factory=list)
followup_questions: list[str] | None = None followup_questions: list[str] = Field(default_factory=list)
sources: list[Any] | None = None sources: list[Any] = Field(default_factory=list)
metadata_model: ChatLLMMetadata | None = None metadata_model: ChatLLMMetadata | None = None

View File

@ -0,0 +1,265 @@
import datetime
from pydantic import ConfigDict, create_model
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
SystemMessagePromptTemplate,
MessagesPlaceholder,
)
class CustomPromptsDict(dict):
def __init__(self, type, *args, **kwargs):
super().__init__(*args, **kwargs)
self._type = type
def __setitem__(self, key, value):
# Automatically convert the value into a tuple (my_type, value)
super().__setitem__(key, (self._type, value))
def _define_custom_prompts() -> CustomPromptsDict:
custom_prompts: CustomPromptsDict = CustomPromptsDict(type=BasePromptTemplate)
today_date = datetime.datetime.now().strftime("%B %d, %Y")
# ---------------------------------------------------------------------------
# Prompt for question rephrasing
# ---------------------------------------------------------------------------
system_message_template = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is. "
"Do not output your reasoning, just the question."
)
template_answer = "User question: {question}\n Standalone question:"
CONDENSE_QUESTION_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_template),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template(template_answer),
]
)
custom_prompts["CONDENSE_QUESTION_PROMPT"] = CONDENSE_QUESTION_PROMPT
# ---------------------------------------------------------------------------
# Prompt for RAG
# ---------------------------------------------------------------------------
system_message_template = f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}. "
system_message_template += (
"- When answering use markdown. Use markdown code blocks for code snippets.\n"
"- Answer in a concise and clear manner.\n"
"- If no preferred language is provided, answer in the same language as the language used by the user.\n"
"- You must use ONLY the provided context to answer the question. "
"Do not use any prior knowledge or external information, even if you are certain of the answer.\n"
"- If you cannot provide an answer using ONLY the context provided, do not attempt to answer from your own knowledge."
"Instead, inform the user that the answer isn't available in the context and suggest using the available tools {tools}.\n"
"- Do not apologize when providing an answer.\n"
"- Don't cite the source id in the answer objects, but you can use the source to answer the question.\n\n"
)
context_template = (
"\n"
"- You have access to the following internal reasoning to provide an answer: {reasoning}\n"
"- You have access to the following files to answer the user question (limited to first 20 files): {files}\n"
"- You have access to the following context to answer the user question: {context}\n"
"- Follow these user instruction when crafting the answer: {custom_instructions}\n"
"- These user instructions shall take priority over any other previous instruction.\n"
"- Remember: if you cannot provide an answer using ONLY the provided context and CITING the sources, "
"inform the user that you don't have the answer and consider if any of the tools can help answer the question.\n"
"- Explain your reasoning about the potentiel tool usage in the answer.\n"
"- Only use binded tools to answer the question.\n"
# "OFFER the user the possibility to ACTIVATE a relevant tool among "
# "the tools which can be activated."
# "Tools which can be activated: {tools}. If any of these tools can help in providing an answer "
# "to the user question, you should offer the user the possibility to activate it. "
# "Remember, you shall NOT use the above tools, ONLY offer the user the possibility to activate them.\n"
)
template_answer = (
"Original task: {question}\n"
"Rephrased and contextualized task: {rephrased_task}\n"
)
RAG_ANSWER_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_template),
MessagesPlaceholder(variable_name="chat_history"),
SystemMessagePromptTemplate.from_template(context_template),
HumanMessagePromptTemplate.from_template(template_answer),
]
)
custom_prompts["RAG_ANSWER_PROMPT"] = RAG_ANSWER_PROMPT
# ---------------------------------------------------------------------------
# Prompt for formatting documents
# ---------------------------------------------------------------------------
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
template="Filename: {original_file_name}\nSource: {index} \n {page_content}"
)
custom_prompts["DEFAULT_DOCUMENT_PROMPT"] = DEFAULT_DOCUMENT_PROMPT
# ---------------------------------------------------------------------------
# Prompt for chatting directly with LLMs, without any document retrieval stage
# ---------------------------------------------------------------------------
system_message_template = (
f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}."
)
system_message_template += """
If not None, also follow these user instructions when answering: {custom_instructions}
"""
template_answer = """
User Question: {question}
Answer:
"""
CHAT_LLM_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_template),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template(template_answer),
]
)
custom_prompts["CHAT_LLM_PROMPT"] = CHAT_LLM_PROMPT
# ---------------------------------------------------------------------------
# Prompt to understand the user intent
# ---------------------------------------------------------------------------
system_message_template = (
"Given the following user input, determine the user intent, in particular "
"whether the user is providing instructions to the system or is asking the system to "
"execute a task:\n"
" - if the user is providing direct instructions to modify the system behaviour (for instance, "
"'Can you reply in French?' or 'Answer in French' or 'You are an expert legal assistant' "
"or 'You will behave as...'), the user intent is 'prompt';\n"
" - in all other cases (asking questions, asking for summarising a text, asking for translating a text, ...), "
"the intent is 'task'.\n"
)
template_answer = "User input: {question}"
USER_INTENT_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_template),
HumanMessagePromptTemplate.from_template(template_answer),
]
)
custom_prompts["USER_INTENT_PROMPT"] = USER_INTENT_PROMPT
# ---------------------------------------------------------------------------
# Prompt to create a system prompt from user instructions
# ---------------------------------------------------------------------------
system_message_template = (
"- Given the following user instruction, current system prompt, list of available tools "
"and list of activated tools, update the prompt to include the instruction and decide which tools to activate.\n"
"- The prompt shall only contain generic instructions which can be applied to any user task or question.\n"
"- The prompt shall be concise and clear.\n"
"- If the system prompt already contains the instruction, do not add it again.\n"
"- If the system prompt contradicts ther user instruction, remove the contradictory "
"statement or statements in the system prompt.\n"
"- You shall return separately the updated system prompt and the reasoning that led to the update.\n"
"- If the system prompt refers to a tool, you shall add the tool to the list of activated tools.\n"
"- If no tool activation is needed, return empty lists.\n"
"- You shall also return the reasoning that led to the tool activation.\n"
"- Current system prompt: {system_prompt}\n"
"- List of available tools: {available_tools}\n"
"- List of activated tools: {activated_tools}\n\n"
)
template_answer = "User instructions: {instruction}\n"
UPDATE_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_template),
HumanMessagePromptTemplate.from_template(template_answer),
]
)
custom_prompts["UPDATE_PROMPT"] = UPDATE_PROMPT
# ---------------------------------------------------------------------------
# Prompt to split the user input into multiple questions / instructions
# ---------------------------------------------------------------------------
system_message_template = (
"Given a chat history and the user input, split and rephrase the input into instructions and tasks.\n"
"- Instructions direct the system to behave in a certain way or to use specific tools: examples of instructions are "
"'Can you reply in French?', 'Answer in French', 'You are an expert legal assistant', "
"'You will behave as...', 'Use web search').\n"
"- You shall collect and condense all the instructions into a single string.\n"
"- The instructions shall be standalone and self-contained, so that they can be understood "
"without the chat history. If no instructions are found, return an empty string.\n"
"- Instructions to be understood may require considering the chat history.\n"
"- Tasks are often questions, but they can also be summarisation tasks, translation tasks, content generation tasks, etc.\n"
"- Tasks to be understood may require considering the chat history.\n"
"- If the user input contains different tasks, you shall split the input into multiple tasks.\n"
"- Each splitted task shall be a standalone, self-contained task which can be understood "
"without the chat history. You shall rephrase the tasks if needed.\n"
"- If no explicit task is present, you shall infer the tasks from the user input and the chat history.\n"
"- Do NOT try to solve the tasks or answer the questions, "
"just reformulate them if needed and otherwise return them as is.\n"
"- Remember, you shall NOT suggest or generate new tasks.\n"
"- As an example, the user input 'What is Apple? Who is its CEO? When was it founded?' "
"shall be split into the questions 'What is Apple?', 'Who is the CEO of Apple?' and 'When was Apple founded?'.\n"
"- If no tasks are found, return the user input as is in the task.\n"
)
template_answer = "User input: {user_input}"
SPLIT_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_template),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template(template_answer),
]
)
custom_prompts["SPLIT_PROMPT"] = SPLIT_PROMPT
# ---------------------------------------------------------------------------
# Prompt to grade the relevance of an answer and decide whather to perform a web search
# ---------------------------------------------------------------------------
system_message_template = (
"Given the following tasks you shall determine whether all tasks can be "
"completed fully and in the best possible way using the provided context and chat history. "
"You shall:\n"
"- Consider each task separately,\n"
"- Determine whether the context and chat history contain "
"all the information necessary to complete the task.\n"
"- If the context and chat history do not contain all the information necessary to complete the task, "
"consider ONLY the list of tools below and select the tool most appropriate to complete the task.\n"
"- If no tools are listed, return the tasks as is and no tool.\n"
"- If no relevant tool can be selected, return the tasks as is and no tool.\n"
"- Do not propose to use a tool if that tool is not listed among the available tools.\n"
)
context_template = "Context: {context}\n {activated_tools}\n"
template_answer = "Tasks: {tasks}\n"
TOOL_ROUTING_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_template),
MessagesPlaceholder(variable_name="chat_history"),
SystemMessagePromptTemplate.from_template(context_template),
HumanMessagePromptTemplate.from_template(template_answer),
]
)
custom_prompts["TOOL_ROUTING_PROMPT"] = TOOL_ROUTING_PROMPT
return custom_prompts
_custom_prompts = _define_custom_prompts()
CustomPromptsModel = create_model(
"CustomPromptsModel", **_custom_prompts, __config__=ConfigDict(extra="forbid")
)
custom_prompts = CustomPromptsModel()

View File

@ -12,18 +12,18 @@ from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from quivr_core.chat import ChatHistory from quivr_core.rag.entities.chat import ChatHistory
from quivr_core.config import RetrievalConfig from quivr_core.rag.entities.config import RetrievalConfig
from quivr_core.llm import LLMEndpoint from quivr_core.llm import LLMEndpoint
from quivr_core.models import ( from quivr_core.rag.entities.models import (
ParsedRAGChunkResponse, ParsedRAGChunkResponse,
ParsedRAGResponse, ParsedRAGResponse,
QuivrKnowledge, QuivrKnowledge,
RAGResponseMetadata, RAGResponseMetadata,
cited_answer, cited_answer,
) )
from quivr_core.prompts import custom_prompts from quivr_core.rag.prompts import custom_prompts
from quivr_core.utils import ( from quivr_core.rag.utils import (
combine_documents, combine_documents,
format_file_list, format_file_list,
get_chunk_metadata, get_chunk_metadata,

View File

@ -0,0 +1,891 @@
import logging
from typing import (
Annotated,
AsyncGenerator,
List,
Optional,
Sequence,
Tuple,
TypedDict,
Dict,
Any,
Type,
)
from uuid import uuid4
import asyncio
# TODO(@aminediro): this is the only dependency to langchain package, we should remove it
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
from langchain_community.document_compressors import JinaRerank
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.messages import BaseMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.vectorstores import VectorStore
from langchain_core.prompts.base import BasePromptTemplate
from langgraph.graph import START, END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.types import Send
from pydantic import BaseModel, Field
import openai
from quivr_core.rag.entities.chat import ChatHistory
from quivr_core.rag.entities.config import DefaultRerankers, NodeConfig, RetrievalConfig
from quivr_core.llm import LLMEndpoint
from quivr_core.llm_tools.llm_tools import LLMToolFactory
from quivr_core.rag.entities.models import (
ParsedRAGChunkResponse,
QuivrKnowledge,
)
from quivr_core.rag.prompts import custom_prompts
from quivr_core.rag.utils import (
collect_tools,
combine_documents,
format_file_list,
get_chunk_metadata,
parse_chunk_response,
)
logger = logging.getLogger("quivr_core")
class SplittedInput(BaseModel):
instructions_reasoning: Optional[str] = Field(
default=None,
description="The reasoning that leads to identifying the user instructions to the system",
)
instructions: Optional[str] = Field(
default=None, description="The instructions to the system"
)
tasks_reasoning: Optional[str] = Field(
default=None,
description="The reasoning that leads to identifying the explicit or implicit user tasks and questions",
)
tasks: Optional[List[str]] = Field(
default_factory=lambda: ["No explicit or implicit tasks found"],
description="The list of standalone, self-contained tasks or questions.",
)
class TasksCompletion(BaseModel):
completable_tasks_reasoning: Optional[str] = Field(
default=None,
description="The reasoning that leads to identifying the user tasks or questions that can be completed using the provided context and chat history.",
)
completable_tasks: Optional[List[str]] = Field(
default_factory=list,
description="The user tasks or questions that can be completed using the provided context and chat history.",
)
non_completable_tasks_reasoning: Optional[str] = Field(
default=None,
description="The reasoning that leads to identifying the user tasks or questions that cannot be completed using the provided context and chat history.",
)
non_completable_tasks: Optional[List[str]] = Field(
default_factory=list,
description="The user tasks or questions that need a tool to be completed.",
)
tool_reasoning: Optional[str] = Field(
default=None,
description="The reasoning that leads to identifying the tool that shall be used to complete the tasks.",
)
tool: Optional[str] = Field(
default_factory=list,
description="The tool that shall be used to complete the tasks.",
)
class FinalAnswer(BaseModel):
reasoning_answer: str = Field(
description="The step-by-step reasoning that led to the final answer"
)
answer: str = Field(description="The final answer to the user tasks/questions")
all_tasks_completed: bool = Field(
description="Whether all tasks/questions have been successfully answered/completed or not. "
" If the final answer to the user is 'I don't know' or 'I don't have enough information' or 'I'm not sure', "
" this variable should be 'false'"
)
class UpdatedPromptAndTools(BaseModel):
prompt_reasoning: Optional[str] = Field(
default=None,
description="The step-by-step reasoning that leads to the updated system prompt",
)
prompt: Optional[str] = Field(default=None, description="The updated system prompt")
tools_reasoning: Optional[str] = Field(
default=None,
description="The reasoning that leads to activating and deactivating the tools",
)
tools_to_activate: Optional[List[str]] = Field(
default_factory=list, description="The list of tools to activate"
)
tools_to_deactivate: Optional[List[str]] = Field(
default_factory=list, description="The list of tools to deactivate"
)
class AgentState(TypedDict):
# The add_messages function defines how an update should be processed
# Default is to replace. add_messages says "append"
messages: Annotated[Sequence[BaseMessage], add_messages]
reasoning: List[str]
chat_history: ChatHistory
docs: list[Document]
files: str
tasks: List[str]
instructions: str
tool: str
class IdempotentCompressor(BaseDocumentCompressor):
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
A no-op document compressor that simply returns the documents it is given.
This is a placeholder until a more sophisticated document compression
algorithm is implemented.
"""
return documents
class QuivrQARAGLangGraph:
def __init__(
self,
*,
retrieval_config: RetrievalConfig,
llm: LLMEndpoint,
vector_store: VectorStore | None = None,
):
"""
Construct a QuivrQARAGLangGraph object.
Args:
retrieval_config (RetrievalConfig): The configuration for the RAG model.
llm (LLMEndpoint): The LLM to use for generating text.
vector_store (VectorStore): The vector store to use for storing and retrieving documents.
reranker (BaseDocumentCompressor | None): The document compressor to use for re-ranking documents. Defaults to IdempotentCompressor if not provided.
"""
self.retrieval_config = retrieval_config
self.vector_store = vector_store
self.llm_endpoint = llm
self.graph = None
def get_reranker(self, **kwargs):
# Extract the reranker configuration from self
config = self.retrieval_config.reranker_config
# Allow kwargs to override specific config values
supplier = kwargs.pop("supplier", config.supplier)
model = kwargs.pop("model", config.model)
top_n = kwargs.pop("top_n", config.top_n)
api_key = kwargs.pop("api_key", config.api_key)
if supplier == DefaultRerankers.COHERE:
reranker = CohereRerank(
model=model, top_n=top_n, cohere_api_key=api_key, **kwargs
)
elif supplier == DefaultRerankers.JINA:
reranker = JinaRerank(
model=model, top_n=top_n, jina_api_key=api_key, **kwargs
)
else:
reranker = IdempotentCompressor()
return reranker
def get_retriever(self, **kwargs):
"""
Returns a retriever that can retrieve documents from the vector store.
Returns:
VectorStoreRetriever: The retriever.
"""
if self.vector_store:
retriever = self.vector_store.as_retriever(**kwargs)
else:
raise ValueError("No vector store provided")
return retriever
def routing(self, state: AgentState) -> List[Send]:
"""
The routing function for the RAG model.
Args:
state (AgentState): The current state of the agent.
Returns:
dict: The next state of the agent.
"""
msg = custom_prompts.SPLIT_PROMPT.format(
user_input=state["messages"][0].content,
)
response: SplittedInput
try:
structured_llm = self.llm_endpoint._llm.with_structured_output(
SplittedInput, method="json_schema"
)
response = structured_llm.invoke(msg)
except openai.BadRequestError:
structured_llm = self.llm_endpoint._llm.with_structured_output(
SplittedInput
)
response = structured_llm.invoke(msg)
send_list: List[Send] = []
instructions = (
response.instructions
if response.instructions
else self.retrieval_config.prompt
)
if instructions:
send_list.append(Send("edit_system_prompt", {"instructions": instructions}))
elif response.tasks:
chat_history = state["chat_history"]
send_list.append(
Send(
"filter_history",
{"chat_history": chat_history, "tasks": response.tasks},
)
)
return send_list
def routing_split(self, state: AgentState):
response = self.invoke_structured_output(
custom_prompts.SPLIT_PROMPT.format(
chat_history=state["chat_history"].to_list(),
user_input=state["messages"][0].content,
),
SplittedInput,
)
instructions = response.instructions or self.retrieval_config.prompt
tasks = response.tasks or []
if instructions:
return [
Send(
"edit_system_prompt",
{**state, "instructions": instructions, "tasks": tasks},
)
]
elif tasks:
return [Send("filter_history", {**state, "tasks": tasks})]
return []
def update_active_tools(self, updated_prompt_and_tools: UpdatedPromptAndTools):
if updated_prompt_and_tools.tools_to_activate:
for tool in updated_prompt_and_tools.tools_to_activate:
for (
validated_tool
) in self.retrieval_config.workflow_config.validated_tools:
if tool == validated_tool.name:
self.retrieval_config.workflow_config.activated_tools.append(
validated_tool
)
if updated_prompt_and_tools.tools_to_deactivate:
for tool in updated_prompt_and_tools.tools_to_deactivate:
for (
activated_tool
) in self.retrieval_config.workflow_config.activated_tools:
if tool == activated_tool.name:
self.retrieval_config.workflow_config.activated_tools.remove(
activated_tool
)
def edit_system_prompt(self, state: AgentState) -> AgentState:
user_instruction = state["instructions"]
prompt = self.retrieval_config.prompt
available_tools, activated_tools = collect_tools(
self.retrieval_config.workflow_config
)
inputs = {
"instruction": user_instruction,
"system_prompt": prompt if prompt else "",
"available_tools": available_tools,
"activated_tools": activated_tools,
}
msg = custom_prompts.UPDATE_PROMPT.format(**inputs)
response: UpdatedPromptAndTools = self.invoke_structured_output(
msg, UpdatedPromptAndTools
)
self.update_active_tools(response)
self.retrieval_config.prompt = response.prompt
reasoning = [response.prompt_reasoning] if response.prompt_reasoning else []
reasoning += [response.tools_reasoning] if response.tools_reasoning else []
return {**state, "messages": [], "reasoning": reasoning}
def filter_history(self, state: AgentState) -> AgentState:
"""
Filter out the chat history to only include the messages that are relevant to the current question
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '),
AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer,
sous la direction de son supérieur hiérarchique, Stanislas Girard."),
HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''),
HumanMessage(content='Dis moi en plus sur elle'),
AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
a token is 4 characters
"""
chat_history = state["chat_history"]
total_tokens = 0
total_pairs = 0
_chat_id = uuid4()
_chat_history = ChatHistory(chat_id=_chat_id, brain_id=chat_history.brain_id)
for human_message, ai_message in reversed(list(chat_history.iter_pairs())):
# TODO: replace with tiktoken
message_tokens = self.llm_endpoint.count_tokens(
human_message.content
) + self.llm_endpoint.count_tokens(ai_message.content)
if (
total_tokens + message_tokens
> self.retrieval_config.llm_config.max_context_tokens
or total_pairs >= self.retrieval_config.max_history
):
break
_chat_history.append(human_message)
_chat_history.append(ai_message)
total_tokens += message_tokens
total_pairs += 1
return {**state, "chat_history": _chat_history}
### Nodes
async def rewrite(self, state: AgentState) -> AgentState:
"""
Transform the query to produce a better question.
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
tasks = (
state["tasks"]
if "tasks" in state and state["tasks"]
else [state["messages"][0].content]
)
# Prepare the async tasks for all user tsks
async_tasks = []
for task in tasks:
msg = custom_prompts.CONDENSE_QUESTION_PROMPT.format(
chat_history=state["chat_history"].to_list(),
question=task,
)
model = self.llm_endpoint._llm
# Asynchronously invoke the model for each question
async_tasks.append(model.ainvoke(msg))
# Gather all the responses asynchronously
responses = await asyncio.gather(*async_tasks) if async_tasks else []
# Replace each question with its condensed version
condensed_questions = []
for response in responses:
condensed_questions.append(response.content)
return {**state, "tasks": condensed_questions}
def filter_chunks_by_relevance(self, chunks: List[Document], **kwargs):
config = self.retrieval_config.reranker_config
relevance_score_threshold = kwargs.get(
"relevance_score_threshold", config.relevance_score_threshold
)
if relevance_score_threshold is None:
return chunks
filtered_chunks = []
for chunk in chunks:
if config.relevance_score_key not in chunk.metadata:
logger.warning(
f"Relevance score key {config.relevance_score_key} not found in metadata, cannot filter chunks by relevance"
)
filtered_chunks.append(chunk)
elif (
chunk.metadata[config.relevance_score_key] >= relevance_score_threshold
):
filtered_chunks.append(chunk)
return filtered_chunks
def tool_routing(self, state: AgentState):
tasks = state["tasks"]
if not tasks:
return [Send("generate_rag", state)]
docs = state["docs"]
_, activated_tools = collect_tools(self.retrieval_config.workflow_config)
input = {
"chat_history": state["chat_history"].to_list(),
"tasks": state["tasks"],
"context": docs,
"activated_tools": activated_tools,
}
input, _ = self.reduce_rag_context(
inputs=input,
prompt=custom_prompts.TOOL_ROUTING_PROMPT,
docs=docs,
# max_context_tokens=2000,
)
msg = custom_prompts.TOOL_ROUTING_PROMPT.format(**input)
response: TasksCompletion = self.invoke_structured_output(msg, TasksCompletion)
send_list: List[Send] = []
if response.non_completable_tasks and response.tool:
payload = {
**state,
"tasks": response.non_completable_tasks,
"tool": response.tool,
}
send_list.append(Send("run_tool", payload))
else:
send_list.append(Send("generate_rag", state))
return send_list
async def run_tool(self, state: AgentState) -> AgentState:
tool = state["tool"]
if tool not in [
t.name for t in self.retrieval_config.workflow_config.activated_tools
]:
raise ValueError(f"Tool {tool} not activated")
tasks = state["tasks"]
tool_wrapper = LLMToolFactory.create_tool(tool, {})
# Prepare the async tasks for all questions
async_tasks = []
for task in tasks:
formatted_input = tool_wrapper.format_input(task)
# Asynchronously invoke the model for each question
async_tasks.append(tool_wrapper.tool.ainvoke(formatted_input))
# Gather all the responses asynchronously
responses = await asyncio.gather(*async_tasks) if async_tasks else []
docs = []
for response in responses:
_docs = tool_wrapper.format_output(response)
_docs = self.filter_chunks_by_relevance(_docs)
docs += _docs
return {**state, "docs": state["docs"] + docs}
async def retrieve(self, state: AgentState) -> AgentState:
"""
Retrieve relevent chunks
Args:
state (messages): The current state
Returns:
dict: The retrieved chunks
"""
tasks = state["tasks"]
if not tasks:
return {**state, "docs": []}
kwargs = {
"search_kwargs": {
"k": self.retrieval_config.k,
}
} # type: ignore
base_retriever = self.get_retriever(**kwargs)
kwargs = {"top_n": self.retrieval_config.reranker_config.top_n} # type: ignore
reranker = self.get_reranker(**kwargs)
compression_retriever = ContextualCompressionRetriever(
base_compressor=reranker, base_retriever=base_retriever
)
# Prepare the async tasks for all questions
async_tasks = []
for task in tasks:
# Asynchronously invoke the model for each question
async_tasks.append(compression_retriever.ainvoke(task))
# Gather all the responses asynchronously
responses = await asyncio.gather(*async_tasks) if async_tasks else []
docs = []
for response in responses:
_docs = self.filter_chunks_by_relevance(response)
docs += _docs
return {**state, "docs": docs}
async def dynamic_retrieve(self, state: AgentState) -> AgentState:
"""
Retrieve relevent chunks
Args:
state (messages): The current state
Returns:
dict: The retrieved chunks
"""
tasks = state["tasks"]
if not tasks:
return {**state, "docs": []}
k = self.retrieval_config.k
top_n = self.retrieval_config.reranker_config.top_n
number_of_relevant_chunks = top_n
i = 1
while number_of_relevant_chunks == top_n:
top_n = self.retrieval_config.reranker_config.top_n * i
kwargs = {"top_n": top_n}
reranker = self.get_reranker(**kwargs)
k = max([top_n * 2, self.retrieval_config.k])
kwargs = {"search_kwargs": {"k": k}} # type: ignore
base_retriever = self.get_retriever(**kwargs)
if i > 1:
logging.info(
f"Increasing top_n to {top_n} and k to {k} to retrieve more relevant chunks"
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=reranker, base_retriever=base_retriever
)
# Prepare the async tasks for all questions
async_tasks = []
for task in tasks:
# Asynchronously invoke the model for each question
async_tasks.append(compression_retriever.ainvoke(task))
# Gather all the responses asynchronously
responses = await asyncio.gather(*async_tasks) if async_tasks else []
docs = []
_n = []
for response in responses:
_docs = self.filter_chunks_by_relevance(response)
_n.append(len(_docs))
docs += _docs
if not docs:
break
context_length = self.get_rag_context_length(state, docs)
if context_length >= self.retrieval_config.llm_config.max_context_tokens:
logging.warning(
f"The context length is {context_length} which is greater than "
f"the max context tokens of {self.retrieval_config.llm_config.max_context_tokens}"
)
break
number_of_relevant_chunks = max(_n)
i += 1
return {**state, "docs": docs}
def get_rag_context_length(self, state: AgentState, docs: List[Document]) -> int:
final_inputs = self._build_rag_prompt_inputs(state, docs)
msg = custom_prompts.RAG_ANSWER_PROMPT.format(**final_inputs)
return self.llm_endpoint.count_tokens(msg)
def reduce_rag_context(
self,
inputs: Dict[str, Any],
prompt: BasePromptTemplate,
docs: List[Document] | None = None,
max_context_tokens: int | None = None,
) -> Tuple[Dict[str, Any], List[Document] | None]:
MAX_ITERATION = 100
SECURITY_FACTOR = 0.85
iteration = 0
msg = prompt.format(**inputs)
n = self.llm_endpoint.count_tokens(msg)
max_context_tokens = (
max_context_tokens
if max_context_tokens
else self.retrieval_config.llm_config.max_context_tokens
)
while n > max_context_tokens * SECURITY_FACTOR:
chat_history = inputs["chat_history"] if "chat_history" in inputs else []
if len(chat_history) > 0:
inputs["chat_history"] = chat_history[2:]
elif docs and len(docs) > 1:
docs = docs[:-1]
else:
logging.warning(
f"Not enough context to reduce. The context length is {n} "
f"which is greater than the max context tokens of {max_context_tokens}"
)
break
if docs and "context" in inputs:
inputs["context"] = combine_documents(docs)
msg = prompt.format(**inputs)
n = self.llm_endpoint.count_tokens(msg)
iteration += 1
if iteration > MAX_ITERATION:
logging.warning(
f"Attained the maximum number of iterations ({MAX_ITERATION})"
)
break
return inputs, docs
def bind_tools_to_llm(self, node_name: str):
if self.llm_endpoint.supports_func_calling():
tools = self.retrieval_config.workflow_config.get_node_tools(node_name)
if tools: # Only bind tools if there are any available
return self.llm_endpoint._llm.bind_tools(tools, tool_choice="any")
return self.llm_endpoint._llm
def generate_rag(self, state: AgentState) -> AgentState:
docs: List[Document] | None = state["docs"]
final_inputs = self._build_rag_prompt_inputs(state, docs)
reduced_inputs, docs = self.reduce_rag_context(
final_inputs, custom_prompts.RAG_ANSWER_PROMPT, docs
)
msg = custom_prompts.RAG_ANSWER_PROMPT.format(**reduced_inputs)
llm = self.bind_tools_to_llm(self.generate_rag.__name__)
response = llm.invoke(msg)
return {**state, "messages": [response], "docs": docs if docs else []}
def generate_chat_llm(self, state: AgentState) -> AgentState:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
messages = state["messages"]
user_question = messages[0].content
# Prompt
prompt = self.retrieval_config.prompt
final_inputs = {}
final_inputs["question"] = user_question
final_inputs["custom_instructions"] = prompt if prompt else "None"
final_inputs["chat_history"] = state["chat_history"].to_list()
# LLM
llm = self.llm_endpoint._llm
reduced_inputs, _ = self.reduce_rag_context(
final_inputs, custom_prompts.CHAT_LLM_PROMPT, None
)
msg = custom_prompts.CHAT_LLM_PROMPT.format(**reduced_inputs)
# Run
response = llm.invoke(msg)
return {**state, "messages": [response]}
def build_chain(self):
"""
Builds the langchain chain for the given configuration.
Returns:
Callable[[Dict], Dict]: The langchain chain.
"""
if not self.graph:
self.graph = self.create_graph()
return self.graph
def create_graph(self):
workflow = StateGraph(AgentState)
self.final_nodes = []
self._build_workflow(workflow)
return workflow.compile()
def _build_workflow(self, workflow: StateGraph):
for node in self.retrieval_config.workflow_config.nodes:
if node.name not in [START, END]:
workflow.add_node(node.name, getattr(self, node.name))
for node in self.retrieval_config.workflow_config.nodes:
self._add_node_edges(workflow, node)
def _add_node_edges(self, workflow: StateGraph, node: NodeConfig):
if node.edges:
for edge in node.edges:
workflow.add_edge(node.name, edge)
if edge == END:
self.final_nodes.append(node.name)
elif node.conditional_edge:
routing_function = getattr(self, node.conditional_edge.routing_function)
workflow.add_conditional_edges(
node.name, routing_function, node.conditional_edge.conditions
)
if END in node.conditional_edge.conditions:
self.final_nodes.append(node.name)
else:
raise ValueError("Node should have at least one edge or conditional_edge")
async def answer_astream(
self,
question: str,
history: ChatHistory,
list_files: list[QuivrKnowledge],
metadata: dict[str, str] = {},
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
"""
Answer a question using the langgraph chain and yield each chunk of the answer separately.
"""
concat_list_files = format_file_list(
list_files, self.retrieval_config.max_files
)
conversational_qa_chain = self.build_chain()
rolling_message = AIMessageChunk(content="")
docs: list[Document] | None = None
previous_content = ""
async for event in conversational_qa_chain.astream_events(
{
"messages": [("user", question)],
"chat_history": history,
"files": concat_list_files,
},
version="v1",
config={"metadata": metadata},
):
if self._is_final_node_with_docs(event):
docs = event["data"]["output"]["docs"]
if self._is_final_node_and_chat_model_stream(event):
chunk = event["data"]["chunk"]
rolling_message, new_content, previous_content = parse_chunk_response(
rolling_message,
chunk,
self.llm_endpoint.supports_func_calling(),
previous_content,
)
if new_content:
chunk_metadata = get_chunk_metadata(rolling_message, docs)
yield ParsedRAGChunkResponse(
answer=new_content, metadata=chunk_metadata
)
# Yield final metadata chunk
yield ParsedRAGChunkResponse(
answer="",
metadata=get_chunk_metadata(rolling_message, docs),
last_chunk=True,
)
def _is_final_node_with_docs(self, event: dict) -> bool:
return (
"output" in event["data"]
and event["data"]["output"] is not None
and "docs" in event["data"]["output"]
and event["metadata"]["langgraph_node"] in self.final_nodes
)
def _is_final_node_and_chat_model_stream(self, event: dict) -> bool:
return (
event["event"] == "on_chat_model_stream"
and "langgraph_node" in event["metadata"]
and event["metadata"]["langgraph_node"] in self.final_nodes
)
def invoke_structured_output(
self, prompt: str, output_class: Type[BaseModel]
) -> Any:
try:
structured_llm = self.llm_endpoint._llm.with_structured_output(
output_class, method="json_schema"
)
return structured_llm.invoke(prompt)
except openai.BadRequestError:
structured_llm = self.llm_endpoint._llm.with_structured_output(output_class)
return structured_llm.invoke(prompt)
def _build_rag_prompt_inputs(
self, state: AgentState, docs: List[Document] | None
) -> Dict[str, Any]:
"""Build the input dictionary for RAG_ANSWER_PROMPT.
Args:
state: Current agent state
docs: List of documents or None
Returns:
Dictionary containing all inputs needed for RAG_ANSWER_PROMPT
"""
messages = state["messages"]
user_question = messages[0].content
files = state["files"]
prompt = self.retrieval_config.prompt
available_tools, _ = collect_tools(self.retrieval_config.workflow_config)
return {
"context": combine_documents(docs) if docs else "None",
"question": user_question,
"rephrased_task": state["tasks"],
"custom_instructions": prompt if prompt else "None",
"files": files if files else "None",
"chat_history": state["chat_history"].to_list(),
"reasoning": state["reasoning"] if "reasoning" in state else "None",
"tools": available_tools,
}

View File

@ -0,0 +1,188 @@
import logging
from typing import Any, List, Tuple, no_type_check
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.prompts import format_document
from quivr_core.rag.entities.config import WorkflowConfig
from quivr_core.rag.entities.models import (
ChatLLMMetadata,
ParsedRAGResponse,
QuivrKnowledge,
RAGResponseMetadata,
RawRAGResponse,
)
from quivr_core.rag.prompts import custom_prompts
# TODO(@aminediro): define a types packages where we clearly define IO types
# This should be used for serialization/deseriallization later
logger = logging.getLogger("quivr_core")
def model_supports_function_calling(model_name: str):
models_not_supporting_function_calls: list[str] = ["llama2", "test", "ollama3"]
return model_name not in models_not_supporting_function_calls
def format_history_to_openai_mesages(
tuple_history: List[Tuple[str, str]], system_message: str, question: str
) -> List[BaseMessage]:
"""Format the chat history into a list of Base Messages"""
messages = []
messages.append(SystemMessage(content=system_message))
for human, ai in tuple_history:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
messages.append(HumanMessage(content=question))
return messages
def cited_answer_filter(tool):
return tool["name"] == "cited_answer"
def get_chunk_metadata(
msg: AIMessageChunk, sources: list[Any] | None = None
) -> RAGResponseMetadata:
metadata = {"sources": sources or []}
if not msg.tool_calls:
return RAGResponseMetadata(**metadata, metadata_model=None)
all_citations = []
all_followup_questions = []
for tool_call in msg.tool_calls:
if tool_call.get("name") == "cited_answer" and "args" in tool_call:
args = tool_call["args"]
all_citations.extend(args.get("citations", []))
all_followup_questions.extend(args.get("followup_questions", []))
metadata["citations"] = all_citations
metadata["followup_questions"] = all_followup_questions[:3] # Limit to 3
return RAGResponseMetadata(**metadata, metadata_model=None)
def get_prev_message_str(msg: AIMessageChunk) -> str:
if msg.tool_calls:
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
if "args" in cited_answer and "answer" in cited_answer["args"]:
return cited_answer["args"]["answer"]
return ""
# TODO: CONVOLUTED LOGIC !
# TODO(@aminediro): redo this
@no_type_check
def parse_chunk_response(
rolling_msg: AIMessageChunk,
raw_chunk: AIMessageChunk,
supports_func_calling: bool,
previous_content: str = "",
) -> Tuple[AIMessageChunk, str, str]:
"""Parse a chunk response
Args:
rolling_msg: The accumulated message so far
raw_chunk: The new chunk to add
supports_func_calling: Whether function calling is supported
previous_content: The previous content string
Returns:
Tuple of (updated rolling message, new content only, full content)
"""
rolling_msg += raw_chunk
if not supports_func_calling or not rolling_msg.tool_calls:
new_content = raw_chunk.content # Just the new chunk's content
full_content = rolling_msg.content # The full accumulated content
return rolling_msg, new_content, full_content
current_answers = get_answers_from_tool_calls(rolling_msg.tool_calls)
full_answer = "\n\n".join(current_answers)
new_content = full_answer[len(previous_content) :]
return rolling_msg, new_content, full_answer
def get_answers_from_tool_calls(tool_calls):
answers = []
for tool_call in tool_calls:
if tool_call.get("name") == "cited_answer" and "args" in tool_call:
answers.append(tool_call["args"].get("answer", ""))
return answers
@no_type_check
def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse:
answers = []
sources = raw_response["docs"] if "docs" in raw_response else []
metadata = RAGResponseMetadata(
sources=sources, metadata_model=ChatLLMMetadata(name=model_name)
)
if (
model_supports_function_calling(model_name)
and "tool_calls" in raw_response["answer"]
and raw_response["answer"].tool_calls
):
all_citations = []
all_followup_questions = []
for tool_call in raw_response["answer"].tool_calls:
if "args" in tool_call:
args = tool_call["args"]
if "citations" in args:
all_citations.extend(args["citations"])
if "followup_questions" in args:
all_followup_questions.extend(args["followup_questions"])
if "answer" in args:
answers.append(args["answer"])
metadata.citations = all_citations
metadata.followup_questions = all_followup_questions
else:
answers.append(raw_response["answer"].content)
answer_str = "\n".join(answers)
parsed_response = ParsedRAGResponse(answer=answer_str, metadata=metadata)
return parsed_response
def combine_documents(
docs,
document_prompt=custom_prompts.DEFAULT_DOCUMENT_PROMPT,
document_separator="\n\n",
):
# for each docs, add an index in the metadata to be able to cite the sources
for doc, index in zip(docs, range(len(docs)), strict=False):
doc.metadata["index"] = index
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
def format_file_list(
list_files_array: list[QuivrKnowledge], max_files: int = 20
) -> str:
list_files = [file.file_name or file.url for file in list_files_array]
files: list[str] = list(filter(lambda n: n is not None, list_files)) # type: ignore
files = files[:max_files]
files_str = "\n".join(files) if list_files_array else "None"
return files_str
def collect_tools(workflow_config: WorkflowConfig):
validated_tools = "Available tools which can be activated:\n"
for i, tool in enumerate(workflow_config.validated_tools):
validated_tools += f"Tool {i+1} name: {tool.name}\n"
validated_tools += f"Tool {i+1} description: {tool.description}\n\n"
activated_tools = "Activated tools which can be deactivated:\n"
for i, tool in enumerate(workflow_config.activated_tools):
activated_tools += f"Tool {i+1} name: {tool.name}\n"
activated_tools += f"Tool {i+1} description: {tool.description}\n\n"
return validated_tools, activated_tools

View File

@ -1,167 +0,0 @@
import logging
from typing import Any, List, Tuple, no_type_check
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.prompts import format_document
from quivr_core.models import (
ChatLLMMetadata,
ParsedRAGResponse,
QuivrKnowledge,
RAGResponseMetadata,
RawRAGResponse,
)
from quivr_core.prompts import custom_prompts
# TODO(@aminediro): define a types packages where we clearly define IO types
# This should be used for serialization/deseriallization later
logger = logging.getLogger("quivr_core")
def model_supports_function_calling(model_name: str):
models_supporting_function_calls = [
"gpt-4",
"gpt-4-1106-preview",
"gpt-4-0613",
"gpt-4o",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-0613",
"gpt-4-0125-preview",
"gpt-3.5-turbo",
"gpt-4-turbo",
"gpt-4o",
"gpt-4o-mini",
]
return model_name in models_supporting_function_calls
def format_history_to_openai_mesages(
tuple_history: List[Tuple[str, str]], system_message: str, question: str
) -> List[BaseMessage]:
"""Format the chat history into a list of Base Messages"""
messages = []
messages.append(SystemMessage(content=system_message))
for human, ai in tuple_history:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
messages.append(HumanMessage(content=question))
return messages
def cited_answer_filter(tool):
return tool["name"] == "cited_answer"
def get_chunk_metadata(
msg: AIMessageChunk, sources: list[Any] | None = None
) -> RAGResponseMetadata:
# Initiate the source
metadata = {"sources": sources} if sources else {"sources": []}
if msg.tool_calls:
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
if "args" in cited_answer:
gathered_args = cited_answer["args"]
if "citations" in gathered_args:
citations = gathered_args["citations"]
metadata["citations"] = citations
if "followup_questions" in gathered_args:
followup_questions = gathered_args["followup_questions"]
metadata["followup_questions"] = followup_questions
return RAGResponseMetadata(**metadata, metadata_model=None)
def get_prev_message_str(msg: AIMessageChunk) -> str:
if msg.tool_calls:
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
if "args" in cited_answer and "answer" in cited_answer["args"]:
return cited_answer["args"]["answer"]
return ""
# TODO: CONVOLUTED LOGIC !
# TODO(@aminediro): redo this
@no_type_check
def parse_chunk_response(
rolling_msg: AIMessageChunk,
raw_chunk: dict[str, Any],
supports_func_calling: bool,
) -> Tuple[AIMessageChunk, str]:
# Init with sources
answer_str = ""
if "answer" in raw_chunk:
answer = raw_chunk["answer"]
else:
answer = raw_chunk
rolling_msg += answer
if supports_func_calling and rolling_msg.tool_calls:
cited_answer = next(x for x in rolling_msg.tool_calls if cited_answer_filter(x))
if "args" in cited_answer and "answer" in cited_answer["args"]:
gathered_args = cited_answer["args"]
# Only send the difference between answer and response_tokens which was the previous answer
answer_str = gathered_args["answer"]
return rolling_msg, answer_str
return rolling_msg, answer.content
@no_type_check
def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse:
answer = ""
sources = raw_response["docs"] if "docs" in raw_response else []
metadata = RAGResponseMetadata(
sources=sources, metadata_model=ChatLLMMetadata(name=model_name)
)
if (
model_supports_function_calling(model_name)
and "tool_calls" in raw_response["answer"]
and raw_response["answer"].tool_calls
):
if "citations" in raw_response["answer"].tool_calls[-1]["args"]:
citations = raw_response["answer"].tool_calls[-1]["args"]["citations"]
metadata.citations = citations
followup_questions = raw_response["answer"].tool_calls[-1]["args"][
"followup_questions"
]
if followup_questions:
metadata.followup_questions = followup_questions
answer = raw_response["answer"].tool_calls[-1]["args"]["answer"]
else:
answer = raw_response["answer"].tool_calls[-1]["args"]["answer"]
else:
answer = raw_response["answer"].content
parsed_response = ParsedRAGResponse(answer=answer, metadata=metadata)
return parsed_response
def combine_documents(
docs,
document_prompt=custom_prompts.DEFAULT_DOCUMENT_PROMPT,
document_separator="\n\n",
):
# for each docs, add an index in the metadata to be able to cite the sources
for doc, index in zip(docs, range(len(docs)), strict=False):
doc.metadata["index"] = index
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
def format_file_list(
list_files_array: list[QuivrKnowledge], max_files: int = 20
) -> str:
list_files = [file.file_name or file.url for file in list_files_array]
files: list[str] = list(filter(lambda n: n is not None, list_files)) # type: ignore
files = files[:max_files]
files_str = "\n".join(files) if list_files_array else "None"
return files_str

View File

@ -27,6 +27,8 @@ anyio==4.6.2.post1
# via anthropic # via anthropic
# via httpx # via httpx
# via openai # via openai
appnope==0.1.4
# via ipykernel
asttokens==2.4.1 asttokens==2.4.1
# via stack-data # via stack-data
attrs==24.2.0 attrs==24.2.0
@ -80,8 +82,6 @@ frozenlist==1.4.1
# via aiosignal # via aiosignal
fsspec==2024.9.0 fsspec==2024.9.0
# via huggingface-hub # via huggingface-hub
greenlet==3.1.1
# via sqlalchemy
h11==0.14.0 h11==0.14.0
# via httpcore # via httpcore
httpcore==1.0.6 httpcore==1.0.6
@ -278,6 +278,8 @@ pyyaml==6.0.2
pyzmq==26.2.0 pyzmq==26.2.0
# via ipykernel # via ipykernel
# via jupyter-client # via jupyter-client
rapidfuzz==3.10.1
# via quivr-core
regex==2024.9.11 regex==2024.9.11
# via tiktoken # via tiktoken
# via transformers # via transformers

View File

@ -56,8 +56,6 @@ frozenlist==1.4.1
# via aiosignal # via aiosignal
fsspec==2024.9.0 fsspec==2024.9.0
# via huggingface-hub # via huggingface-hub
greenlet==3.1.1
# via sqlalchemy
h11==0.14.0 h11==0.14.0
# via httpcore # via httpcore
httpcore==1.0.6 httpcore==1.0.6
@ -185,6 +183,8 @@ pyyaml==6.0.2
# via langchain-community # via langchain-community
# via langchain-core # via langchain-core
# via transformers # via transformers
rapidfuzz==3.10.1
# via quivr-core
regex==2024.9.11 regex==2024.9.11
# via tiktoken # via tiktoken
# via transformers # via transformers

View File

@ -9,7 +9,7 @@ from langchain_core.language_models import FakeListChatModel
from langchain_core.messages.ai import AIMessageChunk from langchain_core.messages.ai import AIMessageChunk
from langchain_core.runnables.utils import AddableDict from langchain_core.runnables.utils import AddableDict
from langchain_core.vectorstores import InMemoryVectorStore from langchain_core.vectorstores import InMemoryVectorStore
from quivr_core.config import LLMEndpointConfig from quivr_core.rag.entities.config import LLMEndpointConfig
from quivr_core.files.file import FileExtension, QuivrFile from quivr_core.files.file import FileExtension, QuivrFile
from quivr_core.llm import LLMEndpoint from quivr_core.llm import LLMEndpoint

View File

@ -5,10 +5,10 @@ from uuid import uuid4
from langchain_core.embeddings import DeterministicFakeEmbedding from langchain_core.embeddings import DeterministicFakeEmbedding
from langchain_core.messages.ai import AIMessageChunk from langchain_core.messages.ai import AIMessageChunk
from langchain_core.vectorstores import InMemoryVectorStore from langchain_core.vectorstores import InMemoryVectorStore
from quivr_core.chat import ChatHistory from quivr_core.rag.entities.chat import ChatHistory
from quivr_core.config import LLMEndpointConfig, RetrievalConfig from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig
from quivr_core.llm import LLMEndpoint from quivr_core.llm import LLMEndpoint
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph
async def main(): async def main():

View File

@ -27,9 +27,9 @@ retrieval_config:
supplier: "openai" supplier: "openai"
# The model to use for the LLM for the given supplier # The model to use for the LLM for the given supplier
model: "gpt-4o" model: "gpt-3.5-turbo-0125"
max_input_tokens: 2000 max_context_tokens: 2000
# Maximum number of tokens to pass to the LLM # Maximum number of tokens to pass to the LLM
# as a context to generate the answer # as a context to generate the answer

View File

@ -40,9 +40,9 @@ retrieval_config:
supplier: "openai" supplier: "openai"
# The model to use for the LLM for the given supplier # The model to use for the LLM for the given supplier
model: "gpt-4o" model: "gpt-3.5-turbo-0125"
max_input_tokens: 2000 max_context_tokens: 2000
# Maximum number of tokens to pass to the LLM # Maximum number of tokens to pass to the LLM
# as a context to generate the answer # as a context to generate the answer

View File

@ -5,7 +5,7 @@ import pytest
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from quivr_core.brain import Brain from quivr_core.brain import Brain
from quivr_core.chat import ChatHistory from quivr_core.rag.entities.chat import ChatHistory
from quivr_core.llm import LLMEndpoint from quivr_core.llm import LLMEndpoint
from quivr_core.storage.local_storage import TransparentStorage from quivr_core.storage.local_storage import TransparentStorage
@ -104,8 +104,8 @@ async def test_brain_get_history(
vector_db=mem_vector_store, vector_db=mem_vector_store,
) )
brain.ask("question") await brain.aask("question")
brain.ask("question") await brain.aask("question")
assert len(brain.default_chat) == 4 assert len(brain.default_chat) == 4

View File

@ -3,7 +3,7 @@ from uuid import uuid4
import pytest import pytest
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from quivr_core.chat import ChatHistory from quivr_core.rag.entities.chat import ChatHistory
@pytest.fixture @pytest.fixture

View File

@ -1,18 +1,21 @@
from quivr_core.config import LLMEndpointConfig, RetrievalConfig from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig
def test_default_llm_config(): def test_default_llm_config():
config = LLMEndpointConfig() config = LLMEndpointConfig()
assert config.model_dump(exclude={"llm_api_key"}) == LLMEndpointConfig( assert (
model="gpt-4o", config.model_dump()
llm_base_url=None, == LLMEndpointConfig(
llm_api_key=None, model="gpt-4o",
max_input_tokens=2000, llm_base_url=None,
max_output_tokens=2000, llm_api_key=None,
temperature=0.7, max_context_tokens=2000,
streaming=True, max_output_tokens=2000,
).model_dump(exclude={"llm_api_key"}) temperature=0.7,
streaming=True,
).model_dump()
)
def test_default_retrievalconfig(): def test_default_retrievalconfig():
@ -20,6 +23,6 @@ def test_default_retrievalconfig():
assert config.max_files == 20 assert config.max_files == 20
assert config.prompt is None assert config.prompt is None
assert config.llm_config.model_dump( print("\n\n", config.llm_config, "\n\n")
exclude={"llm_api_key"} print("\n\n", LLMEndpointConfig(), "\n\n")
) == LLMEndpointConfig().model_dump(exclude={"llm_api_key"}) assert config.llm_config == LLMEndpointConfig()

View File

@ -3,7 +3,7 @@ import os
import pytest import pytest
from langchain_core.language_models import FakeListChatModel from langchain_core.language_models import FakeListChatModel
from pydantic.v1.error_wrappers import ValidationError from pydantic.v1.error_wrappers import ValidationError
from quivr_core.config import LLMEndpointConfig from quivr_core.rag.entities.config import LLMEndpointConfig
from quivr_core.llm import LLMEndpoint from quivr_core.llm import LLMEndpoint

View File

@ -1,25 +1,51 @@
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from quivr_core.chat import ChatHistory from quivr_core.rag.entities.chat import ChatHistory
from quivr_core.config import LLMEndpointConfig, RetrievalConfig from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig
from quivr_core.llm import LLMEndpoint from quivr_core.llm import LLMEndpoint
from quivr_core.models import ParsedRAGChunkResponse, RAGResponseMetadata from quivr_core.rag.entities.models import ParsedRAGChunkResponse, RAGResponseMetadata
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mock_chain_qa_stream(monkeypatch, chunks_stream_answer): def mock_chain_qa_stream(monkeypatch, chunks_stream_answer):
class MockQAChain: class MockQAChain:
async def astream_events(self, *args, **kwargs): async def astream_events(self, *args, **kwargs):
for c in chunks_stream_answer: default_metadata = {
"langgraph_node": "generate",
"is_final_node": False,
"citations": None,
"followup_questions": None,
"sources": None,
"metadata_model": None,
}
# Send all chunks except the last one
for chunk in chunks_stream_answer[:-1]:
yield { yield {
"event": "on_chat_model_stream", "event": "on_chat_model_stream",
"metadata": {"langgraph_node": "generate"}, "metadata": default_metadata,
"data": {"chunk": c}, "data": {"chunk": chunk["answer"]},
} }
# Send the last chunk
yield {
"event": "end",
"metadata": {
"langgraph_node": "generate",
"is_final_node": True,
"citations": [],
"followup_questions": None,
"sources": [],
"metadata_model": None,
},
"data": {"chunk": chunks_stream_answer[-1]["answer"]},
}
def mock_qa_chain(*args, **kwargs): def mock_qa_chain(*args, **kwargs):
self = args[0]
self.final_nodes = ["generate"]
return MockQAChain() return MockQAChain()
monkeypatch.setattr(QuivrQARAGLangGraph, "build_chain", mock_qa_chain) monkeypatch.setattr(QuivrQARAGLangGraph, "build_chain", mock_qa_chain)
@ -48,11 +74,13 @@ async def test_quivrqaraglanggraph(
): ):
stream_responses.append(resp) stream_responses.append(resp)
# This assertion passed
assert all( assert all(
not r.last_chunk for r in stream_responses[:-1] not r.last_chunk for r in stream_responses[:-1]
), "Some chunks before last have last_chunk=True" ), "Some chunks before last have last_chunk=True"
assert stream_responses[-1].last_chunk assert stream_responses[-1].last_chunk
# Let's check this assertion
for idx, response in enumerate(stream_responses[1:-1]): for idx, response in enumerate(stream_responses[1:-1]):
assert ( assert (
len(response.answer) > 0 len(response.answer) > 0

View File

@ -3,7 +3,7 @@ from uuid import uuid4
import pytest import pytest
from langchain_core.messages.ai import AIMessageChunk from langchain_core.messages.ai import AIMessageChunk
from langchain_core.messages.tool import ToolCall from langchain_core.messages.tool import ToolCall
from quivr_core.utils import ( from quivr_core.rag.utils import (
get_prev_message_str, get_prev_message_str,
model_supports_function_calling, model_supports_function_calling,
parse_chunk_response, parse_chunk_response,
@ -43,13 +43,9 @@ def test_get_prev_message_str():
def test_parse_chunk_response_nofunc_calling(): def test_parse_chunk_response_nofunc_calling():
rolling_msg = AIMessageChunk(content="") rolling_msg = AIMessageChunk(content="")
chunk = { chunk = AIMessageChunk(content="next ")
"answer": AIMessageChunk(
content="next ",
)
}
for i in range(10): for i in range(10):
rolling_msg, parsed_chunk = parse_chunk_response(rolling_msg, chunk, False) rolling_msg, parsed_chunk, _ = parse_chunk_response(rolling_msg, chunk, False)
assert rolling_msg.content == "next " * (i + 1) assert rolling_msg.content == "next " * (i + 1)
assert parsed_chunk == "next " assert parsed_chunk == "next "
@ -70,8 +66,9 @@ def test_parse_chunk_response_func_calling(chunks_stream_answer):
answer_str_history: list[str] = [] answer_str_history: list[str] = []
for chunk in chunks_stream_answer: for chunk in chunks_stream_answer:
# This is done # Extract the AIMessageChunk from the chunk dictionary
rolling_msg, answer_str = parse_chunk_response(rolling_msg, chunk, True) chunk_msg = chunk["answer"] # Get the AIMessageChunk from the dict
rolling_msg, answer_str, _ = parse_chunk_response(rolling_msg, chunk_msg, True)
rolling_msgs_history.append(rolling_msg) rolling_msgs_history.append(rolling_msg)
answer_str_history.append(answer_str) answer_str_history.append(answer_str)

View File

@ -5,7 +5,7 @@ from pathlib import Path
import dotenv import dotenv
from quivr_core import Brain from quivr_core import Brain
from quivr_core.config import AssistantConfig from quivr_core.rag.entities.config import AssistantConfig
from rich.traceback import install as rich_install from rich.traceback import install as rich_install
ConsoleOutputHandler = logging.StreamHandler() ConsoleOutputHandler = logging.StreamHandler()

View File

@ -1,7 +1,7 @@
from langchain_core.embeddings import DeterministicFakeEmbedding from langchain_core.embeddings import DeterministicFakeEmbedding
from langchain_core.language_models import FakeListChatModel from langchain_core.language_models import FakeListChatModel
from quivr_core import Brain from quivr_core import Brain
from quivr_core.config import LLMEndpointConfig from quivr_core.rag.entities.config import LLMEndpointConfig
from quivr_core.llm.llm_endpoint import LLMEndpoint from quivr_core.llm.llm_endpoint import LLMEndpoint
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel

View File

@ -7,6 +7,7 @@ authors = [
] ]
dependencies = [ dependencies = [
"quivr-core @ file:///${PROJECT_ROOT}/../../core", "quivr-core @ file:///${PROJECT_ROOT}/../../core",
"python-dotenv>=1.0.1",
] ]
readme = "README.md" readme = "README.md"
requires-python = ">= 3.11" requires-python = ">= 3.11"

View File

@ -55,8 +55,6 @@ frozenlist==1.4.1
# via aiosignal # via aiosignal
fsspec==2024.10.0 fsspec==2024.10.0
# via huggingface-hub # via huggingface-hub
greenlet==3.1.1
# via sqlalchemy
h11==0.14.0 h11==0.14.0
# via httpcore # via httpcore
httpcore==1.0.6 httpcore==1.0.6
@ -169,6 +167,7 @@ pydantic==2.9.2
# via langsmith # via langsmith
# via openai # via openai
# via quivr-core # via quivr-core
# via sqlmodel
pydantic-core==2.23.4 pydantic-core==2.23.4
# via cohere # via cohere
# via pydantic # via pydantic
@ -176,6 +175,8 @@ pygments==2.18.0
# via rich # via rich
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
# via pandas # via pandas
python-dotenv==1.0.1
# via quivr-core
pytz==2024.2 pytz==2024.2
# via pandas # via pandas
pyyaml==6.0.2 pyyaml==6.0.2
@ -185,6 +186,8 @@ pyyaml==6.0.2
# via langchain-core # via langchain-core
# via transformers # via transformers
quivr-core @ file:///${PROJECT_ROOT}/../../core quivr-core @ file:///${PROJECT_ROOT}/../../core
rapidfuzz==3.10.1
# via quivr-core
regex==2024.9.11 regex==2024.9.11
# via tiktoken # via tiktoken
# via transformers # via transformers
@ -215,6 +218,9 @@ sniffio==1.3.1
sqlalchemy==2.0.36 sqlalchemy==2.0.36
# via langchain # via langchain
# via langchain-community # via langchain-community
# via sqlmodel
sqlmodel==0.0.22
# via quivr-core
tabulate==0.9.0 tabulate==0.9.0
# via langchain-cohere # via langchain-cohere
tenacity==8.5.0 tenacity==8.5.0

View File

@ -55,8 +55,6 @@ frozenlist==1.4.1
# via aiosignal # via aiosignal
fsspec==2024.10.0 fsspec==2024.10.0
# via huggingface-hub # via huggingface-hub
greenlet==3.1.1
# via sqlalchemy
h11==0.14.0 h11==0.14.0
# via httpcore # via httpcore
httpcore==1.0.6 httpcore==1.0.6
@ -169,6 +167,7 @@ pydantic==2.9.2
# via langsmith # via langsmith
# via openai # via openai
# via quivr-core # via quivr-core
# via sqlmodel
pydantic-core==2.23.4 pydantic-core==2.23.4
# via cohere # via cohere
# via pydantic # via pydantic
@ -176,6 +175,8 @@ pygments==2.18.0
# via rich # via rich
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
# via pandas # via pandas
python-dotenv==1.0.1
# via quivr-core
pytz==2024.2 pytz==2024.2
# via pandas # via pandas
pyyaml==6.0.2 pyyaml==6.0.2
@ -185,6 +186,8 @@ pyyaml==6.0.2
# via langchain-core # via langchain-core
# via transformers # via transformers
quivr-core @ file:///${PROJECT_ROOT}/../../core quivr-core @ file:///${PROJECT_ROOT}/../../core
rapidfuzz==3.10.1
# via quivr-core
regex==2024.9.11 regex==2024.9.11
# via tiktoken # via tiktoken
# via transformers # via transformers
@ -215,6 +218,9 @@ sniffio==1.3.1
sqlalchemy==2.0.36 sqlalchemy==2.0.36
# via langchain # via langchain
# via langchain-community # via langchain-community
# via sqlmodel
sqlmodel==0.0.22
# via quivr-core
tabulate==0.9.0 tabulate==0.9.0
# via langchain-cohere # via langchain-cohere
tenacity==8.5.0 tenacity==8.5.0

View File

@ -2,6 +2,10 @@ import tempfile
from quivr_core import Brain from quivr_core import Brain
import dotenv
dotenv.load_dotenv()
if __name__ == "__main__": if __name__ == "__main__":
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file: with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
temp_file.write("Gold is a liquid of blue-like colour.") temp_file.write("Gold is a liquid of blue-like colour.")
@ -12,7 +16,5 @@ if __name__ == "__main__":
file_paths=[temp_file.name], file_paths=[temp_file.name],
) )
answer = brain.ask( answer = brain.ask("what is gold? answer in french")
"what is gold? asnwer in french" print("answer QuivrQARAGLangGraph :", answer)
)
print("answer QuivrQARAGLangGraph :", answer.answer)

View File

@ -4,7 +4,7 @@ import tempfile
from dotenv import load_dotenv from dotenv import load_dotenv
from quivr_core import Brain from quivr_core import Brain
from quivr_core.quivr_rag import QuivrQARAG from quivr_core.quivr_rag import QuivrQARAG
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph
async def main(): async def main():