feat: quivr core chat history (#2824)

# Description

- Defined quivr-core `ChatHistory`
- `ChatHistory` can be iterated over in tuples of
`HumanMessage,AIMessage`
-  Brain appends to the chatHistory once response is received
- Brain holds a dict of chats and defines the default chat (TODO: define
a system of selecting the chats)
- Wrote test 
- Updated `QuivrQARAG` to use `ChatHistory` as input
This commit is contained in:
AmineDiro 2024-07-10 15:22:59 +02:00 committed by GitHub
parent 442186c674
commit 847e161d80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 213 additions and 45 deletions

View File

@ -6,8 +6,10 @@ from uuid import UUID, uuid4
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.vectorstores import VectorStore
from quivr_core.chat import ChatHistory
from quivr_core.config import LLMEndpointConfig, RAGConfig
from quivr_core.llm import LLMEndpoint
from quivr_core.models import ParsedRAGResponse, SearchResult
@ -108,13 +110,23 @@ class Brain:
self.storage = storage
# Chat history
self.chat_history: list[str] = []
self._chats = self._init_chats()
self.default_chat = list(self._chats.values())[0]
# RAG dependencies:
self.llm = llm
self.vector_db = vector_db
self.embedder = embedder
@property
def chat_history(self):
return self.default_chat
def _init_chats(self):
chat_id = uuid4()
default_chat = ChatHistory(chat_id=chat_id, brain_id=self.id)
return {chat_id: default_chat}
@classmethod
async def afrom_files(
cls,
@ -235,6 +247,9 @@ class Brain:
return [SearchResult(chunk=d, score=s) for d, s in result]
def get_chat_history(self, chat_id: UUID):
return self._chats[chat_id]
# TODO(@aminediro)
def add_file(self) -> None:
# add it to storage
@ -242,7 +257,9 @@ class Brain:
raise NotImplementedError
def ask(
self, question: str, rag_config: RAGConfig | None = None
self,
question: str,
rag_config: RAGConfig | None = None,
) -> ParsedRAGResponse:
llm = self.llm
@ -257,9 +274,12 @@ class Brain:
rag_config=rag_config, llm=llm, vector_store=self.vector_db
)
# transformed_history = format_chat_history(history)
chat_history = self.default_chat
parsed_response = rag_pipeline.answer(question, [], [])
parsed_response = rag_pipeline.answer(question, chat_history, [])
chat_history.append(HumanMessage(content=question))
chat_history.append(AIMessage(content=parsed_response.answer))
# Save answer to the chat history
return parsed_response

View File

@ -0,0 +1,50 @@
from datetime import datetime
from typing import Any, Generator, Tuple
from uuid import UUID, uuid4
from langchain_core.messages import AIMessage, HumanMessage
from quivr_core.models import ChatMessage
class ChatHistory:
def __init__(self, chat_id: UUID, brain_id: UUID) -> None:
self.id = chat_id
self.brain_id = brain_id
# TODO(@aminediro): maybe use a deque() instead ?
self._msgs: list[ChatMessage] = []
def get_chat_history(self, newest_first: bool = False):
"""Returns a ChatMessage list sorted by time
Returns:
list[ChatMessage]: list of chat messages
"""
history = sorted(self._msgs, key=lambda msg: msg.message_time)
if newest_first:
return history[::-1]
return history
def __len__(self):
return len(self._msgs)
def append(
self, langchain_msg: AIMessage | HumanMessage, metadata: dict[str, Any] = {}
):
chat_msg = ChatMessage(
chat_id=self.id,
message_id=uuid4(),
brain_id=self.brain_id,
msg=langchain_msg,
message_time=datetime.now(),
metadata=metadata,
)
self._msgs.append(chat_msg)
def iter_pairs(self) -> Generator[Tuple[HumanMessage, AIMessage], None, None]:
# Reverse the chat_history, newest first
it = iter(self.get_chat_history(newest_first=True))
for ai_message, human_message in zip(it, it):
assert isinstance(human_message.msg, HumanMessage)
assert isinstance(ai_message.msg, AIMessage)
yield (human_message.msg, ai_message.msg)

View File

@ -12,5 +12,6 @@ class LLMEndpointConfig(BaseModel):
class RAGConfig(BaseModel):
llm_config: LLMEndpointConfig = LLMEndpointConfig()
max_history: int = 10
max_files: int = 20
prompt: str | None = None

View File

@ -3,6 +3,7 @@ from typing import Any
from uuid import UUID
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
from langchain_core.pydantic_v1 import Field as FieldV1
from pydantic import BaseModel
@ -33,17 +34,13 @@ class cited_answer(BaseModelV1):
)
class GetChatHistoryOutput(BaseModel):
class ChatMessage(BaseModelV1):
chat_id: UUID
message_id: UUID
user_message: str
brain_id: UUID
msg: AIMessage | HumanMessage
message_time: datetime
assistant: str | None = None
prompt_title: str | None = None
brain_name: str | None = None
brain_id: UUID | None = None # string because UUID is not JSON serializable
metadata: dict | None = None
thumbs: bool | None = None
metadata: dict[str, Any]
class Source(BaseModel):

View File

@ -5,11 +5,13 @@ from typing import AsyncGenerator, Optional, Sequence
from langchain.retrievers import ContextualCompressionRetriever
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.vectorstores import VectorStore
from quivr_core.chat import ChatHistory
from quivr_core.config import RAGConfig
from quivr_core.llm import LLMEndpoint
from quivr_core.models import (
@ -59,7 +61,8 @@ class QuivrQARAG:
return self.vector_store.as_retriever()
def filter_history(
self, chat_history, max_history: int = 10, max_tokens: int = 2000
self,
chat_history: ChatHistory,
):
"""
Filter out the chat history to only include the messages that are relevant to the current question
@ -68,29 +71,23 @@ class QuivrQARAG:
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
a token is 4 characters
"""
chat_history = chat_history[::-1]
total_tokens = 0
total_pairs = 0
filtered_chat_history = []
for i in range(0, len(chat_history), 2):
if i + 1 < len(chat_history):
human_message = chat_history[i]
ai_message = chat_history[i + 1]
message_tokens = (
len(human_message.content) + len(ai_message.content)
) // 4
filtered_chat_history: list[AIMessage | HumanMessage] = []
for human_message, ai_message in chat_history.iter_pairs():
# TODO: replace with tiktoken
message_tokens = (len(human_message.content) + len(ai_message.content)) // 4
if (
total_tokens + message_tokens > max_tokens
or total_pairs >= max_history
total_tokens + message_tokens > self.rag_config.llm_config.max_tokens
or total_pairs >= self.rag_config.max_history
):
break
filtered_chat_history.append(human_message)
filtered_chat_history.append(ai_message)
total_tokens += message_tokens
total_pairs += 1
chat_history = filtered_chat_history[::-1]
return chat_history
return filtered_chat_history[::-1]
def build_chain(self, files: str):
compression_retriever = ContextualCompressionRetriever(
@ -146,7 +143,7 @@ class QuivrQARAG:
def answer(
self,
question: str,
history: list[dict[str, str]],
history: ChatHistory,
list_files: list[QuivrKnowledge],
metadata: dict[str, str] = {},
) -> ParsedRAGResponse:
@ -166,7 +163,7 @@ class QuivrQARAG:
async def answer_astream(
self,
question: str,
history: list[dict[str, str]],
history: ChatHistory,
list_files: list[QuivrKnowledge],
metadata: dict[str, str] = {},
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:

View File

@ -11,7 +11,7 @@ from langchain.schema import (
from langchain_core.messages.ai import AIMessageChunk
from quivr_core.models import (
GetChatHistoryOutput,
ChatMessage,
ParsedRAGChunkResponse,
ParsedRAGResponse,
QuivrKnowledge,
@ -44,7 +44,7 @@ def model_supports_function_calling(model_name: str):
def format_chat_history(
history: List[GetChatHistoryOutput],
history: List[ChatMessage],
) -> List[Dict[str, str]]:
"""Format the chat history into a list of HumanMessage and AIMessage"""
formatted_history = []

View File

@ -0,0 +1,8 @@
import os
import pytest
@pytest.fixture(scope="session", autouse=True)
def openai_api_key():
os.environ["OPENAI_API_KEY"] = "abcd"

View File

@ -6,6 +6,7 @@ from langchain_core.embeddings import DeterministicFakeEmbedding, Embeddings
from langchain_core.language_models import FakeListChatModel
from quivr_core.brain import Brain
from quivr_core.chat import ChatHistory
from quivr_core.config import LLMEndpointConfig
from quivr_core.llm import LLMEndpoint
from quivr_core.storage.local_storage import TransparentStorage
@ -25,7 +26,7 @@ def answers():
@pytest.fixture(scope="function")
def llm(answers: list[str]):
def fake_llm(answers: list[str]):
llm = FakeListChatModel(responses=answers)
return LLMEndpoint(llm=llm, llm_config=LLMEndpointConfig(model="fake_model"))
@ -41,23 +42,26 @@ def test_brain_empty_files():
Brain.from_files(name="test_brain", file_paths=[])
def test_brain_from_files_success(llm: LLMEndpoint, embedder, temp_data_file):
def test_brain_from_files_success(fake_llm: LLMEndpoint, embedder, temp_data_file):
brain = Brain.from_files(
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=llm
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
)
assert brain.name == "test_brain"
assert brain.chat_history == []
assert brain.llm == llm
assert len(brain.chat_history) == 0
assert brain.llm == fake_llm
assert brain.vector_db.embeddings == embedder
assert isinstance(brain.default_chat, ChatHistory)
assert len(brain.default_chat) == 0
# storage
assert isinstance(brain.storage, TransparentStorage)
assert len(brain.storage.get_files()) == 1
@pytest.mark.asyncio
async def test_brain_ask_txt(llm: LLMEndpoint, embedder, temp_data_file, answers):
async def test_brain_ask_txt(fake_llm: LLMEndpoint, embedder, temp_data_file, answers):
brain = await Brain.afrom_files(
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=llm
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
)
answer = brain.ask("question")
@ -68,7 +72,7 @@ async def test_brain_ask_txt(llm: LLMEndpoint, embedder, temp_data_file, answers
@pytest.mark.asyncio
async def test_brain_from_langchain_docs(llm, embedder):
async def test_brain_from_langchain_docs(embedder):
chunk = Document("content_1", metadata={"id": uuid4()})
brain = await Brain.afrom_langchain_documents(
name="test", langchain_documents=[chunk], embedder=embedder
@ -92,3 +96,17 @@ async def test_brain_search(
assert len(result) == 1
assert result[0].chunk == chunk
assert result[0].score == 0
@pytest.mark.asyncio
async def test_brain_get_history(
fake_llm: LLMEndpoint, embedder, temp_data_file, answers
):
brain = await Brain.afrom_files(
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
)
brain.ask("question")
brain.ask("question")
assert len(brain.default_chat) == 4

View File

@ -0,0 +1,77 @@
from time import sleep
from uuid import uuid4
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from quivr_core.chat import ChatHistory
@pytest.fixture
def ai_message():
return AIMessage("ai message")
@pytest.fixture
def human_message():
return HumanMessage("human message")
def test_chat_history_constructor():
brain_id, chat_id = uuid4(), uuid4()
chat_history = ChatHistory(brain_id=brain_id, chat_id=chat_id)
assert chat_history.brain_id == brain_id
assert chat_history.id == chat_id
assert len(chat_history._msgs) == 0
def test_chat_history_append(ai_message: AIMessage, human_message: HumanMessage):
chat_history = ChatHistory(uuid4(), uuid4())
chat_history.append(ai_message)
assert len(chat_history) == 1
chat_history.append(human_message)
assert len(chat_history) == 2
def test_chat_history_get_history(ai_message: AIMessage, human_message: HumanMessage):
chat_history = ChatHistory(uuid4(), uuid4())
chat_history.append(ai_message)
chat_history.append(human_message)
chat_history.append(ai_message)
sleep(0.01)
chat_history.append(human_message)
msgs = chat_history.get_chat_history()
assert len(msgs) == 4
assert msgs[-1].message_time > msgs[0].message_time
assert isinstance(msgs[0].msg, AIMessage)
assert isinstance(msgs[1].msg, HumanMessage)
msgs = chat_history.get_chat_history(newest_first=True)
assert msgs[-1].message_time < msgs[0].message_time
def test_chat_history_iter_pairs_invalid(
ai_message: AIMessage, human_message: HumanMessage
):
with pytest.raises(AssertionError):
chat_history = ChatHistory(uuid4(), uuid4())
chat_history.append(ai_message)
chat_history.append(ai_message)
next(chat_history.iter_pairs())
def test_chat_history_iter_pais(ai_message: AIMessage, human_message: HumanMessage):
chat_history = ChatHistory(uuid4(), uuid4())
chat_history.append(human_message)
chat_history.append(ai_message)
chat_history.append(human_message)
chat_history.append(ai_message)
result = list(chat_history.iter_pairs())
assert result == [(human_message, ai_message), (human_message, ai_message)]