quivr/backend/core/quivr_core/chat.py
Jacopo Chevallard ef90e8e672
feat: introducing configurable retrieval workflows (#3227)
# Description

Major PR which, among other things, introduces the possibility of easily
customizing the retrieval workflows. Workflows are based on LangGraph,
and can be customized using a [yaml configuration
file](core/tests/test_llm_endpoint.py), and adding the implementation of
the nodes logic into
[quivr_rag_langgraph.py](1a0c98437a/backend/core/quivr_core/quivr_rag_langgraph.py)

This is a first, simple implementation that will significantly evolve in
the coming weeks to enable more complex workflows (for instance, with
conditional nodes). We also plan to adopt a similar approach for the
ingestion part, i.e. to enable user to easily customize the ingestion
pipeline.

Closes CORE-195, CORE-203, CORE-204

## Checklist before requesting a review

Please delete options that are not relevant.

- [X] My code follows the style guidelines of this project
- [X] I have performed a self-review of my code
- [X] I have commented hard-to-understand areas
- [X] I have ideally added tests that prove my fix is effective or that
my feature works
- [X] New and existing unit tests pass locally with my changes
- [X] Any dependent changes have been merged

## Screenshots (if appropriate):
2024-09-23 09:11:06 -07:00

85 lines
2.7 KiB
Python

from datetime import datetime
from typing import Any, Generator, Tuple, List
from uuid import UUID, uuid4
from copy import deepcopy
from langchain_core.messages import AIMessage, HumanMessage
from quivr_core.models import ChatMessage
class ChatHistory:
"""
Chat history is a list of ChatMessage.
It is used to store the chat history of a chat.
"""
def __init__(self, chat_id: UUID, brain_id: UUID | None) -> None:
self.id = chat_id
self.brain_id = brain_id
# TODO(@aminediro): maybe use a deque() instead ?
self._msgs: list[ChatMessage] = []
def get_chat_history(self, newest_first: bool = False):
"""Returns a ChatMessage list sorted by time
Returns:
list[ChatMessage]: list of chat messages
"""
history = sorted(self._msgs, key=lambda msg: msg.message_time)
if newest_first:
return history[::-1]
return history
def __len__(self):
return len(self._msgs)
def append(
self, langchain_msg: AIMessage | HumanMessage, metadata: dict[str, Any] = {}
):
"""
Append a message to the chat history.
"""
chat_msg = ChatMessage(
chat_id=self.id,
message_id=uuid4(),
brain_id=self.brain_id,
msg=langchain_msg,
message_time=datetime.now(),
metadata=metadata,
)
self._msgs.append(chat_msg)
def iter_pairs(self) -> Generator[Tuple[HumanMessage, AIMessage], None, None]:
"""
Iterate over the chat history as pairs of HumanMessage and AIMessage.
"""
# Reverse the chat_history, newest first
it = iter(self.get_chat_history(newest_first=True))
for ai_message, human_message in zip(it, it, strict=False):
assert isinstance(
human_message.msg, HumanMessage
), f"msg {human_message} is not HumanMessage"
assert isinstance(
ai_message.msg, AIMessage
), f"msg {human_message} is not AIMessage"
yield (human_message.msg, ai_message.msg)
def to_list(self) -> List[HumanMessage | AIMessage]:
"""Format the chat history into a list of HumanMessage and AIMessage"""
return [_msg.msg for _msg in self._msgs]
def __deepcopy__(self, memo):
"""
Support for deepcopy of ChatHistory.
This method ensures that mutable objects (like lists) are copied deeply.
"""
# Create a new instance of ChatHistory
new_copy = ChatHistory(self.id, deepcopy(self.brain_id, memo))
# Perform a deepcopy of the _msgs list
new_copy._msgs = deepcopy(self._msgs, memo)
# Return the deep copied instance
return new_copy