mirror of
https://github.com/StanGirard/quivr.git
synced 2025-01-05 18:57:48 +03:00
feat: Add pricing calculation method to GPT4Brain class and update user usage in chat controller (#2210)
This pull request adds a new method called `calculate_pricing` to the `GPT4Brain` class in the codebase. This method calculates the pricing for the GPT4Brain model. Additionally, the user usage in the chat controller has been updated to include the new pricing calculation method.
This commit is contained in:
parent
aa4e85fc32
commit
2c71e0edc7
@ -24,6 +24,9 @@ class GPT4Brain(KnowledgeBrainQA):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def calculate_pricing(self):
|
||||
return 3
|
||||
|
||||
def get_chain(self):
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
|
@ -8,10 +8,15 @@ from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from logger import get_logger
|
||||
from models import BrainSettings
|
||||
from models.user_usage import UserUsage
|
||||
from modules.brain.qa_interface import QAInterface
|
||||
from modules.brain.rags.quivr_rag import QuivrRAG
|
||||
from modules.brain.rags.rag_interface import RAGInterface
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.controller.chat.utils import (
|
||||
find_model_and_generate_metadata,
|
||||
update_user_usage,
|
||||
)
|
||||
from modules.chat.dto.chats import ChatQuestion, Sources
|
||||
from modules.chat.dto.inputs import CreateChatHistory
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
@ -124,8 +129,12 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
max_input: int = 2000
|
||||
streaming: bool = False
|
||||
knowledge_qa: Optional[RAGInterface] = None
|
||||
metadata: Optional[dict] = None
|
||||
user_id: str = None
|
||||
user_email: str = None
|
||||
user_usage: Optional[UserUsage] = None
|
||||
user_settings: Optional[dict] = None
|
||||
models_settings: Optional[List[dict]] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
callbacks: List[AsyncIteratorCallbackHandler] = (
|
||||
None # pyright: ignore reportPrivateUsage=none
|
||||
@ -138,11 +147,12 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
model: str,
|
||||
brain_id: str,
|
||||
chat_id: str,
|
||||
max_tokens: int,
|
||||
streaming: bool = False,
|
||||
prompt_id: Optional[UUID] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
user_id: str = None,
|
||||
user_email: str = None,
|
||||
cost: int = 100,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
@ -160,9 +170,17 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
streaming=streaming,
|
||||
**kwargs,
|
||||
)
|
||||
self.metadata = metadata
|
||||
self.max_tokens = max_tokens
|
||||
self.user_id = user_id
|
||||
self.user_email = user_email
|
||||
self.user_usage = UserUsage(
|
||||
id=user_id,
|
||||
email=user_email,
|
||||
)
|
||||
self.user_settings = self.user_usage.get_user_settings()
|
||||
|
||||
# Get Model settings for the user
|
||||
self.models_settings = self.user_usage.get_model_settings()
|
||||
self.increase_usage_user()
|
||||
|
||||
@property
|
||||
def prompt_to_use(self):
|
||||
@ -179,6 +197,39 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
else:
|
||||
return None
|
||||
|
||||
def increase_usage_user(self):
|
||||
# Raises an error if the user has consumed all of of his credits
|
||||
|
||||
update_user_usage(
|
||||
usage=self.user_usage,
|
||||
user_settings=self.user_settings,
|
||||
cost=self.calculate_pricing(),
|
||||
)
|
||||
|
||||
def calculate_pricing(self):
|
||||
|
||||
logger.info("Calculating pricing")
|
||||
logger.info(f"Model: {self.model}")
|
||||
logger.info(f"User settings: {self.user_settings}")
|
||||
logger.info(f"Models settings: {self.models_settings}")
|
||||
model_to_use = find_model_and_generate_metadata(
|
||||
self.chat_id,
|
||||
self.model,
|
||||
self.user_settings,
|
||||
self.models_settings,
|
||||
)
|
||||
|
||||
self.model = model_to_use.name
|
||||
self.max_input = model_to_use.max_input
|
||||
self.max_tokens = model_to_use.max_output
|
||||
user_choosen_model_price = 1000
|
||||
|
||||
for model_setting in self.models_settings:
|
||||
if model_setting["name"] == self.model:
|
||||
user_choosen_model_price = model_setting["price"]
|
||||
|
||||
return user_choosen_model_price
|
||||
|
||||
def generate_answer(
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> GetChatHistoryOutput:
|
||||
|
@ -10,6 +10,12 @@ class QAInterface(ABC):
|
||||
This can be used to implement custom answer generation logic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def calculate_pricing(self):
|
||||
raise NotImplementedError(
|
||||
"calculate_pricing is an abstract method and must be implemented"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def generate_answer(
|
||||
self,
|
||||
|
@ -59,26 +59,22 @@ class BrainfulChat(ChatInterface):
|
||||
brain,
|
||||
chat_id,
|
||||
model,
|
||||
max_tokens,
|
||||
max_input,
|
||||
temperature,
|
||||
streaming,
|
||||
prompt_id,
|
||||
user_id,
|
||||
metadata,
|
||||
user_email,
|
||||
):
|
||||
if brain and brain.brain_type == BrainType.DOC:
|
||||
return KnowledgeBrainQA(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
max_input=max_input,
|
||||
temperature=temperature,
|
||||
brain_id=str(brain.brain_id),
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
metadata=metadata,
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
)
|
||||
|
||||
if brain.brain_type == BrainType.API:
|
||||
@ -88,18 +84,16 @@ class BrainfulChat(ChatInterface):
|
||||
return APIBrainQA(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
max_input=max_input,
|
||||
temperature=temperature,
|
||||
brain_id=str(brain.brain_id),
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
user_id=user_id,
|
||||
metadata=metadata,
|
||||
raw=(brain_definition.raw if brain_definition else None),
|
||||
jq_instructions=(
|
||||
brain_definition.jq_instructions if brain_definition else None
|
||||
),
|
||||
user_email=user_email,
|
||||
)
|
||||
if brain.brain_type == BrainType.INTEGRATION:
|
||||
integration_brain = integration_brain_description_service.get_integration_description_by_user_brain_id(
|
||||
@ -113,12 +107,10 @@ class BrainfulChat(ChatInterface):
|
||||
return integration_class(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
max_input=max_input,
|
||||
temperature=temperature,
|
||||
brain_id=str(brain.brain_id),
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
metadata=metadata,
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
)
|
||||
|
@ -7,8 +7,8 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
from models.databases.entity import LLMModels
|
||||
from modules.chat.controller.chat.utils import (
|
||||
check_user_requests_limit,
|
||||
find_model_and_generate_metadata,
|
||||
update_user_usage,
|
||||
)
|
||||
|
||||
|
||||
@ -76,7 +76,7 @@ def test_find_model_and_generate_metadata_user_not_allowed(mock_chat_service):
|
||||
|
||||
|
||||
@patch("modules.chat.controller.chat.utils.time")
|
||||
def test_check_user_requests_limit_within_limit(mock_time):
|
||||
def test_check_update_user_usage_within_limit(mock_time):
|
||||
mock_time.strftime.return_value = "20220101"
|
||||
usage = Mock()
|
||||
usage.get_user_monthly_usage.return_value = 50
|
||||
@ -84,13 +84,13 @@ def test_check_user_requests_limit_within_limit(mock_time):
|
||||
models_settings = [{"name": "gpt-3.5-turbo", "price": 10}]
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
check_user_requests_limit(usage, user_settings, models_settings, model_name)
|
||||
update_user_usage(usage, user_settings, models_settings, model_name)
|
||||
|
||||
usage.handle_increment_user_request_count.assert_called_once_with("20220101", 10)
|
||||
|
||||
|
||||
@patch("modules.chat.controller.chat.utils.time")
|
||||
def test_check_user_requests_limit_exceeds_limit(mock_time):
|
||||
def test_update_user_usage_exceeds_limit(mock_time):
|
||||
mock_time.strftime.return_value = "20220101"
|
||||
usage = Mock()
|
||||
usage.get_user_monthly_usage.return_value = 100
|
||||
@ -99,7 +99,7 @@ def test_check_user_requests_limit_exceeds_limit(mock_time):
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
check_user_requests_limit(usage, user_settings, models_settings, model_name)
|
||||
update_user_usage(usage, user_settings, models_settings, model_name)
|
||||
|
||||
assert exc_info.value.status_code == 429
|
||||
assert (
|
||||
|
@ -31,44 +31,39 @@ class NullableUUID(UUID):
|
||||
|
||||
def find_model_and_generate_metadata(
|
||||
chat_id: UUID,
|
||||
brain,
|
||||
brain_model: str,
|
||||
user_settings,
|
||||
models_settings,
|
||||
metadata_brain,
|
||||
):
|
||||
# Add metadata_brain to metadata
|
||||
metadata = {}
|
||||
metadata = {**metadata, **metadata_brain}
|
||||
follow_up_questions = chat_service.get_follow_up_question(chat_id)
|
||||
metadata["follow_up_questions"] = follow_up_questions
|
||||
|
||||
# Default model is gpt-3.5-turbo-0125
|
||||
default_model = "gpt-3.5-turbo-0125"
|
||||
model_to_use = LLMModels( # TODO Implement default models in database
|
||||
name=default_model, price=1, max_input=4000, max_output=1000
|
||||
)
|
||||
|
||||
logger.info("Brain model: %s", brain.model)
|
||||
logger.debug("Brain model: %s", brain_model)
|
||||
|
||||
# If brain.model is None, set it to the default_model
|
||||
if brain.model is None:
|
||||
brain.model = default_model
|
||||
if brain_model is None:
|
||||
brain_model = default_model
|
||||
|
||||
is_brain_model_available = any(
|
||||
brain.model == model_dict.get("name") for model_dict in models_settings
|
||||
brain_model == model_dict.get("name") for model_dict in models_settings
|
||||
)
|
||||
|
||||
is_user_allowed_model = brain.model in user_settings.get(
|
||||
is_user_allowed_model = brain_model in user_settings.get(
|
||||
"models", [default_model]
|
||||
) # Checks if the model is available in the list of models
|
||||
|
||||
logger.info(f"Brain model: {brain.model}")
|
||||
logger.info(f"User models: {user_settings.get('models', [])}")
|
||||
logger.info(f"Model available: {is_brain_model_available}")
|
||||
logger.info(f"User allowed model: {is_user_allowed_model}")
|
||||
logger.debug(f"Brain model: {brain_model}")
|
||||
logger.debug(f"User models: {user_settings.get('models', [])}")
|
||||
logger.debug(f"Model available: {is_brain_model_available}")
|
||||
logger.debug(f"User allowed model: {is_user_allowed_model}")
|
||||
|
||||
if is_brain_model_available and is_user_allowed_model:
|
||||
# Use the model from the brain
|
||||
model_to_use.name = brain.model
|
||||
model_to_use.name = brain_model
|
||||
for model_dict in models_settings:
|
||||
if model_dict.get("name") == model_to_use.name:
|
||||
model_to_use.price = model_dict.get("price")
|
||||
@ -76,19 +71,12 @@ def find_model_and_generate_metadata(
|
||||
model_to_use.max_output = model_dict.get("max_output")
|
||||
break
|
||||
|
||||
metadata["model"] = model_to_use.name
|
||||
metadata["max_tokens"] = model_to_use.max_output
|
||||
metadata["max_input"] = model_to_use.max_input
|
||||
|
||||
logger.info(f"Model to use: {model_to_use}")
|
||||
logger.info(f"Metadata: {metadata}")
|
||||
|
||||
return model_to_use, metadata
|
||||
return model_to_use
|
||||
|
||||
|
||||
def check_user_requests_limit(
|
||||
usage: UserUsage, user_settings, models_settings, model_name: str
|
||||
):
|
||||
def update_user_usage(usage: UserUsage, user_settings, cost: int = 100):
|
||||
"""Checks the user requests limit.
|
||||
It checks the user requests limit and raises an exception if the user has reached the limit.
|
||||
By default, the user has a limit of 100 requests per month. The limit can be increased by upgrading the plan.
|
||||
@ -105,18 +93,13 @@ def check_user_requests_limit(
|
||||
date = time.strftime("%Y%m%d")
|
||||
|
||||
monthly_chat_credit = user_settings.get("monthly_chat_credit", 100)
|
||||
daily_user_count = usage.get_user_monthly_usage(date)
|
||||
user_choosen_model_price = 1000
|
||||
montly_usage = usage.get_user_monthly_usage(date)
|
||||
|
||||
for model_setting in models_settings:
|
||||
if model_setting["name"] == model_name:
|
||||
user_choosen_model_price = model_setting["price"]
|
||||
|
||||
if int(daily_user_count + user_choosen_model_price) > int(monthly_chat_credit):
|
||||
if int(montly_usage + cost) > int(monthly_chat_credit):
|
||||
raise HTTPException(
|
||||
status_code=429, # pyright: ignore reportPrivateUsage=none
|
||||
detail=f"You have reached your monthly chat limit of {monthly_chat_credit} requests per months. Please upgrade your plan to increase your daily chat limit.",
|
||||
)
|
||||
else:
|
||||
usage.handle_increment_user_request_count(date, user_choosen_model_price)
|
||||
usage.handle_increment_user_request_count(date, cost)
|
||||
pass
|
||||
|
@ -11,10 +11,6 @@ from models.settings import BrainSettings, get_supabase_client
|
||||
from models.user_usage import UserUsage
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.controller.chat.brainful_chat import BrainfulChat
|
||||
from modules.chat.controller.chat.utils import (
|
||||
check_user_requests_limit,
|
||||
find_model_and_generate_metadata,
|
||||
)
|
||||
from modules.chat.dto.chats import ChatItem, ChatQuestion
|
||||
from modules.chat.dto.inputs import (
|
||||
ChatUpdatableProperties,
|
||||
@ -76,12 +72,6 @@ def get_answer_generator(
|
||||
# Get History
|
||||
history = chat_service.get_chat_history(chat_id)
|
||||
|
||||
# Get user settings
|
||||
user_settings = user_usage.get_user_settings()
|
||||
|
||||
# Get Model settings for the user
|
||||
models_settings = user_usage.get_model_settings()
|
||||
|
||||
# Generic
|
||||
brain, metadata_brain = brain_service.find_brain_from_question(
|
||||
brain_id, chat_question.question, current_user, chat_id, history, vector_store
|
||||
@ -89,35 +79,17 @@ def get_answer_generator(
|
||||
|
||||
logger.info(f"Brain: {brain}")
|
||||
|
||||
model_to_use, metadata = find_model_and_generate_metadata(
|
||||
chat_id,
|
||||
brain,
|
||||
user_settings,
|
||||
models_settings,
|
||||
metadata_brain,
|
||||
)
|
||||
|
||||
# Raises an error if the user has consumed all of of his credits
|
||||
check_user_requests_limit(
|
||||
usage=user_usage,
|
||||
user_settings=user_settings,
|
||||
models_settings=models_settings,
|
||||
model_name=model_to_use.name,
|
||||
)
|
||||
|
||||
send_telemetry("question_asked", {"model_name": model_to_use.name})
|
||||
send_telemetry("question_asked", {"model_name": brain.model})
|
||||
|
||||
gpt_answer_generator = chat_instance.get_answer_generator(
|
||||
brain=brain,
|
||||
chat_id=str(chat_id),
|
||||
model=model_to_use.name,
|
||||
max_tokens=model_to_use.max_output,
|
||||
max_input=model_to_use.max_input,
|
||||
model=brain.model,
|
||||
temperature=0.1,
|
||||
streaming=True,
|
||||
prompt_id=chat_question.prompt_id,
|
||||
user_id=current_user.id,
|
||||
metadata=metadata,
|
||||
brain=brain,
|
||||
user_email=current_user.email,
|
||||
)
|
||||
|
||||
return gpt_answer_generator
|
||||
|
Loading…
Reference in New Issue
Block a user