mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-22 20:09:40 +03:00
feat: quivr core chat history (#2824)
# Description - Defined quivr-core `ChatHistory` - `ChatHistory` can be iterated over in tuples of `HumanMessage,AIMessage` - Brain appends to the chatHistory once response is received - Brain holds a dict of chats and defines the default chat (TODO: define a system of selecting the chats) - Wrote test - Updated `QuivrQARAG` to use `ChatHistory` as input
This commit is contained in:
parent
442186c674
commit
847e161d80
@ -6,8 +6,10 @@ from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import LLMEndpointConfig, RAGConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.models import ParsedRAGResponse, SearchResult
|
||||
@ -108,13 +110,23 @@ class Brain:
|
||||
self.storage = storage
|
||||
|
||||
# Chat history
|
||||
self.chat_history: list[str] = []
|
||||
self._chats = self._init_chats()
|
||||
self.default_chat = list(self._chats.values())[0]
|
||||
|
||||
# RAG dependencies:
|
||||
self.llm = llm
|
||||
self.vector_db = vector_db
|
||||
self.embedder = embedder
|
||||
|
||||
@property
|
||||
def chat_history(self):
|
||||
return self.default_chat
|
||||
|
||||
def _init_chats(self):
|
||||
chat_id = uuid4()
|
||||
default_chat = ChatHistory(chat_id=chat_id, brain_id=self.id)
|
||||
return {chat_id: default_chat}
|
||||
|
||||
@classmethod
|
||||
async def afrom_files(
|
||||
cls,
|
||||
@ -235,6 +247,9 @@ class Brain:
|
||||
|
||||
return [SearchResult(chunk=d, score=s) for d, s in result]
|
||||
|
||||
def get_chat_history(self, chat_id: UUID):
|
||||
return self._chats[chat_id]
|
||||
|
||||
# TODO(@aminediro)
|
||||
def add_file(self) -> None:
|
||||
# add it to storage
|
||||
@ -242,7 +257,9 @@ class Brain:
|
||||
raise NotImplementedError
|
||||
|
||||
def ask(
|
||||
self, question: str, rag_config: RAGConfig | None = None
|
||||
self,
|
||||
question: str,
|
||||
rag_config: RAGConfig | None = None,
|
||||
) -> ParsedRAGResponse:
|
||||
llm = self.llm
|
||||
|
||||
@ -257,9 +274,12 @@ class Brain:
|
||||
rag_config=rag_config, llm=llm, vector_store=self.vector_db
|
||||
)
|
||||
|
||||
# transformed_history = format_chat_history(history)
|
||||
chat_history = self.default_chat
|
||||
|
||||
parsed_response = rag_pipeline.answer(question, [], [])
|
||||
parsed_response = rag_pipeline.answer(question, chat_history, [])
|
||||
|
||||
chat_history.append(HumanMessage(content=question))
|
||||
chat_history.append(AIMessage(content=parsed_response.answer))
|
||||
|
||||
# Save answer to the chat history
|
||||
return parsed_response
|
||||
|
50
backend/core/quivr_core/chat.py
Normal file
50
backend/core/quivr_core/chat.py
Normal file
@ -0,0 +1,50 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Generator, Tuple
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from quivr_core.models import ChatMessage
|
||||
|
||||
|
||||
class ChatHistory:
|
||||
def __init__(self, chat_id: UUID, brain_id: UUID) -> None:
|
||||
self.id = chat_id
|
||||
self.brain_id = brain_id
|
||||
# TODO(@aminediro): maybe use a deque() instead ?
|
||||
self._msgs: list[ChatMessage] = []
|
||||
|
||||
def get_chat_history(self, newest_first: bool = False):
|
||||
"""Returns a ChatMessage list sorted by time
|
||||
|
||||
Returns:
|
||||
list[ChatMessage]: list of chat messages
|
||||
"""
|
||||
history = sorted(self._msgs, key=lambda msg: msg.message_time)
|
||||
if newest_first:
|
||||
return history[::-1]
|
||||
return history
|
||||
|
||||
def __len__(self):
|
||||
return len(self._msgs)
|
||||
|
||||
def append(
|
||||
self, langchain_msg: AIMessage | HumanMessage, metadata: dict[str, Any] = {}
|
||||
):
|
||||
chat_msg = ChatMessage(
|
||||
chat_id=self.id,
|
||||
message_id=uuid4(),
|
||||
brain_id=self.brain_id,
|
||||
msg=langchain_msg,
|
||||
message_time=datetime.now(),
|
||||
metadata=metadata,
|
||||
)
|
||||
self._msgs.append(chat_msg)
|
||||
|
||||
def iter_pairs(self) -> Generator[Tuple[HumanMessage, AIMessage], None, None]:
|
||||
# Reverse the chat_history, newest first
|
||||
it = iter(self.get_chat_history(newest_first=True))
|
||||
for ai_message, human_message in zip(it, it):
|
||||
assert isinstance(human_message.msg, HumanMessage)
|
||||
assert isinstance(ai_message.msg, AIMessage)
|
||||
yield (human_message.msg, ai_message.msg)
|
@ -12,5 +12,6 @@ class LLMEndpointConfig(BaseModel):
|
||||
|
||||
class RAGConfig(BaseModel):
|
||||
llm_config: LLMEndpointConfig = LLMEndpointConfig()
|
||||
max_history: int = 10
|
||||
max_files: int = 20
|
||||
prompt: str | None = None
|
||||
|
@ -3,6 +3,7 @@ from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.pydantic_v1 import BaseModel as BaseModelV1
|
||||
from langchain_core.pydantic_v1 import Field as FieldV1
|
||||
from pydantic import BaseModel
|
||||
@ -33,17 +34,13 @@ class cited_answer(BaseModelV1):
|
||||
)
|
||||
|
||||
|
||||
class GetChatHistoryOutput(BaseModel):
|
||||
class ChatMessage(BaseModelV1):
|
||||
chat_id: UUID
|
||||
message_id: UUID
|
||||
user_message: str
|
||||
brain_id: UUID
|
||||
msg: AIMessage | HumanMessage
|
||||
message_time: datetime
|
||||
assistant: str | None = None
|
||||
prompt_title: str | None = None
|
||||
brain_name: str | None = None
|
||||
brain_id: UUID | None = None # string because UUID is not JSON serializable
|
||||
metadata: dict | None = None
|
||||
thumbs: bool | None = None
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class Source(BaseModel):
|
||||
|
@ -5,11 +5,13 @@ from typing import AsyncGenerator, Optional, Sequence
|
||||
from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import RAGConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.models import (
|
||||
@ -59,7 +61,8 @@ class QuivrQARAG:
|
||||
return self.vector_store.as_retriever()
|
||||
|
||||
def filter_history(
|
||||
self, chat_history, max_history: int = 10, max_tokens: int = 2000
|
||||
self,
|
||||
chat_history: ChatHistory,
|
||||
):
|
||||
"""
|
||||
Filter out the chat history to only include the messages that are relevant to the current question
|
||||
@ -68,29 +71,23 @@ class QuivrQARAG:
|
||||
Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair
|
||||
a token is 4 characters
|
||||
"""
|
||||
chat_history = chat_history[::-1]
|
||||
total_tokens = 0
|
||||
total_pairs = 0
|
||||
filtered_chat_history = []
|
||||
for i in range(0, len(chat_history), 2):
|
||||
if i + 1 < len(chat_history):
|
||||
human_message = chat_history[i]
|
||||
ai_message = chat_history[i + 1]
|
||||
message_tokens = (
|
||||
len(human_message.content) + len(ai_message.content)
|
||||
) // 4
|
||||
if (
|
||||
total_tokens + message_tokens > max_tokens
|
||||
or total_pairs >= max_history
|
||||
):
|
||||
break
|
||||
filtered_chat_history.append(human_message)
|
||||
filtered_chat_history.append(ai_message)
|
||||
total_tokens += message_tokens
|
||||
total_pairs += 1
|
||||
chat_history = filtered_chat_history[::-1]
|
||||
filtered_chat_history: list[AIMessage | HumanMessage] = []
|
||||
for human_message, ai_message in chat_history.iter_pairs():
|
||||
# TODO: replace with tiktoken
|
||||
message_tokens = (len(human_message.content) + len(ai_message.content)) // 4
|
||||
if (
|
||||
total_tokens + message_tokens > self.rag_config.llm_config.max_tokens
|
||||
or total_pairs >= self.rag_config.max_history
|
||||
):
|
||||
break
|
||||
filtered_chat_history.append(human_message)
|
||||
filtered_chat_history.append(ai_message)
|
||||
total_tokens += message_tokens
|
||||
total_pairs += 1
|
||||
|
||||
return chat_history
|
||||
return filtered_chat_history[::-1]
|
||||
|
||||
def build_chain(self, files: str):
|
||||
compression_retriever = ContextualCompressionRetriever(
|
||||
@ -146,7 +143,7 @@ class QuivrQARAG:
|
||||
def answer(
|
||||
self,
|
||||
question: str,
|
||||
history: list[dict[str, str]],
|
||||
history: ChatHistory,
|
||||
list_files: list[QuivrKnowledge],
|
||||
metadata: dict[str, str] = {},
|
||||
) -> ParsedRAGResponse:
|
||||
@ -166,7 +163,7 @@ class QuivrQARAG:
|
||||
async def answer_astream(
|
||||
self,
|
||||
question: str,
|
||||
history: list[dict[str, str]],
|
||||
history: ChatHistory,
|
||||
list_files: list[QuivrKnowledge],
|
||||
metadata: dict[str, str] = {},
|
||||
) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]:
|
||||
|
@ -11,7 +11,7 @@ from langchain.schema import (
|
||||
from langchain_core.messages.ai import AIMessageChunk
|
||||
|
||||
from quivr_core.models import (
|
||||
GetChatHistoryOutput,
|
||||
ChatMessage,
|
||||
ParsedRAGChunkResponse,
|
||||
ParsedRAGResponse,
|
||||
QuivrKnowledge,
|
||||
@ -44,7 +44,7 @@ def model_supports_function_calling(model_name: str):
|
||||
|
||||
|
||||
def format_chat_history(
|
||||
history: List[GetChatHistoryOutput],
|
||||
history: List[ChatMessage],
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Format the chat history into a list of HumanMessage and AIMessage"""
|
||||
formatted_history = []
|
||||
|
8
backend/core/tests/conftest.py
Normal file
8
backend/core/tests/conftest.py
Normal file
@ -0,0 +1,8 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def openai_api_key():
|
||||
os.environ["OPENAI_API_KEY"] = "abcd"
|
@ -6,6 +6,7 @@ from langchain_core.embeddings import DeterministicFakeEmbedding, Embeddings
|
||||
from langchain_core.language_models import FakeListChatModel
|
||||
|
||||
from quivr_core.brain import Brain
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.storage.local_storage import TransparentStorage
|
||||
@ -25,7 +26,7 @@ def answers():
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def llm(answers: list[str]):
|
||||
def fake_llm(answers: list[str]):
|
||||
llm = FakeListChatModel(responses=answers)
|
||||
return LLMEndpoint(llm=llm, llm_config=LLMEndpointConfig(model="fake_model"))
|
||||
|
||||
@ -41,23 +42,26 @@ def test_brain_empty_files():
|
||||
Brain.from_files(name="test_brain", file_paths=[])
|
||||
|
||||
|
||||
def test_brain_from_files_success(llm: LLMEndpoint, embedder, temp_data_file):
|
||||
def test_brain_from_files_success(fake_llm: LLMEndpoint, embedder, temp_data_file):
|
||||
brain = Brain.from_files(
|
||||
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=llm
|
||||
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
|
||||
)
|
||||
assert brain.name == "test_brain"
|
||||
assert brain.chat_history == []
|
||||
assert brain.llm == llm
|
||||
assert len(brain.chat_history) == 0
|
||||
assert brain.llm == fake_llm
|
||||
assert brain.vector_db.embeddings == embedder
|
||||
assert isinstance(brain.default_chat, ChatHistory)
|
||||
assert len(brain.default_chat) == 0
|
||||
|
||||
# storage
|
||||
assert isinstance(brain.storage, TransparentStorage)
|
||||
assert len(brain.storage.get_files()) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brain_ask_txt(llm: LLMEndpoint, embedder, temp_data_file, answers):
|
||||
async def test_brain_ask_txt(fake_llm: LLMEndpoint, embedder, temp_data_file, answers):
|
||||
brain = await Brain.afrom_files(
|
||||
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=llm
|
||||
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
|
||||
)
|
||||
answer = brain.ask("question")
|
||||
|
||||
@ -68,7 +72,7 @@ async def test_brain_ask_txt(llm: LLMEndpoint, embedder, temp_data_file, answers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brain_from_langchain_docs(llm, embedder):
|
||||
async def test_brain_from_langchain_docs(embedder):
|
||||
chunk = Document("content_1", metadata={"id": uuid4()})
|
||||
brain = await Brain.afrom_langchain_documents(
|
||||
name="test", langchain_documents=[chunk], embedder=embedder
|
||||
@ -92,3 +96,17 @@ async def test_brain_search(
|
||||
assert len(result) == 1
|
||||
assert result[0].chunk == chunk
|
||||
assert result[0].score == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brain_get_history(
|
||||
fake_llm: LLMEndpoint, embedder, temp_data_file, answers
|
||||
):
|
||||
brain = await Brain.afrom_files(
|
||||
name="test_brain", file_paths=[temp_data_file], embedder=embedder, llm=fake_llm
|
||||
)
|
||||
|
||||
brain.ask("question")
|
||||
brain.ask("question")
|
||||
|
||||
assert len(brain.default_chat) == 4
|
||||
|
77
backend/core/tests/test_chat_history.py
Normal file
77
backend/core/tests/test_chat_history.py
Normal file
@ -0,0 +1,77 @@
|
||||
from time import sleep
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from quivr_core.chat import ChatHistory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ai_message():
|
||||
return AIMessage("ai message")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def human_message():
|
||||
return HumanMessage("human message")
|
||||
|
||||
|
||||
def test_chat_history_constructor():
|
||||
brain_id, chat_id = uuid4(), uuid4()
|
||||
chat_history = ChatHistory(brain_id=brain_id, chat_id=chat_id)
|
||||
|
||||
assert chat_history.brain_id == brain_id
|
||||
assert chat_history.id == chat_id
|
||||
assert len(chat_history._msgs) == 0
|
||||
|
||||
|
||||
def test_chat_history_append(ai_message: AIMessage, human_message: HumanMessage):
|
||||
chat_history = ChatHistory(uuid4(), uuid4())
|
||||
chat_history.append(ai_message)
|
||||
|
||||
assert len(chat_history) == 1
|
||||
chat_history.append(human_message)
|
||||
assert len(chat_history) == 2
|
||||
|
||||
|
||||
def test_chat_history_get_history(ai_message: AIMessage, human_message: HumanMessage):
|
||||
chat_history = ChatHistory(uuid4(), uuid4())
|
||||
chat_history.append(ai_message)
|
||||
chat_history.append(human_message)
|
||||
chat_history.append(ai_message)
|
||||
sleep(0.01)
|
||||
chat_history.append(human_message)
|
||||
|
||||
msgs = chat_history.get_chat_history()
|
||||
|
||||
assert len(msgs) == 4
|
||||
assert msgs[-1].message_time > msgs[0].message_time
|
||||
assert isinstance(msgs[0].msg, AIMessage)
|
||||
assert isinstance(msgs[1].msg, HumanMessage)
|
||||
|
||||
msgs = chat_history.get_chat_history(newest_first=True)
|
||||
assert msgs[-1].message_time < msgs[0].message_time
|
||||
|
||||
|
||||
def test_chat_history_iter_pairs_invalid(
|
||||
ai_message: AIMessage, human_message: HumanMessage
|
||||
):
|
||||
with pytest.raises(AssertionError):
|
||||
chat_history = ChatHistory(uuid4(), uuid4())
|
||||
chat_history.append(ai_message)
|
||||
chat_history.append(ai_message)
|
||||
next(chat_history.iter_pairs())
|
||||
|
||||
|
||||
def test_chat_history_iter_pais(ai_message: AIMessage, human_message: HumanMessage):
|
||||
chat_history = ChatHistory(uuid4(), uuid4())
|
||||
|
||||
chat_history.append(human_message)
|
||||
chat_history.append(ai_message)
|
||||
chat_history.append(human_message)
|
||||
chat_history.append(ai_message)
|
||||
|
||||
result = list(chat_history.iter_pairs())
|
||||
|
||||
assert result == [(human_message, ai_message), (human_message, ai_message)]
|
Loading…
Reference in New Issue
Block a user