feat: Add get_model method to ModelRepository (#2949)

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
This commit is contained in:
Stan Girard 2024-08-06 17:44:12 +02:00 committed by GitHub
parent 35eaf08680
commit 13e9fc490b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 89 additions and 19 deletions

View File

@ -3,6 +3,7 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.brain.entity.brain_entity import RoleEnum
@ -22,6 +23,7 @@ from quivr_api.modules.chat.service.chat_service import ChatService
from quivr_api.modules.chat_llm_service.chat_llm_service import ChatLLMService
from quivr_api.modules.dependencies import get_service
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
from quivr_api.modules.models.service.model_service import ModelService
from quivr_api.modules.prompt.service.prompt_service import PromptService
from quivr_api.modules.rag_service import RAGService
from quivr_api.modules.user.entity.user_identity import UserIdentity
@ -37,6 +39,7 @@ prompt_service = PromptService()
ChatServiceDep = Annotated[ChatService, Depends(get_service(ChatService))]
UserIdentityDep = Annotated[UserIdentity, Depends(get_current_user)]
ModelServiceDep = Annotated[ModelService, Depends(get_service(ModelService))]
def validate_authorization(user_id, brain_id):
@ -162,6 +165,7 @@ async def create_question_handler(
chat_id: UUID,
current_user: UserIdentityDep,
chat_service: ChatServiceDep,
model_service: ModelServiceDep,
brain_id: Annotated[UUID | None, Query()] = None,
):
# TODO: check logic into middleware
@ -184,6 +188,7 @@ async def create_question_handler(
chat_question.model,
chat_id,
chat_service,
model_service,
)
chat_answer = await service.generate_answer(chat_question.question)
@ -215,6 +220,7 @@ async def create_stream_question_handler(
chat_id: UUID,
chat_service: ChatServiceDep,
current_user: UserIdentityDep,
model_service: ModelServiceDep,
brain_id: Annotated[UUID | None, Query()] = None,
) -> StreamingResponse:
validate_authorization(user_id=current_user.id, brain_id=brain_id)
@ -241,6 +247,7 @@ async def create_stream_question_handler(
chat_question.model,
chat_id,
chat_service,
model_service,
)
maybe_send_telemetry("question_asked", {"streaming": True}, request)

View File

@ -5,7 +5,7 @@ from quivr_core.chat import ChatHistory as ChatHistoryCore
from quivr_core.chat_llm import ChatLLM
from quivr_core.config import LLMEndpointConfig
from quivr_core.llm.llm_endpoint import LLMEndpoint
from quivr_core.models import ParsedRAGResponse, RAGResponseMetadata
from quivr_core.models import ChatLLMMetadata, ParsedRAGResponse, RAGResponseMetadata
from quivr_api.logger import get_logger
from quivr_api.models.settings import settings
@ -20,6 +20,7 @@ from quivr_api.modules.chat.controller.chat.utils import (
from quivr_api.modules.chat.dto.inputs import CreateChatHistory
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
from quivr_api.modules.chat.service.chat_service import ChatService
from quivr_api.modules.models.service.model_service import ModelService
from quivr_api.modules.user.entity.user_identity import UserIdentity
from quivr_api.modules.user.service.user_usage import UserUsage
@ -33,9 +34,11 @@ class ChatLLMService:
model_name: str,
chat_id: UUID,
chat_service: ChatService,
model_service: ModelService,
):
# Services
self.chat_service = chat_service
self.model_service = model_service
# Base models
self.current_user = current_user
@ -130,12 +133,20 @@ class ChatLLMService:
)
chat_llm = self.build_llm()
history = await self.chat_service.get_chat_history(self.chat_id)
model_metadata = await self.model_service.get_model(self.model_to_use.name)
# Format the history, sanitize the input
chat_history = self._build_chat_history(history)
parsed_response = chat_llm.answer(question, chat_history)
if parsed_response.metadata:
parsed_response.metadata.metadata_model = ChatLLMMetadata(
name=self.model_to_use.name,
description=model_metadata.description,
image_url=model_metadata.image_url,
display_name=model_metadata.display_name,
)
# Save the answer to db
new_chat_entry = self.save_answer(question, parsed_response)
@ -167,6 +178,9 @@ class ChatLLMService:
)
# Build the rag config
chat_llm = self.build_llm()
# Get model metadata
model_metadata = await self.model_service.get_model(self.model_to_use.name)
# Get chat history
history = await self.chat_service.get_chat_history(self.chat_id)
# Format the history, sanitize the input
@ -203,12 +217,22 @@ class ChatLLMService:
metadata=response.metadata.model_dump(),
**message_metadata,
)
metadata = RAGResponseMetadata(**streamed_chat_history.metadata) # type: ignore
metadata.metadata_model = ChatLLMMetadata(
name=self.model_to_use.name,
description=model_metadata.description,
image_url=model_metadata.image_url,
display_name=model_metadata.display_name,
)
streamed_chat_history.metadata = metadata.model_dump()
logger.info("Last chunk before saving")
self.save_answer(
question,
ParsedRAGResponse(
answer=full_answer,
metadata=RAGResponseMetadata(**streamed_chat_history.metadata),
metadata=metadata,
),
)
yield f"data: {streamed_chat_history.model_dump_json()}"

View File

@ -1,13 +1,16 @@
from sqlmodel import Field, SQLModel
class Model(SQLModel, table=True):
__tablename__ = "models"
class Model(SQLModel, table=True): # type: ignore
__tablename__ = "models" # type: ignore
name: str = Field(primary_key=True)
price: int = Field(default=1)
max_input: int = Field(default=2000)
max_output: int = Field(default=1000)
description: str = Field(default="")
display_name: str = Field(default="")
image_url: str = Field(default="")
class Config:
arbitrary_types_allowed = True

View File

@ -1,10 +1,11 @@
from typing import Sequence
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.models.settings import get_supabase_client
from quivr_api.modules.dependencies import BaseRepository
from quivr_api.modules.models.entity.model import Model
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
class ModelRepository(BaseRepository):
@ -17,3 +18,8 @@ class ModelRepository(BaseRepository):
query = select(Model)
response = await self.session.exec(query)
return response.all()
async def get_model(self, model_name: str) -> Model:
query = select(Model).where(Model.name == model_name)
response = await self.session.exec(query)
return response.first()

View File

@ -10,3 +10,10 @@ class ModelsInterface(ABC):
Get all models
"""
pass
@abstractmethod
def get_model(self, model_name: str) -> Model:
"""
Get a model by name
"""
pass

View File

@ -16,6 +16,12 @@ class ModelService(BaseService[ModelRepository]):
logger.info("Getting models")
models = await self.repository.get_models()
logger.info(f"Insert response {models}")
return models
return models # type: ignore
async def get_model(self, model_name: str) -> Model:
logger.info(f"Getting model {model_name}")
model = await self.repository.get_model(model_name)
return model

View File

@ -10,6 +10,7 @@ from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from quivr_core.chat import ChatHistory
from quivr_core.llm import LLMEndpoint
from quivr_core.models import (
ChatLLMMetadata,
ParsedRAGChunkResponse,
ParsedRAGResponse,
RAGResponseMetadata,
@ -139,7 +140,9 @@ class ChatLLM:
metadata=get_chunk_metadata(rolling_message),
last_chunk=True,
)
last_chunk.metadata.model_name = self.llm_endpoint._config.model
last_chunk.metadata.metadata_model = ChatLLMMetadata(
name=self.llm_endpoint._config.model,
)
logger.debug(
f"answer_astream last_chunk={last_chunk} question={question} rolling_msg={rolling_message} chunk_id={chunk_id}"
)

View File

@ -55,11 +55,18 @@ class RawRAGResponse(TypedDict):
docs: dict[str, Any]
class ChatLLMMetadata(BaseModel):
name: str
display_name: str | None = None
description: str | None = None
image_url: str | None = None
class RAGResponseMetadata(BaseModel):
citations: list[int] | None = None
followup_questions: list[str] | None = None
sources: list[Any] | None = None
model_name: str | None = None
metadata_model: ChatLLMMetadata | None = None
class ParsedRAGResponse(BaseModel):

View File

@ -6,6 +6,7 @@ from langchain_core.messages.ai import AIMessageChunk
from langchain_core.prompts import format_document
from quivr_core.models import (
ChatLLMMetadata,
ParsedRAGResponse,
QuivrKnowledge,
RAGResponseMetadata,
@ -115,22 +116,21 @@ def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGRe
answer = raw_response["answer"].content
sources = raw_response["docs"] or []
metadata = {"sources": sources, "model_name":model_name}
metadata["model_name"] = model_name
metadata = RAGResponseMetadata(
sources=sources, metadata_model=ChatLLMMetadata(name=model_name)
)
if model_supports_function_calling(model_name):
if raw_response["answer"].tool_calls:
citations = raw_response["answer"].tool_calls[-1]["args"]["citations"]
metadata["citations"] = citations
metadata.citations = citations
followup_questions = raw_response["answer"].tool_calls[-1]["args"][
"followup_questions"
]
if followup_questions:
metadata["followup_questions"] = followup_questions
metadata.followup_questions = followup_questions
parsed_response = ParsedRAGResponse(
answer=answer, metadata=RAGResponseMetadata(**metadata)
)
parsed_response = ParsedRAGResponse(answer=answer, metadata=metadata)
return parsed_response

View File

@ -1,6 +1,6 @@
import pytest
from quivr_core import ChatLLM
from quivr_core.chat_llm import ChatLLM
@pytest.mark.base
@ -15,3 +15,5 @@ def test_chat_llm(fake_llm):
assert answer.metadata.citations is None
assert answer.metadata.followup_questions is None
assert answer.metadata.sources == []
assert answer.metadata.metadata_model is not None
assert answer.metadata.metadata_model.name is not None

View File

@ -0,0 +1,5 @@
alter table "public"."models" add column "description" text not null default 'Default Description'::text;
alter table "public"."models" add column "display_name" text not null default gen_random_uuid();
alter table "public"."models" add column "image_url" text not null default 'https://quivr-cms.s3.eu-west-3.amazonaws.com/logo_quivr_white_7e3c72620f.png'::text;