mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-20 12:51:45 +03:00
feat: save and load brain (#3202)
# Description - Save and load brain to disk: ```python async def main(): with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file: temp_file.write("Gold is a liquid of blue-like colour.") temp_file.flush() brain = await Brain.afrom_files(name="test_brain", file_paths=[temp_file.name]) save_path = await brain.save("/home/amine/.local/quivr") brain_loaded = Brain.load(save_path) brain_loaded.print_info() ``` # TODO: - Loading all chat history - Loading from other vector stores, PG for example can be great ...
This commit is contained in:
parent
06f72eb451
commit
eda619f454
22
backend/core/examples/save_load_brain.py
Normal file
22
backend/core/examples/save_load_brain.py
Normal file
@ -0,0 +1,22 @@
|
||||
import asyncio
|
||||
import tempfile
|
||||
|
||||
from quivr_core import Brain
|
||||
|
||||
|
||||
async def main():
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
|
||||
temp_file.write("Gold is a liquid of blue-like colour.")
|
||||
temp_file.flush()
|
||||
|
||||
brain = await Brain.afrom_files(name="test_brain", file_paths=[temp_file.name])
|
||||
|
||||
save_path = await brain.save("/home/amine/.local/quivr")
|
||||
|
||||
brain_loaded = Brain.load(save_path)
|
||||
brain_loaded.print_info()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the main function in the existing event loop
|
||||
asyncio.run(main())
|
@ -1,22 +1,23 @@
|
||||
import tempfile
|
||||
|
||||
from quivr_core import Brain
|
||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
from quivr_core.quivr_rag import QuivrQARAG
|
||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
|
||||
if __name__ == "__main__":
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
|
||||
temp_file.write("Gold is a liquid of blue-like colour.")
|
||||
temp_file.flush()
|
||||
|
||||
brain = Brain.from_files(name="test_brain",
|
||||
file_paths=[temp_file.name],
|
||||
)
|
||||
brain = Brain.from_files(
|
||||
name="test_brain",
|
||||
file_paths=[temp_file.name],
|
||||
)
|
||||
|
||||
answer = brain.ask("what is gold? asnwer in french",
|
||||
rag_pipeline=QuivrQARAGLangGraph)
|
||||
answer = brain.ask(
|
||||
"what is gold? asnwer in french", rag_pipeline=QuivrQARAGLangGraph
|
||||
)
|
||||
print("answer QuivrQARAGLangGraph :", answer.answer)
|
||||
|
||||
|
||||
answer = brain.ask("what is gold? asnwer in french",
|
||||
rag_pipeline=QuivrQARAG)
|
||||
print("answer QuivrQARAG :", answer.answer)
|
||||
answer = brain.ask("what is gold? asnwer in french", rag_pipeline=QuivrQARAG)
|
||||
print("answer QuivrQARAG :", answer.answer)
|
||||
|
@ -1,29 +1,34 @@
|
||||
from dotenv import load_dotenv
|
||||
import tempfile
|
||||
import asyncio
|
||||
import tempfile
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from quivr_core import Brain
|
||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
from quivr_core.quivr_rag import QuivrQARAG
|
||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
|
||||
|
||||
async def main():
|
||||
dotenv_path = "/Users/jchevall/Coding/QuivrHQ/quivr/.env"
|
||||
load_dotenv(dotenv_path)
|
||||
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
|
||||
temp_file.write("Gold is a liquid of blue-like colour.")
|
||||
temp_file.flush()
|
||||
|
||||
brain = await Brain.afrom_files(name="test_brain",
|
||||
file_paths=[temp_file.name])
|
||||
brain = await Brain.afrom_files(name="test_brain", file_paths=[temp_file.name])
|
||||
|
||||
await brain.save("~/.local/quivr")
|
||||
|
||||
question = "what is gold? answer in french"
|
||||
async for chunk in brain.ask_streaming(question, rag_pipeline=QuivrQARAG):
|
||||
print("answer QuivrQARAG:", chunk.answer)
|
||||
print("answer QuivrQARAG:", chunk.answer)
|
||||
|
||||
async for chunk in brain.ask_streaming(question, rag_pipeline=QuivrQARAGLangGraph):
|
||||
async for chunk in brain.ask_streaming(
|
||||
question, rag_pipeline=QuivrQARAGLangGraph
|
||||
):
|
||||
print("answer QuivrQARAGLangGraph:", chunk.answer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the main function in the existing event loop
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
@ -1,18 +1,27 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pprint import PrettyPrinter
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, Self, Union, Type
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, Self, Type, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
from quivr_core.brain.info import BrainInfo, ChatHistoryInfo
|
||||
from quivr_core.brain.serialization import (
|
||||
BrainSerialized,
|
||||
EmbedderConfig,
|
||||
FAISSConfig,
|
||||
LocalStorageConfig,
|
||||
TransparentStorageConfig,
|
||||
)
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import RAGConfig
|
||||
from quivr_core.files.file import load_qfile
|
||||
@ -20,8 +29,8 @@ from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.models import ParsedRAGChunkResponse, ParsedRAGResponse, SearchResult
|
||||
from quivr_core.processor.registry import get_processor_class
|
||||
from quivr_core.quivr_rag import QuivrQARAG
|
||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
from quivr_core.storage.local_storage import TransparentStorage
|
||||
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
|
||||
from quivr_core.storage.local_storage import LocalStorage, TransparentStorage
|
||||
from quivr_core.storage.storage_base import StorageBase
|
||||
|
||||
from .brain_defaults import build_default_vectordb, default_embedder, default_llm
|
||||
@ -90,6 +99,108 @@ class Brain:
|
||||
panel = Panel(tree, title="Brain Info", expand=False, border_style="bold")
|
||||
console.print(panel)
|
||||
|
||||
@classmethod
|
||||
def load(cls, folder_path: str | Path) -> Self:
|
||||
if isinstance(folder_path, str):
|
||||
folder_path = Path(folder_path)
|
||||
if not folder_path.exists():
|
||||
raise ValueError(f"path {folder_path} doesn't exist")
|
||||
|
||||
# Load brainserialized
|
||||
with open(os.path.join(folder_path, "config.json"), "r") as f:
|
||||
bserialized = BrainSerialized.model_validate_json(f.read())
|
||||
|
||||
# Loading storage
|
||||
if bserialized.storage_config.storage_type == "transparent_storage":
|
||||
storage: StorageBase = TransparentStorage.load(bserialized.storage_config)
|
||||
elif bserialized.storage_config.storage_type == "local_storage":
|
||||
storage: StorageBase = LocalStorage.load(bserialized.storage_config)
|
||||
else:
|
||||
raise ValueError("unknown storage")
|
||||
|
||||
# Load Embedder
|
||||
if bserialized.embedding_config.embedder_type == "openai_embedding":
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
embedder = OpenAIEmbeddings(**bserialized.embedding_config.config)
|
||||
else:
|
||||
raise ValueError("unknown embedder")
|
||||
|
||||
# Load vector db
|
||||
if bserialized.vectordb_config.vectordb_type == "faiss":
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
vector_db = FAISS.load_local(
|
||||
folder_path=bserialized.vectordb_config.vectordb_folder_path,
|
||||
embeddings=embedder,
|
||||
allow_dangerous_deserialization=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported vectordb")
|
||||
|
||||
return cls(
|
||||
id=bserialized.id,
|
||||
name=bserialized.name,
|
||||
embedder=embedder,
|
||||
llm=LLMEndpoint.from_config(bserialized.llm_config),
|
||||
storage=storage,
|
||||
vector_db=vector_db,
|
||||
)
|
||||
|
||||
async def save(self, folder_path: str | Path):
|
||||
if isinstance(folder_path, str):
|
||||
folder_path = Path(folder_path)
|
||||
|
||||
brain_path = os.path.join(folder_path, f"brain_{self.id}")
|
||||
os.makedirs(brain_path, exist_ok=True)
|
||||
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
if isinstance(self.vector_db, FAISS):
|
||||
vectordb_path = os.path.join(brain_path, "vector_store")
|
||||
os.makedirs(vectordb_path, exist_ok=True)
|
||||
self.vector_db.save_local(folder_path=vectordb_path)
|
||||
vector_store = FAISSConfig(vectordb_folder_path=vectordb_path)
|
||||
else:
|
||||
raise Exception("can't serialize other vector stores for now")
|
||||
|
||||
if isinstance(self.embedder, OpenAIEmbeddings):
|
||||
embedder_config = EmbedderConfig(
|
||||
config=self.embedder.dict(exclude={"openai_api_key"})
|
||||
)
|
||||
else:
|
||||
raise Exception("can't serialize embedder other than openai for now")
|
||||
|
||||
# TODO : each instance should know how to serialize/deserialize itself
|
||||
if isinstance(self.storage, LocalStorage):
|
||||
serialized_files = {
|
||||
f.id: f.serialize() for f in await self.storage.get_files()
|
||||
}
|
||||
storage_config = LocalStorageConfig(
|
||||
storage_path=self.storage.dir_path, files=serialized_files
|
||||
)
|
||||
elif isinstance(self.storage, TransparentStorage):
|
||||
serialized_files = {
|
||||
f.id: f.serialize() for f in await self.storage.get_files()
|
||||
}
|
||||
storage_config = TransparentStorageConfig(files=serialized_files)
|
||||
else:
|
||||
raise Exception("can't serialize storage. not supported for now")
|
||||
|
||||
bserialized = BrainSerialized(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
chat_history=self.chat_history.get_chat_history(),
|
||||
llm_config=self.llm.get_config(),
|
||||
vectordb_config=vector_store,
|
||||
embedding_config=embedder_config,
|
||||
storage_config=storage_config,
|
||||
)
|
||||
|
||||
with open(os.path.join(brain_path, "config.json"), "w") as f:
|
||||
f.write(bserialized.model_dump_json())
|
||||
return brain_path
|
||||
|
||||
def info(self) -> BrainInfo:
|
||||
# TODO: dim of embedding
|
||||
# "embedder": {},
|
||||
@ -177,7 +288,7 @@ class Brain:
|
||||
storage: StorageBase = TransparentStorage(),
|
||||
llm: LLMEndpoint | None = None,
|
||||
embedder: Embeddings | None = None,
|
||||
skip_file_error: bool = False
|
||||
skip_file_error: bool = False,
|
||||
) -> Self:
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
@ -223,7 +334,7 @@ class Brain:
|
||||
storage=storage,
|
||||
llm=llm,
|
||||
embedder=embedder,
|
||||
vector_db=vector_db
|
||||
vector_db=vector_db,
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
|
@ -33,7 +33,6 @@ class LLMInfo:
|
||||
llm_tree.add(f"Base URL: [underline]{self.llm_base_url}[/underline]")
|
||||
llm_tree.add(f"Temperature: [bold]{self.temperature}[/bold]")
|
||||
llm_tree.add(f"Max Tokens: [bold]{self.max_tokens}[/bold]")
|
||||
|
||||
func_call_color = "green" if self.supports_function_calling else "red"
|
||||
llm_tree.add(
|
||||
f"Supports Function Calling: [bold {func_call_color}]{self.supports_function_calling}[/bold {func_call_color}]"
|
||||
|
55
backend/core/quivr_core/brain/serialization.py
Normal file
55
backend/core/quivr_core/brain/serialization.py
Normal file
@ -0,0 +1,55 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
from quivr_core.files.file import QuivrFileSerialized
|
||||
from quivr_core.models import ChatMessage
|
||||
|
||||
|
||||
class EmbedderConfig(BaseModel):
|
||||
embedder_type: Literal["openai_embedding"] = "openai_embedding"
|
||||
# TODO: type this correctly
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class PGVectorConfig(BaseModel):
|
||||
vectordb_type: Literal["pgvector"] = "pgvector"
|
||||
pg_url: str
|
||||
pg_user: str
|
||||
pg_psswd: SecretStr
|
||||
table_name: str
|
||||
vector_dim: int
|
||||
|
||||
|
||||
class FAISSConfig(BaseModel):
|
||||
vectordb_type: Literal["faiss"] = "faiss"
|
||||
vectordb_folder_path: str
|
||||
|
||||
|
||||
class LocalStorageConfig(BaseModel):
|
||||
storage_type: Literal["local_storage"] = "local_storage"
|
||||
storage_path: Path
|
||||
files: dict[UUID, QuivrFileSerialized]
|
||||
|
||||
|
||||
class TransparentStorageConfig(BaseModel):
|
||||
storage_type: Literal["transparent_storage"] = "transparent_storage"
|
||||
files: dict[UUID, QuivrFileSerialized]
|
||||
|
||||
|
||||
class BrainSerialized(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
chat_history: list[ChatMessage]
|
||||
vectordb_config: Union[FAISSConfig, PGVectorConfig] = Field(
|
||||
..., discriminator="vectordb_type"
|
||||
)
|
||||
storage_config: Union[TransparentStorageConfig, LocalStorageConfig] = Field(
|
||||
..., discriminator="storage_type"
|
||||
)
|
||||
|
||||
llm_config: LLMEndpointConfig
|
||||
embedding_config: EmbedderConfig
|
@ -5,10 +5,22 @@ import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, AsyncIterable
|
||||
from typing import Any, AsyncGenerator, AsyncIterable, Self
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import aiofiles
|
||||
from openai import BaseModel
|
||||
|
||||
|
||||
class QuivrFileSerialized(BaseModel):
|
||||
id: UUID
|
||||
brain_id: UUID
|
||||
path: Path
|
||||
original_filename: str
|
||||
file_size: int | None
|
||||
file_extension: str
|
||||
file_sha1: str
|
||||
additional_metadata: dict[str, Any]
|
||||
|
||||
|
||||
class FileExtension(str, Enum):
|
||||
@ -137,3 +149,28 @@ class QuivrFile:
|
||||
"file_size": self.file_size,
|
||||
**self.additional_metadata,
|
||||
}
|
||||
|
||||
def serialize(self) -> QuivrFileSerialized:
|
||||
return QuivrFileSerialized(
|
||||
id=self.id,
|
||||
brain_id=self.brain_id,
|
||||
path=self.path.absolute(),
|
||||
original_filename=self.original_filename,
|
||||
file_size=self.file_size,
|
||||
file_extension=self.file_extension,
|
||||
file_sha1=self.file_sha1,
|
||||
additional_metadata=self.additional_metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, serialized: QuivrFileSerialized) -> Self:
|
||||
return cls(
|
||||
id=serialized.id,
|
||||
brain_id=serialized.brain_id,
|
||||
path=serialized.path,
|
||||
original_filename=serialized.original_filename,
|
||||
file_size=serialized.file_size,
|
||||
file_extension=serialized.file_extension,
|
||||
file_sha1=serialized.file_sha1,
|
||||
metadata=serialized.additional_metadata,
|
||||
)
|
||||
|
@ -1,10 +1,10 @@
|
||||
import logging
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from pydantic.v1 import SecretStr
|
||||
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||
from pydantic.v1 import SecretStr
|
||||
|
||||
from quivr_core.brain.info import LLMInfo
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
@ -27,8 +27,6 @@ class LLMEndpoint:
|
||||
@classmethod
|
||||
def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
|
||||
try:
|
||||
|
||||
|
||||
if config.model.startswith("azure/"):
|
||||
# Parse the URL
|
||||
parsed_url = urlparse(config.llm_base_url)
|
||||
|
@ -147,8 +147,6 @@ def get_processor_class(file_extension: FileExtension | str) -> Type[ProcessorBa
|
||||
if file_extension not in known_processors:
|
||||
raise ValueError(f"Extension not known: {file_extension}")
|
||||
entries = known_processors[file_extension]
|
||||
if file_extension == FileExtension.txt:
|
||||
print(entries)
|
||||
while entries:
|
||||
proc_entry = heappop(entries)
|
||||
try:
|
||||
|
@ -1,15 +1,15 @@
|
||||
import logging
|
||||
from typing import AsyncGenerator, Optional, Sequence, Annotated, Sequence, TypedDict
|
||||
from typing import Annotated, AsyncGenerator, Optional, Sequence, TypedDict
|
||||
|
||||
# TODO(@aminediro): this is the only dependency to langchain package, we should remove it
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.graph import END, StateGraph
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import RAGConfig
|
||||
@ -19,19 +19,20 @@ from quivr_core.models import (
|
||||
ParsedRAGResponse,
|
||||
QuivrKnowledge,
|
||||
RAGResponseMetadata,
|
||||
cited_answer
|
||||
cited_answer,
|
||||
)
|
||||
from quivr_core.prompts import CONDENSE_QUESTION_PROMPT, ANSWER_PROMPT
|
||||
from quivr_core.prompts import ANSWER_PROMPT, CONDENSE_QUESTION_PROMPT
|
||||
from quivr_core.utils import (
|
||||
combine_documents,
|
||||
format_file_list,
|
||||
get_chunk_metadata,
|
||||
parse_chunk_response,
|
||||
combine_documents,
|
||||
parse_response
|
||||
parse_response,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("quivr_core")
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
# The add_messages function defines how an update should be processed
|
||||
# Default is to replace. add_messages says "append"
|
||||
@ -43,6 +44,7 @@ class AgentState(TypedDict):
|
||||
files: str
|
||||
final_response: dict
|
||||
|
||||
|
||||
class IdempotentCompressor(BaseDocumentCompressor):
|
||||
def compress_documents(
|
||||
self,
|
||||
@ -50,7 +52,6 @@ class IdempotentCompressor(BaseDocumentCompressor):
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
|
||||
"""
|
||||
A no-op document compressor that simply returns the documents it is given.
|
||||
|
||||
@ -59,6 +60,7 @@ class IdempotentCompressor(BaseDocumentCompressor):
|
||||
"""
|
||||
return documents
|
||||
|
||||
|
||||
class QuivrQARAGLangGraph:
|
||||
def __init__(
|
||||
self,
|
||||
@ -96,18 +98,15 @@ class QuivrQARAGLangGraph:
|
||||
"""
|
||||
return self.vector_store.as_retriever()
|
||||
|
||||
def filter_history(
|
||||
self,
|
||||
state
|
||||
):
|
||||
def filter_history(self, state):
|
||||
"""
|
||||
Filter out the chat history to only include the messages that are relevant to the current question
|
||||
|
||||
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '),
|
||||
AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer,
|
||||
sous la direction de son supérieur hiérarchique, Stanislas Girard."),
|
||||
HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''),
|
||||
HumanMessage(content='Dis moi en plus sur elle'),
|
||||
Takes in a chat_history= [HumanMessage(content='Qui est Chloé ? '),
|
||||
AIMessage(content="Chloé est une salariée travaillant pour l'entreprise Quivr en tant qu'AI Engineer,
|
||||
sous la direction de son supérieur hiérarchique, Stanislas Girard."),
|
||||
HumanMessage(content='Dis moi en plus sur elle'), AIMessage(content=''),
|
||||
HumanMessage(content='Dis moi en plus sur elle'),
|
||||
AIMessage(content="Désolé, je n'ai pas d'autres informations sur Chloé à partir des fichiers fournis.")]
|
||||
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
|
||||
a token is 4 characters
|
||||
@ -131,7 +130,6 @@ class QuivrQARAGLangGraph:
|
||||
|
||||
return {"filtered_chat_history": filtered_chat_history}
|
||||
|
||||
|
||||
### Nodes
|
||||
def rewrite(self, state):
|
||||
"""
|
||||
@ -145,7 +143,10 @@ class QuivrQARAGLangGraph:
|
||||
"""
|
||||
|
||||
# Grader
|
||||
msg = CONDENSE_QUESTION_PROMPT.format(chat_history=state['filtered_chat_history'], question=state["messages"][0].content)
|
||||
msg = CONDENSE_QUESTION_PROMPT.format(
|
||||
chat_history=state["filtered_chat_history"],
|
||||
question=state["messages"][0].content,
|
||||
)
|
||||
|
||||
model = self.llm_endpoint._llm
|
||||
response = model.invoke(msg)
|
||||
@ -179,7 +180,7 @@ class QuivrQARAGLangGraph:
|
||||
question = messages[0].content
|
||||
files = state["files"]
|
||||
|
||||
docs = state['docs']
|
||||
docs = state["docs"]
|
||||
|
||||
# Prompt
|
||||
prompt = self.rag_config.prompt
|
||||
@ -206,11 +207,10 @@ class QuivrQARAGLangGraph:
|
||||
response = rag_chain.invoke(final_inputs)
|
||||
formatted_response = {
|
||||
"answer": response, # Assuming the last message contains the final answer
|
||||
"docs": docs
|
||||
"docs": docs,
|
||||
}
|
||||
return {"messages": [response], "final_response": formatted_response}
|
||||
|
||||
|
||||
def build_langgraph_chain(self):
|
||||
"""
|
||||
Builds the langchain chain for the given configuration.
|
||||
@ -247,7 +247,7 @@ class QuivrQARAGLangGraph:
|
||||
workflow.add_node("filter_history", self.filter_history)
|
||||
workflow.add_node("rewrite", self.rewrite) # Re-writing the question
|
||||
workflow.add_node("retrieve", self.retrieve) # retrieval
|
||||
workflow.add_node("generate", self.generate)
|
||||
workflow.add_node("generate", self.generate)
|
||||
|
||||
# Add node for filtering history
|
||||
|
||||
@ -293,7 +293,9 @@ class QuivrQARAGLangGraph:
|
||||
inputs,
|
||||
config={"metadata": metadata},
|
||||
)
|
||||
response = parse_response(raw_llm_response["final_response"], self.rag_config.llm_config.model)
|
||||
response = parse_response(
|
||||
raw_llm_response["final_response"], self.rag_config.llm_config.model
|
||||
)
|
||||
return response
|
||||
|
||||
async def answer_astream(
|
||||
@ -303,7 +305,6 @@ class QuivrQARAGLangGraph:
|
||||
list_files: list[QuivrKnowledge],
|
||||
metadata: dict[str, str] = {},
|
||||
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
|
||||
|
||||
"""
|
||||
Answer a question using the langgraph chain and yield each chunk of the answer separately.
|
||||
|
||||
@ -337,10 +338,17 @@ class QuivrQARAGLangGraph:
|
||||
):
|
||||
kind = event["event"]
|
||||
|
||||
if not sources and "output" in event["data"] and "docs" in event["data"]["output"]:
|
||||
if (
|
||||
not sources
|
||||
and "output" in event["data"]
|
||||
and "docs" in event["data"]["output"]
|
||||
):
|
||||
sources = event["data"]["output"]["docs"]
|
||||
|
||||
if kind == "on_chat_model_stream" and event["metadata"]["langgraph_node"] == "generate":
|
||||
if (
|
||||
kind == "on_chat_model_stream"
|
||||
and event["metadata"]["langgraph_node"] == "generate"
|
||||
):
|
||||
chunk = event["data"]["chunk"]
|
||||
|
||||
rolling_message, answer_str = parse_chunk_response(
|
||||
|
@ -1,9 +1,10 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Set
|
||||
from typing import Self, Set
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_core.brain.serialization import LocalStorageConfig, TransparentStorageConfig
|
||||
from quivr_core.files.file import QuivrFile
|
||||
from quivr_core.storage.storage_base import StorageBase
|
||||
|
||||
@ -57,6 +58,12 @@ class LocalStorage(StorageBase):
|
||||
async def remove_file(self, file_id: UUID) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def load(cls, config: LocalStorageConfig) -> Self:
|
||||
tstorage = cls(dir_path=config.storage_path)
|
||||
tstorage.files = [QuivrFile.deserialize(f) for f in config.files.values()]
|
||||
return tstorage
|
||||
|
||||
|
||||
class TransparentStorage(StorageBase):
|
||||
"""Transparent Storage."""
|
||||
@ -77,3 +84,11 @@ class TransparentStorage(StorageBase):
|
||||
|
||||
async def get_files(self) -> list[QuivrFile]:
|
||||
return list(self.id_files.values())
|
||||
|
||||
@classmethod
|
||||
def load(cls, config: TransparentStorageConfig) -> Self:
|
||||
tstorage = cls()
|
||||
tstorage.id_files = {
|
||||
i: QuivrFile.deserialize(f) for i, f in config.files.items()
|
||||
}
|
||||
return tstorage
|
||||
|
@ -4,7 +4,6 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from quivr_core.brain import Brain
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
|
Loading…
Reference in New Issue
Block a user