mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 07:59:00 +03:00
feat(models): all models by default (#2983)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
b888524dbf
commit
4ef5f30aa9
@ -5,6 +5,7 @@ from uuid import UUID
|
||||
|
||||
from celery.schedules import crontab
|
||||
from pytz import timezone
|
||||
|
||||
from quivr_api.celery_config import celery
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.middlewares.auth.auth_bearer import AuthBearer
|
||||
@ -241,7 +242,6 @@ def check_if_is_premium_user():
|
||||
"max_brain_size": product["max_brain_size"],
|
||||
"monthly_chat_credit": product["monthly_chat_credit"],
|
||||
"api_access": product["api_access"],
|
||||
"models": product["models"],
|
||||
"is_premium": True,
|
||||
"last_stripe_check": current_time_str,
|
||||
}
|
||||
|
@ -65,6 +65,19 @@ async def retrieve_all_brains_for_user(
|
||||
"""Retrieve all brains for the current user."""
|
||||
brains = brain_user_service.get_user_brains(current_user.id)
|
||||
models = await model_service.get_models()
|
||||
default_model = await model_service.get_default_model()
|
||||
|
||||
for brain in brains:
|
||||
# find the brain.model in models and set the brain.price to the model.price
|
||||
found = False
|
||||
if brain.model:
|
||||
for model in models:
|
||||
if model.name == brain.model:
|
||||
brain.price = model.price
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
brain.price = default_model.price
|
||||
|
||||
for model in models:
|
||||
brains.append(
|
||||
@ -85,6 +98,7 @@ async def retrieve_all_brains_for_user(
|
||||
max_files=0,
|
||||
)
|
||||
)
|
||||
|
||||
return {"brains": brains}
|
||||
|
||||
|
||||
|
@ -111,6 +111,7 @@ class BrainUser(BaseModel):
|
||||
class MinimalUserBrainEntity(BaseModel):
|
||||
id: UUID
|
||||
name: str
|
||||
brain_model: Optional[str] = None
|
||||
rights: RoleEnum
|
||||
status: str
|
||||
brain_type: BrainType
|
||||
|
@ -39,7 +39,7 @@ class BrainsUsers(BrainsUsersInterface):
|
||||
response = (
|
||||
self.db.from_("brains_users")
|
||||
.select(
|
||||
"id:brain_id, rights, brains (brain_id, name, status, brain_type, description, meaning, integrations_user (brain_id, integration_id, integrations (id, integration_name, integration_logo_url, max_files)))"
|
||||
"id:brain_id, rights, brains (brain_id, name, status, brain_type, model, description, meaning, integrations_user (brain_id, integration_id, integrations (id, integration_name, integration_logo_url, max_files)))"
|
||||
)
|
||||
.filter("user_id", "eq", user_id)
|
||||
.execute()
|
||||
@ -60,6 +60,7 @@ class BrainsUsers(BrainsUsersInterface):
|
||||
user_brains.append(
|
||||
MinimalUserBrainEntity(
|
||||
id=item["brains"]["brain_id"],
|
||||
brain_model=item["brains"]["model"],
|
||||
name=item["brains"]["name"],
|
||||
rights=item["rights"],
|
||||
status=item["brains"]["status"],
|
||||
|
@ -2,8 +2,11 @@ import time
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.models.databases.llm_models import LLMModel
|
||||
from quivr_api.modules.models.entity.model import Model
|
||||
from quivr_api.modules.models.service.model_service import ModelService
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
from quivr_api.modules.user.service.user_usage import UserUsage
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -26,59 +29,14 @@ class NullableUUID(UUID):
|
||||
|
||||
|
||||
# TODO: rewrite
|
||||
def compute_cost(model_to_use, models_settings):
|
||||
model = model_to_use.name
|
||||
user_choosen_model_price = 1000
|
||||
for model_setting in models_settings:
|
||||
if model_setting["name"] == model:
|
||||
user_choosen_model_price = model_setting["price"]
|
||||
return user_choosen_model_price
|
||||
|
||||
|
||||
# TODO: rewrite
|
||||
def find_model_and_generate_metadata(
|
||||
async def find_model_and_generate_metadata(
|
||||
brain_model: str | None,
|
||||
user_settings,
|
||||
models_settings,
|
||||
):
|
||||
# Default model is gpt-3.5-turbo-0125
|
||||
default_model = "gpt-3.5-turbo-0125"
|
||||
model_to_use = LLMModel( # TODO Implement default models in database
|
||||
name=default_model, price=1, max_input=4000, max_output=1000
|
||||
)
|
||||
|
||||
logger.debug("Brain model: %s", brain_model)
|
||||
|
||||
# If brain.model is None, set it to the default_model
|
||||
if brain_model is None:
|
||||
brain_model = default_model
|
||||
|
||||
is_brain_model_available = any(
|
||||
brain_model == model_dict.get("name") for model_dict in models_settings
|
||||
)
|
||||
|
||||
is_user_allowed_model = brain_model in user_settings.get(
|
||||
"models", [default_model]
|
||||
) # Checks if the model is available in the list of models
|
||||
|
||||
logger.debug(f"Brain model: {brain_model}")
|
||||
logger.debug(f"User models: {user_settings.get('models', [])}")
|
||||
logger.debug(f"Model available: {is_brain_model_available}")
|
||||
logger.debug(f"User allowed model: {is_user_allowed_model}")
|
||||
|
||||
if is_brain_model_available and is_user_allowed_model:
|
||||
# Use the model from the brain
|
||||
model_to_use.name = brain_model
|
||||
for model_dict in models_settings:
|
||||
if model_dict.get("name") == model_to_use.name:
|
||||
model_to_use.price = model_dict.get("price")
|
||||
model_to_use.max_input = model_dict.get("max_input")
|
||||
model_to_use.max_output = model_dict.get("max_output")
|
||||
break
|
||||
|
||||
logger.info(f"Model to use: {model_to_use}")
|
||||
|
||||
return model_to_use
|
||||
model_service: ModelService,
|
||||
) -> Model:
|
||||
model = await model_service.get_model(brain_model)
|
||||
if model is None:
|
||||
model = await model_service.get_default_model()
|
||||
return model
|
||||
|
||||
|
||||
def update_user_usage(usage: UserUsage, user_settings, cost: int = 100):
|
||||
@ -107,3 +65,30 @@ def update_user_usage(usage: UserUsage, user_settings, cost: int = 100):
|
||||
else:
|
||||
usage.handle_increment_user_request_count(date, cost)
|
||||
pass
|
||||
|
||||
|
||||
async def check_and_update_user_usage(
|
||||
user: UserIdentity, model_name: str, model_service: ModelService
|
||||
):
|
||||
"""Check user limits and raises if user reached his limits:
|
||||
1. Raise if one of the conditions :
|
||||
- User doesn't have access to brains
|
||||
- Model of brain is not is user_settings.models
|
||||
- Latest sum_30d(user_daily_user) < user_settings.max_monthly_usage
|
||||
- Check sum(user_settings.daily_user_count)+ model_price < user_settings.monthly_chat_credits
|
||||
2. Updates user usage
|
||||
"""
|
||||
# TODO(@aminediro) : THIS is bug prone, should retrieve it from DB here
|
||||
user_usage = UserUsage(id=user.id, email=user.email)
|
||||
user_settings = user_usage.get_user_settings()
|
||||
|
||||
# Get the model to use
|
||||
model = await model_service.get_model(model_name)
|
||||
logger.info(f"Model 🔥: {model}")
|
||||
if model is None:
|
||||
model = await model_service.get_default_model()
|
||||
logger.info(f"Model 🔥: {model}")
|
||||
|
||||
# Raises HTTP if user usage exceeds limits
|
||||
update_user_usage(user_usage, user_settings, model.price) # noqa: F821
|
||||
return model
|
||||
|
@ -11,6 +11,7 @@ from quivr_api.modules.brain.service.brain_authorization_service import (
|
||||
validate_brain_authorization,
|
||||
)
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
from quivr_api.modules.chat.controller.chat.utils import check_and_update_user_usage
|
||||
from quivr_api.modules.chat.dto.chats import ChatItem, ChatQuestion
|
||||
from quivr_api.modules.chat.dto.inputs import (
|
||||
ChatMessageProperties,
|
||||
@ -179,31 +180,44 @@ async def create_question_handler(
|
||||
if brain_id == generate_uuid_from_string(model.name):
|
||||
model_to_use = model
|
||||
break
|
||||
|
||||
try:
|
||||
service = None
|
||||
service = None | RAGService | ChatLLMService
|
||||
if not model_to_use:
|
||||
# TODO: check logic into middleware
|
||||
brain = brain_service.get_brain_details(brain_id, current_user.id) # type: ignore
|
||||
model = await check_and_update_user_usage(
|
||||
current_user, str(brain.model), model_service
|
||||
) # type: ignore
|
||||
assert model is not None # type: ignore
|
||||
assert brain is not None # type: ignore
|
||||
|
||||
brain.model = model.name
|
||||
validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
||||
service = RAGService(
|
||||
current_user,
|
||||
brain_id,
|
||||
brain,
|
||||
chat_id,
|
||||
brain_service,
|
||||
prompt_service,
|
||||
chat_service,
|
||||
knowledge_service,
|
||||
model_service,
|
||||
)
|
||||
else:
|
||||
await check_and_update_user_usage(
|
||||
current_user, model_to_use.name, model_service
|
||||
) # type: ignore
|
||||
service = ChatLLMService(
|
||||
current_user,
|
||||
model_to_use.name,
|
||||
chat_id,
|
||||
chat_service,
|
||||
model_service,
|
||||
)
|
||||
) # type: ignore
|
||||
assert service is not None # type: ignore
|
||||
maybe_send_telemetry("question_asked", {"streaming": True}, request)
|
||||
chat_answer = await service.generate_answer(chat_question.question)
|
||||
|
||||
maybe_send_telemetry("question_asked", {"streaming": False}, request)
|
||||
return chat_answer
|
||||
|
||||
except AssertionError:
|
||||
@ -252,24 +266,37 @@ async def create_stream_question_handler(
|
||||
try:
|
||||
service = None
|
||||
if not model_to_use:
|
||||
brain = brain_service.get_brain_details(brain_id, current_user.id) # type: ignore
|
||||
model = await check_and_update_user_usage(
|
||||
current_user, str(brain.model), model_service
|
||||
) # type: ignore
|
||||
assert model is not None # type: ignore
|
||||
assert brain is not None # type: ignore
|
||||
|
||||
brain.model = model.name
|
||||
validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
||||
service = RAGService(
|
||||
current_user,
|
||||
brain_id,
|
||||
brain,
|
||||
chat_id,
|
||||
brain_service,
|
||||
prompt_service,
|
||||
chat_service,
|
||||
knowledge_service,
|
||||
model_service,
|
||||
)
|
||||
else:
|
||||
await check_and_update_user_usage(
|
||||
current_user, model_to_use.name, model_service
|
||||
) # type: ignore
|
||||
service = ChatLLMService(
|
||||
current_user,
|
||||
model_to_use.name,
|
||||
chat_id,
|
||||
chat_service,
|
||||
model_service,
|
||||
)
|
||||
) # type: ignore
|
||||
assert service is not None # type: ignore
|
||||
maybe_send_telemetry("question_asked", {"streaming": True}, request)
|
||||
|
||||
return StreamingResponse(
|
||||
|
@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
import os
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from quivr_core.chat import ChatHistory as ChatHistoryCore
|
||||
@ -8,21 +9,14 @@ from quivr_core.llm.llm_endpoint import LLMEndpoint
|
||||
from quivr_core.models import ChatLLMMetadata, ParsedRAGResponse, RAGResponseMetadata
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.models.settings import settings
|
||||
from quivr_api.modules.brain.service.utils.format_chat_history import (
|
||||
format_chat_history,
|
||||
)
|
||||
from quivr_api.modules.chat.controller.chat.utils import (
|
||||
compute_cost,
|
||||
find_model_and_generate_metadata,
|
||||
update_user_usage,
|
||||
)
|
||||
from quivr_api.modules.chat.dto.inputs import CreateChatHistory
|
||||
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from quivr_api.modules.chat.service.chat_service import ChatService
|
||||
from quivr_api.modules.models.service.model_service import ModelService
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
from quivr_api.modules.user.service.user_usage import UserUsage
|
||||
from quivr_api.packages.utils.uuid_generator import generate_uuid_from_string
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -46,9 +40,7 @@ class ChatLLMService:
|
||||
self.chat_id = chat_id
|
||||
|
||||
# check at init time
|
||||
self.model_to_use = self.check_and_update_user_usage(
|
||||
self.current_user, model_name
|
||||
)
|
||||
self.model_to_use = model_name
|
||||
|
||||
def _build_chat_history(
|
||||
self,
|
||||
@ -60,56 +52,26 @@ class ChatLLMService:
|
||||
[chat_history.append(m) for m in transformed_history]
|
||||
return chat_history
|
||||
|
||||
def build_llm(self) -> ChatLLM:
|
||||
ollama_url = (
|
||||
settings.ollama_api_base_url
|
||||
if settings.ollama_api_base_url
|
||||
and self.model_to_use.name.startswith("ollama")
|
||||
else None
|
||||
)
|
||||
|
||||
async def build_llm(self) -> ChatLLM:
|
||||
model = await self.model_service.get_model(self.model_to_use)
|
||||
api_key = os.getenv(model.env_variable_name, "not-defined")
|
||||
chat_llm = ChatLLM(
|
||||
llm=LLMEndpoint.from_config(
|
||||
LLMEndpointConfig(
|
||||
model=self.model_to_use.name,
|
||||
llm_base_url=ollama_url,
|
||||
llm_api_key="abc-123" if ollama_url else None,
|
||||
model=self.model_to_use,
|
||||
llm_base_url=model.endpoint_url,
|
||||
llm_api_key=api_key,
|
||||
temperature=(LLMEndpointConfig.model_fields["temperature"].default),
|
||||
max_input=self.model_to_use.max_input,
|
||||
max_tokens=self.model_to_use.max_output,
|
||||
max_input=model.max_input,
|
||||
max_tokens=model.max_output,
|
||||
),
|
||||
)
|
||||
)
|
||||
return chat_llm
|
||||
|
||||
def check_and_update_user_usage(self, user: UserIdentity, model_name: str):
|
||||
"""Check user limits and raises if user reached his limits:
|
||||
1. Raise if one of the conditions :
|
||||
- User doesn't have access to brains
|
||||
- Model of brain is not is user_settings.models
|
||||
- Latest sum_30d(user_daily_user) < user_settings.max_monthly_usage
|
||||
- Check sum(user_settings.daily_user_count)+ model_price < user_settings.monthly_chat_credits
|
||||
2. Updates user usage
|
||||
"""
|
||||
# TODO(@aminediro) : THIS is bug prone, should retrieve it from DB here
|
||||
user_usage = UserUsage(id=user.id, email=user.email)
|
||||
user_settings = user_usage.get_user_settings()
|
||||
all_models = user_usage.get_models()
|
||||
|
||||
# TODO(@aminediro): refactor this function
|
||||
model_to_use = find_model_and_generate_metadata(
|
||||
model_name,
|
||||
user_settings,
|
||||
all_models,
|
||||
)
|
||||
cost = compute_cost(model_to_use, all_models)
|
||||
# Raises HTTP if user usage exceeds limits
|
||||
update_user_usage(user_usage, user_settings, cost) # noqa: F821
|
||||
return model_to_use
|
||||
|
||||
def save_answer(self, question: str, answer: ParsedRAGResponse):
|
||||
logger.info(
|
||||
f"Saving answer for chat {self.chat_id} with model {self.model_to_use.name}"
|
||||
f"Saving answer for chat {self.chat_id} with model {self.model_to_use}"
|
||||
)
|
||||
logger.info(answer)
|
||||
return self.chat_service.update_chat_history(
|
||||
@ -130,11 +92,11 @@ class ChatLLMService:
|
||||
question: str,
|
||||
):
|
||||
logger.info(
|
||||
f"Creating question for chat {self.chat_id} with model {self.model_to_use.name} "
|
||||
f"Creating question for chat {self.chat_id} with model {self.model_to_use} "
|
||||
)
|
||||
chat_llm = self.build_llm()
|
||||
chat_llm = await self.build_llm()
|
||||
history = await self.chat_service.get_chat_history(self.chat_id)
|
||||
model_metadata = await self.model_service.get_model(self.model_to_use.name)
|
||||
model_metadata = await self.model_service.get_model(self.model_to_use)
|
||||
# Format the history, sanitize the input
|
||||
chat_history = self._build_chat_history(history)
|
||||
|
||||
@ -143,12 +105,12 @@ class ChatLLMService:
|
||||
if parsed_response.metadata:
|
||||
# TODO: check if this is the right way to do it
|
||||
parsed_response.metadata.metadata_model = ChatLLMMetadata(
|
||||
name=self.model_to_use.name,
|
||||
name=self.model_to_use,
|
||||
description=model_metadata.description,
|
||||
image_url=model_metadata.image_url,
|
||||
display_name=model_metadata.display_name,
|
||||
brain_id=str(generate_uuid_from_string(self.model_to_use.name)),
|
||||
brain_name=self.model_to_use.name,
|
||||
brain_id=str(generate_uuid_from_string(self.model_to_use)),
|
||||
brain_name=self.model_to_use,
|
||||
)
|
||||
|
||||
# Save the answer to db
|
||||
@ -178,13 +140,13 @@ class ChatLLMService:
|
||||
question: str,
|
||||
):
|
||||
logger.info(
|
||||
f"Creating question for chat {self.chat_id} with model {self.model_to_use.name} "
|
||||
f"Creating question for chat {self.chat_id} with model {self.model_to_use} "
|
||||
)
|
||||
# Build the rag config
|
||||
chat_llm = self.build_llm()
|
||||
chat_llm = await self.build_llm()
|
||||
|
||||
# Get model metadata
|
||||
model_metadata = await self.model_service.get_model(self.model_to_use.name)
|
||||
model_metadata = await self.model_service.get_model(self.model_to_use)
|
||||
# Get chat history
|
||||
history = await self.chat_service.get_chat_history(self.chat_id)
|
||||
# Format the history, sanitize the input
|
||||
@ -202,12 +164,12 @@ class ChatLLMService:
|
||||
"brain_id": None,
|
||||
}
|
||||
metadata_model = ChatLLMMetadata(
|
||||
name=self.model_to_use.name,
|
||||
name=self.model_to_use,
|
||||
description=model_metadata.description,
|
||||
image_url=model_metadata.image_url,
|
||||
display_name=model_metadata.display_name,
|
||||
brain_id=str(generate_uuid_from_string(self.model_to_use.name)),
|
||||
brain_name=self.model_to_use.name,
|
||||
brain_id=str(generate_uuid_from_string(self.model_to_use)),
|
||||
brain_name=self.model_to_use,
|
||||
)
|
||||
|
||||
async for response in chat_llm.answer_astream(question, chat_history):
|
||||
@ -233,12 +195,12 @@ class ChatLLMService:
|
||||
|
||||
metadata = RAGResponseMetadata(**streamed_chat_history.metadata) # type: ignore
|
||||
metadata.metadata_model = ChatLLMMetadata(
|
||||
name=self.model_to_use.name,
|
||||
name=self.model_to_use,
|
||||
description=model_metadata.description,
|
||||
image_url=model_metadata.image_url,
|
||||
display_name=model_metadata.display_name,
|
||||
brain_id=str(generate_uuid_from_string(self.model_to_use.name)),
|
||||
brain_name=self.model_to_use.name,
|
||||
brain_id=str(generate_uuid_from_string(self.model_to_use)),
|
||||
brain_name=self.model_to_use,
|
||||
)
|
||||
streamed_chat_history.metadata = metadata.model_dump()
|
||||
|
||||
|
@ -11,6 +11,9 @@ class Model(SQLModel, table=True): # type: ignore
|
||||
description: str = Field(default="")
|
||||
display_name: str = Field(default="")
|
||||
image_url: str = Field(default="")
|
||||
endpoint_url: str = Field(default="")
|
||||
env_variable_name: str = Field(default="")
|
||||
default: bool = Field(default=False)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
@ -19,7 +19,12 @@ class ModelRepository(BaseRepository):
|
||||
response = await self.session.exec(query)
|
||||
return response.all()
|
||||
|
||||
async def get_model(self, model_name: str) -> Model:
|
||||
async def get_model(self, model_name: str) -> Model | None:
|
||||
query = select(Model).where(Model.name == model_name)
|
||||
response = await self.session.exec(query)
|
||||
return response.first()
|
||||
|
||||
async def get_default_model(self) -> Model:
|
||||
query = select(Model).where(Model.default == True) # noqa: E712
|
||||
response = await self.session.exec(query)
|
||||
return response.first()
|
||||
|
@ -17,3 +17,10 @@ class ModelsInterface(ABC):
|
||||
Get a model by name
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> Model:
|
||||
"""
|
||||
Get the default model
|
||||
"""
|
||||
pass
|
||||
|
@ -25,3 +25,10 @@ class ModelService(BaseService[ModelRepository]):
|
||||
model = await self.repository.get_model(model_name)
|
||||
|
||||
return model
|
||||
|
||||
async def get_default_model(self) -> Model | None:
|
||||
logger.info("Getting default model")
|
||||
|
||||
model = await self.repository.get_default_model()
|
||||
|
||||
return model
|
||||
|
@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
import os
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from quivr_core.chat import ChatHistory as ChatHistoryCore
|
||||
@ -11,26 +12,20 @@ from quivr_api.logger import get_logger
|
||||
from quivr_api.models.settings import (
|
||||
get_embedding_client,
|
||||
get_supabase_client,
|
||||
settings,
|
||||
)
|
||||
from quivr_api.modules.brain.entity.brain_entity import BrainEntity
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
from quivr_api.modules.brain.service.utils.format_chat_history import (
|
||||
format_chat_history,
|
||||
)
|
||||
from quivr_api.modules.chat.controller.chat.utils import (
|
||||
compute_cost,
|
||||
find_model_and_generate_metadata,
|
||||
update_user_usage,
|
||||
)
|
||||
from quivr_api.modules.chat.dto.inputs import CreateChatHistory
|
||||
from quivr_api.modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
from quivr_api.modules.chat.service.chat_service import ChatService
|
||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||
from quivr_api.modules.models.service.model_service import ModelService
|
||||
from quivr_api.modules.prompt.entity.prompt import Prompt
|
||||
from quivr_api.modules.prompt.service.prompt_service import PromptService
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
from quivr_api.modules.user.service.user_usage import UserUsage
|
||||
from quivr_api.vectorstore.supabase import CustomSupabaseVectorStore
|
||||
|
||||
from .utils import generate_source
|
||||
@ -42,29 +37,30 @@ class RAGService:
|
||||
def __init__(
|
||||
self,
|
||||
current_user: UserIdentity,
|
||||
brain_id: UUID | None,
|
||||
brain: BrainEntity,
|
||||
chat_id: UUID,
|
||||
brain_service: BrainService,
|
||||
prompt_service: PromptService,
|
||||
chat_service: ChatService,
|
||||
knowledge_service: KnowledgeRepository,
|
||||
model_service: ModelService,
|
||||
):
|
||||
# Services
|
||||
self.brain_service = brain_service
|
||||
self.prompt_service = prompt_service
|
||||
self.chat_service = chat_service
|
||||
self.knowledge_service = knowledge_service
|
||||
self.model_service = model_service
|
||||
|
||||
# Base models
|
||||
self.current_user = current_user
|
||||
self.chat_id = chat_id
|
||||
self.brain = self.get_or_create_brain(brain_id, self.current_user.id)
|
||||
self.brain = brain
|
||||
self.prompt = self.get_brain_prompt(self.brain)
|
||||
|
||||
# check at init time
|
||||
self.model_to_use = self.check_and_update_user_usage(
|
||||
self.current_user, self.brain
|
||||
)
|
||||
self.model_to_use = brain.model
|
||||
assert self.model_to_use is not None
|
||||
|
||||
def get_brain_prompt(self, brain: BrainEntity) -> Prompt | None:
|
||||
return (
|
||||
@ -85,30 +81,18 @@ class RAGService:
|
||||
[chat_history.append(m) for m in transformed_history]
|
||||
return chat_history
|
||||
|
||||
def _build_rag_config(self) -> RAGConfig:
|
||||
ollama_url = (
|
||||
settings.ollama_api_base_url
|
||||
if settings.ollama_api_base_url
|
||||
and self.model_to_use.name.startswith("ollama")
|
||||
else None
|
||||
)
|
||||
async def _build_rag_config(self) -> RAGConfig:
|
||||
model = await self.model_service.get_model(self.model_to_use) # type: ignore
|
||||
api_key = os.getenv(model.env_variable_name, "not-defined")
|
||||
|
||||
rag_config = RAGConfig(
|
||||
llm_config=LLMEndpointConfig(
|
||||
model=self.model_to_use.name,
|
||||
llm_base_url=ollama_url,
|
||||
llm_api_key="abc-123" if ollama_url else None,
|
||||
temperature=(
|
||||
self.brain.temperature
|
||||
if self.brain.temperature
|
||||
else LLMEndpointConfig.model_fields["temperature"].default
|
||||
),
|
||||
max_input=self.model_to_use.max_input,
|
||||
max_tokens=(
|
||||
self.brain.max_tokens
|
||||
if self.brain.max_tokens
|
||||
else LLMEndpointConfig.model_fields["max_tokens"].default
|
||||
),
|
||||
model=self.model_to_use, # type: ignore
|
||||
llm_base_url=model.endpoint_url,
|
||||
llm_api_key=api_key,
|
||||
temperature=(LLMEndpointConfig.model_fields["temperature"].default),
|
||||
max_input=model.max_input,
|
||||
max_tokens=model.max_output,
|
||||
),
|
||||
prompt=self.prompt.content if self.prompt else None,
|
||||
)
|
||||
@ -117,44 +101,6 @@ class RAGService:
|
||||
def get_llm(self, rag_config: RAGConfig):
|
||||
return LLMEndpoint.from_config(rag_config.llm_config)
|
||||
|
||||
def get_or_create_brain(self, brain_id: UUID | None, user_id: UUID) -> BrainEntity:
|
||||
brain = None
|
||||
if brain_id is not None:
|
||||
brain = self.brain_service.get_brain_details(brain_id, user_id)
|
||||
|
||||
# TODO: Create if doesn't exist
|
||||
assert brain
|
||||
|
||||
if brain.integration:
|
||||
# TODO: entity should be UUID
|
||||
assert brain.integration.user_id == str(user_id)
|
||||
return brain
|
||||
|
||||
def check_and_update_user_usage(self, user: UserIdentity, brain: BrainEntity):
|
||||
"""Check user limits and raises if user reached his limits:
|
||||
1. Raise if one of the conditions :
|
||||
- User doesn't have access to brains
|
||||
- Model of brain is not is user_settings.models
|
||||
- Latest sum_30d(user_daily_user) < user_settings.max_monthly_usage
|
||||
- Check sum(user_settings.daily_user_count)+ model_price < user_settings.monthly_chat_credits
|
||||
2. Updates user usage
|
||||
"""
|
||||
# TODO(@aminediro) : THIS is bug prone, should retrieve it from DB here
|
||||
user_usage = UserUsage(id=user.id, email=user.email)
|
||||
user_settings = user_usage.get_user_settings()
|
||||
all_models = user_usage.get_models()
|
||||
|
||||
# TODO(@aminediro): refactor this function
|
||||
model_to_use = find_model_and_generate_metadata(
|
||||
brain.model,
|
||||
user_settings,
|
||||
all_models,
|
||||
)
|
||||
cost = compute_cost(model_to_use, all_models)
|
||||
# Raises HTTP if user usage exceeds limits
|
||||
update_user_usage(user_usage, user_settings, cost) # noqa: F821
|
||||
return model_to_use
|
||||
|
||||
def create_vector_store(
|
||||
self, brain_id: UUID, max_input: int
|
||||
) -> CustomSupabaseVectorStore:
|
||||
@ -190,7 +136,7 @@ class RAGService:
|
||||
logger.info(
|
||||
f"Creating question for chat {self.chat_id} with brain {self.brain.brain_id} "
|
||||
)
|
||||
rag_config = self._build_rag_config()
|
||||
rag_config = await self._build_rag_config()
|
||||
logger.debug(f"generate_answer with config : {rag_config.model_dump()}")
|
||||
history = await self.chat_service.get_chat_history(self.chat_id)
|
||||
# Get list of files
|
||||
@ -241,7 +187,7 @@ class RAGService:
|
||||
f"Creating question for chat {self.chat_id} with brain {self.brain.brain_id} "
|
||||
)
|
||||
# Build the rag config
|
||||
rag_config = self._build_rag_config()
|
||||
rag_config = await self._build_rag_config()
|
||||
# Get chat history
|
||||
history = await self.chat_service.get_chat_history(self.chat_id)
|
||||
# Format the history, sanitize the input
|
||||
|
@ -1,6 +1,11 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
||||
from quivr_api.modules.brain.service.brain_user_service import BrainUserService
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
from quivr_api.modules.models.service.model_service import ModelService
|
||||
from quivr_api.modules.user.dto.inputs import UserUpdatableProperties
|
||||
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
from quivr_api.modules.user.repository.users import Users
|
||||
@ -8,12 +13,15 @@ from quivr_api.modules.user.service.user_usage import UserUsage
|
||||
|
||||
user_router = APIRouter()
|
||||
brain_user_service = BrainUserService()
|
||||
ModelServiceDep = Annotated[ModelService, Depends(get_service(ModelService))]
|
||||
user_repository = Users()
|
||||
|
||||
|
||||
@user_router.get("/user", dependencies=[Depends(AuthBearer())], tags=["User"])
|
||||
async def get_user_endpoint(
|
||||
request: Request, current_user: UserIdentity = Depends(get_current_user)
|
||||
request: Request,
|
||||
model_service: ModelServiceDep,
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Get user information and statistics.
|
||||
@ -38,14 +46,15 @@ async def get_user_endpoint(
|
||||
monthly_chat_credit = user_settings.get("monthly_chat_credit", 10)
|
||||
|
||||
user_daily_usage = UserUsage(id=current_user.id)
|
||||
|
||||
models = await model_service.get_models()
|
||||
models_names = [model.name for model in models]
|
||||
return {
|
||||
"email": current_user.email,
|
||||
"max_brain_size": max_brain_size,
|
||||
"max_brains": max_brains,
|
||||
"current_brain_size": 0,
|
||||
"monthly_chat_credit": monthly_chat_credit,
|
||||
"models": user_settings.get("models", []),
|
||||
"models": models_names,
|
||||
"id": current_user.id,
|
||||
"is_premium": user_settings["is_premium"],
|
||||
}
|
||||
|
@ -0,0 +1,9 @@
|
||||
alter table "public"."models" add column "default" boolean not null default false;
|
||||
|
||||
alter table "public"."models" add column "endpoint_url" text not null default 'https://api.openai.com/v1/models'::text;
|
||||
|
||||
alter table "public"."models" add column "env_variable_name" text not null default 'OPENAI_API_KEY'::text;
|
||||
|
||||
alter table "public"."user_settings" drop column "models";
|
||||
|
||||
|
@ -0,0 +1 @@
|
||||
alter table "public"."models" alter column "endpoint_url" set default 'https://api.openai.com/v1'::text;
|
File diff suppressed because one or more lines are too long
@ -49,13 +49,16 @@ services:
|
||||
- 5050:5050
|
||||
|
||||
notifier:
|
||||
pull_policy: never
|
||||
image: backend-base:latest
|
||||
pull_policy: if_not_present
|
||||
image: stangirard/quivr-backend-prebuilt:latest
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
env_file:
|
||||
- .env
|
||||
container_name: notifier
|
||||
build:
|
||||
context: backend
|
||||
dockerfile: Dockerfile
|
||||
command:
|
||||
- "python"
|
||||
- "/code/api/quivr_api/celery_monitor.py"
|
||||
@ -122,7 +125,3 @@ services:
|
||||
- beat
|
||||
ports:
|
||||
- 5555:5555
|
||||
|
||||
networks:
|
||||
quivr-network:
|
||||
driver: bridge
|
||||
|
Loading…
Reference in New Issue
Block a user