mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 01:21:48 +03:00
fix: update max token overwrite logic (#1694)
Issue: https://github.com/StanGirard/quivr/issues/1690
This commit is contained in:
parent
d893ec7d97
commit
4bb1a0dc8a
@ -55,7 +55,7 @@ class APIBrainQA(
|
||||
response = completion(
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=2000,
|
||||
max_tokens=self.max_tokens,
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
stream=True,
|
||||
@ -109,10 +109,14 @@ class APIBrainQA(
|
||||
yield value
|
||||
|
||||
else:
|
||||
if hasattr(chunk.choices[0], 'delta') and chunk.choices[0].delta and hasattr(chunk.choices[0].delta, 'content'):
|
||||
if (
|
||||
hasattr(chunk.choices[0], "delta")
|
||||
and chunk.choices[0].delta
|
||||
and hasattr(chunk.choices[0].delta, "content")
|
||||
):
|
||||
content = chunk.choices[0].delta.content
|
||||
yield content
|
||||
else: # pragma: no cover
|
||||
else: # pragma: no cover
|
||||
yield "**...**"
|
||||
break
|
||||
|
||||
|
@ -7,9 +7,10 @@ from fastapi.responses import StreamingResponse
|
||||
from llm.qa_base import QABaseBrainPicking
|
||||
from llm.qa_headless import HeadlessQA
|
||||
from middlewares.auth import AuthBearer, get_current_user
|
||||
from models import Brain, BrainEntity, Chat, ChatQuestion, UserUsage, get_supabase_db
|
||||
from models import Chat, ChatQuestion, UserUsage, get_supabase_db
|
||||
from models.databases.supabase.chats import QuestionAndAnswer
|
||||
from modules.user.entity.user_identity import UserIdentity
|
||||
from repository.brain.get_brain_by_id import get_brain_by_id
|
||||
from repository.chat import (
|
||||
ChatUpdatableProperties,
|
||||
CreateChatProperties,
|
||||
@ -25,6 +26,7 @@ from repository.chat.get_chat_history_with_notifications import (
|
||||
get_chat_history_with_notifications,
|
||||
)
|
||||
from repository.notification.remove_chat_notifications import remove_chat_notifications
|
||||
|
||||
from routes.chat.factory import get_chat_strategy
|
||||
from routes.chat.utils import (
|
||||
NullableUUID,
|
||||
@ -133,15 +135,16 @@ async def create_question_handler(
|
||||
|
||||
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
||||
|
||||
brain = Brain(id=brain_id)
|
||||
brain_details: BrainEntity | None = None
|
||||
fallback_model = "gpt-3.5-turbo"
|
||||
fallback_temperature = 0.1
|
||||
fallback_max_tokens = 512
|
||||
|
||||
userDailyUsage = UserUsage(
|
||||
user_daily_usage = UserUsage(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
)
|
||||
userSettings = userDailyUsage.get_user_settings()
|
||||
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||
user_settings = user_daily_usage.get_user_settings()
|
||||
is_model_ok = (chat_question).model in user_settings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||
|
||||
# Retrieve chat model (temperature, max_tokens, model)
|
||||
if (
|
||||
@ -149,16 +152,20 @@ async def create_question_handler(
|
||||
or not chat_question.temperature
|
||||
or not chat_question.max_tokens
|
||||
):
|
||||
# TODO: create ChatConfig class (pick config from brain or user or chat) and use it here
|
||||
chat_question.model = chat_question.model or brain.model or "gpt-3.5-turbo"
|
||||
chat_question.temperature = (
|
||||
chat_question.temperature or brain.temperature or 0.1
|
||||
)
|
||||
chat_question.max_tokens = chat_question.max_tokens or brain.max_tokens or 512
|
||||
if brain_id:
|
||||
brain = get_brain_by_id(brain_id)
|
||||
if brain:
|
||||
fallback_model = brain.model or fallback_model
|
||||
fallback_temperature = brain.temperature or fallback_temperature
|
||||
fallback_max_tokens = brain.max_tokens or fallback_max_tokens
|
||||
|
||||
chat_question.model = chat_question.model or fallback_model
|
||||
chat_question.temperature = chat_question.temperature or fallback_temperature
|
||||
chat_question.max_tokens = chat_question.max_tokens or fallback_max_tokens
|
||||
|
||||
try:
|
||||
check_user_requests_limit(current_user)
|
||||
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||
is_model_ok = (chat_question).model in user_settings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||
gpt_answer_generator = chat_instance.get_answer_generator(
|
||||
chat_id=str(chat_id),
|
||||
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
||||
@ -199,14 +206,12 @@ async def create_stream_question_handler(
|
||||
chat_instance = get_chat_strategy(brain_id)
|
||||
chat_instance.validate_authorization(user_id=current_user.id, brain_id=brain_id)
|
||||
|
||||
brain = Brain(id=brain_id)
|
||||
brain_details: BrainEntity | None = None
|
||||
userDailyUsage = UserUsage(
|
||||
user_daily_usage = UserUsage(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
)
|
||||
|
||||
userSettings = userDailyUsage.get_user_settings()
|
||||
user_settings = user_daily_usage.get_user_settings()
|
||||
|
||||
# Retrieve chat model (temperature, max_tokens, model)
|
||||
if (
|
||||
@ -214,10 +219,20 @@ async def create_stream_question_handler(
|
||||
or chat_question.temperature is None
|
||||
or not chat_question.max_tokens
|
||||
):
|
||||
# TODO: create ChatConfig class (pick config from brain or user or chat) and use it here
|
||||
chat_question.model = chat_question.model or brain.model or "gpt-3.5-turbo"
|
||||
chat_question.temperature = chat_question.temperature or brain.temperature or 0
|
||||
chat_question.max_tokens = chat_question.max_tokens or brain.max_tokens or 256
|
||||
fallback_model = "gpt-3.5-turbo"
|
||||
fallback_temperature = 0
|
||||
fallback_max_tokens = 256
|
||||
|
||||
if brain_id:
|
||||
brain = get_brain_by_id(brain_id)
|
||||
if brain:
|
||||
fallback_model = brain.model or fallback_model
|
||||
fallback_temperature = brain.temperature or fallback_temperature
|
||||
fallback_max_tokens = brain.max_tokens or fallback_max_tokens
|
||||
|
||||
chat_question.model = chat_question.model or fallback_model
|
||||
chat_question.temperature = chat_question.temperature or fallback_temperature
|
||||
chat_question.max_tokens = chat_question.max_tokens or fallback_max_tokens
|
||||
|
||||
try:
|
||||
logger.info(f"Streaming request for {chat_question.model}")
|
||||
@ -225,13 +240,12 @@ async def create_stream_question_handler(
|
||||
gpt_answer_generator: HeadlessQA | QABaseBrainPicking
|
||||
# TODO check if model is in the list of models available for the user
|
||||
|
||||
is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||
|
||||
is_model_ok = chat_question.model in user_settings.get("models", ["gpt-3.5-turbo"]) # type: ignore
|
||||
gpt_answer_generator = chat_instance.get_answer_generator(
|
||||
chat_id=str(chat_id),
|
||||
model=(brain_details or chat_question).model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
||||
max_tokens=(brain_details or chat_question).max_tokens, # type: ignore
|
||||
temperature=(brain_details or chat_question).temperature, # type: ignore
|
||||
model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
|
||||
max_tokens=chat_question.max_tokens,
|
||||
temperature=chat_question.temperature, # type: ignore
|
||||
streaming=True,
|
||||
prompt_id=chat_question.prompt_id,
|
||||
brain_id=str(brain_id),
|
||||
|
Loading…
Reference in New Issue
Block a user