mirror of
https://github.com/StanGirard/quivr.git
synced 2025-01-05 18:57:48 +03:00
feat: Add pricing calculation method to GPT4Brain class and update user usage in chat controller (#2216)
Reverts QuivrHQ/quivr#2215
This commit is contained in:
parent
874c21f7e4
commit
4edf670028
@ -24,6 +24,9 @@ class GPT4Brain(KnowledgeBrainQA):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def calculate_pricing(self):
|
||||
return 3
|
||||
|
||||
def get_chain(self):
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
|
@ -8,10 +8,16 @@ from llm.utils.get_prompt_to_use import get_prompt_to_use
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from logger import get_logger
|
||||
from models import BrainSettings
|
||||
from models.user_usage import UserUsage
|
||||
from modules.brain.entity.brain_entity import BrainEntity
|
||||
from modules.brain.qa_interface import QAInterface
|
||||
from modules.brain.rags.quivr_rag import QuivrRAG
|
||||
from modules.brain.rags.rag_interface import RAGInterface
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.controller.chat.utils import (
|
||||
find_model_and_generate_metadata,
|
||||
update_user_usage,
|
||||
)
|
||||
from modules.chat.dto.chats import ChatQuestion, Sources
|
||||
from modules.chat.dto.inputs import CreateChatHistory
|
||||
from modules.chat.dto.outputs import GetChatHistoryOutput
|
||||
@ -116,7 +122,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
brain_settings: BaseSettings = BrainSettings()
|
||||
|
||||
# Default class attributes
|
||||
model: str = None # pyright: ignore reportPrivateUsage=none
|
||||
model: str = "gpt-3.5-turbo-0125" # pyright: ignore reportPrivateUsage=none
|
||||
temperature: float = 0.1
|
||||
chat_id: str = None # pyright: ignore reportPrivateUsage=none
|
||||
brain_id: str = None # pyright: ignore reportPrivateUsage=none
|
||||
@ -124,8 +130,13 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
max_input: int = 2000
|
||||
streaming: bool = False
|
||||
knowledge_qa: Optional[RAGInterface] = None
|
||||
metadata: Optional[dict] = None
|
||||
brain: Optional[BrainEntity] = None
|
||||
user_id: str = None
|
||||
user_email: str = None
|
||||
user_usage: Optional[UserUsage] = None
|
||||
user_settings: Optional[dict] = None
|
||||
models_settings: Optional[List[dict]] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
callbacks: List[AsyncIteratorCallbackHandler] = (
|
||||
None # pyright: ignore reportPrivateUsage=none
|
||||
@ -135,34 +146,43 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
brain_id: str,
|
||||
chat_id: str,
|
||||
max_tokens: int,
|
||||
streaming: bool = False,
|
||||
prompt_id: Optional[UUID] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
user_id: str = None,
|
||||
user_email: str = None,
|
||||
cost: int = 100,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
brain_id=brain_id,
|
||||
chat_id=chat_id,
|
||||
streaming=streaming,
|
||||
**kwargs,
|
||||
)
|
||||
self.prompt_id = prompt_id
|
||||
self.user_id = user_id
|
||||
self.user_email = user_email
|
||||
self.user_usage = UserUsage(
|
||||
id=user_id,
|
||||
email=user_email,
|
||||
)
|
||||
self.brain = brain_service.get_brain_by_id(brain_id)
|
||||
|
||||
self.user_settings = self.user_usage.get_user_settings()
|
||||
|
||||
# Get Model settings for the user
|
||||
self.models_settings = self.user_usage.get_model_settings()
|
||||
self.increase_usage_user()
|
||||
self.knowledge_qa = QuivrRAG(
|
||||
model=model,
|
||||
model=self.brain.model,
|
||||
brain_id=brain_id,
|
||||
chat_id=chat_id,
|
||||
streaming=streaming,
|
||||
**kwargs,
|
||||
)
|
||||
self.metadata = metadata
|
||||
self.max_tokens = max_tokens
|
||||
self.user_id = user_id
|
||||
|
||||
@property
|
||||
def prompt_to_use(self):
|
||||
@ -179,6 +199,38 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
else:
|
||||
return None
|
||||
|
||||
def increase_usage_user(self):
|
||||
# Raises an error if the user has consumed all of of his credits
|
||||
|
||||
update_user_usage(
|
||||
usage=self.user_usage,
|
||||
user_settings=self.user_settings,
|
||||
cost=self.calculate_pricing(),
|
||||
)
|
||||
|
||||
def calculate_pricing(self):
|
||||
|
||||
logger.info("Calculating pricing")
|
||||
logger.info(f"Model: {self.model}")
|
||||
logger.info(f"User settings: {self.user_settings}")
|
||||
logger.info(f"Models settings: {self.models_settings}")
|
||||
model_to_use = find_model_and_generate_metadata(
|
||||
self.chat_id,
|
||||
self.brain.model,
|
||||
self.user_settings,
|
||||
self.models_settings,
|
||||
)
|
||||
self.model = model_to_use.name
|
||||
self.max_input = model_to_use.max_input
|
||||
self.max_tokens = model_to_use.max_output
|
||||
user_choosen_model_price = 1000
|
||||
|
||||
for model_setting in self.models_settings:
|
||||
if model_setting["name"] == self.model:
|
||||
user_choosen_model_price = model_setting["price"]
|
||||
|
||||
return user_choosen_model_price
|
||||
|
||||
def generate_answer(
|
||||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True
|
||||
) -> GetChatHistoryOutput:
|
||||
@ -198,8 +250,6 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
|
||||
answer = model_response["answer"].content
|
||||
|
||||
brain = brain_service.get_brain_by_id(self.brain_id)
|
||||
|
||||
if save_answer:
|
||||
# save the answer to the database or not -> add a variable
|
||||
new_chat = chat_service.update_chat_history(
|
||||
@ -208,7 +258,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
"chat_id": chat_id,
|
||||
"user_message": question.question,
|
||||
"assistant": answer,
|
||||
"brain_id": brain.brain_id,
|
||||
"brain_id": self.brain.brain_id,
|
||||
"prompt_id": self.prompt_to_use_id,
|
||||
}
|
||||
)
|
||||
@ -223,9 +273,9 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
"prompt_title": (
|
||||
self.prompt_to_use.title if self.prompt_to_use else None
|
||||
),
|
||||
"brain_name": brain.name if brain else None,
|
||||
"brain_name": self.brain.name if self.brain else None,
|
||||
"message_id": new_chat.message_id,
|
||||
"brain_id": str(brain.brain_id) if brain else None,
|
||||
"brain_id": str(self.brain.brain_id) if self.brain else None,
|
||||
}
|
||||
)
|
||||
|
||||
@ -240,7 +290,7 @@ class KnowledgeBrainQA(BaseModel, QAInterface):
|
||||
),
|
||||
"brain_name": None,
|
||||
"message_id": None,
|
||||
"brain_id": str(brain.brain_id) if brain else None,
|
||||
"brain_id": str(self.brain.brain_id) if self.brain else None,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -10,6 +10,12 @@ class QAInterface(ABC):
|
||||
This can be used to implement custom answer generation logic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def calculate_pricing(self):
|
||||
raise NotImplementedError(
|
||||
"calculate_pricing is an abstract method and must be implemented"
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def generate_answer(
|
||||
self,
|
||||
|
@ -59,26 +59,20 @@ class BrainfulChat(ChatInterface):
|
||||
brain,
|
||||
chat_id,
|
||||
model,
|
||||
max_tokens,
|
||||
max_input,
|
||||
temperature,
|
||||
streaming,
|
||||
prompt_id,
|
||||
user_id,
|
||||
metadata,
|
||||
user_email,
|
||||
):
|
||||
if brain and brain.brain_type == BrainType.DOC:
|
||||
return KnowledgeBrainQA(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
max_input=max_input,
|
||||
temperature=temperature,
|
||||
brain_id=str(brain.brain_id),
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
metadata=metadata,
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
)
|
||||
|
||||
if brain.brain_type == BrainType.API:
|
||||
@ -88,18 +82,16 @@ class BrainfulChat(ChatInterface):
|
||||
return APIBrainQA(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
max_input=max_input,
|
||||
temperature=temperature,
|
||||
brain_id=str(brain.brain_id),
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
user_id=user_id,
|
||||
metadata=metadata,
|
||||
raw=(brain_definition.raw if brain_definition else None),
|
||||
jq_instructions=(
|
||||
brain_definition.jq_instructions if brain_definition else None
|
||||
),
|
||||
user_email=user_email,
|
||||
)
|
||||
if brain.brain_type == BrainType.INTEGRATION:
|
||||
integration_brain = integration_brain_description_service.get_integration_description_by_user_brain_id(
|
||||
@ -113,12 +105,10 @@ class BrainfulChat(ChatInterface):
|
||||
return integration_class(
|
||||
chat_id=chat_id,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
max_input=max_input,
|
||||
temperature=temperature,
|
||||
brain_id=str(brain.brain_id),
|
||||
streaming=streaming,
|
||||
prompt_id=prompt_id,
|
||||
metadata=metadata,
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
)
|
||||
|
@ -7,8 +7,8 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
from models.databases.entity import LLMModels
|
||||
from modules.chat.controller.chat.utils import (
|
||||
check_user_requests_limit,
|
||||
find_model_and_generate_metadata,
|
||||
update_user_usage,
|
||||
)
|
||||
|
||||
|
||||
@ -76,7 +76,7 @@ def test_find_model_and_generate_metadata_user_not_allowed(mock_chat_service):
|
||||
|
||||
|
||||
@patch("modules.chat.controller.chat.utils.time")
|
||||
def test_check_user_requests_limit_within_limit(mock_time):
|
||||
def test_check_update_user_usage_within_limit(mock_time):
|
||||
mock_time.strftime.return_value = "20220101"
|
||||
usage = Mock()
|
||||
usage.get_user_monthly_usage.return_value = 50
|
||||
@ -84,13 +84,13 @@ def test_check_user_requests_limit_within_limit(mock_time):
|
||||
models_settings = [{"name": "gpt-3.5-turbo", "price": 10}]
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
check_user_requests_limit(usage, user_settings, models_settings, model_name)
|
||||
update_user_usage(usage, user_settings, models_settings, model_name)
|
||||
|
||||
usage.handle_increment_user_request_count.assert_called_once_with("20220101", 10)
|
||||
|
||||
|
||||
@patch("modules.chat.controller.chat.utils.time")
|
||||
def test_check_user_requests_limit_exceeds_limit(mock_time):
|
||||
def test_update_user_usage_exceeds_limit(mock_time):
|
||||
mock_time.strftime.return_value = "20220101"
|
||||
usage = Mock()
|
||||
usage.get_user_monthly_usage.return_value = 100
|
||||
@ -99,7 +99,7 @@ def test_check_user_requests_limit_exceeds_limit(mock_time):
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
check_user_requests_limit(usage, user_settings, models_settings, model_name)
|
||||
update_user_usage(usage, user_settings, models_settings, model_name)
|
||||
|
||||
assert exc_info.value.status_code == 429
|
||||
assert (
|
||||
|
@ -31,44 +31,39 @@ class NullableUUID(UUID):
|
||||
|
||||
def find_model_and_generate_metadata(
|
||||
chat_id: UUID,
|
||||
brain,
|
||||
brain_model: str,
|
||||
user_settings,
|
||||
models_settings,
|
||||
metadata_brain,
|
||||
):
|
||||
# Add metadata_brain to metadata
|
||||
metadata = {}
|
||||
metadata = {**metadata, **metadata_brain}
|
||||
follow_up_questions = chat_service.get_follow_up_question(chat_id)
|
||||
metadata["follow_up_questions"] = follow_up_questions
|
||||
|
||||
# Default model is gpt-3.5-turbo-0125
|
||||
default_model = "gpt-3.5-turbo-0125"
|
||||
model_to_use = LLMModels( # TODO Implement default models in database
|
||||
name=default_model, price=1, max_input=4000, max_output=1000
|
||||
)
|
||||
|
||||
logger.info("Brain model: %s", brain.model)
|
||||
logger.debug("Brain model: %s", brain_model)
|
||||
|
||||
# If brain.model is None, set it to the default_model
|
||||
if brain.model is None:
|
||||
brain.model = default_model
|
||||
if brain_model is None:
|
||||
brain_model = default_model
|
||||
|
||||
is_brain_model_available = any(
|
||||
brain.model == model_dict.get("name") for model_dict in models_settings
|
||||
brain_model == model_dict.get("name") for model_dict in models_settings
|
||||
)
|
||||
|
||||
is_user_allowed_model = brain.model in user_settings.get(
|
||||
is_user_allowed_model = brain_model in user_settings.get(
|
||||
"models", [default_model]
|
||||
) # Checks if the model is available in the list of models
|
||||
|
||||
logger.info(f"Brain model: {brain.model}")
|
||||
logger.info(f"User models: {user_settings.get('models', [])}")
|
||||
logger.info(f"Model available: {is_brain_model_available}")
|
||||
logger.info(f"User allowed model: {is_user_allowed_model}")
|
||||
logger.debug(f"Brain model: {brain_model}")
|
||||
logger.debug(f"User models: {user_settings.get('models', [])}")
|
||||
logger.debug(f"Model available: {is_brain_model_available}")
|
||||
logger.debug(f"User allowed model: {is_user_allowed_model}")
|
||||
|
||||
if is_brain_model_available and is_user_allowed_model:
|
||||
# Use the model from the brain
|
||||
model_to_use.name = brain.model
|
||||
model_to_use.name = brain_model
|
||||
for model_dict in models_settings:
|
||||
if model_dict.get("name") == model_to_use.name:
|
||||
model_to_use.price = model_dict.get("price")
|
||||
@ -76,19 +71,12 @@ def find_model_and_generate_metadata(
|
||||
model_to_use.max_output = model_dict.get("max_output")
|
||||
break
|
||||
|
||||
metadata["model"] = model_to_use.name
|
||||
metadata["max_tokens"] = model_to_use.max_output
|
||||
metadata["max_input"] = model_to_use.max_input
|
||||
|
||||
logger.info(f"Model to use: {model_to_use}")
|
||||
logger.info(f"Metadata: {metadata}")
|
||||
|
||||
return model_to_use, metadata
|
||||
return model_to_use
|
||||
|
||||
|
||||
def check_user_requests_limit(
|
||||
usage: UserUsage, user_settings, models_settings, model_name: str
|
||||
):
|
||||
def update_user_usage(usage: UserUsage, user_settings, cost: int = 100):
|
||||
"""Checks the user requests limit.
|
||||
It checks the user requests limit and raises an exception if the user has reached the limit.
|
||||
By default, the user has a limit of 100 requests per month. The limit can be increased by upgrading the plan.
|
||||
@ -105,18 +93,13 @@ def check_user_requests_limit(
|
||||
date = time.strftime("%Y%m%d")
|
||||
|
||||
monthly_chat_credit = user_settings.get("monthly_chat_credit", 100)
|
||||
daily_user_count = usage.get_user_monthly_usage(date)
|
||||
user_choosen_model_price = 1000
|
||||
montly_usage = usage.get_user_monthly_usage(date)
|
||||
|
||||
for model_setting in models_settings:
|
||||
if model_setting["name"] == model_name:
|
||||
user_choosen_model_price = model_setting["price"]
|
||||
|
||||
if int(daily_user_count + user_choosen_model_price) > int(monthly_chat_credit):
|
||||
if int(montly_usage + cost) > int(monthly_chat_credit):
|
||||
raise HTTPException(
|
||||
status_code=429, # pyright: ignore reportPrivateUsage=none
|
||||
detail=f"You have reached your monthly chat limit of {monthly_chat_credit} requests per months. Please upgrade your plan to increase your daily chat limit.",
|
||||
)
|
||||
else:
|
||||
usage.handle_increment_user_request_count(date, user_choosen_model_price)
|
||||
usage.handle_increment_user_request_count(date, cost)
|
||||
pass
|
||||
|
@ -11,10 +11,6 @@ from models.settings import BrainSettings, get_supabase_client
|
||||
from models.user_usage import UserUsage
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.controller.chat.brainful_chat import BrainfulChat
|
||||
from modules.chat.controller.chat.utils import (
|
||||
check_user_requests_limit,
|
||||
find_model_and_generate_metadata,
|
||||
)
|
||||
from modules.chat.dto.chats import ChatItem, ChatQuestion
|
||||
from modules.chat.dto.inputs import (
|
||||
ChatUpdatableProperties,
|
||||
@ -76,12 +72,6 @@ def get_answer_generator(
|
||||
# Get History
|
||||
history = chat_service.get_chat_history(chat_id)
|
||||
|
||||
# Get user settings
|
||||
user_settings = user_usage.get_user_settings()
|
||||
|
||||
# Get Model settings for the user
|
||||
models_settings = user_usage.get_model_settings()
|
||||
|
||||
# Generic
|
||||
brain, metadata_brain = brain_service.find_brain_from_question(
|
||||
brain_id, chat_question.question, current_user, chat_id, history, vector_store
|
||||
@ -89,35 +79,17 @@ def get_answer_generator(
|
||||
|
||||
logger.info(f"Brain: {brain}")
|
||||
|
||||
model_to_use, metadata = find_model_and_generate_metadata(
|
||||
chat_id,
|
||||
brain,
|
||||
user_settings,
|
||||
models_settings,
|
||||
metadata_brain,
|
||||
)
|
||||
|
||||
# Raises an error if the user has consumed all of of his credits
|
||||
check_user_requests_limit(
|
||||
usage=user_usage,
|
||||
user_settings=user_settings,
|
||||
models_settings=models_settings,
|
||||
model_name=model_to_use.name,
|
||||
)
|
||||
|
||||
send_telemetry("question_asked", {"model_name": model_to_use.name})
|
||||
send_telemetry("question_asked", {"model_name": brain.model})
|
||||
|
||||
gpt_answer_generator = chat_instance.get_answer_generator(
|
||||
brain=brain,
|
||||
chat_id=str(chat_id),
|
||||
model=model_to_use.name,
|
||||
max_tokens=model_to_use.max_output,
|
||||
max_input=model_to_use.max_input,
|
||||
model=brain.model,
|
||||
temperature=0.1,
|
||||
streaming=True,
|
||||
prompt_id=chat_question.prompt_id,
|
||||
user_id=current_user.id,
|
||||
metadata=metadata,
|
||||
brain=brain,
|
||||
user_email=current_user.email,
|
||||
)
|
||||
|
||||
return gpt_answer_generator
|
||||
|
Loading…
Reference in New Issue
Block a user