mirror of
https://github.com/QuivrHQ/quivr.git
synced 2025-01-05 23:03:53 +03:00
feat: 🎸 user-limits (#2104)
optimized number of calls # Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
bc5545e4cf
commit
2e06b5c7f2
@ -288,16 +288,12 @@ class CompositeBrainQA(
|
|||||||
question=function_args["question"], brain_id=function_name
|
question=function_args["question"], brain_id=function_name
|
||||||
)
|
)
|
||||||
|
|
||||||
print("querying brain", function_name)
|
|
||||||
# TODO: extract chat_id from generate_answer function of XBrainQA
|
# TODO: extract chat_id from generate_answer function of XBrainQA
|
||||||
function_response = function_to_call(
|
function_response = function_to_call(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
question=question,
|
question=question,
|
||||||
save_answer=False,
|
save_answer=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("brain_answer", function_response.assistant)
|
|
||||||
|
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
"tool_call_id": tool_call.id,
|
"tool_call_id": tool_call.id,
|
||||||
|
@ -0,0 +1,72 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import create_autospec
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from modules.brain.dto.inputs import CreateBrainProperties
|
||||||
|
from modules.brain.entity.brain_entity import BrainEntity, BrainType
|
||||||
|
from modules.brain.repository.interfaces.brains_interface import BrainsInterface
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_brains_interface():
|
||||||
|
return create_autospec(BrainsInterface)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_brain(mock_brains_interface):
|
||||||
|
brain = CreateBrainProperties()
|
||||||
|
mock_brains_interface.create_brain.return_value = BrainEntity(
|
||||||
|
brain_id=uuid.uuid4(), # generate a valid UUID
|
||||||
|
name="test_name",
|
||||||
|
last_update=datetime.now().isoformat(), # convert datetime to string
|
||||||
|
brain_type=BrainType.DOC,
|
||||||
|
)
|
||||||
|
result = mock_brains_interface.create_brain(brain)
|
||||||
|
mock_brains_interface.create_brain.assert_called_once_with(brain)
|
||||||
|
assert isinstance(result, BrainEntity)
|
||||||
|
|
||||||
|
|
||||||
|
def test_brain_entity_creation():
|
||||||
|
brain_id = uuid.uuid4()
|
||||||
|
name = "test_name"
|
||||||
|
last_update = datetime.now().isoformat()
|
||||||
|
brain_type = BrainType.DOC
|
||||||
|
|
||||||
|
brain_entity = BrainEntity(
|
||||||
|
brain_id=brain_id, name=name, last_update=last_update, brain_type=brain_type
|
||||||
|
)
|
||||||
|
|
||||||
|
assert brain_entity.brain_id == brain_id
|
||||||
|
assert brain_entity.name == name
|
||||||
|
assert brain_entity.last_update == last_update
|
||||||
|
assert brain_entity.brain_type == brain_type
|
||||||
|
|
||||||
|
|
||||||
|
def test_brain_entity_id_property():
|
||||||
|
brain_id = uuid.uuid4()
|
||||||
|
name = "test_name"
|
||||||
|
last_update = datetime.now().isoformat()
|
||||||
|
brain_type = BrainType.DOC
|
||||||
|
|
||||||
|
brain_entity = BrainEntity(
|
||||||
|
brain_id=brain_id, name=name, last_update=last_update, brain_type=brain_type
|
||||||
|
)
|
||||||
|
|
||||||
|
assert brain_entity.id == brain_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_brain_entity_dict_method():
|
||||||
|
brain_id = uuid.uuid4()
|
||||||
|
name = "test_name"
|
||||||
|
last_update = datetime.now().isoformat()
|
||||||
|
brain_type = BrainType.DOC
|
||||||
|
|
||||||
|
brain_entity = BrainEntity(
|
||||||
|
brain_id=brain_id, name=name, last_update=last_update, brain_type=brain_type
|
||||||
|
)
|
||||||
|
|
||||||
|
brain_dict = brain_entity.dict()
|
||||||
|
assert brain_dict["id"] == brain_id
|
||||||
|
assert brain_dict["name"] == name
|
||||||
|
assert brain_dict["last_update"] == last_update
|
||||||
|
assert brain_dict["brain_type"] == brain_type
|
@ -4,7 +4,7 @@ from uuid import UUID
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from langchain.embeddings.ollama import OllamaEmbeddings
|
from langchain.embeddings.ollama import OllamaEmbeddings
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
from vectorstore.supabase import CustomSupabaseVectorStore
|
from logger import get_logger
|
||||||
from models.settings import BrainSettings, get_supabase_client
|
from models.settings import BrainSettings, get_supabase_client
|
||||||
from modules.brain.dto.inputs import BrainUpdatableProperties, CreateBrainProperties
|
from modules.brain.dto.inputs import BrainUpdatableProperties, CreateBrainProperties
|
||||||
from modules.brain.entity.brain_entity import BrainEntity, BrainType, PublicBrain
|
from modules.brain.entity.brain_entity import BrainEntity, BrainType, PublicBrain
|
||||||
@ -25,8 +25,7 @@ from modules.brain.repository.interfaces import (
|
|||||||
from modules.brain.service.api_brain_definition_service import ApiBrainDefinitionService
|
from modules.brain.service.api_brain_definition_service import ApiBrainDefinitionService
|
||||||
from modules.brain.service.utils.validate_brain import validate_api_brain
|
from modules.brain.service.utils.validate_brain import validate_api_brain
|
||||||
from modules.knowledge.service.knowledge_service import KnowledgeService
|
from modules.knowledge.service.knowledge_service import KnowledgeService
|
||||||
|
from vectorstore.supabase import CustomSupabaseVectorStore
|
||||||
from logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -84,6 +83,7 @@ class BrainService:
|
|||||||
# Init
|
# Init
|
||||||
|
|
||||||
brain_id_to_use = brain_id
|
brain_id_to_use = brain_id
|
||||||
|
brain_to_use = None
|
||||||
|
|
||||||
# Get the first question from the chat_question
|
# Get the first question from the chat_question
|
||||||
|
|
||||||
@ -97,6 +97,7 @@ class BrainService:
|
|||||||
|
|
||||||
if history and not brain_id:
|
if history and not brain_id:
|
||||||
brain_id_to_use = history[0].brain_id
|
brain_id_to_use = history[0].brain_id
|
||||||
|
brain_to_use = self.get_brain_by_id(brain_id_to_use)
|
||||||
|
|
||||||
# Calculate the closest brains to the question
|
# Calculate the closest brains to the question
|
||||||
list_brains = vector_store.find_brain_closest_query(user.id, question)
|
list_brains = vector_store.find_brain_closest_query(user.id, question)
|
||||||
@ -111,10 +112,11 @@ class BrainService:
|
|||||||
|
|
||||||
metadata["close_brains"] = unique_list_brains[:5]
|
metadata["close_brains"] = unique_list_brains[:5]
|
||||||
|
|
||||||
if list_brains and not brain_id_to_use:
|
if list_brains and not brain_to_use:
|
||||||
brain_id_to_use = list_brains[0]["id"]
|
brain_id_to_use = list_brains[0]["id"]
|
||||||
|
brain_to_use = self.get_brain_by_id(brain_id_to_use)
|
||||||
|
|
||||||
return brain_id_to_use, metadata
|
return brain_to_use, metadata
|
||||||
|
|
||||||
def create_brain(
|
def create_brain(
|
||||||
self,
|
self,
|
||||||
|
108
backend/modules/chat/controller/chat/test_utils.py
Normal file
108
backend/modules/chat/controller/chat/test_utils.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# FILEPATH: /Users/stan/Dev/Padok/secondbrain/backend/modules/chat/controller/chat/test_utils.py
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from models.databases.entity import LLMModels
|
||||||
|
from modules.chat.controller.chat.utils import (
|
||||||
|
check_user_requests_limit,
|
||||||
|
find_model_and_generate_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("modules.chat.controller.chat.utils.chat_service")
|
||||||
|
def test_find_model_and_generate_metadata(mock_chat_service):
|
||||||
|
chat_id = uuid.uuid4()
|
||||||
|
brain = Mock()
|
||||||
|
brain.model = "gpt-3.5-turbo-1106"
|
||||||
|
user_settings = {"models": ["gpt-3.5-turbo-1106"]}
|
||||||
|
models_settings = [
|
||||||
|
{"name": "gpt-3.5-turbo-1106", "max_input": 512, "max_output": 512}
|
||||||
|
]
|
||||||
|
metadata_brain = {"key": "value"}
|
||||||
|
|
||||||
|
mock_chat_service.get_follow_up_question.return_value = []
|
||||||
|
|
||||||
|
model_to_use, metadata = find_model_and_generate_metadata(
|
||||||
|
chat_id, brain, user_settings, models_settings, metadata_brain
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(model_to_use, LLMModels)
|
||||||
|
assert model_to_use.name == "gpt-3.5-turbo-1106"
|
||||||
|
assert model_to_use.max_input == 512
|
||||||
|
assert model_to_use.max_output == 512
|
||||||
|
assert metadata == {
|
||||||
|
"key": "value",
|
||||||
|
"follow_up_questions": [],
|
||||||
|
"model": "gpt-3.5-turbo-1106",
|
||||||
|
"max_tokens": 512,
|
||||||
|
"max_input": 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("modules.chat.controller.chat.utils.chat_service")
|
||||||
|
def test_find_model_and_generate_metadata_user_not_allowed(mock_chat_service):
|
||||||
|
chat_id = uuid.uuid4()
|
||||||
|
brain = Mock()
|
||||||
|
brain.model = "gpt-3.5-turbo-1106"
|
||||||
|
user_settings = {
|
||||||
|
"models": ["gpt-3.5-turbo-1107"]
|
||||||
|
} # User is not allowed to use the brain's model
|
||||||
|
models_settings = [
|
||||||
|
{"name": "gpt-3.5-turbo-1106", "max_input": 512, "max_output": 512},
|
||||||
|
{"name": "gpt-3.5-turbo-1107", "max_input": 512, "max_output": 512},
|
||||||
|
]
|
||||||
|
metadata_brain = {"key": "value"}
|
||||||
|
|
||||||
|
mock_chat_service.get_follow_up_question.return_value = []
|
||||||
|
|
||||||
|
model_to_use, metadata = find_model_and_generate_metadata(
|
||||||
|
chat_id, brain, user_settings, models_settings, metadata_brain
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(model_to_use, LLMModels)
|
||||||
|
assert model_to_use.name == "gpt-3.5-turbo-1106" # Default model is used
|
||||||
|
assert model_to_use.max_input == 512
|
||||||
|
assert model_to_use.max_output == 512
|
||||||
|
assert metadata == {
|
||||||
|
"key": "value",
|
||||||
|
"follow_up_questions": [],
|
||||||
|
"model": "gpt-3.5-turbo-1106",
|
||||||
|
"max_tokens": 512,
|
||||||
|
"max_input": 512,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("modules.chat.controller.chat.utils.time")
|
||||||
|
def test_check_user_requests_limit_within_limit(mock_time):
|
||||||
|
mock_time.strftime.return_value = "20220101"
|
||||||
|
usage = Mock()
|
||||||
|
usage.get_user_monthly_usage.return_value = 50
|
||||||
|
user_settings = {"monthly_chat_credit": 100}
|
||||||
|
models_settings = [{"name": "gpt-3.5-turbo", "price": 10}]
|
||||||
|
model_name = "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
check_user_requests_limit(usage, user_settings, models_settings, model_name)
|
||||||
|
|
||||||
|
usage.handle_increment_user_request_count.assert_called_once_with("20220101", 10)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("modules.chat.controller.chat.utils.time")
|
||||||
|
def test_check_user_requests_limit_exceeds_limit(mock_time):
|
||||||
|
mock_time.strftime.return_value = "20220101"
|
||||||
|
usage = Mock()
|
||||||
|
usage.get_user_monthly_usage.return_value = 100
|
||||||
|
user_settings = {"monthly_chat_credit": 100}
|
||||||
|
models_settings = [{"name": "gpt-3.5-turbo", "price": 10}]
|
||||||
|
model_name = "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
check_user_requests_limit(usage, user_settings, models_settings, model_name)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 429
|
||||||
|
assert (
|
||||||
|
"You have reached your monthly chat limit of 100 requests per months."
|
||||||
|
in str(exc_info.value.detail)
|
||||||
|
)
|
@ -4,9 +4,13 @@ from uuid import UUID
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from logger import get_logger
|
from logger import get_logger
|
||||||
from models import UserUsage
|
from models import UserUsage
|
||||||
from modules.user.entity.user_identity import UserIdentity
|
from models.databases.entity import LLMModels
|
||||||
|
from modules.brain.service.brain_service import BrainService
|
||||||
|
from modules.chat.service.chat_service import ChatService
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
brain_service = BrainService()
|
||||||
|
chat_service = ChatService()
|
||||||
|
|
||||||
|
|
||||||
class NullableUUID(UUID):
|
class NullableUUID(UUID):
|
||||||
@ -24,8 +28,56 @@ class NullableUUID(UUID):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_user_requests_limit(user: UserIdentity, model: str):
|
def find_model_and_generate_metadata(
|
||||||
# TODO : Pass objects to avoid multiple calls to the database
|
chat_id: UUID,
|
||||||
|
brain,
|
||||||
|
user_settings,
|
||||||
|
models_settings,
|
||||||
|
metadata_brain,
|
||||||
|
):
|
||||||
|
# Add metadata_brain to metadata
|
||||||
|
metadata = {}
|
||||||
|
metadata = {**metadata, **metadata_brain}
|
||||||
|
follow_up_questions = chat_service.get_follow_up_question(chat_id)
|
||||||
|
metadata["follow_up_questions"] = follow_up_questions
|
||||||
|
# Default model is gpt-3.5-turbo-1106
|
||||||
|
model_to_use = LLMModels(
|
||||||
|
name="gpt-3.5-turbo-1106", price=1, max_input=512, max_output=512
|
||||||
|
)
|
||||||
|
|
||||||
|
is_brain_model_available = any(
|
||||||
|
brain.model == model_dict.get("name") for model_dict in models_settings
|
||||||
|
)
|
||||||
|
|
||||||
|
is_user_allowed_model = brain.model in user_settings.get(
|
||||||
|
"models", ["gpt-3.5-turbo-1106"]
|
||||||
|
) # Checks if the model is available in the list of models
|
||||||
|
|
||||||
|
logger.info(f"Brain model: {brain.model}")
|
||||||
|
logger.info(f"User models: {user_settings.get('models', [])}")
|
||||||
|
logger.info(f"Model available: {is_brain_model_available}")
|
||||||
|
logger.info(f"User allowed model: {is_user_allowed_model}")
|
||||||
|
|
||||||
|
if is_brain_model_available and is_user_allowed_model:
|
||||||
|
# Use the model from the brain
|
||||||
|
model_to_use.name = brain.model
|
||||||
|
for model_dict in models_settings:
|
||||||
|
if model_dict.get("name") == model_to_use.name:
|
||||||
|
logger.info(f"Using model {model_to_use.name}")
|
||||||
|
model_to_use.max_input = model_dict.get("max_input")
|
||||||
|
model_to_use.max_output = model_dict.get("max_output")
|
||||||
|
break
|
||||||
|
|
||||||
|
metadata["model"] = model_to_use.name
|
||||||
|
metadata["max_tokens"] = model_to_use.max_output
|
||||||
|
metadata["max_input"] = model_to_use.max_input
|
||||||
|
|
||||||
|
return model_to_use, metadata
|
||||||
|
|
||||||
|
|
||||||
|
def check_user_requests_limit(
|
||||||
|
usage: UserUsage, user_settings, models_settings, model_name: str
|
||||||
|
):
|
||||||
"""Checks the user requests limit.
|
"""Checks the user requests limit.
|
||||||
It checks the user requests limit and raises an exception if the user has reached the limit.
|
It checks the user requests limit and raises an exception if the user has reached the limit.
|
||||||
By default, the user has a limit of 100 requests per month. The limit can be increased by upgrading the plan.
|
By default, the user has a limit of 100 requests per month. The limit can be increased by upgrading the plan.
|
||||||
@ -37,19 +89,16 @@ def check_user_requests_limit(user: UserIdentity, model: str):
|
|||||||
Raises:
|
Raises:
|
||||||
HTTPException: Raises a 429 error if the user has reached the limit.
|
HTTPException: Raises a 429 error if the user has reached the limit.
|
||||||
"""
|
"""
|
||||||
userDailyUsage = UserUsage(id=user.id, email=user.email)
|
usage
|
||||||
|
|
||||||
userSettings = userDailyUsage.get_user_settings()
|
|
||||||
|
|
||||||
date = time.strftime("%Y%m%d")
|
date = time.strftime("%Y%m%d")
|
||||||
|
|
||||||
monthly_chat_credit = userSettings.get("monthly_chat_credit", 100)
|
monthly_chat_credit = user_settings.get("monthly_chat_credit", 100)
|
||||||
daily_user_count = userDailyUsage.get_user_monthly_usage(date)
|
daily_user_count = usage.get_user_monthly_usage(date)
|
||||||
models_price = userDailyUsage.get_model_settings()
|
|
||||||
user_choosen_model_price = 1000
|
user_choosen_model_price = 1000
|
||||||
|
|
||||||
for model_setting in models_price:
|
for model_setting in models_settings:
|
||||||
if model_setting["name"] == model:
|
if model_setting["name"] == model_name:
|
||||||
user_choosen_model_price = model_setting["price"]
|
user_choosen_model_price = model_setting["price"]
|
||||||
|
|
||||||
if int(daily_user_count + user_choosen_model_price) > int(monthly_chat_credit):
|
if int(daily_user_count + user_choosen_model_price) > int(monthly_chat_credit):
|
||||||
@ -58,7 +107,5 @@ def check_user_requests_limit(user: UserIdentity, model: str):
|
|||||||
detail=f"You have reached your monthly chat limit of {monthly_chat_credit} requests per months. Please upgrade your plan to increase your daily chat limit.",
|
detail=f"You have reached your monthly chat limit of {monthly_chat_credit} requests per months. Please upgrade your plan to increase your daily chat limit.",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
userDailyUsage.handle_increment_user_request_count(
|
usage.handle_increment_user_request_count(date, user_choosen_model_price)
|
||||||
date, user_choosen_model_price
|
|
||||||
)
|
|
||||||
pass
|
pass
|
||||||
|
@ -3,13 +3,16 @@ from uuid import UUID
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from logger import get_logger
|
||||||
from middlewares.auth import AuthBearer, get_current_user
|
from middlewares.auth import AuthBearer, get_current_user
|
||||||
from models.databases.entity import LLMModels
|
|
||||||
from models.user_usage import UserUsage
|
from models.user_usage import UserUsage
|
||||||
from modules.brain.service.brain_service import BrainService
|
from modules.brain.service.brain_service import BrainService
|
||||||
from modules.chat.controller.chat.brainful_chat import BrainfulChat
|
from modules.chat.controller.chat.brainful_chat import BrainfulChat
|
||||||
from modules.chat.controller.chat.factory import get_chat_strategy
|
from modules.chat.controller.chat.utils import (
|
||||||
from modules.chat.controller.chat.utils import NullableUUID, check_user_requests_limit
|
NullableUUID,
|
||||||
|
check_user_requests_limit,
|
||||||
|
find_model_and_generate_metadata,
|
||||||
|
)
|
||||||
from modules.chat.dto.chats import ChatItem, ChatQuestion
|
from modules.chat.dto.chats import ChatItem, ChatQuestion
|
||||||
from modules.chat.dto.inputs import (
|
from modules.chat.dto.inputs import (
|
||||||
ChatUpdatableProperties,
|
ChatUpdatableProperties,
|
||||||
@ -21,8 +24,6 @@ from modules.chat.service.chat_service import ChatService
|
|||||||
from modules.notification.service.notification_service import NotificationService
|
from modules.notification.service.notification_service import NotificationService
|
||||||
from modules.user.entity.user_identity import UserIdentity
|
from modules.user.entity.user_identity import UserIdentity
|
||||||
|
|
||||||
from logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
chat_router = APIRouter()
|
chat_router = APIRouter()
|
||||||
@ -32,6 +33,65 @@ brain_service = BrainService()
|
|||||||
chat_service = ChatService()
|
chat_service = ChatService()
|
||||||
|
|
||||||
|
|
||||||
|
def get_answer_generator(
|
||||||
|
chat_id: UUID,
|
||||||
|
chat_question: ChatQuestion,
|
||||||
|
brain_id: UUID,
|
||||||
|
current_user: UserIdentity,
|
||||||
|
):
|
||||||
|
chat_instance = BrainfulChat()
|
||||||
|
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
||||||
|
|
||||||
|
user_usage = UserUsage(
|
||||||
|
id=current_user.id,
|
||||||
|
email=current_user.email,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get History
|
||||||
|
history = chat_service.get_chat_history(chat_id)
|
||||||
|
|
||||||
|
# Get user settings
|
||||||
|
user_settings = user_usage.get_user_settings()
|
||||||
|
|
||||||
|
# Get Model settings for the user
|
||||||
|
models_settings = user_usage.get_model_settings()
|
||||||
|
|
||||||
|
# Generic
|
||||||
|
brain, metadata_brain = brain_service.find_brain_from_question(
|
||||||
|
brain_id, chat_question.question, current_user, chat_id, history
|
||||||
|
)
|
||||||
|
|
||||||
|
model_to_use, metadata = find_model_and_generate_metadata(
|
||||||
|
chat_id,
|
||||||
|
brain,
|
||||||
|
user_settings,
|
||||||
|
models_settings,
|
||||||
|
metadata_brain,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Raises an error if the user has consumed all of of his credits
|
||||||
|
check_user_requests_limit(
|
||||||
|
usage=user_usage,
|
||||||
|
user_settings=user_settings,
|
||||||
|
models_settings=models_settings,
|
||||||
|
model_name=model_to_use.name,
|
||||||
|
)
|
||||||
|
gpt_answer_generator = chat_instance.get_answer_generator(
|
||||||
|
chat_id=str(chat_id),
|
||||||
|
model=model_to_use.name,
|
||||||
|
max_tokens=model_to_use.max_output,
|
||||||
|
max_input=model_to_use.max_input,
|
||||||
|
temperature=0.1,
|
||||||
|
streaming=True,
|
||||||
|
prompt_id=chat_question.prompt_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
metadata=metadata,
|
||||||
|
brain=brain,
|
||||||
|
)
|
||||||
|
|
||||||
|
return gpt_answer_generator
|
||||||
|
|
||||||
|
|
||||||
@chat_router.get("/chat/healthz", tags=["Health"])
|
@chat_router.get("/chat/healthz", tags=["Health"])
|
||||||
async def healthz():
|
async def healthz():
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
@ -123,56 +183,9 @@ async def create_question_handler(
|
|||||||
| None = Query(..., description="The ID of the brain"),
|
| None = Query(..., description="The ID of the brain"),
|
||||||
current_user: UserIdentity = Depends(get_current_user),
|
current_user: UserIdentity = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Add a new question to the chat.
|
|
||||||
"""
|
|
||||||
|
|
||||||
chat_instance = get_chat_strategy(brain_id)
|
|
||||||
|
|
||||||
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
|
||||||
|
|
||||||
fallback_model = "gpt-3.5-turbo-1106"
|
|
||||||
fallback_temperature = 0.1
|
|
||||||
fallback_max_tokens = 512
|
|
||||||
|
|
||||||
user_daily_usage = UserUsage(
|
|
||||||
id=current_user.id,
|
|
||||||
email=current_user.email,
|
|
||||||
)
|
|
||||||
user_settings = user_daily_usage.get_user_settings()
|
|
||||||
is_model_ok = (chat_question).model in user_settings.get("models", ["gpt-3.5-turbo-1106"]) # type: ignore
|
|
||||||
|
|
||||||
# Retrieve chat model (temperature, max_tokens, model)
|
|
||||||
if (
|
|
||||||
not chat_question.model
|
|
||||||
or not chat_question.temperature
|
|
||||||
or not chat_question.max_tokens
|
|
||||||
):
|
|
||||||
if brain_id:
|
|
||||||
brain = brain_service.get_brain_by_id(brain_id)
|
|
||||||
if brain:
|
|
||||||
fallback_model = brain.model or fallback_model
|
|
||||||
fallback_temperature = brain.temperature or fallback_temperature
|
|
||||||
fallback_max_tokens = brain.max_tokens or fallback_max_tokens
|
|
||||||
|
|
||||||
chat_question.model = chat_question.model or fallback_model
|
|
||||||
chat_question.temperature = chat_question.temperature or fallback_temperature
|
|
||||||
chat_question.max_tokens = chat_question.max_tokens or fallback_max_tokens
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
check_user_requests_limit(current_user, chat_question.model)
|
gpt_answer_generator = get_answer_generator(
|
||||||
is_model_ok = (chat_question).model in user_settings.get("models", ["gpt-3.5-turbo-1106"]) # type: ignore
|
chat_id, chat_question, brain_id, current_user
|
||||||
gpt_answer_generator = chat_instance.get_answer_generator(
|
|
||||||
chat_id=str(chat_id),
|
|
||||||
model=chat_question.model if is_model_ok else "gpt-3.5-turbo-1106", # type: ignore
|
|
||||||
max_tokens=chat_question.max_tokens,
|
|
||||||
temperature=chat_question.temperature,
|
|
||||||
streaming=False,
|
|
||||||
prompt_id=chat_question.prompt_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
max_input=2000,
|
|
||||||
brain=brain_service.get_brain_by_id(brain_id),
|
|
||||||
metadata={},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_answer = gpt_answer_generator.generate_answer(
|
chat_answer = gpt_answer_generator.generate_answer(
|
||||||
@ -203,86 +216,11 @@ async def create_stream_question_handler(
|
|||||||
| None = Query(..., description="The ID of the brain"),
|
| None = Query(..., description="The ID of the brain"),
|
||||||
current_user: UserIdentity = Depends(get_current_user),
|
current_user: UserIdentity = Depends(get_current_user),
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
chat_instance = BrainfulChat()
|
gpt_answer_generator = get_answer_generator(
|
||||||
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
chat_id, chat_question, brain_id, current_user
|
||||||
|
|
||||||
user_usage = UserUsage(
|
|
||||||
id=current_user.id,
|
|
||||||
email=current_user.email,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get History
|
|
||||||
history = chat_service.get_chat_history(chat_id)
|
|
||||||
|
|
||||||
# Get user settings
|
|
||||||
user_settings = user_usage.get_user_settings()
|
|
||||||
|
|
||||||
# Get Model settings for the user
|
|
||||||
models_settings = user_usage.get_model_settings()
|
|
||||||
|
|
||||||
# Generic
|
|
||||||
brain_id_to_use, metadata_brain = brain_service.find_brain_from_question(
|
|
||||||
brain_id, chat_question.question, current_user, chat_id, history
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add metadata_brain to metadata
|
|
||||||
metadata = {}
|
|
||||||
metadata = {**metadata, **metadata_brain}
|
|
||||||
follow_up_questions = chat_service.get_follow_up_question(chat_id)
|
|
||||||
metadata["follow_up_questions"] = follow_up_questions
|
|
||||||
|
|
||||||
# Get the Brain settings
|
|
||||||
brain = brain_service.get_brain_by_id(brain_id_to_use)
|
|
||||||
|
|
||||||
logger.info(f"Brain model: {brain.model}")
|
|
||||||
logger.info(f"Brain is : {str(brain)}")
|
|
||||||
try:
|
try:
|
||||||
# Default model is gpt-3.5-turbo-1106
|
|
||||||
model_to_use = LLMModels(
|
|
||||||
name="gpt-3.5-turbo-1106", price=1, max_input=512, max_output=512
|
|
||||||
)
|
|
||||||
|
|
||||||
is_brain_model_available = any(
|
|
||||||
brain.model == model_dict.get("name") for model_dict in models_settings
|
|
||||||
)
|
|
||||||
|
|
||||||
is_user_allowed_model = brain.model in user_settings.get(
|
|
||||||
"models", ["gpt-3.5-turbo-1106"]
|
|
||||||
) # Checks if the model is available in the list of models
|
|
||||||
|
|
||||||
logger.info(f"Brain model: {brain.model}")
|
|
||||||
logger.info(f"User models: {user_settings.get('models', [])}")
|
|
||||||
logger.info(f"Model available: {is_brain_model_available}")
|
|
||||||
logger.info(f"User allowed model: {is_user_allowed_model}")
|
|
||||||
|
|
||||||
if is_brain_model_available and is_user_allowed_model:
|
|
||||||
# Use the model from the brain
|
|
||||||
model_to_use.name = brain.model
|
|
||||||
for model_dict in models_settings:
|
|
||||||
if model_dict.get("name") == model_to_use.name:
|
|
||||||
logger.info(f"Using model {model_to_use.name}")
|
|
||||||
model_to_use.max_input = model_dict.get("max_input")
|
|
||||||
model_to_use.max_output = model_dict.get("max_output")
|
|
||||||
break
|
|
||||||
|
|
||||||
metadata["model"] = model_to_use.name
|
|
||||||
metadata["max_tokens"] = model_to_use.max_output
|
|
||||||
metadata["max_input"] = model_to_use.max_input
|
|
||||||
|
|
||||||
check_user_requests_limit(current_user, model_to_use.name)
|
|
||||||
gpt_answer_generator = chat_instance.get_answer_generator(
|
|
||||||
chat_id=str(chat_id),
|
|
||||||
model=model_to_use.name,
|
|
||||||
max_tokens=model_to_use.max_output,
|
|
||||||
max_input=model_to_use.max_input,
|
|
||||||
temperature=0.1,
|
|
||||||
streaming=True,
|
|
||||||
prompt_id=chat_question.prompt_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
metadata=metadata,
|
|
||||||
brain=brain,
|
|
||||||
)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
gpt_answer_generator.generate_stream(
|
gpt_answer_generator.generate_stream(
|
||||||
chat_id, chat_question, save_answer=True
|
chat_id, chat_question, save_answer=True
|
||||||
|
@ -53,25 +53,12 @@ def test_create_chat_and_talk(client, api_key):
|
|||||||
"model": "gpt-3.5-turbo-1106",
|
"model": "gpt-3.5-turbo-1106",
|
||||||
"question": "Hello, how are you?",
|
"question": "Hello, how are you?",
|
||||||
"temperature": "0",
|
"temperature": "0",
|
||||||
"max_tokens": "256",
|
"max_tokens": "2000",
|
||||||
},
|
},
|
||||||
headers={"Authorization": "Bearer " + api_key},
|
headers={"Authorization": "Bearer " + api_key},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
response = client.post(
|
|
||||||
f"/chat/{chat_id}/question?brain_id={default_brain_id}",
|
|
||||||
json={
|
|
||||||
"model": "gpt-4",
|
|
||||||
"question": "Hello, how are you?",
|
|
||||||
"temperature": "0",
|
|
||||||
"max_tokens": "256",
|
|
||||||
},
|
|
||||||
headers={"Authorization": "Bearer " + api_key},
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
# Now, let's delete the chat
|
# Now, let's delete the chat
|
||||||
delete_response = client.delete(
|
delete_response = client.delete(
|
||||||
"/chat/" + chat_id, headers={"Authorization": "Bearer " + api_key}
|
"/chat/" + chat_id, headers={"Authorization": "Bearer " + api_key}
|
||||||
|
Loading…
Reference in New Issue
Block a user