mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 01:21:48 +03:00
feat: Add get_model method to ModelRepository (#2949)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
35eaf08680
commit
13e9fc490b
@ -3,6 +3,7 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
||||
from quivr_api.modules.brain.entity.brain_entity import RoleEnum
|
||||
@ -22,6 +23,7 @@ from quivr_api.modules.chat.service.chat_service import ChatService
|
||||
from quivr_api.modules.chat_llm_service.chat_llm_service import ChatLLMService
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||
from quivr_api.modules.models.service.model_service import ModelService
|
||||
from quivr_api.modules.prompt.service.prompt_service import PromptService
|
||||
from quivr_api.modules.rag_service import RAGService
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
@ -37,6 +39,7 @@ prompt_service = PromptService()
|
||||
|
||||
ChatServiceDep = Annotated[ChatService, Depends(get_service(ChatService))]
|
||||
UserIdentityDep = Annotated[UserIdentity, Depends(get_current_user)]
|
||||
ModelServiceDep = Annotated[ModelService, Depends(get_service(ModelService))]
|
||||
|
||||
|
||||
def validate_authorization(user_id, brain_id):
|
||||
@ -162,6 +165,7 @@ async def create_question_handler(
|
||||
chat_id: UUID,
|
||||
current_user: UserIdentityDep,
|
||||
chat_service: ChatServiceDep,
|
||||
model_service: ModelServiceDep,
|
||||
brain_id: Annotated[UUID | None, Query()] = None,
|
||||
):
|
||||
# TODO: check logic into middleware
|
||||
@ -184,6 +188,7 @@ async def create_question_handler(
|
||||
chat_question.model,
|
||||
chat_id,
|
||||
chat_service,
|
||||
model_service,
|
||||
)
|
||||
chat_answer = await service.generate_answer(chat_question.question)
|
||||
|
||||
@ -215,6 +220,7 @@ async def create_stream_question_handler(
|
||||
chat_id: UUID,
|
||||
chat_service: ChatServiceDep,
|
||||
current_user: UserIdentityDep,
|
||||
model_service: ModelServiceDep,
|
||||
brain_id: Annotated[UUID | None, Query()] = None,
|
||||
) -> StreamingResponse:
|
||||
validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
||||
@ -241,6 +247,7 @@ async def create_stream_question_handler(
|
||||
chat_question.model,
|
||||
chat_id,
|
||||
chat_service,
|
||||
model_service,
|
||||
)
|
||||
maybe_send_telemetry("question_asked", {"streaming": True}, request)
|
||||
|
||||
|
@ -5,7 +5,7 @@ from quivr_core.chat import ChatHistory as ChatHistoryCore
|
||||
from quivr_core.chat_llm import ChatLLM
|
||||
from quivr_core.config import LLMEndpointConfig
|
||||
from quivr_core.llm.llm_endpoint import LLMEndpoint
|
||||
from quivr_core.models import ParsedRAGResponse, RAGResponseMetadata
|
||||
from quivr_core.models import ChatLLMMetadata, ParsedRAGResponse, RAGResponseMetadata
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.models.settings import settings
|
||||
@ -20,6 +20,7 @@ from quivr_api.modules.chat.controller.chat.utils import (
|
||||
from quivr_api.modules.chat.dto.inputs import CreateChatHistory
|
||||
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from quivr_api.modules.chat.service.chat_service import ChatService
|
||||
from quivr_api.modules.models.service.model_service import ModelService
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
from quivr_api.modules.user.service.user_usage import UserUsage
|
||||
|
||||
@ -33,9 +34,11 @@ class ChatLLMService:
|
||||
model_name: str,
|
||||
chat_id: UUID,
|
||||
chat_service: ChatService,
|
||||
model_service: ModelService,
|
||||
):
|
||||
# Services
|
||||
self.chat_service = chat_service
|
||||
self.model_service = model_service
|
||||
|
||||
# Base models
|
||||
self.current_user = current_user
|
||||
@ -130,12 +133,20 @@ class ChatLLMService:
|
||||
)
|
||||
chat_llm = self.build_llm()
|
||||
history = await self.chat_service.get_chat_history(self.chat_id)
|
||||
|
||||
model_metadata = await self.model_service.get_model(self.model_to_use.name)
|
||||
# Format the history, sanitize the input
|
||||
chat_history = self._build_chat_history(history)
|
||||
|
||||
parsed_response = chat_llm.answer(question, chat_history)
|
||||
|
||||
if parsed_response.metadata:
|
||||
parsed_response.metadata.metadata_model = ChatLLMMetadata(
|
||||
name=self.model_to_use.name,
|
||||
description=model_metadata.description,
|
||||
image_url=model_metadata.image_url,
|
||||
display_name=model_metadata.display_name,
|
||||
)
|
||||
|
||||
# Save the answer to db
|
||||
new_chat_entry = self.save_answer(question, parsed_response)
|
||||
|
||||
@ -167,6 +178,9 @@ class ChatLLMService:
|
||||
)
|
||||
# Build the rag config
|
||||
chat_llm = self.build_llm()
|
||||
|
||||
# Get model metadata
|
||||
model_metadata = await self.model_service.get_model(self.model_to_use.name)
|
||||
# Get chat history
|
||||
history = await self.chat_service.get_chat_history(self.chat_id)
|
||||
# Format the history, sanitize the input
|
||||
@ -203,12 +217,22 @@ class ChatLLMService:
|
||||
metadata=response.metadata.model_dump(),
|
||||
**message_metadata,
|
||||
)
|
||||
|
||||
metadata = RAGResponseMetadata(**streamed_chat_history.metadata) # type: ignore
|
||||
metadata.metadata_model = ChatLLMMetadata(
|
||||
name=self.model_to_use.name,
|
||||
description=model_metadata.description,
|
||||
image_url=model_metadata.image_url,
|
||||
display_name=model_metadata.display_name,
|
||||
)
|
||||
streamed_chat_history.metadata = metadata.model_dump()
|
||||
|
||||
logger.info("Last chunk before saving")
|
||||
self.save_answer(
|
||||
question,
|
||||
ParsedRAGResponse(
|
||||
answer=full_answer,
|
||||
metadata=RAGResponseMetadata(**streamed_chat_history.metadata),
|
||||
metadata=metadata,
|
||||
),
|
||||
)
|
||||
yield f"data: {streamed_chat_history.model_dump_json()}"
|
||||
|
@ -1,13 +1,16 @@
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class Model(SQLModel, table=True):
|
||||
__tablename__ = "models"
|
||||
class Model(SQLModel, table=True): # type: ignore
|
||||
__tablename__ = "models" # type: ignore
|
||||
|
||||
name: str = Field(primary_key=True)
|
||||
price: int = Field(default=1)
|
||||
max_input: int = Field(default=2000)
|
||||
max_output: int = Field(default=1000)
|
||||
description: str = Field(default="")
|
||||
display_name: str = Field(default="")
|
||||
image_url: str = Field(default="")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
@ -1,10 +1,11 @@
|
||||
from typing import Sequence
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from quivr_api.models.settings import get_supabase_client
|
||||
from quivr_api.modules.dependencies import BaseRepository
|
||||
from quivr_api.modules.models.entity.model import Model
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class ModelRepository(BaseRepository):
|
||||
@ -17,3 +18,8 @@ class ModelRepository(BaseRepository):
|
||||
query = select(Model)
|
||||
response = await self.session.exec(query)
|
||||
return response.all()
|
||||
|
||||
async def get_model(self, model_name: str) -> Model:
|
||||
query = select(Model).where(Model.name == model_name)
|
||||
response = await self.session.exec(query)
|
||||
return response.first()
|
||||
|
@ -10,3 +10,10 @@ class ModelsInterface(ABC):
|
||||
Get all models
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self, model_name: str) -> Model:
|
||||
"""
|
||||
Get a model by name
|
||||
"""
|
||||
pass
|
||||
|
@ -16,6 +16,12 @@ class ModelService(BaseService[ModelRepository]):
|
||||
logger.info("Getting models")
|
||||
|
||||
models = await self.repository.get_models()
|
||||
logger.info(f"Insert response {models}")
|
||||
|
||||
return models
|
||||
return models # type: ignore
|
||||
|
||||
async def get_model(self, model_name: str) -> Model:
|
||||
logger.info(f"Getting model {model_name}")
|
||||
|
||||
model = await self.repository.get_model(model_name)
|
||||
|
||||
return model
|
||||
|
@ -10,6 +10,7 @@ from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
||||
from quivr_core.chat import ChatHistory
|
||||
from quivr_core.llm import LLMEndpoint
|
||||
from quivr_core.models import (
|
||||
ChatLLMMetadata,
|
||||
ParsedRAGChunkResponse,
|
||||
ParsedRAGResponse,
|
||||
RAGResponseMetadata,
|
||||
@ -139,7 +140,9 @@ class ChatLLM:
|
||||
metadata=get_chunk_metadata(rolling_message),
|
||||
last_chunk=True,
|
||||
)
|
||||
last_chunk.metadata.model_name = self.llm_endpoint._config.model
|
||||
last_chunk.metadata.metadata_model = ChatLLMMetadata(
|
||||
name=self.llm_endpoint._config.model,
|
||||
)
|
||||
logger.debug(
|
||||
f"answer_astream last_chunk={last_chunk} question={question} rolling_msg={rolling_message} chunk_id={chunk_id}"
|
||||
)
|
||||
|
@ -55,11 +55,18 @@ class RawRAGResponse(TypedDict):
|
||||
docs: dict[str, Any]
|
||||
|
||||
|
||||
class ChatLLMMetadata(BaseModel):
|
||||
name: str
|
||||
display_name: str | None = None
|
||||
description: str | None = None
|
||||
image_url: str | None = None
|
||||
|
||||
|
||||
class RAGResponseMetadata(BaseModel):
|
||||
citations: list[int] | None = None
|
||||
followup_questions: list[str] | None = None
|
||||
sources: list[Any] | None = None
|
||||
model_name: str | None = None
|
||||
metadata_model: ChatLLMMetadata | None = None
|
||||
|
||||
|
||||
class ParsedRAGResponse(BaseModel):
|
||||
|
@ -6,6 +6,7 @@ from langchain_core.messages.ai import AIMessageChunk
|
||||
from langchain_core.prompts import format_document
|
||||
|
||||
from quivr_core.models import (
|
||||
ChatLLMMetadata,
|
||||
ParsedRAGResponse,
|
||||
QuivrKnowledge,
|
||||
RAGResponseMetadata,
|
||||
@ -115,22 +116,21 @@ def parse_response(raw_response: RawRAGResponse, model_name: str) -> ParsedRAGRe
|
||||
answer = raw_response["answer"].content
|
||||
sources = raw_response["docs"] or []
|
||||
|
||||
metadata = {"sources": sources, "model_name":model_name}
|
||||
metadata["model_name"] = model_name
|
||||
metadata = RAGResponseMetadata(
|
||||
sources=sources, metadata_model=ChatLLMMetadata(name=model_name)
|
||||
)
|
||||
|
||||
if model_supports_function_calling(model_name):
|
||||
if raw_response["answer"].tool_calls:
|
||||
citations = raw_response["answer"].tool_calls[-1]["args"]["citations"]
|
||||
metadata["citations"] = citations
|
||||
metadata.citations = citations
|
||||
followup_questions = raw_response["answer"].tool_calls[-1]["args"][
|
||||
"followup_questions"
|
||||
]
|
||||
if followup_questions:
|
||||
metadata["followup_questions"] = followup_questions
|
||||
metadata.followup_questions = followup_questions
|
||||
|
||||
parsed_response = ParsedRAGResponse(
|
||||
answer=answer, metadata=RAGResponseMetadata(**metadata)
|
||||
)
|
||||
parsed_response = ParsedRAGResponse(answer=answer, metadata=metadata)
|
||||
return parsed_response
|
||||
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from quivr_core import ChatLLM
|
||||
from quivr_core.chat_llm import ChatLLM
|
||||
|
||||
|
||||
@pytest.mark.base
|
||||
@ -15,3 +15,5 @@ def test_chat_llm(fake_llm):
|
||||
assert answer.metadata.citations is None
|
||||
assert answer.metadata.followup_questions is None
|
||||
assert answer.metadata.sources == []
|
||||
assert answer.metadata.metadata_model is not None
|
||||
assert answer.metadata.metadata_model.name is not None
|
||||
|
@ -0,0 +1,5 @@
|
||||
alter table "public"."models" add column "description" text not null default 'Default Description'::text;
|
||||
|
||||
alter table "public"."models" add column "display_name" text not null default gen_random_uuid();
|
||||
|
||||
alter table "public"."models" add column "image_url" text not null default 'https://quivr-cms.s3.eu-west-3.amazonaws.com/logo_quivr_white_7e3c72620f.png'::text;
|
Loading…
Reference in New Issue
Block a user