fix: update max token overwrite logic (#1694)

Issue: https://github.com/StanGirard/quivr/issues/1690
This commit is contained in:
Mamadou DICKO 2023-11-23 17:36:11 +01:00 committed by GitHub
parent d893ec7d97
commit 4bb1a0dc8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 29 deletions

View File

@ -55,7 +55,7 @@ class APIBrainQA(
response = completion(
model=self.model,
temperature=self.temperature,
max_tokens=2000,
max_tokens=self.max_tokens,
messages=messages,
functions=functions,
stream=True,
@ -109,10 +109,14 @@ class APIBrainQA(
yield value
else:
if hasattr(chunk.choices[0], 'delta') and chunk.choices[0].delta and hasattr(chunk.choices[0].delta, 'content'):
if (
hasattr(chunk.choices[0], "delta")
and chunk.choices[0].delta
and hasattr(chunk.choices[0].delta, "content")
):
content = chunk.choices[0].delta.content
yield content
else: # pragma: no cover
else: # pragma: no cover
yield "**...**"
break

View File

@ -7,9 +7,10 @@ from fastapi.responses import StreamingResponse
from llm.qa_base import QABaseBrainPicking
from llm.qa_headless import HeadlessQA
from middlewares.auth import AuthBearer, get_current_user
from models import Brain, BrainEntity, Chat, ChatQuestion, UserUsage, get_supabase_db
from models import Chat, ChatQuestion, UserUsage, get_supabase_db
from models.databases.supabase.chats import QuestionAndAnswer
from modules.user.entity.user_identity import UserIdentity
from repository.brain.get_brain_by_id import get_brain_by_id
from repository.chat import (
ChatUpdatableProperties,
CreateChatProperties,
@ -25,6 +26,7 @@ from repository.chat.get_chat_history_with_notifications import (
get_chat_history_with_notifications,
)
from repository.notification.remove_chat_notifications import remove_chat_notifications
from routes.chat.factory import get_chat_strategy
from routes.chat.utils import (
NullableUUID,
@ -133,15 +135,16 @@ async def create_question_handler(
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
brain = Brain(id=brain_id)
brain_details: BrainEntity | None = None
fallback_model = "gpt-3.5-turbo"
fallback_temperature = 0.1
fallback_max_tokens = 512
userDailyUsage = UserUsage(
user_daily_usage = UserUsage(
id=current_user.id,
email=current_user.email,
)
userSettings = userDailyUsage.get_user_settings()
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
user_settings = user_daily_usage.get_user_settings()
is_model_ok = (chat_question).model in user_settings.get("models", ["gpt-3.5-turbo"]) # type: ignore
# Retrieve chat model (temperature, max_tokens, model)
if (
@ -149,16 +152,20 @@ async def create_question_handler(
or not chat_question.temperature
or not chat_question.max_tokens
):
# TODO: create ChatConfig class (pick config from brain or user or chat) and use it here
chat_question.model = chat_question.model or brain.model or "gpt-3.5-turbo"
chat_question.temperature = (
chat_question.temperature or brain.temperature or 0.1
)
chat_question.max_tokens = chat_question.max_tokens or brain.max_tokens or 512
if brain_id:
brain = get_brain_by_id(brain_id)
if brain:
fallback_model = brain.model or fallback_model
fallback_temperature = brain.temperature or fallback_temperature
fallback_max_tokens = brain.max_tokens or fallback_max_tokens
chat_question.model = chat_question.model or fallback_model
chat_question.temperature = chat_question.temperature or fallback_temperature
chat_question.max_tokens = chat_question.max_tokens or fallback_max_tokens
try:
check_user_requests_limit(current_user)
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
is_model_ok = (chat_question).model in user_settings.get("models", ["gpt-3.5-turbo"]) # type: ignore
gpt_answer_generator = chat_instance.get_answer_generator(
chat_id=str(chat_id),
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
@ -199,14 +206,12 @@ async def create_stream_question_handler(
chat_instance = get_chat_strategy(brain_id)
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
brain = Brain(id=brain_id)
brain_details: BrainEntity | None = None
userDailyUsage = UserUsage(
user_daily_usage = UserUsage(
id=current_user.id,
email=current_user.email,
)
userSettings = userDailyUsage.get_user_settings()
user_settings = user_daily_usage.get_user_settings()
# Retrieve chat model (temperature, max_tokens, model)
if (
@ -214,10 +219,20 @@ async def create_stream_question_handler(
or chat_question.temperature is None
or not chat_question.max_tokens
):
# TODO: create ChatConfig class (pick config from brain or user or chat) and use it here
chat_question.model = chat_question.model or brain.model or "gpt-3.5-turbo"
chat_question.temperature = chat_question.temperature or brain.temperature or 0
chat_question.max_tokens = chat_question.max_tokens or brain.max_tokens or 256
fallback_model = "gpt-3.5-turbo"
fallback_temperature = 0
fallback_max_tokens = 256
if brain_id:
brain = get_brain_by_id(brain_id)
if brain:
fallback_model = brain.model or fallback_model
fallback_temperature = brain.temperature or fallback_temperature
fallback_max_tokens = brain.max_tokens or fallback_max_tokens
chat_question.model = chat_question.model or fallback_model
chat_question.temperature = chat_question.temperature or fallback_temperature
chat_question.max_tokens = chat_question.max_tokens or fallback_max_tokens
try:
logger.info(f"Streaming request for {chat_question.model}")
@ -225,13 +240,12 @@ async def create_stream_question_handler(
gpt_answer_generator: HeadlessQA | QABaseBrainPicking
# TODO check if model is in the list of models available for the user
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
is_model_ok = chat_question.model in user_settings.get("models", ["gpt-3.5-turbo"]) # type: ignore
gpt_answer_generator = chat_instance.get_answer_generator(
chat_id=str(chat_id),
model=(brain_details or chat_question).model if is_model_ok else "gpt-3.5-turbo", # type: ignore
max_tokens=(brain_details or chat_question).max_tokens, # type: ignore
temperature=(brain_details or chat_question).temperature, # type: ignore
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
max_tokens=chat_question.max_tokens,
temperature=chat_question.temperature, # type: ignore
streaming=True,
prompt_id=chat_question.prompt_id,
brain_id=str(brain_id),