diff --git a/backend/llm/openai.py b/backend/llm/openai.py
index 9857c6693..dd6ab9ef0 100644
--- a/backend/llm/openai.py
+++ b/backend/llm/openai.py
@@ -2,9 +2,8 @@ from typing import Optional
from uuid import UUID
from langchain.embeddings.openai import OpenAIEmbeddings
-from logger import get_logger
-
from llm.qa_base import QABaseBrainPicking
+from logger import get_logger
logger = get_logger(__name__)
@@ -16,7 +15,7 @@ class OpenAIBrainPicking(QABaseBrainPicking):
"""
# Default class attributes
- model: str = "gpt-3.5-turbo"
+ model: str
def __init__(
self,
diff --git a/backend/models/brains.py b/backend/models/brains.py
index ea3b88771..e695bcf80 100644
--- a/backend/models/brains.py
+++ b/backend/models/brains.py
@@ -16,7 +16,7 @@ class Brain(BaseModel):
name: Optional[str] = "Default brain"
description: Optional[str] = "This is a description"
status: Optional[str] = "private"
- model: Optional[str] = "gpt-3.5-turbo"
+ model: Optional[str]
temperature: Optional[float] = 0.0
max_tokens: Optional[int] = 256
openai_api_key: Optional[str] = None
diff --git a/backend/models/chats.py b/backend/models/chats.py
index 2b2d8a115..8674ad021 100644
--- a/backend/models/chats.py
+++ b/backend/models/chats.py
@@ -5,7 +5,7 @@ from pydantic import BaseModel
class ChatMessage(BaseModel):
- model: str = "gpt-3.5-turbo-16k"
+ model: str
question: str
# A list of tuples where each tuple is (speaker, text)
history: List[Tuple[str, str]]
diff --git a/backend/models/databases/supabase/brains.py b/backend/models/databases/supabase/brains.py
index 12303cbb1..26773d3d6 100644
--- a/backend/models/databases/supabase/brains.py
+++ b/backend/models/databases/supabase/brains.py
@@ -13,7 +13,7 @@ class CreateBrainProperties(BaseModel):
name: Optional[str] = "Default brain"
description: Optional[str] = "This is a description"
status: Optional[str] = "private"
- model: Optional[str] = "gpt-3.5-turbo"
+ model: Optional[str]
temperature: Optional[float] = 0.0
max_tokens: Optional[int] = 256
openai_api_key: Optional[str] = None
diff --git a/backend/routes/chat_routes.py b/backend/routes/chat_routes.py
index 8a2e128c0..ffff085d2 100644
--- a/backend/routes/chat_routes.py
+++ b/backend/routes/chat_routes.py
@@ -185,6 +185,15 @@ async def create_question_handler(
# Retrieve user's OpenAI API key
current_user.openai_api_key = request.headers.get("Openai-Api-Key")
brain = Brain(id=brain_id)
+ brain_details: BrainEntity | None = None
+
+ userDailyUsage = UserUsage(
+ id=current_user.id,
+ email=current_user.email,
+ openai_api_key=current_user.openai_api_key,
+ )
+ userSettings = userDailyUsage.get_user_settings()
+ is_model_ok = (brain_details or chat_question).model in userSettings.models # type: ignore
if not current_user.openai_api_key and brain_id:
brain_details = get_brain_details(brain_id)
@@ -210,12 +219,12 @@ async def create_question_handler(
try:
check_user_requests_limit(current_user)
-
+ is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
gpt_answer_generator: HeadlessQA | OpenAIBrainPicking
if brain_id:
gpt_answer_generator = OpenAIBrainPicking(
chat_id=str(chat_id),
- model=chat_question.model,
+ model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
max_tokens=chat_question.max_tokens,
temperature=chat_question.temperature,
brain_id=str(brain_id),
@@ -224,7 +233,7 @@ async def create_question_handler(
)
else:
gpt_answer_generator = HeadlessQA(
- model=chat_question.model,
+ model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
temperature=chat_question.temperature,
max_tokens=chat_question.max_tokens,
user_openai_api_key=current_user.openai_api_key,
@@ -264,6 +273,13 @@ async def create_stream_question_handler(
current_user.openai_api_key = request.headers.get("Openai-Api-Key")
brain = Brain(id=brain_id)
brain_details: BrainEntity | None = None
+ userDailyUsage = UserUsage(
+ id=current_user.id,
+ email=current_user.email,
+ openai_api_key=current_user.openai_api_key,
+ )
+
+ userSettings = userDailyUsage.get_user_settings()
if not current_user.openai_api_key and brain_id:
brain_details = get_brain_details(brain_id)
if brain_details:
@@ -290,18 +306,15 @@ async def create_stream_question_handler(
logger.info(f"Streaming request for {chat_question.model}")
check_user_requests_limit(current_user)
gpt_answer_generator: HeadlessQA | OpenAIBrainPicking
+ # TODO check if model is in the list of models available for the user
+ print(userSettings.get("models", ["gpt-3.5-turbo"])) # type: ignore
+ is_model_ok = (brain_details or chat_question).model in userSettings.get("models", ["gpt-3.5-turbo"]) # type: ignore
if brain_id:
gpt_answer_generator = OpenAIBrainPicking(
chat_id=str(chat_id),
- model=(brain_details or chat_question).model
- if current_user.openai_api_key
- else "gpt-3.5-turbo", # type: ignore
- max_tokens=(brain_details or chat_question).max_tokens
- if current_user.openai_api_key
- else 256, # type: ignore
- temperature=(brain_details or chat_question).temperature
- if current_user.openai_api_key
- else 0, # type: ignore
+ model=(brain_details or chat_question).model if is_model_ok else "gpt-3.5-turbo", # type: ignore
+ max_tokens=(brain_details or chat_question).max_tokens, # type: ignore
+ temperature=(brain_details or chat_question).temperature, # type: ignore
brain_id=str(brain_id),
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=True,
@@ -309,15 +322,9 @@ async def create_stream_question_handler(
)
else:
gpt_answer_generator = HeadlessQA(
- model=chat_question.model
- if current_user.openai_api_key
- else "gpt-3.5-turbo",
- temperature=chat_question.temperature
- if current_user.openai_api_key
- else 0,
- max_tokens=chat_question.max_tokens
- if current_user.openai_api_key
- else 256,
+ model=chat_question.model if is_model_ok else "gpt-3.5-turbo", # type: ignore
+ temperature=chat_question.temperature,
+ max_tokens=chat_question.max_tokens,
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
chat_id=str(chat_id),
streaming=True,
diff --git a/frontend/lib/components/AddBrainModal/AddBrainModal.tsx b/frontend/lib/components/AddBrainModal/AddBrainModal.tsx
index fee22bbc6..340dd7560 100644
--- a/frontend/lib/components/AddBrainModal/AddBrainModal.tsx
+++ b/frontend/lib/components/AddBrainModal/AddBrainModal.tsx
@@ -1,5 +1,5 @@
/* eslint-disable @typescript-eslint/no-unsafe-assignment */
-/* eslint-disable max-lines */
+/* eslint-disable */
import { useTranslation } from "react-i18next";
import { MdAdd } from "react-icons/md";
@@ -8,11 +8,10 @@ import Button from "@/lib/components/ui/Button";
import Field from "@/lib/components/ui/Field";
import { Modal } from "@/lib/components/ui/Modal";
import { defineMaxTokens } from "@/lib/helpers/defineMaxTokens";
-import { freeModels, paidModels } from "@/lib/types/brainConfig";
-import { useAddBrainModal } from "./hooks/useAddBrainModal";
import { Divider } from "../ui/Divider";
import { TextArea } from "../ui/TextArea";
+import { useAddBrainModal } from "./hooks/useAddBrainModal";
export const AddBrainModal = (): JSX.Element => {
const { t } = useTranslation(["translation", "brain", "config"]);
@@ -27,6 +26,7 @@ export const AddBrainModal = (): JSX.Element => {
model,
isPending,
pickPublicPrompt,
+ accessibleModels,
} = useAddBrainModal();
return (
@@ -89,13 +89,11 @@ export const AddBrainModal = (): JSX.Element => {
{...register("model")}
className="px-5 py-2 dark:bg-gray-700 bg-gray-200 rounded-md"
>
- {(openAiKey !== undefined ? paidModels : freeModels).map(
- (availableModel) => (
-
- )
- )}
+ {accessibleModels.map((availableModel) => (
+
+ ))}
diff --git a/frontend/lib/components/AddBrainModal/hooks/useAddBrainModal.ts b/frontend/lib/components/AddBrainModal/hooks/useAddBrainModal.ts
index 33ac04443..a60a8ed09 100644
--- a/frontend/lib/components/AddBrainModal/hooks/useAddBrainModal.ts
+++ b/frontend/lib/components/AddBrainModal/hooks/useAddBrainModal.ts
@@ -1,4 +1,4 @@
-/* eslint-disable max-lines */
+/* eslint-disable */
import axios from "axios";
import { useEffect, useState } from "react";
import { useForm } from "react-hook-form";
@@ -6,10 +6,17 @@ import { useTranslation } from "react-i18next";
import { useBrainApi } from "@/lib/api/brain/useBrainApi";
import { usePromptApi } from "@/lib/api/prompt/usePromptApi";
+import { USER_DATA_KEY } from "@/lib/api/user/config";
+import { useUserApi } from "@/lib/api/user/useUserApi";
import { defaultBrainConfig } from "@/lib/config/defaultBrainConfig";
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
import { defineMaxTokens } from "@/lib/helpers/defineMaxTokens";
+import { getAccessibleModels } from "@/lib/helpers/getAccessibleModels";
import { useToast } from "@/lib/hooks";
+import { useQuery } from "@tanstack/react-query";
+
+
+
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export const useAddBrainModal = () => {
@@ -21,6 +28,13 @@ export const useAddBrainModal = () => {
const { createPrompt } = usePromptApi();
const [isShareModalOpen, setIsShareModalOpen] = useState(false);
+ const { getUser } = useUserApi();
+
+
+ const { data: userData } = useQuery({
+ queryKey: [USER_DATA_KEY],
+ queryFn: getUser,
+ });
const defaultValues = {
...defaultBrainConfig,
name: "",
@@ -32,6 +46,10 @@ export const useAddBrainModal = () => {
},
};
+
+
+
+
const { register, getValues, reset, watch, setValue } = useForm({
defaultValues,
});
@@ -41,6 +59,11 @@ export const useAddBrainModal = () => {
const temperature = watch("temperature");
const maxTokens = watch("maxTokens");
+ const accessibleModels = getAccessibleModels({
+ openAiKey,
+ userData,
+ });
+
useEffect(() => {
setValue("maxTokens", Math.min(maxTokens, defineMaxTokens(model)));
}, [maxTokens, model, setValue]);
@@ -155,6 +178,7 @@ export const useAddBrainModal = () => {
temperature,
maxTokens,
isPending,
+ accessibleModels,
pickPublicPrompt,
};
};