feat: Add pricing calculation method to GPT4Brain class and update user usage in chat controller (#2216)

Reverts QuivrHQ/quivr#2215
This commit is contained in:
Stan Girard 2024-02-19 17:29:45 -08:00 committed by GitHub
parent 874c21f7e4
commit 4edf670028
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 104 additions and 100 deletions

View File

@ -24,6 +24,9 @@ class GPT4Brain(KnowledgeBrainQA):
**kwargs,
)
def calculate_pricing(self):
return 3
def get_chain(self):
prompt = ChatPromptTemplate.from_messages(

View File

@ -8,10 +8,16 @@ from llm.utils.get_prompt_to_use import get_prompt_to_use
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
from logger import get_logger
from models import BrainSettings
from models.user_usage import UserUsage
from modules.brain.entity.brain_entity import BrainEntity
from modules.brain.qa_interface import QAInterface
from modules.brain.rags.quivr_rag import QuivrRAG
from modules.brain.rags.rag_interface import RAGInterface
from modules.brain.service.brain_service import BrainService
from modules.chat.controller.chat.utils import (
find_model_and_generate_metadata,
update_user_usage,
)
from modules.chat.dto.chats import ChatQuestion, Sources
from modules.chat.dto.inputs import CreateChatHistory
from modules.chat.dto.outputs import GetChatHistoryOutput
@ -116,7 +122,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
brain_settings: BaseSettings = BrainSettings()
# Default class attributes
model: str = None # pyright: ignore reportPrivateUsage=none
model: str = "gpt-3.5-turbo-0125" # pyright: ignore reportPrivateUsage=none
temperature: float = 0.1
chat_id: str = None # pyright: ignore reportPrivateUsage=none
brain_id: str = None # pyright: ignore reportPrivateUsage=none
@ -124,8 +130,13 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
max_input: int = 2000
streaming: bool = False
knowledge_qa: Optional[RAGInterface] = None
metadata: Optional[dict] = None
brain: Optional[BrainEntity] = None
user_id: str = None
user_email: str = None
user_usage: Optional[UserUsage] = None
user_settings: Optional[dict] = None
models_settings: Optional[List[dict]] = None
metadata: Optional[dict] = None
callbacks: List[AsyncIteratorCallbackHandler] = (
None # pyright: ignore reportPrivateUsage=none
@ -135,34 +146,43 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
def __init__(
self,
model: str,
brain_id: str,
chat_id: str,
max_tokens: int,
streaming: bool = False,
prompt_id: Optional[UUID] = None,
metadata: Optional[dict] = None,
user_id: str = None,
user_email: str = None,
cost: int = 100,
**kwargs,
):
super().__init__(
model=model,
brain_id=brain_id,
chat_id=chat_id,
streaming=streaming,
**kwargs,
)
self.prompt_id = prompt_id
self.user_id = user_id
self.user_email = user_email
self.user_usage = UserUsage(
id=user_id,
email=user_email,
)
self.brain = brain_service.get_brain_by_id(brain_id)
self.user_settings = self.user_usage.get_user_settings()
# Get Model settings for the user
self.models_settings = self.user_usage.get_model_settings()
self.increase_usage_user()
self.knowledge_qa = QuivrRAG(
model=model,
model=self.brain.model,
brain_id=brain_id,
chat_id=chat_id,
streaming=streaming,
**kwargs,
)
self.metadata = metadata
self.max_tokens = max_tokens
self.user_id = user_id
@property
def prompt_to_use(self):
@ -179,6 +199,38 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
else:
return None
def increase_usage_user(self):
# Raises an error if the user has consumed all of of his credits
update_user_usage(
usage=self.user_usage,
user_settings=self.user_settings,
cost=self.calculate_pricing(),
)
def calculate_pricing(self):
logger.info("Calculating pricing")
logger.info(f"Model: {self.model}")
logger.info(f"User settings: {self.user_settings}")
logger.info(f"Models settings: {self.models_settings}")
model_to_use = find_model_and_generate_metadata(
self.chat_id,
self.brain.model,
self.user_settings,
self.models_settings,
)
self.model = model_to_use.name
self.max_input = model_to_use.max_input
self.max_tokens = model_to_use.max_output
user_choosen_model_price = 1000
for model_setting in self.models_settings:
if model_setting["name"] == self.model:
user_choosen_model_price = model_setting["price"]
return user_choosen_model_price
def generate_answer(
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
) -> GetChatHistoryOutput:
@ -198,8 +250,6 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
answer = model_response["answer"].content
brain = brain_service.get_brain_by_id(self.brain_id)
if save_answer:
# save the answer to the database or not -> add a variable
new_chat = chat_service.update_chat_history(
@ -208,7 +258,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"brain_id": brain.brain_id,
"brain_id": self.brain.brain_id,
"prompt_id": self.prompt_to_use_id,
}
)
@ -223,9 +273,9 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
"prompt_title": (
self.prompt_to_use.title if self.prompt_to_use else None
),
"brain_name": brain.name if brain else None,
"brain_name": self.brain.name if self.brain else None,
"message_id": new_chat.message_id,
"brain_id": str(brain.brain_id) if brain else None,
"brain_id": str(self.brain.brain_id) if self.brain else None,
}
)
@ -240,7 +290,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
),
"brain_name": None,
"message_id": None,
"brain_id": str(brain.brain_id) if brain else None,
"brain_id": str(self.brain.brain_id) if self.brain else None,
}
)

View File

@ -10,6 +10,12 @@ class QAInterface(ABC):
This can be used to implement custom answer generation logic.
"""
@abstractmethod
def calculate_pricing(self):
raise NotImplementedError(
"calculate_pricing is an abstract method and must be implemented"
)
@abstractmethod
def generate_answer(
self,

View File

@ -59,26 +59,20 @@ class BrainfulChat(ChatInterface):
brain,
chat_id,
model,
max_tokens,
max_input,
temperature,
streaming,
prompt_id,
user_id,
metadata,
user_email,
):
if brain and brain.brain_type == BrainType.DOC:
return KnowledgeBrainQA(
chat_id=chat_id,
model=model,
max_tokens=max_tokens,
max_input=max_input,
temperature=temperature,
brain_id=str(brain.brain_id),
streaming=streaming,
prompt_id=prompt_id,
metadata=metadata,
user_id=user_id,
user_email=user_email,
)
if brain.brain_type == BrainType.API:
@ -88,18 +82,16 @@ class BrainfulChat(ChatInterface):
return APIBrainQA(
chat_id=chat_id,
model=model,
max_tokens=max_tokens,
max_input=max_input,
temperature=temperature,
brain_id=str(brain.brain_id),
streaming=streaming,
prompt_id=prompt_id,
user_id=user_id,
metadata=metadata,
raw=(brain_definition.raw if brain_definition else None),
jq_instructions=(
brain_definition.jq_instructions if brain_definition else None
),
user_email=user_email,
)
if brain.brain_type == BrainType.INTEGRATION:
integration_brain = integration_brain_description_service.get_integration_description_by_user_brain_id(
@ -113,12 +105,10 @@ class BrainfulChat(ChatInterface):
return integration_class(
chat_id=chat_id,
model=model,
max_tokens=max_tokens,
max_input=max_input,
temperature=temperature,
brain_id=str(brain.brain_id),
streaming=streaming,
prompt_id=prompt_id,
metadata=metadata,
user_id=user_id,
user_email=user_email,
)

View File

@ -7,8 +7,8 @@ import pytest
from fastapi import HTTPException
from models.databases.entity import LLMModels
from modules.chat.controller.chat.utils import (
check_user_requests_limit,
find_model_and_generate_metadata,
update_user_usage,
)
@ -76,7 +76,7 @@ def test_find_model_and_generate_metadata_user_not_allowed(mock_chat_service):
@patch("modules.chat.controller.chat.utils.time")
def test_check_user_requests_limit_within_limit(mock_time):
def test_check_update_user_usage_within_limit(mock_time):
mock_time.strftime.return_value = "20220101"
usage = Mock()
usage.get_user_monthly_usage.return_value = 50
@ -84,13 +84,13 @@ def test_check_user_requests_limit_within_limit(mock_time):
models_settings = [{"name": "gpt-3.5-turbo", "price": 10}]
model_name = "gpt-3.5-turbo"
check_user_requests_limit(usage, user_settings, models_settings, model_name)
update_user_usage(usage, user_settings, models_settings, model_name)
usage.handle_increment_user_request_count.assert_called_once_with("20220101", 10)
@patch("modules.chat.controller.chat.utils.time")
def test_check_user_requests_limit_exceeds_limit(mock_time):
def test_update_user_usage_exceeds_limit(mock_time):
mock_time.strftime.return_value = "20220101"
usage = Mock()
usage.get_user_monthly_usage.return_value = 100
@ -99,7 +99,7 @@ def test_check_user_requests_limit_exceeds_limit(mock_time):
model_name = "gpt-3.5-turbo"
with pytest.raises(HTTPException) as exc_info:
check_user_requests_limit(usage, user_settings, models_settings, model_name)
update_user_usage(usage, user_settings, models_settings, model_name)
assert exc_info.value.status_code == 429
assert (

View File

@ -31,44 +31,39 @@ class NullableUUID(UUID):
def find_model_and_generate_metadata(
chat_id: UUID,
brain,
brain_model: str,
user_settings,
models_settings,
metadata_brain,
):
# Add metadata_brain to metadata
metadata = {}
metadata = {**metadata, **metadata_brain}
follow_up_questions = chat_service.get_follow_up_question(chat_id)
metadata["follow_up_questions"] = follow_up_questions
# Default model is gpt-3.5-turbo-0125
default_model = "gpt-3.5-turbo-0125"
model_to_use = LLMModels( # TODO Implement default models in database
name=default_model, price=1, max_input=4000, max_output=1000
)
logger.info("Brain model: %s", brain.model)
logger.debug("Brain model: %s", brain_model)
# If brain.model is None, set it to the default_model
if brain.model is None:
brain.model = default_model
if brain_model is None:
brain_model = default_model
is_brain_model_available = any(
brain.model == model_dict.get("name") for model_dict in models_settings
brain_model == model_dict.get("name") for model_dict in models_settings
)
is_user_allowed_model = brain.model in user_settings.get(
is_user_allowed_model = brain_model in user_settings.get(
"models", [default_model]
) # Checks if the model is available in the list of models
logger.info(f"Brain model: {brain.model}")
logger.info(f"User models: {user_settings.get('models', [])}")
logger.info(f"Model available: {is_brain_model_available}")
logger.info(f"User allowed model: {is_user_allowed_model}")
logger.debug(f"Brain model: {brain_model}")
logger.debug(f"User models: {user_settings.get('models', [])}")
logger.debug(f"Model available: {is_brain_model_available}")
logger.debug(f"User allowed model: {is_user_allowed_model}")
if is_brain_model_available and is_user_allowed_model:
# Use the model from the brain
model_to_use.name = brain.model
model_to_use.name = brain_model
for model_dict in models_settings:
if model_dict.get("name") == model_to_use.name:
model_to_use.price = model_dict.get("price")
@ -76,19 +71,12 @@ def find_model_and_generate_metadata(
model_to_use.max_output = model_dict.get("max_output")
break
metadata["model"] = model_to_use.name
metadata["max_tokens"] = model_to_use.max_output
metadata["max_input"] = model_to_use.max_input
logger.info(f"Model to use: {model_to_use}")
logger.info(f"Metadata: {metadata}")
return model_to_use, metadata
return model_to_use
def check_user_requests_limit(
usage: UserUsage, user_settings, models_settings, model_name: str
):
def update_user_usage(usage: UserUsage, user_settings, cost: int = 100):
"""Checks the user requests limit.
It checks the user requests limit and raises an exception if the user has reached the limit.
By default, the user has a limit of 100 requests per month. The limit can be increased by upgrading the plan.
@ -105,18 +93,13 @@ def check_user_requests_limit(
date = time.strftime("%Y%m%d")
monthly_chat_credit = user_settings.get("monthly_chat_credit", 100)
daily_user_count = usage.get_user_monthly_usage(date)
user_choosen_model_price = 1000
montly_usage = usage.get_user_monthly_usage(date)
for model_setting in models_settings:
if model_setting["name"] == model_name:
user_choosen_model_price = model_setting["price"]
if int(daily_user_count + user_choosen_model_price) > int(monthly_chat_credit):
if int(montly_usage + cost) > int(monthly_chat_credit):
raise HTTPException(
status_code=429, # pyright: ignore reportPrivateUsage=none
detail=f"You have reached your monthly chat limit of {monthly_chat_credit} requests per months. Please upgrade your plan to increase your daily chat limit.",
)
else:
usage.handle_increment_user_request_count(date, user_choosen_model_price)
usage.handle_increment_user_request_count(date, cost)
pass

View File

@ -11,10 +11,6 @@ from models.settings import BrainSettings, get_supabase_client
from models.user_usage import UserUsage
from modules.brain.service.brain_service import BrainService
from modules.chat.controller.chat.brainful_chat import BrainfulChat
from modules.chat.controller.chat.utils import (
check_user_requests_limit,
find_model_and_generate_metadata,
)
from modules.chat.dto.chats import ChatItem, ChatQuestion
from modules.chat.dto.inputs import (
ChatUpdatableProperties,
@ -76,12 +72,6 @@ def get_answer_generator(
# Get History
history = chat_service.get_chat_history(chat_id)
# Get user settings
user_settings = user_usage.get_user_settings()
# Get Model settings for the user
models_settings = user_usage.get_model_settings()
# Generic
brain, metadata_brain = brain_service.find_brain_from_question(
brain_id, chat_question.question, current_user, chat_id, history, vector_store
@ -89,35 +79,17 @@ def get_answer_generator(
logger.info(f"Brain: {brain}")
model_to_use, metadata = find_model_and_generate_metadata(
chat_id,
brain,
user_settings,
models_settings,
metadata_brain,
)
# Raises an error if the user has consumed all of of his credits
check_user_requests_limit(
usage=user_usage,
user_settings=user_settings,
models_settings=models_settings,
model_name=model_to_use.name,
)
send_telemetry("question_asked", {"model_name": model_to_use.name})
send_telemetry("question_asked", {"model_name": brain.model})
gpt_answer_generator = chat_instance.get_answer_generator(
brain=brain,
chat_id=str(chat_id),
model=model_to_use.name,
max_tokens=model_to_use.max_output,
max_input=model_to_use.max_input,
model=brain.model,
temperature=0.1,
streaming=True,
prompt_id=chat_question.prompt_id,
user_id=current_user.id,
metadata=metadata,
brain=brain,
user_email=current_user.email,
)
return gpt_answer_generator