From 4edf670028be9611cfd1b2d931dd775c8359c3b3 Mon Sep 17 00:00:00 2001 From: Stan Girard Date: Mon, 19 Feb 2024 17:29:45 -0800 Subject: [PATCH] feat: Add pricing calculation method to GPT4Brain class and update user usage in chat controller (#2216) Reverts QuivrHQ/quivr#2215 --- .../modules/brain/integrations/GPT4/Brain.py | 3 + backend/modules/brain/knowledge_brain_qa.py | 80 +++++++++++++++---- backend/modules/brain/qa_interface.py | 6 ++ .../chat/controller/chat/brainful_chat.py | 18 +---- .../chat/controller/chat/test_utils.py | 10 +-- backend/modules/chat/controller/chat/utils.py | 51 ++++-------- .../modules/chat/controller/chat_routes.py | 36 +-------- 7 files changed, 104 insertions(+), 100 deletions(-) diff --git a/backend/modules/brain/integrations/GPT4/Brain.py b/backend/modules/brain/integrations/GPT4/Brain.py index 505b4d4f9..dea230129 100644 --- a/backend/modules/brain/integrations/GPT4/Brain.py +++ b/backend/modules/brain/integrations/GPT4/Brain.py @@ -24,6 +24,9 @@ class GPT4Brain(KnowledgeBrainQA): **kwargs, ) + def calculate_pricing(self): + return 3 + def get_chain(self): prompt = ChatPromptTemplate.from_messages( diff --git a/backend/modules/brain/knowledge_brain_qa.py b/backend/modules/brain/knowledge_brain_qa.py index f710c0c1c..3dfc1619f 100644 --- a/backend/modules/brain/knowledge_brain_qa.py +++ b/backend/modules/brain/knowledge_brain_qa.py @@ -8,10 +8,16 @@ from llm.utils.get_prompt_to_use import get_prompt_to_use from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id from logger import get_logger from models import BrainSettings +from models.user_usage import UserUsage +from modules.brain.entity.brain_entity import BrainEntity from modules.brain.qa_interface import QAInterface from modules.brain.rags.quivr_rag import QuivrRAG from modules.brain.rags.rag_interface import RAGInterface from modules.brain.service.brain_service import BrainService +from modules.chat.controller.chat.utils import ( + find_model_and_generate_metadata, + update_user_usage, +) from modules.chat.dto.chats import ChatQuestion, Sources from modules.chat.dto.inputs import CreateChatHistory from modules.chat.dto.outputs import GetChatHistoryOutput @@ -116,7 +122,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface): brain_settings: BaseSettings = BrainSettings() # Default class attributes - model: str = None # pyright: ignore reportPrivateUsage=none + model: str = "gpt-3.5-turbo-0125" # pyright: ignore reportPrivateUsage=none temperature: float = 0.1 chat_id: str = None # pyright: ignore reportPrivateUsage=none brain_id: str = None # pyright: ignore reportPrivateUsage=none @@ -124,8 +130,13 @@ class KnowledgeBrainQA(BaseModel, QAInterface): max_input: int = 2000 streaming: bool = False knowledge_qa: Optional[RAGInterface] = None - metadata: Optional[dict] = None + brain: Optional[BrainEntity] = None user_id: str = None + user_email: str = None + user_usage: Optional[UserUsage] = None + user_settings: Optional[dict] = None + models_settings: Optional[List[dict]] = None + metadata: Optional[dict] = None callbacks: List[AsyncIteratorCallbackHandler] = ( None # pyright: ignore reportPrivateUsage=none @@ -135,34 +146,43 @@ class KnowledgeBrainQA(BaseModel, QAInterface): def __init__( self, - model: str, brain_id: str, chat_id: str, - max_tokens: int, streaming: bool = False, prompt_id: Optional[UUID] = None, metadata: Optional[dict] = None, user_id: str = None, + user_email: str = None, + cost: int = 100, **kwargs, ): super().__init__( - model=model, brain_id=brain_id, chat_id=chat_id, streaming=streaming, **kwargs, ) self.prompt_id = prompt_id + self.user_id = user_id + self.user_email = user_email + self.user_usage = UserUsage( + id=user_id, + email=user_email, + ) + self.brain = brain_service.get_brain_by_id(brain_id) + + self.user_settings = self.user_usage.get_user_settings() + + # Get Model settings for the user + self.models_settings = self.user_usage.get_model_settings() + self.increase_usage_user() self.knowledge_qa = QuivrRAG( - model=model, + model=self.brain.model, brain_id=brain_id, chat_id=chat_id, streaming=streaming, **kwargs, ) - self.metadata = metadata - self.max_tokens = max_tokens - self.user_id = user_id @property def prompt_to_use(self): @@ -179,6 +199,38 @@ class KnowledgeBrainQA(BaseModel, QAInterface): else: return None + def increase_usage_user(self): + # Raises an error if the user has consumed all of of his credits + + update_user_usage( + usage=self.user_usage, + user_settings=self.user_settings, + cost=self.calculate_pricing(), + ) + + def calculate_pricing(self): + + logger.info("Calculating pricing") + logger.info(f"Model: {self.model}") + logger.info(f"User settings: {self.user_settings}") + logger.info(f"Models settings: {self.models_settings}") + model_to_use = find_model_and_generate_metadata( + self.chat_id, + self.brain.model, + self.user_settings, + self.models_settings, + ) + self.model = model_to_use.name + self.max_input = model_to_use.max_input + self.max_tokens = model_to_use.max_output + user_choosen_model_price = 1000 + + for model_setting in self.models_settings: + if model_setting["name"] == self.model: + user_choosen_model_price = model_setting["price"] + + return user_choosen_model_price + def generate_answer( self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True ) -> GetChatHistoryOutput: @@ -198,8 +250,6 @@ class KnowledgeBrainQA(BaseModel, QAInterface): answer = model_response["answer"].content - brain = brain_service.get_brain_by_id(self.brain_id) - if save_answer: # save the answer to the database or not -> add a variable new_chat = chat_service.update_chat_history( @@ -208,7 +258,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface): "chat_id": chat_id, "user_message": question.question, "assistant": answer, - "brain_id": brain.brain_id, + "brain_id": self.brain.brain_id, "prompt_id": self.prompt_to_use_id, } ) @@ -223,9 +273,9 @@ class KnowledgeBrainQA(BaseModel, QAInterface): "prompt_title": ( self.prompt_to_use.title if self.prompt_to_use else None ), - "brain_name": brain.name if brain else None, + "brain_name": self.brain.name if self.brain else None, "message_id": new_chat.message_id, - "brain_id": str(brain.brain_id) if brain else None, + "brain_id": str(self.brain.brain_id) if self.brain else None, } ) @@ -240,7 +290,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface): ), "brain_name": None, "message_id": None, - "brain_id": str(brain.brain_id) if brain else None, + "brain_id": str(self.brain.brain_id) if self.brain else None, } ) diff --git a/backend/modules/brain/qa_interface.py b/backend/modules/brain/qa_interface.py index a4c18fcde..69b9f2c91 100644 --- a/backend/modules/brain/qa_interface.py +++ b/backend/modules/brain/qa_interface.py @@ -10,6 +10,12 @@ class QAInterface(ABC): This can be used to implement custom answer generation logic. """ + @abstractmethod + def calculate_pricing(self): + raise NotImplementedError( + "calculate_pricing is an abstract method and must be implemented" + ) + @abstractmethod def generate_answer( self, diff --git a/backend/modules/chat/controller/chat/brainful_chat.py b/backend/modules/chat/controller/chat/brainful_chat.py index 58b536482..eb85ba6a8 100644 --- a/backend/modules/chat/controller/chat/brainful_chat.py +++ b/backend/modules/chat/controller/chat/brainful_chat.py @@ -59,26 +59,20 @@ class BrainfulChat(ChatInterface): brain, chat_id, model, - max_tokens, - max_input, temperature, streaming, prompt_id, user_id, - metadata, + user_email, ): if brain and brain.brain_type == BrainType.DOC: return KnowledgeBrainQA( chat_id=chat_id, - model=model, - max_tokens=max_tokens, - max_input=max_input, - temperature=temperature, brain_id=str(brain.brain_id), streaming=streaming, prompt_id=prompt_id, - metadata=metadata, user_id=user_id, + user_email=user_email, ) if brain.brain_type == BrainType.API: @@ -88,18 +82,16 @@ class BrainfulChat(ChatInterface): return APIBrainQA( chat_id=chat_id, model=model, - max_tokens=max_tokens, - max_input=max_input, temperature=temperature, brain_id=str(brain.brain_id), streaming=streaming, prompt_id=prompt_id, user_id=user_id, - metadata=metadata, raw=(brain_definition.raw if brain_definition else None), jq_instructions=( brain_definition.jq_instructions if brain_definition else None ), + user_email=user_email, ) if brain.brain_type == BrainType.INTEGRATION: integration_brain = integration_brain_description_service.get_integration_description_by_user_brain_id( @@ -113,12 +105,10 @@ class BrainfulChat(ChatInterface): return integration_class( chat_id=chat_id, model=model, - max_tokens=max_tokens, - max_input=max_input, temperature=temperature, brain_id=str(brain.brain_id), streaming=streaming, prompt_id=prompt_id, - metadata=metadata, user_id=user_id, + user_email=user_email, ) diff --git a/backend/modules/chat/controller/chat/test_utils.py b/backend/modules/chat/controller/chat/test_utils.py index abbf24e57..beb66cb50 100644 --- a/backend/modules/chat/controller/chat/test_utils.py +++ b/backend/modules/chat/controller/chat/test_utils.py @@ -7,8 +7,8 @@ import pytest from fastapi import HTTPException from models.databases.entity import LLMModels from modules.chat.controller.chat.utils import ( - check_user_requests_limit, find_model_and_generate_metadata, + update_user_usage, ) @@ -76,7 +76,7 @@ def test_find_model_and_generate_metadata_user_not_allowed(mock_chat_service): @patch("modules.chat.controller.chat.utils.time") -def test_check_user_requests_limit_within_limit(mock_time): +def test_check_update_user_usage_within_limit(mock_time): mock_time.strftime.return_value = "20220101" usage = Mock() usage.get_user_monthly_usage.return_value = 50 @@ -84,13 +84,13 @@ def test_check_user_requests_limit_within_limit(mock_time): models_settings = [{"name": "gpt-3.5-turbo", "price": 10}] model_name = "gpt-3.5-turbo" - check_user_requests_limit(usage, user_settings, models_settings, model_name) + update_user_usage(usage, user_settings, models_settings, model_name) usage.handle_increment_user_request_count.assert_called_once_with("20220101", 10) @patch("modules.chat.controller.chat.utils.time") -def test_check_user_requests_limit_exceeds_limit(mock_time): +def test_update_user_usage_exceeds_limit(mock_time): mock_time.strftime.return_value = "20220101" usage = Mock() usage.get_user_monthly_usage.return_value = 100 @@ -99,7 +99,7 @@ def test_check_user_requests_limit_exceeds_limit(mock_time): model_name = "gpt-3.5-turbo" with pytest.raises(HTTPException) as exc_info: - check_user_requests_limit(usage, user_settings, models_settings, model_name) + update_user_usage(usage, user_settings, models_settings, model_name) assert exc_info.value.status_code == 429 assert ( diff --git a/backend/modules/chat/controller/chat/utils.py b/backend/modules/chat/controller/chat/utils.py index a38f6950e..7fe8eba81 100644 --- a/backend/modules/chat/controller/chat/utils.py +++ b/backend/modules/chat/controller/chat/utils.py @@ -31,44 +31,39 @@ class NullableUUID(UUID): def find_model_and_generate_metadata( chat_id: UUID, - brain, + brain_model: str, user_settings, models_settings, - metadata_brain, ): - # Add metadata_brain to metadata - metadata = {} - metadata = {**metadata, **metadata_brain} - follow_up_questions = chat_service.get_follow_up_question(chat_id) - metadata["follow_up_questions"] = follow_up_questions + # Default model is gpt-3.5-turbo-0125 default_model = "gpt-3.5-turbo-0125" model_to_use = LLMModels( # TODO Implement default models in database name=default_model, price=1, max_input=4000, max_output=1000 ) - logger.info("Brain model: %s", brain.model) + logger.debug("Brain model: %s", brain_model) # If brain.model is None, set it to the default_model - if brain.model is None: - brain.model = default_model + if brain_model is None: + brain_model = default_model is_brain_model_available = any( - brain.model == model_dict.get("name") for model_dict in models_settings + brain_model == model_dict.get("name") for model_dict in models_settings ) - is_user_allowed_model = brain.model in user_settings.get( + is_user_allowed_model = brain_model in user_settings.get( "models", [default_model] ) # Checks if the model is available in the list of models - logger.info(f"Brain model: {brain.model}") - logger.info(f"User models: {user_settings.get('models', [])}") - logger.info(f"Model available: {is_brain_model_available}") - logger.info(f"User allowed model: {is_user_allowed_model}") + logger.debug(f"Brain model: {brain_model}") + logger.debug(f"User models: {user_settings.get('models', [])}") + logger.debug(f"Model available: {is_brain_model_available}") + logger.debug(f"User allowed model: {is_user_allowed_model}") if is_brain_model_available and is_user_allowed_model: # Use the model from the brain - model_to_use.name = brain.model + model_to_use.name = brain_model for model_dict in models_settings: if model_dict.get("name") == model_to_use.name: model_to_use.price = model_dict.get("price") @@ -76,19 +71,12 @@ def find_model_and_generate_metadata( model_to_use.max_output = model_dict.get("max_output") break - metadata["model"] = model_to_use.name - metadata["max_tokens"] = model_to_use.max_output - metadata["max_input"] = model_to_use.max_input - logger.info(f"Model to use: {model_to_use}") - logger.info(f"Metadata: {metadata}") - return model_to_use, metadata + return model_to_use -def check_user_requests_limit( - usage: UserUsage, user_settings, models_settings, model_name: str -): +def update_user_usage(usage: UserUsage, user_settings, cost: int = 100): """Checks the user requests limit. It checks the user requests limit and raises an exception if the user has reached the limit. By default, the user has a limit of 100 requests per month. The limit can be increased by upgrading the plan. @@ -105,18 +93,13 @@ def check_user_requests_limit( date = time.strftime("%Y%m%d") monthly_chat_credit = user_settings.get("monthly_chat_credit", 100) - daily_user_count = usage.get_user_monthly_usage(date) - user_choosen_model_price = 1000 + montly_usage = usage.get_user_monthly_usage(date) - for model_setting in models_settings: - if model_setting["name"] == model_name: - user_choosen_model_price = model_setting["price"] - - if int(daily_user_count + user_choosen_model_price) > int(monthly_chat_credit): + if int(montly_usage + cost) > int(monthly_chat_credit): raise HTTPException( status_code=429, # pyright: ignore reportPrivateUsage=none detail=f"You have reached your monthly chat limit of {monthly_chat_credit} requests per months. Please upgrade your plan to increase your daily chat limit.", ) else: - usage.handle_increment_user_request_count(date, user_choosen_model_price) + usage.handle_increment_user_request_count(date, cost) pass diff --git a/backend/modules/chat/controller/chat_routes.py b/backend/modules/chat/controller/chat_routes.py index 2b266cba8..c4857b873 100644 --- a/backend/modules/chat/controller/chat_routes.py +++ b/backend/modules/chat/controller/chat_routes.py @@ -11,10 +11,6 @@ from models.settings import BrainSettings, get_supabase_client from models.user_usage import UserUsage from modules.brain.service.brain_service import BrainService from modules.chat.controller.chat.brainful_chat import BrainfulChat -from modules.chat.controller.chat.utils import ( - check_user_requests_limit, - find_model_and_generate_metadata, -) from modules.chat.dto.chats import ChatItem, ChatQuestion from modules.chat.dto.inputs import ( ChatUpdatableProperties, @@ -76,12 +72,6 @@ def get_answer_generator( # Get History history = chat_service.get_chat_history(chat_id) - # Get user settings - user_settings = user_usage.get_user_settings() - - # Get Model settings for the user - models_settings = user_usage.get_model_settings() - # Generic brain, metadata_brain = brain_service.find_brain_from_question( brain_id, chat_question.question, current_user, chat_id, history, vector_store @@ -89,35 +79,17 @@ def get_answer_generator( logger.info(f"Brain: {brain}") - model_to_use, metadata = find_model_and_generate_metadata( - chat_id, - brain, - user_settings, - models_settings, - metadata_brain, - ) - - # Raises an error if the user has consumed all of of his credits - check_user_requests_limit( - usage=user_usage, - user_settings=user_settings, - models_settings=models_settings, - model_name=model_to_use.name, - ) - - send_telemetry("question_asked", {"model_name": model_to_use.name}) + send_telemetry("question_asked", {"model_name": brain.model}) gpt_answer_generator = chat_instance.get_answer_generator( + brain=brain, chat_id=str(chat_id), - model=model_to_use.name, - max_tokens=model_to_use.max_output, - max_input=model_to_use.max_input, + model=brain.model, temperature=0.1, streaming=True, prompt_id=chat_question.prompt_id, user_id=current_user.id, - metadata=metadata, - brain=brain, + user_email=current_user.email, ) return gpt_answer_generator