mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-22 03:13:00 +03:00
feat: websearch, tool use, user intent, dynamic retrieval, multiple questions (#3424)
# Description This PR includes far too many new features: - detection of user intent (closes CORE-211) - treating multiple questions in parallel (closes CORE-212) - using the chat history when answering a question (closes CORE-213) - filtering of retrieved chunks by relevance threshold (closes CORE-217) - dynamic retrieval of chunks (closes CORE-218) - enabling web search via Tavily (closes CORE-220) - enabling agent / assistant to activate tools when relevant to complete the user task (closes CORE-224) Also closes CORE-205 ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --------- Co-authored-by: Stan Girard <stan@quivr.app>
This commit is contained in:
parent
5401c01ee2
commit
285fe5b960
2
.github/workflows/backend-core-tests.yml
vendored
2
.github/workflows/backend-core-tests.yml
vendored
@ -41,6 +41,4 @@ jobs:
|
|||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y libmagic-dev poppler-utils libreoffice tesseract-ocr pandoc
|
sudo apt-get install -y libmagic-dev poppler-utils libreoffice tesseract-ocr pandoc
|
||||||
cd core
|
cd core
|
||||||
rye run python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()"
|
|
||||||
rye run python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
|
||||||
rye test -p quivr-core
|
rye test -p quivr-core
|
||||||
|
@ -9,7 +9,7 @@ dependencies = [
|
|||||||
"pydantic>=2.8.2",
|
"pydantic>=2.8.2",
|
||||||
"langchain-core>=0.2.38",
|
"langchain-core>=0.2.38",
|
||||||
"langchain>=0.2.14,<0.3.0",
|
"langchain>=0.2.14,<0.3.0",
|
||||||
"langgraph>=0.2.14",
|
"langgraph>=0.2.38",
|
||||||
"httpx>=0.27.0",
|
"httpx>=0.27.0",
|
||||||
"rich>=13.7.1",
|
"rich>=13.7.1",
|
||||||
"tiktoken>=0.7.0",
|
"tiktoken>=0.7.0",
|
||||||
@ -21,6 +21,7 @@ dependencies = [
|
|||||||
"types-pyyaml>=6.0.12.20240808",
|
"types-pyyaml>=6.0.12.20240808",
|
||||||
"transformers[sentencepiece]>=4.44.2",
|
"transformers[sentencepiece]>=4.44.2",
|
||||||
"faiss-cpu>=1.8.0.post1",
|
"faiss-cpu>=1.8.0.post1",
|
||||||
|
"rapidfuzz>=3.10.1",
|
||||||
]
|
]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">= 3.11"
|
requires-python = ">= 3.11"
|
||||||
|
@ -10,7 +10,9 @@ from langchain_core.documents import Document
|
|||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
from quivr_core.rag.entities.models import ParsedRAGResponse
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
from quivr_core.rag.quivr_rag import QuivrQARAG
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
@ -22,19 +24,17 @@ from quivr_core.brain.serialization import (
|
|||||||
LocalStorageConfig,
|
LocalStorageConfig,
|
||||||
TransparentStorageConfig,
|
TransparentStorageConfig,
|
||||||
)
|
)
|
||||||
from quivr_core.chat import ChatHistory
|
from quivr_core.rag.entities.chat import ChatHistory
|
||||||
from quivr_core.config import RetrievalConfig
|
from quivr_core.rag.entities.config import RetrievalConfig
|
||||||
from quivr_core.files.file import load_qfile
|
from quivr_core.files.file import load_qfile
|
||||||
from quivr_core.llm import LLMEndpoint
|
from quivr_core.llm import LLMEndpoint
|
||||||
from quivr_core.models import (
|
from quivr_core.rag.entities.models import (
|
||||||
ParsedRAGChunkResponse,
|
ParsedRAGChunkResponse,
|
||||||
ParsedRAGResponse,
|
|
||||||
QuivrKnowledge,
|
QuivrKnowledge,
|
||||||
SearchResult,
|
SearchResult,
|
||||||
)
|
)
|
||||||
from quivr_core.processor.registry import get_processor_class
|
from quivr_core.processor.registry import get_processor_class
|
||||||
from quivr_core.quivr_rag import QuivrQARAG
|
from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
|
||||||
from quivr_core.storage.local_storage import LocalStorage, TransparentStorage
|
from quivr_core.storage.local_storage import LocalStorage, TransparentStorage
|
||||||
from quivr_core.storage.storage_base import StorageBase
|
from quivr_core.storage.storage_base import StorageBase
|
||||||
|
|
||||||
@ -49,19 +49,15 @@ async def process_files(
|
|||||||
"""
|
"""
|
||||||
Process files in storage.
|
Process files in storage.
|
||||||
This function takes a StorageBase and return a list of langchain documents.
|
This function takes a StorageBase and return a list of langchain documents.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
storage (StorageBase): The storage containing the files to process.
|
storage (StorageBase): The storage containing the files to process.
|
||||||
skip_file_error (bool): Whether to skip files that cannot be processed.
|
skip_file_error (bool): Whether to skip files that cannot be processed.
|
||||||
processor_kwargs (dict[str, Any]): Additional arguments for the processor.
|
processor_kwargs (dict[str, Any]): Additional arguments for the processor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[Document]: List of processed documents in the Langchain Document format.
|
list[Document]: List of processed documents in the Langchain Document format.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If a file cannot be processed and skip_file_error is False.
|
ValueError: If a file cannot be processed and skip_file_error is False.
|
||||||
Exception: If no processor is found for a file of a specific type and skip_file_error is False.
|
Exception: If no processor is found for a file of a specific type and skip_file_error is False.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
knowledge = []
|
knowledge = []
|
||||||
@ -91,23 +87,17 @@ async def process_files(
|
|||||||
class Brain:
|
class Brain:
|
||||||
"""
|
"""
|
||||||
A class representing a Brain.
|
A class representing a Brain.
|
||||||
|
|
||||||
This class allows for the creation of a Brain, which is a collection of knowledge one wants to retrieve information from.
|
This class allows for the creation of a Brain, which is a collection of knowledge one wants to retrieve information from.
|
||||||
|
|
||||||
A Brain is set to:
|
A Brain is set to:
|
||||||
|
|
||||||
* Store files in the storage of your choice (local, S3, etc.)
|
* Store files in the storage of your choice (local, S3, etc.)
|
||||||
* Process the files in the storage to extract text and metadata in a wide range of format.
|
* Process the files in the storage to extract text and metadata in a wide range of format.
|
||||||
* Store the processed files in the vector store of your choice (FAISS, PGVector, etc.) - default to FAISS.
|
* Store the processed files in the vector store of your choice (FAISS, PGVector, etc.) - default to FAISS.
|
||||||
* Create an index of the processed files.
|
* Create an index of the processed files.
|
||||||
* Use the *Quivr* workflow for the retrieval augmented generation.
|
* Use the *Quivr* workflow for the retrieval augmented generation.
|
||||||
|
|
||||||
A Brain is able to:
|
A Brain is able to:
|
||||||
|
|
||||||
* Search for information in the vector store.
|
* Search for information in the vector store.
|
||||||
* Answer questions about the knowledges in the Brain.
|
* Answer questions about the knowledges in the Brain.
|
||||||
* Stream the answer to the question.
|
* Stream the answer to the question.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
name (str): The name of the brain.
|
name (str): The name of the brain.
|
||||||
id (UUID): The unique identifier of the brain.
|
id (UUID): The unique identifier of the brain.
|
||||||
@ -115,16 +105,14 @@ class Brain:
|
|||||||
llm (LLMEndpoint): The language model used to generate the answer.
|
llm (LLMEndpoint): The language model used to generate the answer.
|
||||||
vector_db (VectorStore): The vector store used to store the processed files.
|
vector_db (VectorStore): The vector store used to store the processed files.
|
||||||
embedder (Embeddings): The embeddings used to create the index of the processed files.
|
embedder (Embeddings): The embeddings used to create the index of the processed files.
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
name: str,
|
name: str,
|
||||||
id: UUID,
|
|
||||||
llm: LLMEndpoint,
|
llm: LLMEndpoint,
|
||||||
|
id: UUID | None = None,
|
||||||
vector_db: VectorStore | None = None,
|
vector_db: VectorStore | None = None,
|
||||||
embedder: Embeddings | None = None,
|
embedder: Embeddings | None = None,
|
||||||
storage: StorageBase | None = None,
|
storage: StorageBase | None = None,
|
||||||
@ -156,19 +144,15 @@ class Brain:
|
|||||||
def load(cls, folder_path: str | Path) -> Self:
|
def load(cls, folder_path: str | Path) -> Self:
|
||||||
"""
|
"""
|
||||||
Load a brain from a folder path.
|
Load a brain from a folder path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
folder_path (str | Path): The path to the folder containing the brain.
|
folder_path (str | Path): The path to the folder containing the brain.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Brain: The brain loaded from the folder path.
|
Brain: The brain loaded from the folder path.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
brain_loaded = Brain.load("path/to/brain")
|
brain_loaded = Brain.load("path/to/brain")
|
||||||
brain_loaded.print_info()
|
brain_loaded.print_info()
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(folder_path, str):
|
if isinstance(folder_path, str):
|
||||||
folder_path = Path(folder_path)
|
folder_path = Path(folder_path)
|
||||||
@ -217,16 +201,13 @@ class Brain:
|
|||||||
vector_db=vector_db,
|
vector_db=vector_db,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def save(self, folder_path: str | Path) -> str:
|
async def save(self, folder_path: str | Path):
|
||||||
"""
|
"""
|
||||||
Save the brain to a folder path.
|
Save the brain to a folder path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
folder_path (str | Path): The path to the folder where the brain will be saved.
|
folder_path (str | Path): The path to the folder where the brain will be saved.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The path to the folder where the brain was saved.
|
str: The path to the folder where the brain was saved.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
await brain.save("path/to/brain")
|
await brain.save("path/to/brain")
|
||||||
@ -324,10 +305,9 @@ class Brain:
|
|||||||
embedder: Embeddings | None = None,
|
embedder: Embeddings | None = None,
|
||||||
skip_file_error: bool = False,
|
skip_file_error: bool = False,
|
||||||
processor_kwargs: dict[str, Any] | None = None,
|
processor_kwargs: dict[str, Any] | None = None,
|
||||||
) -> Self:
|
):
|
||||||
"""
|
"""
|
||||||
Create a brain from a list of file paths.
|
Create a brain from a list of file paths.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): The name of the brain.
|
name (str): The name of the brain.
|
||||||
file_paths (list[str | Path]): The list of file paths to add to the brain.
|
file_paths (list[str | Path]): The list of file paths to add to the brain.
|
||||||
@ -337,10 +317,8 @@ class Brain:
|
|||||||
embedder (Embeddings | None): The embeddings used to create the index of the processed files.
|
embedder (Embeddings | None): The embeddings used to create the index of the processed files.
|
||||||
skip_file_error (bool): Whether to skip files that cannot be processed.
|
skip_file_error (bool): Whether to skip files that cannot be processed.
|
||||||
processor_kwargs (dict[str, Any] | None): Additional arguments for the processor.
|
processor_kwargs (dict[str, Any] | None): Additional arguments for the processor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Brain: The brain created from the file paths.
|
Brain: The brain created from the file paths.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
brain = await Brain.afrom_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
|
brain = await Brain.afrom_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
|
||||||
@ -429,7 +407,6 @@ class Brain:
|
|||||||
) -> Self:
|
) -> Self:
|
||||||
"""
|
"""
|
||||||
Create a brain from a list of langchain documents.
|
Create a brain from a list of langchain documents.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): The name of the brain.
|
name (str): The name of the brain.
|
||||||
langchain_documents (list[Document]): The list of langchain documents to add to the brain.
|
langchain_documents (list[Document]): The list of langchain documents to add to the brain.
|
||||||
@ -437,10 +414,8 @@ class Brain:
|
|||||||
storage (StorageBase): The storage used to store the files.
|
storage (StorageBase): The storage used to store the files.
|
||||||
llm (LLMEndpoint | None): The language model used to generate the answer.
|
llm (LLMEndpoint | None): The language model used to generate the answer.
|
||||||
embedder (Embeddings | None): The embeddings used to create the index of the processed files.
|
embedder (Embeddings | None): The embeddings used to create the index of the processed files.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Brain: The brain created from the langchain documents.
|
Brain: The brain created from the langchain documents.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -449,6 +424,7 @@ class Brain:
|
|||||||
brain.print_info()
|
brain.print_info()
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if llm is None:
|
if llm is None:
|
||||||
llm = default_llm()
|
llm = default_llm()
|
||||||
|
|
||||||
@ -481,16 +457,13 @@ class Brain:
|
|||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
"""
|
"""
|
||||||
Search for relevant documents in the brain based on a query.
|
Search for relevant documents in the brain based on a query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str | Document): The query to search for.
|
query (str | Document): The query to search for.
|
||||||
n_results (int): The number of results to return.
|
n_results (int): The number of results to return.
|
||||||
filter (Callable | Dict[str, Any] | None): The filter to apply to the search.
|
filter (Callable | Dict[str, Any] | None): The filter to apply to the search.
|
||||||
fetch_n_neighbors (int): The number of neighbors to fetch.
|
fetch_n_neighbors (int): The number of neighbors to fetch.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[SearchResult]: The list of retrieved chunks.
|
list[SearchResult]: The list of retrieved chunks.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
|
brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
|
||||||
@ -517,57 +490,6 @@ class Brain:
|
|||||||
# add it to vectorstore
|
# add it to vectorstore
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def ask(
|
|
||||||
self,
|
|
||||||
question: str,
|
|
||||||
retrieval_config: RetrievalConfig | None = None,
|
|
||||||
rag_pipeline: Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None = None,
|
|
||||||
list_files: list[QuivrKnowledge] | None = None,
|
|
||||||
chat_history: ChatHistory | None = None,
|
|
||||||
) -> ParsedRAGResponse:
|
|
||||||
"""
|
|
||||||
Ask a question to the brain and get a generated answer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
question (str): The question to ask.
|
|
||||||
retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs).
|
|
||||||
rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use.
|
|
||||||
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
|
|
||||||
chat_history (ChatHistory | None): The chat history to use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ParsedRAGResponse: The generated answer.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
|
|
||||||
answer = brain.ask("What is the meaning of life?")
|
|
||||||
print(answer.answer)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
async def collect_streamed_response():
|
|
||||||
full_answer = ""
|
|
||||||
async for response in self.ask_streaming(
|
|
||||||
question=question,
|
|
||||||
retrieval_config=retrieval_config,
|
|
||||||
rag_pipeline=rag_pipeline,
|
|
||||||
list_files=list_files,
|
|
||||||
chat_history=chat_history
|
|
||||||
):
|
|
||||||
full_answer += response.answer
|
|
||||||
return full_answer
|
|
||||||
|
|
||||||
# Run the async function in the event loop
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
full_answer = loop.run_until_complete(collect_streamed_response())
|
|
||||||
|
|
||||||
chat_history = self.default_chat if chat_history is None else chat_history
|
|
||||||
chat_history.append(HumanMessage(content=question))
|
|
||||||
chat_history.append(AIMessage(content=full_answer))
|
|
||||||
|
|
||||||
# Return the final response
|
|
||||||
return ParsedRAGResponse(answer=full_answer)
|
|
||||||
|
|
||||||
async def ask_streaming(
|
async def ask_streaming(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
@ -578,24 +500,20 @@ class Brain:
|
|||||||
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
|
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
|
||||||
"""
|
"""
|
||||||
Ask a question to the brain and get a streamed generated answer.
|
Ask a question to the brain and get a streamed generated answer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
question (str): The question to ask.
|
question (str): The question to ask.
|
||||||
retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs).
|
retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs).
|
||||||
rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use.
|
rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use.
|
||||||
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
|
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
|
||||||
chat_history (ChatHistory | None): The chat history to use.
|
chat_history (ChatHistory | None): The chat history to use.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]: The streamed generated answer.
|
AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]: The streamed generated answer.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
|
brain = Brain.from_files(name="My Brain", file_paths=["file1.pdf", "file2.pdf"])
|
||||||
async for chunk in brain.ask_streaming("What is the meaning of life?"):
|
async for chunk in brain.ask_streaming("What is the meaning of life?"):
|
||||||
print(chunk.answer)
|
print(chunk.answer)
|
||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
llm = self.llm
|
llm = self.llm
|
||||||
|
|
||||||
@ -630,3 +548,64 @@ class Brain:
|
|||||||
chat_history.append(AIMessage(content=full_answer))
|
chat_history.append(AIMessage(content=full_answer))
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
|
async def aask(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
retrieval_config: RetrievalConfig | None = None,
|
||||||
|
rag_pipeline: Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None = None,
|
||||||
|
list_files: list[QuivrKnowledge] | None = None,
|
||||||
|
chat_history: ChatHistory | None = None,
|
||||||
|
) -> ParsedRAGResponse:
|
||||||
|
"""
|
||||||
|
Synchronous version that asks a question to the brain and gets a generated answer.
|
||||||
|
Args:
|
||||||
|
question (str): The question to ask.
|
||||||
|
retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs).
|
||||||
|
rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use.
|
||||||
|
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
|
||||||
|
chat_history (ChatHistory | None): The chat history to use.
|
||||||
|
Returns:
|
||||||
|
ParsedRAGResponse: The generated answer.
|
||||||
|
"""
|
||||||
|
full_answer = ""
|
||||||
|
|
||||||
|
async for response in self.ask_streaming(
|
||||||
|
question=question,
|
||||||
|
retrieval_config=retrieval_config,
|
||||||
|
rag_pipeline=rag_pipeline,
|
||||||
|
list_files=list_files,
|
||||||
|
chat_history=chat_history,
|
||||||
|
):
|
||||||
|
full_answer += response.answer
|
||||||
|
|
||||||
|
return ParsedRAGResponse(answer=full_answer)
|
||||||
|
|
||||||
|
def ask(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
retrieval_config: RetrievalConfig | None = None,
|
||||||
|
rag_pipeline: Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None = None,
|
||||||
|
list_files: list[QuivrKnowledge] | None = None,
|
||||||
|
chat_history: ChatHistory | None = None,
|
||||||
|
) -> ParsedRAGResponse:
|
||||||
|
"""
|
||||||
|
Fully synchronous version that asks a question to the brain and gets a generated answer.
|
||||||
|
Args:
|
||||||
|
question (str): The question to ask.
|
||||||
|
retrieval_config (RetrievalConfig | None): The retrieval configuration (see RetrievalConfig docs).
|
||||||
|
rag_pipeline (Type[Union[QuivrQARAG, QuivrQARAGLangGraph]] | None): The RAG pipeline to use.
|
||||||
|
list_files (list[QuivrKnowledge] | None): The list of files to include in the RAG pipeline.
|
||||||
|
chat_history (ChatHistory | None): The chat history to use.
|
||||||
|
Returns:
|
||||||
|
ParsedRAGResponse: The generated answer.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self.aask(
|
||||||
|
question=question,
|
||||||
|
retrieval_config=retrieval_config,
|
||||||
|
rag_pipeline=rag_pipeline,
|
||||||
|
list_files=list_files,
|
||||||
|
chat_history=chat_history,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@ -4,7 +4,7 @@ from langchain_core.documents import Document
|
|||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
|
||||||
from quivr_core.config import LLMEndpointConfig
|
from quivr_core.rag.entities.config import DefaultModelSuppliers, LLMEndpointConfig
|
||||||
from quivr_core.llm import LLMEndpoint
|
from quivr_core.llm import LLMEndpoint
|
||||||
|
|
||||||
logger = logging.getLogger("quivr_core")
|
logger = logging.getLogger("quivr_core")
|
||||||
@ -46,7 +46,9 @@ def default_embedder() -> Embeddings:
|
|||||||
def default_llm() -> LLMEndpoint:
|
def default_llm() -> LLMEndpoint:
|
||||||
try:
|
try:
|
||||||
logger.debug("Loaded ChatOpenAI as default LLM for brain")
|
logger.debug("Loaded ChatOpenAI as default LLM for brain")
|
||||||
llm = LLMEndpoint.from_config(LLMEndpointConfig())
|
llm = LLMEndpoint.from_config(
|
||||||
|
LLMEndpointConfig(supplier=DefaultModelSuppliers.OPENAI, model="gpt-4o")
|
||||||
|
)
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
@ -4,9 +4,9 @@ from uuid import UUID
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
from quivr_core.config import LLMEndpointConfig
|
from quivr_core.rag.entities.config import LLMEndpointConfig
|
||||||
|
from quivr_core.rag.entities.models import ChatMessage
|
||||||
from quivr_core.files.file import QuivrFileSerialized
|
from quivr_core.files.file import QuivrFileSerialized
|
||||||
from quivr_core.models import ChatMessage
|
|
||||||
|
|
||||||
|
|
||||||
class EmbedderConfig(BaseModel):
|
class EmbedderConfig(BaseModel):
|
||||||
|
@ -1,15 +1,8 @@
|
|||||||
import os
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from quivr_core.base_config import QuivrBaseConfig
|
|
||||||
from quivr_core.processor.splitter import SplitterConfig
|
|
||||||
from quivr_core.prompts import CustomPromptsModel
|
|
||||||
|
|
||||||
|
|
||||||
class PdfParser(str, Enum):
|
class PdfParser(str, Enum):
|
||||||
LLAMA_PARSE = "llama_parse"
|
LLAMA_PARSE = "llama_parse"
|
||||||
@ -32,489 +25,3 @@ class MegaparseConfig(MegaparseBaseConfig):
|
|||||||
strategy: str = "fast"
|
strategy: str = "fast"
|
||||||
llama_parse_api_key: str | None = None
|
llama_parse_api_key: str | None = None
|
||||||
pdf_parser: PdfParser = PdfParser.UNSTRUCTURED
|
pdf_parser: PdfParser = PdfParser.UNSTRUCTURED
|
||||||
|
|
||||||
|
|
||||||
class BrainConfig(QuivrBaseConfig):
|
|
||||||
brain_id: UUID | None = None
|
|
||||||
name: str
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self) -> UUID | None:
|
|
||||||
return self.brain_id
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultRerankers(str, Enum):
|
|
||||||
"""
|
|
||||||
Enum representing the default API-based reranker suppliers supported by the application.
|
|
||||||
|
|
||||||
This enum defines the various reranker providers that can be used in the system.
|
|
||||||
Each enum value corresponds to a specific supplier's identifier and has an
|
|
||||||
associated default model.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
COHERE (str): Represents Cohere AI as a reranker supplier.
|
|
||||||
JINA (str): Represents Jina AI as a reranker supplier.
|
|
||||||
|
|
||||||
Methods:
|
|
||||||
default_model (property): Returns the default model for the selected supplier.
|
|
||||||
"""
|
|
||||||
|
|
||||||
COHERE = "cohere"
|
|
||||||
JINA = "jina"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_model(self) -> str:
|
|
||||||
"""
|
|
||||||
Get the default model for the selected reranker supplier.
|
|
||||||
|
|
||||||
This property method returns the default model associated with the current
|
|
||||||
reranker supplier (COHERE or JINA).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The name of the default model for the selected supplier.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If the current enum value doesn't have a corresponding default model.
|
|
||||||
"""
|
|
||||||
# Mapping of suppliers to their default models
|
|
||||||
return {
|
|
||||||
self.COHERE: "rerank-multilingual-v3.0",
|
|
||||||
self.JINA: "jina-reranker-v2-base-multilingual",
|
|
||||||
}[self]
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultModelSuppliers(str, Enum):
|
|
||||||
"""
|
|
||||||
Enum representing the default model suppliers supported by the application.
|
|
||||||
|
|
||||||
This enum defines the various AI model providers that can be used as sources
|
|
||||||
for LLMs in the system. Each enum value corresponds to a specific
|
|
||||||
supplier's identifier.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
OPENAI (str): Represents OpenAI as a model supplier.
|
|
||||||
AZURE (str): Represents Azure (Microsoft) as a model supplier.
|
|
||||||
ANTHROPIC (str): Represents Anthropic as a model supplier.
|
|
||||||
META (str): Represents Meta as a model supplier.
|
|
||||||
MISTRAL (str): Represents Mistral AI as a model supplier.
|
|
||||||
GROQ (str): Represents Groq as a model supplier.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OPENAI = "openai"
|
|
||||||
AZURE = "azure"
|
|
||||||
ANTHROPIC = "anthropic"
|
|
||||||
META = "meta"
|
|
||||||
MISTRAL = "mistral"
|
|
||||||
GROQ = "groq"
|
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig(QuivrBaseConfig):
|
|
||||||
context: int | None = None
|
|
||||||
tokenizer_hub: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class LLMModelConfig:
|
|
||||||
_model_defaults: Dict[DefaultModelSuppliers, Dict[str, LLMConfig]] = {
|
|
||||||
DefaultModelSuppliers.OPENAI: {
|
|
||||||
"gpt-4o": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
|
|
||||||
"gpt-4o-mini": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
|
|
||||||
"gpt-4-turbo": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4"),
|
|
||||||
"gpt-4": LLMConfig(context=8192, tokenizer_hub="Xenova/gpt-4"),
|
|
||||||
"gpt-3.5-turbo": LLMConfig(
|
|
||||||
context=16385, tokenizer_hub="Xenova/gpt-3.5-turbo"
|
|
||||||
),
|
|
||||||
"text-embedding-3-large": LLMConfig(
|
|
||||||
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
|
|
||||||
),
|
|
||||||
"text-embedding-3-small": LLMConfig(
|
|
||||||
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
|
|
||||||
),
|
|
||||||
"text-embedding-ada-002": LLMConfig(
|
|
||||||
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
DefaultModelSuppliers.ANTHROPIC: {
|
|
||||||
"claude-3-5-sonnet": LLMConfig(
|
|
||||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
|
||||||
),
|
|
||||||
"claude-3-opus": LLMConfig(
|
|
||||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
|
||||||
),
|
|
||||||
"claude-3-sonnet": LLMConfig(
|
|
||||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
|
||||||
),
|
|
||||||
"claude-3-haiku": LLMConfig(
|
|
||||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
|
||||||
),
|
|
||||||
"claude-2-1": LLMConfig(
|
|
||||||
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
|
||||||
),
|
|
||||||
"claude-2-0": LLMConfig(
|
|
||||||
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
|
|
||||||
),
|
|
||||||
"claude-instant-1-2": LLMConfig(
|
|
||||||
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
DefaultModelSuppliers.META: {
|
|
||||||
"llama-3.1": LLMConfig(
|
|
||||||
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
|
|
||||||
),
|
|
||||||
"llama-3": LLMConfig(
|
|
||||||
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
|
|
||||||
),
|
|
||||||
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
|
|
||||||
"code-llama": LLMConfig(
|
|
||||||
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
DefaultModelSuppliers.GROQ: {
|
|
||||||
"llama-3.1": LLMConfig(
|
|
||||||
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
|
|
||||||
),
|
|
||||||
"llama-3": LLMConfig(
|
|
||||||
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
|
|
||||||
),
|
|
||||||
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
|
|
||||||
"code-llama": LLMConfig(
|
|
||||||
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
DefaultModelSuppliers.MISTRAL: {
|
|
||||||
"mistral-large": LLMConfig(
|
|
||||||
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
|
|
||||||
),
|
|
||||||
"mistral-small": LLMConfig(
|
|
||||||
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
|
|
||||||
),
|
|
||||||
"mistral-nemo": LLMConfig(
|
|
||||||
context=128000, tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer"
|
|
||||||
),
|
|
||||||
"codestral": LLMConfig(
|
|
||||||
context=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_supplier_by_model_name(cls, model: str) -> DefaultModelSuppliers | None:
|
|
||||||
# Iterate over the suppliers and their models
|
|
||||||
for supplier, models in cls._model_defaults.items():
|
|
||||||
# Check if the model name or a base part of the model name is in the supplier's models
|
|
||||||
for base_model_name in models:
|
|
||||||
if model.startswith(base_model_name):
|
|
||||||
return supplier
|
|
||||||
# Return None if no supplier matches the model name
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_llm_model_config(
|
|
||||||
cls, supplier: DefaultModelSuppliers, model_name: str
|
|
||||||
) -> Optional[LLMConfig]:
|
|
||||||
"""Retrieve the LLMConfig (context and tokenizer_hub) for a given supplier and model."""
|
|
||||||
supplier_defaults = cls._model_defaults.get(supplier)
|
|
||||||
if not supplier_defaults:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Use startswith logic for matching model names
|
|
||||||
for key, config in supplier_defaults.items():
|
|
||||||
if model_name.startswith(key):
|
|
||||||
return config
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class LLMEndpointConfig(QuivrBaseConfig):
|
|
||||||
"""
|
|
||||||
Configuration class for Large Language Models (LLM) endpoints.
|
|
||||||
|
|
||||||
This class defines the settings and parameters for interacting with various LLM providers.
|
|
||||||
It includes configuration for the model, API keys, token limits, and other relevant settings.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
supplier (DefaultModelSuppliers): The LLM provider (default: OPENAI).
|
|
||||||
model (str): The specific model to use (default: "gpt-4o").
|
|
||||||
context_length (int | None): The maximum context length for the model.
|
|
||||||
tokenizer_hub (str | None): The tokenizer to use for this model.
|
|
||||||
llm_base_url (str | None): Base URL for the LLM API.
|
|
||||||
env_variable_name (str): Name of the environment variable for the API key.
|
|
||||||
llm_api_key (str | None): The API key for the LLM provider.
|
|
||||||
max_input_tokens (int): Maximum number of input tokens sent to the LLM (default: 2000).
|
|
||||||
max_output_tokens (int): Maximum number of output tokens returned by the LLM (default: 2000).
|
|
||||||
temperature (float): Temperature setting for text generation (default: 0.7).
|
|
||||||
streaming (bool): Whether to use streaming for responses (default: True).
|
|
||||||
prompt (CustomPromptsModel | None): Custom prompt configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
supplier: DefaultModelSuppliers = DefaultModelSuppliers.OPENAI
|
|
||||||
model: str = "gpt-4o"
|
|
||||||
context_length: int | None = None
|
|
||||||
tokenizer_hub: str | None = None
|
|
||||||
llm_base_url: str | None = None
|
|
||||||
env_variable_name: str = f"{supplier.upper()}_API_KEY"
|
|
||||||
llm_api_key: str | None = None
|
|
||||||
max_input_tokens: int = 2000
|
|
||||||
max_output_tokens: int = 2000
|
|
||||||
temperature: float = 0.7
|
|
||||||
streaming: bool = True
|
|
||||||
prompt: CustomPromptsModel | None = None
|
|
||||||
|
|
||||||
_FALLBACK_TOKENIZER = "cl100k_base"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def fallback_tokenizer(self) -> str:
|
|
||||||
"""
|
|
||||||
Get the fallback tokenizer.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The name of the fallback tokenizer.
|
|
||||||
"""
|
|
||||||
return self._FALLBACK_TOKENIZER
|
|
||||||
|
|
||||||
def __init__(self, **data):
|
|
||||||
"""
|
|
||||||
Initialize the LLMEndpointConfig.
|
|
||||||
|
|
||||||
This method sets up the initial configuration, including setting the LLM model
|
|
||||||
config and API key.
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__(**data)
|
|
||||||
self.set_llm_model_config()
|
|
||||||
self.set_api_key()
|
|
||||||
|
|
||||||
def set_api_key(self, force_reset: bool = False):
|
|
||||||
"""
|
|
||||||
Set the API key for the LLM provider.
|
|
||||||
|
|
||||||
This method attempts to set the API key from the environment variable.
|
|
||||||
If the key is not found, it raises a ValueError.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
force_reset (bool): If True, forces a reset of the API key even if already set.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the API key is not set in the environment.
|
|
||||||
"""
|
|
||||||
if not self.llm_api_key or force_reset:
|
|
||||||
self.llm_api_key = os.getenv(self.env_variable_name)
|
|
||||||
|
|
||||||
if not self.llm_api_key:
|
|
||||||
raise ValueError(
|
|
||||||
f"The API key for supplier '{self.supplier}' is not set. "
|
|
||||||
f"Please set the environment variable: {self.env_variable_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_llm_model_config(self):
|
|
||||||
"""
|
|
||||||
Set the LLM model configuration.
|
|
||||||
|
|
||||||
This method automatically sets the context_length and tokenizer_hub
|
|
||||||
based on the current supplier and model.
|
|
||||||
"""
|
|
||||||
llm_model_config = LLMModelConfig.get_llm_model_config(
|
|
||||||
self.supplier, self.model
|
|
||||||
)
|
|
||||||
if llm_model_config:
|
|
||||||
self.context_length = llm_model_config.context
|
|
||||||
self.tokenizer_hub = llm_model_config.tokenizer_hub
|
|
||||||
|
|
||||||
def set_llm_model(self, model: str):
|
|
||||||
"""
|
|
||||||
Set the LLM model and update related configurations.
|
|
||||||
|
|
||||||
This method updates the supplier and model based on the provided model name,
|
|
||||||
then updates the model config and API key accordingly.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (str): The name of the model to set.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If no corresponding supplier is found for the given model.
|
|
||||||
"""
|
|
||||||
supplier = LLMModelConfig.get_supplier_by_model_name(model)
|
|
||||||
if supplier is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot find the corresponding supplier for model {model}"
|
|
||||||
)
|
|
||||||
self.supplier = supplier
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
self.set_llm_model_config()
|
|
||||||
self.set_api_key(force_reset=True)
|
|
||||||
|
|
||||||
def set_from_sqlmodel(self, sqlmodel: BaseModel, mapping: Dict[str, str]):
|
|
||||||
"""
|
|
||||||
Set attributes in LLMEndpointConfig from SQLModel attributes using a field mapping.
|
|
||||||
|
|
||||||
This method allows for dynamic setting of LLMEndpointConfig attributes based on
|
|
||||||
a provided SQLModel instance and a mapping dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sqlmodel (SQLModel): An instance of the SQLModel class.
|
|
||||||
mapping (Dict[str, str]): A dictionary that maps SQLModel fields to LLMEndpointConfig fields.
|
|
||||||
Example: {"max_input": "max_input_tokens", "env_variable_name": "env_variable_name"}
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AttributeError: If any field in the mapping doesn't exist in either the SQLModel or LLMEndpointConfig.
|
|
||||||
"""
|
|
||||||
for model_field, llm_field in mapping.items():
|
|
||||||
if hasattr(sqlmodel, model_field) and hasattr(self, llm_field):
|
|
||||||
setattr(self, llm_field, getattr(sqlmodel, model_field))
|
|
||||||
else:
|
|
||||||
raise AttributeError(
|
|
||||||
f"Invalid mapping: {model_field} or {llm_field} does not exist."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Cannot use Pydantic v2 field_validator because of conflicts with pydantic v1 still in use in LangChain
|
|
||||||
class RerankerConfig(QuivrBaseConfig):
|
|
||||||
"""
|
|
||||||
Configuration class for reranker models.
|
|
||||||
|
|
||||||
This class defines the settings for reranker models used in the application,
|
|
||||||
including the supplier, model, and API key information.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
supplier (DefaultRerankers | None): The reranker supplier (e.g., COHERE).
|
|
||||||
model (str | None): The specific reranker model to use.
|
|
||||||
top_n (int): The number of top chunks returned by the reranker (default: 5).
|
|
||||||
api_key (str | None): The API key for the reranker service.
|
|
||||||
"""
|
|
||||||
|
|
||||||
supplier: DefaultRerankers | None = None
|
|
||||||
model: str | None = None
|
|
||||||
top_n: int = 5
|
|
||||||
api_key: str | None = None
|
|
||||||
|
|
||||||
def __init__(self, **data):
|
|
||||||
"""
|
|
||||||
Initialize the RerankerConfig.
|
|
||||||
"""
|
|
||||||
super().__init__(**data)
|
|
||||||
self.validate_model()
|
|
||||||
|
|
||||||
def validate_model(self):
|
|
||||||
"""
|
|
||||||
Validate and set up the reranker model configuration.
|
|
||||||
|
|
||||||
This method ensures that a model is set (using the default if not provided)
|
|
||||||
and that the necessary API key is available in the environment.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the required API key is not set in the environment.
|
|
||||||
"""
|
|
||||||
if self.model is None and self.supplier is not None:
|
|
||||||
self.model = self.supplier.default_model
|
|
||||||
|
|
||||||
if self.supplier:
|
|
||||||
api_key_var = f"{self.supplier.upper()}_API_KEY"
|
|
||||||
self.api_key = os.getenv(api_key_var)
|
|
||||||
|
|
||||||
if self.api_key is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"The API key for supplier '{self.supplier}' is not set. "
|
|
||||||
f"Please set the environment variable: {api_key_var}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class NodeConfig(QuivrBaseConfig):
|
|
||||||
"""
|
|
||||||
Configuration class for a node in an AI assistant workflow.
|
|
||||||
|
|
||||||
This class represents a single node in a workflow configuration,
|
|
||||||
defining its name and connections to other nodes.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
name (str): The name of the node.
|
|
||||||
edges (List[str]): List of names of other nodes this node links to.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
edges: List[str]
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowConfig(QuivrBaseConfig):
|
|
||||||
"""
|
|
||||||
Configuration class for an AI assistant workflow.
|
|
||||||
|
|
||||||
This class represents the entire workflow configuration,
|
|
||||||
consisting of multiple interconnected nodes.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
name (str): The name of the workflow.
|
|
||||||
nodes (List[NodeConfig]): List of nodes in the workflow.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
nodes: List[NodeConfig]
|
|
||||||
|
|
||||||
|
|
||||||
class RetrievalConfig(QuivrBaseConfig):
|
|
||||||
"""
|
|
||||||
Configuration class for the retrieval phase of a RAG assistant.
|
|
||||||
|
|
||||||
This class defines the settings for the retrieval process,
|
|
||||||
including reranker and LLM configurations, as well as various limits and prompts.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
workflow_config (WorkflowConfig | None): Configuration for the workflow.
|
|
||||||
reranker_config (RerankerConfig): Configuration for the reranker.
|
|
||||||
llm_config (LLMEndpointConfig): Configuration for the LLM endpoint.
|
|
||||||
max_history (int): Maximum number of past conversation turns to pass to the LLM as context (default: 10).
|
|
||||||
max_files (int): Maximum number of files to process (default: 20).
|
|
||||||
prompt (str | None): Custom prompt for the retrieval process.
|
|
||||||
"""
|
|
||||||
|
|
||||||
workflow_config: WorkflowConfig | None = None
|
|
||||||
reranker_config: RerankerConfig = RerankerConfig()
|
|
||||||
llm_config: LLMEndpointConfig = LLMEndpointConfig()
|
|
||||||
max_history: int = 10
|
|
||||||
max_files: int = 20
|
|
||||||
prompt: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ParserConfig(QuivrBaseConfig):
|
|
||||||
"""
|
|
||||||
Configuration class for the parser.
|
|
||||||
|
|
||||||
This class defines the settings for the parsing process,
|
|
||||||
including configurations for the text splitter and Megaparse.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
splitter_config (SplitterConfig): Configuration for the text splitter.
|
|
||||||
megaparse_config (MegaparseConfig): Configuration for Megaparse.
|
|
||||||
"""
|
|
||||||
|
|
||||||
splitter_config: SplitterConfig = SplitterConfig()
|
|
||||||
megaparse_config: MegaparseConfig = MegaparseConfig()
|
|
||||||
|
|
||||||
|
|
||||||
class IngestionConfig(QuivrBaseConfig):
|
|
||||||
"""
|
|
||||||
Configuration class for the data ingestion process.
|
|
||||||
|
|
||||||
This class defines the settings for the data ingestion process,
|
|
||||||
including the parser configuration.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
parser_config (ParserConfig): Configuration for the parser.
|
|
||||||
"""
|
|
||||||
|
|
||||||
parser_config: ParserConfig = ParserConfig()
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantConfig(QuivrBaseConfig):
|
|
||||||
"""
|
|
||||||
Configuration class for an AI assistant.
|
|
||||||
|
|
||||||
This class defines the overall configuration for an AI assistant,
|
|
||||||
including settings for retrieval and ingestion processes.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
retrieval_config (RetrievalConfig): Configuration for the retrieval process.
|
|
||||||
ingestion_config (IngestionConfig): Configuration for the ingestion process.
|
|
||||||
"""
|
|
||||||
|
|
||||||
retrieval_config: RetrievalConfig = RetrievalConfig()
|
|
||||||
ingestion_config: IngestionConfig = IngestionConfig()
|
|
||||||
|
@ -10,8 +10,8 @@ from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
|||||||
from pydantic.v1 import SecretStr
|
from pydantic.v1 import SecretStr
|
||||||
|
|
||||||
from quivr_core.brain.info import LLMInfo
|
from quivr_core.brain.info import LLMInfo
|
||||||
from quivr_core.config import DefaultModelSuppliers, LLMEndpointConfig
|
from quivr_core.rag.entities.config import DefaultModelSuppliers, LLMEndpointConfig
|
||||||
from quivr_core.utils import model_supports_function_calling
|
from quivr_core.rag.utils import model_supports_function_calling
|
||||||
|
|
||||||
logger = logging.getLogger("quivr_core")
|
logger = logging.getLogger("quivr_core")
|
||||||
|
|
||||||
@ -70,6 +70,7 @@ class LLMEndpoint:
|
|||||||
else None,
|
else None,
|
||||||
azure_endpoint=azure_endpoint,
|
azure_endpoint=azure_endpoint,
|
||||||
max_tokens=config.max_output_tokens,
|
max_tokens=config.max_output_tokens,
|
||||||
|
temperature=config.temperature,
|
||||||
)
|
)
|
||||||
elif config.supplier == DefaultModelSuppliers.ANTHROPIC:
|
elif config.supplier == DefaultModelSuppliers.ANTHROPIC:
|
||||||
_llm = ChatAnthropic(
|
_llm = ChatAnthropic(
|
||||||
@ -79,6 +80,7 @@ class LLMEndpoint:
|
|||||||
else None,
|
else None,
|
||||||
base_url=config.llm_base_url,
|
base_url=config.llm_base_url,
|
||||||
max_tokens=config.max_output_tokens,
|
max_tokens=config.max_output_tokens,
|
||||||
|
temperature=config.temperature,
|
||||||
)
|
)
|
||||||
elif config.supplier == DefaultModelSuppliers.OPENAI:
|
elif config.supplier == DefaultModelSuppliers.OPENAI:
|
||||||
_llm = ChatOpenAI(
|
_llm = ChatOpenAI(
|
||||||
@ -88,6 +90,7 @@ class LLMEndpoint:
|
|||||||
else None,
|
else None,
|
||||||
base_url=config.llm_base_url,
|
base_url=config.llm_base_url,
|
||||||
max_tokens=config.max_output_tokens,
|
max_tokens=config.max_output_tokens,
|
||||||
|
temperature=config.temperature,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_llm = ChatOpenAI(
|
_llm = ChatOpenAI(
|
||||||
@ -97,6 +100,7 @@ class LLMEndpoint:
|
|||||||
else None,
|
else None,
|
||||||
base_url=config.llm_base_url,
|
base_url=config.llm_base_url,
|
||||||
max_tokens=config.max_output_tokens,
|
max_tokens=config.max_output_tokens,
|
||||||
|
temperature=config.temperature,
|
||||||
)
|
)
|
||||||
return cls(llm=_llm, llm_config=config)
|
return cls(llm=_llm, llm_config=config)
|
||||||
|
|
||||||
@ -118,3 +122,7 @@ class LLMEndpoint:
|
|||||||
max_tokens=self._config.max_output_tokens,
|
max_tokens=self._config.max_output_tokens,
|
||||||
supports_function_calling=self.supports_func_calling(),
|
supports_function_calling=self.supports_func_calling(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def clone_llm(self):
|
||||||
|
"""Create a new instance of the LLM with the same configuration."""
|
||||||
|
return self._llm.__class__(**self._llm.__dict__)
|
||||||
|
0
core/quivr_core/llm_tools/__init__.py
Normal file
0
core/quivr_core/llm_tools/__init__.py
Normal file
36
core/quivr_core/llm_tools/entity.py
Normal file
36
core/quivr_core/llm_tools/entity.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from quivr_core.base_config import QuivrBaseConfig
|
||||||
|
from typing import Callable
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class ToolsCategory(QuivrBaseConfig):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
tools: list
|
||||||
|
default_tool: str | None = None
|
||||||
|
create_tool: Callable
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
self.name = self.name.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class ToolWrapper:
|
||||||
|
def __init__(self, tool: BaseTool, format_input: Callable, format_output: Callable):
|
||||||
|
self.tool = tool
|
||||||
|
self.format_input = format_input
|
||||||
|
self.format_output = format_output
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRegistry:
|
||||||
|
def __init__(self):
|
||||||
|
self._registry = {}
|
||||||
|
|
||||||
|
def register_tool(self, tool_name: str, create_func: Callable):
|
||||||
|
self._registry[tool_name] = create_func
|
||||||
|
|
||||||
|
def create_tool(self, tool_name: str, config: Dict[str, Any]) -> ToolWrapper:
|
||||||
|
if tool_name not in self._registry:
|
||||||
|
raise ValueError(f"Tool {tool_name} is not supported.")
|
||||||
|
return self._registry[tool_name](config)
|
33
core/quivr_core/llm_tools/llm_tools.py
Normal file
33
core/quivr_core/llm_tools/llm_tools.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from typing import Dict, Any, Type, Union
|
||||||
|
|
||||||
|
from quivr_core.llm_tools.entity import ToolWrapper
|
||||||
|
|
||||||
|
from quivr_core.llm_tools.web_search_tools import (
|
||||||
|
WebSearchTools,
|
||||||
|
)
|
||||||
|
|
||||||
|
from quivr_core.llm_tools.other_tools import (
|
||||||
|
OtherTools,
|
||||||
|
)
|
||||||
|
|
||||||
|
TOOLS_CATEGORIES = {
|
||||||
|
WebSearchTools.name: WebSearchTools,
|
||||||
|
OtherTools.name: OtherTools,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Register all ToolsList enums
|
||||||
|
TOOLS_LISTS = {
|
||||||
|
**{tool.value: tool for tool in WebSearchTools.tools},
|
||||||
|
**{tool.value: tool for tool in OtherTools.tools},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LLMToolFactory:
|
||||||
|
@staticmethod
|
||||||
|
def create_tool(tool_name: str, config: Dict[str, Any]) -> Union[ToolWrapper, Type]:
|
||||||
|
for category, tools_class in TOOLS_CATEGORIES.items():
|
||||||
|
if tool_name in tools_class.tools:
|
||||||
|
return tools_class.create_tool(tool_name, config)
|
||||||
|
elif tool_name.lower() == category and tools_class.default_tool:
|
||||||
|
return tools_class.create_tool(tools_class.default_tool, config)
|
||||||
|
raise ValueError(f"Tool {tool_name} is not supported.")
|
24
core/quivr_core/llm_tools/other_tools.py
Normal file
24
core/quivr_core/llm_tools/other_tools.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Any, Type, Union
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from quivr_core.llm_tools.entity import ToolsCategory
|
||||||
|
from quivr_core.rag.entities.models import cited_answer
|
||||||
|
|
||||||
|
|
||||||
|
class OtherToolsList(str, Enum):
|
||||||
|
CITED_ANSWER = "cited_answer"
|
||||||
|
|
||||||
|
|
||||||
|
def create_other_tool(tool_name: str, config: Dict[str, Any]) -> Union[BaseTool, Type]:
|
||||||
|
if tool_name == OtherToolsList.CITED_ANSWER:
|
||||||
|
return cited_answer
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Tool {tool_name} is not supported.")
|
||||||
|
|
||||||
|
|
||||||
|
OtherTools = ToolsCategory(
|
||||||
|
name="Other",
|
||||||
|
description="Other tools",
|
||||||
|
tools=[OtherToolsList.CITED_ANSWER],
|
||||||
|
create_tool=create_other_tool,
|
||||||
|
)
|
73
core/quivr_core/llm_tools/web_search_tools.py
Normal file
73
core/quivr_core/llm_tools/web_search_tools.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Any
|
||||||
|
from langchain_community.tools import TavilySearchResults
|
||||||
|
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
|
||||||
|
from quivr_core.llm_tools.entity import ToolsCategory
|
||||||
|
import os
|
||||||
|
from pydantic.v1 import SecretStr as SecretStrV1 # Ensure correct import
|
||||||
|
from quivr_core.llm_tools.entity import ToolWrapper, ToolRegistry
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchToolsList(str, Enum):
|
||||||
|
TAVILY = "tavily"
|
||||||
|
|
||||||
|
|
||||||
|
def create_tavily_tool(config: Dict[str, Any]) -> ToolWrapper:
|
||||||
|
api_key = (
|
||||||
|
config.pop("api_key") if "api_key" in config else os.getenv("TAVILY_API_KEY")
|
||||||
|
)
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing required config key 'api_key' or environment variable 'TAVILY_API_KEY'"
|
||||||
|
)
|
||||||
|
|
||||||
|
tavily_api_wrapper = TavilySearchAPIWrapper(
|
||||||
|
tavily_api_key=SecretStrV1(api_key),
|
||||||
|
)
|
||||||
|
tool = TavilySearchResults(
|
||||||
|
api_wrapper=tavily_api_wrapper,
|
||||||
|
max_results=config.pop("max_results", 5),
|
||||||
|
search_depth=config.pop("search_depth", "advanced"),
|
||||||
|
include_answer=config.pop("include_answer", True),
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool.name = WebSearchToolsList.TAVILY.value
|
||||||
|
|
||||||
|
def format_input(task: str) -> Dict[str, Any]:
|
||||||
|
return {"query": task}
|
||||||
|
|
||||||
|
def format_output(response: Any) -> List[Document]:
|
||||||
|
metadata = {"integration": "", "integration_link": ""}
|
||||||
|
return [
|
||||||
|
Document(
|
||||||
|
page_content=d["content"],
|
||||||
|
metadata={
|
||||||
|
**metadata,
|
||||||
|
"file_name": d["url"],
|
||||||
|
"original_file_name": d["url"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for d in response
|
||||||
|
]
|
||||||
|
|
||||||
|
return ToolWrapper(tool, format_input, format_output)
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize the registry and register tools
|
||||||
|
web_search_tool_registry = ToolRegistry()
|
||||||
|
web_search_tool_registry.register_tool(WebSearchToolsList.TAVILY, create_tavily_tool)
|
||||||
|
|
||||||
|
|
||||||
|
def create_web_search_tool(tool_name: str, config: Dict[str, Any]) -> ToolWrapper:
|
||||||
|
return web_search_tool_registry.create_tool(tool_name, config)
|
||||||
|
|
||||||
|
|
||||||
|
WebSearchTools = ToolsCategory(
|
||||||
|
name="Web Search",
|
||||||
|
description="Tools for web searching",
|
||||||
|
tools=[WebSearchToolsList.TAVILY],
|
||||||
|
default_tool=WebSearchToolsList.TAVILY,
|
||||||
|
create_tool=create_web_search_tool,
|
||||||
|
)
|
@ -1,119 +0,0 @@
|
|||||||
import datetime
|
|
||||||
|
|
||||||
from langchain_core.prompts import (
|
|
||||||
ChatPromptTemplate,
|
|
||||||
HumanMessagePromptTemplate,
|
|
||||||
MessagesPlaceholder,
|
|
||||||
PromptTemplate,
|
|
||||||
SystemMessagePromptTemplate,
|
|
||||||
)
|
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
|
||||||
from pydantic import ConfigDict, create_model
|
|
||||||
|
|
||||||
|
|
||||||
class CustomPromptsDict(dict):
|
|
||||||
def __init__(self, type, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._type = type
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
# Automatically convert the value into a tuple (my_type, value)
|
|
||||||
super().__setitem__(key, (self._type, value))
|
|
||||||
|
|
||||||
|
|
||||||
def _define_custom_prompts() -> CustomPromptsDict:
|
|
||||||
custom_prompts: CustomPromptsDict = CustomPromptsDict(type=BasePromptTemplate)
|
|
||||||
|
|
||||||
today_date = datetime.datetime.now().strftime("%B %d, %Y")
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Prompt for question rephrasing
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. Keep as much details as possible from previous messages. Keep entity names and all.
|
|
||||||
|
|
||||||
Chat History:
|
|
||||||
{chat_history}
|
|
||||||
Follow Up Input: {question}
|
|
||||||
Standalone question:"""
|
|
||||||
|
|
||||||
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
|
|
||||||
custom_prompts["CONDENSE_QUESTION_PROMPT"] = CONDENSE_QUESTION_PROMPT
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Prompt for RAG
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
system_message_template = (
|
|
||||||
f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}."
|
|
||||||
)
|
|
||||||
|
|
||||||
system_message_template += """
|
|
||||||
When answering use markdown.
|
|
||||||
Use markdown code blocks for code snippets.
|
|
||||||
Answer in a concise and clear manner.
|
|
||||||
Use the following pieces of context from files provided by the user to answer the users.
|
|
||||||
Answer in the same language as the user question.
|
|
||||||
If you don't know the answer with the context provided from the files, just say that you don't know, don't try to make up an answer.
|
|
||||||
Don't cite the source id in the answer objects, but you can use the source to answer the question.
|
|
||||||
You have access to the files to answer the user question (limited to first 20 files):
|
|
||||||
{files}
|
|
||||||
|
|
||||||
If not None, User instruction to follow to answer: {custom_instructions}
|
|
||||||
Don't cite the source id in the answer objects, but you can use the source to answer the question.
|
|
||||||
"""
|
|
||||||
|
|
||||||
template_answer = """
|
|
||||||
Context:
|
|
||||||
{context}
|
|
||||||
|
|
||||||
User Question: {question}
|
|
||||||
Answer:
|
|
||||||
"""
|
|
||||||
|
|
||||||
RAG_ANSWER_PROMPT = ChatPromptTemplate.from_messages(
|
|
||||||
[
|
|
||||||
SystemMessagePromptTemplate.from_template(system_message_template),
|
|
||||||
HumanMessagePromptTemplate.from_template(template_answer),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
custom_prompts["RAG_ANSWER_PROMPT"] = RAG_ANSWER_PROMPT
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Prompt for formatting documents
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
|
|
||||||
template="Source: {index} \n {page_content}"
|
|
||||||
)
|
|
||||||
custom_prompts["DEFAULT_DOCUMENT_PROMPT"] = DEFAULT_DOCUMENT_PROMPT
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Prompt for chatting directly with LLMs, without any document retrieval stage
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
system_message_template = (
|
|
||||||
f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}."
|
|
||||||
)
|
|
||||||
system_message_template += """
|
|
||||||
If not None, also follow these user instructions when answering: {custom_instructions}
|
|
||||||
"""
|
|
||||||
|
|
||||||
template_answer = """
|
|
||||||
User Question: {question}
|
|
||||||
Answer:
|
|
||||||
"""
|
|
||||||
CHAT_LLM_PROMPT = ChatPromptTemplate.from_messages(
|
|
||||||
[
|
|
||||||
SystemMessagePromptTemplate.from_template(system_message_template),
|
|
||||||
MessagesPlaceholder(variable_name="chat_history"),
|
|
||||||
HumanMessagePromptTemplate.from_template(template_answer),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
custom_prompts["CHAT_LLM_PROMPT"] = CHAT_LLM_PROMPT
|
|
||||||
|
|
||||||
return custom_prompts
|
|
||||||
|
|
||||||
|
|
||||||
_custom_prompts = _define_custom_prompts()
|
|
||||||
CustomPromptsModel = create_model(
|
|
||||||
"CustomPromptsModel", **_custom_prompts, __config__=ConfigDict(extra="forbid")
|
|
||||||
)
|
|
||||||
|
|
||||||
custom_prompts = CustomPromptsModel()
|
|
@ -1,488 +0,0 @@
|
|||||||
import logging
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Annotated, AsyncGenerator, Optional, Sequence, TypedDict
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
# TODO(@aminediro): this is the only dependency to langchain package, we should remove it
|
|
||||||
from langchain.retrievers import ContextualCompressionRetriever
|
|
||||||
from langchain_cohere import CohereRerank
|
|
||||||
from langchain_community.document_compressors import JinaRerank
|
|
||||||
from langchain_core.callbacks import Callbacks
|
|
||||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
|
||||||
from langchain_core.messages import BaseMessage
|
|
||||||
from langchain_core.messages.ai import AIMessageChunk
|
|
||||||
from langchain_core.vectorstores import VectorStore
|
|
||||||
from langgraph.graph import END, START, StateGraph
|
|
||||||
from langgraph.graph.message import add_messages
|
|
||||||
|
|
||||||
from quivr_core.chat import ChatHistory
|
|
||||||
from quivr_core.config import DefaultRerankers, RetrievalConfig
|
|
||||||
from quivr_core.llm import LLMEndpoint
|
|
||||||
from quivr_core.models import (
|
|
||||||
ParsedRAGChunkResponse,
|
|
||||||
ParsedRAGResponse,
|
|
||||||
QuivrKnowledge,
|
|
||||||
RAGResponseMetadata,
|
|
||||||
cited_answer,
|
|
||||||
)
|
|
||||||
from quivr_core.prompts import custom_prompts
|
|
||||||
from quivr_core.utils import (
|
|
||||||
combine_documents,
|
|
||||||
format_file_list,
|
|
||||||
get_chunk_metadata,
|
|
||||||
parse_chunk_response,
|
|
||||||
parse_response,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger("quivr_core")
|
|
||||||
|
|
||||||
|
|
||||||
class SpecialEdges(str, Enum):
|
|
||||||
START = "START"
|
|
||||||
END = "END"
|
|
||||||
|
|
||||||
|
|
||||||
class AgentState(TypedDict):
|
|
||||||
# The add_messages function defines how an update should be processed
|
|
||||||
# Default is to replace. add_messages says "append"
|
|
||||||
messages: Annotated[Sequence[BaseMessage], add_messages]
|
|
||||||
chat_history: ChatHistory
|
|
||||||
docs: list[Document]
|
|
||||||
files: str
|
|
||||||
final_response: dict
|
|
||||||
|
|
||||||
|
|
||||||
class IdempotentCompressor(BaseDocumentCompressor):
|
|
||||||
def compress_documents(
|
|
||||||
self,
|
|
||||||
documents: Sequence[Document],
|
|
||||||
query: str,
|
|
||||||
callbacks: Optional[Callbacks] = None,
|
|
||||||
) -> Sequence[Document]:
|
|
||||||
"""
|
|
||||||
A no-op document compressor that simply returns the documents it is given.
|
|
||||||
|
|
||||||
This is a placeholder until a more sophisticated document compression
|
|
||||||
algorithm is implemented.
|
|
||||||
"""
|
|
||||||
return documents
|
|
||||||
|
|
||||||
|
|
||||||
class QuivrQARAGLangGraph:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
retrieval_config: RetrievalConfig,
|
|
||||||
llm: LLMEndpoint,
|
|
||||||
vector_store: VectorStore | None = None,
|
|
||||||
reranker: BaseDocumentCompressor | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Construct a QuivrQARAGLangGraph object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
retrieval_config (RetrievalConfig): The configuration for the RAG model.
|
|
||||||
llm (LLMEndpoint): The LLM to use for generating text.
|
|
||||||
vector_store (VectorStore): The vector store to use for storing and retrieving documents.
|
|
||||||
reranker (BaseDocumentCompressor | None): The document compressor to use for re-ranking documents. Defaults to IdempotentCompressor if not provided.
|
|
||||||
"""
|
|
||||||
self.retrieval_config = retrieval_config
|
|
||||||
self.vector_store = vector_store
|
|
||||||
self.llm_endpoint = llm
|
|
||||||
|
|
||||||
self.graph = None
|
|
||||||
|
|
||||||
if reranker is not None:
|
|
||||||
self.reranker = reranker
|
|
||||||
elif self.retrieval_config.reranker_config.supplier == DefaultRerankers.COHERE:
|
|
||||||
self.reranker = CohereRerank(
|
|
||||||
model=self.retrieval_config.reranker_config.model,
|
|
||||||
top_n=self.retrieval_config.reranker_config.top_n,
|
|
||||||
cohere_api_key=self.retrieval_config.reranker_config.api_key,
|
|
||||||
)
|
|
||||||
elif self.retrieval_config.reranker_config.supplier == DefaultRerankers.JINA:
|
|
||||||
self.reranker = JinaRerank(
|
|
||||||
model=self.retrieval_config.reranker_config.model,
|
|
||||||
top_n=self.retrieval_config.reranker_config.top_n,
|
|
||||||
jina_api_key=self.retrieval_config.reranker_config.api_key,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.reranker = IdempotentCompressor()
|
|
||||||
|
|
||||||
if self.vector_store:
|
|
||||||
self.compression_retriever = ContextualCompressionRetriever(
|
|
||||||
base_compressor=self.reranker, base_retriever=self.retriever
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def retriever(self):
|
|
||||||
"""
|
|
||||||
Returns a retriever that can retrieve documents from the vector store.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
VectorStoreRetriever: The retriever.
|
|
||||||
"""
|
|
||||||
if self.vector_store:
|
|
||||||
return self.vector_store.as_retriever()
|
|
||||||
else:
|
|
||||||
raise ValueError("No vector store provided")
|
|
||||||
|
|
||||||
def filter_history(self, state: AgentState) -> dict:
|
|
||||||
"""
|
|
||||||
Filter out the chat history to only include the messages that are relevant to the current question
|
|
||||||
|
|
||||||
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '),
|
|
||||||
AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer,
|
|
||||||
sous la direction de son supérieur hiérarchique, Stanislas Girard."),
|
|
||||||
HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''),
|
|
||||||
HumanMessage(content='Dis moi en plus sur elle'),
|
|
||||||
AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
|
|
||||||
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
|
|
||||||
a token is 4 characters
|
|
||||||
"""
|
|
||||||
chat_history = state["chat_history"]
|
|
||||||
total_tokens = 0
|
|
||||||
total_pairs = 0
|
|
||||||
_chat_id = uuid4()
|
|
||||||
_chat_history = ChatHistory(chat_id=_chat_id, brain_id=chat_history.brain_id)
|
|
||||||
for human_message, ai_message in reversed(list(chat_history.iter_pairs())):
|
|
||||||
# TODO: replace with tiktoken
|
|
||||||
message_tokens = self.llm_endpoint.count_tokens(
|
|
||||||
human_message.content
|
|
||||||
) + self.llm_endpoint.count_tokens(ai_message.content)
|
|
||||||
if (
|
|
||||||
total_tokens + message_tokens
|
|
||||||
> self.retrieval_config.llm_config.max_output_tokens
|
|
||||||
or total_pairs >= self.retrieval_config.max_history
|
|
||||||
):
|
|
||||||
break
|
|
||||||
_chat_history.append(human_message)
|
|
||||||
_chat_history.append(ai_message)
|
|
||||||
total_tokens += message_tokens
|
|
||||||
total_pairs += 1
|
|
||||||
|
|
||||||
return {"chat_history": _chat_history}
|
|
||||||
|
|
||||||
### Nodes
|
|
||||||
def rewrite(self, state):
|
|
||||||
"""
|
|
||||||
Transform the query to produce a better question.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (messages): The current state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: The updated state with re-phrased question
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Grader
|
|
||||||
msg = custom_prompts.CONDENSE_QUESTION_PROMPT.format(
|
|
||||||
chat_history=state["chat_history"],
|
|
||||||
question=state["messages"][0].content,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = self.llm_endpoint._llm
|
|
||||||
response = model.invoke(msg)
|
|
||||||
return {"messages": [response]}
|
|
||||||
|
|
||||||
def retrieve(self, state):
|
|
||||||
"""
|
|
||||||
Retrieve relevent chunks
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (messages): The current state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: The retrieved chunks
|
|
||||||
"""
|
|
||||||
question = state["messages"][-1].content
|
|
||||||
docs = self.compression_retriever.invoke(question)
|
|
||||||
return {"docs": docs}
|
|
||||||
|
|
||||||
def generate_rag(self, state):
|
|
||||||
"""
|
|
||||||
Generate answer
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (messages): The current state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: The updated state with re-phrased question
|
|
||||||
"""
|
|
||||||
messages = state["messages"]
|
|
||||||
user_question = messages[0].content
|
|
||||||
files = state["files"]
|
|
||||||
|
|
||||||
docs = state["docs"]
|
|
||||||
|
|
||||||
# Prompt
|
|
||||||
prompt = self.retrieval_config.prompt
|
|
||||||
|
|
||||||
final_inputs = {}
|
|
||||||
final_inputs["context"] = combine_documents(docs) if docs else "None"
|
|
||||||
final_inputs["question"] = user_question
|
|
||||||
final_inputs["custom_instructions"] = prompt if prompt else "None"
|
|
||||||
final_inputs["files"] = files if files else "None"
|
|
||||||
|
|
||||||
# LLM
|
|
||||||
llm = self.llm_endpoint._llm
|
|
||||||
if self.llm_endpoint.supports_func_calling():
|
|
||||||
llm = self.llm_endpoint._llm.bind_tools(
|
|
||||||
[cited_answer],
|
|
||||||
tool_choice="any",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Chain
|
|
||||||
rag_chain = custom_prompts.RAG_ANSWER_PROMPT | llm
|
|
||||||
|
|
||||||
# Run
|
|
||||||
response = rag_chain.invoke(final_inputs)
|
|
||||||
formatted_response = {
|
|
||||||
"answer": response, # Assuming the last message contains the final answer
|
|
||||||
"docs": docs,
|
|
||||||
}
|
|
||||||
return {"messages": [response], "final_response": formatted_response}
|
|
||||||
|
|
||||||
def generate_chat_llm(self, state):
|
|
||||||
"""
|
|
||||||
Generate answer
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state (messages): The current state
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: The updated state with re-phrased question
|
|
||||||
"""
|
|
||||||
messages = state["messages"]
|
|
||||||
user_question = messages[0].content
|
|
||||||
|
|
||||||
# Prompt
|
|
||||||
prompt = self.retrieval_config.prompt
|
|
||||||
|
|
||||||
final_inputs = {}
|
|
||||||
final_inputs["question"] = user_question
|
|
||||||
final_inputs["custom_instructions"] = prompt if prompt else "None"
|
|
||||||
final_inputs["chat_history"] = state["chat_history"].to_list()
|
|
||||||
|
|
||||||
# LLM
|
|
||||||
llm = self.llm_endpoint._llm
|
|
||||||
|
|
||||||
# Chain
|
|
||||||
rag_chain = custom_prompts.CHAT_LLM_PROMPT | llm
|
|
||||||
|
|
||||||
# Run
|
|
||||||
response = rag_chain.invoke(final_inputs)
|
|
||||||
formatted_response = {
|
|
||||||
"answer": response, # Assuming the last message contains the final answer
|
|
||||||
}
|
|
||||||
return {"messages": [response], "final_response": formatted_response}
|
|
||||||
|
|
||||||
def build_chain(self):
|
|
||||||
"""
|
|
||||||
Builds the langchain chain for the given configuration.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Callable[[Dict], Dict]: The langchain chain.
|
|
||||||
"""
|
|
||||||
if not self.graph:
|
|
||||||
self.graph = self.create_graph()
|
|
||||||
|
|
||||||
return self.graph
|
|
||||||
|
|
||||||
def create_graph(self):
|
|
||||||
"""
|
|
||||||
Builds the langchain chain for the given configuration.
|
|
||||||
|
|
||||||
This function creates a state machine which takes a chat history and a question
|
|
||||||
and produces an answer. The state machine consists of the following states:
|
|
||||||
|
|
||||||
- filter_history: Filter the chat history (i.e., remove the last message)
|
|
||||||
- rewrite: Re-write the question using the filtered history
|
|
||||||
- retrieve: Retrieve documents related to the re-written question
|
|
||||||
- generate: Generate an answer using the retrieved documents
|
|
||||||
|
|
||||||
The state machine starts in the filter_history state and transitions as follows:
|
|
||||||
filter_history -> rewrite -> retrieve -> generate -> END
|
|
||||||
|
|
||||||
The final answer is returned as a dictionary with the answer and the list of documents
|
|
||||||
used to generate the answer.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Callable[[Dict], Dict]: The langchain chain.
|
|
||||||
"""
|
|
||||||
workflow = StateGraph(AgentState)
|
|
||||||
|
|
||||||
if self.retrieval_config.workflow_config:
|
|
||||||
if SpecialEdges.START not in [
|
|
||||||
node.name for node in self.retrieval_config.workflow_config.nodes
|
|
||||||
]:
|
|
||||||
raise ValueError("The workflow should contain a 'START' node")
|
|
||||||
for node in self.retrieval_config.workflow_config.nodes:
|
|
||||||
if node.name not in SpecialEdges._value2member_map_:
|
|
||||||
workflow.add_node(node.name, getattr(self, node.name))
|
|
||||||
|
|
||||||
for node in self.retrieval_config.workflow_config.nodes:
|
|
||||||
for edge in node.edges:
|
|
||||||
if node.name == SpecialEdges.START:
|
|
||||||
workflow.add_edge(START, edge)
|
|
||||||
elif edge == SpecialEdges.END:
|
|
||||||
workflow.add_edge(node.name, END)
|
|
||||||
else:
|
|
||||||
workflow.add_edge(node.name, edge)
|
|
||||||
else:
|
|
||||||
# Define the nodes we will cycle between
|
|
||||||
workflow.add_node("filter_history", self.filter_history)
|
|
||||||
workflow.add_node("rewrite", self.rewrite) # Re-writing the question
|
|
||||||
workflow.add_node("retrieve", self.retrieve) # retrieval
|
|
||||||
workflow.add_node("generate", self.generate_rag)
|
|
||||||
|
|
||||||
# Add node for filtering history
|
|
||||||
|
|
||||||
workflow.set_entry_point("filter_history")
|
|
||||||
workflow.add_edge("filter_history", "rewrite")
|
|
||||||
workflow.add_edge("rewrite", "retrieve")
|
|
||||||
workflow.add_edge("retrieve", "generate")
|
|
||||||
workflow.add_edge(
|
|
||||||
"generate", END
|
|
||||||
) # Add edge from generate to format_response
|
|
||||||
|
|
||||||
# Compile
|
|
||||||
graph = workflow.compile()
|
|
||||||
return graph
|
|
||||||
|
|
||||||
def answer(
|
|
||||||
self,
|
|
||||||
question: str,
|
|
||||||
history: ChatHistory,
|
|
||||||
list_files: list[QuivrKnowledge],
|
|
||||||
metadata: dict[str, str] = {},
|
|
||||||
) -> ParsedRAGResponse:
|
|
||||||
"""
|
|
||||||
Answer a question using the langgraph chain.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
question (str): The question to answer.
|
|
||||||
history (ChatHistory): The chat history to use for context.
|
|
||||||
list_files (list[QuivrKnowledge]): The list of files to use for retrieval.
|
|
||||||
metadata (dict[str, str], optional): The metadata to pass to the langchain invocation. Defaults to {}.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ParsedRAGResponse: The answer to the question.
|
|
||||||
"""
|
|
||||||
concat_list_files = format_file_list(
|
|
||||||
list_files, self.retrieval_config.max_files
|
|
||||||
)
|
|
||||||
conversational_qa_chain = self.build_chain()
|
|
||||||
inputs = {
|
|
||||||
"messages": [
|
|
||||||
("user", question),
|
|
||||||
],
|
|
||||||
"chat_history": history,
|
|
||||||
"files": concat_list_files,
|
|
||||||
}
|
|
||||||
raw_llm_response = conversational_qa_chain.invoke(
|
|
||||||
inputs,
|
|
||||||
config={"metadata": metadata},
|
|
||||||
)
|
|
||||||
response = parse_response(
|
|
||||||
raw_llm_response["final_response"], self.retrieval_config.llm_config.model
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def answer_astream(
|
|
||||||
self,
|
|
||||||
question: str,
|
|
||||||
history: ChatHistory,
|
|
||||||
list_files: list[QuivrKnowledge],
|
|
||||||
metadata: dict[str, str] = {},
|
|
||||||
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
|
|
||||||
"""
|
|
||||||
Answer a question using the langgraph chain and yield each chunk of the answer separately.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
question (str): The question to answer.
|
|
||||||
history (ChatHistory): The chat history to use for context.
|
|
||||||
list_files (list[QuivrKnowledge]): The list of files to use for retrieval.
|
|
||||||
metadata (dict[str, str], optional): The metadata to pass to the langchain invocation. Defaults to {}.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
ParsedRAGChunkResponse: Each chunk of the answer.
|
|
||||||
"""
|
|
||||||
concat_list_files = format_file_list(
|
|
||||||
list_files, self.retrieval_config.max_files
|
|
||||||
)
|
|
||||||
conversational_qa_chain = self.build_chain()
|
|
||||||
|
|
||||||
rolling_message = AIMessageChunk(content="")
|
|
||||||
sources: list[Document] | None = None
|
|
||||||
prev_answer = ""
|
|
||||||
chunk_id = 0
|
|
||||||
|
|
||||||
async for event in conversational_qa_chain.astream_events(
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
("user", question),
|
|
||||||
],
|
|
||||||
"chat_history": history,
|
|
||||||
"files": concat_list_files,
|
|
||||||
},
|
|
||||||
version="v1",
|
|
||||||
config={"metadata": metadata},
|
|
||||||
):
|
|
||||||
kind = event["event"]
|
|
||||||
if (
|
|
||||||
not sources
|
|
||||||
and "output" in event["data"]
|
|
||||||
and "docs" in event["data"]["output"]
|
|
||||||
):
|
|
||||||
sources = event["data"]["output"]["docs"]
|
|
||||||
|
|
||||||
if (
|
|
||||||
kind == "on_chat_model_stream"
|
|
||||||
and "generate" in event["metadata"]["langgraph_node"]
|
|
||||||
):
|
|
||||||
chunk = event["data"]["chunk"]
|
|
||||||
rolling_message, answer_str = parse_chunk_response(
|
|
||||||
rolling_message,
|
|
||||||
chunk,
|
|
||||||
self.llm_endpoint.supports_func_calling(),
|
|
||||||
)
|
|
||||||
if len(answer_str) > 0:
|
|
||||||
if (
|
|
||||||
self.llm_endpoint.supports_func_calling()
|
|
||||||
and rolling_message.tool_calls
|
|
||||||
):
|
|
||||||
diff_answer = answer_str[len(prev_answer) :]
|
|
||||||
if len(diff_answer) > 0:
|
|
||||||
parsed_chunk = ParsedRAGChunkResponse(
|
|
||||||
answer=diff_answer,
|
|
||||||
metadata=RAGResponseMetadata(),
|
|
||||||
)
|
|
||||||
prev_answer += diff_answer
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"answer_astream func_calling=True question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}"
|
|
||||||
)
|
|
||||||
yield parsed_chunk
|
|
||||||
else:
|
|
||||||
parsed_chunk = ParsedRAGChunkResponse(
|
|
||||||
answer=answer_str,
|
|
||||||
metadata=RAGResponseMetadata(),
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"answer_astream func_calling=False question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}"
|
|
||||||
)
|
|
||||||
yield parsed_chunk
|
|
||||||
|
|
||||||
chunk_id += 1
|
|
||||||
|
|
||||||
# Last chunk provides metadata
|
|
||||||
last_chunk = ParsedRAGChunkResponse(
|
|
||||||
answer="",
|
|
||||||
metadata=get_chunk_metadata(rolling_message, sources),
|
|
||||||
last_chunk=True,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"answer_astream last_chunk={last_chunk} question={question} rolling_msg={rolling_message} chunk_id={chunk_id}"
|
|
||||||
)
|
|
||||||
yield last_chunk
|
|
0
core/quivr_core/rag/__init__.py
Normal file
0
core/quivr_core/rag/__init__.py
Normal file
0
core/quivr_core/rag/entities/__init__.py
Normal file
0
core/quivr_core/rag/entities/__init__.py
Normal file
@ -1,11 +1,10 @@
|
|||||||
from copy import deepcopy
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Generator, List, Tuple
|
from typing import Any, Generator, Tuple, List
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
from quivr_core.models import ChatMessage
|
from quivr_core.rag.entities.models import ChatMessage
|
||||||
|
|
||||||
|
|
||||||
class ChatHistory:
|
class ChatHistory:
|
||||||
@ -98,17 +97,3 @@ class ChatHistory:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
return [_msg.msg for _msg in self._msgs]
|
return [_msg.msg for _msg in self._msgs]
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
|
||||||
"""
|
|
||||||
Support for deepcopy of ChatHistory.
|
|
||||||
This method ensures that mutable objects (like lists) are copied deeply.
|
|
||||||
"""
|
|
||||||
# Create a new instance of ChatHistory
|
|
||||||
new_copy = ChatHistory(self.id, deepcopy(self.brain_id, memo))
|
|
||||||
|
|
||||||
# Perform a deepcopy of the _msgs list
|
|
||||||
new_copy._msgs = deepcopy(self._msgs, memo)
|
|
||||||
|
|
||||||
# Return the deep copied instance
|
|
||||||
return new_copy
|
|
452
core/quivr_core/rag/entities/config.py
Normal file
452
core/quivr_core/rag/entities/config.py
Normal file
@ -0,0 +1,452 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Hashable, List, Optional, Union, Any, Type
|
||||||
|
from uuid import UUID
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from langgraph.graph import START, END
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
from quivr_core.config import MegaparseConfig
|
||||||
|
from rapidfuzz import process, fuzz
|
||||||
|
|
||||||
|
|
||||||
|
from quivr_core.base_config import QuivrBaseConfig
|
||||||
|
from quivr_core.processor.splitter import SplitterConfig
|
||||||
|
from quivr_core.rag.prompts import CustomPromptsModel
|
||||||
|
from quivr_core.llm_tools.llm_tools import LLMToolFactory, TOOLS_CATEGORIES, TOOLS_LISTS
|
||||||
|
|
||||||
|
logger = logging.getLogger("quivr_core")
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_to_env_variable_name(name: str) -> str:
|
||||||
|
# Replace any character that is not a letter, digit, or underscore with an underscore
|
||||||
|
env_variable_name = re.sub(r"[^A-Za-z0-9_]", "_", name).upper()
|
||||||
|
|
||||||
|
# Check if the normalized name starts with a digit
|
||||||
|
if env_variable_name[0].isdigit():
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid environment variable name '{env_variable_name}': Cannot start with a digit."
|
||||||
|
)
|
||||||
|
|
||||||
|
return env_variable_name
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialEdges(str, Enum):
|
||||||
|
start = "START"
|
||||||
|
end = "END"
|
||||||
|
|
||||||
|
|
||||||
|
class BrainConfig(QuivrBaseConfig):
|
||||||
|
brain_id: UUID | None = None
|
||||||
|
name: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> UUID | None:
|
||||||
|
return self.brain_id
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultWebSearchTool(str, Enum):
|
||||||
|
TAVILY = "tavily"
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultRerankers(str, Enum):
|
||||||
|
COHERE = "cohere"
|
||||||
|
JINA = "jina"
|
||||||
|
# MIXEDBREAD = "mixedbread-ai"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_model(self) -> str:
|
||||||
|
# Mapping of suppliers to their default models
|
||||||
|
return {
|
||||||
|
self.COHERE: "rerank-multilingual-v3.0",
|
||||||
|
self.JINA: "jina-reranker-v2-base-multilingual",
|
||||||
|
# self.MIXEDBREAD: "rmxbai-rerank-large-v1",
|
||||||
|
}[self]
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultModelSuppliers(str, Enum):
|
||||||
|
OPENAI = "openai"
|
||||||
|
AZURE = "azure"
|
||||||
|
ANTHROPIC = "anthropic"
|
||||||
|
META = "meta"
|
||||||
|
MISTRAL = "mistral"
|
||||||
|
GROQ = "groq"
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfig(QuivrBaseConfig):
|
||||||
|
context: int | None = None
|
||||||
|
tokenizer_hub: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LLMModelConfig:
|
||||||
|
_model_defaults: Dict[DefaultModelSuppliers, Dict[str, LLMConfig]] = {
|
||||||
|
DefaultModelSuppliers.OPENAI: {
|
||||||
|
"gpt-4o": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
|
||||||
|
"gpt-4o-mini": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
|
||||||
|
"gpt-4-turbo": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4"),
|
||||||
|
"gpt-4": LLMConfig(context=8192, tokenizer_hub="Xenova/gpt-4"),
|
||||||
|
"gpt-3.5-turbo": LLMConfig(
|
||||||
|
context=16385, tokenizer_hub="Xenova/gpt-3.5-turbo"
|
||||||
|
),
|
||||||
|
"text-embedding-3-large": LLMConfig(
|
||||||
|
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
|
||||||
|
),
|
||||||
|
"text-embedding-3-small": LLMConfig(
|
||||||
|
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
|
||||||
|
),
|
||||||
|
"text-embedding-ada-002": LLMConfig(
|
||||||
|
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
DefaultModelSuppliers.ANTHROPIC: {
|
||||||
|
"claude-3-5-sonnet": LLMConfig(
|
||||||
|
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||||
|
),
|
||||||
|
"claude-3-opus": LLMConfig(
|
||||||
|
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||||
|
),
|
||||||
|
"claude-3-sonnet": LLMConfig(
|
||||||
|
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||||
|
),
|
||||||
|
"claude-3-haiku": LLMConfig(
|
||||||
|
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||||
|
),
|
||||||
|
"claude-2-1": LLMConfig(
|
||||||
|
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||||
|
),
|
||||||
|
"claude-2-0": LLMConfig(
|
||||||
|
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||||
|
),
|
||||||
|
"claude-instant-1-2": LLMConfig(
|
||||||
|
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
DefaultModelSuppliers.META: {
|
||||||
|
"llama-3.1": LLMConfig(
|
||||||
|
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
|
||||||
|
),
|
||||||
|
"llama-3": LLMConfig(
|
||||||
|
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
|
||||||
|
),
|
||||||
|
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
|
||||||
|
"code-llama": LLMConfig(
|
||||||
|
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
DefaultModelSuppliers.GROQ: {
|
||||||
|
"llama-3.1": LLMConfig(
|
||||||
|
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
|
||||||
|
),
|
||||||
|
"llama-3": LLMConfig(
|
||||||
|
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
|
||||||
|
),
|
||||||
|
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
|
||||||
|
"code-llama": LLMConfig(
|
||||||
|
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
DefaultModelSuppliers.MISTRAL: {
|
||||||
|
"mistral-large": LLMConfig(
|
||||||
|
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
|
||||||
|
),
|
||||||
|
"mistral-small": LLMConfig(
|
||||||
|
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
|
||||||
|
),
|
||||||
|
"mistral-nemo": LLMConfig(
|
||||||
|
context=128000, tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer"
|
||||||
|
),
|
||||||
|
"codestral": LLMConfig(
|
||||||
|
context=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supplier_by_model_name(cls, model: str) -> DefaultModelSuppliers | None:
|
||||||
|
# Iterate over the suppliers and their models
|
||||||
|
for supplier, models in cls._model_defaults.items():
|
||||||
|
# Check if the model name or a base part of the model name is in the supplier's models
|
||||||
|
for base_model_name in models:
|
||||||
|
if model.startswith(base_model_name):
|
||||||
|
return supplier
|
||||||
|
# Return None if no supplier matches the model name
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_llm_model_config(
|
||||||
|
cls, supplier: DefaultModelSuppliers, model_name: str
|
||||||
|
) -> Optional[LLMConfig]:
|
||||||
|
"""Retrieve the LLMConfig (context and tokenizer_hub) for a given supplier and model."""
|
||||||
|
supplier_defaults = cls._model_defaults.get(supplier)
|
||||||
|
if not supplier_defaults:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Use startswith logic for matching model names
|
||||||
|
for key, config in supplier_defaults.items():
|
||||||
|
if model_name.startswith(key):
|
||||||
|
return config
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class LLMEndpointConfig(QuivrBaseConfig):
|
||||||
|
supplier: DefaultModelSuppliers = DefaultModelSuppliers.OPENAI
|
||||||
|
model: str = "gpt-4o"
|
||||||
|
context_length: int | None = None
|
||||||
|
tokenizer_hub: str | None = None
|
||||||
|
llm_base_url: str | None = None
|
||||||
|
env_variable_name: str | None = None
|
||||||
|
llm_api_key: str | None = None
|
||||||
|
max_context_tokens: int = 2000
|
||||||
|
max_output_tokens: int = 2000
|
||||||
|
temperature: float = 0.7
|
||||||
|
streaming: bool = True
|
||||||
|
prompt: CustomPromptsModel | None = None
|
||||||
|
|
||||||
|
_FALLBACK_TOKENIZER = "cl100k_base"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fallback_tokenizer(self) -> str:
|
||||||
|
return self._FALLBACK_TOKENIZER
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
self.set_llm_model_config()
|
||||||
|
self.set_api_key()
|
||||||
|
|
||||||
|
def set_api_key(self, force_reset: bool = False):
|
||||||
|
if not self.supplier:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the corresponding API key environment variable is set
|
||||||
|
if not self.env_variable_name:
|
||||||
|
self.env_variable_name = (
|
||||||
|
f"{normalize_to_env_variable_name(self.supplier)}_API_KEY"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.llm_api_key or force_reset:
|
||||||
|
self.llm_api_key = os.getenv(self.env_variable_name)
|
||||||
|
|
||||||
|
if not self.llm_api_key:
|
||||||
|
logger.warning(f"The API key for supplier '{self.supplier}' is not set. ")
|
||||||
|
|
||||||
|
def set_llm_model_config(self):
|
||||||
|
# Automatically set context_length and tokenizer_hub based on the supplier and model
|
||||||
|
llm_model_config = LLMModelConfig.get_llm_model_config(
|
||||||
|
self.supplier, self.model
|
||||||
|
)
|
||||||
|
if llm_model_config:
|
||||||
|
self.context_length = llm_model_config.context
|
||||||
|
self.tokenizer_hub = llm_model_config.tokenizer_hub
|
||||||
|
|
||||||
|
def set_llm_model(self, model: str):
|
||||||
|
supplier = LLMModelConfig.get_supplier_by_model_name(model)
|
||||||
|
if supplier is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot find the corresponding supplier for model {model}"
|
||||||
|
)
|
||||||
|
self.supplier = supplier
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
self.set_llm_model_config()
|
||||||
|
self.set_api_key(force_reset=True)
|
||||||
|
|
||||||
|
def set_from_sqlmodel(self, sqlmodel: BaseModel, mapping: Dict[str, str]):
|
||||||
|
"""
|
||||||
|
Set attributes in LLMEndpointConfig from Model attributes using a field mapping.
|
||||||
|
|
||||||
|
:param model_instance: An instance of the Model class.
|
||||||
|
:param mapping: A dictionary that maps Model fields to LLMEndpointConfig fields.
|
||||||
|
Example: {"max_input": "max_input_tokens", "env_variable_name": "env_variable_name"}
|
||||||
|
"""
|
||||||
|
for model_field, llm_field in mapping.items():
|
||||||
|
if hasattr(sqlmodel, model_field) and hasattr(self, llm_field):
|
||||||
|
setattr(self, llm_field, getattr(sqlmodel, model_field))
|
||||||
|
else:
|
||||||
|
raise AttributeError(
|
||||||
|
f"Invalid mapping: {model_field} or {llm_field} does not exist."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Cannot use Pydantic v2 field_validator because of conflicts with pydantic v1 still in use in LangChain
|
||||||
|
class RerankerConfig(QuivrBaseConfig):
|
||||||
|
supplier: DefaultRerankers | None = None
|
||||||
|
model: str | None = None
|
||||||
|
top_n: int = 5 # Number of chunks returned by the re-ranker
|
||||||
|
api_key: str | None = None
|
||||||
|
relevance_score_threshold: float | None = None
|
||||||
|
relevance_score_key: str = "relevance_score"
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data) # Call Pydantic's BaseModel init
|
||||||
|
self.validate_model() # Automatically call external validation
|
||||||
|
|
||||||
|
def validate_model(self):
|
||||||
|
# If model is not provided, get default model based on supplier
|
||||||
|
if self.model is None and self.supplier is not None:
|
||||||
|
self.model = self.supplier.default_model
|
||||||
|
|
||||||
|
# Check if the corresponding API key environment variable is set
|
||||||
|
if self.supplier:
|
||||||
|
api_key_var = f"{normalize_to_env_variable_name(self.supplier)}_API_KEY"
|
||||||
|
self.api_key = os.getenv(api_key_var)
|
||||||
|
|
||||||
|
if self.api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The API key for supplier '{self.supplier}' is not set. "
|
||||||
|
f"Please set the environment variable: {api_key_var}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalEdgeConfig(QuivrBaseConfig):
|
||||||
|
routing_function: str
|
||||||
|
conditions: Union[list, Dict[Hashable, str]]
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
self.resolve_special_edges()
|
||||||
|
|
||||||
|
def resolve_special_edges(self):
|
||||||
|
"""Replace SpecialEdges enum values with their corresponding langgraph values."""
|
||||||
|
|
||||||
|
if isinstance(self.conditions, dict):
|
||||||
|
# If conditions is a dictionary, iterate through the key-value pairs
|
||||||
|
for key, value in self.conditions.items():
|
||||||
|
if value == SpecialEdges.end:
|
||||||
|
self.conditions[key] = END
|
||||||
|
elif value == SpecialEdges.start:
|
||||||
|
self.conditions[key] = START
|
||||||
|
elif isinstance(self.conditions, list):
|
||||||
|
# If conditions is a list, iterate through the values
|
||||||
|
for index, value in enumerate(self.conditions):
|
||||||
|
if value == SpecialEdges.end:
|
||||||
|
self.conditions[index] = END
|
||||||
|
elif value == SpecialEdges.start:
|
||||||
|
self.conditions[index] = START
|
||||||
|
|
||||||
|
|
||||||
|
class NodeConfig(QuivrBaseConfig):
|
||||||
|
name: str
|
||||||
|
edges: List[str] | None = None
|
||||||
|
conditional_edge: ConditionalEdgeConfig | None = None
|
||||||
|
tools: List[Dict[str, Any]] | None = None
|
||||||
|
instantiated_tools: List[BaseTool | Type] | None = None
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
self._instantiate_tools()
|
||||||
|
self.resolve_special_edges_in_name_and_edges()
|
||||||
|
|
||||||
|
def resolve_special_edges_in_name_and_edges(self):
|
||||||
|
"""Replace SpecialEdges enum values in name and edges with corresponding langgraph values."""
|
||||||
|
if self.name == SpecialEdges.start:
|
||||||
|
self.name = START
|
||||||
|
elif self.name == SpecialEdges.end:
|
||||||
|
self.name = END
|
||||||
|
|
||||||
|
if self.edges:
|
||||||
|
for i, edge in enumerate(self.edges):
|
||||||
|
if edge == SpecialEdges.start:
|
||||||
|
self.edges[i] = START
|
||||||
|
elif edge == SpecialEdges.end:
|
||||||
|
self.edges[i] = END
|
||||||
|
|
||||||
|
def _instantiate_tools(self):
|
||||||
|
"""Instantiate tools based on the configuration."""
|
||||||
|
if self.tools:
|
||||||
|
self.instantiated_tools = [
|
||||||
|
LLMToolFactory.create_tool(tool_config.pop("name"), tool_config)
|
||||||
|
for tool_config in self.tools
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultWorkflow(str, Enum):
|
||||||
|
RAG = "rag"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nodes(self) -> List[NodeConfig]:
|
||||||
|
# Mapping of workflow types to their default node configurations
|
||||||
|
workflows = {
|
||||||
|
self.RAG: [
|
||||||
|
NodeConfig(name=START, edges=["filter_history"]),
|
||||||
|
NodeConfig(name="filter_history", edges=["rewrite"]),
|
||||||
|
NodeConfig(name="rewrite", edges=["retrieve"]),
|
||||||
|
NodeConfig(name="retrieve", edges=["generate_rag"]),
|
||||||
|
NodeConfig(name="generate_rag", edges=[END]),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
return workflows[self]
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowConfig(QuivrBaseConfig):
|
||||||
|
name: str | None = None
|
||||||
|
nodes: List[NodeConfig] = []
|
||||||
|
available_tools: List[str] | None = None
|
||||||
|
validated_tools: List[BaseTool | Type] = []
|
||||||
|
activated_tools: List[BaseTool | Type] = []
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
self.check_first_node_is_start()
|
||||||
|
self.validate_available_tools()
|
||||||
|
|
||||||
|
def check_first_node_is_start(self):
|
||||||
|
if self.nodes and self.nodes[0].name != START:
|
||||||
|
raise ValueError(f"The first node should be a {SpecialEdges.start} node")
|
||||||
|
|
||||||
|
def get_node_tools(self, node_name: str) -> List[Any]:
|
||||||
|
"""Get tools for a specific node."""
|
||||||
|
for node in self.nodes:
|
||||||
|
if node.name == node_name and node.instantiated_tools:
|
||||||
|
return node.instantiated_tools
|
||||||
|
return []
|
||||||
|
|
||||||
|
def validate_available_tools(self):
|
||||||
|
if self.available_tools:
|
||||||
|
valid_tools = list(TOOLS_CATEGORIES.keys()) + list(TOOLS_LISTS.keys())
|
||||||
|
for tool in self.available_tools:
|
||||||
|
if tool.lower() in valid_tools:
|
||||||
|
self.validated_tools.append(
|
||||||
|
LLMToolFactory.create_tool(tool, {}).tool
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
matches = process.extractOne(
|
||||||
|
tool.lower(), valid_tools, scorer=fuzz.WRatio
|
||||||
|
)
|
||||||
|
if matches:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool {tool} is not a valid ToolsCategory or ToolsList. Did you mean {matches[0]}?"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool {tool} is not a valid ToolsCategory or ToolsList"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalConfig(QuivrBaseConfig):
|
||||||
|
reranker_config: RerankerConfig = RerankerConfig()
|
||||||
|
llm_config: LLMEndpointConfig = LLMEndpointConfig()
|
||||||
|
max_history: int = 10
|
||||||
|
max_files: int = 20
|
||||||
|
k: int = 40 # Number of chunks returned by the retriever
|
||||||
|
prompt: str | None = None
|
||||||
|
workflow_config: WorkflowConfig = WorkflowConfig(nodes=DefaultWorkflow.RAG.nodes)
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
self.llm_config.set_api_key(force_reset=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ParserConfig(QuivrBaseConfig):
|
||||||
|
splitter_config: SplitterConfig = SplitterConfig()
|
||||||
|
megaparse_config: MegaparseConfig = MegaparseConfig()
|
||||||
|
|
||||||
|
|
||||||
|
class IngestionConfig(QuivrBaseConfig):
|
||||||
|
parser_config: ParserConfig = ParserConfig()
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantConfig(QuivrBaseConfig):
|
||||||
|
retrieval_config: RetrievalConfig = RetrievalConfig()
|
||||||
|
ingestion_config: IngestionConfig = IngestionConfig()
|
@ -7,7 +7,7 @@ from langchain_core.documents import Document
|
|||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
|
from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
|
||||||
from langchain_core.pydantic_v1 import Field as FieldV1
|
from langchain_core.pydantic_v1 import Field as FieldV1
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
@ -73,9 +73,9 @@ class ChatLLMMetadata(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class RAGResponseMetadata(BaseModel):
|
class RAGResponseMetadata(BaseModel):
|
||||||
citations: list[int] | None = None
|
citations: list[int] = Field(default_factory=list)
|
||||||
followup_questions: list[str] | None = None
|
followup_questions: list[str] = Field(default_factory=list)
|
||||||
sources: list[Any] | None = None
|
sources: list[Any] = Field(default_factory=list)
|
||||||
metadata_model: ChatLLMMetadata | None = None
|
metadata_model: ChatLLMMetadata | None = None
|
||||||
|
|
||||||
|
|
265
core/quivr_core/rag/prompts.py
Normal file
265
core/quivr_core/rag/prompts.py
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
import datetime
|
||||||
|
from pydantic import ConfigDict, create_model
|
||||||
|
|
||||||
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
|
from langchain_core.prompts import (
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
PromptTemplate,
|
||||||
|
SystemMessagePromptTemplate,
|
||||||
|
MessagesPlaceholder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomPromptsDict(dict):
|
||||||
|
def __init__(self, type, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._type = type
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
# Automatically convert the value into a tuple (my_type, value)
|
||||||
|
super().__setitem__(key, (self._type, value))
|
||||||
|
|
||||||
|
|
||||||
|
def _define_custom_prompts() -> CustomPromptsDict:
|
||||||
|
custom_prompts: CustomPromptsDict = CustomPromptsDict(type=BasePromptTemplate)
|
||||||
|
|
||||||
|
today_date = datetime.datetime.now().strftime("%B %d, %Y")
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Prompt for question rephrasing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
system_message_template = (
|
||||||
|
"Given a chat history and the latest user question "
|
||||||
|
"which might reference context in the chat history, "
|
||||||
|
"formulate a standalone question which can be understood "
|
||||||
|
"without the chat history. Do NOT answer the question, "
|
||||||
|
"just reformulate it if needed and otherwise return it as is. "
|
||||||
|
"Do not output your reasoning, just the question."
|
||||||
|
)
|
||||||
|
|
||||||
|
template_answer = "User question: {question}\n Standalone question:"
|
||||||
|
|
||||||
|
CONDENSE_QUESTION_PROMPT = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||||
|
MessagesPlaceholder(variable_name="chat_history"),
|
||||||
|
HumanMessagePromptTemplate.from_template(template_answer),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
custom_prompts["CONDENSE_QUESTION_PROMPT"] = CONDENSE_QUESTION_PROMPT
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Prompt for RAG
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
system_message_template = f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}. "
|
||||||
|
|
||||||
|
system_message_template += (
|
||||||
|
"- When answering use markdown. Use markdown code blocks for code snippets.\n"
|
||||||
|
"- Answer in a concise and clear manner.\n"
|
||||||
|
"- If no preferred language is provided, answer in the same language as the language used by the user.\n"
|
||||||
|
"- You must use ONLY the provided context to answer the question. "
|
||||||
|
"Do not use any prior knowledge or external information, even if you are certain of the answer.\n"
|
||||||
|
"- If you cannot provide an answer using ONLY the context provided, do not attempt to answer from your own knowledge."
|
||||||
|
"Instead, inform the user that the answer isn't available in the context and suggest using the available tools {tools}.\n"
|
||||||
|
"- Do not apologize when providing an answer.\n"
|
||||||
|
"- Don't cite the source id in the answer objects, but you can use the source to answer the question.\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
context_template = (
|
||||||
|
"\n"
|
||||||
|
"- You have access to the following internal reasoning to provide an answer: {reasoning}\n"
|
||||||
|
"- You have access to the following files to answer the user question (limited to first 20 files): {files}\n"
|
||||||
|
"- You have access to the following context to answer the user question: {context}\n"
|
||||||
|
"- Follow these user instruction when crafting the answer: {custom_instructions}\n"
|
||||||
|
"- These user instructions shall take priority over any other previous instruction.\n"
|
||||||
|
"- Remember: if you cannot provide an answer using ONLY the provided context and CITING the sources, "
|
||||||
|
"inform the user that you don't have the answer and consider if any of the tools can help answer the question.\n"
|
||||||
|
"- Explain your reasoning about the potentiel tool usage in the answer.\n"
|
||||||
|
"- Only use binded tools to answer the question.\n"
|
||||||
|
# "OFFER the user the possibility to ACTIVATE a relevant tool among "
|
||||||
|
# "the tools which can be activated."
|
||||||
|
# "Tools which can be activated: {tools}. If any of these tools can help in providing an answer "
|
||||||
|
# "to the user question, you should offer the user the possibility to activate it. "
|
||||||
|
# "Remember, you shall NOT use the above tools, ONLY offer the user the possibility to activate them.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
template_answer = (
|
||||||
|
"Original task: {question}\n"
|
||||||
|
"Rephrased and contextualized task: {rephrased_task}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
RAG_ANSWER_PROMPT = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||||
|
MessagesPlaceholder(variable_name="chat_history"),
|
||||||
|
SystemMessagePromptTemplate.from_template(context_template),
|
||||||
|
HumanMessagePromptTemplate.from_template(template_answer),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
custom_prompts["RAG_ANSWER_PROMPT"] = RAG_ANSWER_PROMPT
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Prompt for formatting documents
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(
|
||||||
|
template="Filename: {original_file_name}\nSource: {index} \n {page_content}"
|
||||||
|
)
|
||||||
|
custom_prompts["DEFAULT_DOCUMENT_PROMPT"] = DEFAULT_DOCUMENT_PROMPT
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Prompt for chatting directly with LLMs, without any document retrieval stage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
system_message_template = (
|
||||||
|
f"Your name is Quivr. You're a helpful assistant. Today's date is {today_date}."
|
||||||
|
)
|
||||||
|
system_message_template += """
|
||||||
|
If not None, also follow these user instructions when answering: {custom_instructions}
|
||||||
|
"""
|
||||||
|
|
||||||
|
template_answer = """
|
||||||
|
User Question: {question}
|
||||||
|
Answer:
|
||||||
|
"""
|
||||||
|
CHAT_LLM_PROMPT = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||||
|
MessagesPlaceholder(variable_name="chat_history"),
|
||||||
|
HumanMessagePromptTemplate.from_template(template_answer),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
custom_prompts["CHAT_LLM_PROMPT"] = CHAT_LLM_PROMPT
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Prompt to understand the user intent
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
system_message_template = (
|
||||||
|
"Given the following user input, determine the user intent, in particular "
|
||||||
|
"whether the user is providing instructions to the system or is asking the system to "
|
||||||
|
"execute a task:\n"
|
||||||
|
" - if the user is providing direct instructions to modify the system behaviour (for instance, "
|
||||||
|
"'Can you reply in French?' or 'Answer in French' or 'You are an expert legal assistant' "
|
||||||
|
"or 'You will behave as...'), the user intent is 'prompt';\n"
|
||||||
|
" - in all other cases (asking questions, asking for summarising a text, asking for translating a text, ...), "
|
||||||
|
"the intent is 'task'.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
template_answer = "User input: {question}"
|
||||||
|
|
||||||
|
USER_INTENT_PROMPT = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||||
|
HumanMessagePromptTemplate.from_template(template_answer),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
custom_prompts["USER_INTENT_PROMPT"] = USER_INTENT_PROMPT
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Prompt to create a system prompt from user instructions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
system_message_template = (
|
||||||
|
"- Given the following user instruction, current system prompt, list of available tools "
|
||||||
|
"and list of activated tools, update the prompt to include the instruction and decide which tools to activate.\n"
|
||||||
|
"- The prompt shall only contain generic instructions which can be applied to any user task or question.\n"
|
||||||
|
"- The prompt shall be concise and clear.\n"
|
||||||
|
"- If the system prompt already contains the instruction, do not add it again.\n"
|
||||||
|
"- If the system prompt contradicts ther user instruction, remove the contradictory "
|
||||||
|
"statement or statements in the system prompt.\n"
|
||||||
|
"- You shall return separately the updated system prompt and the reasoning that led to the update.\n"
|
||||||
|
"- If the system prompt refers to a tool, you shall add the tool to the list of activated tools.\n"
|
||||||
|
"- If no tool activation is needed, return empty lists.\n"
|
||||||
|
"- You shall also return the reasoning that led to the tool activation.\n"
|
||||||
|
"- Current system prompt: {system_prompt}\n"
|
||||||
|
"- List of available tools: {available_tools}\n"
|
||||||
|
"- List of activated tools: {activated_tools}\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
template_answer = "User instructions: {instruction}\n"
|
||||||
|
|
||||||
|
UPDATE_PROMPT = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||||
|
HumanMessagePromptTemplate.from_template(template_answer),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
custom_prompts["UPDATE_PROMPT"] = UPDATE_PROMPT
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Prompt to split the user input into multiple questions / instructions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
system_message_template = (
|
||||||
|
"Given a chat history and the user input, split and rephrase the input into instructions and tasks.\n"
|
||||||
|
"- Instructions direct the system to behave in a certain way or to use specific tools: examples of instructions are "
|
||||||
|
"'Can you reply in French?', 'Answer in French', 'You are an expert legal assistant', "
|
||||||
|
"'You will behave as...', 'Use web search').\n"
|
||||||
|
"- You shall collect and condense all the instructions into a single string.\n"
|
||||||
|
"- The instructions shall be standalone and self-contained, so that they can be understood "
|
||||||
|
"without the chat history. If no instructions are found, return an empty string.\n"
|
||||||
|
"- Instructions to be understood may require considering the chat history.\n"
|
||||||
|
"- Tasks are often questions, but they can also be summarisation tasks, translation tasks, content generation tasks, etc.\n"
|
||||||
|
"- Tasks to be understood may require considering the chat history.\n"
|
||||||
|
"- If the user input contains different tasks, you shall split the input into multiple tasks.\n"
|
||||||
|
"- Each splitted task shall be a standalone, self-contained task which can be understood "
|
||||||
|
"without the chat history. You shall rephrase the tasks if needed.\n"
|
||||||
|
"- If no explicit task is present, you shall infer the tasks from the user input and the chat history.\n"
|
||||||
|
"- Do NOT try to solve the tasks or answer the questions, "
|
||||||
|
"just reformulate them if needed and otherwise return them as is.\n"
|
||||||
|
"- Remember, you shall NOT suggest or generate new tasks.\n"
|
||||||
|
"- As an example, the user input 'What is Apple? Who is its CEO? When was it founded?' "
|
||||||
|
"shall be split into the questions 'What is Apple?', 'Who is the CEO of Apple?' and 'When was Apple founded?'.\n"
|
||||||
|
"- If no tasks are found, return the user input as is in the task.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
template_answer = "User input: {user_input}"
|
||||||
|
|
||||||
|
SPLIT_PROMPT = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||||
|
MessagesPlaceholder(variable_name="chat_history"),
|
||||||
|
HumanMessagePromptTemplate.from_template(template_answer),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
custom_prompts["SPLIT_PROMPT"] = SPLIT_PROMPT
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Prompt to grade the relevance of an answer and decide whather to perform a web search
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
system_message_template = (
|
||||||
|
"Given the following tasks you shall determine whether all tasks can be "
|
||||||
|
"completed fully and in the best possible way using the provided context and chat history. "
|
||||||
|
"You shall:\n"
|
||||||
|
"- Consider each task separately,\n"
|
||||||
|
"- Determine whether the context and chat history contain "
|
||||||
|
"all the information necessary to complete the task.\n"
|
||||||
|
"- If the context and chat history do not contain all the information necessary to complete the task, "
|
||||||
|
"consider ONLY the list of tools below and select the tool most appropriate to complete the task.\n"
|
||||||
|
"- If no tools are listed, return the tasks as is and no tool.\n"
|
||||||
|
"- If no relevant tool can be selected, return the tasks as is and no tool.\n"
|
||||||
|
"- Do not propose to use a tool if that tool is not listed among the available tools.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
context_template = "Context: {context}\n {activated_tools}\n"
|
||||||
|
|
||||||
|
template_answer = "Tasks: {tasks}\n"
|
||||||
|
|
||||||
|
TOOL_ROUTING_PROMPT = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
SystemMessagePromptTemplate.from_template(system_message_template),
|
||||||
|
MessagesPlaceholder(variable_name="chat_history"),
|
||||||
|
SystemMessagePromptTemplate.from_template(context_template),
|
||||||
|
HumanMessagePromptTemplate.from_template(template_answer),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
custom_prompts["TOOL_ROUTING_PROMPT"] = TOOL_ROUTING_PROMPT
|
||||||
|
|
||||||
|
return custom_prompts
|
||||||
|
|
||||||
|
|
||||||
|
_custom_prompts = _define_custom_prompts()
|
||||||
|
CustomPromptsModel = create_model(
|
||||||
|
"CustomPromptsModel", **_custom_prompts, __config__=ConfigDict(extra="forbid")
|
||||||
|
)
|
||||||
|
|
||||||
|
custom_prompts = CustomPromptsModel()
|
@ -12,18 +12,18 @@ from langchain_core.output_parsers import StrOutputParser
|
|||||||
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
|
||||||
from quivr_core.chat import ChatHistory
|
from quivr_core.rag.entities.chat import ChatHistory
|
||||||
from quivr_core.config import RetrievalConfig
|
from quivr_core.rag.entities.config import RetrievalConfig
|
||||||
from quivr_core.llm import LLMEndpoint
|
from quivr_core.llm import LLMEndpoint
|
||||||
from quivr_core.models import (
|
from quivr_core.rag.entities.models import (
|
||||||
ParsedRAGChunkResponse,
|
ParsedRAGChunkResponse,
|
||||||
ParsedRAGResponse,
|
ParsedRAGResponse,
|
||||||
QuivrKnowledge,
|
QuivrKnowledge,
|
||||||
RAGResponseMetadata,
|
RAGResponseMetadata,
|
||||||
cited_answer,
|
cited_answer,
|
||||||
)
|
)
|
||||||
from quivr_core.prompts import custom_prompts
|
from quivr_core.rag.prompts import custom_prompts
|
||||||
from quivr_core.utils import (
|
from quivr_core.rag.utils import (
|
||||||
combine_documents,
|
combine_documents,
|
||||||
format_file_list,
|
format_file_list,
|
||||||
get_chunk_metadata,
|
get_chunk_metadata,
|
891
core/quivr_core/rag/quivr_rag_langgraph.py
Normal file
891
core/quivr_core/rag/quivr_rag_langgraph.py
Normal file
@ -0,0 +1,891 @@
|
|||||||
|
import logging
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
AsyncGenerator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
TypedDict,
|
||||||
|
Dict,
|
||||||
|
Any,
|
||||||
|
Type,
|
||||||
|
)
|
||||||
|
from uuid import uuid4
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# TODO(@aminediro): this is the only dependency to langchain package, we should remove it
|
||||||
|
from langchain.retrievers import ContextualCompressionRetriever
|
||||||
|
from langchain_cohere import CohereRerank
|
||||||
|
from langchain_community.document_compressors import JinaRerank
|
||||||
|
from langchain_core.callbacks import Callbacks
|
||||||
|
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.messages.ai import AIMessageChunk
|
||||||
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
from langchain_core.prompts.base import BasePromptTemplate
|
||||||
|
from langgraph.graph import START, END, StateGraph
|
||||||
|
from langgraph.graph.message import add_messages
|
||||||
|
from langgraph.types import Send
|
||||||
|
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from quivr_core.rag.entities.chat import ChatHistory
|
||||||
|
from quivr_core.rag.entities.config import DefaultRerankers, NodeConfig, RetrievalConfig
|
||||||
|
from quivr_core.llm import LLMEndpoint
|
||||||
|
from quivr_core.llm_tools.llm_tools import LLMToolFactory
|
||||||
|
from quivr_core.rag.entities.models import (
|
||||||
|
ParsedRAGChunkResponse,
|
||||||
|
QuivrKnowledge,
|
||||||
|
)
|
||||||
|
from quivr_core.rag.prompts import custom_prompts
|
||||||
|
from quivr_core.rag.utils import (
|
||||||
|
collect_tools,
|
||||||
|
combine_documents,
|
||||||
|
format_file_list,
|
||||||
|
get_chunk_metadata,
|
||||||
|
parse_chunk_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("quivr_core")
|
||||||
|
|
||||||
|
|
||||||
|
class SplittedInput(BaseModel):
|
||||||
|
instructions_reasoning: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The reasoning that leads to identifying the user instructions to the system",
|
||||||
|
)
|
||||||
|
instructions: Optional[str] = Field(
|
||||||
|
default=None, description="The instructions to the system"
|
||||||
|
)
|
||||||
|
|
||||||
|
tasks_reasoning: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The reasoning that leads to identifying the explicit or implicit user tasks and questions",
|
||||||
|
)
|
||||||
|
tasks: Optional[List[str]] = Field(
|
||||||
|
default_factory=lambda: ["No explicit or implicit tasks found"],
|
||||||
|
description="The list of standalone, self-contained tasks or questions.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TasksCompletion(BaseModel):
|
||||||
|
completable_tasks_reasoning: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The reasoning that leads to identifying the user tasks or questions that can be completed using the provided context and chat history.",
|
||||||
|
)
|
||||||
|
completable_tasks: Optional[List[str]] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="The user tasks or questions that can be completed using the provided context and chat history.",
|
||||||
|
)
|
||||||
|
|
||||||
|
non_completable_tasks_reasoning: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The reasoning that leads to identifying the user tasks or questions that cannot be completed using the provided context and chat history.",
|
||||||
|
)
|
||||||
|
non_completable_tasks: Optional[List[str]] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="The user tasks or questions that need a tool to be completed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_reasoning: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The reasoning that leads to identifying the tool that shall be used to complete the tasks.",
|
||||||
|
)
|
||||||
|
tool: Optional[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="The tool that shall be used to complete the tasks.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FinalAnswer(BaseModel):
|
||||||
|
reasoning_answer: str = Field(
|
||||||
|
description="The step-by-step reasoning that led to the final answer"
|
||||||
|
)
|
||||||
|
answer: str = Field(description="The final answer to the user tasks/questions")
|
||||||
|
|
||||||
|
all_tasks_completed: bool = Field(
|
||||||
|
description="Whether all tasks/questions have been successfully answered/completed or not. "
|
||||||
|
" If the final answer to the user is 'I don't know' or 'I don't have enough information' or 'I'm not sure', "
|
||||||
|
" this variable should be 'false'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdatedPromptAndTools(BaseModel):
|
||||||
|
prompt_reasoning: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The step-by-step reasoning that leads to the updated system prompt",
|
||||||
|
)
|
||||||
|
prompt: Optional[str] = Field(default=None, description="The updated system prompt")
|
||||||
|
|
||||||
|
tools_reasoning: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The reasoning that leads to activating and deactivating the tools",
|
||||||
|
)
|
||||||
|
tools_to_activate: Optional[List[str]] = Field(
|
||||||
|
default_factory=list, description="The list of tools to activate"
|
||||||
|
)
|
||||||
|
tools_to_deactivate: Optional[List[str]] = Field(
|
||||||
|
default_factory=list, description="The list of tools to deactivate"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentState(TypedDict):
|
||||||
|
# The add_messages function defines how an update should be processed
|
||||||
|
# Default is to replace. add_messages says "append"
|
||||||
|
messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||||
|
reasoning: List[str]
|
||||||
|
chat_history: ChatHistory
|
||||||
|
docs: list[Document]
|
||||||
|
files: str
|
||||||
|
tasks: List[str]
|
||||||
|
instructions: str
|
||||||
|
tool: str
|
||||||
|
|
||||||
|
|
||||||
|
class IdempotentCompressor(BaseDocumentCompressor):
|
||||||
|
def compress_documents(
|
||||||
|
self,
|
||||||
|
documents: Sequence[Document],
|
||||||
|
query: str,
|
||||||
|
callbacks: Optional[Callbacks] = None,
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
"""
|
||||||
|
A no-op document compressor that simply returns the documents it is given.
|
||||||
|
|
||||||
|
This is a placeholder until a more sophisticated document compression
|
||||||
|
algorithm is implemented.
|
||||||
|
"""
|
||||||
|
return documents
|
||||||
|
|
||||||
|
|
||||||
|
class QuivrQARAGLangGraph:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
retrieval_config: RetrievalConfig,
|
||||||
|
llm: LLMEndpoint,
|
||||||
|
vector_store: VectorStore | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Construct a QuivrQARAGLangGraph object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retrieval_config (RetrievalConfig): The configuration for the RAG model.
|
||||||
|
llm (LLMEndpoint): The LLM to use for generating text.
|
||||||
|
vector_store (VectorStore): The vector store to use for storing and retrieving documents.
|
||||||
|
reranker (BaseDocumentCompressor | None): The document compressor to use for re-ranking documents. Defaults to IdempotentCompressor if not provided.
|
||||||
|
"""
|
||||||
|
self.retrieval_config = retrieval_config
|
||||||
|
self.vector_store = vector_store
|
||||||
|
self.llm_endpoint = llm
|
||||||
|
|
||||||
|
self.graph = None
|
||||||
|
|
||||||
|
def get_reranker(self, **kwargs):
|
||||||
|
# Extract the reranker configuration from self
|
||||||
|
config = self.retrieval_config.reranker_config
|
||||||
|
|
||||||
|
# Allow kwargs to override specific config values
|
||||||
|
supplier = kwargs.pop("supplier", config.supplier)
|
||||||
|
model = kwargs.pop("model", config.model)
|
||||||
|
top_n = kwargs.pop("top_n", config.top_n)
|
||||||
|
api_key = kwargs.pop("api_key", config.api_key)
|
||||||
|
|
||||||
|
if supplier == DefaultRerankers.COHERE:
|
||||||
|
reranker = CohereRerank(
|
||||||
|
model=model, top_n=top_n, cohere_api_key=api_key, **kwargs
|
||||||
|
)
|
||||||
|
elif supplier == DefaultRerankers.JINA:
|
||||||
|
reranker = JinaRerank(
|
||||||
|
model=model, top_n=top_n, jina_api_key=api_key, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reranker = IdempotentCompressor()
|
||||||
|
|
||||||
|
return reranker
|
||||||
|
|
||||||
|
def get_retriever(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Returns a retriever that can retrieve documents from the vector store.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VectorStoreRetriever: The retriever.
|
||||||
|
"""
|
||||||
|
if self.vector_store:
|
||||||
|
retriever = self.vector_store.as_retriever(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError("No vector store provided")
|
||||||
|
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
def routing(self, state: AgentState) -> List[Send]:
|
||||||
|
"""
|
||||||
|
The routing function for the RAG model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (AgentState): The current state of the agent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The next state of the agent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
msg = custom_prompts.SPLIT_PROMPT.format(
|
||||||
|
user_input=state["messages"][0].content,
|
||||||
|
)
|
||||||
|
|
||||||
|
response: SplittedInput
|
||||||
|
|
||||||
|
try:
|
||||||
|
structured_llm = self.llm_endpoint._llm.with_structured_output(
|
||||||
|
SplittedInput, method="json_schema"
|
||||||
|
)
|
||||||
|
response = structured_llm.invoke(msg)
|
||||||
|
|
||||||
|
except openai.BadRequestError:
|
||||||
|
structured_llm = self.llm_endpoint._llm.with_structured_output(
|
||||||
|
SplittedInput
|
||||||
|
)
|
||||||
|
response = structured_llm.invoke(msg)
|
||||||
|
|
||||||
|
send_list: List[Send] = []
|
||||||
|
|
||||||
|
instructions = (
|
||||||
|
response.instructions
|
||||||
|
if response.instructions
|
||||||
|
else self.retrieval_config.prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
if instructions:
|
||||||
|
send_list.append(Send("edit_system_prompt", {"instructions": instructions}))
|
||||||
|
elif response.tasks:
|
||||||
|
chat_history = state["chat_history"]
|
||||||
|
send_list.append(
|
||||||
|
Send(
|
||||||
|
"filter_history",
|
||||||
|
{"chat_history": chat_history, "tasks": response.tasks},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return send_list
|
||||||
|
|
||||||
|
def routing_split(self, state: AgentState):
|
||||||
|
response = self.invoke_structured_output(
|
||||||
|
custom_prompts.SPLIT_PROMPT.format(
|
||||||
|
chat_history=state["chat_history"].to_list(),
|
||||||
|
user_input=state["messages"][0].content,
|
||||||
|
),
|
||||||
|
SplittedInput,
|
||||||
|
)
|
||||||
|
|
||||||
|
instructions = response.instructions or self.retrieval_config.prompt
|
||||||
|
tasks = response.tasks or []
|
||||||
|
|
||||||
|
if instructions:
|
||||||
|
return [
|
||||||
|
Send(
|
||||||
|
"edit_system_prompt",
|
||||||
|
{**state, "instructions": instructions, "tasks": tasks},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
elif tasks:
|
||||||
|
return [Send("filter_history", {**state, "tasks": tasks})]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
def update_active_tools(self, updated_prompt_and_tools: UpdatedPromptAndTools):
|
||||||
|
if updated_prompt_and_tools.tools_to_activate:
|
||||||
|
for tool in updated_prompt_and_tools.tools_to_activate:
|
||||||
|
for (
|
||||||
|
validated_tool
|
||||||
|
) in self.retrieval_config.workflow_config.validated_tools:
|
||||||
|
if tool == validated_tool.name:
|
||||||
|
self.retrieval_config.workflow_config.activated_tools.append(
|
||||||
|
validated_tool
|
||||||
|
)
|
||||||
|
|
||||||
|
if updated_prompt_and_tools.tools_to_deactivate:
|
||||||
|
for tool in updated_prompt_and_tools.tools_to_deactivate:
|
||||||
|
for (
|
||||||
|
activated_tool
|
||||||
|
) in self.retrieval_config.workflow_config.activated_tools:
|
||||||
|
if tool == activated_tool.name:
|
||||||
|
self.retrieval_config.workflow_config.activated_tools.remove(
|
||||||
|
activated_tool
|
||||||
|
)
|
||||||
|
|
||||||
|
def edit_system_prompt(self, state: AgentState) -> AgentState:
|
||||||
|
user_instruction = state["instructions"]
|
||||||
|
prompt = self.retrieval_config.prompt
|
||||||
|
available_tools, activated_tools = collect_tools(
|
||||||
|
self.retrieval_config.workflow_config
|
||||||
|
)
|
||||||
|
inputs = {
|
||||||
|
"instruction": user_instruction,
|
||||||
|
"system_prompt": prompt if prompt else "",
|
||||||
|
"available_tools": available_tools,
|
||||||
|
"activated_tools": activated_tools,
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = custom_prompts.UPDATE_PROMPT.format(**inputs)
|
||||||
|
|
||||||
|
response: UpdatedPromptAndTools = self.invoke_structured_output(
|
||||||
|
msg, UpdatedPromptAndTools
|
||||||
|
)
|
||||||
|
|
||||||
|
self.update_active_tools(response)
|
||||||
|
self.retrieval_config.prompt = response.prompt
|
||||||
|
|
||||||
|
reasoning = [response.prompt_reasoning] if response.prompt_reasoning else []
|
||||||
|
reasoning += [response.tools_reasoning] if response.tools_reasoning else []
|
||||||
|
|
||||||
|
return {**state, "messages": [], "reasoning": reasoning}
|
||||||
|
|
||||||
|
def filter_history(self, state: AgentState) -> AgentState:
|
||||||
|
"""
|
||||||
|
Filter out the chat history to only include the messages that are relevant to the current question
|
||||||
|
|
||||||
|
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '),
|
||||||
|
AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer,
|
||||||
|
sous la direction de son supérieur hiérarchique, Stanislas Girard."),
|
||||||
|
HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''),
|
||||||
|
HumanMessage(content='Dis moi en plus sur elle'),
|
||||||
|
AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
|
||||||
|
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
|
||||||
|
a token is 4 characters
|
||||||
|
"""
|
||||||
|
|
||||||
|
chat_history = state["chat_history"]
|
||||||
|
total_tokens = 0
|
||||||
|
total_pairs = 0
|
||||||
|
_chat_id = uuid4()
|
||||||
|
_chat_history = ChatHistory(chat_id=_chat_id, brain_id=chat_history.brain_id)
|
||||||
|
for human_message, ai_message in reversed(list(chat_history.iter_pairs())):
|
||||||
|
# TODO: replace with tiktoken
|
||||||
|
message_tokens = self.llm_endpoint.count_tokens(
|
||||||
|
human_message.content
|
||||||
|
) + self.llm_endpoint.count_tokens(ai_message.content)
|
||||||
|
if (
|
||||||
|
total_tokens + message_tokens
|
||||||
|
> self.retrieval_config.llm_config.max_context_tokens
|
||||||
|
or total_pairs >= self.retrieval_config.max_history
|
||||||
|
):
|
||||||
|
break
|
||||||
|
_chat_history.append(human_message)
|
||||||
|
_chat_history.append(ai_message)
|
||||||
|
total_tokens += message_tokens
|
||||||
|
total_pairs += 1
|
||||||
|
|
||||||
|
return {**state, "chat_history": _chat_history}
|
||||||
|
|
||||||
|
### Nodes
|
||||||
|
async def rewrite(self, state: AgentState) -> AgentState:
|
||||||
|
"""
|
||||||
|
Transform the query to produce a better question.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (messages): The current state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The updated state with re-phrased question
|
||||||
|
"""
|
||||||
|
|
||||||
|
tasks = (
|
||||||
|
state["tasks"]
|
||||||
|
if "tasks" in state and state["tasks"]
|
||||||
|
else [state["messages"][0].content]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare the async tasks for all user tsks
|
||||||
|
async_tasks = []
|
||||||
|
for task in tasks:
|
||||||
|
msg = custom_prompts.CONDENSE_QUESTION_PROMPT.format(
|
||||||
|
chat_history=state["chat_history"].to_list(),
|
||||||
|
question=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = self.llm_endpoint._llm
|
||||||
|
# Asynchronously invoke the model for each question
|
||||||
|
async_tasks.append(model.ainvoke(msg))
|
||||||
|
|
||||||
|
# Gather all the responses asynchronously
|
||||||
|
responses = await asyncio.gather(*async_tasks) if async_tasks else []
|
||||||
|
|
||||||
|
# Replace each question with its condensed version
|
||||||
|
condensed_questions = []
|
||||||
|
for response in responses:
|
||||||
|
condensed_questions.append(response.content)
|
||||||
|
|
||||||
|
return {**state, "tasks": condensed_questions}
|
||||||
|
|
||||||
|
def filter_chunks_by_relevance(self, chunks: List[Document], **kwargs):
|
||||||
|
config = self.retrieval_config.reranker_config
|
||||||
|
relevance_score_threshold = kwargs.get(
|
||||||
|
"relevance_score_threshold", config.relevance_score_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
if relevance_score_threshold is None:
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
filtered_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
if config.relevance_score_key not in chunk.metadata:
|
||||||
|
logger.warning(
|
||||||
|
f"Relevance score key {config.relevance_score_key} not found in metadata, cannot filter chunks by relevance"
|
||||||
|
)
|
||||||
|
filtered_chunks.append(chunk)
|
||||||
|
elif (
|
||||||
|
chunk.metadata[config.relevance_score_key] >= relevance_score_threshold
|
||||||
|
):
|
||||||
|
filtered_chunks.append(chunk)
|
||||||
|
|
||||||
|
return filtered_chunks
|
||||||
|
|
||||||
|
def tool_routing(self, state: AgentState):
|
||||||
|
tasks = state["tasks"]
|
||||||
|
if not tasks:
|
||||||
|
return [Send("generate_rag", state)]
|
||||||
|
|
||||||
|
docs = state["docs"]
|
||||||
|
|
||||||
|
_, activated_tools = collect_tools(self.retrieval_config.workflow_config)
|
||||||
|
|
||||||
|
input = {
|
||||||
|
"chat_history": state["chat_history"].to_list(),
|
||||||
|
"tasks": state["tasks"],
|
||||||
|
"context": docs,
|
||||||
|
"activated_tools": activated_tools,
|
||||||
|
}
|
||||||
|
|
||||||
|
input, _ = self.reduce_rag_context(
|
||||||
|
inputs=input,
|
||||||
|
prompt=custom_prompts.TOOL_ROUTING_PROMPT,
|
||||||
|
docs=docs,
|
||||||
|
# max_context_tokens=2000,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = custom_prompts.TOOL_ROUTING_PROMPT.format(**input)
|
||||||
|
|
||||||
|
response: TasksCompletion = self.invoke_structured_output(msg, TasksCompletion)
|
||||||
|
|
||||||
|
send_list: List[Send] = []
|
||||||
|
|
||||||
|
if response.non_completable_tasks and response.tool:
|
||||||
|
payload = {
|
||||||
|
**state,
|
||||||
|
"tasks": response.non_completable_tasks,
|
||||||
|
"tool": response.tool,
|
||||||
|
}
|
||||||
|
send_list.append(Send("run_tool", payload))
|
||||||
|
else:
|
||||||
|
send_list.append(Send("generate_rag", state))
|
||||||
|
|
||||||
|
return send_list
|
||||||
|
|
||||||
|
async def run_tool(self, state: AgentState) -> AgentState:
|
||||||
|
tool = state["tool"]
|
||||||
|
if tool not in [
|
||||||
|
t.name for t in self.retrieval_config.workflow_config.activated_tools
|
||||||
|
]:
|
||||||
|
raise ValueError(f"Tool {tool} not activated")
|
||||||
|
|
||||||
|
tasks = state["tasks"]
|
||||||
|
|
||||||
|
tool_wrapper = LLMToolFactory.create_tool(tool, {})
|
||||||
|
|
||||||
|
# Prepare the async tasks for all questions
|
||||||
|
async_tasks = []
|
||||||
|
for task in tasks:
|
||||||
|
formatted_input = tool_wrapper.format_input(task)
|
||||||
|
# Asynchronously invoke the model for each question
|
||||||
|
async_tasks.append(tool_wrapper.tool.ainvoke(formatted_input))
|
||||||
|
|
||||||
|
# Gather all the responses asynchronously
|
||||||
|
responses = await asyncio.gather(*async_tasks) if async_tasks else []
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
for response in responses:
|
||||||
|
_docs = tool_wrapper.format_output(response)
|
||||||
|
_docs = self.filter_chunks_by_relevance(_docs)
|
||||||
|
docs += _docs
|
||||||
|
|
||||||
|
return {**state, "docs": state["docs"] + docs}
|
||||||
|
|
||||||
|
async def retrieve(self, state: AgentState) -> AgentState:
|
||||||
|
"""
|
||||||
|
Retrieve relevent chunks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (messages): The current state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The retrieved chunks
|
||||||
|
"""
|
||||||
|
|
||||||
|
tasks = state["tasks"]
|
||||||
|
if not tasks:
|
||||||
|
return {**state, "docs": []}
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"search_kwargs": {
|
||||||
|
"k": self.retrieval_config.k,
|
||||||
|
}
|
||||||
|
} # type: ignore
|
||||||
|
base_retriever = self.get_retriever(**kwargs)
|
||||||
|
|
||||||
|
kwargs = {"top_n": self.retrieval_config.reranker_config.top_n} # type: ignore
|
||||||
|
reranker = self.get_reranker(**kwargs)
|
||||||
|
|
||||||
|
compression_retriever = ContextualCompressionRetriever(
|
||||||
|
base_compressor=reranker, base_retriever=base_retriever
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare the async tasks for all questions
|
||||||
|
async_tasks = []
|
||||||
|
for task in tasks:
|
||||||
|
# Asynchronously invoke the model for each question
|
||||||
|
async_tasks.append(compression_retriever.ainvoke(task))
|
||||||
|
|
||||||
|
# Gather all the responses asynchronously
|
||||||
|
responses = await asyncio.gather(*async_tasks) if async_tasks else []
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
for response in responses:
|
||||||
|
_docs = self.filter_chunks_by_relevance(response)
|
||||||
|
docs += _docs
|
||||||
|
|
||||||
|
return {**state, "docs": docs}
|
||||||
|
|
||||||
|
async def dynamic_retrieve(self, state: AgentState) -> AgentState:
|
||||||
|
"""
|
||||||
|
Retrieve relevent chunks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (messages): The current state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The retrieved chunks
|
||||||
|
"""
|
||||||
|
|
||||||
|
tasks = state["tasks"]
|
||||||
|
if not tasks:
|
||||||
|
return {**state, "docs": []}
|
||||||
|
|
||||||
|
k = self.retrieval_config.k
|
||||||
|
top_n = self.retrieval_config.reranker_config.top_n
|
||||||
|
number_of_relevant_chunks = top_n
|
||||||
|
i = 1
|
||||||
|
|
||||||
|
while number_of_relevant_chunks == top_n:
|
||||||
|
top_n = self.retrieval_config.reranker_config.top_n * i
|
||||||
|
kwargs = {"top_n": top_n}
|
||||||
|
reranker = self.get_reranker(**kwargs)
|
||||||
|
|
||||||
|
k = max([top_n * 2, self.retrieval_config.k])
|
||||||
|
kwargs = {"search_kwargs": {"k": k}} # type: ignore
|
||||||
|
base_retriever = self.get_retriever(**kwargs)
|
||||||
|
|
||||||
|
if i > 1:
|
||||||
|
logging.info(
|
||||||
|
f"Increasing top_n to {top_n} and k to {k} to retrieve more relevant chunks"
|
||||||
|
)
|
||||||
|
|
||||||
|
compression_retriever = ContextualCompressionRetriever(
|
||||||
|
base_compressor=reranker, base_retriever=base_retriever
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare the async tasks for all questions
|
||||||
|
async_tasks = []
|
||||||
|
for task in tasks:
|
||||||
|
# Asynchronously invoke the model for each question
|
||||||
|
async_tasks.append(compression_retriever.ainvoke(task))
|
||||||
|
|
||||||
|
# Gather all the responses asynchronously
|
||||||
|
responses = await asyncio.gather(*async_tasks) if async_tasks else []
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
_n = []
|
||||||
|
for response in responses:
|
||||||
|
_docs = self.filter_chunks_by_relevance(response)
|
||||||
|
_n.append(len(_docs))
|
||||||
|
docs += _docs
|
||||||
|
|
||||||
|
if not docs:
|
||||||
|
break
|
||||||
|
|
||||||
|
context_length = self.get_rag_context_length(state, docs)
|
||||||
|
if context_length >= self.retrieval_config.llm_config.max_context_tokens:
|
||||||
|
logging.warning(
|
||||||
|
f"The context length is {context_length} which is greater than "
|
||||||
|
f"the max context tokens of {self.retrieval_config.llm_config.max_context_tokens}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
number_of_relevant_chunks = max(_n)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return {**state, "docs": docs}
|
||||||
|
|
||||||
|
def get_rag_context_length(self, state: AgentState, docs: List[Document]) -> int:
|
||||||
|
final_inputs = self._build_rag_prompt_inputs(state, docs)
|
||||||
|
msg = custom_prompts.RAG_ANSWER_PROMPT.format(**final_inputs)
|
||||||
|
return self.llm_endpoint.count_tokens(msg)
|
||||||
|
|
||||||
|
def reduce_rag_context(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
prompt: BasePromptTemplate,
|
||||||
|
docs: List[Document] | None = None,
|
||||||
|
max_context_tokens: int | None = None,
|
||||||
|
) -> Tuple[Dict[str, Any], List[Document] | None]:
|
||||||
|
MAX_ITERATION = 100
|
||||||
|
SECURITY_FACTOR = 0.85
|
||||||
|
iteration = 0
|
||||||
|
|
||||||
|
msg = prompt.format(**inputs)
|
||||||
|
n = self.llm_endpoint.count_tokens(msg)
|
||||||
|
|
||||||
|
max_context_tokens = (
|
||||||
|
max_context_tokens
|
||||||
|
if max_context_tokens
|
||||||
|
else self.retrieval_config.llm_config.max_context_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
while n > max_context_tokens * SECURITY_FACTOR:
|
||||||
|
chat_history = inputs["chat_history"] if "chat_history" in inputs else []
|
||||||
|
|
||||||
|
if len(chat_history) > 0:
|
||||||
|
inputs["chat_history"] = chat_history[2:]
|
||||||
|
elif docs and len(docs) > 1:
|
||||||
|
docs = docs[:-1]
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Not enough context to reduce. The context length is {n} "
|
||||||
|
f"which is greater than the max context tokens of {max_context_tokens}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if docs and "context" in inputs:
|
||||||
|
inputs["context"] = combine_documents(docs)
|
||||||
|
|
||||||
|
msg = prompt.format(**inputs)
|
||||||
|
n = self.llm_endpoint.count_tokens(msg)
|
||||||
|
|
||||||
|
iteration += 1
|
||||||
|
if iteration > MAX_ITERATION:
|
||||||
|
logging.warning(
|
||||||
|
f"Attained the maximum number of iterations ({MAX_ITERATION})"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
return inputs, docs
|
||||||
|
|
||||||
|
def bind_tools_to_llm(self, node_name: str):
|
||||||
|
if self.llm_endpoint.supports_func_calling():
|
||||||
|
tools = self.retrieval_config.workflow_config.get_node_tools(node_name)
|
||||||
|
if tools: # Only bind tools if there are any available
|
||||||
|
return self.llm_endpoint._llm.bind_tools(tools, tool_choice="any")
|
||||||
|
return self.llm_endpoint._llm
|
||||||
|
|
||||||
|
def generate_rag(self, state: AgentState) -> AgentState:
|
||||||
|
docs: List[Document] | None = state["docs"]
|
||||||
|
final_inputs = self._build_rag_prompt_inputs(state, docs)
|
||||||
|
|
||||||
|
reduced_inputs, docs = self.reduce_rag_context(
|
||||||
|
final_inputs, custom_prompts.RAG_ANSWER_PROMPT, docs
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = custom_prompts.RAG_ANSWER_PROMPT.format(**reduced_inputs)
|
||||||
|
llm = self.bind_tools_to_llm(self.generate_rag.__name__)
|
||||||
|
response = llm.invoke(msg)
|
||||||
|
|
||||||
|
return {**state, "messages": [response], "docs": docs if docs else []}
|
||||||
|
|
||||||
|
def generate_chat_llm(self, state: AgentState) -> AgentState:
|
||||||
|
"""
|
||||||
|
Generate answer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (messages): The current state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The updated state with re-phrased question
|
||||||
|
"""
|
||||||
|
messages = state["messages"]
|
||||||
|
user_question = messages[0].content
|
||||||
|
|
||||||
|
# Prompt
|
||||||
|
prompt = self.retrieval_config.prompt
|
||||||
|
|
||||||
|
final_inputs = {}
|
||||||
|
final_inputs["question"] = user_question
|
||||||
|
final_inputs["custom_instructions"] = prompt if prompt else "None"
|
||||||
|
final_inputs["chat_history"] = state["chat_history"].to_list()
|
||||||
|
|
||||||
|
# LLM
|
||||||
|
llm = self.llm_endpoint._llm
|
||||||
|
|
||||||
|
reduced_inputs, _ = self.reduce_rag_context(
|
||||||
|
final_inputs, custom_prompts.CHAT_LLM_PROMPT, None
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = custom_prompts.CHAT_LLM_PROMPT.format(**reduced_inputs)
|
||||||
|
|
||||||
|
# Run
|
||||||
|
response = llm.invoke(msg)
|
||||||
|
return {**state, "messages": [response]}
|
||||||
|
|
||||||
|
def build_chain(self):
|
||||||
|
"""
|
||||||
|
Builds the langchain chain for the given configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[Dict], Dict]: The langchain chain.
|
||||||
|
"""
|
||||||
|
if not self.graph:
|
||||||
|
self.graph = self.create_graph()
|
||||||
|
|
||||||
|
return self.graph
|
||||||
|
|
||||||
|
def create_graph(self):
|
||||||
|
workflow = StateGraph(AgentState)
|
||||||
|
self.final_nodes = []
|
||||||
|
|
||||||
|
self._build_workflow(workflow)
|
||||||
|
|
||||||
|
return workflow.compile()
|
||||||
|
|
||||||
|
def _build_workflow(self, workflow: StateGraph):
|
||||||
|
for node in self.retrieval_config.workflow_config.nodes:
|
||||||
|
if node.name not in [START, END]:
|
||||||
|
workflow.add_node(node.name, getattr(self, node.name))
|
||||||
|
|
||||||
|
for node in self.retrieval_config.workflow_config.nodes:
|
||||||
|
self._add_node_edges(workflow, node)
|
||||||
|
|
||||||
|
def _add_node_edges(self, workflow: StateGraph, node: NodeConfig):
|
||||||
|
if node.edges:
|
||||||
|
for edge in node.edges:
|
||||||
|
workflow.add_edge(node.name, edge)
|
||||||
|
if edge == END:
|
||||||
|
self.final_nodes.append(node.name)
|
||||||
|
elif node.conditional_edge:
|
||||||
|
routing_function = getattr(self, node.conditional_edge.routing_function)
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
node.name, routing_function, node.conditional_edge.conditions
|
||||||
|
)
|
||||||
|
if END in node.conditional_edge.conditions:
|
||||||
|
self.final_nodes.append(node.name)
|
||||||
|
else:
|
||||||
|
raise ValueError("Node should have at least one edge or conditional_edge")
|
||||||
|
|
||||||
|
async def answer_astream(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
history: ChatHistory,
|
||||||
|
list_files: list[QuivrKnowledge],
|
||||||
|
metadata: dict[str, str] = {},
|
||||||
|
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
|
||||||
|
"""
|
||||||
|
Answer a question using the langgraph chain and yield each chunk of the answer separately.
|
||||||
|
"""
|
||||||
|
concat_list_files = format_file_list(
|
||||||
|
list_files, self.retrieval_config.max_files
|
||||||
|
)
|
||||||
|
conversational_qa_chain = self.build_chain()
|
||||||
|
|
||||||
|
rolling_message = AIMessageChunk(content="")
|
||||||
|
docs: list[Document] | None = None
|
||||||
|
previous_content = ""
|
||||||
|
|
||||||
|
async for event in conversational_qa_chain.astream_events(
|
||||||
|
{
|
||||||
|
"messages": [("user", question)],
|
||||||
|
"chat_history": history,
|
||||||
|
"files": concat_list_files,
|
||||||
|
},
|
||||||
|
version="v1",
|
||||||
|
config={"metadata": metadata},
|
||||||
|
):
|
||||||
|
if self._is_final_node_with_docs(event):
|
||||||
|
docs = event["data"]["output"]["docs"]
|
||||||
|
|
||||||
|
if self._is_final_node_and_chat_model_stream(event):
|
||||||
|
chunk = event["data"]["chunk"]
|
||||||
|
rolling_message, new_content, previous_content = parse_chunk_response(
|
||||||
|
rolling_message,
|
||||||
|
chunk,
|
||||||
|
self.llm_endpoint.supports_func_calling(),
|
||||||
|
previous_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_content:
|
||||||
|
chunk_metadata = get_chunk_metadata(rolling_message, docs)
|
||||||
|
yield ParsedRAGChunkResponse(
|
||||||
|
answer=new_content, metadata=chunk_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield final metadata chunk
|
||||||
|
yield ParsedRAGChunkResponse(
|
||||||
|
answer="",
|
||||||
|
metadata=get_chunk_metadata(rolling_message, docs),
|
||||||
|
last_chunk=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_final_node_with_docs(self, event: dict) -> bool:
|
||||||
|
return (
|
||||||
|
"output" in event["data"]
|
||||||
|
and event["data"]["output"] is not None
|
||||||
|
and "docs" in event["data"]["output"]
|
||||||
|
and event["metadata"]["langgraph_node"] in self.final_nodes
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_final_node_and_chat_model_stream(self, event: dict) -> bool:
|
||||||
|
return (
|
||||||
|
event["event"] == "on_chat_model_stream"
|
||||||
|
and "langgraph_node" in event["metadata"]
|
||||||
|
and event["metadata"]["langgraph_node"] in self.final_nodes
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke_structured_output(
|
||||||
|
self, prompt: str, output_class: Type[BaseModel]
|
||||||
|
) -> Any:
|
||||||
|
try:
|
||||||
|
structured_llm = self.llm_endpoint._llm.with_structured_output(
|
||||||
|
output_class, method="json_schema"
|
||||||
|
)
|
||||||
|
return structured_llm.invoke(prompt)
|
||||||
|
except openai.BadRequestError:
|
||||||
|
structured_llm = self.llm_endpoint._llm.with_structured_output(output_class)
|
||||||
|
return structured_llm.invoke(prompt)
|
||||||
|
|
||||||
|
def _build_rag_prompt_inputs(
|
||||||
|
self, state: AgentState, docs: List[Document] | None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Build the input dictionary for RAG_ANSWER_PROMPT.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current agent state
|
||||||
|
docs: List of documents or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing all inputs needed for RAG_ANSWER_PROMPT
|
||||||
|
"""
|
||||||
|
messages = state["messages"]
|
||||||
|
user_question = messages[0].content
|
||||||
|
files = state["files"]
|
||||||
|
prompt = self.retrieval_config.prompt
|
||||||
|
available_tools, _ = collect_tools(self.retrieval_config.workflow_config)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"context": combine_documents(docs) if docs else "None",
|
||||||
|
"question": user_question,
|
||||||
|
"rephrased_task": state["tasks"],
|
||||||
|
"custom_instructions": prompt if prompt else "None",
|
||||||
|
"files": files if files else "None",
|
||||||
|
"chat_history": state["chat_history"].to_list(),
|
||||||
|
"reasoning": state["reasoning"] if "reasoning" in state else "None",
|
||||||
|
"tools": available_tools,
|
||||||
|
}
|
188
core/quivr_core/rag/utils.py
Normal file
188
core/quivr_core/rag/utils.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, List, Tuple, no_type_check
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
|
from langchain_core.messages.ai import AIMessageChunk
|
||||||
|
from langchain_core.prompts import format_document
|
||||||
|
|
||||||
|
from quivr_core.rag.entities.config import WorkflowConfig
|
||||||
|
from quivr_core.rag.entities.models import (
|
||||||
|
ChatLLMMetadata,
|
||||||
|
ParsedRAGResponse,
|
||||||
|
QuivrKnowledge,
|
||||||
|
RAGResponseMetadata,
|
||||||
|
RawRAGResponse,
|
||||||
|
)
|
||||||
|
from quivr_core.rag.prompts import custom_prompts
|
||||||
|
|
||||||
|
# TODO(@aminediro): define a types packages where we clearly define IO types
|
||||||
|
# This should be used for serialization/deseriallization later
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("quivr_core")
|
||||||
|
|
||||||
|
|
||||||
|
def model_supports_function_calling(model_name: str):
|
||||||
|
models_not_supporting_function_calls: list[str] = ["llama2", "test", "ollama3"]
|
||||||
|
|
||||||
|
return model_name not in models_not_supporting_function_calls
|
||||||
|
|
||||||
|
|
||||||
|
def format_history_to_openai_mesages(
|
||||||
|
tuple_history: List[Tuple[str, str]], system_message: str, question: str
|
||||||
|
) -> List[BaseMessage]:
|
||||||
|
"""Format the chat history into a list of Base Messages"""
|
||||||
|
messages = []
|
||||||
|
messages.append(SystemMessage(content=system_message))
|
||||||
|
for human, ai in tuple_history:
|
||||||
|
messages.append(HumanMessage(content=human))
|
||||||
|
messages.append(AIMessage(content=ai))
|
||||||
|
messages.append(HumanMessage(content=question))
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def cited_answer_filter(tool):
|
||||||
|
return tool["name"] == "cited_answer"
|
||||||
|
|
||||||
|
|
||||||
|
def get_chunk_metadata(
|
||||||
|
msg: AIMessageChunk, sources: list[Any] | None = None
|
||||||
|
) -> RAGResponseMetadata:
|
||||||
|
metadata = {"sources": sources or []}
|
||||||
|
|
||||||
|
if not msg.tool_calls:
|
||||||
|
return RAGResponseMetadata(**metadata, metadata_model=None)
|
||||||
|
|
||||||
|
all_citations = []
|
||||||
|
all_followup_questions = []
|
||||||
|
|
||||||
|
for tool_call in msg.tool_calls:
|
||||||
|
if tool_call.get("name") == "cited_answer" and "args" in tool_call:
|
||||||
|
args = tool_call["args"]
|
||||||
|
all_citations.extend(args.get("citations", []))
|
||||||
|
all_followup_questions.extend(args.get("followup_questions", []))
|
||||||
|
|
||||||
|
metadata["citations"] = all_citations
|
||||||
|
metadata["followup_questions"] = all_followup_questions[:3] # Limit to 3
|
||||||
|
|
||||||
|
return RAGResponseMetadata(**metadata, metadata_model=None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_prev_message_str(msg: AIMessageChunk) -> str:
|
||||||
|
if msg.tool_calls:
|
||||||
|
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
|
||||||
|
if "args" in cited_answer and "answer" in cited_answer["args"]:
|
||||||
|
return cited_answer["args"]["answer"]
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: CONVOLUTED LOGIC !
|
||||||
|
# TODO(@aminediro): redo this
|
||||||
|
@no_type_check
|
||||||
|
def parse_chunk_response(
|
||||||
|
rolling_msg: AIMessageChunk,
|
||||||
|
raw_chunk: AIMessageChunk,
|
||||||
|
supports_func_calling: bool,
|
||||||
|
previous_content: str = "",
|
||||||
|
) -> Tuple[AIMessageChunk, str, str]:
|
||||||
|
"""Parse a chunk response
|
||||||
|
Args:
|
||||||
|
rolling_msg: The accumulated message so far
|
||||||
|
raw_chunk: The new chunk to add
|
||||||
|
supports_func_calling: Whether function calling is supported
|
||||||
|
previous_content: The previous content string
|
||||||
|
Returns:
|
||||||
|
Tuple of (updated rolling message, new content only, full content)
|
||||||
|
"""
|
||||||
|
rolling_msg += raw_chunk
|
||||||
|
|
||||||
|
if not supports_func_calling or not rolling_msg.tool_calls:
|
||||||
|
new_content = raw_chunk.content # Just the new chunk's content
|
||||||
|
full_content = rolling_msg.content # The full accumulated content
|
||||||
|
return rolling_msg, new_content, full_content
|
||||||
|
|
||||||
|
current_answers = get_answers_from_tool_calls(rolling_msg.tool_calls)
|
||||||
|
full_answer = "\n\n".join(current_answers)
|
||||||
|
new_content = full_answer[len(previous_content) :]
|
||||||
|
|
||||||
|
return rolling_msg, new_content, full_answer
|
||||||
|
|
||||||
|
|
||||||
|
def get_answers_from_tool_calls(tool_calls):
|
||||||
|
answers = []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
if tool_call.get("name") == "cited_answer" and "args" in tool_call:
|
||||||
|
answers.append(tool_call["args"].get("answer", ""))
|
||||||
|
return answers
|
||||||
|
|
||||||
|
|
||||||
|
@no_type_check
|
||||||
|
def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse:
|
||||||
|
answers = []
|
||||||
|
sources = raw_response["docs"] if "docs" in raw_response else []
|
||||||
|
|
||||||
|
metadata = RAGResponseMetadata(
|
||||||
|
sources=sources, metadata_model=ChatLLMMetadata(name=model_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
model_supports_function_calling(model_name)
|
||||||
|
and "tool_calls" in raw_response["answer"]
|
||||||
|
and raw_response["answer"].tool_calls
|
||||||
|
):
|
||||||
|
all_citations = []
|
||||||
|
all_followup_questions = []
|
||||||
|
for tool_call in raw_response["answer"].tool_calls:
|
||||||
|
if "args" in tool_call:
|
||||||
|
args = tool_call["args"]
|
||||||
|
if "citations" in args:
|
||||||
|
all_citations.extend(args["citations"])
|
||||||
|
if "followup_questions" in args:
|
||||||
|
all_followup_questions.extend(args["followup_questions"])
|
||||||
|
if "answer" in args:
|
||||||
|
answers.append(args["answer"])
|
||||||
|
metadata.citations = all_citations
|
||||||
|
metadata.followup_questions = all_followup_questions
|
||||||
|
else:
|
||||||
|
answers.append(raw_response["answer"].content)
|
||||||
|
|
||||||
|
answer_str = "\n".join(answers)
|
||||||
|
parsed_response = ParsedRAGResponse(answer=answer_str, metadata=metadata)
|
||||||
|
return parsed_response
|
||||||
|
|
||||||
|
|
||||||
|
def combine_documents(
|
||||||
|
docs,
|
||||||
|
document_prompt=custom_prompts.DEFAULT_DOCUMENT_PROMPT,
|
||||||
|
document_separator="\n\n",
|
||||||
|
):
|
||||||
|
# for each docs, add an index in the metadata to be able to cite the sources
|
||||||
|
for doc, index in zip(docs, range(len(docs)), strict=False):
|
||||||
|
doc.metadata["index"] = index
|
||||||
|
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
||||||
|
return document_separator.join(doc_strings)
|
||||||
|
|
||||||
|
|
||||||
|
def format_file_list(
|
||||||
|
list_files_array: list[QuivrKnowledge], max_files: int = 20
|
||||||
|
) -> str:
|
||||||
|
list_files = [file.file_name or file.url for file in list_files_array]
|
||||||
|
files: list[str] = list(filter(lambda n: n is not None, list_files)) # type: ignore
|
||||||
|
files = files[:max_files]
|
||||||
|
|
||||||
|
files_str = "\n".join(files) if list_files_array else "None"
|
||||||
|
return files_str
|
||||||
|
|
||||||
|
|
||||||
|
def collect_tools(workflow_config: WorkflowConfig):
|
||||||
|
validated_tools = "Available tools which can be activated:\n"
|
||||||
|
for i, tool in enumerate(workflow_config.validated_tools):
|
||||||
|
validated_tools += f"Tool {i+1} name: {tool.name}\n"
|
||||||
|
validated_tools += f"Tool {i+1} description: {tool.description}\n\n"
|
||||||
|
|
||||||
|
activated_tools = "Activated tools which can be deactivated:\n"
|
||||||
|
for i, tool in enumerate(workflow_config.activated_tools):
|
||||||
|
activated_tools += f"Tool {i+1} name: {tool.name}\n"
|
||||||
|
activated_tools += f"Tool {i+1} description: {tool.description}\n\n"
|
||||||
|
|
||||||
|
return validated_tools, activated_tools
|
@ -1,167 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Any, List, Tuple, no_type_check
|
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
|
||||||
from langchain_core.messages.ai import AIMessageChunk
|
|
||||||
from langchain_core.prompts import format_document
|
|
||||||
|
|
||||||
from quivr_core.models import (
|
|
||||||
ChatLLMMetadata,
|
|
||||||
ParsedRAGResponse,
|
|
||||||
QuivrKnowledge,
|
|
||||||
RAGResponseMetadata,
|
|
||||||
RawRAGResponse,
|
|
||||||
)
|
|
||||||
from quivr_core.prompts import custom_prompts
|
|
||||||
|
|
||||||
# TODO(@aminediro): define a types packages where we clearly define IO types
|
|
||||||
# This should be used for serialization/deseriallization later
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("quivr_core")
|
|
||||||
|
|
||||||
|
|
||||||
def model_supports_function_calling(model_name: str):
|
|
||||||
models_supporting_function_calls = [
|
|
||||||
"gpt-4",
|
|
||||||
"gpt-4-1106-preview",
|
|
||||||
"gpt-4-0613",
|
|
||||||
"gpt-4o",
|
|
||||||
"gpt-3.5-turbo-1106",
|
|
||||||
"gpt-3.5-turbo-0613",
|
|
||||||
"gpt-4-0125-preview",
|
|
||||||
"gpt-3.5-turbo",
|
|
||||||
"gpt-4-turbo",
|
|
||||||
"gpt-4o",
|
|
||||||
"gpt-4o-mini",
|
|
||||||
]
|
|
||||||
return model_name in models_supporting_function_calls
|
|
||||||
|
|
||||||
|
|
||||||
def format_history_to_openai_mesages(
|
|
||||||
tuple_history: List[Tuple[str, str]], system_message: str, question: str
|
|
||||||
) -> List[BaseMessage]:
|
|
||||||
"""Format the chat history into a list of Base Messages"""
|
|
||||||
messages = []
|
|
||||||
messages.append(SystemMessage(content=system_message))
|
|
||||||
for human, ai in tuple_history:
|
|
||||||
messages.append(HumanMessage(content=human))
|
|
||||||
messages.append(AIMessage(content=ai))
|
|
||||||
messages.append(HumanMessage(content=question))
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
def cited_answer_filter(tool):
|
|
||||||
return tool["name"] == "cited_answer"
|
|
||||||
|
|
||||||
|
|
||||||
def get_chunk_metadata(
|
|
||||||
msg: AIMessageChunk, sources: list[Any] | None = None
|
|
||||||
) -> RAGResponseMetadata:
|
|
||||||
# Initiate the source
|
|
||||||
metadata = {"sources": sources} if sources else {"sources": []}
|
|
||||||
if msg.tool_calls:
|
|
||||||
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
|
|
||||||
|
|
||||||
if "args" in cited_answer:
|
|
||||||
gathered_args = cited_answer["args"]
|
|
||||||
if "citations" in gathered_args:
|
|
||||||
citations = gathered_args["citations"]
|
|
||||||
metadata["citations"] = citations
|
|
||||||
|
|
||||||
if "followup_questions" in gathered_args:
|
|
||||||
followup_questions = gathered_args["followup_questions"]
|
|
||||||
metadata["followup_questions"] = followup_questions
|
|
||||||
|
|
||||||
return RAGResponseMetadata(**metadata, metadata_model=None)
|
|
||||||
|
|
||||||
|
|
||||||
def get_prev_message_str(msg: AIMessageChunk) -> str:
|
|
||||||
if msg.tool_calls:
|
|
||||||
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
|
|
||||||
if "args" in cited_answer and "answer" in cited_answer["args"]:
|
|
||||||
return cited_answer["args"]["answer"]
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: CONVOLUTED LOGIC !
|
|
||||||
# TODO(@aminediro): redo this
|
|
||||||
@no_type_check
|
|
||||||
def parse_chunk_response(
|
|
||||||
rolling_msg: AIMessageChunk,
|
|
||||||
raw_chunk: dict[str, Any],
|
|
||||||
supports_func_calling: bool,
|
|
||||||
) -> Tuple[AIMessageChunk, str]:
|
|
||||||
# Init with sources
|
|
||||||
answer_str = ""
|
|
||||||
|
|
||||||
if "answer" in raw_chunk:
|
|
||||||
answer = raw_chunk["answer"]
|
|
||||||
else:
|
|
||||||
answer = raw_chunk
|
|
||||||
|
|
||||||
rolling_msg += answer
|
|
||||||
if supports_func_calling and rolling_msg.tool_calls:
|
|
||||||
cited_answer = next(x for x in rolling_msg.tool_calls if cited_answer_filter(x))
|
|
||||||
if "args" in cited_answer and "answer" in cited_answer["args"]:
|
|
||||||
gathered_args = cited_answer["args"]
|
|
||||||
# Only send the difference between answer and response_tokens which was the previous answer
|
|
||||||
answer_str = gathered_args["answer"]
|
|
||||||
return rolling_msg, answer_str
|
|
||||||
|
|
||||||
return rolling_msg, answer.content
|
|
||||||
|
|
||||||
|
|
||||||
@no_type_check
|
|
||||||
def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGResponse:
|
|
||||||
answer = ""
|
|
||||||
sources = raw_response["docs"] if "docs" in raw_response else []
|
|
||||||
|
|
||||||
metadata = RAGResponseMetadata(
|
|
||||||
sources=sources, metadata_model=ChatLLMMetadata(name=model_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
model_supports_function_calling(model_name)
|
|
||||||
and "tool_calls" in raw_response["answer"]
|
|
||||||
and raw_response["answer"].tool_calls
|
|
||||||
):
|
|
||||||
if "citations" in raw_response["answer"].tool_calls[-1]["args"]:
|
|
||||||
citations = raw_response["answer"].tool_calls[-1]["args"]["citations"]
|
|
||||||
metadata.citations = citations
|
|
||||||
followup_questions = raw_response["answer"].tool_calls[-1]["args"][
|
|
||||||
"followup_questions"
|
|
||||||
]
|
|
||||||
if followup_questions:
|
|
||||||
metadata.followup_questions = followup_questions
|
|
||||||
answer = raw_response["answer"].tool_calls[-1]["args"]["answer"]
|
|
||||||
else:
|
|
||||||
answer = raw_response["answer"].tool_calls[-1]["args"]["answer"]
|
|
||||||
else:
|
|
||||||
answer = raw_response["answer"].content
|
|
||||||
|
|
||||||
parsed_response = ParsedRAGResponse(answer=answer, metadata=metadata)
|
|
||||||
return parsed_response
|
|
||||||
|
|
||||||
|
|
||||||
def combine_documents(
|
|
||||||
docs,
|
|
||||||
document_prompt=custom_prompts.DEFAULT_DOCUMENT_PROMPT,
|
|
||||||
document_separator="\n\n",
|
|
||||||
):
|
|
||||||
# for each docs, add an index in the metadata to be able to cite the sources
|
|
||||||
for doc, index in zip(docs, range(len(docs)), strict=False):
|
|
||||||
doc.metadata["index"] = index
|
|
||||||
doc_strings = [format_document(doc, document_prompt) for doc in docs]
|
|
||||||
return document_separator.join(doc_strings)
|
|
||||||
|
|
||||||
|
|
||||||
def format_file_list(
|
|
||||||
list_files_array: list[QuivrKnowledge], max_files: int = 20
|
|
||||||
) -> str:
|
|
||||||
list_files = [file.file_name or file.url for file in list_files_array]
|
|
||||||
files: list[str] = list(filter(lambda n: n is not None, list_files)) # type: ignore
|
|
||||||
files = files[:max_files]
|
|
||||||
|
|
||||||
files_str = "\n".join(files) if list_files_array else "None"
|
|
||||||
return files_str
|
|
@ -27,6 +27,8 @@ anyio==4.6.2.post1
|
|||||||
# via anthropic
|
# via anthropic
|
||||||
# via httpx
|
# via httpx
|
||||||
# via openai
|
# via openai
|
||||||
|
appnope==0.1.4
|
||||||
|
# via ipykernel
|
||||||
asttokens==2.4.1
|
asttokens==2.4.1
|
||||||
# via stack-data
|
# via stack-data
|
||||||
attrs==24.2.0
|
attrs==24.2.0
|
||||||
@ -80,8 +82,6 @@ frozenlist==1.4.1
|
|||||||
# via aiosignal
|
# via aiosignal
|
||||||
fsspec==2024.9.0
|
fsspec==2024.9.0
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
greenlet==3.1.1
|
|
||||||
# via sqlalchemy
|
|
||||||
h11==0.14.0
|
h11==0.14.0
|
||||||
# via httpcore
|
# via httpcore
|
||||||
httpcore==1.0.6
|
httpcore==1.0.6
|
||||||
@ -278,6 +278,8 @@ pyyaml==6.0.2
|
|||||||
pyzmq==26.2.0
|
pyzmq==26.2.0
|
||||||
# via ipykernel
|
# via ipykernel
|
||||||
# via jupyter-client
|
# via jupyter-client
|
||||||
|
rapidfuzz==3.10.1
|
||||||
|
# via quivr-core
|
||||||
regex==2024.9.11
|
regex==2024.9.11
|
||||||
# via tiktoken
|
# via tiktoken
|
||||||
# via transformers
|
# via transformers
|
||||||
|
@ -56,8 +56,6 @@ frozenlist==1.4.1
|
|||||||
# via aiosignal
|
# via aiosignal
|
||||||
fsspec==2024.9.0
|
fsspec==2024.9.0
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
greenlet==3.1.1
|
|
||||||
# via sqlalchemy
|
|
||||||
h11==0.14.0
|
h11==0.14.0
|
||||||
# via httpcore
|
# via httpcore
|
||||||
httpcore==1.0.6
|
httpcore==1.0.6
|
||||||
@ -185,6 +183,8 @@ pyyaml==6.0.2
|
|||||||
# via langchain-community
|
# via langchain-community
|
||||||
# via langchain-core
|
# via langchain-core
|
||||||
# via transformers
|
# via transformers
|
||||||
|
rapidfuzz==3.10.1
|
||||||
|
# via quivr-core
|
||||||
regex==2024.9.11
|
regex==2024.9.11
|
||||||
# via tiktoken
|
# via tiktoken
|
||||||
# via transformers
|
# via transformers
|
||||||
|
@ -9,7 +9,7 @@ from langchain_core.language_models import FakeListChatModel
|
|||||||
from langchain_core.messages.ai import AIMessageChunk
|
from langchain_core.messages.ai import AIMessageChunk
|
||||||
from langchain_core.runnables.utils import AddableDict
|
from langchain_core.runnables.utils import AddableDict
|
||||||
from langchain_core.vectorstores import InMemoryVectorStore
|
from langchain_core.vectorstores import InMemoryVectorStore
|
||||||
from quivr_core.config import LLMEndpointConfig
|
from quivr_core.rag.entities.config import LLMEndpointConfig
|
||||||
from quivr_core.files.file import FileExtension, QuivrFile
|
from quivr_core.files.file import FileExtension, QuivrFile
|
||||||
from quivr_core.llm import LLMEndpoint
|
from quivr_core.llm import LLMEndpoint
|
||||||
|
|
||||||
|
@ -5,10 +5,10 @@ from uuid import uuid4
|
|||||||
from langchain_core.embeddings import DeterministicFakeEmbedding
|
from langchain_core.embeddings import DeterministicFakeEmbedding
|
||||||
from langchain_core.messages.ai import AIMessageChunk
|
from langchain_core.messages.ai import AIMessageChunk
|
||||||
from langchain_core.vectorstores import InMemoryVectorStore
|
from langchain_core.vectorstores import InMemoryVectorStore
|
||||||
from quivr_core.chat import ChatHistory
|
from quivr_core.rag.entities.chat import ChatHistory
|
||||||
from quivr_core.config import LLMEndpointConfig, RetrievalConfig
|
from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig
|
||||||
from quivr_core.llm import LLMEndpoint
|
from quivr_core.llm import LLMEndpoint
|
||||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
@ -27,9 +27,9 @@ retrieval_config:
|
|||||||
supplier: "openai"
|
supplier: "openai"
|
||||||
|
|
||||||
# The model to use for the LLM for the given supplier
|
# The model to use for the LLM for the given supplier
|
||||||
model: "gpt-4o"
|
model: "gpt-3.5-turbo-0125"
|
||||||
|
|
||||||
max_input_tokens: 2000
|
max_context_tokens: 2000
|
||||||
|
|
||||||
# Maximum number of tokens to pass to the LLM
|
# Maximum number of tokens to pass to the LLM
|
||||||
# as a context to generate the answer
|
# as a context to generate the answer
|
||||||
|
@ -40,9 +40,9 @@ retrieval_config:
|
|||||||
supplier: "openai"
|
supplier: "openai"
|
||||||
|
|
||||||
# The model to use for the LLM for the given supplier
|
# The model to use for the LLM for the given supplier
|
||||||
model: "gpt-4o"
|
model: "gpt-3.5-turbo-0125"
|
||||||
|
|
||||||
max_input_tokens: 2000
|
max_context_tokens: 2000
|
||||||
|
|
||||||
# Maximum number of tokens to pass to the LLM
|
# Maximum number of tokens to pass to the LLM
|
||||||
# as a context to generate the answer
|
# as a context to generate the answer
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from quivr_core.brain import Brain
|
from quivr_core.brain import Brain
|
||||||
from quivr_core.chat import ChatHistory
|
from quivr_core.rag.entities.chat import ChatHistory
|
||||||
from quivr_core.llm import LLMEndpoint
|
from quivr_core.llm import LLMEndpoint
|
||||||
from quivr_core.storage.local_storage import TransparentStorage
|
from quivr_core.storage.local_storage import TransparentStorage
|
||||||
|
|
||||||
@ -104,8 +104,8 @@ async def test_brain_get_history(
|
|||||||
vector_db=mem_vector_store,
|
vector_db=mem_vector_store,
|
||||||
)
|
)
|
||||||
|
|
||||||
brain.ask("question")
|
await brain.aask("question")
|
||||||
brain.ask("question")
|
await brain.aask("question")
|
||||||
|
|
||||||
assert len(brain.default_chat) == 4
|
assert len(brain.default_chat) == 4
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from uuid import uuid4
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
from quivr_core.chat import ChatHistory
|
from quivr_core.rag.entities.chat import ChatHistory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -1,18 +1,21 @@
|
|||||||
from quivr_core.config import LLMEndpointConfig, RetrievalConfig
|
from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig
|
||||||
|
|
||||||
|
|
||||||
def test_default_llm_config():
|
def test_default_llm_config():
|
||||||
config = LLMEndpointConfig()
|
config = LLMEndpointConfig()
|
||||||
|
|
||||||
assert config.model_dump(exclude={"llm_api_key"}) == LLMEndpointConfig(
|
assert (
|
||||||
model="gpt-4o",
|
config.model_dump()
|
||||||
llm_base_url=None,
|
== LLMEndpointConfig(
|
||||||
llm_api_key=None,
|
model="gpt-4o",
|
||||||
max_input_tokens=2000,
|
llm_base_url=None,
|
||||||
max_output_tokens=2000,
|
llm_api_key=None,
|
||||||
temperature=0.7,
|
max_context_tokens=2000,
|
||||||
streaming=True,
|
max_output_tokens=2000,
|
||||||
).model_dump(exclude={"llm_api_key"})
|
temperature=0.7,
|
||||||
|
streaming=True,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_default_retrievalconfig():
|
def test_default_retrievalconfig():
|
||||||
@ -20,6 +23,6 @@ def test_default_retrievalconfig():
|
|||||||
|
|
||||||
assert config.max_files == 20
|
assert config.max_files == 20
|
||||||
assert config.prompt is None
|
assert config.prompt is None
|
||||||
assert config.llm_config.model_dump(
|
print("\n\n", config.llm_config, "\n\n")
|
||||||
exclude={"llm_api_key"}
|
print("\n\n", LLMEndpointConfig(), "\n\n")
|
||||||
) == LLMEndpointConfig().model_dump(exclude={"llm_api_key"})
|
assert config.llm_config == LLMEndpointConfig()
|
||||||
|
@ -3,7 +3,7 @@ import os
|
|||||||
import pytest
|
import pytest
|
||||||
from langchain_core.language_models import FakeListChatModel
|
from langchain_core.language_models import FakeListChatModel
|
||||||
from pydantic.v1.error_wrappers import ValidationError
|
from pydantic.v1.error_wrappers import ValidationError
|
||||||
from quivr_core.config import LLMEndpointConfig
|
from quivr_core.rag.entities.config import LLMEndpointConfig
|
||||||
from quivr_core.llm import LLMEndpoint
|
from quivr_core.llm import LLMEndpoint
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,25 +1,51 @@
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from quivr_core.chat import ChatHistory
|
from quivr_core.rag.entities.chat import ChatHistory
|
||||||
from quivr_core.config import LLMEndpointConfig, RetrievalConfig
|
from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig
|
||||||
from quivr_core.llm import LLMEndpoint
|
from quivr_core.llm import LLMEndpoint
|
||||||
from quivr_core.models import ParsedRAGChunkResponse, RAGResponseMetadata
|
from quivr_core.rag.entities.models import ParsedRAGChunkResponse, RAGResponseMetadata
|
||||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def mock_chain_qa_stream(monkeypatch, chunks_stream_answer):
|
def mock_chain_qa_stream(monkeypatch, chunks_stream_answer):
|
||||||
class MockQAChain:
|
class MockQAChain:
|
||||||
async def astream_events(self, *args, **kwargs):
|
async def astream_events(self, *args, **kwargs):
|
||||||
for c in chunks_stream_answer:
|
default_metadata = {
|
||||||
|
"langgraph_node": "generate",
|
||||||
|
"is_final_node": False,
|
||||||
|
"citations": None,
|
||||||
|
"followup_questions": None,
|
||||||
|
"sources": None,
|
||||||
|
"metadata_model": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send all chunks except the last one
|
||||||
|
for chunk in chunks_stream_answer[:-1]:
|
||||||
yield {
|
yield {
|
||||||
"event": "on_chat_model_stream",
|
"event": "on_chat_model_stream",
|
||||||
"metadata": {"langgraph_node": "generate"},
|
"metadata": default_metadata,
|
||||||
"data": {"chunk": c},
|
"data": {"chunk": chunk["answer"]},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Send the last chunk
|
||||||
|
yield {
|
||||||
|
"event": "end",
|
||||||
|
"metadata": {
|
||||||
|
"langgraph_node": "generate",
|
||||||
|
"is_final_node": True,
|
||||||
|
"citations": [],
|
||||||
|
"followup_questions": None,
|
||||||
|
"sources": [],
|
||||||
|
"metadata_model": None,
|
||||||
|
},
|
||||||
|
"data": {"chunk": chunks_stream_answer[-1]["answer"]},
|
||||||
|
}
|
||||||
|
|
||||||
def mock_qa_chain(*args, **kwargs):
|
def mock_qa_chain(*args, **kwargs):
|
||||||
|
self = args[0]
|
||||||
|
self.final_nodes = ["generate"]
|
||||||
return MockQAChain()
|
return MockQAChain()
|
||||||
|
|
||||||
monkeypatch.setattr(QuivrQARAGLangGraph, "build_chain", mock_qa_chain)
|
monkeypatch.setattr(QuivrQARAGLangGraph, "build_chain", mock_qa_chain)
|
||||||
@ -48,11 +74,13 @@ async def test_quivrqaraglanggraph(
|
|||||||
):
|
):
|
||||||
stream_responses.append(resp)
|
stream_responses.append(resp)
|
||||||
|
|
||||||
|
# This assertion passed
|
||||||
assert all(
|
assert all(
|
||||||
not r.last_chunk for r in stream_responses[:-1]
|
not r.last_chunk for r in stream_responses[:-1]
|
||||||
), "Some chunks before last have last_chunk=True"
|
), "Some chunks before last have last_chunk=True"
|
||||||
assert stream_responses[-1].last_chunk
|
assert stream_responses[-1].last_chunk
|
||||||
|
|
||||||
|
# Let's check this assertion
|
||||||
for idx, response in enumerate(stream_responses[1:-1]):
|
for idx, response in enumerate(stream_responses[1:-1]):
|
||||||
assert (
|
assert (
|
||||||
len(response.answer) > 0
|
len(response.answer) > 0
|
||||||
|
@ -3,7 +3,7 @@ from uuid import uuid4
|
|||||||
import pytest
|
import pytest
|
||||||
from langchain_core.messages.ai import AIMessageChunk
|
from langchain_core.messages.ai import AIMessageChunk
|
||||||
from langchain_core.messages.tool import ToolCall
|
from langchain_core.messages.tool import ToolCall
|
||||||
from quivr_core.utils import (
|
from quivr_core.rag.utils import (
|
||||||
get_prev_message_str,
|
get_prev_message_str,
|
||||||
model_supports_function_calling,
|
model_supports_function_calling,
|
||||||
parse_chunk_response,
|
parse_chunk_response,
|
||||||
@ -43,13 +43,9 @@ def test_get_prev_message_str():
|
|||||||
|
|
||||||
def test_parse_chunk_response_nofunc_calling():
|
def test_parse_chunk_response_nofunc_calling():
|
||||||
rolling_msg = AIMessageChunk(content="")
|
rolling_msg = AIMessageChunk(content="")
|
||||||
chunk = {
|
chunk = AIMessageChunk(content="next ")
|
||||||
"answer": AIMessageChunk(
|
|
||||||
content="next ",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
rolling_msg, parsed_chunk = parse_chunk_response(rolling_msg, chunk, False)
|
rolling_msg, parsed_chunk, _ = parse_chunk_response(rolling_msg, chunk, False)
|
||||||
assert rolling_msg.content == "next " * (i + 1)
|
assert rolling_msg.content == "next " * (i + 1)
|
||||||
assert parsed_chunk == "next "
|
assert parsed_chunk == "next "
|
||||||
|
|
||||||
@ -70,8 +66,9 @@ def test_parse_chunk_response_func_calling(chunks_stream_answer):
|
|||||||
answer_str_history: list[str] = []
|
answer_str_history: list[str] = []
|
||||||
|
|
||||||
for chunk in chunks_stream_answer:
|
for chunk in chunks_stream_answer:
|
||||||
# This is done
|
# Extract the AIMessageChunk from the chunk dictionary
|
||||||
rolling_msg, answer_str = parse_chunk_response(rolling_msg, chunk, True)
|
chunk_msg = chunk["answer"] # Get the AIMessageChunk from the dict
|
||||||
|
rolling_msg, answer_str, _ = parse_chunk_response(rolling_msg, chunk_msg, True)
|
||||||
rolling_msgs_history.append(rolling_msg)
|
rolling_msgs_history.append(rolling_msg)
|
||||||
answer_str_history.append(answer_str)
|
answer_str_history.append(answer_str)
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
from quivr_core import Brain
|
from quivr_core import Brain
|
||||||
from quivr_core.config import AssistantConfig
|
from quivr_core.rag.entities.config import AssistantConfig
|
||||||
from rich.traceback import install as rich_install
|
from rich.traceback import install as rich_install
|
||||||
|
|
||||||
ConsoleOutputHandler = logging.StreamHandler()
|
ConsoleOutputHandler = logging.StreamHandler()
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from langchain_core.embeddings import DeterministicFakeEmbedding
|
from langchain_core.embeddings import DeterministicFakeEmbedding
|
||||||
from langchain_core.language_models import FakeListChatModel
|
from langchain_core.language_models import FakeListChatModel
|
||||||
from quivr_core import Brain
|
from quivr_core import Brain
|
||||||
from quivr_core.config import LLMEndpointConfig
|
from quivr_core.rag.entities.config import LLMEndpointConfig
|
||||||
from quivr_core.llm.llm_endpoint import LLMEndpoint
|
from quivr_core.llm.llm_endpoint import LLMEndpoint
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
@ -7,6 +7,7 @@ authors = [
|
|||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"quivr-core @ file:///${PROJECT_ROOT}/../../core",
|
"quivr-core @ file:///${PROJECT_ROOT}/../../core",
|
||||||
|
"python-dotenv>=1.0.1",
|
||||||
]
|
]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">= 3.11"
|
requires-python = ">= 3.11"
|
||||||
|
@ -55,8 +55,6 @@ frozenlist==1.4.1
|
|||||||
# via aiosignal
|
# via aiosignal
|
||||||
fsspec==2024.10.0
|
fsspec==2024.10.0
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
greenlet==3.1.1
|
|
||||||
# via sqlalchemy
|
|
||||||
h11==0.14.0
|
h11==0.14.0
|
||||||
# via httpcore
|
# via httpcore
|
||||||
httpcore==1.0.6
|
httpcore==1.0.6
|
||||||
@ -169,6 +167,7 @@ pydantic==2.9.2
|
|||||||
# via langsmith
|
# via langsmith
|
||||||
# via openai
|
# via openai
|
||||||
# via quivr-core
|
# via quivr-core
|
||||||
|
# via sqlmodel
|
||||||
pydantic-core==2.23.4
|
pydantic-core==2.23.4
|
||||||
# via cohere
|
# via cohere
|
||||||
# via pydantic
|
# via pydantic
|
||||||
@ -176,6 +175,8 @@ pygments==2.18.0
|
|||||||
# via rich
|
# via rich
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
# via pandas
|
# via pandas
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
# via quivr-core
|
||||||
pytz==2024.2
|
pytz==2024.2
|
||||||
# via pandas
|
# via pandas
|
||||||
pyyaml==6.0.2
|
pyyaml==6.0.2
|
||||||
@ -185,6 +186,8 @@ pyyaml==6.0.2
|
|||||||
# via langchain-core
|
# via langchain-core
|
||||||
# via transformers
|
# via transformers
|
||||||
quivr-core @ file:///${PROJECT_ROOT}/../../core
|
quivr-core @ file:///${PROJECT_ROOT}/../../core
|
||||||
|
rapidfuzz==3.10.1
|
||||||
|
# via quivr-core
|
||||||
regex==2024.9.11
|
regex==2024.9.11
|
||||||
# via tiktoken
|
# via tiktoken
|
||||||
# via transformers
|
# via transformers
|
||||||
@ -215,6 +218,9 @@ sniffio==1.3.1
|
|||||||
sqlalchemy==2.0.36
|
sqlalchemy==2.0.36
|
||||||
# via langchain
|
# via langchain
|
||||||
# via langchain-community
|
# via langchain-community
|
||||||
|
# via sqlmodel
|
||||||
|
sqlmodel==0.0.22
|
||||||
|
# via quivr-core
|
||||||
tabulate==0.9.0
|
tabulate==0.9.0
|
||||||
# via langchain-cohere
|
# via langchain-cohere
|
||||||
tenacity==8.5.0
|
tenacity==8.5.0
|
||||||
|
@ -55,8 +55,6 @@ frozenlist==1.4.1
|
|||||||
# via aiosignal
|
# via aiosignal
|
||||||
fsspec==2024.10.0
|
fsspec==2024.10.0
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
greenlet==3.1.1
|
|
||||||
# via sqlalchemy
|
|
||||||
h11==0.14.0
|
h11==0.14.0
|
||||||
# via httpcore
|
# via httpcore
|
||||||
httpcore==1.0.6
|
httpcore==1.0.6
|
||||||
@ -169,6 +167,7 @@ pydantic==2.9.2
|
|||||||
# via langsmith
|
# via langsmith
|
||||||
# via openai
|
# via openai
|
||||||
# via quivr-core
|
# via quivr-core
|
||||||
|
# via sqlmodel
|
||||||
pydantic-core==2.23.4
|
pydantic-core==2.23.4
|
||||||
# via cohere
|
# via cohere
|
||||||
# via pydantic
|
# via pydantic
|
||||||
@ -176,6 +175,8 @@ pygments==2.18.0
|
|||||||
# via rich
|
# via rich
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
# via pandas
|
# via pandas
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
# via quivr-core
|
||||||
pytz==2024.2
|
pytz==2024.2
|
||||||
# via pandas
|
# via pandas
|
||||||
pyyaml==6.0.2
|
pyyaml==6.0.2
|
||||||
@ -185,6 +186,8 @@ pyyaml==6.0.2
|
|||||||
# via langchain-core
|
# via langchain-core
|
||||||
# via transformers
|
# via transformers
|
||||||
quivr-core @ file:///${PROJECT_ROOT}/../../core
|
quivr-core @ file:///${PROJECT_ROOT}/../../core
|
||||||
|
rapidfuzz==3.10.1
|
||||||
|
# via quivr-core
|
||||||
regex==2024.9.11
|
regex==2024.9.11
|
||||||
# via tiktoken
|
# via tiktoken
|
||||||
# via transformers
|
# via transformers
|
||||||
@ -215,6 +218,9 @@ sniffio==1.3.1
|
|||||||
sqlalchemy==2.0.36
|
sqlalchemy==2.0.36
|
||||||
# via langchain
|
# via langchain
|
||||||
# via langchain-community
|
# via langchain-community
|
||||||
|
# via sqlmodel
|
||||||
|
sqlmodel==0.0.22
|
||||||
|
# via quivr-core
|
||||||
tabulate==0.9.0
|
tabulate==0.9.0
|
||||||
# via langchain-cohere
|
# via langchain-cohere
|
||||||
tenacity==8.5.0
|
tenacity==8.5.0
|
||||||
|
@ -2,6 +2,10 @@ import tempfile
|
|||||||
|
|
||||||
from quivr_core import Brain
|
from quivr_core import Brain
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
|
||||||
temp_file.write("Gold is a liquid of blue-like colour.")
|
temp_file.write("Gold is a liquid of blue-like colour.")
|
||||||
@ -12,7 +16,5 @@ if __name__ == "__main__":
|
|||||||
file_paths=[temp_file.name],
|
file_paths=[temp_file.name],
|
||||||
)
|
)
|
||||||
|
|
||||||
answer = brain.ask(
|
answer = brain.ask("what is gold? answer in french")
|
||||||
"what is gold? asnwer in french"
|
print("answer QuivrQARAGLangGraph :", answer)
|
||||||
)
|
|
||||||
print("answer QuivrQARAGLangGraph :", answer.answer)
|
|
||||||
|
@ -4,7 +4,7 @@ import tempfile
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from quivr_core import Brain
|
from quivr_core import Brain
|
||||||
from quivr_core.quivr_rag import QuivrQARAG
|
from quivr_core.quivr_rag import QuivrQARAG
|
||||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
Loading…
Reference in New Issue
Block a user