feat: Add pricing calculation method to GPT4Brain class and update user usage in chat controller (#2210)

This pull request adds a new method called `calculate_pricing` to the
`GPT4Brain` class in the codebase. This method calculates the pricing
for the GPT4Brain model. Additionally, the user usage in the chat
controller has been updated to include the new pricing calculation
method.
This commit is contained in:
Stan Girard 2024-02-18 23:05:13 -08:00 committed by GitHub
parent aa4e85fc32
commit 2c71e0edc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 94 additions and 87 deletions

View File

@ -24,6 +24,9 @@ class GPT4Brain(KnowledgeBrainQA):
**kwargs,
)
def calculate_pricing(self):
return 3
def get_chain(self):
prompt = ChatPromptTemplate.from_messages(

View File

@ -8,10 +8,15 @@ from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
from logger import get_logger
from models import BrainSettings
from models.user_usage import UserUsage
from modules.brain.qa_interface import QAInterface
from modules.brain.rags.quivr_rag import QuivrRAG
from modules.brain.rags.rag_interface import RAGInterface
from modules.brain.service.brain_service import BrainService
from modules.chat.controller.chat.utils import (
find_model_and_generate_metadata,
update_user_usage,
)
from modules.chat.dto.chats import ChatQuestion, Sources
from modules.chat.dto.inputs import CreateChatHistory
from modules.chat.dto.outputs import GetChatHistoryOutput
@ -124,8 +129,12 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
max_input: int = 2000
streaming: bool = False
knowledge_qa: Optional[RAGInterface] = None
metadata: Optional[dict] = None
user_id: str = None
user_email: str = None
user_usage: Optional[UserUsage] = None
user_settings: Optional[dict] = None
models_settings: Optional[List[dict]] = None
metadata: Optional[dict] = None
callbacks: List[AsyncIteratorCallbackHandler] = (
None # pyright: ignore reportPrivateUsage=none
@ -138,11 +147,12 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
model: str,
brain_id: str,
chat_id: str,
max_tokens: int,
streaming: bool = False,
prompt_id: Optional[UUID] = None,
metadata: Optional[dict] = None,
user_id: str = None,
user_email: str = None,
cost: int = 100,
**kwargs,
):
super().__init__(
@ -160,9 +170,17 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
streaming=streaming,
**kwargs,
)
self.metadata = metadata
self.max_tokens = max_tokens
self.user_id = user_id
self.user_email = user_email
self.user_usage = UserUsage(
id=user_id,
email=user_email,
)
self.user_settings = self.user_usage.get_user_settings()
# Get Model settings for the user
self.models_settings = self.user_usage.get_model_settings()
self.increase_usage_user()
@property
def prompt_to_use(self):
@ -179,6 +197,39 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
else:
return None
def increase_usage_user(self):
# Raises an error if the user has consumed all of of his credits
update_user_usage(
usage=self.user_usage,
user_settings=self.user_settings,
cost=self.calculate_pricing(),
)
def calculate_pricing(self):
logger.info("Calculating pricing")
logger.info(f"Model: {self.model}")
logger.info(f"User settings: {self.user_settings}")
logger.info(f"Models settings: {self.models_settings}")
model_to_use = find_model_and_generate_metadata(
self.chat_id,
self.model,
self.user_settings,
self.models_settings,
)
self.model = model_to_use.name
self.max_input = model_to_use.max_input
self.max_tokens = model_to_use.max_output
user_choosen_model_price = 1000
for model_setting in self.models_settings:
if model_setting["name"] == self.model:
user_choosen_model_price = model_setting["price"]
return user_choosen_model_price
def generate_answer(
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> GetChatHistoryOutput:

View File

@ -10,6 +10,12 @@ class QAInterface(ABC):
This can be used to implement custom answer generation logic.
"""
@abstractmethod
def calculate_pricing(self):
raise NotImplementedError(
"calculate_pricing is an abstract method and must be implemented"
)
@abstractmethod
def generate_answer(
self,

View File

@ -59,26 +59,22 @@ class BrainfulChat(ChatInterface):
brain,
chat_id,
model,
max_tokens,
max_input,
temperature,
streaming,
prompt_id,
user_id,
metadata,
user_email,
):
if brain and brain.brain_type == BrainType.DOC:
return KnowledgeBrainQA(
chat_id=chat_id,
model=model,
max_tokens=max_tokens,
max_input=max_input,
temperature=temperature,
brain_id=str(brain.brain_id),
streaming=streaming,
prompt_id=prompt_id,
metadata=metadata,
user_id=user_id,
user_email=user_email,
)
if brain.brain_type == BrainType.API:
@ -88,18 +84,16 @@ class BrainfulChat(ChatInterface):
return APIBrainQA(
chat_id=chat_id,
model=model,
max_tokens=max_tokens,
max_input=max_input,
temperature=temperature,
brain_id=str(brain.brain_id),
streaming=streaming,
prompt_id=prompt_id,
user_id=user_id,
metadata=metadata,
raw=(brain_definition.raw if brain_definition else None),
jq_instructions=(
brain_definition.jq_instructions if brain_definition else None
),
user_email=user_email,
)
if brain.brain_type == BrainType.INTEGRATION:
integration_brain = integration_brain_description_service.get_integration_description_by_user_brain_id(
@ -113,12 +107,10 @@ class BrainfulChat(ChatInterface):
return integration_class(
chat_id=chat_id,
model=model,
max_tokens=max_tokens,
max_input=max_input,
temperature=temperature,
brain_id=str(brain.brain_id),
streaming=streaming,
prompt_id=prompt_id,
metadata=metadata,
user_id=user_id,
user_email=user_email,
)

View File

@ -7,8 +7,8 @@ import pytest
from fastapi import HTTPException
from models.databases.entity import LLMModels
from modules.chat.controller.chat.utils import (
check_user_requests_limit,
find_model_and_generate_metadata,
update_user_usage,
)
@ -76,7 +76,7 @@ def test_find_model_and_generate_metadata_user_not_allowed(mock_chat_service):
@patch("modules.chat.controller.chat.utils.time")
def test_check_user_requests_limit_within_limit(mock_time):
def test_check_update_user_usage_within_limit(mock_time):
mock_time.strftime.return_value = "20220101"
usage = Mock()
usage.get_user_monthly_usage.return_value = 50
@ -84,13 +84,13 @@ def test_check_user_requests_limit_within_limit(mock_time):
models_settings = [{"name": "gpt-3.5-turbo", "price": 10}]
model_name = "gpt-3.5-turbo"
check_user_requests_limit(usage, user_settings, models_settings, model_name)
update_user_usage(usage, user_settings, models_settings, model_name)
usage.handle_increment_user_request_count.assert_called_once_with("20220101", 10)
@patch("modules.chat.controller.chat.utils.time")
def test_check_user_requests_limit_exceeds_limit(mock_time):
def test_update_user_usage_exceeds_limit(mock_time):
mock_time.strftime.return_value = "20220101"
usage = Mock()
usage.get_user_monthly_usage.return_value = 100
@ -99,7 +99,7 @@ def test_check_user_requests_limit_exceeds_limit(mock_time):
model_name = "gpt-3.5-turbo"
with pytest.raises(HTTPException) as exc_info:
check_user_requests_limit(usage, user_settings, models_settings, model_name)
update_user_usage(usage, user_settings, models_settings, model_name)
assert exc_info.value.status_code == 429
assert (

View File

@ -31,44 +31,39 @@ class NullableUUID(UUID):
def find_model_and_generate_metadata(
chat_id: UUID,
brain,
brain_model: str,
user_settings,
models_settings,
metadata_brain,
):
# Add metadata_brain to metadata
metadata = {}
metadata = {**metadata, **metadata_brain}
follow_up_questions = chat_service.get_follow_up_question(chat_id)
metadata["follow_up_questions"] = follow_up_questions
# Default model is gpt-3.5-turbo-0125
default_model = "gpt-3.5-turbo-0125"
model_to_use = LLMModels( # TODO Implement default models in database
name=default_model, price=1, max_input=4000, max_output=1000
)
logger.info("Brain model: %s", brain.model)
logger.debug("Brain model: %s", brain_model)
# If brain.model is None, set it to the default_model
if brain.model is None:
brain.model = default_model
if brain_model is None:
brain_model = default_model
is_brain_model_available = any(
brain.model == model_dict.get("name") for model_dict in models_settings
brain_model == model_dict.get("name") for model_dict in models_settings
)
is_user_allowed_model = brain.model in user_settings.get(
is_user_allowed_model = brain_model in user_settings.get(
"models", [default_model]
) # Checks if the model is available in the list of models
logger.info(f"Brain model: {brain.model}")
logger.info(f"User models: {user_settings.get('models', [])}")
logger.info(f"Model available: {is_brain_model_available}")
logger.info(f"User allowed model: {is_user_allowed_model}")
logger.debug(f"Brain model: {brain_model}")
logger.debug(f"User models: {user_settings.get('models', [])}")
logger.debug(f"Model available: {is_brain_model_available}")
logger.debug(f"User allowed model: {is_user_allowed_model}")
if is_brain_model_available and is_user_allowed_model:
# Use the model from the brain
model_to_use.name = brain.model
model_to_use.name = brain_model
for model_dict in models_settings:
if model_dict.get("name") == model_to_use.name:
model_to_use.price = model_dict.get("price")
@ -76,19 +71,12 @@ def find_model_and_generate_metadata(
model_to_use.max_output = model_dict.get("max_output")
break
metadata["model"] = model_to_use.name
metadata["max_tokens"] = model_to_use.max_output
metadata["max_input"] = model_to_use.max_input
logger.info(f"Model to use: {model_to_use}")
logger.info(f"Metadata: {metadata}")
return model_to_use, metadata
return model_to_use
def check_user_requests_limit(
usage: UserUsage, user_settings, models_settings, model_name: str
):
def update_user_usage(usage: UserUsage, user_settings, cost: int = 100):
"""Checks the user requests limit.
It checks the user requests limit and raises an exception if the user has reached the limit.
By default, the user has a limit of 100 requests per month. The limit can be increased by upgrading the plan.
@ -105,18 +93,13 @@ def check_user_requests_limit(
date = time.strftime("%Y%m%d")
monthly_chat_credit = user_settings.get("monthly_chat_credit", 100)
daily_user_count = usage.get_user_monthly_usage(date)
user_choosen_model_price = 1000
montly_usage = usage.get_user_monthly_usage(date)
for model_setting in models_settings:
if model_setting["name"] == model_name:
user_choosen_model_price = model_setting["price"]
if int(daily_user_count + user_choosen_model_price) > int(monthly_chat_credit):
if int(montly_usage + cost) > int(monthly_chat_credit):
raise HTTPException(
status_code=429, # pyright: ignore reportPrivateUsage=none
detail=f"You have reached your monthly chat limit of {monthly_chat_credit} requests per months. Please upgrade your plan to increase your daily chat limit.",
)
else:
usage.handle_increment_user_request_count(date, user_choosen_model_price)
usage.handle_increment_user_request_count(date, cost)
pass

View File

@ -11,10 +11,6 @@ from models.settings import BrainSettings, get_supabase_client
from models.user_usage import UserUsage
from modules.brain.service.brain_service import BrainService
from modules.chat.controller.chat.brainful_chat import BrainfulChat
from modules.chat.controller.chat.utils import (
check_user_requests_limit,
find_model_and_generate_metadata,
)
from modules.chat.dto.chats import ChatItem, ChatQuestion
from modules.chat.dto.inputs import (
ChatUpdatableProperties,
@ -76,12 +72,6 @@ def get_answer_generator(
# Get History
history = chat_service.get_chat_history(chat_id)
# Get user settings
user_settings = user_usage.get_user_settings()
# Get Model settings for the user
models_settings = user_usage.get_model_settings()
# Generic
brain, metadata_brain = brain_service.find_brain_from_question(
brain_id, chat_question.question, current_user, chat_id, history, vector_store
@ -89,35 +79,17 @@ def get_answer_generator(
logger.info(f"Brain: {brain}")
model_to_use, metadata = find_model_and_generate_metadata(
chat_id,
brain,
user_settings,
models_settings,
metadata_brain,
)
# Raises an error if the user has consumed all of of his credits
check_user_requests_limit(
usage=user_usage,
user_settings=user_settings,
models_settings=models_settings,
model_name=model_to_use.name,
)
send_telemetry("question_asked", {"model_name": model_to_use.name})
send_telemetry("question_asked", {"model_name": brain.model})
gpt_answer_generator = chat_instance.get_answer_generator(
brain=brain,
chat_id=str(chat_id),
model=model_to_use.name,
max_tokens=model_to_use.max_output,
max_input=model_to_use.max_input,
model=brain.model,
temperature=0.1,
streaming=True,
prompt_id=chat_question.prompt_id,
user_id=current_user.id,
metadata=metadata,
brain=brain,
user_email=current_user.email,
)
return gpt_answer_generator