mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-21 16:12:42 +03:00
feat: websearch, tool use, user intent, dynamic retrieval, multiple questions (#3424)
# Description This PR includes far too many new features: - detection of user intent (closes CORE-211) - treating multiple questions in parallel (closes CORE-212) - using the chat history when answering a question (closes CORE-213) - filtering of retrieved chunks by relevance threshold (closes CORE-217) - dynamic retrieval of chunks (closes CORE-218) - enabling web search via Tavily (closes CORE-220) - enabling agent / assistant to activate tools when relevant to complete the user task (closes CORE-224) Also closes CORE-205 ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --------- Co-authored-by: Stan Girard <stan@quivr.app>
This commit is contained in:
parent
5401c01ee2
commit
285fe5b960
2
.github/workflows/backend-core-tests.yml
vendored
2
.github/workflows/backend-core-tests.yml
vendored
@ -41,6 +41,4 @@ jobs:
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libmagic-dev poppler-utils libreoffice tesseract-ocr pandoc
|
||||
cd core
|
||||
rye run python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()"
|
||||
rye run python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
||||
rye test -p quivr-core
|
||||
|
@ -9,7 +9,7 @@ dependencies = [
|
||||
"pydantic>=2.8.2",
|
||||
"langchain-core>=0.2.38",
|
||||
"langchain>=0.2.14,<0.3.0",
|
||||
"langgraph>=0.2.14",
|
||||
"langgraph>=0.2.38",
|
||||
"httpx>=0.27.0",
|
||||
"rich>=13.7.1",
|
||||
"tiktoken>=0.7.0",
|
||||
@ -21,6 +21,7 @@ dependencies = [
|
||||
"types-pyyaml>=6.0.12.20240808",
|
||||
"transformers[sentencepiece]>=4.44.2",
|
||||
"faiss-cpu>=1.8.0.post1",
|
||||
"rapidfuzz>=3.10.1",
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.11"
|
||||
|
@ -10,7 +10,9 @@ from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from quivr_core.rag.entities.models import ParsedRAGResponse
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from quivr_core.rag.quivr_rag import QuivrQARAG
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
@ -22,19 +24,17 @@ from quivr_core.brain.serialization import (
|
||||
LocalStorageConfig,
|
||||
TransparentStorageConfig,
|
||||
)
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import RetrievalConfig
|
||||
from quivr_core.rag.entities.chat import ChatHistory
|
||||
from quivr_core.rag.entities.config import RetrievalConfig
|
||||
from quivr_core.files.file import load_qfile
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.models import (
|
||||
from quivr_core.rag.entities.models import (
|
||||
ParsedRAGChunkResponse,
|
||||
ParsedRAGResponse,
|
||||
QuivrKnowledge,
|
||||
SearchResult,
|
||||
)
|
||||
from quivr_core.processor.registry import get_processor_class
|
||||
from quivr_core.quivr_rag import QuivrQARAG
|
||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
from quivr_core.storage.local_storage import LocalStorage, TransparentStorage
|
||||
from quivr_core.storage.storage_base import StorageBase
|
||||
|
||||
@ -49,19 +49,15 @@ async def process_files(
|
||||
"""
|
||||
Process files in storage.
|
||||
This function takes a StorageBase and return a list of langchain documents.
|
||||
|
||||
Args:
|
||||
storage (StorageBase): The storage containing the files to process.
|
||||
skip_file_error (bool): Whether to skip files that cannot be processed.
|
||||
processor_kwargs (dict[str, Any]): Additional arguments for the processor.
|
||||
|
||||
Returns:
|
||||
list[Document]: List of processed documents in the Langchain Document format.
|
||||
|
||||
Raises:
|
||||
ValueError: If a file cannot be processed and skip_file_error is False.
|
||||
Exception: If no processor is found for a file of a specific type and skip_file_error is False.
|
||||
|
||||
"""
|
||||
|
||||
knowledge = []
|
||||
@ -91,23 +87,17 @@ async def process_files(
|
||||
class Brain:
|
||||
"""
|
||||
A class representing a Brain.
|
||||
|
||||
This class allows for the creation of a Brain, which is a collection of knowledge one wants to retrieve information from.
|
||||
|
||||
A Brain is set to:
|
||||
|
||||
* Store files in the storage of your choice (local, S3, etc.)
|
||||
* Process the files in the storage to extract text and metadata in a wide range of format.
|
||||
* Store the processed files in the vector store of your choice (FAISS, PGVector, etc.) - default to FAISS.
|
||||
* Create an index of the processed files.
|
||||
* Use the *Quivr* workflow for the retrieval augmented generation.
|
||||
|
||||
A Brain is able to:
|
||||
|
||||
* Search for information in the vector store.
|
||||
* Answer questions about the knowledges in the Brain.
|
||||
* Stream the answer to the question.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the brain.
|
||||
id (UUID): The unique identifier of the brain.
|
||||
@ -115,16 +105,14 @@ class Brain:
|
||||
llm (LLMEndpoint): The language model used to generate the answer.
|
||||
vector_db (VectorStore): The vector store used to store the processed files.
|
||||
embedder (Embeddings): The embeddings used to create the index of the processed files.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
id: UUID,
|
||||
llm: LLMEndpoint,
|
||||
id: UUID | None = None,
|
||||
vector_db: VectorStore | None = None,
|
||||
embedder: Embeddings | None = None,
|
||||
storage: StorageBase | None = None,
|
||||
@ -156,19 +144,15 @@ class Brain:
|
||||
def load(cls, folder_path: str | Path) -> Self:
|
||||
"""
|
||||
Load a brain from a folder path.
|
||||
|
||||
Args:
|
||||
folder_path (str | Path): The path to the folder containing the brain.
|
||||
|
||||
Returns:
|
||||
Brain: The brain loaded from the folder path.
|
||||
|
||||
Example:
|
||||
```python
|
||||
brain_loaded = Brain.load("path/to/brain")
|
||||
brain_loaded.print_info()
|
||||
```
|
||||
|
||||
"""
|
||||
if isinstance(folder_path, str):
|
||||
folder_path = Path(folder_path)
|
||||
@ -217,16 +201,13 @@ class Brain:
|
||||
vector_db=vector_db,
|
||||
)
|
||||
|
||||
async def save(self, folder_path: str | Path) -> str:
|
||||
async def save(self, folder_path: str | Path):
|
||||
"""
|
||||
Save the brain to a folder path.
|
||||
|
||||
Args:
|
||||
folder_path (str | Path): The path to the folder where the brain will be saved.
|
||||
|
||||
Returns:
|
||||
str: The path to the folder where the brain was saved.
|
||||
|
||||
Example:
|
||||
```python
|
||||
await brain.save("path/to/brain")
|
||||
@ -324,10 +305,9 @@ class Brain:
|
||||
embedder: Embeddings | None = None,
|
||||
skip_file_error: bool = False,
|
||||
processor_kwargs: dict[str, Any] | None = None,
|
||||
) -> Self:
|
||||
):
|
||||
"""
|
||||
Create a brain from a list of file paths.
|
||||
|
||||
Args:
|
||||
name (str): The name of the brain.
|
||||
file_paths (list[str | Path]): The list of file paths to add to the brain.
|
||||
@ -337,10 +317,8 @@ class Brain:
|
||||
embedder (Embeddings | None): The embeddings used to create the index of the processed files.
|
||||
skip_file_error (bool): Whether to skip files that cannot be processed.
|
||||
processor_kwargs (dict[str, Any] | None): Additional arguments for the processor.
|
||||
|
||||
Returns:
|
||||
Brain: The brain created from the file paths.
|
||||
|
||||
Example:
|
||||
```python
|
||||
brain = await Brain.afrom_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
|
||||
@ -429,7 +407,6 @@ class Brain:
|
||||
) -> Self:
|
||||
"""
|
||||
Create a brain from a list of langchain documents.
|
||||
|
||||
Args:
|
||||
name (str): The name of the brain.
|
||||
langchain_documents (list[Document]): The list of langchain documents to add to the brain.
|
||||
@ -437,10 +414,8 @@ class Brain:
|
||||
storage (StorageBase): The storage used to store the files.
|
||||
llm (LLMEndpoint | None): The language model used to generate the answer.
|
||||
embedder (Embeddings | None): The embeddings used to create the index of the processed files.
|
||||
|
||||
Returns:
|
||||
Brain: The brain created from the langchain documents.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain_core.documents import Document
|
||||
@ -449,6 +424,7 @@ class Brain:
|
||||
brain.print_info()
|
||||
```
|
||||
"""
|
||||
|
||||
if llm is None:
|
||||
llm = default_llm()
|
||||
|
||||
@ -481,16 +457,13 @@ class Brain:
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search for relevant documents in the brain based on a query.
|
||||
|
||||
Args:
|
||||
query (str | Document): The query to search for.
|
||||
n_results (int): The number of results to return.
|
||||
filter (Callable | Dict[str, Any] | None): The filter to apply to the search.
|
||||
fetch_n_neighbors (int): The number of neighbors to fetch.
|
||||
|
||||
Returns:
|
||||
list[SearchResult]: The list of retrieved chunks.
|
||||
|
||||
Example:
|
||||
```python
|
||||
brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
|
||||
@ -517,57 +490,6 @@ class Brain:
|
||||
# add it to vectorstore
|
||||
raise NotImplementedError
|
||||
|
||||
def ask(
|
||||
self,
|
||||
question: str,
|
||||
retrieval_config: RetrievalConfig | None = None,
|
||||
rag_pipeline: Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None = None,
|
||||
list_files: list[QuivrKnowledge] | None = None,
|
||||
chat_history: ChatHistory | None = None,
|
||||
) -> ParsedRAGResponse:
|
||||
"""
|
||||
Ask a question to the brain and get a generated answer.
|
||||
|
||||
Args:
|
||||
question (str): The question to ask.
|
||||
retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs).
|
||||
rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use.
|
||||
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
|
||||
chat_history (ChatHistory | None): The chat history to use.
|
||||
|
||||
Returns:
|
||||
ParsedRAGResponse: The generated answer.
|
||||
|
||||
Example:
|
||||
```python
|
||||
brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
|
||||
answer = brain.ask("What is the meaning of life?")
|
||||
print(answer.answer)
|
||||
```
|
||||
"""
|
||||
async def collect_streamed_response():
|
||||
full_answer = ""
|
||||
async for response in self.ask_streaming(
|
||||
question=question,
|
||||
retrieval_config=retrieval_config,
|
||||
rag_pipeline=rag_pipeline,
|
||||
list_files=list_files,
|
||||
chat_history=chat_history
|
||||
):
|
||||
full_answer += response.answer
|
||||
return full_answer
|
||||
|
||||
# Run the async function in the event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
full_answer = loop.run_until_complete(collect_streamed_response())
|
||||
|
||||
chat_history = self.default_chat if chat_history is None else chat_history
|
||||
chat_history.append(HumanMessage(content=question))
|
||||
chat_history.append(AIMessage(content=full_answer))
|
||||
|
||||
# Return the final response
|
||||
return ParsedRAGResponse(answer=full_answer)
|
||||
|
||||
async def ask_streaming(
|
||||
self,
|
||||
question: str,
|
||||
@ -578,24 +500,20 @@ class Brain:
|
||||
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
|
||||
"""
|
||||
Ask a question to the brain and get a streamed generated answer.
|
||||
|
||||
Args:
|
||||
question (str): The question to ask.
|
||||
retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs).
|
||||
rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use.
|
||||
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
|
||||
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
|
||||
chat_history (ChatHistory | None): The chat history to use.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]: The streamed generated answer.
|
||||
|
||||
Example:
|
||||
```python
|
||||
brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
|
||||
async for chunk in brain.ask_streaming("What is the meaning of life?"):
|
||||
print(chunk.answer)
|
||||
```
|
||||
|
||||
"""
|
||||
llm = self.llm
|
||||
|
||||
@ -630,3 +548,64 @@ class Brain:
|
||||
chat_history.append(AIMessage(content=full_answer))
|
||||
yield response
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
question: str,
|
||||
retrieval_config: RetrievalConfig | None = None,
|
||||
rag_pipeline: Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None = None,
|
||||
list_files: list[QuivrKnowledge] | None = None,
|
||||
chat_history: ChatHistory | None = None,
|
||||
) -> ParsedRAGResponse:
|
||||
"""
|
||||
Synchronous version that asks a question to the brain and gets a generated answer.
|
||||
Args:
|
||||
question (str): The question to ask.
|
||||
retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs).
|
||||
rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use.
|
||||
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
|
||||
chat_history (ChatHistory | None): The chat history to use.
|
||||
Returns:
|
||||
ParsedRAGResponse: The generated answer.
|
||||
"""
|
||||
full_answer = ""
|
||||
|
||||
async for response in self.ask_streaming(
|
||||
question=question,
|
||||
retrieval_config=retrieval_config,
|
||||
rag_pipeline=rag_pipeline,
|
||||
list_files=list_files,
|
||||
chat_history=chat_history,
|
||||
):
|
||||
full_answer += response.answer
|
||||
|
||||
return ParsedRAGResponse(answer=full_answer)
|
||||
|
||||
def ask(
|
||||
self,
|
||||
question: str,
|
||||
retrieval_config: RetrievalConfig | None = None,
|
||||
rag_pipeline: Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None = None,
|
||||
list_files: list[QuivrKnowledge] | None = None,
|
||||
chat_history: ChatHistory | None = None,
|
||||
) -> ParsedRAGResponse:
|
||||
"""
|
||||
Fully synchronous version that asks a question to the brain and gets a generated answer.
|
||||
Args:
|
||||
question (str): The question to ask.
|
||||
retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs).
|
||||
rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use.
|
||||
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
|
||||
chat_history (ChatHistory | None): The chat history to use.
|
||||
Returns:
|
||||
ParsedRAGResponse: The generated answer.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.aask(
|
||||
question=question,
|
||||
retrieval_config=retrieval_config,
|
||||
rag_pipeline=rag_pipeline,
|
||||
list_files=list_files,
|
||||
chat_history=chat_history,
|
||||
)
|
||||
)
|
||||
|
@ -4,7 +4,7 @@ from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
from quivr_core.rag.entities.config import DefaultModelSuppliers, LLMEndpointConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
|
||||
logger = logging.getLogger("quivr_core")
|
||||
@ -46,7 +46,9 @@ def default_embedder() -> Embeddings:
|
||||
def default_llm() -> LLMEndpoint:
|
||||
try:
|
||||
logger.debug("Loaded ChatOpenAI as default LLM for brain")
|
||||
llm = LLMEndpoint.from_config(LLMEndpointConfig())
|
||||
llm = LLMEndpoint.from_config(
|
||||
LLMEndpointConfig(supplier=DefaultModelSuppliers.OPENAI, model="gpt-4o")
|
||||
)
|
||||
return llm
|
||||
|
||||
except ImportError as e:
|
||||
|
@ -4,9 +4,9 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
from quivr_core.rag.entities.config import LLMEndpointConfig
|
||||
from quivr_core.rag.entities.models import ChatMessage
|
||||
from quivr_core.files.file import QuivrFileSerialized
|
||||
from quivr_core.models import ChatMessage
|
||||
|
||||
|
||||
class EmbedderConfig(BaseModel):
|
||||
|
@ -1,15 +1,8 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from quivr_core.base_config import QuivrBaseConfig
|
||||
from quivr_core.processor.splitter import SplitterConfig
|
||||
from quivr_core.prompts import CustomPromptsModel
|
||||
|
||||
|
||||
class PdfParser(str, Enum):
|
||||
LLAMA_PARSE = "llama_parse"
|
||||
@ -32,489 +25,3 @@ class MegaparseConfig(MegaparseBaseConfig):
|
||||
strategy: str = "fast"
|
||||
llama_parse_api_key: str | None = None
|
||||
pdf_parser: PdfParser = PdfParser.UNSTRUCTURED
|
||||
|
||||
|
||||
class BrainConfig(QuivrBaseConfig):
|
||||
brain_id: UUID | None = None
|
||||
name: str
|
||||
|
||||
@property
|
||||
def id(self) -> UUID | None:
|
||||
return self.brain_id
|
||||
|
||||
|
||||
class DefaultRerankers(str, Enum):
|
||||
"""
|
||||
Enum representing the default API-based reranker suppliers supported by the application.
|
||||
|
||||
This enum defines the various reranker providers that can be used in the system.
|
||||
Each enum value corresponds to a specific supplier's identifier and has an
|
||||
associated default model.
|
||||
|
||||
Attributes:
|
||||
COHERE (str): Represents Cohere AI as a reranker supplier.
|
||||
JINA (str): Represents Jina AI as a reranker supplier.
|
||||
|
||||
Methods:
|
||||
default_model (property): Returns the default model for the selected supplier.
|
||||
"""
|
||||
|
||||
COHERE = "cohere"
|
||||
JINA = "jina"
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
"""
|
||||
Get the default model for the selected reranker supplier.
|
||||
|
||||
This property method returns the default model associated with the current
|
||||
reranker supplier (COHERE or JINA).
|
||||
|
||||
Returns:
|
||||
str: The name of the default model for the selected supplier.
|
||||
|
||||
Raises:
|
||||
KeyError: If the current enum value doesn't have a corresponding default model.
|
||||
"""
|
||||
# Mapping of suppliers to their default models
|
||||
return {
|
||||
self.COHERE: "rerank-multilingual-v3.0",
|
||||
self.JINA: "jina-reranker-v2-base-multilingual",
|
||||
}[self]
|
||||
|
||||
|
||||
class DefaultModelSuppliers(str, Enum):
|
||||
"""
|
||||
Enum representing the default model suppliers supported by the application.
|
||||
|
||||
This enum defines the various AI model providers that can be used as sources
|
||||
for LLMs in the system. Each enum value corresponds to a specific
|
||||
supplier's identifier.
|
||||
|
||||
Attributes:
|
||||
OPENAI (str): Represents OpenAI as a model supplier.
|
||||
AZURE (str): Represents Azure (Microsoft) as a model supplier.
|
||||
ANTHROPIC (str): Represents Anthropic as a model supplier.
|
||||
META (str): Represents Meta as a model supplier.
|
||||
MISTRAL (str): Represents Mistral AI as a model supplier.
|
||||
GROQ (str): Represents Groq as a model supplier.
|
||||
"""
|
||||
|
||||
OPENAI = "openai"
|
||||
AZURE = "azure"
|
||||
ANTHROPIC = "anthropic"
|
||||
META = "meta"
|
||||
MISTRAL = "mistral"
|
||||
GROQ = "groq"
|
||||
|
||||
|
||||
class LLMConfig(QuivrBaseConfig):
|
||||
context: int | None = None
|
||||
tokenizer_hub: str | None = None
|
||||
|
||||
|
||||
class LLMModelConfig:
|
||||
_model_defaults: Dict[DefaultModelSuppliers, Dict[str, LLMConfig]] = {
|
||||
DefaultModelSuppliers.OPENAI: {
|
||||
"gpt-4o": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
|
||||
"gpt-4o-mini": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
|
||||
"gpt-4-turbo": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4"),
|
||||
"gpt-4": LLMConfig(context=8192, tokenizer_hub="Xenova/gpt-4"),
|
||||
"gpt-3.5-turbo": LLMConfig(
|
||||
context=16385, tokenizer_hub="Xenova/gpt-3.5-turbo"
|
||||
),
|
||||
"text-embedding-3-large": LLMConfig(
|
||||
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
|
||||
),
|
||||
"text-embedding-3-small": LLMConfig(
|
||||
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
|
||||
),
|
||||
"text-embedding-ada-002": LLMConfig(
|
||||
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
|
||||
),
|
||||
},
|
||||
DefaultModelSuppliers.ANTHROPIC: {
|
||||
"claude-3-5-sonnet": LLMConfig(
|
||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
"claude-3-opus": LLMConfig(
|
||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
"claude-3-sonnet": LLMConfig(
|
||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
"claude-3-haiku": LLMConfig(
|
||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
"claude-2-1": LLMConfig(
|
||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
"claude-2-0": LLMConfig(
|
||||
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
"claude-instant-1-2": LLMConfig(
|
||||
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
},
|
||||
DefaultModelSuppliers.META: {
|
||||
"llama-3.1": LLMConfig(
|
||||
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
|
||||
),
|
||||
"llama-3": LLMConfig(
|
||||
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
|
||||
),
|
||||
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
|
||||
"code-llama": LLMConfig(
|
||||
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
|
||||
),
|
||||
},
|
||||
DefaultModelSuppliers.GROQ: {
|
||||
"llama-3.1": LLMConfig(
|
||||
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
|
||||
),
|
||||
"llama-3": LLMConfig(
|
||||
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
|
||||
),
|
||||
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
|
||||
"code-llama": LLMConfig(
|
||||
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
|
||||
),
|
||||
},
|
||||
DefaultModelSuppliers.MISTRAL: {
|
||||
"mistral-large": LLMConfig(
|
||||
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
|
||||
),
|
||||
"mistral-small": LLMConfig(
|
||||
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
|
||||
),
|
||||
"mistral-nemo": LLMConfig(
|
||||
context=128000, tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer"
|
||||
),
|
||||
"codestral": LLMConfig(
|
||||
context=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_supplier_by_model_name(cls, model: str) -> DefaultModelSuppliers | None:
|
||||
# Iterate over the suppliers and their models
|
||||
for supplier, models in cls._model_defaults.items():
|
||||
# Check if the model name or a base part of the model name is in the supplier's models
|
||||
for base_model_name in models:
|
||||
if model.startswith(base_model_name):
|
||||
return supplier
|
||||
# Return None if no supplier matches the model name
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_llm_model_config(
|
||||
cls, supplier: DefaultModelSuppliers, model_name: str
|
||||
) -> Optional[LLMConfig]:
|
||||
"""Retrieve the LLMConfig (context and tokenizer_hub) for a given supplier and model."""
|
||||
supplier_defaults = cls._model_defaults.get(supplier)
|
||||
if not supplier_defaults:
|
||||
return None
|
||||
|
||||
# Use startswith logic for matching model names
|
||||
for key, config in supplier_defaults.items():
|
||||
if model_name.startswith(key):
|
||||
return config
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class LLMEndpointConfig(QuivrBaseConfig):
|
||||
"""
|
||||
Configuration class for Large Language Models (LLM) endpoints.
|
||||
|
||||
This class defines the settings and parameters for interacting with various LLM providers.
|
||||
It includes configuration for the model, API keys, token limits, and other relevant settings.
|
||||
|
||||
Attributes:
|
||||
supplier (DefaultModelSuppliers): The LLM provider (default: OPENAI).
|
||||
model (str): The specific model to use (default: "gpt-4o").
|
||||
context_length (int | None): The maximum context length for the model.
|
||||
tokenizer_hub (str | None): The tokenizer to use for this model.
|
||||
llm_base_url (str | None): Base URL for the LLM API.
|
||||
env_variable_name (str): Name of the environment variable for the API key.
|
||||
llm_api_key (str | None): The API key for the LLM provider.
|
||||
max_input_tokens (int): Maximum number of input tokens sent to the LLM (default: 2000).
|
||||
max_output_tokens (int): Maximum number of output tokens returned by the LLM (default: 2000).
|
||||
temperature (float): Temperature setting for text generation (default: 0.7).
|
||||
streaming (bool): Whether to use streaming for responses (default: True).
|
||||
prompt (CustomPromptsModel | None): Custom prompt configuration.
|
||||
"""
|
||||
|
||||
supplier: DefaultModelSuppliers = DefaultModelSuppliers.OPENAI
|
||||
model: str = "gpt-4o"
|
||||
context_length: int | None = None
|
||||
tokenizer_hub: str | None = None
|
||||
llm_base_url: str | None = None
|
||||
env_variable_name: str = f"{supplier.upper()}_API_KEY"
|
||||
llm_api_key: str | None = None
|
||||
max_input_tokens: int = 2000
|
||||
max_output_tokens: int = 2000
|
||||
temperature: float = 0.7
|
||||
streaming: bool = True
|
||||
prompt: CustomPromptsModel | None = None
|
||||
|
||||
_FALLBACK_TOKENIZER = "cl100k_base"
|
||||
|
||||
@property
|
||||
def fallback_tokenizer(self) -> str:
|
||||
"""
|
||||
Get the fallback tokenizer.
|
||||
|
||||
Returns:
|
||||
str: The name of the fallback tokenizer.
|
||||
"""
|
||||
return self._FALLBACK_TOKENIZER
|
||||
|
||||
def __init__(self, **data):
|
||||
"""
|
||||
Initialize the LLMEndpointConfig.
|
||||
|
||||
This method sets up the initial configuration, including setting the LLM model
|
||||
config and API key.
|
||||
|
||||
"""
|
||||
super().__init__(**data)
|
||||
self.set_llm_model_config()
|
||||
self.set_api_key()
|
||||
|
||||
def set_api_key(self, force_reset: bool = False):
|
||||
"""
|
||||
Set the API key for the LLM provider.
|
||||
|
||||
This method attempts to set the API key from the environment variable.
|
||||
If the key is not found, it raises a ValueError.
|
||||
|
||||
Args:
|
||||
force_reset (bool): If True, forces a reset of the API key even if already set.
|
||||
|
||||
Raises:
|
||||
ValueError: If the API key is not set in the environment.
|
||||
"""
|
||||
if not self.llm_api_key or force_reset:
|
||||
self.llm_api_key = os.getenv(self.env_variable_name)
|
||||
|
||||
if not self.llm_api_key:
|
||||
raise ValueError(
|
||||
f"The API key for supplier '{self.supplier}' is not set. "
|
||||
f"Please set the environment variable: {self.env_variable_name}"
|
||||
)
|
||||
|
||||
def set_llm_model_config(self):
|
||||
"""
|
||||
Set the LLM model configuration.
|
||||
|
||||
This method automatically sets the context_length and tokenizer_hub
|
||||
based on the current supplier and model.
|
||||
"""
|
||||
llm_model_config = LLMModelConfig.get_llm_model_config(
|
||||
self.supplier, self.model
|
||||
)
|
||||
if llm_model_config:
|
||||
self.context_length = llm_model_config.context
|
||||
self.tokenizer_hub = llm_model_config.tokenizer_hub
|
||||
|
||||
def set_llm_model(self, model: str):
|
||||
"""
|
||||
Set the LLM model and update related configurations.
|
||||
|
||||
This method updates the supplier and model based on the provided model name,
|
||||
then updates the model config and API key accordingly.
|
||||
|
||||
Args:
|
||||
model (str): The name of the model to set.
|
||||
|
||||
Raises:
|
||||
ValueError: If no corresponding supplier is found for the given model.
|
||||
"""
|
||||
supplier = LLMModelConfig.get_supplier_by_model_name(model)
|
||||
if supplier is None:
|
||||
raise ValueError(
|
||||
f"Cannot find the corresponding supplier for model {model}"
|
||||
)
|
||||
self.supplier = supplier
|
||||
self.model = model
|
||||
|
||||
self.set_llm_model_config()
|
||||
self.set_api_key(force_reset=True)
|
||||
|
||||
def set_from_sqlmodel(self, sqlmodel: BaseModel, mapping: Dict[str, str]):
|
||||
"""
|
||||
Set attributes in LLMEndpointConfig from SQLModel attributes using a field mapping.
|
||||
|
||||
This method allows for dynamic setting of LLMEndpointConfig attributes based on
|
||||
a provided SQLModel instance and a mapping dictionary.
|
||||
|
||||
Args:
|
||||
sqlmodel (SQLModel): An instance of the SQLModel class.
|
||||
mapping (Dict[str, str]): A dictionary that maps SQLModel fields to LLMEndpointConfig fields.
|
||||
Example: {"max_input": "max_input_tokens", "env_variable_name": "env_variable_name"}
|
||||
|
||||
Raises:
|
||||
AttributeError: If any field in the mapping doesn't exist in either the SQLModel or LLMEndpointConfig.
|
||||
"""
|
||||
for model_field, llm_field in mapping.items():
|
||||
if hasattr(sqlmodel, model_field) and hasattr(self, llm_field):
|
||||
setattr(self, llm_field, getattr(sqlmodel, model_field))
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"Invalid mapping: {model_field} or {llm_field} does not exist."
|
||||
)
|
||||
|
||||
|
||||
# Cannot use Pydantic v2 field_validator because of conflicts with pydantic v1 still in use in LangChain
|
||||
class RerankerConfig(QuivrBaseConfig):
|
||||
"""
|
||||
Configuration class for reranker models.
|
||||
|
||||
This class defines the settings for reranker models used in the application,
|
||||
including the supplier, model, and API key information.
|
||||
|
||||
Attributes:
|
||||
supplier (DefaultRerankers | None): The reranker supplier (e.g., COHERE).
|
||||
model (str | None): The specific reranker model to use.
|
||||
top_n (int): The number of top chunks returned by the reranker (default: 5).
|
||||
api_key (str | None): The API key for the reranker service.
|
||||
"""
|
||||
|
||||
supplier: DefaultRerankers | None = None
|
||||
model: str | None = None
|
||||
top_n: int = 5
|
||||
api_key: str | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
"""
|
||||
Initialize the RerankerConfig.
|
||||
"""
|
||||
super().__init__(**data)
|
||||
self.validate_model()
|
||||
|
||||
def validate_model(self):
|
||||
"""
|
||||
Validate and set up the reranker model configuration.
|
||||
|
||||
This method ensures that a model is set (using the default if not provided)
|
||||
and that the necessary API key is available in the environment.
|
||||
|
||||
Raises:
|
||||
ValueError: If the required API key is not set in the environment.
|
||||
"""
|
||||
if self.model is None and self.supplier is not None:
|
||||
self.model = self.supplier.default_model
|
||||
|
||||
if self.supplier:
|
||||
api_key_var = f"{self.supplier.upper()}_API_KEY"
|
||||
self.api_key = os.getenv(api_key_var)
|
||||
|
||||
if self.api_key is None:
|
||||
raise ValueError(
|
||||
f"The API key for supplier '{self.supplier}' is not set. "
|
||||
f"Please set the environment variable: {api_key_var}"
|
||||
)
|
||||
|
||||
|
||||
class NodeConfig(QuivrBaseConfig):
|
||||
"""
|
||||
Configuration class for a node in an AI assistant workflow.
|
||||
|
||||
This class represents a single node in a workflow configuration,
|
||||
defining its name and connections to other nodes.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the node.
|
||||
edges (List[str]): List of names of other nodes this node links to.
|
||||
"""
|
||||
|
||||
name: str
|
||||
edges: List[str]
|
||||
|
||||
|
||||
class WorkflowConfig(QuivrBaseConfig):
|
||||
"""
|
||||
Configuration class for an AI assistant workflow.
|
||||
|
||||
This class represents the entire workflow configuration,
|
||||
consisting of multiple interconnected nodes.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the workflow.
|
||||
nodes (List[NodeConfig]): List of nodes in the workflow.
|
||||
"""
|
||||
|
||||
name: str
|
||||
nodes: List[NodeConfig]
|
||||
|
||||
|
||||
class RetrievalConfig(QuivrBaseConfig):
|
||||
"""
|
||||
Configuration class for the retrieval phase of a RAG assistant.
|
||||
|
||||
This class defines the settings for the retrieval process,
|
||||
including reranker and LLM configurations, as well as various limits and prompts.
|
||||
|
||||
Attributes:
|
||||
workflow_config (WorkflowConfig | None): Configuration for the workflow.
|
||||
reranker_config (RerankerConfig): Configuration for the reranker.
|
||||
llm_config (LLMEndpointConfig): Configuration for the LLM endpoint.
|
||||
max_history (int): Maximum number of past conversation turns to pass to the LLM as context (default: 10).
|
||||
max_files (int): Maximum number of files to process (default: 20).
|
||||
prompt (str | None): Custom prompt for the retrieval process.
|
||||
"""
|
||||
|
||||
workflow_config: WorkflowConfig | None = None
|
||||
reranker_config: RerankerConfig = RerankerConfig()
|
||||
llm_config: LLMEndpointConfig = LLMEndpointConfig()
|
||||
max_history: int = 10
|
||||
max_files: int = 20
|
||||
prompt: str | None = None
|
||||
|
||||
|
||||
class ParserConfig(QuivrBaseConfig):
|
||||
"""
|
||||
Configuration class for the parser.
|
||||
|
||||
This class defines the settings for the parsing process,
|
||||
including configurations for the text splitter and Megaparse.
|
||||
|
||||
Attributes:
|
||||
splitter_config (SplitterConfig): Configuration for the text splitter.
|
||||
megaparse_config (MegaparseConfig): Configuration for Megaparse.
|
||||
"""
|
||||
|
||||
splitter_config: SplitterConfig = SplitterConfig()
|
||||
megaparse_config: MegaparseConfig = MegaparseConfig()
|
||||
|
||||
|
||||
class IngestionConfig(QuivrBaseConfig):
|
||||
"""
|
||||
Configuration class for the data ingestion process.
|
||||
|
||||
This class defines the settings for the data ingestion process,
|
||||
including the parser configuration.
|
||||
|
||||
Attributes:
|
||||
parser_config (ParserConfig): Configuration for the parser.
|
||||
"""
|
||||
|
||||
parser_config: ParserConfig = ParserConfig()
|
||||
|
||||
|
||||
class AssistantConfig(QuivrBaseConfig):
|
||||
"""
|
||||
Configuration class for an AI assistant.
|
||||
|
||||
This class defines the overall configuration for an AI assistant,
|
||||
including settings for retrieval and ingestion processes.
|
||||
|
||||
Attributes:
|
||||
retrieval_config (RetrievalConfig): Configuration for the retrieval process.
|
||||
ingestion_config (IngestionConfig): Configuration for the ingestion process.
|
||||
"""
|
||||
|
||||
retrieval_config: RetrievalConfig = RetrievalConfig()
|
||||
ingestion_config: IngestionConfig = IngestionConfig()
|
||||
|
@ -10,8 +10,8 @@ from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from quivr_core.brain.info import LLMInfo
|
||||
from quivr_core.config import DefaultModelSuppliers, LLMEndpointConfig
|
||||
from quivr_core.utils import model_supports_function_calling
|
||||
from quivr_core.rag.entities.config import DefaultModelSuppliers, LLMEndpointConfig
|
||||
from quivr_core.rag.utils import model_supports_function_calling
|
||||
|
||||
logger = logging.getLogger("quivr_core")
|
||||
|
||||
@ -70,6 +70,7 @@ class LLMEndpoint:
|
||||
else None,
|
||||
azure_endpoint=azure_endpoint,
|
||||
max_tokens=config.max_output_tokens,
|
||||
temperature=config.temperature,
|
||||
)
|
||||
elif config.supplier == DefaultModelSuppliers.ANTHROPIC:
|
||||
_llm = ChatAnthropic(
|
||||
@ -79,6 +80,7 @@ class LLMEndpoint:
|
||||
else None,
|
||||
base_url=config.llm_base_url,
|
||||
max_tokens=config.max_output_tokens,
|
||||
temperature=config.temperature,
|
||||
)
|
||||
elif config.supplier == DefaultModelSuppliers.OPENAI:
|
||||
_llm = ChatOpenAI(
|
||||
@ -88,6 +90,7 @@ class LLMEndpoint:
|
||||
else None,
|
||||
base_url=config.llm_base_url,
|
||||
max_tokens=config.max_output_tokens,
|
||||
temperature=config.temperature,
|
||||
)
|
||||
else:
|
||||
_llm = ChatOpenAI(
|
||||
@ -97,6 +100,7 @@ class LLMEndpoint:
|
||||
else None,
|
||||
base_url=config.llm_base_url,
|
||||
max_tokens=config.max_output_tokens,
|
||||
temperature=config.temperature,
|
||||
)
|
||||
return cls(llm=_llm, llm_config=config)
|
||||
|
||||
@ -118,3 +122,7 @@ class LLMEndpoint:
|
||||
max_tokens=self._config.max_output_tokens,
|
||||
supports_function_calling=self.supports_func_calling(),
|
||||
)
|
||||
|
||||
def clone_llm(self):
|
||||
"""Create a new instance of the LLM with the same configuration."""
|
||||
return self._llm.__class__(**self._llm.__dict__)
|
||||
|
0
core/quivr_core/llm_tools/__init__.py
Normal file
0
core/quivr_core/llm_tools/__init__.py
Normal file
36
core/quivr_core/llm_tools/entity.py
Normal file
36
core/quivr_core/llm_tools/entity.py
Normal file
@ -0,0 +1,36 @@
|
||||
from quivr_core.base_config import QuivrBaseConfig
|
||||
from typing import Callable
|
||||
from langchain_core.tools import BaseTool
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class ToolsCategory(QuivrBaseConfig):
|
||||
name: str
|
||||
description: str
|
||||
tools: list
|
||||
default_tool: str | None = None
|
||||
create_tool: Callable
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.name = self.name.lower()
|
||||
|
||||
|
||||
class ToolWrapper:
|
||||
def __init__(self, tool: BaseTool, format_input: Callable, format_output: Callable):
|
||||
self.tool = tool
|
||||
self.format_input = format_input
|
||||
self.format_output = format_output
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
def __init__(self):
|
||||
self._registry = {}
|
||||
|
||||
def register_tool(self, tool_name: str, create_func: Callable):
|
||||
self._registry[tool_name] = create_func
|
||||
|
||||
def create_tool(self, tool_name: str, config: Dict[str, Any]) -> ToolWrapper:
|
||||
if tool_name not in self._registry:
|
||||
raise ValueError(f"Tool {tool_name} is not supported.")
|
||||
return self._registry[tool_name](config)
|
33
core/quivr_core/llm_tools/llm_tools.py
Normal file
33
core/quivr_core/llm_tools/llm_tools.py
Normal file
@ -0,0 +1,33 @@
|
||||
from typing import Dict, Any, Type, Union
|
||||
|
||||
from quivr_core.llm_tools.entity import ToolWrapper
|
||||
|
||||
from quivr_core.llm_tools.web_search_tools import (
|
||||
WebSearchTools,
|
||||
)
|
||||
|
||||
from quivr_core.llm_tools.other_tools import (
|
||||
OtherTools,
|
||||
)
|
||||
|
||||
TOOLS_CATEGORIES = {
|
||||
WebSearchTools.name: WebSearchTools,
|
||||
OtherTools.name: OtherTools,
|
||||
}
|
||||
|
||||
# Register all ToolsList enums
|
||||
TOOLS_LISTS = {
|
||||
**{tool.value: tool for tool in WebSearchTools.tools},
|
||||
**{tool.value: tool for tool in OtherTools.tools},
|
||||
}
|
||||
|
||||
|
||||
class LLMToolFactory:
|
||||
@staticmethod
|
||||
def create_tool(tool_name: str, config: Dict[str, Any]) -> Union[ToolWrapper, Type]:
|
||||
for category, tools_class in TOOLS_CATEGORIES.items():
|
||||
if tool_name in tools_class.tools:
|
||||
return tools_class.create_tool(tool_name, config)
|
||||
elif tool_name.lower() == category and tools_class.default_tool:
|
||||
return tools_class.create_tool(tools_class.default_tool, config)
|
||||
raise ValueError(f"Tool {tool_name} is not supported.")
|
24
core/quivr_core/llm_tools/other_tools.py
Normal file
24
core/quivr_core/llm_tools/other_tools.py
Normal file
@ -0,0 +1,24 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Type, Union
|
||||
from langchain_core.tools import BaseTool
|
||||
from quivr_core.llm_tools.entity import ToolsCategory
|
||||
from quivr_core.rag.entities.models import cited_answer
|
||||
|
||||
|
||||
class OtherToolsList(str, Enum):
|
||||
CITED_ANSWER = "cited_answer"
|
||||
|
||||
|
||||
def create_other_tool(tool_name: str, config: Dict[str, Any]) -> Union[BaseTool, Type]:
|
||||
if tool_name == OtherToolsList.CITED_ANSWER:
|
||||
return cited_answer
|
||||
else:
|
||||
raise ValueError(f"Tool {tool_name} is not supported.")
|
||||
|
||||
|
||||
OtherTools = ToolsCategory(
|
||||
name="Other",
|
||||
description="Other tools",
|
||||
tools=[OtherToolsList.CITED_ANSWER],
|
||||
create_tool=create_other_tool,
|
||||
)
|
73
core/quivr_core/llm_tools/web_search_tools.py
Normal file
73
core/quivr_core/llm_tools/web_search_tools.py
Normal file
@ -0,0 +1,73 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Any
|
||||
from langchain_community.tools import TavilySearchResults
|
||||
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
|
||||
from quivr_core.llm_tools.entity import ToolsCategory
|
||||
import os
|
||||
from pydantic.v1 import SecretStr as SecretStrV1 # Ensure correct import
|
||||
from quivr_core.llm_tools.entity import ToolWrapper, ToolRegistry
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
class WebSearchToolsList(str, Enum):
|
||||
TAVILY = "tavily"
|
||||
|
||||
|
||||
def create_tavily_tool(config: Dict[str, Any]) -> ToolWrapper:
|
||||
api_key = (
|
||||
config.pop("api_key") if "api_key" in config else os.getenv("TAVILY_API_KEY")
|
||||
)
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"Missing required config key 'api_key' or environment variable 'TAVILY_API_KEY'"
|
||||
)
|
||||
|
||||
tavily_api_wrapper = TavilySearchAPIWrapper(
|
||||
tavily_api_key=SecretStrV1(api_key),
|
||||
)
|
||||
tool = TavilySearchResults(
|
||||
api_wrapper=tavily_api_wrapper,
|
||||
max_results=config.pop("max_results", 5),
|
||||
search_depth=config.pop("search_depth", "advanced"),
|
||||
include_answer=config.pop("include_answer", True),
|
||||
**config,
|
||||
)
|
||||
|
||||
tool.name = WebSearchToolsList.TAVILY.value
|
||||
|
||||
def format_input(task: str) -> Dict[str, Any]:
|
||||
return {"query": task}
|
||||
|
||||
def format_output(response: Any) -> List[Document]:
|
||||
metadata = {"integration": "", "integration_link": ""}
|
||||
return [
|
||||
Document(
|
||||
page_content=d["content"],
|
||||
metadata={
|
||||
**metadata,
|
||||
"file_name": d["url"],
|
||||
"original_file_name": d["url"],
|
||||
},
|
||||
)
|
||||
for d in response
|
||||
]
|
||||
|
||||
return ToolWrapper(tool, format_input, format_output)
|
||||
|
||||
|
||||
# Initialize the registry and register tools
|
||||
web_search_tool_registry = ToolRegistry()
|
||||
web_search_tool_registry.register_tool(WebSearchToolsList.TAVILY, create_tavily_tool)
|
||||
|
||||
|
||||
def create_web_search_tool(tool_name: str, config: Dict[str, Any]) -> ToolWrapper:
|
||||
return web_search_tool_registry.create_tool(tool_name, config)
|
||||
|
||||
|
||||
WebSearchTools = ToolsCategory(
|
||||
name="Web Search",
|
||||
description="Tools for web searching",
|
||||
tools=[WebSearchToolsList.TAVILY],
|
||||
default_tool=WebSearchToolsList.TAVILY,
|
||||
create_tool=create_web_search_tool,
|
||||
)
|
@ -1,119 +0,0 @@
|
||||
import datetime
|
||||
|
||||
from langchain_core.prompts import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
PromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from pydantic import ConfigDict, create_model
|
||||
|
||||
|
||||
class CustomPromptsDict(dict):
|
||||
def __init__(self, type, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._type = type
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Automatically convert the value into a tuple (my_type, value)
|
||||
super().__setitem__(key, (self._type, value))
|
||||
|
||||
|
||||
def _define_custom_prompts() -> CustomPromptsDict:
|
||||
custom_prompts: CustomPromptsDict = CustomPromptsDict(type=BasePromptTemplate)
|
||||
|
||||
today_date = datetime.datetime.now().strftime("%B %d, %Y")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt for question rephrasing
|
||||
# ---------------------------------------------------------------------------
|
||||
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. Keep as much details as possible from previous messages. Keep entity names and all.
|
||||
|
||||
Chat History:
|
||||
{chat_history}
|
||||
Follow Up Input: {question}
|
||||
Standalone question:"""
|
||||
|
||||
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
|
||||
custom_prompts["CONDENSE_QUESTION_PROMPT"] = CONDENSE_QUESTION_PROMPT
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt for RAG
|
||||
# ---------------------------------------------------------------------------
|
||||
system_message_template = (
|
||||
f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}."
|
||||
)
|
||||
|
||||
system_message_template += """
|
||||
When answering use markdown.
|
||||
Use markdown code blocks for code snippets.
|
||||
Answer in a concise and clear manner.
|
||||
Use the following pieces of context from files provided by the user to answer the users.
|
||||
Answer in the same language as the user question.
|
||||
If you don't know the answer with the context provided from the files, just say that you don't know, don't try to make up an answer.
|
||||
Don't cite the source id in the answer objects, but you can use the source to answer the question.
|
||||
You have access to the files to answer the user question (limited to first 20 files):
|
||||
{files}
|
||||
|
||||
If not None, User instruction to follow to answer: {custom_instructions}
|
||||
Don't cite the source id in the answer objects, but you can use the source to answer the question.
|
||||
"""
|
||||
|
||||
template_answer = """
|
||||
Context:
|
||||
{context}
|
||||
|
||||
User Question: {question}
|
||||
Answer:
|
||||
"""
|
||||
|
||||
RAG_ANSWER_PROMPT = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||
HumanMessagePromptTemplate.from_template(template_answer),
|
||||
]
|
||||
)
|
||||
custom_prompts["RAG_ANSWER_PROMPT"] = RAG_ANSWER_PROMPT
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt for formatting documents
|
||||
# ---------------------------------------------------------------------------
|
||||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
|
||||
template="Source: {index} \n {page_content}"
|
||||
)
|
||||
custom_prompts["DEFAULT_DOCUMENT_PROMPT"] = DEFAULT_DOCUMENT_PROMPT
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt for chatting directly with LLMs, without any document retrieval stage
|
||||
# ---------------------------------------------------------------------------
|
||||
system_message_template = (
|
||||
f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}."
|
||||
)
|
||||
system_message_template += """
|
||||
If not None, also follow these user instructions when answering: {custom_instructions}
|
||||
"""
|
||||
|
||||
template_answer = """
|
||||
User Question: {question}
|
||||
Answer:
|
||||
"""
|
||||
CHAT_LLM_PROMPT = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
HumanMessagePromptTemplate.from_template(template_answer),
|
||||
]
|
||||
)
|
||||
custom_prompts["CHAT_LLM_PROMPT"] = CHAT_LLM_PROMPT
|
||||
|
||||
return custom_prompts
|
||||
|
||||
|
||||
_custom_prompts = _define_custom_prompts()
|
||||
CustomPromptsModel = create_model(
|
||||
"CustomPromptsModel", **_custom_prompts, __config__=ConfigDict(extra="forbid")
|
||||
)
|
||||
|
||||
custom_prompts = CustomPromptsModel()
|
@ -1,488 +0,0 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Annotated, AsyncGenerator, Optional, Sequence, TypedDict
|
||||
from uuid import uuid4
|
||||
|
||||
# TODO(@aminediro): this is the only dependency to langchain package, we should remove it
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain_cohere import CohereRerank
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import DefaultRerankers, RetrievalConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.models import (
|
||||
ParsedRAGChunkResponse,
|
||||
ParsedRAGResponse,
|
||||
QuivrKnowledge,
|
||||
RAGResponseMetadata,
|
||||
cited_answer,
|
||||
)
|
||||
from quivr_core.prompts import custom_prompts
|
||||
from quivr_core.utils import (
|
||||
combine_documents,
|
||||
format_file_list,
|
||||
get_chunk_metadata,
|
||||
parse_chunk_response,
|
||||
parse_response,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("quivr_core")
|
||||
|
||||
|
||||
class SpecialEdges(str, Enum):
|
||||
START = "START"
|
||||
END = "END"
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
# The add_messages function defines how an update should be processed
|
||||
# Default is to replace. add_messages says "append"
|
||||
messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
chat_history: ChatHistory
|
||||
docs: list[Document]
|
||||
files: str
|
||||
final_response: dict
|
||||
|
||||
|
||||
class IdempotentCompressor(BaseDocumentCompressor):
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""
|
||||
A no-op document compressor that simply returns the documents it is given.
|
||||
|
||||
This is a placeholder until a more sophisticated document compression
|
||||
algorithm is implemented.
|
||||
"""
|
||||
return documents
|
||||
|
||||
|
||||
class QuivrQARAGLangGraph:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
retrieval_config: RetrievalConfig,
|
||||
llm: LLMEndpoint,
|
||||
vector_store: VectorStore | None = None,
|
||||
reranker: BaseDocumentCompressor | None = None,
|
||||
):
|
||||
"""
|
||||
Construct a QuivrQARAGLangGraph object.
|
||||
|
||||
Args:
|
||||
retrieval_config (RetrievalConfig): The configuration for the RAG model.
|
||||
llm (LLMEndpoint): The LLM to use for generating text.
|
||||
vector_store (VectorStore): The vector store to use for storing and retrieving documents.
|
||||
reranker (BaseDocumentCompressor | None): The document compressor to use for re-ranking documents. Defaults to IdempotentCompressor if not provided.
|
||||
"""
|
||||
self.retrieval_config = retrieval_config
|
||||
self.vector_store = vector_store
|
||||
self.llm_endpoint = llm
|
||||
|
||||
self.graph = None
|
||||
|
||||
if reranker is not None:
|
||||
self.reranker = reranker
|
||||
elif self.retrieval_config.reranker_config.supplier == DefaultRerankers.COHERE:
|
||||
self.reranker = CohereRerank(
|
||||
model=self.retrieval_config.reranker_config.model,
|
||||
top_n=self.retrieval_config.reranker_config.top_n,
|
||||
cohere_api_key=self.retrieval_config.reranker_config.api_key,
|
||||
)
|
||||
elif self.retrieval_config.reranker_config.supplier == DefaultRerankers.JINA:
|
||||
self.reranker = JinaRerank(
|
||||
model=self.retrieval_config.reranker_config.model,
|
||||
top_n=self.retrieval_config.reranker_config.top_n,
|
||||
jina_api_key=self.retrieval_config.reranker_config.api_key,
|
||||
)
|
||||
else:
|
||||
self.reranker = IdempotentCompressor()
|
||||
|
||||
if self.vector_store:
|
||||
self.compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=self.reranker, base_retriever=self.retriever
|
||||
)
|
||||
|
||||
@property
|
||||
def retriever(self):
|
||||
"""
|
||||
Returns a retriever that can retrieve documents from the vector store.
|
||||
|
||||
Returns:
|
||||
VectorStoreRetriever: The retriever.
|
||||
"""
|
||||
if self.vector_store:
|
||||
return self.vector_store.as_retriever()
|
||||
else:
|
||||
raise ValueError("No vector store provided")
|
||||
|
||||
def filter_history(self, state: AgentState) -> dict:
|
||||
"""
|
||||
Filter out the chat history to only include the messages that are relevant to the current question
|
||||
|
||||
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '),
|
||||
AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer,
|
||||
sous la direction de son supérieur hiérarchique, Stanislas Girard."),
|
||||
HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''),
|
||||
HumanMessage(content='Dis moi en plus sur elle'),
|
||||
AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
|
||||
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
|
||||
a token is 4 characters
|
||||
"""
|
||||
chat_history = state["chat_history"]
|
||||
total_tokens = 0
|
||||
total_pairs = 0
|
||||
_chat_id = uuid4()
|
||||
_chat_history = ChatHistory(chat_id=_chat_id, brain_id=chat_history.brain_id)
|
||||
for human_message, ai_message in reversed(list(chat_history.iter_pairs())):
|
||||
# TODO: replace with tiktoken
|
||||
message_tokens = self.llm_endpoint.count_tokens(
|
||||
human_message.content
|
||||
) + self.llm_endpoint.count_tokens(ai_message.content)
|
||||
if (
|
||||
total_tokens + message_tokens
|
||||
> self.retrieval_config.llm_config.max_output_tokens
|
||||
or total_pairs >= self.retrieval_config.max_history
|
||||
):
|
||||
break
|
||||
_chat_history.append(human_message)
|
||||
_chat_history.append(ai_message)
|
||||
total_tokens += message_tokens
|
||||
total_pairs += 1
|
||||
|
||||
return {"chat_history": _chat_history}
|
||||
|
||||
### Nodes
|
||||
def rewrite(self, state):
|
||||
"""
|
||||
Transform the query to produce a better question.
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
|
||||
# Grader
|
||||
msg = custom_prompts.CONDENSE_QUESTION_PROMPT.format(
|
||||
chat_history=state["chat_history"],
|
||||
question=state["messages"][0].content,
|
||||
)
|
||||
|
||||
model = self.llm_endpoint._llm
|
||||
response = model.invoke(msg)
|
||||
return {"messages": [response]}
|
||||
|
||||
def retrieve(self, state):
|
||||
"""
|
||||
Retrieve relevent chunks
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The retrieved chunks
|
||||
"""
|
||||
question = state["messages"][-1].content
|
||||
docs = self.compression_retriever.invoke(question)
|
||||
return {"docs": docs}
|
||||
|
||||
def generate_rag(self, state):
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
messages = state["messages"]
|
||||
user_question = messages[0].content
|
||||
files = state["files"]
|
||||
|
||||
docs = state["docs"]
|
||||
|
||||
# Prompt
|
||||
prompt = self.retrieval_config.prompt
|
||||
|
||||
final_inputs = {}
|
||||
final_inputs["context"] = combine_documents(docs) if docs else "None"
|
||||
final_inputs["question"] = user_question
|
||||
final_inputs["custom_instructions"] = prompt if prompt else "None"
|
||||
final_inputs["files"] = files if files else "None"
|
||||
|
||||
# LLM
|
||||
llm = self.llm_endpoint._llm
|
||||
if self.llm_endpoint.supports_func_calling():
|
||||
llm = self.llm_endpoint._llm.bind_tools(
|
||||
[cited_answer],
|
||||
tool_choice="any",
|
||||
)
|
||||
|
||||
# Chain
|
||||
rag_chain = custom_prompts.RAG_ANSWER_PROMPT | llm
|
||||
|
||||
# Run
|
||||
response = rag_chain.invoke(final_inputs)
|
||||
formatted_response = {
|
||||
"answer": response, # Assuming the last message contains the final answer
|
||||
"docs": docs,
|
||||
}
|
||||
return {"messages": [response], "final_response": formatted_response}
|
||||
|
||||
def generate_chat_llm(self, state):
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
messages = state["messages"]
|
||||
user_question = messages[0].content
|
||||
|
||||
# Prompt
|
||||
prompt = self.retrieval_config.prompt
|
||||
|
||||
final_inputs = {}
|
||||
final_inputs["question"] = user_question
|
||||
final_inputs["custom_instructions"] = prompt if prompt else "None"
|
||||
final_inputs["chat_history"] = state["chat_history"].to_list()
|
||||
|
||||
# LLM
|
||||
llm = self.llm_endpoint._llm
|
||||
|
||||
# Chain
|
||||
rag_chain = custom_prompts.CHAT_LLM_PROMPT | llm
|
||||
|
||||
# Run
|
||||
response = rag_chain.invoke(final_inputs)
|
||||
formatted_response = {
|
||||
"answer": response, # Assuming the last message contains the final answer
|
||||
}
|
||||
return {"messages": [response], "final_response": formatted_response}
|
||||
|
||||
def build_chain(self):
|
||||
"""
|
||||
Builds the langchain chain for the given configuration.
|
||||
|
||||
Returns:
|
||||
Callable[[Dict], Dict]: The langchain chain.
|
||||
"""
|
||||
if not self.graph:
|
||||
self.graph = self.create_graph()
|
||||
|
||||
return self.graph
|
||||
|
||||
def create_graph(self):
|
||||
"""
|
||||
Builds the langchain chain for the given configuration.
|
||||
|
||||
This function creates a state machine which takes a chat history and a question
|
||||
and produces an answer. The state machine consists of the following states:
|
||||
|
||||
- filter_history: Filter the chat history (i.e., remove the last message)
|
||||
- rewrite: Re-write the question using the filtered history
|
||||
- retrieve: Retrieve documents related to the re-written question
|
||||
- generate: Generate an answer using the retrieved documents
|
||||
|
||||
The state machine starts in the filter_history state and transitions as follows:
|
||||
filter_history -> rewrite -> retrieve -> generate -> END
|
||||
|
||||
The final answer is returned as a dictionary with the answer and the list of documents
|
||||
used to generate the answer.
|
||||
|
||||
Returns:
|
||||
Callable[[Dict], Dict]: The langchain chain.
|
||||
"""
|
||||
workflow = StateGraph(AgentState)
|
||||
|
||||
if self.retrieval_config.workflow_config:
|
||||
if SpecialEdges.START not in [
|
||||
node.name for node in self.retrieval_config.workflow_config.nodes
|
||||
]:
|
||||
raise ValueError("The workflow should contain a 'START' node")
|
||||
for node in self.retrieval_config.workflow_config.nodes:
|
||||
if node.name not in SpecialEdges._value2member_map_:
|
||||
workflow.add_node(node.name, getattr(self, node.name))
|
||||
|
||||
for node in self.retrieval_config.workflow_config.nodes:
|
||||
for edge in node.edges:
|
||||
if node.name == SpecialEdges.START:
|
||||
workflow.add_edge(START, edge)
|
||||
elif edge == SpecialEdges.END:
|
||||
workflow.add_edge(node.name, END)
|
||||
else:
|
||||
workflow.add_edge(node.name, edge)
|
||||
else:
|
||||
# Define the nodes we will cycle between
|
||||
workflow.add_node("filter_history", self.filter_history)
|
||||
workflow.add_node("rewrite", self.rewrite) # Re-writing the question
|
||||
workflow.add_node("retrieve", self.retrieve) # retrieval
|
||||
workflow.add_node("generate", self.generate_rag)
|
||||
|
||||
# Add node for filtering history
|
||||
|
||||
workflow.set_entry_point("filter_history")
|
||||
workflow.add_edge("filter_history", "rewrite")
|
||||
workflow.add_edge("rewrite", "retrieve")
|
||||
workflow.add_edge("retrieve", "generate")
|
||||
workflow.add_edge(
|
||||
"generate", END
|
||||
) # Add edge from generate to format_response
|
||||
|
||||
# Compile
|
||||
graph = workflow.compile()
|
||||
return graph
|
||||
|
||||
def answer(
|
||||
self,
|
||||
question: str,
|
||||
history: ChatHistory,
|
||||
list_files: list[QuivrKnowledge],
|
||||
metadata: dict[str, str] = {},
|
||||
) -> ParsedRAGResponse:
|
||||
"""
|
||||
Answer a question using the langgraph chain.
|
||||
|
||||
Args:
|
||||
question (str): The question to answer.
|
||||
history (ChatHistory): The chat history to use for context.
|
||||
list_files (list[QuivrKnowledge]): The list of files to use for retrieval.
|
||||
metadata (dict[str, str], optional): The metadata to pass to the langchain invocation. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
ParsedRAGResponse: The answer to the question.
|
||||
"""
|
||||
concat_list_files = format_file_list(
|
||||
list_files, self.retrieval_config.max_files
|
||||
)
|
||||
conversational_qa_chain = self.build_chain()
|
||||
inputs = {
|
||||
"messages": [
|
||||
("user", question),
|
||||
],
|
||||
"chat_history": history,
|
||||
"files": concat_list_files,
|
||||
}
|
||||
raw_llm_response = conversational_qa_chain.invoke(
|
||||
inputs,
|
||||
config={"metadata": metadata},
|
||||
)
|
||||
response = parse_response(
|
||||
raw_llm_response["final_response"], self.retrieval_config.llm_config.model
|
||||
)
|
||||
return response
|
||||
|
||||
async def answer_astream(
|
||||
self,
|
||||
question: str,
|
||||
history: ChatHistory,
|
||||
list_files: list[QuivrKnowledge],
|
||||
metadata: dict[str, str] = {},
|
||||
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
|
||||
"""
|
||||
Answer a question using the langgraph chain and yield each chunk of the answer separately.
|
||||
|
||||
Args:
|
||||
question (str): The question to answer.
|
||||
history (ChatHistory): The chat history to use for context.
|
||||
list_files (list[QuivrKnowledge]): The list of files to use for retrieval.
|
||||
metadata (dict[str, str], optional): The metadata to pass to the langchain invocation. Defaults to {}.
|
||||
|
||||
Yields:
|
||||
ParsedRAGChunkResponse: Each chunk of the answer.
|
||||
"""
|
||||
concat_list_files = format_file_list(
|
||||
list_files, self.retrieval_config.max_files
|
||||
)
|
||||
conversational_qa_chain = self.build_chain()
|
||||
|
||||
rolling_message = AIMessageChunk(content="")
|
||||
sources: list[Document] | None = None
|
||||
prev_answer = ""
|
||||
chunk_id = 0
|
||||
|
||||
async for event in conversational_qa_chain.astream_events(
|
||||
{
|
||||
"messages": [
|
||||
("user", question),
|
||||
],
|
||||
"chat_history": history,
|
||||
"files": concat_list_files,
|
||||
},
|
||||
version="v1",
|
||||
config={"metadata": metadata},
|
||||
):
|
||||
kind = event["event"]
|
||||
if (
|
||||
not sources
|
||||
and "output" in event["data"]
|
||||
and "docs" in event["data"]["output"]
|
||||
):
|
||||
sources = event["data"]["output"]["docs"]
|
||||
|
||||
if (
|
||||
kind == "on_chat_model_stream"
|
||||
and "generate" in event["metadata"]["langgraph_node"]
|
||||
):
|
||||
chunk = event["data"]["chunk"]
|
||||
rolling_message, answer_str = parse_chunk_response(
|
||||
rolling_message,
|
||||
chunk,
|
||||
self.llm_endpoint.supports_func_calling(),
|
||||
)
|
||||
if len(answer_str) > 0:
|
||||
if (
|
||||
self.llm_endpoint.supports_func_calling()
|
||||
and rolling_message.tool_calls
|
||||
):
|
||||
diff_answer = answer_str[len(prev_answer) :]
|
||||
if len(diff_answer) > 0:
|
||||
parsed_chunk = ParsedRAGChunkResponse(
|
||||
answer=diff_answer,
|
||||
metadata=RAGResponseMetadata(),
|
||||
)
|
||||
prev_answer += diff_answer
|
||||
|
||||
logger.debug(
|
||||
f"answer_astream func_calling=True question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}"
|
||||
)
|
||||
yield parsed_chunk
|
||||
else:
|
||||
parsed_chunk = ParsedRAGChunkResponse(
|
||||
answer=answer_str,
|
||||
metadata=RAGResponseMetadata(),
|
||||
)
|
||||
logger.debug(
|
||||
f"answer_astream func_calling=False question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}"
|
||||
)
|
||||
yield parsed_chunk
|
||||
|
||||
chunk_id += 1
|
||||
|
||||
# Last chunk provides metadata
|
||||
last_chunk = ParsedRAGChunkResponse(
|
||||
answer="",
|
||||
metadata=get_chunk_metadata(rolling_message, sources),
|
||||
last_chunk=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"answer_astream last_chunk={last_chunk} question={question} rolling_msg={rolling_message} chunk_id={chunk_id}"
|
||||
)
|
||||
yield last_chunk
|
0
core/quivr_core/rag/__init__.py
Normal file
0
core/quivr_core/rag/__init__.py
Normal file
0
core/quivr_core/rag/entities/__init__.py
Normal file
0
core/quivr_core/rag/entities/__init__.py
Normal file
@ -1,11 +1,10 @@
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from typing import Any, Generator, List, Tuple
|
||||
from typing import Any, Generator, Tuple, List
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from quivr_core.models import ChatMessage
|
||||
from quivr_core.rag.entities.models import ChatMessage
|
||||
|
||||
|
||||
class ChatHistory:
|
||||
@ -98,17 +97,3 @@ class ChatHistory:
|
||||
"""
|
||||
|
||||
return [_msg.msg for _msg in self._msgs]
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""
|
||||
Support for deepcopy of ChatHistory.
|
||||
This method ensures that mutable objects (like lists) are copied deeply.
|
||||
"""
|
||||
# Create a new instance of ChatHistory
|
||||
new_copy = ChatHistory(self.id, deepcopy(self.brain_id, memo))
|
||||
|
||||
# Perform a deepcopy of the _msgs list
|
||||
new_copy._msgs = deepcopy(self._msgs, memo)
|
||||
|
||||
# Return the deep copied instance
|
||||
return new_copy
|
452
core/quivr_core/rag/entities/config.py
Normal file
452
core/quivr_core/rag/entities/config.py
Normal file
@ -0,0 +1,452 @@
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Dict, Hashable, List, Optional, Union, Any, Type
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from langgraph.graph import START, END
|
||||
from langchain_core.tools import BaseTool
|
||||
from quivr_core.config import MegaparseConfig
|
||||
from rapidfuzz import process, fuzz
|
||||
|
||||
|
||||
from quivr_core.base_config import QuivrBaseConfig
|
||||
from quivr_core.processor.splitter import SplitterConfig
|
||||
from quivr_core.rag.prompts import CustomPromptsModel
|
||||
from quivr_core.llm_tools.llm_tools import LLMToolFactory, TOOLS_CATEGORIES, TOOLS_LISTS
|
||||
|
||||
logger = logging.getLogger("quivr_core")
|
||||
|
||||
|
||||
def normalize_to_env_variable_name(name: str) -> str:
|
||||
# Replace any character that is not a letter, digit, or underscore with an underscore
|
||||
env_variable_name = re.sub(r"[^A-Za-z0-9_]", "_", name).upper()
|
||||
|
||||
# Check if the normalized name starts with a digit
|
||||
if env_variable_name[0].isdigit():
|
||||
raise ValueError(
|
||||
f"Invalid environment variable name '{env_variable_name}': Cannot start with a digit."
|
||||
)
|
||||
|
||||
return env_variable_name
|
||||
|
||||
|
||||
class SpecialEdges(str, Enum):
|
||||
start = "START"
|
||||
end = "END"
|
||||
|
||||
|
||||
class BrainConfig(QuivrBaseConfig):
|
||||
brain_id: UUID | None = None
|
||||
name: str
|
||||
|
||||
@property
|
||||
def id(self) -> UUID | None:
|
||||
return self.brain_id
|
||||
|
||||
|
||||
class DefaultWebSearchTool(str, Enum):
|
||||
TAVILY = "tavily"
|
||||
|
||||
|
||||
class DefaultRerankers(str, Enum):
|
||||
COHERE = "cohere"
|
||||
JINA = "jina"
|
||||
# MIXEDBREAD = "mixedbread-ai"
|
||||
|
||||
@property
|
||||
def default_model(self) -> str:
|
||||
# Mapping of suppliers to their default models
|
||||
return {
|
||||
self.COHERE: "rerank-multilingual-v3.0",
|
||||
self.JINA: "jina-reranker-v2-base-multilingual",
|
||||
# self.MIXEDBREAD: "rmxbai-rerank-large-v1",
|
||||
}[self]
|
||||
|
||||
|
||||
class DefaultModelSuppliers(str, Enum):
|
||||
OPENAI = "openai"
|
||||
AZURE = "azure"
|
||||
ANTHROPIC = "anthropic"
|
||||
META = "meta"
|
||||
MISTRAL = "mistral"
|
||||
GROQ = "groq"
|
||||
|
||||
|
||||
class LLMConfig(QuivrBaseConfig):
|
||||
context: int | None = None
|
||||
tokenizer_hub: str | None = None
|
||||
|
||||
|
||||
class LLMModelConfig:
|
||||
_model_defaults: Dict[DefaultModelSuppliers, Dict[str, LLMConfig]] = {
|
||||
DefaultModelSuppliers.OPENAI: {
|
||||
"gpt-4o": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
|
||||
"gpt-4o-mini": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
|
||||
"gpt-4-turbo": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4"),
|
||||
"gpt-4": LLMConfig(context=8192, tokenizer_hub="Xenova/gpt-4"),
|
||||
"gpt-3.5-turbo": LLMConfig(
|
||||
context=16385, tokenizer_hub="Xenova/gpt-3.5-turbo"
|
||||
),
|
||||
"text-embedding-3-large": LLMConfig(
|
||||
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
|
||||
),
|
||||
"text-embedding-3-small": LLMConfig(
|
||||
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
|
||||
),
|
||||
"text-embedding-ada-002": LLMConfig(
|
||||
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
|
||||
),
|
||||
},
|
||||
DefaultModelSuppliers.ANTHROPIC: {
|
||||
"claude-3-5-sonnet": LLMConfig(
|
||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
"claude-3-opus": LLMConfig(
|
||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
"claude-3-sonnet": LLMConfig(
|
||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
"claude-3-haiku": LLMConfig(
|
||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
"claude-2-1": LLMConfig(
|
||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
"claude-2-0": LLMConfig(
|
||||
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
"claude-instant-1-2": LLMConfig(
|
||||
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||
),
|
||||
},
|
||||
DefaultModelSuppliers.META: {
|
||||
"llama-3.1": LLMConfig(
|
||||
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
|
||||
),
|
||||
"llama-3": LLMConfig(
|
||||
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
|
||||
),
|
||||
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
|
||||
"code-llama": LLMConfig(
|
||||
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
|
||||
),
|
||||
},
|
||||
DefaultModelSuppliers.GROQ: {
|
||||
"llama-3.1": LLMConfig(
|
||||
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
|
||||
),
|
||||
"llama-3": LLMConfig(
|
||||
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
|
||||
),
|
||||
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
|
||||
"code-llama": LLMConfig(
|
||||
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
|
||||
),
|
||||
},
|
||||
DefaultModelSuppliers.MISTRAL: {
|
||||
"mistral-large": LLMConfig(
|
||||
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
|
||||
),
|
||||
"mistral-small": LLMConfig(
|
||||
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
|
||||
),
|
||||
"mistral-nemo": LLMConfig(
|
||||
context=128000, tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer"
|
||||
),
|
||||
"codestral": LLMConfig(
|
||||
context=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_supplier_by_model_name(cls, model: str) -> DefaultModelSuppliers | None:
|
||||
# Iterate over the suppliers and their models
|
||||
for supplier, models in cls._model_defaults.items():
|
||||
# Check if the model name or a base part of the model name is in the supplier's models
|
||||
for base_model_name in models:
|
||||
if model.startswith(base_model_name):
|
||||
return supplier
|
||||
# Return None if no supplier matches the model name
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_llm_model_config(
|
||||
cls, supplier: DefaultModelSuppliers, model_name: str
|
||||
) -> Optional[LLMConfig]:
|
||||
"""Retrieve the LLMConfig (context and tokenizer_hub) for a given supplier and model."""
|
||||
supplier_defaults = cls._model_defaults.get(supplier)
|
||||
if not supplier_defaults:
|
||||
return None
|
||||
|
||||
# Use startswith logic for matching model names
|
||||
for key, config in supplier_defaults.items():
|
||||
if model_name.startswith(key):
|
||||
return config
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class LLMEndpointConfig(QuivrBaseConfig):
|
||||
supplier: DefaultModelSuppliers = DefaultModelSuppliers.OPENAI
|
||||
model: str = "gpt-4o"
|
||||
context_length: int | None = None
|
||||
tokenizer_hub: str | None = None
|
||||
llm_base_url: str | None = None
|
||||
env_variable_name: str | None = None
|
||||
llm_api_key: str | None = None
|
||||
max_context_tokens: int = 2000
|
||||
max_output_tokens: int = 2000
|
||||
temperature: float = 0.7
|
||||
streaming: bool = True
|
||||
prompt: CustomPromptsModel | None = None
|
||||
|
||||
_FALLBACK_TOKENIZER = "cl100k_base"
|
||||
|
||||
@property
|
||||
def fallback_tokenizer(self) -> str:
|
||||
return self._FALLBACK_TOKENIZER
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.set_llm_model_config()
|
||||
self.set_api_key()
|
||||
|
||||
def set_api_key(self, force_reset: bool = False):
|
||||
if not self.supplier:
|
||||
return
|
||||
|
||||
# Check if the corresponding API key environment variable is set
|
||||
if not self.env_variable_name:
|
||||
self.env_variable_name = (
|
||||
f"{normalize_to_env_variable_name(self.supplier)}_API_KEY"
|
||||
)
|
||||
|
||||
if not self.llm_api_key or force_reset:
|
||||
self.llm_api_key = os.getenv(self.env_variable_name)
|
||||
|
||||
if not self.llm_api_key:
|
||||
logger.warning(f"The API key for supplier '{self.supplier}' is not set. ")
|
||||
|
||||
def set_llm_model_config(self):
|
||||
# Automatically set context_length and tokenizer_hub based on the supplier and model
|
||||
llm_model_config = LLMModelConfig.get_llm_model_config(
|
||||
self.supplier, self.model
|
||||
)
|
||||
if llm_model_config:
|
||||
self.context_length = llm_model_config.context
|
||||
self.tokenizer_hub = llm_model_config.tokenizer_hub
|
||||
|
||||
def set_llm_model(self, model: str):
|
||||
supplier = LLMModelConfig.get_supplier_by_model_name(model)
|
||||
if supplier is None:
|
||||
raise ValueError(
|
||||
f"Cannot find the corresponding supplier for model {model}"
|
||||
)
|
||||
self.supplier = supplier
|
||||
self.model = model
|
||||
|
||||
self.set_llm_model_config()
|
||||
self.set_api_key(force_reset=True)
|
||||
|
||||
def set_from_sqlmodel(self, sqlmodel: BaseModel, mapping: Dict[str, str]):
|
||||
"""
|
||||
Set attributes in LLMEndpointConfig from Model attributes using a field mapping.
|
||||
|
||||
:param model_instance: An instance of the Model class.
|
||||
:param mapping: A dictionary that maps Model fields to LLMEndpointConfig fields.
|
||||
Example: {"max_input": "max_input_tokens", "env_variable_name": "env_variable_name"}
|
||||
"""
|
||||
for model_field, llm_field in mapping.items():
|
||||
if hasattr(sqlmodel, model_field) and hasattr(self, llm_field):
|
||||
setattr(self, llm_field, getattr(sqlmodel, model_field))
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"Invalid mapping: {model_field} or {llm_field} does not exist."
|
||||
)
|
||||
|
||||
|
||||
# Cannot use Pydantic v2 field_validator because of conflicts with pydantic v1 still in use in LangChain
|
||||
class RerankerConfig(QuivrBaseConfig):
|
||||
supplier: DefaultRerankers | None = None
|
||||
model: str | None = None
|
||||
top_n: int = 5 # Number of chunks returned by the re-ranker
|
||||
api_key: str | None = None
|
||||
relevance_score_threshold: float | None = None
|
||||
relevance_score_key: str = "relevance_score"
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data) # Call Pydantic's BaseModel init
|
||||
self.validate_model() # Automatically call external validation
|
||||
|
||||
def validate_model(self):
|
||||
# If model is not provided, get default model based on supplier
|
||||
if self.model is None and self.supplier is not None:
|
||||
self.model = self.supplier.default_model
|
||||
|
||||
# Check if the corresponding API key environment variable is set
|
||||
if self.supplier:
|
||||
api_key_var = f"{normalize_to_env_variable_name(self.supplier)}_API_KEY"
|
||||
self.api_key = os.getenv(api_key_var)
|
||||
|
||||
if self.api_key is None:
|
||||
raise ValueError(
|
||||
f"The API key for supplier '{self.supplier}' is not set. "
|
||||
f"Please set the environment variable: {api_key_var}"
|
||||
)
|
||||
|
||||
|
||||
class ConditionalEdgeConfig(QuivrBaseConfig):
|
||||
routing_function: str
|
||||
conditions: Union[list, Dict[Hashable, str]]
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.resolve_special_edges()
|
||||
|
||||
def resolve_special_edges(self):
|
||||
"""Replace SpecialEdges enum values with their corresponding langgraph values."""
|
||||
|
||||
if isinstance(self.conditions, dict):
|
||||
# If conditions is a dictionary, iterate through the key-value pairs
|
||||
for key, value in self.conditions.items():
|
||||
if value == SpecialEdges.end:
|
||||
self.conditions[key] = END
|
||||
elif value == SpecialEdges.start:
|
||||
self.conditions[key] = START
|
||||
elif isinstance(self.conditions, list):
|
||||
# If conditions is a list, iterate through the values
|
||||
for index, value in enumerate(self.conditions):
|
||||
if value == SpecialEdges.end:
|
||||
self.conditions[index] = END
|
||||
elif value == SpecialEdges.start:
|
||||
self.conditions[index] = START
|
||||
|
||||
|
||||
class NodeConfig(QuivrBaseConfig):
|
||||
name: str
|
||||
edges: List[str] | None = None
|
||||
conditional_edge: ConditionalEdgeConfig | None = None
|
||||
tools: List[Dict[str, Any]] | None = None
|
||||
instantiated_tools: List[BaseTool | Type] | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._instantiate_tools()
|
||||
self.resolve_special_edges_in_name_and_edges()
|
||||
|
||||
def resolve_special_edges_in_name_and_edges(self):
|
||||
"""Replace SpecialEdges enum values in name and edges with corresponding langgraph values."""
|
||||
if self.name == SpecialEdges.start:
|
||||
self.name = START
|
||||
elif self.name == SpecialEdges.end:
|
||||
self.name = END
|
||||
|
||||
if self.edges:
|
||||
for i, edge in enumerate(self.edges):
|
||||
if edge == SpecialEdges.start:
|
||||
self.edges[i] = START
|
||||
elif edge == SpecialEdges.end:
|
||||
self.edges[i] = END
|
||||
|
||||
def _instantiate_tools(self):
|
||||
"""Instantiate tools based on the configuration."""
|
||||
if self.tools:
|
||||
self.instantiated_tools = [
|
||||
LLMToolFactory.create_tool(tool_config.pop("name"), tool_config)
|
||||
for tool_config in self.tools
|
||||
]
|
||||
|
||||
|
||||
class DefaultWorkflow(str, Enum):
|
||||
RAG = "rag"
|
||||
|
||||
@property
|
||||
def nodes(self) -> List[NodeConfig]:
|
||||
# Mapping of workflow types to their default node configurations
|
||||
workflows = {
|
||||
self.RAG: [
|
||||
NodeConfig(name=START, edges=["filter_history"]),
|
||||
NodeConfig(name="filter_history", edges=["rewrite"]),
|
||||
NodeConfig(name="rewrite", edges=["retrieve"]),
|
||||
NodeConfig(name="retrieve", edges=["generate_rag"]),
|
||||
NodeConfig(name="generate_rag", edges=[END]),
|
||||
]
|
||||
}
|
||||
return workflows[self]
|
||||
|
||||
|
||||
class WorkflowConfig(QuivrBaseConfig):
|
||||
name: str | None = None
|
||||
nodes: List[NodeConfig] = []
|
||||
available_tools: List[str] | None = None
|
||||
validated_tools: List[BaseTool | Type] = []
|
||||
activated_tools: List[BaseTool | Type] = []
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.check_first_node_is_start()
|
||||
self.validate_available_tools()
|
||||
|
||||
def check_first_node_is_start(self):
|
||||
if self.nodes and self.nodes[0].name != START:
|
||||
raise ValueError(f"The first node should be a {SpecialEdges.start} node")
|
||||
|
||||
def get_node_tools(self, node_name: str) -> List[Any]:
|
||||
"""Get tools for a specific node."""
|
||||
for node in self.nodes:
|
||||
if node.name == node_name and node.instantiated_tools:
|
||||
return node.instantiated_tools
|
||||
return []
|
||||
|
||||
def validate_available_tools(self):
|
||||
if self.available_tools:
|
||||
valid_tools = list(TOOLS_CATEGORIES.keys()) + list(TOOLS_LISTS.keys())
|
||||
for tool in self.available_tools:
|
||||
if tool.lower() in valid_tools:
|
||||
self.validated_tools.append(
|
||||
LLMToolFactory.create_tool(tool, {}).tool
|
||||
)
|
||||
else:
|
||||
matches = process.extractOne(
|
||||
tool.lower(), valid_tools, scorer=fuzz.WRatio
|
||||
)
|
||||
if matches:
|
||||
raise ValueError(
|
||||
f"Tool {tool} is not a valid ToolsCategory or ToolsList. Did you mean {matches[0]}?"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Tool {tool} is not a valid ToolsCategory or ToolsList"
|
||||
)
|
||||
|
||||
|
||||
class RetrievalConfig(QuivrBaseConfig):
|
||||
reranker_config: RerankerConfig = RerankerConfig()
|
||||
llm_config: LLMEndpointConfig = LLMEndpointConfig()
|
||||
max_history: int = 10
|
||||
max_files: int = 20
|
||||
k: int = 40 # Number of chunks returned by the retriever
|
||||
prompt: str | None = None
|
||||
workflow_config: WorkflowConfig = WorkflowConfig(nodes=DefaultWorkflow.RAG.nodes)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.llm_config.set_api_key(force_reset=True)
|
||||
|
||||
|
||||
class ParserConfig(QuivrBaseConfig):
|
||||
splitter_config: SplitterConfig = SplitterConfig()
|
||||
megaparse_config: MegaparseConfig = MegaparseConfig()
|
||||
|
||||
|
||||
class IngestionConfig(QuivrBaseConfig):
|
||||
parser_config: ParserConfig = ParserConfig()
|
||||
|
||||
|
||||
class AssistantConfig(QuivrBaseConfig):
|
||||
retrieval_config: RetrievalConfig = RetrievalConfig()
|
||||
ingestion_config: IngestionConfig = IngestionConfig()
|
@ -7,7 +7,7 @@ from langchain_core.documents import Document
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
|
||||
from langchain_core.pydantic_v1 import Field as FieldV1
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
@ -73,9 +73,9 @@ class ChatLLMMetadata(BaseModel):
|
||||
|
||||
|
||||
class RAGResponseMetadata(BaseModel):
|
||||
citations: list[int] | None = None
|
||||
followup_questions: list[str] | None = None
|
||||
sources: list[Any] | None = None
|
||||
citations: list[int] = Field(default_factory=list)
|
||||
followup_questions: list[str] = Field(default_factory=list)
|
||||
sources: list[Any] = Field(default_factory=list)
|
||||
metadata_model: ChatLLMMetadata | None = None
|
||||
|
||||
|
265
core/quivr_core/rag/prompts.py
Normal file
265
core/quivr_core/rag/prompts.py
Normal file
@ -0,0 +1,265 @@
|
||||
import datetime
|
||||
from pydantic import ConfigDict, create_model
|
||||
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.prompts import (
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
PromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
)
|
||||
|
||||
|
||||
class CustomPromptsDict(dict):
|
||||
def __init__(self, type, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._type = type
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Automatically convert the value into a tuple (my_type, value)
|
||||
super().__setitem__(key, (self._type, value))
|
||||
|
||||
|
||||
def _define_custom_prompts() -> CustomPromptsDict:
|
||||
custom_prompts: CustomPromptsDict = CustomPromptsDict(type=BasePromptTemplate)
|
||||
|
||||
today_date = datetime.datetime.now().strftime("%B %d, %Y")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt for question rephrasing
|
||||
# ---------------------------------------------------------------------------
|
||||
system_message_template = (
|
||||
"Given a chat history and the latest user question "
|
||||
"which might reference context in the chat history, "
|
||||
"formulate a standalone question which can be understood "
|
||||
"without the chat history. Do NOT answer the question, "
|
||||
"just reformulate it if needed and otherwise return it as is. "
|
||||
"Do not output your reasoning, just the question."
|
||||
)
|
||||
|
||||
template_answer = "User question: {question}\n Standalone question:"
|
||||
|
||||
CONDENSE_QUESTION_PROMPT = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
HumanMessagePromptTemplate.from_template(template_answer),
|
||||
]
|
||||
)
|
||||
|
||||
custom_prompts["CONDENSE_QUESTION_PROMPT"] = CONDENSE_QUESTION_PROMPT
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt for RAG
|
||||
# ---------------------------------------------------------------------------
|
||||
system_message_template = f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}. "
|
||||
|
||||
system_message_template += (
|
||||
"- When answering use markdown. Use markdown code blocks for code snippets.\n"
|
||||
"- Answer in a concise and clear manner.\n"
|
||||
"- If no preferred language is provided, answer in the same language as the language used by the user.\n"
|
||||
"- You must use ONLY the provided context to answer the question. "
|
||||
"Do not use any prior knowledge or external information, even if you are certain of the answer.\n"
|
||||
"- If you cannot provide an answer using ONLY the context provided, do not attempt to answer from your own knowledge."
|
||||
"Instead, inform the user that the answer isn't available in the context and suggest using the available tools {tools}.\n"
|
||||
"- Do not apologize when providing an answer.\n"
|
||||
"- Don't cite the source id in the answer objects, but you can use the source to answer the question.\n\n"
|
||||
)
|
||||
|
||||
context_template = (
|
||||
"\n"
|
||||
"- You have access to the following internal reasoning to provide an answer: {reasoning}\n"
|
||||
"- You have access to the following files to answer the user question (limited to first 20 files): {files}\n"
|
||||
"- You have access to the following context to answer the user question: {context}\n"
|
||||
"- Follow these user instruction when crafting the answer: {custom_instructions}\n"
|
||||
"- These user instructions shall take priority over any other previous instruction.\n"
|
||||
"- Remember: if you cannot provide an answer using ONLY the provided context and CITING the sources, "
|
||||
"inform the user that you don't have the answer and consider if any of the tools can help answer the question.\n"
|
||||
"- Explain your reasoning about the potentiel tool usage in the answer.\n"
|
||||
"- Only use binded tools to answer the question.\n"
|
||||
# "OFFER the user the possibility to ACTIVATE a relevant tool among "
|
||||
# "the tools which can be activated."
|
||||
# "Tools which can be activated: {tools}. If any of these tools can help in providing an answer "
|
||||
# "to the user question, you should offer the user the possibility to activate it. "
|
||||
# "Remember, you shall NOT use the above tools, ONLY offer the user the possibility to activate them.\n"
|
||||
)
|
||||
|
||||
template_answer = (
|
||||
"Original task: {question}\n"
|
||||
"Rephrased and contextualized task: {rephrased_task}\n"
|
||||
)
|
||||
|
||||
RAG_ANSWER_PROMPT = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
SystemMessagePromptTemplate.from_template(context_template),
|
||||
HumanMessagePromptTemplate.from_template(template_answer),
|
||||
]
|
||||
)
|
||||
custom_prompts["RAG_ANSWER_PROMPT"] = RAG_ANSWER_PROMPT
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt for formatting documents
|
||||
# ---------------------------------------------------------------------------
|
||||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
|
||||
template="Filename: {original_file_name}\nSource: {index} \n {page_content}"
|
||||
)
|
||||
custom_prompts["DEFAULT_DOCUMENT_PROMPT"] = DEFAULT_DOCUMENT_PROMPT
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt for chatting directly with LLMs, without any document retrieval stage
|
||||
# ---------------------------------------------------------------------------
|
||||
system_message_template = (
|
||||
f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}."
|
||||
)
|
||||
system_message_template += """
|
||||
If not None, also follow these user instructions when answering: {custom_instructions}
|
||||
"""
|
||||
|
||||
template_answer = """
|
||||
User Question: {question}
|
||||
Answer:
|
||||
"""
|
||||
CHAT_LLM_PROMPT = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
HumanMessagePromptTemplate.from_template(template_answer),
|
||||
]
|
||||
)
|
||||
custom_prompts["CHAT_LLM_PROMPT"] = CHAT_LLM_PROMPT
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt to understand the user intent
|
||||
# ---------------------------------------------------------------------------
|
||||
system_message_template = (
|
||||
"Given the following user input, determine the user intent, in particular "
|
||||
"whether the user is providing instructions to the system or is asking the system to "
|
||||
"execute a task:\n"
|
||||
" - if the user is providing direct instructions to modify the system behaviour (for instance, "
|
||||
"'Can you reply in French?' or 'Answer in French' or 'You are an expert legal assistant' "
|
||||
"or 'You will behave as...'), the user intent is 'prompt';\n"
|
||||
" - in all other cases (asking questions, asking for summarising a text, asking for translating a text, ...), "
|
||||
"the intent is 'task'.\n"
|
||||
)
|
||||
|
||||
template_answer = "User input: {question}"
|
||||
|
||||
USER_INTENT_PROMPT = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||
HumanMessagePromptTemplate.from_template(template_answer),
|
||||
]
|
||||
)
|
||||
custom_prompts["USER_INTENT_PROMPT"] = USER_INTENT_PROMPT
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt to create a system prompt from user instructions
|
||||
# ---------------------------------------------------------------------------
|
||||
system_message_template = (
|
||||
"- Given the following user instruction, current system prompt, list of available tools "
|
||||
"and list of activated tools, update the prompt to include the instruction and decide which tools to activate.\n"
|
||||
"- The prompt shall only contain generic instructions which can be applied to any user task or question.\n"
|
||||
"- The prompt shall be concise and clear.\n"
|
||||
"- If the system prompt already contains the instruction, do not add it again.\n"
|
||||
"- If the system prompt contradicts ther user instruction, remove the contradictory "
|
||||
"statement or statements in the system prompt.\n"
|
||||
"- You shall return separately the updated system prompt and the reasoning that led to the update.\n"
|
||||
"- If the system prompt refers to a tool, you shall add the tool to the list of activated tools.\n"
|
||||
"- If no tool activation is needed, return empty lists.\n"
|
||||
"- You shall also return the reasoning that led to the tool activation.\n"
|
||||
"- Current system prompt: {system_prompt}\n"
|
||||
"- List of available tools: {available_tools}\n"
|
||||
"- List of activated tools: {activated_tools}\n\n"
|
||||
)
|
||||
|
||||
template_answer = "User instructions: {instruction}\n"
|
||||
|
||||
UPDATE_PROMPT = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||
HumanMessagePromptTemplate.from_template(template_answer),
|
||||
]
|
||||
)
|
||||
custom_prompts["UPDATE_PROMPT"] = UPDATE_PROMPT
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt to split the user input into multiple questions / instructions
|
||||
# ---------------------------------------------------------------------------
|
||||
system_message_template = (
|
||||
"Given a chat history and the user input, split and rephrase the input into instructions and tasks.\n"
|
||||
"- Instructions direct the system to behave in a certain way or to use specific tools: examples of instructions are "
|
||||
"'Can you reply in French?', 'Answer in French', 'You are an expert legal assistant', "
|
||||
"'You will behave as...', 'Use web search').\n"
|
||||
"- You shall collect and condense all the instructions into a single string.\n"
|
||||
"- The instructions shall be standalone and self-contained, so that they can be understood "
|
||||
"without the chat history. If no instructions are found, return an empty string.\n"
|
||||
"- Instructions to be understood may require considering the chat history.\n"
|
||||
"- Tasks are often questions, but they can also be summarisation tasks, translation tasks, content generation tasks, etc.\n"
|
||||
"- Tasks to be understood may require considering the chat history.\n"
|
||||
"- If the user input contains different tasks, you shall split the input into multiple tasks.\n"
|
||||
"- Each splitted task shall be a standalone, self-contained task which can be understood "
|
||||
"without the chat history. You shall rephrase the tasks if needed.\n"
|
||||
"- If no explicit task is present, you shall infer the tasks from the user input and the chat history.\n"
|
||||
"- Do NOT try to solve the tasks or answer the questions, "
|
||||
"just reformulate them if needed and otherwise return them as is.\n"
|
||||
"- Remember, you shall NOT suggest or generate new tasks.\n"
|
||||
"- As an example, the user input 'What is Apple? Who is its CEO? When was it founded?' "
|
||||
"shall be split into the questions 'What is Apple?', 'Who is the CEO of Apple?' and 'When was Apple founded?'.\n"
|
||||
"- If no tasks are found, return the user input as is in the task.\n"
|
||||
)
|
||||
|
||||
template_answer = "User input: {user_input}"
|
||||
|
||||
SPLIT_PROMPT = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
HumanMessagePromptTemplate.from_template(template_answer),
|
||||
]
|
||||
)
|
||||
custom_prompts["SPLIT_PROMPT"] = SPLIT_PROMPT
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt to grade the relevance of an answer and decide whather to perform a web search
|
||||
# ---------------------------------------------------------------------------
|
||||
system_message_template = (
|
||||
"Given the following tasks you shall determine whether all tasks can be "
|
||||
"completed fully and in the best possible way using the provided context and chat history. "
|
||||
"You shall:\n"
|
||||
"- Consider each task separately,\n"
|
||||
"- Determine whether the context and chat history contain "
|
||||
"all the information necessary to complete the task.\n"
|
||||
"- If the context and chat history do not contain all the information necessary to complete the task, "
|
||||
"consider ONLY the list of tools below and select the tool most appropriate to complete the task.\n"
|
||||
"- If no tools are listed, return the tasks as is and no tool.\n"
|
||||
"- If no relevant tool can be selected, return the tasks as is and no tool.\n"
|
||||
"- Do not propose to use a tool if that tool is not listed among the available tools.\n"
|
||||
)
|
||||
|
||||
context_template = "Context: {context}\n {activated_tools}\n"
|
||||
|
||||
template_answer = "Tasks: {tasks}\n"
|
||||
|
||||
TOOL_ROUTING_PROMPT = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
SystemMessagePromptTemplate.from_template(context_template),
|
||||
HumanMessagePromptTemplate.from_template(template_answer),
|
||||
]
|
||||
)
|
||||
|
||||
custom_prompts["TOOL_ROUTING_PROMPT"] = TOOL_ROUTING_PROMPT
|
||||
|
||||
return custom_prompts
|
||||
|
||||
|
||||
_custom_prompts = _define_custom_prompts()
|
||||
CustomPromptsModel = create_model(
|
||||
"CustomPromptsModel", **_custom_prompts, __config__=ConfigDict(extra="forbid")
|
||||
)
|
||||
|
||||
custom_prompts = CustomPromptsModel()
|
@ -12,18 +12,18 @@ from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import RetrievalConfig
|
||||
from quivr_core.rag.entities.chat import ChatHistory
|
||||
from quivr_core.rag.entities.config import RetrievalConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.models import (
|
||||
from quivr_core.rag.entities.models import (
|
||||
ParsedRAGChunkResponse,
|
||||
ParsedRAGResponse,
|
||||
QuivrKnowledge,
|
||||
RAGResponseMetadata,
|
||||
cited_answer,
|
||||
)
|
||||
from quivr_core.prompts import custom_prompts
|
||||
from quivr_core.utils import (
|
||||
from quivr_core.rag.prompts import custom_prompts
|
||||
from quivr_core.rag.utils import (
|
||||
combine_documents,
|
||||
format_file_list,
|
||||
get_chunk_metadata,
|
891
core/quivr_core/rag/quivr_rag_langgraph.py
Normal file
891
core/quivr_core/rag/quivr_rag_langgraph.py
Normal file
@ -0,0 +1,891 @@
|
||||
import logging
|
||||
from typing import (
|
||||
Annotated,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Dict,
|
||||
Any,
|
||||
Type,
|
||||
)
|
||||
from uuid import uuid4
|
||||
import asyncio
|
||||
|
||||
# TODO(@aminediro): this is the only dependency to langchain package, we should remove it
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain_cohere import CohereRerank
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langgraph.graph import START, END, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.types import Send
|
||||
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import openai
|
||||
|
||||
from quivr_core.rag.entities.chat import ChatHistory
|
||||
from quivr_core.rag.entities.config import DefaultRerankers, NodeConfig, RetrievalConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.llm_tools.llm_tools import LLMToolFactory
|
||||
from quivr_core.rag.entities.models import (
|
||||
ParsedRAGChunkResponse,
|
||||
QuivrKnowledge,
|
||||
)
|
||||
from quivr_core.rag.prompts import custom_prompts
|
||||
from quivr_core.rag.utils import (
|
||||
collect_tools,
|
||||
combine_documents,
|
||||
format_file_list,
|
||||
get_chunk_metadata,
|
||||
parse_chunk_response,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("quivr_core")
|
||||
|
||||
|
||||
class SplittedInput(BaseModel):
|
||||
instructions_reasoning: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The reasoning that leads to identifying the user instructions to the system",
|
||||
)
|
||||
instructions: Optional[str] = Field(
|
||||
default=None, description="The instructions to the system"
|
||||
)
|
||||
|
||||
tasks_reasoning: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The reasoning that leads to identifying the explicit or implicit user tasks and questions",
|
||||
)
|
||||
tasks: Optional[List[str]] = Field(
|
||||
default_factory=lambda: ["No explicit or implicit tasks found"],
|
||||
description="The list of standalone, self-contained tasks or questions.",
|
||||
)
|
||||
|
||||
|
||||
class TasksCompletion(BaseModel):
|
||||
completable_tasks_reasoning: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The reasoning that leads to identifying the user tasks or questions that can be completed using the provided context and chat history.",
|
||||
)
|
||||
completable_tasks: Optional[List[str]] = Field(
|
||||
default_factory=list,
|
||||
description="The user tasks or questions that can be completed using the provided context and chat history.",
|
||||
)
|
||||
|
||||
non_completable_tasks_reasoning: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The reasoning that leads to identifying the user tasks or questions that cannot be completed using the provided context and chat history.",
|
||||
)
|
||||
non_completable_tasks: Optional[List[str]] = Field(
|
||||
default_factory=list,
|
||||
description="The user tasks or questions that need a tool to be completed.",
|
||||
)
|
||||
|
||||
tool_reasoning: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The reasoning that leads to identifying the tool that shall be used to complete the tasks.",
|
||||
)
|
||||
tool: Optional[str] = Field(
|
||||
default_factory=list,
|
||||
description="The tool that shall be used to complete the tasks.",
|
||||
)
|
||||
|
||||
|
||||
class FinalAnswer(BaseModel):
|
||||
reasoning_answer: str = Field(
|
||||
description="The step-by-step reasoning that led to the final answer"
|
||||
)
|
||||
answer: str = Field(description="The final answer to the user tasks/questions")
|
||||
|
||||
all_tasks_completed: bool = Field(
|
||||
description="Whether all tasks/questions have been successfully answered/completed or not. "
|
||||
" If the final answer to the user is 'I don't know' or 'I don't have enough information' or 'I'm not sure', "
|
||||
" this variable should be 'false'"
|
||||
)
|
||||
|
||||
|
||||
class UpdatedPromptAndTools(BaseModel):
|
||||
prompt_reasoning: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The step-by-step reasoning that leads to the updated system prompt",
|
||||
)
|
||||
prompt: Optional[str] = Field(default=None, description="The updated system prompt")
|
||||
|
||||
tools_reasoning: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The reasoning that leads to activating and deactivating the tools",
|
||||
)
|
||||
tools_to_activate: Optional[List[str]] = Field(
|
||||
default_factory=list, description="The list of tools to activate"
|
||||
)
|
||||
tools_to_deactivate: Optional[List[str]] = Field(
|
||||
default_factory=list, description="The list of tools to deactivate"
|
||||
)
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
# The add_messages function defines how an update should be processed
|
||||
# Default is to replace. add_messages says "append"
|
||||
messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
reasoning: List[str]
|
||||
chat_history: ChatHistory
|
||||
docs: list[Document]
|
||||
files: str
|
||||
tasks: List[str]
|
||||
instructions: str
|
||||
tool: str
|
||||
|
||||
|
||||
class IdempotentCompressor(BaseDocumentCompressor):
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""
|
||||
A no-op document compressor that simply returns the documents it is given.
|
||||
|
||||
This is a placeholder until a more sophisticated document compression
|
||||
algorithm is implemented.
|
||||
"""
|
||||
return documents
|
||||
|
||||
|
||||
class QuivrQARAGLangGraph:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
retrieval_config: RetrievalConfig,
|
||||
llm: LLMEndpoint,
|
||||
vector_store: VectorStore | None = None,
|
||||
):
|
||||
"""
|
||||
Construct a QuivrQARAGLangGraph object.
|
||||
|
||||
Args:
|
||||
retrieval_config (RetrievalConfig): The configuration for the RAG model.
|
||||
llm (LLMEndpoint): The LLM to use for generating text.
|
||||
vector_store (VectorStore): The vector store to use for storing and retrieving documents.
|
||||
reranker (BaseDocumentCompressor | None): The document compressor to use for re-ranking documents. Defaults to IdempotentCompressor if not provided.
|
||||
"""
|
||||
self.retrieval_config = retrieval_config
|
||||
self.vector_store = vector_store
|
||||
self.llm_endpoint = llm
|
||||
|
||||
self.graph = None
|
||||
|
||||
def get_reranker(self, **kwargs):
|
||||
# Extract the reranker configuration from self
|
||||
config = self.retrieval_config.reranker_config
|
||||
|
||||
# Allow kwargs to override specific config values
|
||||
supplier = kwargs.pop("supplier", config.supplier)
|
||||
model = kwargs.pop("model", config.model)
|
||||
top_n = kwargs.pop("top_n", config.top_n)
|
||||
api_key = kwargs.pop("api_key", config.api_key)
|
||||
|
||||
if supplier == DefaultRerankers.COHERE:
|
||||
reranker = CohereRerank(
|
||||
model=model, top_n=top_n, cohere_api_key=api_key, **kwargs
|
||||
)
|
||||
elif supplier == DefaultRerankers.JINA:
|
||||
reranker = JinaRerank(
|
||||
model=model, top_n=top_n, jina_api_key=api_key, **kwargs
|
||||
)
|
||||
else:
|
||||
reranker = IdempotentCompressor()
|
||||
|
||||
return reranker
|
||||
|
||||
def get_retriever(self, **kwargs):
|
||||
"""
|
||||
Returns a retriever that can retrieve documents from the vector store.
|
||||
|
||||
Returns:
|
||||
VectorStoreRetriever: The retriever.
|
||||
"""
|
||||
if self.vector_store:
|
||||
retriever = self.vector_store.as_retriever(**kwargs)
|
||||
else:
|
||||
raise ValueError("No vector store provided")
|
||||
|
||||
return retriever
|
||||
|
||||
def routing(self, state: AgentState) -> List[Send]:
|
||||
"""
|
||||
The routing function for the RAG model.
|
||||
|
||||
Args:
|
||||
state (AgentState): The current state of the agent.
|
||||
|
||||
Returns:
|
||||
dict: The next state of the agent.
|
||||
"""
|
||||
|
||||
msg = custom_prompts.SPLIT_PROMPT.format(
|
||||
user_input=state["messages"][0].content,
|
||||
)
|
||||
|
||||
response: SplittedInput
|
||||
|
||||
try:
|
||||
structured_llm = self.llm_endpoint._llm.with_structured_output(
|
||||
SplittedInput, method="json_schema"
|
||||
)
|
||||
response = structured_llm.invoke(msg)
|
||||
|
||||
except openai.BadRequestError:
|
||||
structured_llm = self.llm_endpoint._llm.with_structured_output(
|
||||
SplittedInput
|
||||
)
|
||||
response = structured_llm.invoke(msg)
|
||||
|
||||
send_list: List[Send] = []
|
||||
|
||||
instructions = (
|
||||
response.instructions
|
||||
if response.instructions
|
||||
else self.retrieval_config.prompt
|
||||
)
|
||||
|
||||
if instructions:
|
||||
send_list.append(Send("edit_system_prompt", {"instructions": instructions}))
|
||||
elif response.tasks:
|
||||
chat_history = state["chat_history"]
|
||||
send_list.append(
|
||||
Send(
|
||||
"filter_history",
|
||||
{"chat_history": chat_history, "tasks": response.tasks},
|
||||
)
|
||||
)
|
||||
|
||||
return send_list
|
||||
|
||||
def routing_split(self, state: AgentState):
|
||||
response = self.invoke_structured_output(
|
||||
custom_prompts.SPLIT_PROMPT.format(
|
||||
chat_history=state["chat_history"].to_list(),
|
||||
user_input=state["messages"][0].content,
|
||||
),
|
||||
SplittedInput,
|
||||
)
|
||||
|
||||
instructions = response.instructions or self.retrieval_config.prompt
|
||||
tasks = response.tasks or []
|
||||
|
||||
if instructions:
|
||||
return [
|
||||
Send(
|
||||
"edit_system_prompt",
|
||||
{**state, "instructions": instructions, "tasks": tasks},
|
||||
)
|
||||
]
|
||||
elif tasks:
|
||||
return [Send("filter_history", {**state, "tasks": tasks})]
|
||||
|
||||
return []
|
||||
|
||||
def update_active_tools(self, updated_prompt_and_tools: UpdatedPromptAndTools):
|
||||
if updated_prompt_and_tools.tools_to_activate:
|
||||
for tool in updated_prompt_and_tools.tools_to_activate:
|
||||
for (
|
||||
validated_tool
|
||||
) in self.retrieval_config.workflow_config.validated_tools:
|
||||
if tool == validated_tool.name:
|
||||
self.retrieval_config.workflow_config.activated_tools.append(
|
||||
validated_tool
|
||||
)
|
||||
|
||||
if updated_prompt_and_tools.tools_to_deactivate:
|
||||
for tool in updated_prompt_and_tools.tools_to_deactivate:
|
||||
for (
|
||||
activated_tool
|
||||
) in self.retrieval_config.workflow_config.activated_tools:
|
||||
if tool == activated_tool.name:
|
||||
self.retrieval_config.workflow_config.activated_tools.remove(
|
||||
activated_tool
|
||||
)
|
||||
|
||||
def edit_system_prompt(self, state: AgentState) -> AgentState:
|
||||
user_instruction = state["instructions"]
|
||||
prompt = self.retrieval_config.prompt
|
||||
available_tools, activated_tools = collect_tools(
|
||||
self.retrieval_config.workflow_config
|
||||
)
|
||||
inputs = {
|
||||
"instruction": user_instruction,
|
||||
"system_prompt": prompt if prompt else "",
|
||||
"available_tools": available_tools,
|
||||
"activated_tools": activated_tools,
|
||||
}
|
||||
|
||||
msg = custom_prompts.UPDATE_PROMPT.format(**inputs)
|
||||
|
||||
response: UpdatedPromptAndTools = self.invoke_structured_output(
|
||||
msg, UpdatedPromptAndTools
|
||||
)
|
||||
|
||||
self.update_active_tools(response)
|
||||
self.retrieval_config.prompt = response.prompt
|
||||
|
||||
reasoning = [response.prompt_reasoning] if response.prompt_reasoning else []
|
||||
reasoning += [response.tools_reasoning] if response.tools_reasoning else []
|
||||
|
||||
return {**state, "messages": [], "reasoning": reasoning}
|
||||
|
||||
def filter_history(self, state: AgentState) -> AgentState:
|
||||
"""
|
||||
Filter out the chat history to only include the messages that are relevant to the current question
|
||||
|
||||
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '),
|
||||
AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer,
|
||||
sous la direction de son supérieur hiérarchique, Stanislas Girard."),
|
||||
HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''),
|
||||
HumanMessage(content='Dis moi en plus sur elle'),
|
||||
AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
|
||||
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
|
||||
a token is 4 characters
|
||||
"""
|
||||
|
||||
chat_history = state["chat_history"]
|
||||
total_tokens = 0
|
||||
total_pairs = 0
|
||||
_chat_id = uuid4()
|
||||
_chat_history = ChatHistory(chat_id=_chat_id, brain_id=chat_history.brain_id)
|
||||
for human_message, ai_message in reversed(list(chat_history.iter_pairs())):
|
||||
# TODO: replace with tiktoken
|
||||
message_tokens = self.llm_endpoint.count_tokens(
|
||||
human_message.content
|
||||
) + self.llm_endpoint.count_tokens(ai_message.content)
|
||||
if (
|
||||
total_tokens + message_tokens
|
||||
> self.retrieval_config.llm_config.max_context_tokens
|
||||
or total_pairs >= self.retrieval_config.max_history
|
||||
):
|
||||
break
|
||||
_chat_history.append(human_message)
|
||||
_chat_history.append(ai_message)
|
||||
total_tokens += message_tokens
|
||||
total_pairs += 1
|
||||
|
||||
return {**state, "chat_history": _chat_history}
|
||||
|
||||
### Nodes
|
||||
async def rewrite(self, state: AgentState) -> AgentState:
|
||||
"""
|
||||
Transform the query to produce a better question.
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
|
||||
tasks = (
|
||||
state["tasks"]
|
||||
if "tasks" in state and state["tasks"]
|
||||
else [state["messages"][0].content]
|
||||
)
|
||||
|
||||
# Prepare the async tasks for all user tsks
|
||||
async_tasks = []
|
||||
for task in tasks:
|
||||
msg = custom_prompts.CONDENSE_QUESTION_PROMPT.format(
|
||||
chat_history=state["chat_history"].to_list(),
|
||||
question=task,
|
||||
)
|
||||
|
||||
model = self.llm_endpoint._llm
|
||||
# Asynchronously invoke the model for each question
|
||||
async_tasks.append(model.ainvoke(msg))
|
||||
|
||||
# Gather all the responses asynchronously
|
||||
responses = await asyncio.gather(*async_tasks) if async_tasks else []
|
||||
|
||||
# Replace each question with its condensed version
|
||||
condensed_questions = []
|
||||
for response in responses:
|
||||
condensed_questions.append(response.content)
|
||||
|
||||
return {**state, "tasks": condensed_questions}
|
||||
|
||||
def filter_chunks_by_relevance(self, chunks: List[Document], **kwargs):
|
||||
config = self.retrieval_config.reranker_config
|
||||
relevance_score_threshold = kwargs.get(
|
||||
"relevance_score_threshold", config.relevance_score_threshold
|
||||
)
|
||||
|
||||
if relevance_score_threshold is None:
|
||||
return chunks
|
||||
|
||||
filtered_chunks = []
|
||||
for chunk in chunks:
|
||||
if config.relevance_score_key not in chunk.metadata:
|
||||
logger.warning(
|
||||
f"Relevance score key {config.relevance_score_key} not found in metadata, cannot filter chunks by relevance"
|
||||
)
|
||||
filtered_chunks.append(chunk)
|
||||
elif (
|
||||
chunk.metadata[config.relevance_score_key] >= relevance_score_threshold
|
||||
):
|
||||
filtered_chunks.append(chunk)
|
||||
|
||||
return filtered_chunks
|
||||
|
||||
def tool_routing(self, state: AgentState):
|
||||
tasks = state["tasks"]
|
||||
if not tasks:
|
||||
return [Send("generate_rag", state)]
|
||||
|
||||
docs = state["docs"]
|
||||
|
||||
_, activated_tools = collect_tools(self.retrieval_config.workflow_config)
|
||||
|
||||
input = {
|
||||
"chat_history": state["chat_history"].to_list(),
|
||||
"tasks": state["tasks"],
|
||||
"context": docs,
|
||||
"activated_tools": activated_tools,
|
||||
}
|
||||
|
||||
input, _ = self.reduce_rag_context(
|
||||
inputs=input,
|
||||
prompt=custom_prompts.TOOL_ROUTING_PROMPT,
|
||||
docs=docs,
|
||||
# max_context_tokens=2000,
|
||||
)
|
||||
|
||||
msg = custom_prompts.TOOL_ROUTING_PROMPT.format(**input)
|
||||
|
||||
response: TasksCompletion = self.invoke_structured_output(msg, TasksCompletion)
|
||||
|
||||
send_list: List[Send] = []
|
||||
|
||||
if response.non_completable_tasks and response.tool:
|
||||
payload = {
|
||||
**state,
|
||||
"tasks": response.non_completable_tasks,
|
||||
"tool": response.tool,
|
||||
}
|
||||
send_list.append(Send("run_tool", payload))
|
||||
else:
|
||||
send_list.append(Send("generate_rag", state))
|
||||
|
||||
return send_list
|
||||
|
||||
async def run_tool(self, state: AgentState) -> AgentState:
|
||||
tool = state["tool"]
|
||||
if tool not in [
|
||||
t.name for t in self.retrieval_config.workflow_config.activated_tools
|
||||
]:
|
||||
raise ValueError(f"Tool {tool} not activated")
|
||||
|
||||
tasks = state["tasks"]
|
||||
|
||||
tool_wrapper = LLMToolFactory.create_tool(tool, {})
|
||||
|
||||
# Prepare the async tasks for all questions
|
||||
async_tasks = []
|
||||
for task in tasks:
|
||||
formatted_input = tool_wrapper.format_input(task)
|
||||
# Asynchronously invoke the model for each question
|
||||
async_tasks.append(tool_wrapper.tool.ainvoke(formatted_input))
|
||||
|
||||
# Gather all the responses asynchronously
|
||||
responses = await asyncio.gather(*async_tasks) if async_tasks else []
|
||||
|
||||
docs = []
|
||||
for response in responses:
|
||||
_docs = tool_wrapper.format_output(response)
|
||||
_docs = self.filter_chunks_by_relevance(_docs)
|
||||
docs += _docs
|
||||
|
||||
return {**state, "docs": state["docs"] + docs}
|
||||
|
||||
async def retrieve(self, state: AgentState) -> AgentState:
|
||||
"""
|
||||
Retrieve relevent chunks
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The retrieved chunks
|
||||
"""
|
||||
|
||||
tasks = state["tasks"]
|
||||
if not tasks:
|
||||
return {**state, "docs": []}
|
||||
|
||||
kwargs = {
|
||||
"search_kwargs": {
|
||||
"k": self.retrieval_config.k,
|
||||
}
|
||||
} # type: ignore
|
||||
base_retriever = self.get_retriever(**kwargs)
|
||||
|
||||
kwargs = {"top_n": self.retrieval_config.reranker_config.top_n} # type: ignore
|
||||
reranker = self.get_reranker(**kwargs)
|
||||
|
||||
compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=reranker, base_retriever=base_retriever
|
||||
)
|
||||
|
||||
# Prepare the async tasks for all questions
|
||||
async_tasks = []
|
||||
for task in tasks:
|
||||
# Asynchronously invoke the model for each question
|
||||
async_tasks.append(compression_retriever.ainvoke(task))
|
||||
|
||||
# Gather all the responses asynchronously
|
||||
responses = await asyncio.gather(*async_tasks) if async_tasks else []
|
||||
|
||||
docs = []
|
||||
for response in responses:
|
||||
_docs = self.filter_chunks_by_relevance(response)
|
||||
docs += _docs
|
||||
|
||||
return {**state, "docs": docs}
|
||||
|
||||
async def dynamic_retrieve(self, state: AgentState) -> AgentState:
|
||||
"""
|
||||
Retrieve relevent chunks
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The retrieved chunks
|
||||
"""
|
||||
|
||||
tasks = state["tasks"]
|
||||
if not tasks:
|
||||
return {**state, "docs": []}
|
||||
|
||||
k = self.retrieval_config.k
|
||||
top_n = self.retrieval_config.reranker_config.top_n
|
||||
number_of_relevant_chunks = top_n
|
||||
i = 1
|
||||
|
||||
while number_of_relevant_chunks == top_n:
|
||||
top_n = self.retrieval_config.reranker_config.top_n * i
|
||||
kwargs = {"top_n": top_n}
|
||||
reranker = self.get_reranker(**kwargs)
|
||||
|
||||
k = max([top_n * 2, self.retrieval_config.k])
|
||||
kwargs = {"search_kwargs": {"k": k}} # type: ignore
|
||||
base_retriever = self.get_retriever(**kwargs)
|
||||
|
||||
if i > 1:
|
||||
logging.info(
|
||||
f"Increasing top_n to {top_n} and k to {k} to retrieve more relevant chunks"
|
||||
)
|
||||
|
||||
compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=reranker, base_retriever=base_retriever
|
||||
)
|
||||
|
||||
# Prepare the async tasks for all questions
|
||||
async_tasks = []
|
||||
for task in tasks:
|
||||
# Asynchronously invoke the model for each question
|
||||
async_tasks.append(compression_retriever.ainvoke(task))
|
||||
|
||||
# Gather all the responses asynchronously
|
||||
responses = await asyncio.gather(*async_tasks) if async_tasks else []
|
||||
|
||||
docs = []
|
||||
_n = []
|
||||
for response in responses:
|
||||
_docs = self.filter_chunks_by_relevance(response)
|
||||
_n.append(len(_docs))
|
||||
docs += _docs
|
||||
|
||||
if not docs:
|
||||
break
|
||||
|
||||
context_length = self.get_rag_context_length(state, docs)
|
||||
if context_length >= self.retrieval_config.llm_config.max_context_tokens:
|
||||
logging.warning(
|
||||
f"The context length is {context_length} which is greater than "
|
||||
f"the max context tokens of {self.retrieval_config.llm_config.max_context_tokens}"
|
||||
)
|
||||
break
|
||||
|
||||
number_of_relevant_chunks = max(_n)
|
||||
i += 1
|
||||
|
||||
return {**state, "docs": docs}
|
||||
|
||||
def get_rag_context_length(self, state: AgentState, docs: List[Document]) -> int:
|
||||
final_inputs = self._build_rag_prompt_inputs(state, docs)
|
||||
msg = custom_prompts.RAG_ANSWER_PROMPT.format(**final_inputs)
|
||||
return self.llm_endpoint.count_tokens(msg)
|
||||
|
||||
def reduce_rag_context(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
prompt: BasePromptTemplate,
|
||||
docs: List[Document] | None = None,
|
||||
max_context_tokens: int | None = None,
|
||||
) -> Tuple[Dict[str, Any], List[Document] | None]:
|
||||
MAX_ITERATION = 100
|
||||
SECURITY_FACTOR = 0.85
|
||||
iteration = 0
|
||||
|
||||
msg = prompt.format(**inputs)
|
||||
n = self.llm_endpoint.count_tokens(msg)
|
||||
|
||||
max_context_tokens = (
|
||||
max_context_tokens
|
||||
if max_context_tokens
|
||||
else self.retrieval_config.llm_config.max_context_tokens
|
||||
)
|
||||
|
||||
while n > max_context_tokens * SECURITY_FACTOR:
|
||||
chat_history = inputs["chat_history"] if "chat_history" in inputs else []
|
||||
|
||||
if len(chat_history) > 0:
|
||||
inputs["chat_history"] = chat_history[2:]
|
||||
elif docs and len(docs) > 1:
|
||||
docs = docs[:-1]
|
||||
else:
|
||||
logging.warning(
|
||||
f"Not enough context to reduce. The context length is {n} "
|
||||
f"which is greater than the max context tokens of {max_context_tokens}"
|
||||
)
|
||||
break
|
||||
|
||||
if docs and "context" in inputs:
|
||||
inputs["context"] = combine_documents(docs)
|
||||
|
||||
msg = prompt.format(**inputs)
|
||||
n = self.llm_endpoint.count_tokens(msg)
|
||||
|
||||
iteration += 1
|
||||
if iteration > MAX_ITERATION:
|
||||
logging.warning(
|
||||
f"Attained the maximum number of iterations ({MAX_ITERATION})"
|
||||
)
|
||||
break
|
||||
|
||||
return inputs, docs
|
||||
|
||||
def bind_tools_to_llm(self, node_name: str):
|
||||
if self.llm_endpoint.supports_func_calling():
|
||||
tools = self.retrieval_config.workflow_config.get_node_tools(node_name)
|
||||
if tools: # Only bind tools if there are any available
|
||||
return self.llm_endpoint._llm.bind_tools(tools, tool_choice="any")
|
||||
return self.llm_endpoint._llm
|
||||
|
||||
def generate_rag(self, state: AgentState) -> AgentState:
|
||||
docs: List[Document] | None = state["docs"]
|
||||
final_inputs = self._build_rag_prompt_inputs(state, docs)
|
||||
|
||||
reduced_inputs, docs = self.reduce_rag_context(
|
||||
final_inputs, custom_prompts.RAG_ANSWER_PROMPT, docs
|
||||
)
|
||||
|
||||
msg = custom_prompts.RAG_ANSWER_PROMPT.format(**reduced_inputs)
|
||||
llm = self.bind_tools_to_llm(self.generate_rag.__name__)
|
||||
response = llm.invoke(msg)
|
||||
|
||||
return {**state, "messages": [response], "docs": docs if docs else []}
|
||||
|
||||
def generate_chat_llm(self, state: AgentState) -> AgentState:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
messages = state["messages"]
|
||||
user_question = messages[0].content
|
||||
|
||||
# Prompt
|
||||
prompt = self.retrieval_config.prompt
|
||||
|
||||
final_inputs = {}
|
||||
final_inputs["question"] = user_question
|
||||
final_inputs["custom_instructions"] = prompt if prompt else "None"
|
||||
final_inputs["chat_history"] = state["chat_history"].to_list()
|
||||
|
||||
# LLM
|
||||
llm = self.llm_endpoint._llm
|
||||
|
||||
reduced_inputs, _ = self.reduce_rag_context(
|
||||
final_inputs, custom_prompts.CHAT_LLM_PROMPT, None
|
||||
)
|
||||
|
||||
msg = custom_prompts.CHAT_LLM_PROMPT.format(**reduced_inputs)
|
||||
|
||||
# Run
|
||||
response = llm.invoke(msg)
|
||||
return {**state, "messages": [response]}
|
||||
|
||||
def build_chain(self):
|
||||
"""
|
||||
Builds the langchain chain for the given configuration.
|
||||
|
||||
Returns:
|
||||
Callable[[Dict], Dict]: The langchain chain.
|
||||
"""
|
||||
if not self.graph:
|
||||
self.graph = self.create_graph()
|
||||
|
||||
return self.graph
|
||||
|
||||
def create_graph(self):
|
||||
workflow = StateGraph(AgentState)
|
||||
self.final_nodes = []
|
||||
|
||||
self._build_workflow(workflow)
|
||||
|
||||
return workflow.compile()
|
||||
|
||||
def _build_workflow(self, workflow: StateGraph):
|
||||
for node in self.retrieval_config.workflow_config.nodes:
|
||||
if node.name not in [START, END]:
|
||||
workflow.add_node(node.name, getattr(self, node.name))
|
||||
|
||||
for node in self.retrieval_config.workflow_config.nodes:
|
||||
self._add_node_edges(workflow, node)
|
||||
|
||||
def _add_node_edges(self, workflow: StateGraph, node: NodeConfig):
|
||||
if node.edges:
|
||||
for edge in node.edges:
|
||||
workflow.add_edge(node.name, edge)
|
||||
if edge == END:
|
||||
self.final_nodes.append(node.name)
|
||||
elif node.conditional_edge:
|
||||
routing_function = getattr(self, node.conditional_edge.routing_function)
|
||||
workflow.add_conditional_edges(
|
||||
node.name, routing_function, node.conditional_edge.conditions
|
||||
)
|
||||
if END in node.conditional_edge.conditions:
|
||||
self.final_nodes.append(node.name)
|
||||
else:
|
||||
raise ValueError("Node should have at least one edge or conditional_edge")
|
||||
|
||||
async def answer_astream(
|
||||
self,
|
||||
question: str,
|
||||
history: ChatHistory,
|
||||
list_files: list[QuivrKnowledge],
|
||||
metadata: dict[str, str] = {},
|
||||
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
|
||||
"""
|
||||
Answer a question using the langgraph chain and yield each chunk of the answer separately.
|
||||
"""
|
||||
concat_list_files = format_file_list(
|
||||
list_files, self.retrieval_config.max_files
|
||||
)
|
||||
conversational_qa_chain = self.build_chain()
|
||||
|
||||
rolling_message = AIMessageChunk(content="")
|
||||
docs: list[Document] | None = None
|
||||
previous_content = ""
|
||||
|
||||
async for event in conversational_qa_chain.astream_events(
|
||||
{
|
||||
"messages": [("user", question)],
|
||||
"chat_history": history,
|
||||
"files": concat_list_files,
|
||||
},
|
||||
version="v1",
|
||||
config={"metadata": metadata},
|
||||
):
|
||||
if self._is_final_node_with_docs(event):
|
||||
docs = event["data"]["output"]["docs"]
|
||||
|
||||
if self._is_final_node_and_chat_model_stream(event):
|
||||
chunk = event["data"]["chunk"]
|
||||
rolling_message, new_content, previous_content = parse_chunk_response(
|
||||
rolling_message,
|
||||
chunk,
|
||||
self.llm_endpoint.supports_func_calling(),
|
||||
previous_content,
|
||||
)
|
||||
|
||||
if new_content:
|
||||
chunk_metadata = get_chunk_metadata(rolling_message, docs)
|
||||
yield ParsedRAGChunkResponse(
|
||||
answer=new_content, metadata=chunk_metadata
|
||||
)
|
||||
|
||||
# Yield final metadata chunk
|
||||
yield ParsedRAGChunkResponse(
|
||||
answer="",
|
||||
metadata=get_chunk_metadata(rolling_message, docs),
|
||||
last_chunk=True,
|
||||
)
|
||||
|
||||
def _is_final_node_with_docs(self, event: dict) -> bool:
|
||||
return (
|
||||
"output" in event["data"]
|
||||
and event["data"]["output"] is not None
|
||||
and "docs" in event["data"]["output"]
|
||||
and event["metadata"]["langgraph_node"] in self.final_nodes
|
||||
)
|
||||
|
||||
def _is_final_node_and_chat_model_stream(self, event: dict) -> bool:
|
||||
return (
|
||||
event["event"] == "on_chat_model_stream"
|
||||
and "langgraph_node" in event["metadata"]
|
||||
and event["metadata"]["langgraph_node"] in self.final_nodes
|
||||
)
|
||||
|
||||
def invoke_structured_output(
|
||||
self, prompt: str, output_class: Type[BaseModel]
|
||||
) -> Any:
|
||||
try:
|
||||
structured_llm = self.llm_endpoint._llm.with_structured_output(
|
||||
output_class, method="json_schema"
|
||||
)
|
||||
return structured_llm.invoke(prompt)
|
||||
except openai.BadRequestError:
|
||||
structured_llm = self.llm_endpoint._llm.with_structured_output(output_class)
|
||||
return structured_llm.invoke(prompt)
|
||||
|
||||
def _build_rag_prompt_inputs(
|
||||
self, state: AgentState, docs: List[Document] | None
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the input dictionary for RAG_ANSWER_PROMPT.
|
||||
|
||||
Args:
|
||||
state: Current agent state
|
||||
docs: List of documents or None
|
||||
|
||||
Returns:
|
||||
Dictionary containing all inputs needed for RAG_ANSWER_PROMPT
|
||||
"""
|
||||
messages = state["messages"]
|
||||
user_question = messages[0].content
|
||||
files = state["files"]
|
||||
prompt = self.retrieval_config.prompt
|
||||
available_tools, _ = collect_tools(self.retrieval_config.workflow_config)
|
||||
|
||||
return {
|
||||
"context": combine_documents(docs) if docs else "None",
|
||||
"question": user_question,
|
||||
"rephrased_task": state["tasks"],
|
||||
"custom_instructions": prompt if prompt else "None",
|
||||
"files": files if files else "None",
|
||||
"chat_history": state["chat_history"].to_list(),
|
||||
"reasoning": state["reasoning"] if "reasoning" in state else "None",
|
||||
"tools": available_tools,
|
||||
}
|
188
core/quivr_core/rag/utils.py
Normal file
188
core/quivr_core/rag/utils.py
Normal file
@ -0,0 +1,188 @@
|
||||
import logging
|
||||
from typing import Any, List, Tuple, no_type_check
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.prompts import format_document
|
||||
|
||||
from quivr_core.rag.entities.config import WorkflowConfig
|
||||
from quivr_core.rag.entities.models import (
|
||||
ChatLLMMetadata,
|
||||
ParsedRAGResponse,
|
||||
QuivrKnowledge,
|
||||
RAGResponseMetadata,
|
||||
RawRAGResponse,
|
||||
)
|
||||
from quivr_core.rag.prompts import custom_prompts
|
||||
|
||||
# TODO(@aminediro): define a types packages where we clearly define IO types
|
||||
# This should be used for serialization/deseriallization later
|
||||
|
||||
|
||||
logger = logging.getLogger("quivr_core")
|
||||
|
||||
|
||||
def model_supports_function_calling(model_name: str):
|
||||
models_not_supporting_function_calls: list[str] = ["llama2", "test", "ollama3"]
|
||||
|
||||
return model_name not in models_not_supporting_function_calls
|
||||
|
||||
|
||||
def format_history_to_openai_mesages(
|
||||
tuple_history: List[Tuple[str, str]], system_message: str, question: str
|
||||
) -> List[BaseMessage]:
|
||||
"""Format the chat history into a list of Base Messages"""
|
||||
messages = []
|
||||
messages.append(SystemMessage(content=system_message))
|
||||
for human, ai in tuple_history:
|
||||
messages.append(HumanMessage(content=human))
|
||||
messages.append(AIMessage(content=ai))
|
||||
messages.append(HumanMessage(content=question))
|
||||
return messages
|
||||
|
||||
|
||||
def cited_answer_filter(tool):
|
||||
return tool["name"] == "cited_answer"
|
||||
|
||||
|
||||
def get_chunk_metadata(
|
||||
msg: AIMessageChunk, sources: list[Any] | None = None
|
||||
) -> RAGResponseMetadata:
|
||||
metadata = {"sources": sources or []}
|
||||
|
||||
if not msg.tool_calls:
|
||||
return RAGResponseMetadata(**metadata, metadata_model=None)
|
||||
|
||||
all_citations = []
|
||||
all_followup_questions = []
|
||||
|
||||
for tool_call in msg.tool_calls:
|
||||
if tool_call.get("name") == "cited_answer" and "args" in tool_call:
|
||||
args = tool_call["args"]
|
||||
all_citations.extend(args.get("citations", []))
|
||||
all_followup_questions.extend(args.get("followup_questions", []))
|
||||
|
||||
metadata["citations"] = all_citations
|
||||
metadata["followup_questions"] = all_followup_questions[:3] # Limit to 3
|
||||
|
||||
return RAGResponseMetadata(**metadata, metadata_model=None)
|
||||
|
||||
|
||||
def get_prev_message_str(msg: AIMessageChunk) -> str:
|
||||
if msg.tool_calls:
|
||||
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
|
||||
if "args" in cited_answer and "answer" in cited_answer["args"]:
|
||||
return cited_answer["args"]["answer"]
|
||||
return ""
|
||||
|
||||
|
||||
# TODO: CONVOLUTED LOGIC !
|
||||
# TODO(@aminediro): redo this
|
||||
@no_type_check
|
||||
def parse_chunk_response(
|
||||
rolling_msg: AIMessageChunk,
|
||||
raw_chunk: AIMessageChunk,
|
||||
supports_func_calling: bool,
|
||||
previous_content: str = "",
|
||||
) -> Tuple[AIMessageChunk, str, str]:
|
||||
"""Parse a chunk response
|
||||
Args:
|
||||
rolling_msg: The accumulated message so far
|
||||
raw_chunk: The new chunk to add
|
||||
supports_func_calling: Whether function calling is supported
|
||||
previous_content: The previous content string
|
||||
Returns:
|
||||
Tuple of (updated rolling message, new content only, full content)
|
||||
"""
|
||||
rolling_msg += raw_chunk
|
||||
|
||||
if not supports_func_calling or not rolling_msg.tool_calls:
|
||||
new_content = raw_chunk.content # Just the new chunk's content
|
||||
full_content = rolling_msg.content # The full accumulated content
|
||||
return rolling_msg, new_content, full_content
|
||||
|
||||
current_answers = get_answers_from_tool_calls(rolling_msg.tool_calls)
|
||||
full_answer = "\n\n".join(current_answers)
|
||||
new_content = full_answer[len(previous_content) :]
|
||||
|
||||
return rolling_msg, new_content, full_answer
|
||||
|
||||
|
||||
def get_answers_from_tool_calls(tool_calls):
|
||||
answers = []
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.get("name") == "cited_answer" and "args" in tool_call:
|
||||
answers.append(tool_call["args"].get("answer", ""))
|
||||
return answers
|
||||
|
||||
|
||||
@no_type_check
|
||||
def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse:
|
||||
answers = []
|
||||
sources = raw_response["docs"] if "docs" in raw_response else []
|
||||
|
||||
metadata = RAGResponseMetadata(
|
||||
sources=sources, metadata_model=ChatLLMMetadata(name=model_name)
|
||||
)
|
||||
|
||||
if (
|
||||
model_supports_function_calling(model_name)
|
||||
and "tool_calls" in raw_response["answer"]
|
||||
and raw_response["answer"].tool_calls
|
||||
):
|
||||
all_citations = []
|
||||
all_followup_questions = []
|
||||
for tool_call in raw_response["answer"].tool_calls:
|
||||
if "args" in tool_call:
|
||||
args = tool_call["args"]
|
||||
if "citations" in args:
|
||||
all_citations.extend(args["citations"])
|
||||
if "followup_questions" in args:
|
||||
all_followup_questions.extend(args["followup_questions"])
|
||||
if "answer" in args:
|
||||
answers.append(args["answer"])
|
||||
metadata.citations = all_citations
|
||||
metadata.followup_questions = all_followup_questions
|
||||
else:
|
||||
answers.append(raw_response["answer"].content)
|
||||
|
||||
answer_str = "\n".join(answers)
|
||||
parsed_response = ParsedRAGResponse(answer=answer_str, metadata=metadata)
|
||||
return parsed_response
|
||||
|
||||
|
||||
def combine_documents(
|
||||
docs,
|
||||
document_prompt=custom_prompts.DEFAULT_DOCUMENT_PROMPT,
|
||||
document_separator="\n\n",
|
||||
):
|
||||
# for each docs, add an index in the metadata to be able to cite the sources
|
||||
for doc, index in zip(docs, range(len(docs)), strict=False):
|
||||
doc.metadata["index"] = index
|
||||
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
||||
return document_separator.join(doc_strings)
|
||||
|
||||
|
||||
def format_file_list(
|
||||
list_files_array: list[QuivrKnowledge], max_files: int = 20
|
||||
) -> str:
|
||||
list_files = [file.file_name or file.url for file in list_files_array]
|
||||
files: list[str] = list(filter(lambda n: n is not None, list_files)) # type: ignore
|
||||
files = files[:max_files]
|
||||
|
||||
files_str = "\n".join(files) if list_files_array else "None"
|
||||
return files_str
|
||||
|
||||
|
||||
def collect_tools(workflow_config: WorkflowConfig):
|
||||
validated_tools = "Available tools which can be activated:\n"
|
||||
for i, tool in enumerate(workflow_config.validated_tools):
|
||||
validated_tools += f"Tool {i+1} name: {tool.name}\n"
|
||||
validated_tools += f"Tool {i+1} description: {tool.description}\n\n"
|
||||
|
||||
activated_tools = "Activated tools which can be deactivated:\n"
|
||||
for i, tool in enumerate(workflow_config.activated_tools):
|
||||
activated_tools += f"Tool {i+1} name: {tool.name}\n"
|
||||
activated_tools += f"Tool {i+1} description: {tool.description}\n\n"
|
||||
|
||||
return validated_tools, activated_tools
|
@ -1,167 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, List, Tuple, no_type_check
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.prompts import format_document
|
||||
|
||||
from quivr_core.models import (
|
||||
ChatLLMMetadata,
|
||||
ParsedRAGResponse,
|
||||
QuivrKnowledge,
|
||||
RAGResponseMetadata,
|
||||
RawRAGResponse,
|
||||
)
|
||||
from quivr_core.prompts import custom_prompts
|
||||
|
||||
# TODO(@aminediro): define a types packages where we clearly define IO types
|
||||
# This should be used for serialization/deseriallization later
|
||||
|
||||
|
||||
logger = logging.getLogger("quivr_core")
|
||||
|
||||
|
||||
def model_supports_function_calling(model_name: str):
|
||||
models_supporting_function_calls = [
|
||||
"gpt-4",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-0613",
|
||||
"gpt-4o",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
]
|
||||
return model_name in models_supporting_function_calls
|
||||
|
||||
|
||||
def format_history_to_openai_mesages(
|
||||
tuple_history: List[Tuple[str, str]], system_message: str, question: str
|
||||
) -> List[BaseMessage]:
|
||||
"""Format the chat history into a list of Base Messages"""
|
||||
messages = []
|
||||
messages.append(SystemMessage(content=system_message))
|
||||
for human, ai in tuple_history:
|
||||
messages.append(HumanMessage(content=human))
|
||||
messages.append(AIMessage(content=ai))
|
||||
messages.append(HumanMessage(content=question))
|
||||
return messages
|
||||
|
||||
|
||||
def cited_answer_filter(tool):
|
||||
return tool["name"] == "cited_answer"
|
||||
|
||||
|
||||
def get_chunk_metadata(
|
||||
msg: AIMessageChunk, sources: list[Any] | None = None
|
||||
) -> RAGResponseMetadata:
|
||||
# Initiate the source
|
||||
metadata = {"sources": sources} if sources else {"sources": []}
|
||||
if msg.tool_calls:
|
||||
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
|
||||
|
||||
if "args" in cited_answer:
|
||||
gathered_args = cited_answer["args"]
|
||||
if "citations" in gathered_args:
|
||||
citations = gathered_args["citations"]
|
||||
metadata["citations"] = citations
|
||||
|
||||
if "followup_questions" in gathered_args:
|
||||
followup_questions = gathered_args["followup_questions"]
|
||||
metadata["followup_questions"] = followup_questions
|
||||
|
||||
return RAGResponseMetadata(**metadata, metadata_model=None)
|
||||
|
||||
|
||||
def get_prev_message_str(msg: AIMessageChunk) -> str:
|
||||
if msg.tool_calls:
|
||||
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
|
||||
if "args" in cited_answer and "answer" in cited_answer["args"]:
|
||||
return cited_answer["args"]["answer"]
|
||||
return ""
|
||||
|
||||
|
||||
# TODO: CONVOLUTED LOGIC !
|
||||
# TODO(@aminediro): redo this
|
||||
@no_type_check
|
||||
def parse_chunk_response(
|
||||
rolling_msg: AIMessageChunk,
|
||||
raw_chunk: dict[str, Any],
|
||||
supports_func_calling: bool,
|
||||
) -> Tuple[AIMessageChunk, str]:
|
||||
# Init with sources
|
||||
answer_str = ""
|
||||
|
||||
if "answer" in raw_chunk:
|
||||
answer = raw_chunk["answer"]
|
||||
else:
|
||||
answer = raw_chunk
|
||||
|
||||
rolling_msg += answer
|
||||
if supports_func_calling and rolling_msg.tool_calls:
|
||||
cited_answer = next(x for x in rolling_msg.tool_calls if cited_answer_filter(x))
|
||||
if "args" in cited_answer and "answer" in cited_answer["args"]:
|
||||
gathered_args = cited_answer["args"]
|
||||
# Only send the difference between answer and response_tokens which was the previous answer
|
||||
answer_str = gathered_args["answer"]
|
||||
return rolling_msg, answer_str
|
||||
|
||||
return rolling_msg, answer.content
|
||||
|
||||
|
||||
@no_type_check
|
||||
def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse:
|
||||
answer = ""
|
||||
sources = raw_response["docs"] if "docs" in raw_response else []
|
||||
|
||||
metadata = RAGResponseMetadata(
|
||||
sources=sources, metadata_model=ChatLLMMetadata(name=model_name)
|
||||
)
|
||||
|
||||
if (
|
||||
model_supports_function_calling(model_name)
|
||||
and "tool_calls" in raw_response["answer"]
|
||||
and raw_response["answer"].tool_calls
|
||||
):
|
||||
if "citations" in raw_response["answer"].tool_calls[-1]["args"]:
|
||||
citations = raw_response["answer"].tool_calls[-1]["args"]["citations"]
|
||||
metadata.citations = citations
|
||||
followup_questions = raw_response["answer"].tool_calls[-1]["args"][
|
||||
"followup_questions"
|
||||
]
|
||||
if followup_questions:
|
||||
metadata.followup_questions = followup_questions
|
||||
answer = raw_response["answer"].tool_calls[-1]["args"]["answer"]
|
||||
else:
|
||||
answer = raw_response["answer"].tool_calls[-1]["args"]["answer"]
|
||||
else:
|
||||
answer = raw_response["answer"].content
|
||||
|
||||
parsed_response = ParsedRAGResponse(answer=answer, metadata=metadata)
|
||||
return parsed_response
|
||||
|
||||
|
||||
def combine_documents(
|
||||
docs,
|
||||
document_prompt=custom_prompts.DEFAULT_DOCUMENT_PROMPT,
|
||||
document_separator="\n\n",
|
||||
):
|
||||
# for each docs, add an index in the metadata to be able to cite the sources
|
||||
for doc, index in zip(docs, range(len(docs)), strict=False):
|
||||
doc.metadata["index"] = index
|
||||
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
||||
return document_separator.join(doc_strings)
|
||||
|
||||
|
||||
def format_file_list(
|
||||
list_files_array: list[QuivrKnowledge], max_files: int = 20
|
||||
) -> str:
|
||||
list_files = [file.file_name or file.url for file in list_files_array]
|
||||
files: list[str] = list(filter(lambda n: n is not None, list_files)) # type: ignore
|
||||
files = files[:max_files]
|
||||
|
||||
files_str = "\n".join(files) if list_files_array else "None"
|
||||
return files_str
|
@ -27,6 +27,8 @@ anyio==4.6.2.post1
|
||||
# via anthropic
|
||||
# via httpx
|
||||
# via openai
|
||||
appnope==0.1.4
|
||||
# via ipykernel
|
||||
asttokens==2.4.1
|
||||
# via stack-data
|
||||
attrs==24.2.0
|
||||
@ -80,8 +82,6 @@ frozenlist==1.4.1
|
||||
# via aiosignal
|
||||
fsspec==2024.9.0
|
||||
# via huggingface-hub
|
||||
greenlet==3.1.1
|
||||
# via sqlalchemy
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
httpcore==1.0.6
|
||||
@ -278,6 +278,8 @@ pyyaml==6.0.2
|
||||
pyzmq==26.2.0
|
||||
# via ipykernel
|
||||
# via jupyter-client
|
||||
rapidfuzz==3.10.1
|
||||
# via quivr-core
|
||||
regex==2024.9.11
|
||||
# via tiktoken
|
||||
# via transformers
|
||||
|
@ -56,8 +56,6 @@ frozenlist==1.4.1
|
||||
# via aiosignal
|
||||
fsspec==2024.9.0
|
||||
# via huggingface-hub
|
||||
greenlet==3.1.1
|
||||
# via sqlalchemy
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
httpcore==1.0.6
|
||||
@ -185,6 +183,8 @@ pyyaml==6.0.2
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
# via transformers
|
||||
rapidfuzz==3.10.1
|
||||
# via quivr-core
|
||||
regex==2024.9.11
|
||||
# via tiktoken
|
||||
# via transformers
|
||||
|
@ -9,7 +9,7 @@ from langchain_core.language_models import FakeListChatModel
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.runnables.utils import AddableDict
|
||||
from langchain_core.vectorstores import InMemoryVectorStore
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
from quivr_core.rag.entities.config import LLMEndpointConfig
|
||||
from quivr_core.files.file import FileExtension, QuivrFile
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
|
||||
|
@ -5,10 +5,10 @@ from uuid import uuid4
|
||||
from langchain_core.embeddings import DeterministicFakeEmbedding
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.vectorstores import InMemoryVectorStore
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import LLMEndpointConfig, RetrievalConfig
|
||||
from quivr_core.rag.entities.chat import ChatHistory
|
||||
from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
|
||||
|
||||
async def main():
|
||||
|
@ -27,9 +27,9 @@ retrieval_config:
|
||||
supplier: "openai"
|
||||
|
||||
# The model to use for the LLM for the given supplier
|
||||
model: "gpt-4o"
|
||||
model: "gpt-3.5-turbo-0125"
|
||||
|
||||
max_input_tokens: 2000
|
||||
max_context_tokens: 2000
|
||||
|
||||
# Maximum number of tokens to pass to the LLM
|
||||
# as a context to generate the answer
|
||||
|
@ -40,9 +40,9 @@ retrieval_config:
|
||||
supplier: "openai"
|
||||
|
||||
# The model to use for the LLM for the given supplier
|
||||
model: "gpt-4o"
|
||||
model: "gpt-3.5-turbo-0125"
|
||||
|
||||
max_input_tokens: 2000
|
||||
max_context_tokens: 2000
|
||||
|
||||
# Maximum number of tokens to pass to the LLM
|
||||
# as a context to generate the answer
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from quivr_core.brain import Brain
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.rag.entities.chat import ChatHistory
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.storage.local_storage import TransparentStorage
|
||||
|
||||
@ -104,8 +104,8 @@ async def test_brain_get_history(
|
||||
vector_db=mem_vector_store,
|
||||
)
|
||||
|
||||
brain.ask("question")
|
||||
brain.ask("question")
|
||||
await brain.aask("question")
|
||||
await brain.aask("question")
|
||||
|
||||
assert len(brain.default_chat) == 4
|
||||
|
||||
|
@ -3,7 +3,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.rag.entities.chat import ChatHistory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -1,18 +1,21 @@
|
||||
from quivr_core.config import LLMEndpointConfig, RetrievalConfig
|
||||
from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig
|
||||
|
||||
|
||||
def test_default_llm_config():
|
||||
config = LLMEndpointConfig()
|
||||
|
||||
assert config.model_dump(exclude={"llm_api_key"}) == LLMEndpointConfig(
|
||||
model="gpt-4o",
|
||||
llm_base_url=None,
|
||||
llm_api_key=None,
|
||||
max_input_tokens=2000,
|
||||
max_output_tokens=2000,
|
||||
temperature=0.7,
|
||||
streaming=True,
|
||||
).model_dump(exclude={"llm_api_key"})
|
||||
assert (
|
||||
config.model_dump()
|
||||
== LLMEndpointConfig(
|
||||
model="gpt-4o",
|
||||
llm_base_url=None,
|
||||
llm_api_key=None,
|
||||
max_context_tokens=2000,
|
||||
max_output_tokens=2000,
|
||||
temperature=0.7,
|
||||
streaming=True,
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
|
||||
def test_default_retrievalconfig():
|
||||
@ -20,6 +23,6 @@ def test_default_retrievalconfig():
|
||||
|
||||
assert config.max_files == 20
|
||||
assert config.prompt is None
|
||||
assert config.llm_config.model_dump(
|
||||
exclude={"llm_api_key"}
|
||||
) == LLMEndpointConfig().model_dump(exclude={"llm_api_key"})
|
||||
print("\n\n", config.llm_config, "\n\n")
|
||||
print("\n\n", LLMEndpointConfig(), "\n\n")
|
||||
assert config.llm_config == LLMEndpointConfig()
|
||||
|
@ -3,7 +3,7 @@ import os
|
||||
import pytest
|
||||
from langchain_core.language_models import FakeListChatModel
|
||||
from pydantic.v1.error_wrappers import ValidationError
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
from quivr_core.rag.entities.config import LLMEndpointConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
|
||||
|
||||
|
@ -1,25 +1,51 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import LLMEndpointConfig, RetrievalConfig
|
||||
from quivr_core.rag.entities.chat import ChatHistory
|
||||
from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.models import ParsedRAGChunkResponse, RAGResponseMetadata
|
||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
from quivr_core.rag.entities.models import ParsedRAGChunkResponse, RAGResponseMetadata
|
||||
from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mock_chain_qa_stream(monkeypatch, chunks_stream_answer):
|
||||
class MockQAChain:
|
||||
async def astream_events(self, *args, **kwargs):
|
||||
for c in chunks_stream_answer:
|
||||
default_metadata = {
|
||||
"langgraph_node": "generate",
|
||||
"is_final_node": False,
|
||||
"citations": None,
|
||||
"followup_questions": None,
|
||||
"sources": None,
|
||||
"metadata_model": None,
|
||||
}
|
||||
|
||||
# Send all chunks except the last one
|
||||
for chunk in chunks_stream_answer[:-1]:
|
||||
yield {
|
||||
"event": "on_chat_model_stream",
|
||||
"metadata": {"langgraph_node": "generate"},
|
||||
"data": {"chunk": c},
|
||||
"metadata": default_metadata,
|
||||
"data": {"chunk": chunk["answer"]},
|
||||
}
|
||||
|
||||
# Send the last chunk
|
||||
yield {
|
||||
"event": "end",
|
||||
"metadata": {
|
||||
"langgraph_node": "generate",
|
||||
"is_final_node": True,
|
||||
"citations": [],
|
||||
"followup_questions": None,
|
||||
"sources": [],
|
||||
"metadata_model": None,
|
||||
},
|
||||
"data": {"chunk": chunks_stream_answer[-1]["answer"]},
|
||||
}
|
||||
|
||||
def mock_qa_chain(*args, **kwargs):
|
||||
self = args[0]
|
||||
self.final_nodes = ["generate"]
|
||||
return MockQAChain()
|
||||
|
||||
monkeypatch.setattr(QuivrQARAGLangGraph, "build_chain", mock_qa_chain)
|
||||
@ -48,11 +74,13 @@ async def test_quivrqaraglanggraph(
|
||||
):
|
||||
stream_responses.append(resp)
|
||||
|
||||
# This assertion passed
|
||||
assert all(
|
||||
not r.last_chunk for r in stream_responses[:-1]
|
||||
), "Some chunks before last have last_chunk=True"
|
||||
assert stream_responses[-1].last_chunk
|
||||
|
||||
# Let's check this assertion
|
||||
for idx, response in enumerate(stream_responses[1:-1]):
|
||||
assert (
|
||||
len(response.answer) > 0
|
||||
|
@ -3,7 +3,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from quivr_core.utils import (
|
||||
from quivr_core.rag.utils import (
|
||||
get_prev_message_str,
|
||||
model_supports_function_calling,
|
||||
parse_chunk_response,
|
||||
@ -43,13 +43,9 @@ def test_get_prev_message_str():
|
||||
|
||||
def test_parse_chunk_response_nofunc_calling():
|
||||
rolling_msg = AIMessageChunk(content="")
|
||||
chunk = {
|
||||
"answer": AIMessageChunk(
|
||||
content="next ",
|
||||
)
|
||||
}
|
||||
chunk = AIMessageChunk(content="next ")
|
||||
for i in range(10):
|
||||
rolling_msg, parsed_chunk = parse_chunk_response(rolling_msg, chunk, False)
|
||||
rolling_msg, parsed_chunk, _ = parse_chunk_response(rolling_msg, chunk, False)
|
||||
assert rolling_msg.content == "next " * (i + 1)
|
||||
assert parsed_chunk == "next "
|
||||
|
||||
@ -70,8 +66,9 @@ def test_parse_chunk_response_func_calling(chunks_stream_answer):
|
||||
answer_str_history: list[str] = []
|
||||
|
||||
for chunk in chunks_stream_answer:
|
||||
# This is done
|
||||
rolling_msg, answer_str = parse_chunk_response(rolling_msg, chunk, True)
|
||||
# Extract the AIMessageChunk from the chunk dictionary
|
||||
chunk_msg = chunk["answer"] # Get the AIMessageChunk from the dict
|
||||
rolling_msg, answer_str, _ = parse_chunk_response(rolling_msg, chunk_msg, True)
|
||||
rolling_msgs_history.append(rolling_msg)
|
||||
answer_str_history.append(answer_str)
|
||||
|
||||
|
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
|
||||
import dotenv
|
||||
from quivr_core import Brain
|
||||
from quivr_core.config import AssistantConfig
|
||||
from quivr_core.rag.entities.config import AssistantConfig
|
||||
from rich.traceback import install as rich_install
|
||||
|
||||
ConsoleOutputHandler = logging.StreamHandler()
|
||||
|
@ -1,7 +1,7 @@
|
||||
from langchain_core.embeddings import DeterministicFakeEmbedding
|
||||
from langchain_core.language_models import FakeListChatModel
|
||||
from quivr_core import Brain
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
from quivr_core.rag.entities.config import LLMEndpointConfig
|
||||
from quivr_core.llm.llm_endpoint import LLMEndpoint
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
@ -7,6 +7,7 @@ authors = [
|
||||
]
|
||||
dependencies = [
|
||||
"quivr-core @ file:///${PROJECT_ROOT}/../../core",
|
||||
"python-dotenv>=1.0.1",
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.11"
|
||||
|
@ -55,8 +55,6 @@ frozenlist==1.4.1
|
||||
# via aiosignal
|
||||
fsspec==2024.10.0
|
||||
# via huggingface-hub
|
||||
greenlet==3.1.1
|
||||
# via sqlalchemy
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
httpcore==1.0.6
|
||||
@ -169,6 +167,7 @@ pydantic==2.9.2
|
||||
# via langsmith
|
||||
# via openai
|
||||
# via quivr-core
|
||||
# via sqlmodel
|
||||
pydantic-core==2.23.4
|
||||
# via cohere
|
||||
# via pydantic
|
||||
@ -176,6 +175,8 @@ pygments==2.18.0
|
||||
# via rich
|
||||
python-dateutil==2.9.0.post0
|
||||
# via pandas
|
||||
python-dotenv==1.0.1
|
||||
# via quivr-core
|
||||
pytz==2024.2
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
@ -185,6 +186,8 @@ pyyaml==6.0.2
|
||||
# via langchain-core
|
||||
# via transformers
|
||||
quivr-core @ file:///${PROJECT_ROOT}/../../core
|
||||
rapidfuzz==3.10.1
|
||||
# via quivr-core
|
||||
regex==2024.9.11
|
||||
# via tiktoken
|
||||
# via transformers
|
||||
@ -215,6 +218,9 @@ sniffio==1.3.1
|
||||
sqlalchemy==2.0.36
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via sqlmodel
|
||||
sqlmodel==0.0.22
|
||||
# via quivr-core
|
||||
tabulate==0.9.0
|
||||
# via langchain-cohere
|
||||
tenacity==8.5.0
|
||||
|
@ -55,8 +55,6 @@ frozenlist==1.4.1
|
||||
# via aiosignal
|
||||
fsspec==2024.10.0
|
||||
# via huggingface-hub
|
||||
greenlet==3.1.1
|
||||
# via sqlalchemy
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
httpcore==1.0.6
|
||||
@ -169,6 +167,7 @@ pydantic==2.9.2
|
||||
# via langsmith
|
||||
# via openai
|
||||
# via quivr-core
|
||||
# via sqlmodel
|
||||
pydantic-core==2.23.4
|
||||
# via cohere
|
||||
# via pydantic
|
||||
@ -176,6 +175,8 @@ pygments==2.18.0
|
||||
# via rich
|
||||
python-dateutil==2.9.0.post0
|
||||
# via pandas
|
||||
python-dotenv==1.0.1
|
||||
# via quivr-core
|
||||
pytz==2024.2
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
@ -185,6 +186,8 @@ pyyaml==6.0.2
|
||||
# via langchain-core
|
||||
# via transformers
|
||||
quivr-core @ file:///${PROJECT_ROOT}/../../core
|
||||
rapidfuzz==3.10.1
|
||||
# via quivr-core
|
||||
regex==2024.9.11
|
||||
# via tiktoken
|
||||
# via transformers
|
||||
@ -215,6 +218,9 @@ sniffio==1.3.1
|
||||
sqlalchemy==2.0.36
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via sqlmodel
|
||||
sqlmodel==0.0.22
|
||||
# via quivr-core
|
||||
tabulate==0.9.0
|
||||
# via langchain-cohere
|
||||
tenacity==8.5.0
|
||||
|
@ -2,6 +2,10 @@ import tempfile
|
||||
|
||||
from quivr_core import Brain
|
||||
|
||||
import dotenv
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
if __name__ == "__main__":
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
|
||||
temp_file.write("Gold is a liquid of blue-like colour.")
|
||||
@ -12,7 +16,5 @@ if __name__ == "__main__":
|
||||
file_paths=[temp_file.name],
|
||||
)
|
||||
|
||||
answer = brain.ask(
|
||||
"what is gold? asnwer in french"
|
||||
)
|
||||
print("answer QuivrQARAGLangGraph :", answer.answer)
|
||||
answer = brain.ask("what is gold? answer in french")
|
||||
print("answer QuivrQARAGLangGraph :", answer)
|
||||
|
@ -4,7 +4,7 @@ import tempfile
|
||||
from dotenv import load_dotenv
|
||||
from quivr_core import Brain
|
||||
from quivr_core.quivr_rag import QuivrQARAG
|
||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
|
||||
|
||||
async def main():
|
Loading…
Reference in New Issue
Block a user