From fccd197511d8594db257bfddf757bf0d28f7239d Mon Sep 17 00:00:00 2001 From: Stan Girard Date: Tue, 6 Aug 2024 14:51:27 +0200 Subject: [PATCH] feat: add chat with models (#2933) # Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --------- Co-authored-by: AmineDiro --- .vscode/settings.json | 2 +- Makefile | 5 +- backend/api/quivr_api/main.py | 3 +- .../modules/chat/controller/chat_routes.py | 59 +++-- .../modules/chat/service/chat_service.py | 12 +- .../modules/chat_llm_service/__init__.py | 3 + .../chat_llm_service/chat_llm_service.py | 214 ++++++++++++++++++ .../modules/models/controller/__init__.py | 3 + .../modules/models/controller/model_routes.py | 30 +++ .../quivr_api/modules/models/entity/model.py | 13 ++ .../modules/models/repository/model.py | 19 ++ .../models/repository/model_interface.py | 12 + .../modules/models/service/model_service.py | 21 ++ .../modules/models/tests/test_models.py | 70 ++++++ backend/core/examples/chat_llm.py | 12 + backend/core/quivr_core/__init__.py | 3 +- backend/core/quivr_core/chat.py | 3 +- backend/core/quivr_core/chat_llm.py | 146 ++++++++++++ backend/core/quivr_core/llm/llm_endpoint.py | 1 - backend/core/quivr_core/llm/models.py | 0 backend/core/quivr_core/models.py | 3 +- backend/core/quivr_core/utils.py | 5 +- backend/core/tests/test_chat_llm.py | 17 ++ 23 files changed, 624 insertions(+), 32 deletions(-) create mode 100644 backend/api/quivr_api/modules/chat_llm_service/__init__.py create mode 100644 backend/api/quivr_api/modules/chat_llm_service/chat_llm_service.py create mode 100644 backend/api/quivr_api/modules/models/controller/__init__.py create mode 100644 backend/api/quivr_api/modules/models/controller/model_routes.py create mode 100644 backend/api/quivr_api/modules/models/entity/model.py create mode 100644 backend/api/quivr_api/modules/models/repository/model.py create mode 100644 backend/api/quivr_api/modules/models/repository/model_interface.py create mode 100644 backend/api/quivr_api/modules/models/service/model_service.py create mode 100644 backend/api/quivr_api/modules/models/tests/test_models.py create mode 100644 backend/core/examples/chat_llm.py create mode 100644 backend/core/quivr_core/chat_llm.py delete mode 100644 backend/core/quivr_core/llm/models.py create mode 100644 backend/core/tests/test_chat_llm.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 7563f37b2..700a8799b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -18,7 +18,7 @@ }, "json.sortOnSave.enable": true, "[python]": { - "editor.defaultFormatter": "ms-python.black-formatter", + "editor.defaultFormatter": "charliermarsh.ruff", "editor.formatOnSave": true, "editor.codeActionsOnSave": { "source.organizeImports": "explicit", diff --git a/Makefile b/Makefile index 6cc077c27..432a5081c 100644 --- a/Makefile +++ b/Makefile @@ -32,4 +32,7 @@ test-type: fi front: - cd frontend && yarn build && yarn start \ No newline at end of file + cd frontend && yarn build && yarn start + +test: + cd backend/core && ./scripts/run_tests.sh \ No newline at end of file diff --git a/backend/api/quivr_api/main.py b/backend/api/quivr_api/main.py index 821054b0c..d869381dc 100644 --- a/backend/api/quivr_api/main.py +++ b/backend/api/quivr_api/main.py @@ -16,6 +16,7 @@ from quivr_api.modules.brain.controller import brain_router from quivr_api.modules.chat.controller import chat_router from quivr_api.modules.knowledge.controller import knowledge_router from quivr_api.modules.misc.controller import misc_router +from quivr_api.modules.models.controller.model_routes import model_router from quivr_api.modules.onboarding.controller import onboarding_router from quivr_api.modules.prompt.controller import prompt_router from quivr_api.modules.sync.controller import sync_router @@ -78,13 +79,13 @@ app.include_router(sync_router) app.include_router(onboarding_router) app.include_router(misc_router) app.include_router(analytics_router) - app.include_router(upload_router) app.include_router(user_router) app.include_router(api_key_router) app.include_router(subscription_router) app.include_router(prompt_router) app.include_router(knowledge_router) +app.include_router(model_router) PROFILING = os.getenv("PROFILING", "false").lower() == "true" diff --git a/backend/api/quivr_api/modules/chat/controller/chat_routes.py b/backend/api/quivr_api/modules/chat/controller/chat_routes.py index fafe94f59..c6dd6dcda 100644 --- a/backend/api/quivr_api/modules/chat/controller/chat_routes.py +++ b/backend/api/quivr_api/modules/chat/controller/chat_routes.py @@ -19,6 +19,7 @@ from quivr_api.modules.chat.dto.inputs import ( ) from quivr_api.modules.chat.entity.chat import Chat from quivr_api.modules.chat.service.chat_service import ChatService +from quivr_api.modules.chat_llm_service.chat_llm_service import ChatLLMService from quivr_api.modules.dependencies import get_service from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.prompt.service.prompt_service import PromptService @@ -166,16 +167,25 @@ async def create_question_handler( # TODO: check logic into middleware validate_authorization(user_id=current_user.id, brain_id=brain_id) try: - rag_service = RAGService( - current_user, - brain_id, - chat_id, - brain_service, - prompt_service, - chat_service, - knowledge_service, - ) - chat_answer = await rag_service.generate_answer(chat_question.question) + service = None + if brain_id: + service = RAGService( + current_user, + brain_id, + chat_id, + brain_service, + prompt_service, + chat_service, + knowledge_service, + ) + else: + service = ChatLLMService( + current_user, + chat_question.model, + chat_id, + chat_service, + ) + chat_answer = await service.generate_answer(chat_question.question) maybe_send_telemetry("question_asked", {"streaming": False}, request) return chat_answer @@ -214,19 +224,28 @@ async def create_stream_question_handler( ) try: - rag_service = RAGService( - current_user, - brain_id, - chat_id, - brain_service, - prompt_service, - chat_service, - knowledge_service, - ) + service = None + if brain_id: + service = RAGService( + current_user, + brain_id, + chat_id, + brain_service, + prompt_service, + chat_service, + knowledge_service, + ) + else: + service = ChatLLMService( + current_user, + chat_question.model, + chat_id, + chat_service, + ) maybe_send_telemetry("question_asked", {"streaming": True}, request) return StreamingResponse( - rag_service.generate_answer_stream(chat_question.question), + service.generate_answer_stream(chat_question.question), media_type="text/event-stream", ) diff --git a/backend/api/quivr_api/modules/chat/service/chat_service.py b/backend/api/quivr_api/modules/chat/service/chat_service.py index 9aa763093..f7de7cec4 100644 --- a/backend/api/quivr_api/modules/chat/service/chat_service.py +++ b/backend/api/quivr_api/modules/chat/service/chat_service.py @@ -53,7 +53,7 @@ class ChatService(BaseService[ChatRepository]): return inserted_chat def get_follow_up_question( - self, brain_id: UUID | None = None, question: str = None + self, brain_id: UUID | None = None, question: str | None = None ) -> [str]: follow_up = [ "Summarize the conversation", @@ -87,9 +87,15 @@ class ChatService(BaseService[ChatRepository]): enriched_history: List[GetChatHistoryOutput] = [] if len(history) == 0: return enriched_history - brain: Brain = await history[0].awaitable_attrs.brain - prompt: Prompt = await brain.awaitable_attrs.prompt for message in history: + brain: Brain | None = ( + await message.awaitable_attrs.brain if message.brain_id else None + ) + prompt: Prompt | None = None + if brain: + prompt = ( + await brain.awaitable_attrs.prompt if message.prompt_id else None + ) enriched_history.append( # TODO : WHY bother with having ids here ?? GetChatHistoryOutput( diff --git a/backend/api/quivr_api/modules/chat_llm_service/__init__.py b/backend/api/quivr_api/modules/chat_llm_service/__init__.py new file mode 100644 index 000000000..d3f79a025 --- /dev/null +++ b/backend/api/quivr_api/modules/chat_llm_service/__init__.py @@ -0,0 +1,3 @@ +from .chat_llm_service import ChatLLMService + +__all__ = ["ChatLLMService"] diff --git a/backend/api/quivr_api/modules/chat_llm_service/chat_llm_service.py b/backend/api/quivr_api/modules/chat_llm_service/chat_llm_service.py new file mode 100644 index 000000000..c7c17861f --- /dev/null +++ b/backend/api/quivr_api/modules/chat_llm_service/chat_llm_service.py @@ -0,0 +1,214 @@ +import datetime +from uuid import UUID, uuid4 + +from quivr_core.chat import ChatHistory as ChatHistoryCore +from quivr_core.chat_llm import ChatLLM +from quivr_core.config import LLMEndpointConfig +from quivr_core.llm.llm_endpoint import LLMEndpoint +from quivr_core.models import ParsedRAGResponse, RAGResponseMetadata + +from quivr_api.logger import get_logger +from quivr_api.models.settings import settings +from quivr_api.modules.brain.service.utils.format_chat_history import ( + format_chat_history, +) +from quivr_api.modules.chat.controller.chat.utils import ( + compute_cost, + find_model_and_generate_metadata, + update_user_usage, +) +from quivr_api.modules.chat.dto.inputs import CreateChatHistory +from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput +from quivr_api.modules.chat.service.chat_service import ChatService +from quivr_api.modules.user.entity.user_identity import UserIdentity +from quivr_api.modules.user.service.user_usage import UserUsage + +logger = get_logger(__name__) + + +class ChatLLMService: + def __init__( + self, + current_user: UserIdentity, + model_name: str, + chat_id: UUID, + chat_service: ChatService, + ): + # Services + self.chat_service = chat_service + + # Base models + self.current_user = current_user + self.chat_id = chat_id + + # check at init time + self.model_to_use = self.check_and_update_user_usage( + self.current_user, model_name + ) + + def _build_chat_history( + self, + history: list[GetChatHistoryOutput], + ) -> ChatHistoryCore: + transformed_history = format_chat_history(history) + chat_history = ChatHistoryCore(brain_id=None, chat_id=self.chat_id) + + [chat_history.append(m) for m in transformed_history] + return chat_history + + def build_llm(self) -> ChatLLM: + ollama_url = ( + settings.ollama_api_base_url + if settings.ollama_api_base_url + and self.model_to_use.name.startswith("ollama") + else None + ) + + chat_llm = ChatLLM( + llm=LLMEndpoint.from_config( + LLMEndpointConfig( + model=self.model_to_use.name, + llm_base_url=ollama_url, + llm_api_key="abc-123" if ollama_url else None, + temperature=(LLMEndpointConfig.model_fields["temperature"].default), + max_input=self.model_to_use.max_input, + max_tokens=self.model_to_use.max_output, + ), + ) + ) + return chat_llm + + def check_and_update_user_usage(self, user: UserIdentity, model_name: str): + """Check user limits and raises if user reached his limits: + 1. Raise if one of the conditions : + - User doesn't have access to brains + - Model of brain is not is user_settings.models + - Latest sum_30d(user_daily_user) < user_settings.max_monthly_usage + - Check sum(user_settings.daily_user_count)+ model_price < user_settings.monthly_chat_credits + 2. Updates user usage + """ + # TODO(@aminediro) : THIS is bug prone, should retrieve it from DB here + user_usage = UserUsage(id=user.id, email=user.email) + user_settings = user_usage.get_user_settings() + all_models = user_usage.get_models() + + # TODO(@aminediro): refactor this function + model_to_use = find_model_and_generate_metadata( + model_name, + user_settings, + all_models, + ) + cost = compute_cost(model_to_use, all_models) + # Raises HTTP if user usage exceeds limits + update_user_usage(user_usage, user_settings, cost) # noqa: F821 + return model_to_use + + def save_answer(self, question: str, answer: ParsedRAGResponse): + logger.info( + f"Saving answer for chat {self.chat_id} with model {self.model_to_use.name}" + ) + logger.info(answer) + return self.chat_service.update_chat_history( + CreateChatHistory( + **{ + "chat_id": self.chat_id, + "user_message": question, + "assistant": answer.answer, + "brain_id": None, + "prompt_id": None, + "metadata": answer.metadata.model_dump() if answer.metadata else {}, + } + ) + ) + + async def generate_answer( + self, + question: str, + ): + logger.info( + f"Creating question for chat {self.chat_id} with model {self.model_to_use.name} " + ) + chat_llm = self.build_llm() + history = await self.chat_service.get_chat_history(self.chat_id) + + # Format the history, sanitize the input + chat_history = self._build_chat_history(history) + + parsed_response = chat_llm.answer(question, chat_history) + + # Save the answer to db + new_chat_entry = self.save_answer(question, parsed_response) + + # Format output to be correct + return GetChatHistoryOutput( + **{ + "chat_id": self.chat_id, + "user_message": question, + "assistant": parsed_response.answer, + "message_time": new_chat_entry.message_time, + "prompt_title": None, + "brain_name": None, + "message_id": new_chat_entry.message_id, + "brain_id": None, + "metadata": ( + parsed_response.metadata.model_dump() + if parsed_response.metadata + else {} + ), + } + ) + + async def generate_answer_stream( + self, + question: str, + ): + logger.info( + f"Creating question for chat {self.chat_id} with model {self.model_to_use.name} " + ) + # Build the rag config + chat_llm = self.build_llm() + # Get chat history + history = await self.chat_service.get_chat_history(self.chat_id) + # Format the history, sanitize the input + chat_history = self._build_chat_history(history) + + full_answer = "" + + message_metadata = { + "chat_id": self.chat_id, + "message_id": uuid4(), # do we need it ?, + "user_message": question, # TODO: define result + "message_time": datetime.datetime.now(), # TODO: define result + "prompt_title": None, + "brain_name": None, + "brain_id": None, + } + + async for response in chat_llm.answer_astream(question, chat_history): + # Format output to be correct servicedf;j + if not response.last_chunk: + streamed_chat_history = GetChatHistoryOutput( + assistant=response.answer, + metadata=response.metadata.model_dump(), + **message_metadata, + ) + full_answer += response.answer + yield f"data: {streamed_chat_history.model_dump_json()}" + if response.last_chunk and full_answer == "": + full_answer += response.answer + + # For last chunk parse the sources, and the full answer + streamed_chat_history = GetChatHistoryOutput( + assistant=full_answer, + metadata=response.metadata.model_dump(), + **message_metadata, + ) + logger.info("Last chunk before saving") + self.save_answer( + question, + ParsedRAGResponse( + answer=full_answer, + metadata=RAGResponseMetadata(**streamed_chat_history.metadata), + ), + ) + yield f"data: {streamed_chat_history.model_dump_json()}" diff --git a/backend/api/quivr_api/modules/models/controller/__init__.py b/backend/api/quivr_api/modules/models/controller/__init__.py new file mode 100644 index 000000000..7e4b0c523 --- /dev/null +++ b/backend/api/quivr_api/modules/models/controller/__init__.py @@ -0,0 +1,3 @@ +from .model_routes import model_router + +__all__ = ["model_router"] diff --git a/backend/api/quivr_api/modules/models/controller/model_routes.py b/backend/api/quivr_api/modules/models/controller/model_routes.py new file mode 100644 index 000000000..a5370c90f --- /dev/null +++ b/backend/api/quivr_api/modules/models/controller/model_routes.py @@ -0,0 +1,30 @@ +from typing import Annotated, List + +from fastapi import APIRouter, Depends +from quivr_api.logger import get_logger +from quivr_api.middlewares.auth import AuthBearer, get_current_user +from quivr_api.modules.dependencies import get_service +from quivr_api.modules.models.entity.model import Model +from quivr_api.modules.models.service.model_service import ModelService +from quivr_api.modules.user.entity.user_identity import UserIdentity + +logger = get_logger(__name__) +model_router = APIRouter() + +ModelServiceDep = Annotated[ModelService, Depends(get_service(ModelService))] +UserIdentityDep = Annotated[UserIdentity, Depends(get_current_user)] + + +# get all chats +@model_router.get( + "/models", + response_model=List[Model], + dependencies=[Depends(AuthBearer())], + tags=["Models"], +) +async def get_models(current_user: UserIdentityDep, model_service: ModelServiceDep): + """ + Retrieve all models for the current user. + """ + models = await model_service.get_models() + return models diff --git a/backend/api/quivr_api/modules/models/entity/model.py b/backend/api/quivr_api/modules/models/entity/model.py new file mode 100644 index 000000000..58115086c --- /dev/null +++ b/backend/api/quivr_api/modules/models/entity/model.py @@ -0,0 +1,13 @@ +from sqlmodel import Field, SQLModel + + +class Model(SQLModel, table=True): + __tablename__ = "models" + + name: str = Field(primary_key=True) + price: int = Field(default=1) + max_input: int = Field(default=2000) + max_output: int = Field(default=1000) + + class Config: + arbitrary_types_allowed = True diff --git a/backend/api/quivr_api/modules/models/repository/model.py b/backend/api/quivr_api/modules/models/repository/model.py new file mode 100644 index 000000000..47581fdb9 --- /dev/null +++ b/backend/api/quivr_api/modules/models/repository/model.py @@ -0,0 +1,19 @@ +from typing import Sequence + +from quivr_api.models.settings import get_supabase_client +from quivr_api.modules.dependencies import BaseRepository +from quivr_api.modules.models.entity.model import Model +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + + +class ModelRepository(BaseRepository): + def __init__(self, session: AsyncSession): + super().__init__(session) + # TODO: for now use it instead of session + self.db = get_supabase_client() + + async def get_models(self) -> Sequence[Model]: + query = select(Model) + response = await self.session.exec(query) + return response.all() diff --git a/backend/api/quivr_api/modules/models/repository/model_interface.py b/backend/api/quivr_api/modules/models/repository/model_interface.py new file mode 100644 index 000000000..fcaf9d127 --- /dev/null +++ b/backend/api/quivr_api/modules/models/repository/model_interface.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod + +from quivr_api.modules.models.entity.model import Model + + +class ModelsInterface(ABC): + @abstractmethod + def get_models(self) -> list[Model]: + """ + Get all models + """ + pass diff --git a/backend/api/quivr_api/modules/models/service/model_service.py b/backend/api/quivr_api/modules/models/service/model_service.py new file mode 100644 index 000000000..b999ef7c3 --- /dev/null +++ b/backend/api/quivr_api/modules/models/service/model_service.py @@ -0,0 +1,21 @@ +from quivr_api.logger import get_logger +from quivr_api.modules.dependencies import BaseService +from quivr_api.modules.models.entity.model import Model +from quivr_api.modules.models.repository.model import ModelRepository + +logger = get_logger(__name__) + + +class ModelService(BaseService[ModelRepository]): + repository_cls = ModelRepository + + def __init__(self, repository: ModelRepository): + self.repository = repository + + async def get_models(self) -> list[Model]: + logger.info("Getting models") + + models = await self.repository.get_models() + logger.info(f"Insert response {models}") + + return models diff --git a/backend/api/quivr_api/modules/models/tests/test_models.py b/backend/api/quivr_api/modules/models/tests/test_models.py new file mode 100644 index 000000000..7ec8467ca --- /dev/null +++ b/backend/api/quivr_api/modules/models/tests/test_models.py @@ -0,0 +1,70 @@ +import pytest +import pytest_asyncio +from quivr_api.modules.models.entity.model import Model + + +@pytest_asyncio.fixture() +async def sample_models(): + return [ + Model(name="gpt-3.5-turbo", price=1, max_input=4000, max_output=2000), + Model(name="gpt-4", price=5, max_input=8000, max_output=4000), + ] + + +@pytest.mark.asyncio +async def test_model_creation(): + model = Model(name="test-model", price=2, max_input=1000, max_output=500) + assert model.name == "test-model" + assert model.price == 2 + assert model.max_input == 1000 + assert model.max_output == 500 + + +@pytest.mark.asyncio +async def test_model_attributes(sample_models): + model = sample_models[0] + assert hasattr(model, "name") + assert hasattr(model, "price") + assert hasattr(model, "max_input") + assert hasattr(model, "max_output") + + +@pytest.mark.asyncio +async def test_model_validation(): + # Test valid model creation + valid_model = Model(name="valid-model", price=3, max_input=5000, max_output=2500) + assert valid_model.name == "valid-model" + assert valid_model.price == 3 + assert valid_model.max_input == 5000 + assert valid_model.max_output == 2500 + + +@pytest.mark.asyncio +async def test_model_default_values(): + default_model = Model(name="default-model") + assert default_model.name == "default-model" + assert default_model.price == 1 + assert default_model.max_input == 2000 + assert default_model.max_output == 1000 + + +@pytest.mark.asyncio +async def test_model_comparison(): + model1 = Model(name="model1", price=2, max_input=3000, max_output=1500) + model2 = Model(name="model2", price=3, max_input=4000, max_output=2000) + model3 = Model(name="model1", price=2, max_input=3000, max_output=1500) + + assert model1 != model2 + assert model1 == model3 + + +@pytest.mark.asyncio +async def test_model_dict_representation(): + model = Model(name="test-model", price=2, max_input=3000, max_output=1500) + expected_dict = { + "name": "test-model", + "price": 2, + "max_input": 3000, + "max_output": 1500, + } + assert model.dict() == expected_dict diff --git a/backend/core/examples/chat_llm.py b/backend/core/examples/chat_llm.py new file mode 100644 index 000000000..969f4b466 --- /dev/null +++ b/backend/core/examples/chat_llm.py @@ -0,0 +1,12 @@ +from quivr_core import ChatLLM +from quivr_core.config import LLMEndpointConfig +from quivr_core.llm import LLMEndpoint + +if __name__ == "__main__": + llm_endpoint = LLMEndpoint.from_config(LLMEndpointConfig(model="gpt-4o-mini")) + chat_llm = ChatLLM( + llm=llm_endpoint, + ) + print(chat_llm.llm_endpoint.info()) + response = chat_llm.answer("Hello,what is your model?") + print(response) diff --git a/backend/core/quivr_core/__init__.py b/backend/core/quivr_core/__init__.py index 1fcda2808..5ef621c01 100644 --- a/backend/core/quivr_core/__init__.py +++ b/backend/core/quivr_core/__init__.py @@ -1,9 +1,10 @@ from importlib.metadata import entry_points from .brain import Brain +from .chat_llm import ChatLLM from .processor.registry import register_processor, registry -__all__ = ["Brain", "registry", "register_processor"] +__all__ = ["Brain", "ChatLLM", "registry", "register_processor"] def register_entries(): diff --git a/backend/core/quivr_core/chat.py b/backend/core/quivr_core/chat.py index 59d0b4ac2..7247dc812 100644 --- a/backend/core/quivr_core/chat.py +++ b/backend/core/quivr_core/chat.py @@ -3,12 +3,11 @@ from typing import Any, Generator, Tuple from uuid import UUID, uuid4 from langchain_core.messages import AIMessage, HumanMessage - from quivr_core.models import ChatMessage class ChatHistory: - def __init__(self, chat_id: UUID, brain_id: UUID) -> None: + def __init__(self, chat_id: UUID, brain_id: UUID | None) -> None: self.id = chat_id self.brain_id = brain_id # TODO(@aminediro): maybe use a deque() instead ? diff --git a/backend/core/quivr_core/chat_llm.py b/backend/core/quivr_core/chat_llm.py new file mode 100644 index 000000000..a4e8ae534 --- /dev/null +++ b/backend/core/quivr_core/chat_llm.py @@ -0,0 +1,146 @@ +import logging +from operator import itemgetter +from typing import AsyncGenerator + +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages.ai import AIMessageChunk +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import RunnableLambda, RunnablePassthrough + +from quivr_core.chat import ChatHistory +from quivr_core.llm import LLMEndpoint +from quivr_core.models import ( + ParsedRAGChunkResponse, + ParsedRAGResponse, + RAGResponseMetadata, +) +from quivr_core.utils import get_chunk_metadata, parse_chunk_response, parse_response + +logger = logging.getLogger("quivr_core") + + +class ChatLLM: + def __init__(self, *, llm: LLMEndpoint): + self.llm_endpoint = llm + + def filter_history( + self, + chat_history: ChatHistory | None, + ): + """ + Filter out the chat history to only include the messages that are relevant to the current question + + Returns a filtered chat_history with in priority: first max_tokens, then max_history where a Human message and an AI message count as one pair + a token is 4 characters + """ + total_tokens = 0 + total_pairs = 0 + filtered_chat_history: list[AIMessage | HumanMessage] = [] + if chat_history is None: + return filtered_chat_history + for human_message, ai_message in chat_history.iter_pairs(): + # TODO: replace with tiktoken + message_tokens = (len(human_message.content) + len(ai_message.content)) // 4 + if ( + total_tokens + message_tokens > self.llm_endpoint._config.max_input + or total_pairs >= 10 + ): + break + filtered_chat_history.append(human_message) + filtered_chat_history.append(ai_message) + total_tokens += message_tokens + total_pairs += 1 + + return filtered_chat_history[::-1] + + def build_chain(self): + loaded_memory = RunnablePassthrough.assign( + chat_history=RunnableLambda( + lambda x: self.filter_history(x["chat_history"]), + ), + question=lambda x: x["question"], + ) + logger.info(f"loaded_memory: {loaded_memory}") + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + "You are Quivr. You are an assistant.", + ), + MessagesPlaceholder(variable_name="chat_history"), + ("human", "{question}"), + ] + ) + + final_inputs = { + "question": itemgetter("question"), + "chat_history": itemgetter("chat_history"), + } + llm = self.llm_endpoint._llm + + answer = {"answer": final_inputs | prompt | llm, "docs": lambda _: []} + + return loaded_memory | answer + + def answer( + self, question: str, history: ChatHistory | None = None + ) -> ParsedRAGResponse: + chain = self.build_chain() + raw_llm_response = chain.invoke({"question": question, "chat_history": history}) + + response = parse_response(raw_llm_response, self.llm_endpoint._config.model) + return response + + async def answer_astream( + self, question: str, history: ChatHistory | None = None + ) -> AsyncGenerator[ParsedRAGChunkResponse, ParsedRAGChunkResponse]: + chain = self.build_chain() + rolling_message = AIMessageChunk(content="") + prev_answer = "" + chunk_id = 0 + + async for chunk in chain.astream( + {"question": question, "chat_history": history} + ): + if "answer" in chunk: + rolling_message, answer_str = parse_chunk_response( + rolling_message, + chunk, + self.llm_endpoint.supports_func_calling(), + ) + if len(answer_str) > 0: + if self.llm_endpoint.supports_func_calling(): + diff_answer = answer_str[len(prev_answer) :] + if len(diff_answer) > 0: + parsed_chunk = ParsedRAGChunkResponse( + answer=diff_answer, + metadata=RAGResponseMetadata(), + ) + prev_answer += diff_answer + + logger.debug( + f"answer_astream func_calling=True question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}" + ) + yield parsed_chunk + else: + parsed_chunk = ParsedRAGChunkResponse( + answer=answer_str, + metadata=RAGResponseMetadata(), + ) + logger.debug( + f"answer_astream func_calling=False question={question} rolling_msg={rolling_message} chunk_id={chunk_id}, chunk={parsed_chunk}" + ) + yield parsed_chunk + + chunk_id += 1 + # Last chunk provides metadata + last_chunk = ParsedRAGChunkResponse( + answer=rolling_message.content, + metadata=get_chunk_metadata(rolling_message), + last_chunk=True, + ) + last_chunk.metadata.model_name = self.llm_endpoint._config.model + logger.debug( + f"answer_astream last_chunk={last_chunk} question={question} rolling_msg={rolling_message} chunk_id={chunk_id}" + ) + yield last_chunk diff --git a/backend/core/quivr_core/llm/llm_endpoint.py b/backend/core/quivr_core/llm/llm_endpoint.py index 45df28696..b3eac0a13 100644 --- a/backend/core/quivr_core/llm/llm_endpoint.py +++ b/backend/core/quivr_core/llm/llm_endpoint.py @@ -1,6 +1,5 @@ from langchain_core.language_models.chat_models import BaseChatModel from pydantic.v1 import SecretStr - from quivr_core.brain.info import LLMInfo from quivr_core.config import LLMEndpointConfig from quivr_core.utils import model_supports_function_calling diff --git a/backend/core/quivr_core/llm/models.py b/backend/core/quivr_core/llm/models.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/core/quivr_core/models.py b/backend/core/quivr_core/models.py index 3993e56e2..fc8bcbdb4 100644 --- a/backend/core/quivr_core/models.py +++ b/backend/core/quivr_core/models.py @@ -31,7 +31,7 @@ class cited_answer(BaseModelV1): class ChatMessage(BaseModelV1): chat_id: UUID message_id: UUID - brain_id: UUID + brain_id: UUID | None msg: AIMessage | HumanMessage message_time: datetime metadata: dict[str, Any] @@ -59,6 +59,7 @@ class RAGResponseMetadata(BaseModel): citations: list[int] | None = None followup_questions: list[str] | None = None sources: list[Any] | None = None + model_name: str | None = None class ParsedRAGResponse(BaseModel): diff --git a/backend/core/quivr_core/utils.py b/backend/core/quivr_core/utils.py index 63736b654..d7b049c0e 100644 --- a/backend/core/quivr_core/utils.py +++ b/backend/core/quivr_core/utils.py @@ -4,6 +4,7 @@ from typing import Any, List, Tuple, no_type_check from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages.ai import AIMessageChunk from langchain_core.prompts import format_document + from quivr_core.models import ( ParsedRAGResponse, QuivrKnowledge, @@ -31,6 +32,7 @@ def model_supports_function_calling(model_name: str): "gpt-3.5-turbo", "gpt-4-turbo", "gpt-4o", + "gpt-4o-mini", ] return model_name in models_supporting_function_calls @@ -113,7 +115,8 @@ def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGRe answer = raw_response["answer"].content sources = raw_response["docs"] or [] - metadata = {"sources": sources} + metadata = {"sources": sources, "model_name":model_name} + metadata["model_name"] = model_name if model_supports_function_calling(model_name): if raw_response["answer"].tool_calls: diff --git a/backend/core/tests/test_chat_llm.py b/backend/core/tests/test_chat_llm.py new file mode 100644 index 000000000..769afabca --- /dev/null +++ b/backend/core/tests/test_chat_llm.py @@ -0,0 +1,17 @@ +import pytest +from quivr_core import ChatLLM +from quivr_core.chat_llm import ChatLLM + + +@pytest.mark.base +def test_chat_llm(fake_llm): + chat_llm = ChatLLM( + llm=fake_llm, + ) + answer = chat_llm.answer("Hello, how are you?") + + assert len(answer.answer) > 0 + assert answer.metadata is not None + assert answer.metadata.citations is None + assert answer.metadata.followup_questions is None + assert answer.metadata.sources == []