mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 17:03:29 +03:00
feat: 🎸 user-limits (#2104)
optimized number of calls # Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
bc5545e4cf
commit
2e06b5c7f2
@ -288,16 +288,12 @@ class CompositeBrainQA(
|
||||
question=function_args["question"], brain_id=function_name
|
||||
)
|
||||
|
||||
print("querying brain", function_name)
|
||||
# TODO: extract chat_id from generate_answer function of XBrainQA
|
||||
function_response = function_to_call(
|
||||
chat_id=chat_id,
|
||||
question=question,
|
||||
save_answer=False,
|
||||
)
|
||||
|
||||
print("brain_answer", function_response.assistant)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": tool_call.id,
|
||||
|
@ -0,0 +1,72 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import create_autospec
|
||||
|
||||
import pytest
|
||||
from modules.brain.dto.inputs import CreateBrainProperties
|
||||
from modules.brain.entity.brain_entity import BrainEntity, BrainType
|
||||
from modules.brain.repository.interfaces.brains_interface import BrainsInterface
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_brains_interface():
|
||||
return create_autospec(BrainsInterface)
|
||||
|
||||
|
||||
def test_create_brain(mock_brains_interface):
|
||||
brain = CreateBrainProperties()
|
||||
mock_brains_interface.create_brain.return_value = BrainEntity(
|
||||
brain_id=uuid.uuid4(), # generate a valid UUID
|
||||
name="test_name",
|
||||
last_update=datetime.now().isoformat(), # convert datetime to string
|
||||
brain_type=BrainType.DOC,
|
||||
)
|
||||
result = mock_brains_interface.create_brain(brain)
|
||||
mock_brains_interface.create_brain.assert_called_once_with(brain)
|
||||
assert isinstance(result, BrainEntity)
|
||||
|
||||
|
||||
def test_brain_entity_creation():
|
||||
brain_id = uuid.uuid4()
|
||||
name = "test_name"
|
||||
last_update = datetime.now().isoformat()
|
||||
brain_type = BrainType.DOC
|
||||
|
||||
brain_entity = BrainEntity(
|
||||
brain_id=brain_id, name=name, last_update=last_update, brain_type=brain_type
|
||||
)
|
||||
|
||||
assert brain_entity.brain_id == brain_id
|
||||
assert brain_entity.name == name
|
||||
assert brain_entity.last_update == last_update
|
||||
assert brain_entity.brain_type == brain_type
|
||||
|
||||
|
||||
def test_brain_entity_id_property():
|
||||
brain_id = uuid.uuid4()
|
||||
name = "test_name"
|
||||
last_update = datetime.now().isoformat()
|
||||
brain_type = BrainType.DOC
|
||||
|
||||
brain_entity = BrainEntity(
|
||||
brain_id=brain_id, name=name, last_update=last_update, brain_type=brain_type
|
||||
)
|
||||
|
||||
assert brain_entity.id == brain_id
|
||||
|
||||
|
||||
def test_brain_entity_dict_method():
|
||||
brain_id = uuid.uuid4()
|
||||
name = "test_name"
|
||||
last_update = datetime.now().isoformat()
|
||||
brain_type = BrainType.DOC
|
||||
|
||||
brain_entity = BrainEntity(
|
||||
brain_id=brain_id, name=name, last_update=last_update, brain_type=brain_type
|
||||
)
|
||||
|
||||
brain_dict = brain_entity.dict()
|
||||
assert brain_dict["id"] == brain_id
|
||||
assert brain_dict["name"] == name
|
||||
assert brain_dict["last_update"] == last_update
|
||||
assert brain_dict["brain_type"] == brain_type
|
@ -4,7 +4,7 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from langchain.embeddings.ollama import OllamaEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from vectorstore.supabase import CustomSupabaseVectorStore
|
||||
from logger import get_logger
|
||||
from models.settings import BrainSettings, get_supabase_client
|
||||
from modules.brain.dto.inputs import BrainUpdatableProperties, CreateBrainProperties
|
||||
from modules.brain.entity.brain_entity import BrainEntity, BrainType, PublicBrain
|
||||
@ -25,8 +25,7 @@ from modules.brain.repository.interfaces import (
|
||||
from modules.brain.service.api_brain_definition_service import ApiBrainDefinitionService
|
||||
from modules.brain.service.utils.validate_brain import validate_api_brain
|
||||
from modules.knowledge.service.knowledge_service import KnowledgeService
|
||||
|
||||
from logger import get_logger
|
||||
from vectorstore.supabase import CustomSupabaseVectorStore
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -84,6 +83,7 @@ class BrainService:
|
||||
# Init
|
||||
|
||||
brain_id_to_use = brain_id
|
||||
brain_to_use = None
|
||||
|
||||
# Get the first question from the chat_question
|
||||
|
||||
@ -97,6 +97,7 @@ class BrainService:
|
||||
|
||||
if history and not brain_id:
|
||||
brain_id_to_use = history[0].brain_id
|
||||
brain_to_use = self.get_brain_by_id(brain_id_to_use)
|
||||
|
||||
# Calculate the closest brains to the question
|
||||
list_brains = vector_store.find_brain_closest_query(user.id, question)
|
||||
@ -111,10 +112,11 @@ class BrainService:
|
||||
|
||||
metadata["close_brains"] = unique_list_brains[:5]
|
||||
|
||||
if list_brains and not brain_id_to_use:
|
||||
if list_brains and not brain_to_use:
|
||||
brain_id_to_use = list_brains[0]["id"]
|
||||
brain_to_use = self.get_brain_by_id(brain_id_to_use)
|
||||
|
||||
return brain_id_to_use, metadata
|
||||
return brain_to_use, metadata
|
||||
|
||||
def create_brain(
|
||||
self,
|
||||
|
108
backend/modules/chat/controller/chat/test_utils.py
Normal file
108
backend/modules/chat/controller/chat/test_utils.py
Normal file
@ -0,0 +1,108 @@
|
||||
# FILEPATH: /Users/stan/Dev/Padok/secondbrain/backend/modules/chat/controller/chat/test_utils.py
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from models.databases.entity import LLMModels
|
||||
from modules.chat.controller.chat.utils import (
|
||||
check_user_requests_limit,
|
||||
find_model_and_generate_metadata,
|
||||
)
|
||||
|
||||
|
||||
@patch("modules.chat.controller.chat.utils.chat_service")
|
||||
def test_find_model_and_generate_metadata(mock_chat_service):
|
||||
chat_id = uuid.uuid4()
|
||||
brain = Mock()
|
||||
brain.model = "gpt-3.5-turbo-1106"
|
||||
user_settings = {"models": ["gpt-3.5-turbo-1106"]}
|
||||
models_settings = [
|
||||
{"name": "gpt-3.5-turbo-1106", "max_input": 512, "max_output": 512}
|
||||
]
|
||||
metadata_brain = {"key": "value"}
|
||||
|
||||
mock_chat_service.get_follow_up_question.return_value = []
|
||||
|
||||
model_to_use, metadata = find_model_and_generate_metadata(
|
||||
chat_id, brain, user_settings, models_settings, metadata_brain
|
||||
)
|
||||
|
||||
assert isinstance(model_to_use, LLMModels)
|
||||
assert model_to_use.name == "gpt-3.5-turbo-1106"
|
||||
assert model_to_use.max_input == 512
|
||||
assert model_to_use.max_output == 512
|
||||
assert metadata == {
|
||||
"key": "value",
|
||||
"follow_up_questions": [],
|
||||
"model": "gpt-3.5-turbo-1106",
|
||||
"max_tokens": 512,
|
||||
"max_input": 512,
|
||||
}
|
||||
|
||||
|
||||
@patch("modules.chat.controller.chat.utils.chat_service")
|
||||
def test_find_model_and_generate_metadata_user_not_allowed(mock_chat_service):
|
||||
chat_id = uuid.uuid4()
|
||||
brain = Mock()
|
||||
brain.model = "gpt-3.5-turbo-1106"
|
||||
user_settings = {
|
||||
"models": ["gpt-3.5-turbo-1107"]
|
||||
} # User is not allowed to use the brain's model
|
||||
models_settings = [
|
||||
{"name": "gpt-3.5-turbo-1106", "max_input": 512, "max_output": 512},
|
||||
{"name": "gpt-3.5-turbo-1107", "max_input": 512, "max_output": 512},
|
||||
]
|
||||
metadata_brain = {"key": "value"}
|
||||
|
||||
mock_chat_service.get_follow_up_question.return_value = []
|
||||
|
||||
model_to_use, metadata = find_model_and_generate_metadata(
|
||||
chat_id, brain, user_settings, models_settings, metadata_brain
|
||||
)
|
||||
|
||||
assert isinstance(model_to_use, LLMModels)
|
||||
assert model_to_use.name == "gpt-3.5-turbo-1106" # Default model is used
|
||||
assert model_to_use.max_input == 512
|
||||
assert model_to_use.max_output == 512
|
||||
assert metadata == {
|
||||
"key": "value",
|
||||
"follow_up_questions": [],
|
||||
"model": "gpt-3.5-turbo-1106",
|
||||
"max_tokens": 512,
|
||||
"max_input": 512,
|
||||
}
|
||||
|
||||
|
||||
@patch("modules.chat.controller.chat.utils.time")
|
||||
def test_check_user_requests_limit_within_limit(mock_time):
|
||||
mock_time.strftime.return_value = "20220101"
|
||||
usage = Mock()
|
||||
usage.get_user_monthly_usage.return_value = 50
|
||||
user_settings = {"monthly_chat_credit": 100}
|
||||
models_settings = [{"name": "gpt-3.5-turbo", "price": 10}]
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
check_user_requests_limit(usage, user_settings, models_settings, model_name)
|
||||
|
||||
usage.handle_increment_user_request_count.assert_called_once_with("20220101", 10)
|
||||
|
||||
|
||||
@patch("modules.chat.controller.chat.utils.time")
|
||||
def test_check_user_requests_limit_exceeds_limit(mock_time):
|
||||
mock_time.strftime.return_value = "20220101"
|
||||
usage = Mock()
|
||||
usage.get_user_monthly_usage.return_value = 100
|
||||
user_settings = {"monthly_chat_credit": 100}
|
||||
models_settings = [{"name": "gpt-3.5-turbo", "price": 10}]
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
check_user_requests_limit(usage, user_settings, models_settings, model_name)
|
||||
|
||||
assert exc_info.value.status_code == 429
|
||||
assert (
|
||||
"You have reached your monthly chat limit of 100 requests per months."
|
||||
in str(exc_info.value.detail)
|
||||
)
|
@ -4,9 +4,13 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from logger import get_logger
|
||||
from models import UserUsage
|
||||
from modules.user.entity.user_identity import UserIdentity
|
||||
from models.databases.entity import LLMModels
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.service.chat_service import ChatService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
brain_service = BrainService()
|
||||
chat_service = ChatService()
|
||||
|
||||
|
||||
class NullableUUID(UUID):
|
||||
@ -24,8 +28,56 @@ class NullableUUID(UUID):
|
||||
return None
|
||||
|
||||
|
||||
def check_user_requests_limit(user: UserIdentity, model: str):
|
||||
# TODO : Pass objects to avoid multiple calls to the database
|
||||
def find_model_and_generate_metadata(
|
||||
chat_id: UUID,
|
||||
brain,
|
||||
user_settings,
|
||||
models_settings,
|
||||
metadata_brain,
|
||||
):
|
||||
# Add metadata_brain to metadata
|
||||
metadata = {}
|
||||
metadata = {**metadata, **metadata_brain}
|
||||
follow_up_questions = chat_service.get_follow_up_question(chat_id)
|
||||
metadata["follow_up_questions"] = follow_up_questions
|
||||
# Default model is gpt-3.5-turbo-1106
|
||||
model_to_use = LLMModels(
|
||||
name="gpt-3.5-turbo-1106", price=1, max_input=512, max_output=512
|
||||
)
|
||||
|
||||
is_brain_model_available = any(
|
||||
brain.model == model_dict.get("name") for model_dict in models_settings
|
||||
)
|
||||
|
||||
is_user_allowed_model = brain.model in user_settings.get(
|
||||
"models", ["gpt-3.5-turbo-1106"]
|
||||
) # Checks if the model is available in the list of models
|
||||
|
||||
logger.info(f"Brain model: {brain.model}")
|
||||
logger.info(f"User models: {user_settings.get('models', [])}")
|
||||
logger.info(f"Model available: {is_brain_model_available}")
|
||||
logger.info(f"User allowed model: {is_user_allowed_model}")
|
||||
|
||||
if is_brain_model_available and is_user_allowed_model:
|
||||
# Use the model from the brain
|
||||
model_to_use.name = brain.model
|
||||
for model_dict in models_settings:
|
||||
if model_dict.get("name") == model_to_use.name:
|
||||
logger.info(f"Using model {model_to_use.name}")
|
||||
model_to_use.max_input = model_dict.get("max_input")
|
||||
model_to_use.max_output = model_dict.get("max_output")
|
||||
break
|
||||
|
||||
metadata["model"] = model_to_use.name
|
||||
metadata["max_tokens"] = model_to_use.max_output
|
||||
metadata["max_input"] = model_to_use.max_input
|
||||
|
||||
return model_to_use, metadata
|
||||
|
||||
|
||||
def check_user_requests_limit(
|
||||
usage: UserUsage, user_settings, models_settings, model_name: str
|
||||
):
|
||||
"""Checks the user requests limit.
|
||||
It checks the user requests limit and raises an exception if the user has reached the limit.
|
||||
By default, the user has a limit of 100 requests per month. The limit can be increased by upgrading the plan.
|
||||
@ -37,19 +89,16 @@ def check_user_requests_limit(user: UserIdentity, model: str):
|
||||
Raises:
|
||||
HTTPException: Raises a 429 error if the user has reached the limit.
|
||||
"""
|
||||
userDailyUsage = UserUsage(id=user.id, email=user.email)
|
||||
|
||||
userSettings = userDailyUsage.get_user_settings()
|
||||
usage
|
||||
|
||||
date = time.strftime("%Y%m%d")
|
||||
|
||||
monthly_chat_credit = userSettings.get("monthly_chat_credit", 100)
|
||||
daily_user_count = userDailyUsage.get_user_monthly_usage(date)
|
||||
models_price = userDailyUsage.get_model_settings()
|
||||
monthly_chat_credit = user_settings.get("monthly_chat_credit", 100)
|
||||
daily_user_count = usage.get_user_monthly_usage(date)
|
||||
user_choosen_model_price = 1000
|
||||
|
||||
for model_setting in models_price:
|
||||
if model_setting["name"] == model:
|
||||
for model_setting in models_settings:
|
||||
if model_setting["name"] == model_name:
|
||||
user_choosen_model_price = model_setting["price"]
|
||||
|
||||
if int(daily_user_count + user_choosen_model_price) > int(monthly_chat_credit):
|
||||
@ -58,7 +107,5 @@ def check_user_requests_limit(user: UserIdentity, model: str):
|
||||
detail=f"You have reached your monthly chat limit of {monthly_chat_credit} requests per months. Please upgrade your plan to increase your daily chat limit.",
|
||||
)
|
||||
else:
|
||||
userDailyUsage.handle_increment_user_request_count(
|
||||
date, user_choosen_model_price
|
||||
)
|
||||
usage.handle_increment_user_request_count(date, user_choosen_model_price)
|
||||
pass
|
||||
|
@ -3,13 +3,16 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from logger import get_logger
|
||||
from middlewares.auth import AuthBearer, get_current_user
|
||||
from models.databases.entity import LLMModels
|
||||
from models.user_usage import UserUsage
|
||||
from modules.brain.service.brain_service import BrainService
|
||||
from modules.chat.controller.chat.brainful_chat import BrainfulChat
|
||||
from modules.chat.controller.chat.factory import get_chat_strategy
|
||||
from modules.chat.controller.chat.utils import NullableUUID, check_user_requests_limit
|
||||
from modules.chat.controller.chat.utils import (
|
||||
NullableUUID,
|
||||
check_user_requests_limit,
|
||||
find_model_and_generate_metadata,
|
||||
)
|
||||
from modules.chat.dto.chats import ChatItem, ChatQuestion
|
||||
from modules.chat.dto.inputs import (
|
||||
ChatUpdatableProperties,
|
||||
@ -21,8 +24,6 @@ from modules.chat.service.chat_service import ChatService
|
||||
from modules.notification.service.notification_service import NotificationService
|
||||
from modules.user.entity.user_identity import UserIdentity
|
||||
|
||||
from logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
chat_router = APIRouter()
|
||||
@ -32,6 +33,65 @@ brain_service = BrainService()
|
||||
chat_service = ChatService()
|
||||
|
||||
|
||||
def get_answer_generator(
|
||||
chat_id: UUID,
|
||||
chat_question: ChatQuestion,
|
||||
brain_id: UUID,
|
||||
current_user: UserIdentity,
|
||||
):
|
||||
chat_instance = BrainfulChat()
|
||||
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
||||
|
||||
user_usage = UserUsage(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
)
|
||||
|
||||
# Get History
|
||||
history = chat_service.get_chat_history(chat_id)
|
||||
|
||||
# Get user settings
|
||||
user_settings = user_usage.get_user_settings()
|
||||
|
||||
# Get Model settings for the user
|
||||
models_settings = user_usage.get_model_settings()
|
||||
|
||||
# Generic
|
||||
brain, metadata_brain = brain_service.find_brain_from_question(
|
||||
brain_id, chat_question.question, current_user, chat_id, history
|
||||
)
|
||||
|
||||
model_to_use, metadata = find_model_and_generate_metadata(
|
||||
chat_id,
|
||||
brain,
|
||||
user_settings,
|
||||
models_settings,
|
||||
metadata_brain,
|
||||
)
|
||||
|
||||
# Raises an error if the user has consumed all of of his credits
|
||||
check_user_requests_limit(
|
||||
usage=user_usage,
|
||||
user_settings=user_settings,
|
||||
models_settings=models_settings,
|
||||
model_name=model_to_use.name,
|
||||
)
|
||||
gpt_answer_generator = chat_instance.get_answer_generator(
|
||||
chat_id=str(chat_id),
|
||||
model=model_to_use.name,
|
||||
max_tokens=model_to_use.max_output,
|
||||
max_input=model_to_use.max_input,
|
||||
temperature=0.1,
|
||||
streaming=True,
|
||||
prompt_id=chat_question.prompt_id,
|
||||
user_id=current_user.id,
|
||||
metadata=metadata,
|
||||
brain=brain,
|
||||
)
|
||||
|
||||
return gpt_answer_generator
|
||||
|
||||
|
||||
@chat_router.get("/chat/healthz", tags=["Health"])
|
||||
async def healthz():
|
||||
return {"status": "ok"}
|
||||
@ -123,56 +183,9 @@ async def create_question_handler(
|
||||
| None = Query(..., description="The ID of the brain"),
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Add a new question to the chat.
|
||||
"""
|
||||
|
||||
chat_instance = get_chat_strategy(brain_id)
|
||||
|
||||
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
||||
|
||||
fallback_model = "gpt-3.5-turbo-1106"
|
||||
fallback_temperature = 0.1
|
||||
fallback_max_tokens = 512
|
||||
|
||||
user_daily_usage = UserUsage(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
)
|
||||
user_settings = user_daily_usage.get_user_settings()
|
||||
is_model_ok = (chat_question).model in user_settings.get("models", ["gpt-3.5-turbo-1106"]) # type: ignore
|
||||
|
||||
# Retrieve chat model (temperature, max_tokens, model)
|
||||
if (
|
||||
not chat_question.model
|
||||
or not chat_question.temperature
|
||||
or not chat_question.max_tokens
|
||||
):
|
||||
if brain_id:
|
||||
brain = brain_service.get_brain_by_id(brain_id)
|
||||
if brain:
|
||||
fallback_model = brain.model or fallback_model
|
||||
fallback_temperature = brain.temperature or fallback_temperature
|
||||
fallback_max_tokens = brain.max_tokens or fallback_max_tokens
|
||||
|
||||
chat_question.model = chat_question.model or fallback_model
|
||||
chat_question.temperature = chat_question.temperature or fallback_temperature
|
||||
chat_question.max_tokens = chat_question.max_tokens or fallback_max_tokens
|
||||
|
||||
try:
|
||||
check_user_requests_limit(current_user, chat_question.model)
|
||||
is_model_ok = (chat_question).model in user_settings.get("models", ["gpt-3.5-turbo-1106"]) # type: ignore
|
||||
gpt_answer_generator = chat_instance.get_answer_generator(
|
||||
chat_id=str(chat_id),
|
||||
model=chat_question.model if is_model_ok else "gpt-3.5-turbo-1106", # type: ignore
|
||||
max_tokens=chat_question.max_tokens,
|
||||
temperature=chat_question.temperature,
|
||||
streaming=False,
|
||||
prompt_id=chat_question.prompt_id,
|
||||
user_id=current_user.id,
|
||||
max_input=2000,
|
||||
brain=brain_service.get_brain_by_id(brain_id),
|
||||
metadata={},
|
||||
gpt_answer_generator = get_answer_generator(
|
||||
chat_id, chat_question, brain_id, current_user
|
||||
)
|
||||
|
||||
chat_answer = gpt_answer_generator.generate_answer(
|
||||
@ -203,86 +216,11 @@ async def create_stream_question_handler(
|
||||
| None = Query(..., description="The ID of the brain"),
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
chat_instance = BrainfulChat()
|
||||
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
||||
|
||||
user_usage = UserUsage(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
gpt_answer_generator = get_answer_generator(
|
||||
chat_id, chat_question, brain_id, current_user
|
||||
)
|
||||
|
||||
# Get History
|
||||
history = chat_service.get_chat_history(chat_id)
|
||||
|
||||
# Get user settings
|
||||
user_settings = user_usage.get_user_settings()
|
||||
|
||||
# Get Model settings for the user
|
||||
models_settings = user_usage.get_model_settings()
|
||||
|
||||
# Generic
|
||||
brain_id_to_use, metadata_brain = brain_service.find_brain_from_question(
|
||||
brain_id, chat_question.question, current_user, chat_id, history
|
||||
)
|
||||
|
||||
# Add metadata_brain to metadata
|
||||
metadata = {}
|
||||
metadata = {**metadata, **metadata_brain}
|
||||
follow_up_questions = chat_service.get_follow_up_question(chat_id)
|
||||
metadata["follow_up_questions"] = follow_up_questions
|
||||
|
||||
# Get the Brain settings
|
||||
brain = brain_service.get_brain_by_id(brain_id_to_use)
|
||||
|
||||
logger.info(f"Brain model: {brain.model}")
|
||||
logger.info(f"Brain is : {str(brain)}")
|
||||
try:
|
||||
# Default model is gpt-3.5-turbo-1106
|
||||
model_to_use = LLMModels(
|
||||
name="gpt-3.5-turbo-1106", price=1, max_input=512, max_output=512
|
||||
)
|
||||
|
||||
is_brain_model_available = any(
|
||||
brain.model == model_dict.get("name") for model_dict in models_settings
|
||||
)
|
||||
|
||||
is_user_allowed_model = brain.model in user_settings.get(
|
||||
"models", ["gpt-3.5-turbo-1106"]
|
||||
) # Checks if the model is available in the list of models
|
||||
|
||||
logger.info(f"Brain model: {brain.model}")
|
||||
logger.info(f"User models: {user_settings.get('models', [])}")
|
||||
logger.info(f"Model available: {is_brain_model_available}")
|
||||
logger.info(f"User allowed model: {is_user_allowed_model}")
|
||||
|
||||
if is_brain_model_available and is_user_allowed_model:
|
||||
# Use the model from the brain
|
||||
model_to_use.name = brain.model
|
||||
for model_dict in models_settings:
|
||||
if model_dict.get("name") == model_to_use.name:
|
||||
logger.info(f"Using model {model_to_use.name}")
|
||||
model_to_use.max_input = model_dict.get("max_input")
|
||||
model_to_use.max_output = model_dict.get("max_output")
|
||||
break
|
||||
|
||||
metadata["model"] = model_to_use.name
|
||||
metadata["max_tokens"] = model_to_use.max_output
|
||||
metadata["max_input"] = model_to_use.max_input
|
||||
|
||||
check_user_requests_limit(current_user, model_to_use.name)
|
||||
gpt_answer_generator = chat_instance.get_answer_generator(
|
||||
chat_id=str(chat_id),
|
||||
model=model_to_use.name,
|
||||
max_tokens=model_to_use.max_output,
|
||||
max_input=model_to_use.max_input,
|
||||
temperature=0.1,
|
||||
streaming=True,
|
||||
prompt_id=chat_question.prompt_id,
|
||||
user_id=current_user.id,
|
||||
metadata=metadata,
|
||||
brain=brain,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
gpt_answer_generator.generate_stream(
|
||||
chat_id, chat_question, save_answer=True
|
||||
|
@ -53,25 +53,12 @@ def test_create_chat_and_talk(client, api_key):
|
||||
"model": "gpt-3.5-turbo-1106",
|
||||
"question": "Hello, how are you?",
|
||||
"temperature": "0",
|
||||
"max_tokens": "256",
|
||||
"max_tokens": "2000",
|
||||
},
|
||||
headers={"Authorization": "Bearer " + api_key},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response = client.post(
|
||||
f"/chat/{chat_id}/question?brain_id={default_brain_id}",
|
||||
json={
|
||||
"model": "gpt-4",
|
||||
"question": "Hello, how are you?",
|
||||
"temperature": "0",
|
||||
"max_tokens": "256",
|
||||
},
|
||||
headers={"Authorization": "Bearer " + api_key},
|
||||
)
|
||||
print(response)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Now, let's delete the chat
|
||||
delete_response = client.delete(
|
||||
"/chat/" + chat_id, headers={"Authorization": "Bearer " + api_key}
|
||||
|
Loading…
Reference in New Issue
Block a user