2024-07-10 16:22:59 +03:00
|
|
|
from datetime import datetime
|
|
|
|
from typing import Any, Generator, Tuple
|
|
|
|
from uuid import UUID, uuid4
|
|
|
|
|
|
|
|
from langchain_core.messages import AIMessage, HumanMessage
|
|
|
|
from quivr_core.models import ChatMessage
|
|
|
|
|
|
|
|
|
|
|
|
class ChatHistory:
|
2024-08-06 15:51:27 +03:00
|
|
|
def __init__(self, chat_id: UUID, brain_id: UUID | None) -> None:
|
2024-07-10 16:22:59 +03:00
|
|
|
self.id = chat_id
|
|
|
|
self.brain_id = brain_id
|
|
|
|
# TODO(@aminediro): maybe use a deque() instead ?
|
|
|
|
self._msgs: list[ChatMessage] = []
|
|
|
|
|
|
|
|
def get_chat_history(self, newest_first: bool = False):
|
|
|
|
"""Returns a ChatMessage list sorted by time
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list[ChatMessage]: list of chat messages
|
|
|
|
"""
|
|
|
|
history = sorted(self._msgs, key=lambda msg: msg.message_time)
|
|
|
|
if newest_first:
|
|
|
|
return history[::-1]
|
|
|
|
return history
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self._msgs)
|
|
|
|
|
|
|
|
def append(
|
|
|
|
self, langchain_msg: AIMessage | HumanMessage, metadata: dict[str, Any] = {}
|
|
|
|
):
|
|
|
|
chat_msg = ChatMessage(
|
|
|
|
chat_id=self.id,
|
|
|
|
message_id=uuid4(),
|
|
|
|
brain_id=self.brain_id,
|
|
|
|
msg=langchain_msg,
|
|
|
|
message_time=datetime.now(),
|
|
|
|
metadata=metadata,
|
|
|
|
)
|
|
|
|
self._msgs.append(chat_msg)
|
|
|
|
|
|
|
|
def iter_pairs(self) -> Generator[Tuple[HumanMessage, AIMessage], None, None]:
|
|
|
|
# Reverse the chat_history, newest first
|
|
|
|
it = iter(self.get_chat_history(newest_first=True))
|
|
|
|
for ai_message, human_message in zip(it, it):
|
2024-07-15 20:10:03 +03:00
|
|
|
assert isinstance(
|
|
|
|
human_message.msg, HumanMessage
|
|
|
|
), f"msg {human_message} is not HumanMessage"
|
|
|
|
assert isinstance(
|
|
|
|
ai_message.msg, AIMessage
|
|
|
|
), f"msg {human_message} is not AIMessage"
|
2024-07-10 16:22:59 +03:00
|
|
|
yield (human_message.msg, ai_message.msg)
|