diff --git a/backend/api/quivr_api/modules/chat/controller/chat_routes.py b/backend/api/quivr_api/modules/chat/controller/chat_routes.py index c6dd6dcda..3356edab1 100644 --- a/backend/api/quivr_api/modules/chat/controller/chat_routes.py +++ b/backend/api/quivr_api/modules/chat/controller/chat_routes.py @@ -3,6 +3,7 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.responses import StreamingResponse + from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.brain.entity.brain_entity import RoleEnum @@ -22,6 +23,7 @@ from quivr_api.modules.chat.service.chat_service import ChatService from quivr_api.modules.chat_llm_service.chat_llm_service import ChatLLMService from quivr_api.modules.dependencies import get_service from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository +from quivr_api.modules.models.service.model_service import ModelService from quivr_api.modules.prompt.service.prompt_service import PromptService from quivr_api.modules.rag_service import RAGService from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -37,6 +39,7 @@ prompt_service = PromptService() ChatServiceDep = Annotated[ChatService, Depends(get_service(ChatService))] UserIdentityDep = Annotated[UserIdentity, Depends(get_current_user)] +ModelServiceDep = Annotated[ModelService, Depends(get_service(ModelService))] def validate_authorization(user_id, brain_id): @@ -162,6 +165,7 @@ async def create_question_handler( chat_id: UUID, current_user: UserIdentityDep, chat_service: ChatServiceDep, + model_service: ModelServiceDep, brain_id: Annotated[UUID | None, Query()] = None, ): # TODO: check logic into middleware @@ -184,6 +188,7 @@ async def create_question_handler( chat_question.model, chat_id, chat_service, + model_service, ) chat_answer = await service.generate_answer(chat_question.question) @@ -215,6 +220,7 @@ async def create_stream_question_handler( chat_id: UUID, chat_service: ChatServiceDep, current_user: UserIdentityDep, + model_service: ModelServiceDep, brain_id: Annotated[UUID | None, Query()] = None, ) -> StreamingResponse: validate_authorization(user_id=current_user.id, brain_id=brain_id) @@ -241,6 +247,7 @@ async def create_stream_question_handler( chat_question.model, chat_id, chat_service, + model_service, ) maybe_send_telemetry("question_asked", {"streaming": True}, request) diff --git a/backend/api/quivr_api/modules/chat_llm_service/chat_llm_service.py b/backend/api/quivr_api/modules/chat_llm_service/chat_llm_service.py index c7c17861f..3d84ac35c 100644 --- a/backend/api/quivr_api/modules/chat_llm_service/chat_llm_service.py +++ b/backend/api/quivr_api/modules/chat_llm_service/chat_llm_service.py @@ -5,7 +5,7 @@ from quivr_core.chat import ChatHistory as ChatHistoryCore from quivr_core.chat_llm import ChatLLM from quivr_core.config import LLMEndpointConfig from quivr_core.llm.llm_endpoint import LLMEndpoint -from quivr_core.models import ParsedRAGResponse, RAGResponseMetadata +from quivr_core.models import ChatLLMMetadata, ParsedRAGResponse, RAGResponseMetadata from quivr_api.logger import get_logger from quivr_api.models.settings import settings @@ -20,6 +20,7 @@ from quivr_api.modules.chat.controller.chat.utils import ( from quivr_api.modules.chat.dto.inputs import CreateChatHistory from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput from quivr_api.modules.chat.service.chat_service import ChatService +from quivr_api.modules.models.service.model_service import ModelService from quivr_api.modules.user.entity.user_identity import UserIdentity from quivr_api.modules.user.service.user_usage import UserUsage @@ -33,9 +34,11 @@ class ChatLLMService: model_name: str, chat_id: UUID, chat_service: ChatService, + model_service: ModelService, ): # Services self.chat_service = chat_service + self.model_service = model_service # Base models self.current_user = current_user @@ -130,12 +133,20 @@ class ChatLLMService: ) chat_llm = self.build_llm() history = await self.chat_service.get_chat_history(self.chat_id) - + model_metadata = await self.model_service.get_model(self.model_to_use.name) # Format the history, sanitize the input chat_history = self._build_chat_history(history) parsed_response = chat_llm.answer(question, chat_history) + if parsed_response.metadata: + parsed_response.metadata.metadata_model = ChatLLMMetadata( + name=self.model_to_use.name, + description=model_metadata.description, + image_url=model_metadata.image_url, + display_name=model_metadata.display_name, + ) + # Save the answer to db new_chat_entry = self.save_answer(question, parsed_response) @@ -167,6 +178,9 @@ class ChatLLMService: ) # Build the rag config chat_llm = self.build_llm() + + # Get model metadata + model_metadata = await self.model_service.get_model(self.model_to_use.name) # Get chat history history = await self.chat_service.get_chat_history(self.chat_id) # Format the history, sanitize the input @@ -203,12 +217,22 @@ class ChatLLMService: metadata=response.metadata.model_dump(), **message_metadata, ) + + metadata = RAGResponseMetadata(**streamed_chat_history.metadata) # type: ignore + metadata.metadata_model = ChatLLMMetadata( + name=self.model_to_use.name, + description=model_metadata.description, + image_url=model_metadata.image_url, + display_name=model_metadata.display_name, + ) + streamed_chat_history.metadata = metadata.model_dump() + logger.info("Last chunk before saving") self.save_answer( question, ParsedRAGResponse( answer=full_answer, - metadata=RAGResponseMetadata(**streamed_chat_history.metadata), + metadata=metadata, ), ) yield f"data: {streamed_chat_history.model_dump_json()}" diff --git a/backend/api/quivr_api/modules/models/entity/model.py b/backend/api/quivr_api/modules/models/entity/model.py index 58115086c..22c9ef2cf 100644 --- a/backend/api/quivr_api/modules/models/entity/model.py +++ b/backend/api/quivr_api/modules/models/entity/model.py @@ -1,13 +1,16 @@ from sqlmodel import Field, SQLModel -class Model(SQLModel, table=True): - __tablename__ = "models" +class Model(SQLModel, table=True): # type: ignore + __tablename__ = "models" # type: ignore name: str = Field(primary_key=True) price: int = Field(default=1) max_input: int = Field(default=2000) max_output: int = Field(default=1000) + description: str = Field(default="") + display_name: str = Field(default="") + image_url: str = Field(default="") class Config: arbitrary_types_allowed = True diff --git a/backend/api/quivr_api/modules/models/repository/model.py b/backend/api/quivr_api/modules/models/repository/model.py index 47581fdb9..ddb24c074 100644 --- a/backend/api/quivr_api/modules/models/repository/model.py +++ b/backend/api/quivr_api/modules/models/repository/model.py @@ -1,10 +1,11 @@ from typing import Sequence +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + from quivr_api.models.settings import get_supabase_client from quivr_api.modules.dependencies import BaseRepository from quivr_api.modules.models.entity.model import Model -from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession class ModelRepository(BaseRepository): @@ -17,3 +18,8 @@ class ModelRepository(BaseRepository): query = select(Model) response = await self.session.exec(query) return response.all() + + async def get_model(self, model_name: str) -> Model: + query = select(Model).where(Model.name == model_name) + response = await self.session.exec(query) + return response.first() diff --git a/backend/api/quivr_api/modules/models/repository/model_interface.py b/backend/api/quivr_api/modules/models/repository/model_interface.py index fcaf9d127..eda4beea3 100644 --- a/backend/api/quivr_api/modules/models/repository/model_interface.py +++ b/backend/api/quivr_api/modules/models/repository/model_interface.py @@ -10,3 +10,10 @@ class ModelsInterface(ABC): Get all models """ pass + + @abstractmethod + def get_model(self, model_name: str) -> Model: + """ + Get a model by name + """ + pass diff --git a/backend/api/quivr_api/modules/models/service/model_service.py b/backend/api/quivr_api/modules/models/service/model_service.py index b999ef7c3..8f3eaf8ae 100644 --- a/backend/api/quivr_api/modules/models/service/model_service.py +++ b/backend/api/quivr_api/modules/models/service/model_service.py @@ -16,6 +16,12 @@ class ModelService(BaseService[ModelRepository]): logger.info("Getting models") models = await self.repository.get_models() - logger.info(f"Insert response {models}") - return models + return models # type: ignore + + async def get_model(self, model_name: str) -> Model: + logger.info(f"Getting model {model_name}") + + model = await self.repository.get_model(model_name) + + return model diff --git a/backend/core/quivr_core/chat_llm.py b/backend/core/quivr_core/chat_llm.py index a4e8ae534..9ea9407df 100644 --- a/backend/core/quivr_core/chat_llm.py +++ b/backend/core/quivr_core/chat_llm.py @@ -10,6 +10,7 @@ from langchain_core.runnables import RunnableLambda, RunnablePassthrough from quivr_core.chat import ChatHistory from quivr_core.llm import LLMEndpoint from quivr_core.models import ( + ChatLLMMetadata, ParsedRAGChunkResponse, ParsedRAGResponse, RAGResponseMetadata, @@ -139,7 +140,9 @@ class ChatLLM: metadata=get_chunk_metadata(rolling_message), last_chunk=True, ) - last_chunk.metadata.model_name = self.llm_endpoint._config.model + last_chunk.metadata.metadata_model = ChatLLMMetadata( + name=self.llm_endpoint._config.model, + ) logger.debug( f"answer_astream last_chunk={last_chunk} question={question} rolling_msg={rolling_message} chunk_id={chunk_id}" ) diff --git a/backend/core/quivr_core/models.py b/backend/core/quivr_core/models.py index fc8bcbdb4..418cd704a 100644 --- a/backend/core/quivr_core/models.py +++ b/backend/core/quivr_core/models.py @@ -55,11 +55,18 @@ class RawRAGResponse(TypedDict): docs: dict[str, Any] +class ChatLLMMetadata(BaseModel): + name: str + display_name: str | None = None + description: str | None = None + image_url: str | None = None + + class RAGResponseMetadata(BaseModel): citations: list[int] | None = None followup_questions: list[str] | None = None sources: list[Any] | None = None - model_name: str | None = None + metadata_model: ChatLLMMetadata | None = None class ParsedRAGResponse(BaseModel): diff --git a/backend/core/quivr_core/utils.py b/backend/core/quivr_core/utils.py index d7b049c0e..9654638f4 100644 --- a/backend/core/quivr_core/utils.py +++ b/backend/core/quivr_core/utils.py @@ -6,6 +6,7 @@ from langchain_core.messages.ai import AIMessageChunk from langchain_core.prompts import format_document from quivr_core.models import ( + ChatLLMMetadata, ParsedRAGResponse, QuivrKnowledge, RAGResponseMetadata, @@ -115,22 +116,21 @@ def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGRe answer = raw_response["answer"].content sources = raw_response["docs"] or [] - metadata = {"sources": sources, "model_name":model_name} - metadata["model_name"] = model_name + metadata = RAGResponseMetadata( + sources=sources, metadata_model=ChatLLMMetadata(name=model_name) + ) if model_supports_function_calling(model_name): if raw_response["answer"].tool_calls: citations = raw_response["answer"].tool_calls[-1]["args"]["citations"] - metadata["citations"] = citations + metadata.citations = citations followup_questions = raw_response["answer"].tool_calls[-1]["args"][ "followup_questions" ] if followup_questions: - metadata["followup_questions"] = followup_questions + metadata.followup_questions = followup_questions - parsed_response = ParsedRAGResponse( - answer=answer, metadata=RAGResponseMetadata(**metadata) - ) + parsed_response = ParsedRAGResponse(answer=answer, metadata=metadata) return parsed_response diff --git a/backend/core/tests/test_chat_llm.py b/backend/core/tests/test_chat_llm.py index 769afabca..7eeeb9730 100644 --- a/backend/core/tests/test_chat_llm.py +++ b/backend/core/tests/test_chat_llm.py @@ -1,6 +1,6 @@ import pytest + from quivr_core import ChatLLM -from quivr_core.chat_llm import ChatLLM @pytest.mark.base @@ -15,3 +15,5 @@ def test_chat_llm(fake_llm): assert answer.metadata.citations is None assert answer.metadata.followup_questions is None assert answer.metadata.sources == [] + assert answer.metadata.metadata_model is not None + assert answer.metadata.metadata_model.name is not None diff --git a/backend/supabase/migrations/20240806153621_model-description.sql b/backend/supabase/migrations/20240806153621_model-description.sql new file mode 100644 index 000000000..3d1681005 --- /dev/null +++ b/backend/supabase/migrations/20240806153621_model-description.sql @@ -0,0 +1,5 @@ +alter table "public"."models" add column "description" text not null default 'Default Description'::text; + +alter table "public"."models" add column "display_name" text not null default gen_random_uuid(); + +alter table "public"."models" add column "image_url" text not null default 'https://quivr-cms.s3.eu-west-3.amazonaws.com/logo_quivr_white_7e3c72620f.png'::text;