diff --git a/backend/llm/openai.py b/backend/llm/openai.py index 9857c6693..dd6ab9ef0 100644 --- a/backend/llm/openai.py +++ b/backend/llm/openai.py @@ -2,9 +2,8 @@ from typing import Optional from uuid import UUID from langchain.embeddings.openai import OpenAIEmbeddings -from logger import get_logger - from llm.qa_base import QABaseBrainPicking +from logger import get_logger logger = get_logger(__name__) @@ -16,7 +15,7 @@ class OpenAIBrainPicking(QABaseBrainPicking): """ # Default class attributes - model: str = "gpt-3.5-turbo" + model: str def __init__( self, diff --git a/backend/models/brains.py b/backend/models/brains.py index ea3b88771..e695bcf80 100644 --- a/backend/models/brains.py +++ b/backend/models/brains.py @@ -16,7 +16,7 @@ class Brain(BaseModel): name: Optional[str] = "Default brain" description: Optional[str] = "This is a description" status: Optional[str] = "private" - model: Optional[str] = "gpt-3.5-turbo" + model: Optional[str] temperature: Optional[float] = 0.0 max_tokens: Optional[int] = 256 openai_api_key: Optional[str] = None diff --git a/backend/models/chats.py b/backend/models/chats.py index 2b2d8a115..8674ad021 100644 --- a/backend/models/chats.py +++ b/backend/models/chats.py @@ -5,7 +5,7 @@ from pydantic import BaseModel class ChatMessage(BaseModel): - model: str = "gpt-3.5-turbo-16k" + model: str question: str # A list of tuples where each tuple is (speaker, text) history: List[Tuple[str, str]] diff --git a/backend/models/databases/supabase/brains.py b/backend/models/databases/supabase/brains.py index 12303cbb1..26773d3d6 100644 --- a/backend/models/databases/supabase/brains.py +++ b/backend/models/databases/supabase/brains.py @@ -13,7 +13,7 @@ class CreateBrainProperties(BaseModel): name: Optional[str] = "Default brain" description: Optional[str] = "This is a description" status: Optional[str] = "private" - model: Optional[str] = "gpt-3.5-turbo" + model: Optional[str] temperature: Optional[float] = 0.0 max_tokens: Optional[int] = 256 openai_api_key: Optional[str] = None diff --git a/backend/routes/chat_routes.py b/backend/routes/chat_routes.py index 8a2e128c0..ffff085d2 100644 --- a/backend/routes/chat_routes.py +++ b/backend/routes/chat_routes.py @@ -185,6 +185,15 @@ async def create_question_handler( # Retrieve user's OpenAI API key current_user.openai_api_key = request.headers.get("Openai-Api-Key") brain = Brain(id=brain_id) + brain_details: BrainEntity | None = None + + userDailyUsage = UserUsage( + id=current_user.id, + email=current_user.email, + openai_api_key=current_user.openai_api_key, + ) + userSettings = userDailyUsage.get_user_settings() + is_model_ok = (brain_details or chat_question).model in userSettings.models # type: ignore if not current_user.openai_api_key and brain_id: brain_details = get_brain_details(brain_id) @@ -210,12 +219,12 @@ async def create_question_handler( try: check_user_requests_limit(current_user) - + is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore gpt_answer_generator: HeadlessQA | OpenAIBrainPicking if brain_id: gpt_answer_generator = OpenAIBrainPicking( chat_id=str(chat_id), - model=chat_question.model, + model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore max_tokens=chat_question.max_tokens, temperature=chat_question.temperature, brain_id=str(brain_id), @@ -224,7 +233,7 @@ async def create_question_handler( ) else: gpt_answer_generator = HeadlessQA( - model=chat_question.model, + model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore temperature=chat_question.temperature, max_tokens=chat_question.max_tokens, user_openai_api_key=current_user.openai_api_key, @@ -264,6 +273,13 @@ async def create_stream_question_handler( current_user.openai_api_key = request.headers.get("Openai-Api-Key") brain = Brain(id=brain_id) brain_details: BrainEntity | None = None + userDailyUsage = UserUsage( + id=current_user.id, + email=current_user.email, + openai_api_key=current_user.openai_api_key, + ) + + userSettings = userDailyUsage.get_user_settings() if not current_user.openai_api_key and brain_id: brain_details = get_brain_details(brain_id) if brain_details: @@ -290,18 +306,15 @@ async def create_stream_question_handler( logger.info(f"Streaming request for {chat_question.model}") check_user_requests_limit(current_user) gpt_answer_generator: HeadlessQA | OpenAIBrainPicking + # TODO check if model is in the list of models available for the user + print(userSettings.get("models", ["gpt-3.5-turbo"])) # type: ignore + is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore if brain_id: gpt_answer_generator = OpenAIBrainPicking( chat_id=str(chat_id), - model=(brain_details or chat_question).model - if current_user.openai_api_key - else "gpt-3.5-turbo", # type: ignore - max_tokens=(brain_details or chat_question).max_tokens - if current_user.openai_api_key - else 256, # type: ignore - temperature=(brain_details or chat_question).temperature - if current_user.openai_api_key - else 0, # type: ignore + model=(brain_details or chat_question).model if is_model_ok else "gpt-3.5-turbo", # type: ignore + max_tokens=(brain_details or chat_question).max_tokens, # type: ignore + temperature=(brain_details or chat_question).temperature, # type: ignore brain_id=str(brain_id), user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none streaming=True, @@ -309,15 +322,9 @@ async def create_stream_question_handler( ) else: gpt_answer_generator = HeadlessQA( - model=chat_question.model - if current_user.openai_api_key - else "gpt-3.5-turbo", - temperature=chat_question.temperature - if current_user.openai_api_key - else 0, - max_tokens=chat_question.max_tokens - if current_user.openai_api_key - else 256, + model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore + temperature=chat_question.temperature, + max_tokens=chat_question.max_tokens, user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none chat_id=str(chat_id), streaming=True, diff --git a/frontend/lib/components/AddBrainModal/AddBrainModal.tsx b/frontend/lib/components/AddBrainModal/AddBrainModal.tsx index fee22bbc6..340dd7560 100644 --- a/frontend/lib/components/AddBrainModal/AddBrainModal.tsx +++ b/frontend/lib/components/AddBrainModal/AddBrainModal.tsx @@ -1,5 +1,5 @@ /* eslint-disable @typescript-eslint/no-unsafe-assignment */ -/* eslint-disable max-lines */ +/* eslint-disable */ import { useTranslation } from "react-i18next"; import { MdAdd } from "react-icons/md"; @@ -8,11 +8,10 @@ import Button from "@/lib/components/ui/Button"; import Field from "@/lib/components/ui/Field"; import { Modal } from "@/lib/components/ui/Modal"; import { defineMaxTokens } from "@/lib/helpers/defineMaxTokens"; -import { freeModels, paidModels } from "@/lib/types/brainConfig"; -import { useAddBrainModal } from "./hooks/useAddBrainModal"; import { Divider } from "../ui/Divider"; import { TextArea } from "../ui/TextArea"; +import { useAddBrainModal } from "./hooks/useAddBrainModal"; export const AddBrainModal = (): JSX.Element => { const { t } = useTranslation(["translation", "brain", "config"]); @@ -27,6 +26,7 @@ export const AddBrainModal = (): JSX.Element => { model, isPending, pickPublicPrompt, + accessibleModels, } = useAddBrainModal(); return ( @@ -89,13 +89,11 @@ export const AddBrainModal = (): JSX.Element => { {...register("model")} className="px-5 py-2 dark:bg-gray-700 bg-gray-200 rounded-md" > - {(openAiKey !== undefined ? paidModels : freeModels).map( - (availableModel) => ( - - ) - )} + {accessibleModels.map((availableModel) => ( + + ))} diff --git a/frontend/lib/components/AddBrainModal/hooks/useAddBrainModal.ts b/frontend/lib/components/AddBrainModal/hooks/useAddBrainModal.ts index 33ac04443..a60a8ed09 100644 --- a/frontend/lib/components/AddBrainModal/hooks/useAddBrainModal.ts +++ b/frontend/lib/components/AddBrainModal/hooks/useAddBrainModal.ts @@ -1,4 +1,4 @@ -/* eslint-disable max-lines */ +/* eslint-disable */ import axios from "axios"; import { useEffect, useState } from "react"; import { useForm } from "react-hook-form"; @@ -6,10 +6,17 @@ import { useTranslation } from "react-i18next"; import { useBrainApi } from "@/lib/api/brain/useBrainApi"; import { usePromptApi } from "@/lib/api/prompt/usePromptApi"; +import { USER_DATA_KEY } from "@/lib/api/user/config"; +import { useUserApi } from "@/lib/api/user/useUserApi"; import { defaultBrainConfig } from "@/lib/config/defaultBrainConfig"; import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; import { defineMaxTokens } from "@/lib/helpers/defineMaxTokens"; +import { getAccessibleModels } from "@/lib/helpers/getAccessibleModels"; import { useToast } from "@/lib/hooks"; +import { useQuery } from "@tanstack/react-query"; + + + // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types export const useAddBrainModal = () => { @@ -21,6 +28,13 @@ export const useAddBrainModal = () => { const { createPrompt } = usePromptApi(); const [isShareModalOpen, setIsShareModalOpen] = useState(false); + const { getUser } = useUserApi(); + + + const { data: userData } = useQuery({ + queryKey: [USER_DATA_KEY], + queryFn: getUser, + }); const defaultValues = { ...defaultBrainConfig, name: "", @@ -32,6 +46,10 @@ export const useAddBrainModal = () => { }, }; + + + + const { register, getValues, reset, watch, setValue } = useForm({ defaultValues, }); @@ -41,6 +59,11 @@ export const useAddBrainModal = () => { const temperature = watch("temperature"); const maxTokens = watch("maxTokens"); + const accessibleModels = getAccessibleModels({ + openAiKey, + userData, + }); + useEffect(() => { setValue("maxTokens", Math.min(maxTokens, defineMaxTokens(model))); }, [maxTokens, model, setValue]); @@ -155,6 +178,7 @@ export const useAddBrainModal = () => { temperature, maxTokens, isPending, + accessibleModels, pickPublicPrompt, }; };