diff --git a/backend/core/quivr_core/brain.py b/backend/core/quivr_core/brain.py index 526f59626..7db99d5bd 100644 --- a/backend/core/quivr_core/brain.py +++ b/backend/core/quivr_core/brain.py @@ -6,8 +6,10 @@ from uuid import UUID, uuid4 from langchain_core.documents import Document from langchain_core.embeddings import Embeddings +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.vectorstores import VectorStore +from quivr_core.chat import ChatHistory from quivr_core.config import LLMEndpointConfig, RAGConfig from quivr_core.llm import LLMEndpoint from quivr_core.models import ParsedRAGResponse, SearchResult @@ -108,13 +110,23 @@ class Brain: self.storage = storage # Chat history - self.chat_history: list[str] = [] + self._chats = self._init_chats() + self.default_chat = list(self._chats.values())[0] # RAG dependencies: self.llm = llm self.vector_db = vector_db self.embedder = embedder + @property + def chat_history(self): + return self.default_chat + + def _init_chats(self): + chat_id = uuid4() + default_chat = ChatHistory(chat_id=chat_id, brain_id=self.id) + return {chat_id: default_chat} + @classmethod async def afrom_files( cls, @@ -235,6 +247,9 @@ class Brain: return [SearchResult(chunk=d, score=s) for d, s in result] + def get_chat_history(self, chat_id: UUID): + return self._chats[chat_id] + # TODO(@aminediro) def add_file(self) -> None: # add it to storage @@ -242,7 +257,9 @@ class Brain: raise NotImplementedError def ask( - self, question: str, rag_config: RAGConfig | None = None + self, + question: str, + rag_config: RAGConfig | None = None, ) -> ParsedRAGResponse: llm = self.llm @@ -257,9 +274,12 @@ class Brain: rag_config=rag_config, llm=llm, vector_store=self.vector_db ) - # transformed_history = format_chat_history(history) + chat_history = self.default_chat - parsed_response = rag_pipeline.answer(question, [], []) + parsed_response = rag_pipeline.answer(question, chat_history, []) + + chat_history.append(HumanMessage(content=question)) + chat_history.append(AIMessage(content=parsed_response.answer)) # Save answer to the chat history return parsed_response diff --git a/backend/core/quivr_core/chat.py b/backend/core/quivr_core/chat.py new file mode 100644 index 000000000..b1576650e --- /dev/null +++ b/backend/core/quivr_core/chat.py @@ -0,0 +1,50 @@ +from datetime import datetime +from typing import Any, Generator, Tuple +from uuid import UUID, uuid4 + +from langchain_core.messages import AIMessage, HumanMessage + +from quivr_core.models import ChatMessage + + +class ChatHistory: + def __init__(self, chat_id: UUID, brain_id: UUID) -> None: + self.id = chat_id + self.brain_id = brain_id + # TODO(@aminediro): maybe use a deque() instead ? + self._msgs: list[ChatMessage] = [] + + def get_chat_history(self, newest_first: bool = False): + """Returns a ChatMessage list sorted by time + + Returns: + list[ChatMessage]: list of chat messages + """ + history = sorted(self._msgs, key=lambda msg: msg.message_time) + if newest_first: + return history[::-1] + return history + + def __len__(self): + return len(self._msgs) + + def append( + self, langchain_msg: AIMessage | HumanMessage, metadata: dict[str, Any] = {} + ): + chat_msg = ChatMessage( + chat_id=self.id, + message_id=uuid4(), + brain_id=self.brain_id, + msg=langchain_msg, + message_time=datetime.now(), + metadata=metadata, + ) + self._msgs.append(chat_msg) + + def iter_pairs(self) -> Generator[Tuple[HumanMessage, AIMessage], None, None]: + # Reverse the chat_history, newest first + it = iter(self.get_chat_history(newest_first=True)) + for ai_message, human_message in zip(it, it): + assert isinstance(human_message.msg, HumanMessage) + assert isinstance(ai_message.msg, AIMessage) + yield (human_message.msg, ai_message.msg) diff --git a/backend/core/quivr_core/config.py b/backend/core/quivr_core/config.py index 1d03829c6..98d154e75 100644 --- a/backend/core/quivr_core/config.py +++ b/backend/core/quivr_core/config.py @@ -12,5 +12,6 @@ class LLMEndpointConfig(BaseModel): class RAGConfig(BaseModel): llm_config: LLMEndpointConfig = LLMEndpointConfig() + max_history: int = 10 max_files: int = 20 prompt: str | None = None diff --git a/backend/core/quivr_core/models.py b/backend/core/quivr_core/models.py index fa900d1a7..b3b1cce76 100644 --- a/backend/core/quivr_core/models.py +++ b/backend/core/quivr_core/models.py @@ -3,6 +3,7 @@ from typing import Any from uuid import UUID from langchain_core.documents import Document +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.pydantic_v1 import BaseModel as BaseModelV1 from langchain_core.pydantic_v1 import Field as FieldV1 from pydantic import BaseModel @@ -33,17 +34,13 @@ class cited_answer(BaseModelV1): ) -class GetChatHistoryOutput(BaseModel): +class ChatMessage(BaseModelV1): chat_id: UUID message_id: UUID - user_message: str + brain_id: UUID + msg: AIMessage | HumanMessage message_time: datetime - assistant: str | None = None - prompt_title: str | None = None - brain_name: str | None = None - brain_id: UUID | None = None # string because UUID is not JSON serializable - metadata: dict | None = None - thumbs: bool | None = None + metadata: dict[str, Any] class Source(BaseModel): diff --git a/backend/core/quivr_core/quivr_rag.py b/backend/core/quivr_core/quivr_rag.py index d8d566094..7539f2b02 100644 --- a/backend/core/quivr_core/quivr_rag.py +++ b/backend/core/quivr_core/quivr_rag.py @@ -5,11 +5,13 @@ from typing import AsyncGenerator, Optional, Sequence from langchain.retrievers import ContextualCompressionRetriever from langchain_core.callbacks import Callbacks from langchain_core.documents import BaseDocumentCompressor, Document +from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages.ai import AIMessageChunk from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableLambda, RunnablePassthrough from langchain_core.vectorstores import VectorStore +from quivr_core.chat import ChatHistory from quivr_core.config import RAGConfig from quivr_core.llm import LLMEndpoint from quivr_core.models import ( @@ -59,7 +61,8 @@ class QuivrQARAG: return self.vector_store.as_retriever() def filter_history( - self, chat_history, max_history: int = 10, max_tokens: int = 2000 + self, + chat_history: ChatHistory, ): """ Filter out the chat history to only include the messages that are relevant to the current question @@ -68,29 +71,23 @@ class QuivrQARAG: Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair a token is 4 characters """ - chat_history = chat_history[::-1] total_tokens = 0 total_pairs = 0 - filtered_chat_history = [] - for i in range(0, len(chat_history), 2): - if i + 1 < len(chat_history): - human_message = chat_history[i] - ai_message = chat_history[i + 1] - message_tokens = ( - len(human_message.content) + len(ai_message.content) - ) // 4 - if ( - total_tokens + message_tokens > max_tokens - or total_pairs >= max_history - ): - break - filtered_chat_history.append(human_message) - filtered_chat_history.append(ai_message) - total_tokens += message_tokens - total_pairs += 1 - chat_history = filtered_chat_history[::-1] + filtered_chat_history: list[AIMessage | HumanMessage] = [] + for human_message, ai_message in chat_history.iter_pairs(): + # TODO: replace with tiktoken + message_tokens = (len(human_message.content) + len(ai_message.content)) // 4 + if ( + total_tokens + message_tokens > self.rag_config.llm_config.max_tokens + or total_pairs >= self.rag_config.max_history + ): + break + filtered_chat_history.append(human_message) + filtered_chat_history.append(ai_message) + total_tokens += message_tokens + total_pairs += 1 - return chat_history + return filtered_chat_history[::-1] def build_chain(self, files: str): compression_retriever = ContextualCompressionRetriever( @@ -146,7 +143,7 @@ class QuivrQARAG: def answer( self, question: str, - history: list[dict[str, str]], + history: ChatHistory, list_files: list[QuivrKnowledge], metadata: dict[str, str] = {}, ) -> ParsedRAGResponse: @@ -166,7 +163,7 @@ class QuivrQARAG: async def answer_astream( self, question: str, - history: list[dict[str, str]], + history: ChatHistory, list_files: list[QuivrKnowledge], metadata: dict[str, str] = {}, ) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]: diff --git a/backend/core/quivr_core/utils.py b/backend/core/quivr_core/utils.py index 333f4b376..383ff83a5 100644 --- a/backend/core/quivr_core/utils.py +++ b/backend/core/quivr_core/utils.py @@ -11,7 +11,7 @@ from langchain.schema import ( from langchain_core.messages.ai import AIMessageChunk from quivr_core.models import ( - GetChatHistoryOutput, + ChatMessage, ParsedRAGChunkResponse, ParsedRAGResponse, QuivrKnowledge, @@ -44,7 +44,7 @@ def model_supports_function_calling(model_name: str): def format_chat_history( - history: List[GetChatHistoryOutput], + history: List[ChatMessage], ) -> List[Dict[str, str]]: """Format the chat history into a list of HumanMessage and AIMessage""" formatted_history = [] diff --git a/backend/core/tests/conftest.py b/backend/core/tests/conftest.py new file mode 100644 index 000000000..204521665 --- /dev/null +++ b/backend/core/tests/conftest.py @@ -0,0 +1,8 @@ +import os + +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def openai_api_key(): + os.environ["OPENAI_API_KEY"] = "abcd" diff --git a/backend/core/tests/test_brain.py b/backend/core/tests/test_brain.py index 67372f381..b221e3281 100644 --- a/backend/core/tests/test_brain.py +++ b/backend/core/tests/test_brain.py @@ -6,6 +6,7 @@ from langchain_core.embeddings import DeterministicFakeEmbedding, Embeddings from langchain_core.language_models import FakeListChatModel from quivr_core.brain import Brain +from quivr_core.chat import ChatHistory from quivr_core.config import LLMEndpointConfig from quivr_core.llm import LLMEndpoint from quivr_core.storage.local_storage import TransparentStorage @@ -25,7 +26,7 @@ def answers(): @pytest.fixture(scope="function") -def llm(answers: list[str]): +def fake_llm(answers: list[str]): llm = FakeListChatModel(responses=answers) return LLMEndpoint(llm=llm, llm_config=LLMEndpointConfig(model="fake_model")) @@ -41,23 +42,26 @@ def test_brain_empty_files(): Brain.from_files(name="test_brain", file_paths=[]) -def test_brain_from_files_success(llm: LLMEndpoint, embedder, temp_data_file): +def test_brain_from_files_success(fake_llm: LLMEndpoint, embedder, temp_data_file): brain = Brain.from_files( - name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=llm + name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm ) assert brain.name == "test_brain" - assert brain.chat_history == [] - assert brain.llm == llm + assert len(brain.chat_history) == 0 + assert brain.llm == fake_llm assert brain.vector_db.embeddings == embedder + assert isinstance(brain.default_chat, ChatHistory) + assert len(brain.default_chat) == 0 + # storage assert isinstance(brain.storage, TransparentStorage) assert len(brain.storage.get_files()) == 1 @pytest.mark.asyncio -async def test_brain_ask_txt(llm: LLMEndpoint, embedder, temp_data_file, answers): +async def test_brain_ask_txt(fake_llm: LLMEndpoint, embedder, temp_data_file, answers): brain = await Brain.afrom_files( - name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=llm + name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm ) answer = brain.ask("question") @@ -68,7 +72,7 @@ async def test_brain_ask_txt(llm: LLMEndpoint, embedder, temp_data_file, answers @pytest.mark.asyncio -async def test_brain_from_langchain_docs(llm, embedder): +async def test_brain_from_langchain_docs(embedder): chunk = Document("content_1", metadata={"id": uuid4()}) brain = await Brain.afrom_langchain_documents( name="test", langchain_documents=[chunk], embedder=embedder @@ -92,3 +96,17 @@ async def test_brain_search( assert len(result) == 1 assert result[0].chunk == chunk assert result[0].score == 0 + + +@pytest.mark.asyncio +async def test_brain_get_history( + fake_llm: LLMEndpoint, embedder, temp_data_file, answers +): + brain = await Brain.afrom_files( + name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm + ) + + brain.ask("question") + brain.ask("question") + + assert len(brain.default_chat) == 4 diff --git a/backend/core/tests/test_chat_history.py b/backend/core/tests/test_chat_history.py new file mode 100644 index 000000000..8cb89e7c8 --- /dev/null +++ b/backend/core/tests/test_chat_history.py @@ -0,0 +1,77 @@ +from time import sleep +from uuid import uuid4 + +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from quivr_core.chat import ChatHistory + + +@pytest.fixture +def ai_message(): + return AIMessage("ai message") + + +@pytest.fixture +def human_message(): + return HumanMessage("human message") + + +def test_chat_history_constructor(): + brain_id, chat_id = uuid4(), uuid4() + chat_history = ChatHistory(brain_id=brain_id, chat_id=chat_id) + + assert chat_history.brain_id == brain_id + assert chat_history.id == chat_id + assert len(chat_history._msgs) == 0 + + +def test_chat_history_append(ai_message: AIMessage, human_message: HumanMessage): + chat_history = ChatHistory(uuid4(), uuid4()) + chat_history.append(ai_message) + + assert len(chat_history) == 1 + chat_history.append(human_message) + assert len(chat_history) == 2 + + +def test_chat_history_get_history(ai_message: AIMessage, human_message: HumanMessage): + chat_history = ChatHistory(uuid4(), uuid4()) + chat_history.append(ai_message) + chat_history.append(human_message) + chat_history.append(ai_message) + sleep(0.01) + chat_history.append(human_message) + + msgs = chat_history.get_chat_history() + + assert len(msgs) == 4 + assert msgs[-1].message_time > msgs[0].message_time + assert isinstance(msgs[0].msg, AIMessage) + assert isinstance(msgs[1].msg, HumanMessage) + + msgs = chat_history.get_chat_history(newest_first=True) + assert msgs[-1].message_time < msgs[0].message_time + + +def test_chat_history_iter_pairs_invalid( + ai_message: AIMessage, human_message: HumanMessage +): + with pytest.raises(AssertionError): + chat_history = ChatHistory(uuid4(), uuid4()) + chat_history.append(ai_message) + chat_history.append(ai_message) + next(chat_history.iter_pairs()) + + +def test_chat_history_iter_pais(ai_message: AIMessage, human_message: HumanMessage): + chat_history = ChatHistory(uuid4(), uuid4()) + + chat_history.append(human_message) + chat_history.append(ai_message) + chat_history.append(human_message) + chat_history.append(ai_message) + + result = list(chat_history.iter_pairs()) + + assert result == [(human_message, ai_message), (human_message, ai_message)]