mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 17:03:29 +03:00
feat: the good user management (#1158)
* feat(user_management): added user management * feat(user_management): added user management * feat(user_management): removed print * feat: use tanstack query for user data fecthing * feat: add getUser to sdk * feat: improve user page ux use tanstack query * feat: fetch models from backend on brains settings page * feat: update model selection on chat page * feat: update tests --------- Co-authored-by: mamadoudicko <mamadoudicko100@gmail.com>
This commit is contained in:
parent
9eaba81288
commit
322ee318be
@ -2,13 +2,12 @@ from typing import Any, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from logger import get_logger
|
||||
from models.databases.supabase.supabase import SupabaseDB
|
||||
from models.settings import get_supabase_client, get_supabase_db
|
||||
from pydantic import BaseModel
|
||||
from supabase.client import Client
|
||||
from utils.vectors import get_unique_files_from_vector_ids
|
||||
|
||||
from models.databases.supabase.supabase import SupabaseDB
|
||||
from models.settings import BrainRateLimiting, get_supabase_client, get_supabase_db
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ -22,7 +21,6 @@ class Brain(BaseModel):
|
||||
max_tokens: Optional[int] = 256
|
||||
openai_api_key: Optional[str] = None
|
||||
files: List[Any] = []
|
||||
max_brain_size = BrainRateLimiting().max_brain_size
|
||||
prompt_id: Optional[UUID] = None
|
||||
|
||||
class Config:
|
||||
@ -43,13 +41,6 @@ class Brain(BaseModel):
|
||||
|
||||
return current_brain_size
|
||||
|
||||
@property
|
||||
def remaining_brain_size(self):
|
||||
return (
|
||||
float(self.max_brain_size) # pyright: ignore reportPrivateUsage=none
|
||||
- self.brain_size # pyright: ignore reportPrivateUsage=none
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(cls, *args, **kwargs):
|
||||
commons = {"supabase": get_supabase_client()}
|
||||
|
@ -25,6 +25,32 @@ class UserUsage(Repository):
|
||||
.execute()
|
||||
)
|
||||
|
||||
def get_user_settings(self, user_id):
|
||||
"""
|
||||
Fetch the user settings from the database
|
||||
"""
|
||||
response = (
|
||||
self.db.from_("user_settings")
|
||||
.select("*")
|
||||
.filter("user_id", "eq", str(user_id))
|
||||
.execute()
|
||||
).data
|
||||
|
||||
if len(response) == 0:
|
||||
# Create the user settings
|
||||
result = (
|
||||
self.db.table("user_settings")
|
||||
.insert({"user_id": str(user_id)})
|
||||
.execute()
|
||||
)
|
||||
if result:
|
||||
return self.get_user_settings(user_id)
|
||||
else:
|
||||
raise ValueError("User settings could not be created")
|
||||
if response and len(response) > 0:
|
||||
return response[0]
|
||||
return None
|
||||
|
||||
def get_user_usage(self, user_id):
|
||||
"""
|
||||
Fetch the user request stats from the database
|
||||
|
@ -6,7 +6,6 @@ from vectorstore.supabase import SupabaseVectorStore
|
||||
|
||||
|
||||
class BrainRateLimiting(BaseSettings):
|
||||
max_brain_size: int = 52428800
|
||||
max_brain_per_user: int = 5
|
||||
|
||||
|
||||
|
@ -24,6 +24,14 @@ class UserUsage(UserIdentity):
|
||||
|
||||
return request
|
||||
|
||||
def get_user_settings(self):
|
||||
"""
|
||||
Fetch the user settings from the database
|
||||
"""
|
||||
request = self.supabase_db.get_user_settings(self.id)
|
||||
|
||||
return request
|
||||
|
||||
def handle_increment_user_request_count(self, date):
|
||||
"""
|
||||
Increment the user request count in the database
|
||||
|
@ -3,17 +3,23 @@ from uuid import UUID
|
||||
from auth import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from logger import get_logger
|
||||
from models import BrainRateLimiting, UserIdentity
|
||||
from models.databases.supabase.brains import (BrainQuestionRequest,
|
||||
BrainUpdatableProperties,
|
||||
CreateBrainProperties)
|
||||
from repository.brain import (create_brain, create_brain_user,
|
||||
get_brain_details,
|
||||
get_default_user_brain_or_create_new,
|
||||
get_question_context_from_brain, get_user_brains,
|
||||
get_user_default_brain,
|
||||
set_as_default_brain_for_user,
|
||||
update_brain_by_id)
|
||||
from models import UserIdentity, UserUsage
|
||||
from models.databases.supabase.brains import (
|
||||
BrainQuestionRequest,
|
||||
BrainUpdatableProperties,
|
||||
CreateBrainProperties,
|
||||
)
|
||||
from repository.brain import (
|
||||
create_brain,
|
||||
create_brain_user,
|
||||
get_brain_details,
|
||||
get_default_user_brain_or_create_new,
|
||||
get_question_context_from_brain,
|
||||
get_user_brains,
|
||||
get_user_default_brain,
|
||||
set_as_default_brain_for_user,
|
||||
update_brain_by_id,
|
||||
)
|
||||
from repository.prompt import delete_prompt_by_id, get_prompt_by_id
|
||||
from routes.authorizations.brain_authorization import has_brain_authorization
|
||||
from routes.authorizations.types import RoleEnum
|
||||
@ -105,12 +111,17 @@ async def create_brain_endpoint(
|
||||
"""
|
||||
|
||||
user_brains = get_user_brains(current_user.id)
|
||||
max_brain_per_user = BrainRateLimiting().max_brain_per_user
|
||||
userDailyUsage = UserUsage(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
openai_api_key=current_user.openai_api_key,
|
||||
)
|
||||
userSettings = userDailyUsage.get_user_settings()
|
||||
|
||||
if len(user_brains) >= max_brain_per_user:
|
||||
if len(user_brains) >= userSettings.get("max_brains", 5):
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Maximum number of brains reached ({max_brain_per_user}).",
|
||||
detail=f"Maximum number of brains reached ({userSettings.get('max_brains', 5)}).",
|
||||
)
|
||||
|
||||
new_brain = create_brain(
|
||||
|
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import time
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
@ -7,9 +6,6 @@ from venv import logger
|
||||
from auth import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from repository.notification.remove_chat_notifications import (
|
||||
remove_chat_notifications,
|
||||
)
|
||||
from llm.openai import OpenAIBrainPicking
|
||||
from llm.qa_headless import HeadlessQA
|
||||
from models import (
|
||||
@ -36,6 +32,7 @@ from repository.chat.get_chat_history_with_notifications import (
|
||||
ChatItem,
|
||||
get_chat_history_with_notifications,
|
||||
)
|
||||
from repository.notification.remove_chat_notifications import remove_chat_notifications
|
||||
from repository.user_identity import get_user_identity
|
||||
|
||||
chat_router = APIRouter()
|
||||
@ -76,11 +73,13 @@ def check_user_requests_limit(
|
||||
id=user.id, email=user.email, openai_api_key=user.openai_api_key
|
||||
)
|
||||
|
||||
userSettings = userDailyUsage.get_user_settings()
|
||||
|
||||
date = time.strftime("%Y%m%d")
|
||||
userDailyUsage.handle_increment_user_request_count(date)
|
||||
|
||||
if user.openai_api_key is None:
|
||||
max_requests_number = int(os.getenv("MAX_REQUESTS_NUMBER", 1))
|
||||
max_requests_number = userSettings.get("max_requests_number", 0)
|
||||
if int(userDailyUsage.daily_requests_count) >= int(max_requests_number):
|
||||
raise HTTPException(
|
||||
status_code=429, # pyright: ignore reportPrivateUsage=none
|
||||
|
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import shutil
|
||||
from tempfile import SpooledTemporaryFile
|
||||
from typing import Optional
|
||||
@ -7,7 +6,7 @@ from uuid import UUID
|
||||
from auth import AuthBearer, get_current_user
|
||||
from crawl.crawler import CrawlWebsite
|
||||
from fastapi import APIRouter, Depends, Query, Request, UploadFile
|
||||
from models import Brain, File, UserIdentity
|
||||
from models import Brain, File, UserIdentity, UserUsage
|
||||
from models.databases.supabase.notifications import (
|
||||
CreateNotificationProperties,
|
||||
NotificationUpdatableProperties,
|
||||
@ -15,9 +14,7 @@ from models.databases.supabase.notifications import (
|
||||
from models.notifications import NotificationsStatusEnum
|
||||
from parsers.github import process_github
|
||||
from repository.notification.add_notification import add_notification
|
||||
from repository.notification.update_notification import (
|
||||
update_notification_by_id,
|
||||
)
|
||||
from repository.notification.update_notification import update_notification_by_id
|
||||
from utils.file import convert_bytes
|
||||
from utils.processors import filter_file
|
||||
|
||||
@ -45,12 +42,19 @@ async def crawl_endpoint(
|
||||
# [TODO] check if the user is the owner/editor of the brain
|
||||
brain = Brain(id=brain_id)
|
||||
|
||||
userDailyUsage = UserUsage(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
openai_api_key=current_user.openai_api_key,
|
||||
)
|
||||
userSettings = userDailyUsage.get_user_settings()
|
||||
|
||||
# [TODO] rate limiting of user for crawl
|
||||
if request.headers.get("Openai-Api-Key"):
|
||||
brain.max_brain_size = int(os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200))
|
||||
brain.max_brain_size = userSettings.get("max_brain_size", 1000000000)
|
||||
|
||||
file_size = 1000000
|
||||
remaining_free_space = brain.remaining_brain_size
|
||||
remaining_free_space = userSettings.get("max_brain_size", 1000000000)
|
||||
|
||||
if remaining_free_space - file_size < 0:
|
||||
message = {
|
||||
|
@ -1,10 +1,9 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from auth import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, Query, Request, UploadFile
|
||||
from models import Brain, File, UserIdentity
|
||||
from models import Brain, File, UserIdentity, UserUsage
|
||||
from models.databases.supabase.notifications import (
|
||||
CreateNotificationProperties,
|
||||
NotificationUpdatableProperties,
|
||||
@ -12,17 +11,14 @@ from models.databases.supabase.notifications import (
|
||||
from models.notifications import NotificationsStatusEnum
|
||||
from repository.brain import get_brain_details
|
||||
from repository.notification.add_notification import add_notification
|
||||
from repository.notification.update_notification import (
|
||||
update_notification_by_id,
|
||||
)
|
||||
from repository.notification.update_notification import update_notification_by_id
|
||||
from repository.user_identity import get_user_identity
|
||||
from utils.file import convert_bytes, get_file_size
|
||||
from utils.processors import filter_file
|
||||
|
||||
from routes.authorizations.brain_authorization import (
|
||||
RoleEnum,
|
||||
validate_brain_authorization,
|
||||
)
|
||||
from utils.file import convert_bytes, get_file_size
|
||||
from utils.processors import filter_file
|
||||
|
||||
upload_router = APIRouter()
|
||||
|
||||
@ -58,11 +54,17 @@ async def upload_file(
|
||||
)
|
||||
|
||||
brain = Brain(id=brain_id)
|
||||
userDailyUsage = UserUsage(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
openai_api_key=current_user.openai_api_key,
|
||||
)
|
||||
userSettings = userDailyUsage.get_user_settings()
|
||||
|
||||
if request.headers.get("Openai-Api-Key"):
|
||||
brain.max_brain_size = int(os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200))
|
||||
brain.max_brain_size = userSettings.get("max_brain_size", 1000000000)
|
||||
|
||||
remaining_free_space = brain.remaining_brain_size
|
||||
remaining_free_space = userSettings.get("max_brain_size", 1000000000)
|
||||
|
||||
file_size = get_file_size(uploadFile)
|
||||
|
||||
|
@ -1,9 +1,8 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
from auth import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from models import Brain, BrainRateLimiting, UserIdentity, UserUsage
|
||||
from models import Brain, UserIdentity, UserUsage
|
||||
from repository.brain import get_user_default_brain
|
||||
from repository.user_identity.get_user_identity import get_user_identity
|
||||
from repository.user_identity.update_user_properties import (
|
||||
@ -13,8 +12,6 @@ from repository.user_identity.update_user_properties import (
|
||||
|
||||
user_router = APIRouter()
|
||||
|
||||
MAX_BRAIN_SIZE_WITH_OWN_KEY = int(os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200))
|
||||
|
||||
|
||||
@user_router.get("/user", dependencies=[Depends(AuthBearer())], tags=["User"])
|
||||
async def get_user_endpoint(
|
||||
@ -32,13 +29,16 @@ async def get_user_endpoint(
|
||||
information about the user's API usage.
|
||||
"""
|
||||
|
||||
max_brain_size = BrainRateLimiting().max_brain_size
|
||||
|
||||
if request.headers.get("Openai-Api-Key"):
|
||||
max_brain_size = MAX_BRAIN_SIZE_WITH_OWN_KEY
|
||||
userDailyUsage = UserUsage(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
openai_api_key=current_user.openai_api_key,
|
||||
)
|
||||
userSettings = userDailyUsage.get_user_settings()
|
||||
max_brain_size = userSettings.get("max_brain_size", 1000000000)
|
||||
|
||||
date = time.strftime("%Y%m%d")
|
||||
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
|
||||
max_requests_number = userSettings.get("max_requests_number", 10)
|
||||
|
||||
userDailyUsage = UserUsage(id=current_user.id)
|
||||
requests_stats = userDailyUsage.get_user_usage()
|
||||
@ -55,6 +55,7 @@ async def get_user_endpoint(
|
||||
"current_brain_size": defaul_brain_size,
|
||||
"max_requests_number": max_requests_number,
|
||||
"requests_stats": requests_stats,
|
||||
"models": userSettings.get("models", []),
|
||||
"date": date,
|
||||
"id": current_user.id,
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ services:
|
||||
context: backend
|
||||
dockerfile: Dockerfile
|
||||
container_name: backend-core
|
||||
command: uvicorn main:app --host 0.0.0.0 --port 5050
|
||||
command: uvicorn main:app --reload --host 0.0.0.0 --port 5050
|
||||
restart: always
|
||||
volumes:
|
||||
- ./backend/:/code/
|
||||
@ -49,7 +49,7 @@ services:
|
||||
context: backend
|
||||
dockerfile: Dockerfile
|
||||
container_name: backend-chat
|
||||
command: uvicorn chat_service:app --host 0.0.0.0 --port 5050
|
||||
command: uvicorn chat_service:app --reload --host 0.0.0.0 --port 5050
|
||||
restart: always
|
||||
volumes:
|
||||
- ./backend/:/code/
|
||||
@ -66,7 +66,7 @@ services:
|
||||
context: backend
|
||||
dockerfile: Dockerfile
|
||||
container_name: backend-crawl
|
||||
command: uvicorn crawl_service:app --host 0.0.0.0 --port 5050
|
||||
command: uvicorn crawl_service:app --reload --host 0.0.0.0 --port 5050
|
||||
restart: always
|
||||
volumes:
|
||||
- ./backend/:/code/
|
||||
@ -83,7 +83,7 @@ services:
|
||||
context: backend
|
||||
dockerfile: Dockerfile
|
||||
container_name: backend-upload
|
||||
command: uvicorn upload_service:app --host 0.0.0.0 --port 5050
|
||||
command: uvicorn upload_service:app --reload --host 0.0.0.0 --port 5050
|
||||
restart: always
|
||||
volumes:
|
||||
- ./backend/:/code/
|
||||
|
@ -9,7 +9,6 @@ import { Divider } from "@/lib/components/ui/Divider";
|
||||
import Field from "@/lib/components/ui/Field";
|
||||
import { TextArea } from "@/lib/components/ui/TextArea";
|
||||
import { defineMaxTokens } from "@/lib/helpers/defineMaxTokens";
|
||||
import { freeModels, paidModels } from "@/lib/types/brainConfig";
|
||||
import { SaveButton } from "@/shared/SaveButton";
|
||||
|
||||
import { PublicPrompts } from "./components/PublicPrompts/PublicPrompts";
|
||||
@ -24,7 +23,7 @@ export const SettingsTab = ({ brainId }: SettingsTabProps): JSX.Element => {
|
||||
const {
|
||||
handleSubmit,
|
||||
register,
|
||||
openAiKey,
|
||||
|
||||
temperature,
|
||||
maxTokens,
|
||||
model,
|
||||
@ -36,6 +35,7 @@ export const SettingsTab = ({ brainId }: SettingsTabProps): JSX.Element => {
|
||||
promptId,
|
||||
pickPublicPrompt,
|
||||
removeBrainPrompt,
|
||||
accessibleModels,
|
||||
} = useSettingsTab({ brainId });
|
||||
|
||||
return (
|
||||
@ -102,13 +102,11 @@ export const SettingsTab = ({ brainId }: SettingsTabProps): JSX.Element => {
|
||||
void handleSubmit(false); // Trigger form submission
|
||||
}}
|
||||
>
|
||||
{(openAiKey !== undefined ? paidModels : freeModels).map(
|
||||
(availableModel) => (
|
||||
<option value={availableModel} key={availableModel}>
|
||||
{availableModel}
|
||||
</option>
|
||||
)
|
||||
)}
|
||||
{accessibleModels.map((availableModel) => (
|
||||
<option value={availableModel} key={availableModel}>
|
||||
{availableModel}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</fieldset>
|
||||
<fieldset className="w-full flex mt-4">
|
||||
|
@ -10,10 +10,13 @@ import { useTranslation } from "react-i18next";
|
||||
import { getBrainDataKey } from "@/lib/api/brain/config";
|
||||
import { useBrainApi } from "@/lib/api/brain/useBrainApi";
|
||||
import { usePromptApi } from "@/lib/api/prompt/usePromptApi";
|
||||
import { USER_DATA_KEY } from "@/lib/api/user/config";
|
||||
import { useUserApi } from "@/lib/api/user/useUserApi";
|
||||
import { defaultBrainConfig } from "@/lib/config/defaultBrainConfig";
|
||||
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
|
||||
import { Brain } from "@/lib/context/BrainProvider/types";
|
||||
import { defineMaxTokens } from "@/lib/helpers/defineMaxTokens";
|
||||
import { getAccessibleModels } from "@/lib/helpers/getAccessibleModels";
|
||||
import { useToast } from "@/lib/hooks";
|
||||
|
||||
import { validateOpenAIKey } from "../utils/validateOpenAIKey";
|
||||
@ -33,6 +36,12 @@ export const useSettingsTab = ({ brainId }: UseSettingsTabProps) => {
|
||||
const { fetchAllBrains, fetchDefaultBrain, defaultBrainId } =
|
||||
useBrainContext();
|
||||
const { getPrompt, updatePrompt, createPrompt } = usePromptApi();
|
||||
const { getUser } = useUserApi();
|
||||
|
||||
const { data: userData } = useQuery({
|
||||
queryKey: [USER_DATA_KEY],
|
||||
queryFn: getUser,
|
||||
});
|
||||
|
||||
const defaultValues = {
|
||||
...defaultBrainConfig,
|
||||
@ -69,6 +78,11 @@ export const useSettingsTab = ({ brainId }: UseSettingsTabProps) => {
|
||||
const temperature = watch("temperature");
|
||||
const maxTokens = watch("maxTokens");
|
||||
|
||||
const accessibleModels = getAccessibleModels({
|
||||
openAiKey,
|
||||
userData,
|
||||
});
|
||||
|
||||
const updateFormValues = useCallback(() => {
|
||||
if (brain === undefined) {
|
||||
return;
|
||||
@ -336,7 +350,7 @@ export const useSettingsTab = ({ brainId }: UseSettingsTabProps) => {
|
||||
return {
|
||||
handleSubmit,
|
||||
register,
|
||||
openAiKey,
|
||||
|
||||
model,
|
||||
temperature,
|
||||
maxTokens,
|
||||
@ -348,5 +362,6 @@ export const useSettingsTab = ({ brainId }: UseSettingsTabProps) => {
|
||||
promptId,
|
||||
removeBrainPrompt,
|
||||
pickPublicPrompt,
|
||||
accessibleModels,
|
||||
};
|
||||
};
|
||||
|
@ -39,6 +39,18 @@ vi.mock("@/lib/api/chat/useChatApi", () => ({
|
||||
getHistory: () => [],
|
||||
}),
|
||||
}));
|
||||
vi.mock("@tanstack/react-query", async () => {
|
||||
const actual = await vi.importActual<typeof import("@tanstack/react-query")>(
|
||||
"@tanstack/react-query"
|
||||
);
|
||||
|
||||
return {
|
||||
...actual,
|
||||
useQuery: () => ({
|
||||
data: {},
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
describe("Chat page", () => {
|
||||
it("should render chat page correctly", () => {
|
||||
|
@ -4,7 +4,6 @@ import { MdCheck, MdSettings } from "react-icons/md";
|
||||
import Button from "@/lib/components/ui/Button";
|
||||
import { Modal } from "@/lib/components/ui/Modal";
|
||||
import { defineMaxTokens } from "@/lib/helpers/defineMaxTokens";
|
||||
import { freeModels } from "@/lib/types/brainConfig";
|
||||
|
||||
import { useConfigModal } from "./hooks/useConfigModal";
|
||||
|
||||
@ -17,6 +16,7 @@ export const ConfigModal = ({ chatId }: { chatId?: string }): JSX.Element => {
|
||||
temperature,
|
||||
maxTokens,
|
||||
model,
|
||||
accessibleModels,
|
||||
} = useConfigModal(chatId);
|
||||
|
||||
if (chatId === undefined) {
|
||||
@ -56,7 +56,7 @@ export const ConfigModal = ({ chatId }: { chatId?: string }): JSX.Element => {
|
||||
{...register("model")}
|
||||
className="px-5 py-2 dark:bg-gray-700 bg-gray-200 rounded-md"
|
||||
>
|
||||
{freeModels.map((availableModel) => (
|
||||
{accessibleModels.map((availableModel) => (
|
||||
<option value={availableModel} key={availableModel}>
|
||||
{availableModel}
|
||||
</option>
|
||||
|
@ -1,4 +1,5 @@
|
||||
/* eslint-disable max-lines */
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { FormEvent, useEffect, useState } from "react";
|
||||
import { useForm } from "react-hook-form";
|
||||
|
||||
@ -7,10 +8,13 @@ import {
|
||||
getChatConfigFromLocalStorage,
|
||||
saveChatConfigInLocalStorage,
|
||||
} from "@/lib/api/chat/chat.local";
|
||||
import { USER_DATA_KEY, USER_IDENTITY_DATA_KEY } from "@/lib/api/user/config";
|
||||
import { useUserApi } from "@/lib/api/user/useUserApi";
|
||||
import { defaultBrainConfig } from "@/lib/config/defaultBrainConfig";
|
||||
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
|
||||
import { ChatConfig } from "@/lib/context/ChatProvider/types";
|
||||
import { defineMaxTokens } from "@/lib/helpers/defineMaxTokens";
|
||||
import { getAccessibleModels } from "@/lib/helpers/getAccessibleModels";
|
||||
import { useToast } from "@/lib/hooks";
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
|
||||
@ -19,6 +23,16 @@ export const useConfigModal = (chatId?: string) => {
|
||||
const [isConfigModalOpen, setIsConfigModalOpen] = useState(false);
|
||||
const { getBrain } = useBrainApi();
|
||||
const { currentBrain } = useBrainContext();
|
||||
const { getUser, getUserIdentity } = useUserApi();
|
||||
|
||||
const { data: userData } = useQuery({
|
||||
queryKey: [USER_DATA_KEY],
|
||||
queryFn: getUser,
|
||||
});
|
||||
const { data: userIdentityData } = useQuery({
|
||||
queryKey: [USER_IDENTITY_DATA_KEY],
|
||||
queryFn: getUserIdentity,
|
||||
});
|
||||
|
||||
const defaultValues: ChatConfig = {};
|
||||
|
||||
@ -30,6 +44,11 @@ export const useConfigModal = (chatId?: string) => {
|
||||
const temperature = watch("temperature");
|
||||
const maxTokens = watch("maxTokens");
|
||||
|
||||
const accessibleModels = getAccessibleModels({
|
||||
openAiKey: userIdentityData?.openai_api_key,
|
||||
userData,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
const fetchChatConfig = async () => {
|
||||
if (chatId === undefined) {
|
||||
@ -104,5 +123,6 @@ export const useConfigModal = (chatId?: string) => {
|
||||
model,
|
||||
temperature,
|
||||
maxTokens,
|
||||
accessibleModels,
|
||||
};
|
||||
};
|
||||
|
@ -1,3 +1,4 @@
|
||||
/* eslint-disable max-lines */
|
||||
import { act, renderHook } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
@ -46,6 +47,21 @@ vi.mock("@/lib/hooks", async () => {
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("@tanstack/react-query", async () => {
|
||||
const actual = await vi.importActual<typeof import("@tanstack/react-query")>(
|
||||
"@tanstack/react-query"
|
||||
);
|
||||
|
||||
return {
|
||||
...actual,
|
||||
useQuery: () => ({
|
||||
data: {},
|
||||
}),
|
||||
useQueryClient: () => ({
|
||||
invalidateQueries: vi.fn(),
|
||||
}),
|
||||
};
|
||||
});
|
||||
describe("useApiKeyConfig", () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
|
@ -1,9 +1,11 @@
|
||||
/* eslint-disable max-lines */
|
||||
import { useQuery, useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
|
||||
import { validateOpenAIKey } from "@/app/brains-management/[brainId]/components/BrainManagementTabs/components/SettingsTab/utils/validateOpenAIKey";
|
||||
import { useAuthApi } from "@/lib/api/auth/useAuthApi";
|
||||
import { USER_IDENTITY_DATA_KEY } from "@/lib/api/user/config";
|
||||
import { useUserApi } from "@/lib/api/user/useUserApi";
|
||||
import { UserIdentity } from "@/lib/api/user/user";
|
||||
import { useToast } from "@/lib/hooks";
|
||||
@ -23,13 +25,17 @@ export const useApiKeyConfig = () => {
|
||||
const { publish } = useToast();
|
||||
const [userIdentity, setUserIdentity] = useState<UserIdentity>();
|
||||
const { t } = useTranslation(["config"]);
|
||||
const queryClient = useQueryClient();
|
||||
const { data: userData } = useQuery({
|
||||
queryKey: [USER_IDENTITY_DATA_KEY],
|
||||
queryFn: getUserIdentity,
|
||||
});
|
||||
|
||||
const fetchUserIdentity = async () => {
|
||||
setUserIdentity(await getUserIdentity());
|
||||
};
|
||||
useEffect(() => {
|
||||
void fetchUserIdentity();
|
||||
}, []);
|
||||
if (userData !== undefined) {
|
||||
setUserIdentity(userData);
|
||||
}
|
||||
}, [userData]);
|
||||
|
||||
const handleCreateClick = async () => {
|
||||
try {
|
||||
@ -80,7 +86,10 @@ export const useApiKeyConfig = () => {
|
||||
await updateUserIdentity({
|
||||
openai_api_key: openAiApiKey,
|
||||
});
|
||||
void fetchUserIdentity();
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: [USER_IDENTITY_DATA_KEY],
|
||||
});
|
||||
|
||||
publish({
|
||||
variant: "success",
|
||||
text: "OpenAI API Key updated",
|
||||
@ -104,7 +113,9 @@ export const useApiKeyConfig = () => {
|
||||
text: "OpenAI API Key removed",
|
||||
});
|
||||
|
||||
void fetchUserIdentity();
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: [USER_IDENTITY_DATA_KEY],
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
} finally {
|
||||
|
@ -4,46 +4,35 @@ import { useEffect, useState } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
|
||||
import Spinner from "@/lib/components/ui/Spinner";
|
||||
import { useAxios } from "@/lib/hooks";
|
||||
import { UserStats } from "@/lib/types/User";
|
||||
|
||||
import { USER_DATA_KEY } from "@/lib/api/user/config";
|
||||
import { useUserApi } from "@/lib/api/user/useUserApi";
|
||||
import { useSupabase } from "@/lib/context/SupabaseProvider";
|
||||
import { redirectToLogin } from "@/lib/router/redirectToLogin";
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { UserStatistics } from "./components/UserStatistics";
|
||||
|
||||
const UserPage = (): JSX.Element => {
|
||||
const [userStats, setUserStats] = useState<UserStats>();
|
||||
const { session } = useSupabase();
|
||||
const { axiosInstance } = useAxios();
|
||||
const { t } = useTranslation(["translation","user"]);
|
||||
const { t } = useTranslation(["translation", "user"]);
|
||||
const { getUser } = useUserApi();
|
||||
|
||||
const { data: userData } = useQuery({
|
||||
queryKey: [USER_DATA_KEY],
|
||||
queryFn: getUser,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (userData !== undefined) {
|
||||
setUserStats(userData);
|
||||
}
|
||||
}, [userData]);
|
||||
if (session === null) {
|
||||
redirectToLogin();
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
const fetchUserStats = async () => {
|
||||
try {
|
||||
console.log(
|
||||
`Fetching user stats from ${process.env.NEXT_PUBLIC_BACKEND_URL}/user`
|
||||
);
|
||||
const response = await axiosInstance.get<UserStats>(
|
||||
`${process.env.NEXT_PUBLIC_BACKEND_URL}/user`,
|
||||
{
|
||||
headers: {
|
||||
Authorization: `Bearer ${session.access_token}`,
|
||||
},
|
||||
}
|
||||
);
|
||||
setUserStats(response.data);
|
||||
} catch (error) {
|
||||
console.error("Error fetching user stats", error);
|
||||
setUserStats(undefined);
|
||||
}
|
||||
};
|
||||
fetchUserStats();
|
||||
}, [session.access_token]);
|
||||
|
||||
return (
|
||||
<main className="w-full flex flex-col pt-10">
|
||||
<section className="flex flex-col justify-center items-center flex-1 gap-5 h-full">
|
||||
@ -54,7 +43,7 @@ const UserPage = (): JSX.Element => {
|
||||
</>
|
||||
) : (
|
||||
<div className="flex items-center justify-center">
|
||||
<span>{t("fetching", {ns: "user"})}</span>
|
||||
<span>{t("fetching", { ns: "user" })}</span>
|
||||
<Spinner />
|
||||
</div>
|
||||
)}
|
||||
|
@ -1,5 +1,5 @@
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
import { useUserApi } from "../useUserApi";
|
||||
import { UserIdentityUpdatableProperties } from "../user";
|
||||
@ -17,6 +17,9 @@ vi.mock("@/lib/hooks", () => ({
|
||||
}));
|
||||
|
||||
describe("useUserApi", () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
it("should call updateUserIdentity with the correct parameters", async () => {
|
||||
const {
|
||||
result: {
|
||||
@ -45,4 +48,15 @@ describe("useUserApi", () => {
|
||||
expect(axiosGetMock).toHaveBeenCalledTimes(1);
|
||||
expect(axiosGetMock).toHaveBeenCalledWith(`/user/identity`);
|
||||
});
|
||||
it("should call getUser with the correct parameters", async () => {
|
||||
const {
|
||||
result: {
|
||||
current: { getUser },
|
||||
},
|
||||
} = renderHook(() => useUserApi());
|
||||
await getUser();
|
||||
|
||||
expect(axiosGetMock).toHaveBeenCalledTimes(1);
|
||||
expect(axiosGetMock).toHaveBeenCalledWith(`/user`);
|
||||
});
|
||||
});
|
||||
|
2
frontend/lib/api/user/config.ts
Normal file
2
frontend/lib/api/user/config.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export const USER_IDENTITY_DATA_KEY = "user-identity-data";
|
||||
export const USER_DATA_KEY = "user-data";
|
@ -1,6 +1,7 @@
|
||||
import { useAxios } from "@/lib/hooks";
|
||||
|
||||
import {
|
||||
getUser,
|
||||
getUserIdentity,
|
||||
updateUserIdentity,
|
||||
UserIdentityUpdatableProperties,
|
||||
@ -15,5 +16,6 @@ export const useUserApi = () => {
|
||||
userIdentityUpdatableProperties: UserIdentityUpdatableProperties
|
||||
) => updateUserIdentity(userIdentityUpdatableProperties, axiosInstance),
|
||||
getUserIdentity: async () => getUserIdentity(axiosInstance),
|
||||
getUser: async () => getUser(axiosInstance),
|
||||
};
|
||||
};
|
||||
|
@ -1,6 +1,8 @@
|
||||
import { AxiosInstance } from "axios";
|
||||
import { UUID } from "crypto";
|
||||
|
||||
import { UserStats } from "@/lib/types/User";
|
||||
|
||||
export type UserIdentityUpdatableProperties = {
|
||||
openai_api_key?: string | null;
|
||||
};
|
||||
@ -23,3 +25,7 @@ export const getUserIdentity = async (
|
||||
|
||||
return data;
|
||||
};
|
||||
|
||||
export const getUser = async (
|
||||
axiosInstance: AxiosInstance
|
||||
): Promise<UserStats> => (await axiosInstance.get<UserStats>("/user")).data;
|
||||
|
20
frontend/lib/helpers/getAccessibleModels.ts
Normal file
20
frontend/lib/helpers/getAccessibleModels.ts
Normal file
@ -0,0 +1,20 @@
|
||||
import { UserStats } from "@/lib/types/User";
|
||||
import { freeModels, paidModels } from "@/lib/types/brainConfig";
|
||||
|
||||
type GetAccessibleModelsInput = {
|
||||
openAiKey?: string | null;
|
||||
userData?: UserStats;
|
||||
};
|
||||
export const getAccessibleModels = ({
|
||||
openAiKey,
|
||||
userData,
|
||||
}: GetAccessibleModelsInput): string[] => {
|
||||
if (userData?.models !== undefined) {
|
||||
return userData.models;
|
||||
}
|
||||
if (openAiKey !== undefined && openAiKey !== null) {
|
||||
return paidModels as unknown as string[];
|
||||
}
|
||||
|
||||
return freeModels as unknown as string[];
|
||||
};
|
@ -11,4 +11,5 @@ export interface UserStats {
|
||||
max_requests_number: number;
|
||||
requests_stats: RequestStat[];
|
||||
date: string;
|
||||
models: string[];
|
||||
}
|
||||
|
19
scripts/202309127004032_add_user_limits.sql
Normal file
19
scripts/202309127004032_add_user_limits.sql
Normal file
@ -0,0 +1,19 @@
|
||||
-- Assuming you have a table named "prompts" with columns: "title", "content", "status"
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_settings (
|
||||
user_id UUID PRIMARY KEY,
|
||||
models JSONB DEFAULT '["gpt-3.5-turbo"]'::jsonb,
|
||||
max_requests_number INT DEFAULT 50,
|
||||
max_brains INT DEFAULT 5,
|
||||
max_brain_size INT DEFAULT 1000000
|
||||
);
|
||||
|
||||
|
||||
-- Update migrations table
|
||||
INSERT INTO migrations (name)
|
||||
SELECT '202309127004032_add_user_limits'
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM migrations WHERE name = '202309127004032_add_user_limits'
|
||||
);
|
||||
|
||||
COMMIT;
|
@ -226,9 +226,16 @@ CREATE TABLE IF NOT EXISTS migrations (
|
||||
executed_at TIMESTAMPTZ DEFAULT current_timestamp
|
||||
);
|
||||
|
||||
INSERT INTO migrations (name)
|
||||
SELECT '20230906151400_add_notifications_table'
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM migrations WHERE name = '20230906151400_add_notifications_table'
|
||||
CREATE TABLE IF NOT EXISTS user_settings (
|
||||
user_id UUID PRIMARY KEY,
|
||||
models JSONB DEFAULT '["gpt-3.5-turbo"]'::jsonb,
|
||||
max_requests_number INT DEFAULT 50,
|
||||
max_brains INT DEFAULT 5,
|
||||
max_brain_size INT DEFAULT 1000000
|
||||
);
|
||||
|
||||
INSERT INTO migrations (name)
|
||||
SELECT '202309127004032_add_user_limits'
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM migrations WHERE name = '202309127004032_add_user_limits'
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user