feat: save and load brain (#3202)

# Description
- Save and load brain to disk: 
```python
async def main():
    with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
        temp_file.write("Gold is a liquid of blue-like colour.")
        temp_file.flush()

        brain = await Brain.afrom_files(name="test_brain", file_paths=[temp_file.name])

        save_path = await brain.save("/home/amine/.local/quivr")

        brain_loaded = Brain.load(save_path)
        brain_loaded.print_info()

```

# TODO: 
- Loading all chat history
- Loading from other vector stores, PG for example can be great ...
This commit is contained in:
AmineDiro 2024-09-13 15:35:28 +02:00 committed by GitHub
parent 06f72eb451
commit eda619f454
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 311 additions and 63 deletions

View File

@ -0,0 +1,22 @@
import asyncio
import tempfile
from quivr_core import Brain
async def main():
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
temp_file.write("Gold is a liquid of blue-like colour.")
temp_file.flush()
brain = await Brain.afrom_files(name="test_brain", file_paths=[temp_file.name])
save_path = await brain.save("/home/amine/.local/quivr")
brain_loaded = Brain.load(save_path)
brain_loaded.print_info()
if __name__ == "__main__":
# Run the main function in the existing event loop
asyncio.run(main())

View File

@ -1,22 +1,23 @@
import tempfile
from quivr_core import Brain
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
from quivr_core.quivr_rag import QuivrQARAG
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
if __name__ == "__main__":
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
temp_file.write("Gold is a liquid of blue-like colour.")
temp_file.flush()
brain = Brain.from_files(name="test_brain",
file_paths=[temp_file.name],
)
brain = Brain.from_files(
name="test_brain",
file_paths=[temp_file.name],
)
answer = brain.ask("what is gold? asnwer in french",
rag_pipeline=QuivrQARAGLangGraph)
answer = brain.ask(
"what is gold? asnwer in french", rag_pipeline=QuivrQARAGLangGraph
)
print("answer QuivrQARAGLangGraph :", answer.answer)
answer = brain.ask("what is gold? asnwer in french",
rag_pipeline=QuivrQARAG)
print("answer QuivrQARAG :", answer.answer)
answer = brain.ask("what is gold? asnwer in french", rag_pipeline=QuivrQARAG)
print("answer QuivrQARAG :", answer.answer)

View File

@ -1,29 +1,34 @@
from dotenv import load_dotenv
import tempfile
import asyncio
import tempfile
from dotenv import load_dotenv
from quivr_core import Brain
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
from quivr_core.quivr_rag import QuivrQARAG
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
async def main():
dotenv_path = "/Users/jchevall/Coding/QuivrHQ/quivr/.env"
load_dotenv(dotenv_path)
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
temp_file.write("Gold is a liquid of blue-like colour.")
temp_file.flush()
brain = await Brain.afrom_files(name="test_brain",
file_paths=[temp_file.name])
brain = await Brain.afrom_files(name="test_brain", file_paths=[temp_file.name])
await brain.save("~/.local/quivr")
question = "what is gold? answer in french"
async for chunk in brain.ask_streaming(question, rag_pipeline=QuivrQARAG):
print("answer QuivrQARAG:", chunk.answer)
print("answer QuivrQARAG:", chunk.answer)
async for chunk in brain.ask_streaming(question, rag_pipeline=QuivrQARAGLangGraph):
async for chunk in brain.ask_streaming(
question, rag_pipeline=QuivrQARAGLangGraph
):
print("answer QuivrQARAGLangGraph:", chunk.answer)
if __name__ == "__main__":
# Run the main function in the existing event loop
asyncio.run(main())
asyncio.run(main())

View File

@ -1,18 +1,27 @@
import asyncio
import logging
import os
from pathlib import Path
from pprint import PrettyPrinter
from typing import Any, AsyncGenerator, Callable, Dict, Self, Union, Type
from typing import Any, AsyncGenerator, Callable, Dict, Self, Type, Union
from uuid import UUID, uuid4
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.vectorstores import VectorStore
from langchain_openai import OpenAIEmbeddings
from rich.console import Console
from rich.panel import Panel
from quivr_core.brain.info import BrainInfo, ChatHistoryInfo
from quivr_core.brain.serialization import (
BrainSerialized,
EmbedderConfig,
FAISSConfig,
LocalStorageConfig,
TransparentStorageConfig,
)
from quivr_core.chat import ChatHistory
from quivr_core.config import RAGConfig
from quivr_core.files.file import load_qfile
@ -20,8 +29,8 @@ from quivr_core.llm import LLMEndpoint
from quivr_core.models import ParsedRAGChunkResponse, ParsedRAGResponse, SearchResult
from quivr_core.processor.registry import get_processor_class
from quivr_core.quivr_rag import QuivrQARAG
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
from quivr_core.storage.local_storage import TransparentStorage
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
from quivr_core.storage.local_storage import LocalStorage, TransparentStorage
from quivr_core.storage.storage_base import StorageBase
from .brain_defaults import build_default_vectordb, default_embedder, default_llm
@ -90,6 +99,108 @@ class Brain:
panel = Panel(tree, title="Brain Info", expand=False, border_style="bold")
console.print(panel)
@classmethod
def load(cls, folder_path: str | Path) -> Self:
if isinstance(folder_path, str):
folder_path = Path(folder_path)
if not folder_path.exists():
raise ValueError(f"path {folder_path} doesn't exist")
# Load brainserialized
with open(os.path.join(folder_path, "config.json"), "r") as f:
bserialized = BrainSerialized.model_validate_json(f.read())
# Loading storage
if bserialized.storage_config.storage_type == "transparent_storage":
storage: StorageBase = TransparentStorage.load(bserialized.storage_config)
elif bserialized.storage_config.storage_type == "local_storage":
storage: StorageBase = LocalStorage.load(bserialized.storage_config)
else:
raise ValueError("unknown storage")
# Load Embedder
if bserialized.embedding_config.embedder_type == "openai_embedding":
from langchain_openai import OpenAIEmbeddings
embedder = OpenAIEmbeddings(**bserialized.embedding_config.config)
else:
raise ValueError("unknown embedder")
# Load vector db
if bserialized.vectordb_config.vectordb_type == "faiss":
from langchain_community.vectorstores import FAISS
vector_db = FAISS.load_local(
folder_path=bserialized.vectordb_config.vectordb_folder_path,
embeddings=embedder,
allow_dangerous_deserialization=True,
)
else:
raise ValueError("Unsupported vectordb")
return cls(
id=bserialized.id,
name=bserialized.name,
embedder=embedder,
llm=LLMEndpoint.from_config(bserialized.llm_config),
storage=storage,
vector_db=vector_db,
)
async def save(self, folder_path: str | Path):
if isinstance(folder_path, str):
folder_path = Path(folder_path)
brain_path = os.path.join(folder_path, f"brain_{self.id}")
os.makedirs(brain_path, exist_ok=True)
from langchain_community.vectorstores import FAISS
if isinstance(self.vector_db, FAISS):
vectordb_path = os.path.join(brain_path, "vector_store")
os.makedirs(vectordb_path, exist_ok=True)
self.vector_db.save_local(folder_path=vectordb_path)
vector_store = FAISSConfig(vectordb_folder_path=vectordb_path)
else:
raise Exception("can't serialize other vector stores for now")
if isinstance(self.embedder, OpenAIEmbeddings):
embedder_config = EmbedderConfig(
config=self.embedder.dict(exclude={"openai_api_key"})
)
else:
raise Exception("can't serialize embedder other than openai for now")
# TODO : each instance should know how to serialize/deserialize itself
if isinstance(self.storage, LocalStorage):
serialized_files = {
f.id: f.serialize() for f in await self.storage.get_files()
}
storage_config = LocalStorageConfig(
storage_path=self.storage.dir_path, files=serialized_files
)
elif isinstance(self.storage, TransparentStorage):
serialized_files = {
f.id: f.serialize() for f in await self.storage.get_files()
}
storage_config = TransparentStorageConfig(files=serialized_files)
else:
raise Exception("can't serialize storage. not supported for now")
bserialized = BrainSerialized(
id=self.id,
name=self.name,
chat_history=self.chat_history.get_chat_history(),
llm_config=self.llm.get_config(),
vectordb_config=vector_store,
embedding_config=embedder_config,
storage_config=storage_config,
)
with open(os.path.join(brain_path, "config.json"), "w") as f:
f.write(bserialized.model_dump_json())
return brain_path
def info(self) -> BrainInfo:
# TODO: dim of embedding
# "embedder": {},
@ -177,7 +288,7 @@ class Brain:
storage: StorageBase = TransparentStorage(),
llm: LLMEndpoint | None = None,
embedder: Embeddings | None = None,
skip_file_error: bool = False
skip_file_error: bool = False,
) -> Self:
loop = asyncio.get_event_loop()
return loop.run_until_complete(
@ -223,7 +334,7 @@ class Brain:
storage=storage,
llm=llm,
embedder=embedder,
vector_db=vector_db
vector_db=vector_db,
)
async def asearch(

View File

@ -33,7 +33,6 @@ class LLMInfo:
llm_tree.add(f"Base URL: [underline]{self.llm_base_url}[/underline]")
llm_tree.add(f"Temperature: [bold]{self.temperature}[/bold]")
llm_tree.add(f"Max Tokens: [bold]{self.max_tokens}[/bold]")
func_call_color = "green" if self.supports_function_calling else "red"
llm_tree.add(
f"Supports Function Calling: [bold {func_call_color}]{self.supports_function_calling}[/bold {func_call_color}]"

View File

@ -0,0 +1,55 @@
from pathlib import Path
from typing import Any, Dict, Literal, Union
from uuid import UUID
from pydantic import BaseModel, Field, SecretStr
from quivr_core.config import LLMEndpointConfig
from quivr_core.files.file import QuivrFileSerialized
from quivr_core.models import ChatMessage
class EmbedderConfig(BaseModel):
embedder_type: Literal["openai_embedding"] = "openai_embedding"
# TODO: type this correctly
config: Dict[str, Any]
class PGVectorConfig(BaseModel):
vectordb_type: Literal["pgvector"] = "pgvector"
pg_url: str
pg_user: str
pg_psswd: SecretStr
table_name: str
vector_dim: int
class FAISSConfig(BaseModel):
vectordb_type: Literal["faiss"] = "faiss"
vectordb_folder_path: str
class LocalStorageConfig(BaseModel):
storage_type: Literal["local_storage"] = "local_storage"
storage_path: Path
files: dict[UUID, QuivrFileSerialized]
class TransparentStorageConfig(BaseModel):
storage_type: Literal["transparent_storage"] = "transparent_storage"
files: dict[UUID, QuivrFileSerialized]
class BrainSerialized(BaseModel):
id: UUID
name: str
chat_history: list[ChatMessage]
vectordb_config: Union[FAISSConfig, PGVectorConfig] = Field(
..., discriminator="vectordb_type"
)
storage_config: Union[TransparentStorageConfig, LocalStorageConfig] = Field(
..., discriminator="storage_type"
)
llm_config: LLMEndpointConfig
embedding_config: EmbedderConfig

View File

@ -5,10 +5,22 @@ import warnings
from contextlib import asynccontextmanager
from enum import Enum
from pathlib import Path
from typing import Any, AsyncGenerator, AsyncIterable
from typing import Any, AsyncGenerator, AsyncIterable, Self
from uuid import UUID, uuid4
import aiofiles
from openai import BaseModel
class QuivrFileSerialized(BaseModel):
id: UUID
brain_id: UUID
path: Path
original_filename: str
file_size: int | None
file_extension: str
file_sha1: str
additional_metadata: dict[str, Any]
class FileExtension(str, Enum):
@ -137,3 +149,28 @@ class QuivrFile:
"file_size": self.file_size,
**self.additional_metadata,
}
def serialize(self) -> QuivrFileSerialized:
return QuivrFileSerialized(
id=self.id,
brain_id=self.brain_id,
path=self.path.absolute(),
original_filename=self.original_filename,
file_size=self.file_size,
file_extension=self.file_extension,
file_sha1=self.file_sha1,
additional_metadata=self.additional_metadata,
)
@classmethod
def deserialize(cls, serialized: QuivrFileSerialized) -> Self:
return cls(
id=serialized.id,
brain_id=serialized.brain_id,
path=serialized.path,
original_filename=serialized.original_filename,
file_size=serialized.file_size,
file_extension=serialized.file_extension,
file_sha1=serialized.file_sha1,
metadata=serialized.additional_metadata,
)

View File

@ -1,10 +1,10 @@
import logging
from urllib.parse import parse_qs, urlparse
from langchain_core.language_models.chat_models import BaseChatModel
from pydantic.v1 import SecretStr
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from pydantic.v1 import SecretStr
from quivr_core.brain.info import LLMInfo
from quivr_core.config import LLMEndpointConfig
@ -27,8 +27,6 @@ class LLMEndpoint:
@classmethod
def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
try:
if config.model.startswith("azure/"):
# Parse the URL
parsed_url = urlparse(config.llm_base_url)

View File

@ -147,8 +147,6 @@ def get_processor_class(file_extension: FileExtension | str) -> Type[ProcessorBa
if file_extension not in known_processors:
raise ValueError(f"Extension not known: {file_extension}")
entries = known_processors[file_extension]
if file_extension == FileExtension.txt:
print(entries)
while entries:
proc_entry = heappop(entries)
try:

View File

@ -1,15 +1,15 @@
import logging
from typing import AsyncGenerator, Optional, Sequence, Annotated, Sequence, TypedDict
from typing import Annotated, AsyncGenerator, Optional, Sequence, TypedDict
# TODO(@aminediro): this is the only dependency to langchain package, we should remove it
from langchain.retrievers import ContextualCompressionRetriever
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.vectorstores import VectorStore
from langgraph.graph.message import add_messages
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from quivr_core.chat import ChatHistory
from quivr_core.config import RAGConfig
@ -19,19 +19,20 @@ from quivr_core.models import (
ParsedRAGResponse,
QuivrKnowledge,
RAGResponseMetadata,
cited_answer
cited_answer,
)
from quivr_core.prompts import CONDENSE_QUESTION_PROMPT, ANSWER_PROMPT
from quivr_core.prompts import ANSWER_PROMPT, CONDENSE_QUESTION_PROMPT
from quivr_core.utils import (
combine_documents,
format_file_list,
get_chunk_metadata,
parse_chunk_response,
combine_documents,
parse_response
parse_response,
)
logger = logging.getLogger("quivr_core")
class AgentState(TypedDict):
# The add_messages function defines how an update should be processed
# Default is to replace. add_messages says "append"
@ -43,6 +44,7 @@ class AgentState(TypedDict):
files: str
final_response: dict
class IdempotentCompressor(BaseDocumentCompressor):
def compress_documents(
self,
@ -50,7 +52,6 @@ class IdempotentCompressor(BaseDocumentCompressor):
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
A no-op document compressor that simply returns the documents it is given.
@ -59,6 +60,7 @@ class IdempotentCompressor(BaseDocumentCompressor):
"""
return documents
class QuivrQARAGLangGraph:
def __init__(
self,
@ -96,18 +98,15 @@ class QuivrQARAGLangGraph:
"""
return self.vector_store.as_retriever()
def filter_history(
self,
state
):
def filter_history(self, state):
"""
Filter out the chat history to only include the messages that are relevant to the current question
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '),
AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer,
sous la direction de son supérieur hiérarchique, Stanislas Girard."),
HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''),
HumanMessage(content='Dis moi en plus sur elle'),
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '),
AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer,
sous la direction de son supérieur hiérarchique, Stanislas Girard."),
HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''),
HumanMessage(content='Dis moi en plus sur elle'),
AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
a token is 4 characters
@ -131,7 +130,6 @@ class QuivrQARAGLangGraph:
return {"filtered_chat_history": filtered_chat_history}
### Nodes
def rewrite(self, state):
"""
@ -145,7 +143,10 @@ class QuivrQARAGLangGraph:
"""
# Grader
msg = CONDENSE_QUESTION_PROMPT.format(chat_history=state['filtered_chat_history'], question=state["messages"][0].content)
msg = CONDENSE_QUESTION_PROMPT.format(
chat_history=state["filtered_chat_history"],
question=state["messages"][0].content,
)
model = self.llm_endpoint._llm
response = model.invoke(msg)
@ -179,7 +180,7 @@ class QuivrQARAGLangGraph:
question = messages[0].content
files = state["files"]
docs = state['docs']
docs = state["docs"]
# Prompt
prompt = self.rag_config.prompt
@ -206,11 +207,10 @@ class QuivrQARAGLangGraph:
response = rag_chain.invoke(final_inputs)
formatted_response = {
"answer": response, # Assuming the last message contains the final answer
"docs": docs
"docs": docs,
}
return {"messages": [response], "final_response": formatted_response}
def build_langgraph_chain(self):
"""
Builds the langchain chain for the given configuration.
@ -247,7 +247,7 @@ class QuivrQARAGLangGraph:
workflow.add_node("filter_history", self.filter_history)
workflow.add_node("rewrite", self.rewrite) # Re-writing the question
workflow.add_node("retrieve", self.retrieve) # retrieval
workflow.add_node("generate", self.generate)
workflow.add_node("generate", self.generate)
# Add node for filtering history
@ -293,7 +293,9 @@ class QuivrQARAGLangGraph:
inputs,
config={"metadata": metadata},
)
response = parse_response(raw_llm_response["final_response"], self.rag_config.llm_config.model)
response = parse_response(
raw_llm_response["final_response"], self.rag_config.llm_config.model
)
return response
async def answer_astream(
@ -303,7 +305,6 @@ class QuivrQARAGLangGraph:
list_files: list[QuivrKnowledge],
metadata: dict[str, str] = {},
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
"""
Answer a question using the langgraph chain and yield each chunk of the answer separately.
@ -337,10 +338,17 @@ class QuivrQARAGLangGraph:
):
kind = event["event"]
if not sources and "output" in event["data"] and "docs" in event["data"]["output"]:
if (
not sources
and "output" in event["data"]
and "docs" in event["data"]["output"]
):
sources = event["data"]["output"]["docs"]
if kind == "on_chat_model_stream" and event["metadata"]["langgraph_node"] == "generate":
if (
kind == "on_chat_model_stream"
and event["metadata"]["langgraph_node"] == "generate"
):
chunk = event["data"]["chunk"]
rolling_message, answer_str = parse_chunk_response(

View File

@ -1,9 +1,10 @@
import os
import shutil
from pathlib import Path
from typing import Set
from typing import Self, Set
from uuid import UUID
from quivr_core.brain.serialization import LocalStorageConfig, TransparentStorageConfig
from quivr_core.files.file import QuivrFile
from quivr_core.storage.storage_base import StorageBase
@ -57,6 +58,12 @@ class LocalStorage(StorageBase):
async def remove_file(self, file_id: UUID) -> None:
raise NotImplementedError
@classmethod
def load(cls, config: LocalStorageConfig) -> Self:
tstorage = cls(dir_path=config.storage_path)
tstorage.files = [QuivrFile.deserialize(f) for f in config.files.values()]
return tstorage
class TransparentStorage(StorageBase):
"""Transparent Storage."""
@ -77,3 +84,11 @@ class TransparentStorage(StorageBase):
async def get_files(self) -> list[QuivrFile]:
return list(self.id_files.values())
@classmethod
def load(cls, config: TransparentStorageConfig) -> Self:
tstorage = cls()
tstorage.id_files = {
i: QuivrFile.deserialize(f) for i, f in config.files.items()
}
return tstorage

View File

@ -4,7 +4,6 @@ from uuid import uuid4
import pytest
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from quivr_core.brain import Brain
from quivr_core.chat import ChatHistory
from quivr_core.llm import LLMEndpoint