Fix/requests limiting (#980)

* 🗃️ Rename users table into user_daily_usage

* 💥 replace User model with UserIdentity model

* 🗃️ New UserDailyUsage class for database interaction

* 🐛 fix daily requests rate limiting per user

* 🐛 fix user stats and properties update

* ✏️ add typing and linting

* 🚚 rename user_dialy_usage Class  into user_usage & requests_count into daily_requests_count

* 🚑 fix some rebase errors
This commit is contained in:
Zineb El Bachiri 2023-08-21 14:05:13 +02:00 committed by GitHub
parent f61b70a34f
commit 9aaedcff51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 438 additions and 376 deletions

View File

@ -3,7 +3,7 @@ from uuid import UUID
from fastapi import HTTPException
from models.settings import get_supabase_db
from models.users import User
from models.user_identity import UserIdentity
from pydantic import DateError
@ -33,7 +33,7 @@ async def verify_api_key(
async def get_user_from_api_key(
api_key: str,
) -> User:
) -> UserIdentity:
supabase_db = get_supabase_db()
# Lookup the user_id from the api_keys table
@ -45,7 +45,6 @@ async def get_user_from_api_key(
user_id = user_id_data.data[0]["user_id"]
# Lookup the email from the users table. Todo: remove and use user_id for credentials
user_email_data = supabase_db.get_user_email(user_id)
email = user_email_data.data[0]["email"] if user_email_data.data else None
email = supabase_db.get_user_email(user_id)
return User(email=email, id=user_id)
return UserIdentity(email=email, id=user_id)

View File

@ -5,7 +5,7 @@ from auth.api_key_handler import get_user_from_api_key, verify_api_key
from auth.jwt_token_handler import decode_access_token, verify_token
from fastapi import Depends, HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from models import User
from models import UserIdentity
class AuthBearer(HTTPBearer):
@ -36,7 +36,7 @@ class AuthBearer(HTTPBearer):
async def authenticate(
self,
token: str,
) -> User:
) -> UserIdentity:
if os.environ.get("AUTHENTICATE") == "false":
return self.get_test_user()
elif verify_token(token):
@ -50,11 +50,11 @@ class AuthBearer(HTTPBearer):
else:
raise HTTPException(status_code=401, detail="Invalid token or api key.")
def get_test_user(self) -> User:
return User(
def get_test_user(self) -> UserIdentity:
return UserIdentity(
email="test@example.com", id="XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX" # type: ignore
) # replace with test user information
def get_current_user(user: User = Depends(AuthBearer())) -> User:
def get_current_user(user: UserIdentity = Depends(AuthBearer())) -> UserIdentity:
return user

View File

@ -4,7 +4,7 @@ from typing import Optional
from jose import jwt
from jose.exceptions import JWTError
from models import User
from models import UserIdentity
SECRET_KEY = os.environ.get("JWT_SECRET_KEY")
ALGORITHM = "HS256"
@ -24,7 +24,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
return encoded_jwt
def decode_access_token(token: str) -> User:
def decode_access_token(token: str) -> UserIdentity:
try:
payload = jwt.decode(
token, SECRET_KEY, algorithms=[ALGORITHM], options={"verify_aud": False}
@ -32,7 +32,7 @@ def decode_access_token(token: str) -> User:
except JWTError:
return None # pyright: ignore reportPrivateUsage=none
return User(
return UserIdentity(
email=payload.get("email"),
id=payload.get("sub"), # pyright: ignore reportPrivateUsage=none
)

View File

@ -44,6 +44,7 @@ class BaseBrainPicking(BaseModel):
def _determine_streaming(self, model: str, streaming: bool) -> bool:
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
return streaming
def _determine_callback_array(
self, streaming
) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none
@ -83,7 +84,7 @@ class BaseBrainPicking(BaseModel):
This function should also call: _create_qa, get_chat_history and format_chat_history.
It should also update the chat_history in the DB.
"""
@abstractmethod
async def generate_stream(self, question: str) -> AsyncIterable:
"""

View File

@ -1,21 +1,15 @@
from .files import File
from .users import User
from .brains import Brain
from .chat import Chat, ChatHistory
from .user_identity import UserIdentity
from .prompt import Prompt, PromptStatusEnum
from .chats import ChatQuestion, ChatMessage
from .brain_entity import BrainEntity, MinimalBrainEntity
from .brains import Brain
from .brains_subscription_invitations import BrainSubscription
from .settings import (
BrainRateLimiting,
BrainSettings,
LLMSettings,
get_supabase_db,
get_supabase_client,
get_embeddings,
get_documents_vector_store
)
from .chat import Chat, ChatHistory
from .chats import ChatMessage, ChatQuestion
from .files import File
from .prompt import Prompt, PromptStatusEnum
from .settings import (BrainRateLimiting, BrainSettings, LLMSettings,
get_documents_vector_store, get_embeddings,
get_supabase_client, get_supabase_db)
from .user_identity import UserIdentity
from .user_usage import UserUsage
# TODO uncomment the below import when start using SQLalchemy
# from .sqlalchemy_repository import (

View File

@ -65,19 +65,25 @@ class Repository(ABC):
pass
@abstractmethod
def create_user(self, user_id: UUID, user_email: str, date: datetime):
def create_user_daily_usage(self, user_id: UUID, user_email: str, date: datetime):
pass
@abstractmethod
def get_user_request_stats(self, user_id: UUID):
def get_user_usage(self, user_id: UUID):
pass
@abstractmethod
def fetch_user_requests_count(self, user_id: UUID, date: str):
def get_user_requests_count_for_day(self, user_id: UUID, date: datetime):
pass
@abstractmethod
def update_user_request_count(self, date: str):
def update_user_request_count(self, user_id: UUID, date: str):
pass
@abstractmethod
def increment_user_request_count(
self, user_id: UUID, date: str, current_request_count
):
pass
@abstractmethod
@ -128,10 +134,6 @@ class Repository(ABC):
def get_user_id_by_api_key(self, api_key: UUID):
pass
@abstractmethod
def get_user_stats(self, user_email: str, date: datetime):
pass
@abstractmethod
def create_chat(self, new_chat):
pass

View File

@ -1,8 +1,9 @@
from models.databases.supabase.api_key_handler import ApiKeyHandler
from models.databases.supabase.brains import Brain
from models.databases.supabase.brains_subscription_invitations import BrainSubscription
from models.databases.supabase.brains_subscription_invitations import \
BrainSubscription
from models.databases.supabase.chats import Chats
from models.databases.supabase.files import File
from models.databases.supabase.prompts import Prompts
from models.databases.supabase.users import User
from models.databases.supabase.vectors import Vector
from models.databases.supabase.user_usage import UserUsage
from models.databases.supabase.vectors import Vector

View File

@ -1,5 +1,7 @@
from models.databases.repository import Repository
from datetime import datetime
from uuid import UUID
from models.databases.repository import Repository
class ApiKeyHandler(Repository):
@ -26,7 +28,7 @@ class ApiKeyHandler(Repository):
)
return response
def delete_api_key(self, key_id, user_id):
def delete_api_key(self, key_id: str, user_id: UUID):
return (
self.db.table("api_keys")
.update(
@ -39,7 +41,7 @@ class ApiKeyHandler(Repository):
.execute()
)
def get_active_api_key(self, api_key):
def get_active_api_key(self, api_key: str):
response = (
self.db.table("api_keys")
.select("api_key", "creation_time")
@ -49,7 +51,7 @@ class ApiKeyHandler(Repository):
)
return response
def get_user_id_by_api_key(self, api_key):
def get_user_id_by_api_key(self, api_key: str):
response = (
self.db.table("api_keys")
.select("user_id")
@ -58,7 +60,7 @@ class ApiKeyHandler(Repository):
)
return response
def get_user_api_keys(self, user_id):
def get_user_api_keys(self, user_id: UUID):
response = (
self.db.table("api_keys")
.select("key_id, creation_time")

View File

@ -6,7 +6,7 @@ from models.databases.supabase import (
Chats,
File,
Prompts,
User,
UserUsage,
Vector,
)
@ -14,12 +14,19 @@ logger = get_logger(__name__)
class SupabaseDB(
Brain, User, File, BrainSubscription, ApiKeyHandler, Chats, Vector, Prompts
Brain,
UserUsage,
File,
BrainSubscription,
ApiKeyHandler,
Chats,
Vector,
Prompts,
):
def __init__(self, supabase_client):
self.db = supabase_client
Brain.__init__(self, supabase_client)
User.__init__(self, supabase_client)
UserUsage.__init__(self, supabase_client)
File.__init__(self, supabase_client)
BrainSubscription.__init__(self, supabase_client)
ApiKeyHandler.__init__(self, supabase_client)

View File

@ -0,0 +1,88 @@
from datetime import datetime
from uuid import UUID
from logger import get_logger
from models.databases.repository import Repository
logger = get_logger(__name__)
class UserUsage(Repository):
def __init__(self, supabase_client):
self.db = supabase_client
def create_user_daily_usage(self, user_id: UUID, user_email: str, date: datetime):
return (
self.db.table("user_daily_usage")
.insert(
{
"user_id": str(user_id),
"email": user_email,
"date": date,
"daily_requests_count": 1,
}
)
.execute()
)
def get_user_usage(self, user_id):
"""
Fetch the user request stats from the database
"""
requests_stats = (
self.db.from_("user_daily_usage")
.select("*")
.filter("user_id", "eq", user_id)
.execute()
)
return requests_stats.data
def get_user_requests_count_for_day(self, user_id, date):
"""
Fetch the user request count from the database
"""
response = (
self.db.from_("user_daily_usage")
.select("daily_requests_count")
.filter("user_id", "eq", user_id)
.filter("date", "eq", date)
.execute()
).data
if response and len(response) > 0:
return response[0]["daily_requests_count"]
return None
def increment_user_request_count(self, user_id, date, current_requests_count: int):
"""
Increment the user's requests count for a specific day
"""
self.update_user_request_count(
user_id, daily_requests_count=current_requests_count + 1, date=date
)
def update_user_request_count(self, user_id, daily_requests_count, date):
response = (
self.db.table("user_daily_usage")
.update({"daily_requests_count": daily_requests_count})
.match({"user_id": user_id, "date": date})
.execute()
)
return response
def get_user_email(self, user_id):
"""
Fetch the user email from the database
"""
response = (
self.db.from_("user_daily_usage")
.select("email")
.filter("user_id", "eq", user_id)
.execute()
)
if response and len(response) > 0:
return response[0]["email"]
return None

View File

@ -1,85 +0,0 @@
from models.databases.repository import Repository
from logger import get_logger
logger = get_logger(__name__)
class User(Repository):
def __init__(self, supabase_client):
self.db = supabase_client
# [TODO] Rename the user table and its references to 'user_usage'
def create_user(self, user_id, user_email, date):
return (
self.db.table("users")
.insert(
{
"user_id": user_id,
"email": user_email,
"date": date,
"requests_count": 1,
}
)
.execute()
)
def get_user_request_stats(self, user_id):
"""
Fetch the user request stats from the database
"""
requests_stats = (
self.db.from_("users")
.select("*")
.filter("user_id", "eq", user_id)
.execute()
)
return requests_stats
def fetch_user_requests_count(self, user_id, date):
"""
Fetch the user request count from the database
"""
response = (
self.db.from_("users")
.select("*")
.filter("user_id", "eq", user_id)
.filter("date", "eq", date)
.execute()
)
return response
def update_user_request_count(self, user_id, requests_count, date):
response = (
self.db.table("users")
.update({"requests_count": requests_count})
.match({"user_id": user_id, "date": date})
.execute()
)
return response
def get_user_email(self, user_id):
"""
Fetch the user email from the database
"""
response = (
self.db.from_("users")
.select("email")
.filter("user_id", "eq", user_id)
.execute()
)
return response
def get_user_stats(self, user_email, date):
response = (
self.db.from_("users")
.select("*")
.filter("email", "eq", user_email)
.filter("date", "eq", date)
.execute()
)
return response

View File

@ -9,63 +9,65 @@ Base = declarative_base()
class User(Base):
__tablename__ = 'users'
__tablename__ = "users"
user_id = Column(String, primary_key=True)
email = Column(String)
date = Column(DateTime)
requests_count = Column(Integer)
daily_requests_count = Column(Integer)
class Brain(Base):
__tablename__ = 'brains'
__tablename__ = "brains"
brain_id = Column(Integer, primary_key=True)
name = Column(String)
users = relationship('BrainUser', back_populates='brain')
vectors = relationship('BrainVector', back_populates='brain')
users = relationship("BrainUser", back_populates="brain")
vectors = relationship("BrainVector", back_populates="brain")
class BrainUser(Base):
__tablename__ = 'brains_users'
__tablename__ = "brains_users"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey('users.user_id'))
brain_id = Column(Integer, ForeignKey('brains.brain_id'))
user_id = Column(Integer, ForeignKey("users.user_id"))
brain_id = Column(Integer, ForeignKey("brains.brain_id"))
rights = Column(String)
user = relationship('User')
brain = relationship('Brain', back_populates='users')
user = relationship("User")
brain = relationship("Brain", back_populates="users")
class BrainVector(Base):
__tablename__ = 'brains_vectors'
__tablename__ = "brains_vectors"
vector_id = Column(String, primary_key=True, default=lambda: str(uuid4()))
brain_id = Column(Integer, ForeignKey('brains.brain_id'))
brain_id = Column(Integer, ForeignKey("brains.brain_id"))
file_sha1 = Column(String)
brain = relationship('Brain', back_populates='vectors')
brain = relationship("Brain", back_populates="vectors")
class BrainSubscriptionInvitation(Base):
__tablename__ = 'brain_subscription_invitations'
__tablename__ = "brain_subscription_invitations"
id = Column(Integer, primary_key=True) # Assuming an integer primary key named 'id'
brain_id = Column(String, ForeignKey('brains.brain_id'))
email = Column(String, ForeignKey('users.email'))
brain_id = Column(String, ForeignKey("brains.brain_id"))
email = Column(String, ForeignKey("users.email"))
rights = Column(String)
brain = relationship('Brain')
user = relationship('User', foreign_keys=[email])
brain = relationship("Brain")
user = relationship("User", foreign_keys=[email])
class ApiKey(Base):
__tablename__ = 'api_keys'
__tablename__ = "api_keys"
key_id = Column(String, primary_key=True, default=lambda: str(uuid4()))
user_id = Column(Integer, ForeignKey('users.user_id'))
user_id = Column(Integer, ForeignKey("users.user_id"))
api_key = Column(String, unique=True)
creation_time = Column(DateTime, default=datetime.utcnow)
is_active = Column(Boolean, default=True)
deleted_time = Column(DateTime, nullable=True)
user = relationship('User')
user = relationship("User")

View File

@ -5,5 +5,6 @@ from pydantic import BaseModel
class UserIdentity(BaseModel):
user_id: UUID
id: UUID
email: Optional[str] = None
openai_api_key: Optional[str] = None

View File

@ -0,0 +1,54 @@
from logger import get_logger
from models.databases.supabase.supabase import SupabaseDB
from models.settings import get_supabase_db
from models.user_identity import UserIdentity
logger = get_logger(__name__)
class UserUsage(UserIdentity):
daily_requests_count: int = 0
def __init__(self, **data):
super().__init__(**data)
@property
def supabase_db(self) -> SupabaseDB:
return get_supabase_db()
def get_user_usage(self):
"""
Fetch the user request stats from the database
"""
request = self.supabase_db.get_user_usage(self.id)
return request
def handle_increment_user_request_count(self, date):
"""
Increment the user request count in the database
"""
current_requests_count = self.supabase_db.get_user_requests_count_for_day(
self.id, date
)
if current_requests_count is None:
if self.email is None:
raise ValueError("User Email should be defined for daily usage table")
self.supabase_db.create_user_daily_usage(
user_id=self.id, date=date, user_email=self.email
)
self.daily_requests_count = 1
return
self.supabase_db.increment_user_request_count(
user_id=self.id,
date=date,
current_requests_count=current_requests_count,
)
self.daily_requests_count = current_requests_count
logger.info(
f"User {self.email} request count updated to {current_requests_count}"
)

View File

@ -1,55 +0,0 @@
from typing import Optional
from uuid import UUID
from logger import get_logger
from models.databases.supabase.supabase import SupabaseDB
from models.settings import get_supabase_db
from pydantic import BaseModel
logger = get_logger(__name__)
# [TODO] Rename the user table and its references to 'user_usage'
class User(BaseModel):
id: UUID
email: Optional[str]
user_openai_api_key: Optional[str] = None
requests_count: int = 0
@property
def supabase_db(self) -> SupabaseDB:
return get_supabase_db()
# [TODO] Rename the user table and its references to 'user_usage'
def create_user(self, date):
"""
Create a new user entry in the database
Args:
date (str): Date of the request
"""
logger.info(f"New user entry in db document for user {self.email}")
return self.supabase_db.create_user(self.id, self.email, date)
def get_user_request_stats(self):
"""
Fetch the user request stats from the database
"""
request = self.supabase_db.get_user_request_stats(self.id)
return request.data
def increment_user_request_count(self, date):
"""
Increment the user request count in the database
"""
response = self.supabase_db.fetch_user_requests_count(self.id, date)
userItem = next(iter(response.data or []), {"requests_count": 0})
requests_count = userItem["requests_count"] + 1
logger.info(f"User {self.email} request count updated to {requests_count}")
self.supabase_db.update_user_request_count(self.id, requests_count, date)
self.requests_count = requests_count

View File

@ -1,15 +1,13 @@
from models import BrainEntity, UserIdentity
from models.databases.supabase.brains import CreateBrainProperties
from models import BrainEntity, User
from repository.brain import create_brain, create_brain_user, get_user_default_brain
from repository.brain.create_brain import create_brain
from repository.brain.create_brain_user import create_brain_user
from repository.brain.get_default_user_brain import get_user_default_brain
from routes.authorizations.types import RoleEnum
from repository.brain import (
create_brain,
create_brain_user,
get_user_default_brain
)
def get_default_user_brain_or_create_new(user: User) -> BrainEntity:
def get_default_user_brain_or_create_new(user: UserIdentity) -> BrainEntity:
default_brain = get_user_default_brain(user.id)
if not default_brain:

View File

@ -1,3 +1,3 @@
from .get_user_identity import get_user_identity
from .create_user_identity import create_user_identity
from .update_user_identity import update_user_identity, UserIdentityUpdatableProperties
from .update_user_properties import update_user_properties, UserUpdatableProperties

View File

@ -1,12 +1,23 @@
from models import get_supabase_client, UserIdentity
from typing import Optional
from uuid import UUID
from models import UserIdentity, get_supabase_client
def create_user_identity(user_identity: UserIdentity) -> UserIdentity:
def create_user_identity(id: UUID, openai_api_key: Optional[str]) -> UserIdentity:
supabase_client = get_supabase_client()
user_identity_dict = user_identity.dict()
user_identity_dict["user_id"] = str(user_identity.user_id)
response = (
supabase_client.from_("user_identity").insert(user_identity_dict).execute()
)
return UserIdentity(**response.data[0])
response = (
supabase_client.from_("user_identity")
.insert(
{
"user_id": str(id),
"openai_api_key": openai_api_key,
}
)
.execute()
)
user_identity = response.data[0]
return UserIdentity(
id=user_identity.user_id, openai_api_key=user_identity.openai_api_key
)

View File

@ -1,8 +1,11 @@
from multiprocessing import get_logger
from uuid import UUID
from models import get_supabase_client, UserIdentity
from repository.user_identity.create_user_identity import create_user_identity
logger = get_logger()
def get_user_identity(user_id: UUID) -> UserIdentity:
supabase_client = get_supabase_client()
@ -14,6 +17,9 @@ def get_user_identity(user_id: UUID) -> UserIdentity:
)
if len(response.data) == 0:
return create_user_identity(UserIdentity(user_id=user_id))
return create_user_identity(user_id, openai_api_key=None)
return UserIdentity(**response.data[0])
user_identity = response.data[0]
openai_api_key = user_identity["openai_api_key"]
return UserIdentity(id=user_id, openai_api_key=openai_api_key)

View File

@ -1,32 +0,0 @@
from typing import Optional
from uuid import UUID
from models import get_supabase_client, UserIdentity
from pydantic import BaseModel
from repository.user_identity import create_user_identity
class UserIdentityUpdatableProperties(BaseModel):
openai_api_key: Optional[str]
def update_user_identity(
user_id: UUID,
user_identity_updatable_properties: UserIdentityUpdatableProperties,
) -> UserIdentity:
supabase_client = get_supabase_client()
response = (
supabase_client.from_("user_identity")
.update(user_identity_updatable_properties.__dict__)
.filter("user_id", "eq", user_id)
.execute()
)
if len(response.data) == 0:
user_identity = UserIdentity(
user_id=user_id,
openai_api_key=user_identity_updatable_properties.openai_api_key,
)
return create_user_identity(user_identity)
return UserIdentity(**response.data[0])

View File

@ -0,0 +1,34 @@
from typing import Optional
from uuid import UUID
from models.settings import get_supabase_client
from models.user_identity import UserIdentity
from pydantic import BaseModel
from repository.user_identity.create_user_identity import create_user_identity
class UserUpdatableProperties(BaseModel):
openai_api_key: Optional[str]
def update_user_properties(
user_id: UUID,
user_identity_updatable_properties: UserUpdatableProperties,
) -> UserIdentity:
supabase_client = get_supabase_client()
response = (
supabase_client.from_("user_identity")
.update(user_identity_updatable_properties.__dict__)
.filter("user_id", "eq", user_id)
.execute()
)
if len(response.data) == 0:
return create_user_identity(
user_id, openai_api_key=user_identity_updatable_properties.openai_api_key
)
user_identity = response.data[0]
openai_api_key = user_identity["openai_api_key"]
return UserIdentity(id=user_id, openai_api_key=openai_api_key)

View File

@ -6,13 +6,9 @@ from asyncpg.exceptions import UniqueViolationError
from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends
from logger import get_logger
from models import get_supabase_db, User
from models import UserIdentity, get_supabase_db
from pydantic import BaseModel
from models import get_supabase_db, User
logger = get_logger(__name__)
@ -35,7 +31,7 @@ api_key_router = APIRouter()
dependencies=[Depends(AuthBearer())],
tags=["API Key"],
)
async def create_api_key(current_user: User = Depends(get_current_user)):
async def create_api_key(current_user: UserIdentity = Depends(get_current_user)):
"""
Create new API key for the current user.
@ -71,7 +67,9 @@ async def create_api_key(current_user: User = Depends(get_current_user)):
@api_key_router.delete(
"/api-key/{key_id}", dependencies=[Depends(AuthBearer())], tags=["API Key"]
)
async def delete_api_key(key_id: str, current_user: User = Depends(get_current_user)):
async def delete_api_key(
key_id: str, current_user: UserIdentity = Depends(get_current_user)
):
"""
Delete (deactivate) an API key for the current user.
@ -93,7 +91,7 @@ async def delete_api_key(key_id: str, current_user: User = Depends(get_current_u
dependencies=[Depends(AuthBearer())],
tags=["API Key"],
)
async def get_api_keys(current_user: User = Depends(get_current_user)):
async def get_api_keys(current_user: UserIdentity = Depends(get_current_user)):
"""
Get all active API keys for the current user.

View File

@ -3,9 +3,8 @@ from uuid import UUID
from auth.auth_bearer import get_current_user
from fastapi import Depends, HTTPException, status
from models import User
from models import UserIdentity
from repository.brain import get_brain_for_user
from routes.authorizations.types import RoleEnum
@ -18,7 +17,9 @@ def has_brain_authorization(
return: A wrapper function that checks the authorization
"""
async def wrapper(brain_id: UUID, current_user: User = Depends(get_current_user)):
async def wrapper(
brain_id: UUID, current_user: UserIdentity = Depends(get_current_user)
):
nonlocal required_roles
if isinstance(required_roles, str):
required_roles = [required_roles] # Convert single role to a list

View File

@ -3,28 +3,23 @@ from uuid import UUID
from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, HTTPException
from logger import get_logger
from models import BrainRateLimiting, UserIdentity
from models.databases.supabase.brains import (
BrainUpdatableProperties,
CreateBrainProperties,
)
from models import BrainRateLimiting, User
from repository.brain import (
create_brain,
get_user_brains,
get_brain_details,
create_brain_user,
update_brain_by_id,
get_brain_details,
get_default_user_brain_or_create_new,
get_user_brains,
get_user_default_brain,
set_as_default_brain_for_user,
get_default_user_brain_or_create_new,
)
from repository.prompt import get_prompt_by_id, delete_prompt_by_id
from routes.authorizations.brain_authorization import (
has_brain_authorization,
update_brain_by_id,
)
from repository.prompt import delete_prompt_by_id, get_prompt_by_id
from routes.authorizations.brain_authorization import has_brain_authorization
from routes.authorizations.types import RoleEnum
logger = get_logger(__name__)
@ -34,7 +29,7 @@ brain_router = APIRouter()
# get all brains
@brain_router.get("/brains/", dependencies=[Depends(AuthBearer())], tags=["Brain"])
async def brain_endpoint(current_user: User = Depends(get_current_user)):
async def brain_endpoint(current_user: UserIdentity = Depends(get_current_user)):
"""
Retrieve all brains for the current user.
@ -52,7 +47,9 @@ async def brain_endpoint(current_user: User = Depends(get_current_user)):
@brain_router.get(
"/brains/default/", dependencies=[Depends(AuthBearer())], tags=["Brain"]
)
async def get_default_brain_endpoint(current_user: User = Depends(get_current_user)):
async def get_default_brain_endpoint(
current_user: UserIdentity = Depends(get_current_user),
):
"""
Retrieve the default brain for the current user. If the user doesnt have one, it creates one.
@ -99,7 +96,7 @@ async def get_brain_endpoint(
@brain_router.post("/brains/", dependencies=[Depends(AuthBearer())], tags=["Brain"])
async def create_brain_endpoint(
brain: CreateBrainProperties,
current_user: User = Depends(get_current_user),
current_user: UserIdentity = Depends(get_current_user),
):
"""
Create a new brain with given
@ -201,7 +198,7 @@ async def update_brain_endpoint(
)
async def set_as_default_brain_endpoint(
brain_id: UUID,
user: User = Depends(get_current_user),
user: UserIdentity = Depends(get_current_user),
):
"""
Set a brain as default for the current user.

View File

@ -14,14 +14,10 @@ from models import (
BrainEntity,
Chat,
ChatQuestion,
LLMSettings,
User,
UserIdentity,
UserUsage,
get_supabase_db,
)
from models.brain_entity import BrainEntity
from models.brains import Brain
from models.chat import Chat
from models.chats import ChatQuestion
from models.databases.supabase.supabase import SupabaseDB
from repository.brain import get_brain_details
from repository.chat import (
@ -67,15 +63,19 @@ def delete_chat_from_db(supabase_db: SupabaseDB, chat_id):
pass
def check_user_limit(
user: User,
def check_user_requests_limit(
user: UserIdentity,
):
if user.user_openai_api_key is None:
date = time.strftime("%Y%m%d")
max_requests_number = int(os.getenv("MAX_REQUESTS_NUMBER", 1000))
userDailyUsage = UserUsage(
id=user.id, email=user.email, openai_api_key=user.openai_api_key
)
user.increment_user_request_count(date)
if int(user.requests_count) >= int(max_requests_number):
date = time.strftime("%Y%m%d")
userDailyUsage.handle_increment_user_request_count(date)
if user.openai_api_key is None:
max_requests_number = int(os.getenv("MAX_REQUESTS_NUMBER", 1))
if int(userDailyUsage.daily_requests_count) >= int(max_requests_number):
raise HTTPException(
status_code=429, # pyright: ignore reportPrivateUsage=none
detail="You have reached the maximum number of requests for today.", # pyright: ignore reportPrivateUsage=none
@ -91,7 +91,7 @@ async def healthz():
# get all chats
@chat_router.get("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"])
async def get_chats(current_user: User = Depends(get_current_user)):
async def get_chats(current_user: UserIdentity = Depends(get_current_user)):
"""
Retrieve all chats for the current user.
@ -125,7 +125,7 @@ async def delete_chat(chat_id: UUID):
async def update_chat_metadata_handler(
chat_data: ChatUpdatableProperties,
chat_id: UUID,
current_user: User = Depends(get_current_user),
current_user: UserIdentity = Depends(get_current_user),
) -> Chat:
"""
Update chat attributes
@ -144,7 +144,7 @@ async def update_chat_metadata_handler(
@chat_router.post("/chat", dependencies=[Depends(AuthBearer())], tags=["Chat"])
async def create_chat_handler(
chat_data: CreateChatProperties,
current_user: User = Depends(get_current_user),
current_user: UserIdentity = Depends(get_current_user),
):
"""
Create a new chat with initial chat messages.
@ -170,25 +170,25 @@ async def create_question_handler(
brain_id: NullableUUID
| UUID
| None = Query(..., description="The ID of the brain"),
current_user: User = Depends(get_current_user),
current_user: UserIdentity = Depends(get_current_user),
) -> GetChatHistoryOutput:
"""
Add a new question to the chat.
"""
# Retrieve user's OpenAI API key
current_user.user_openai_api_key = request.headers.get("Openai-Api-Key")
current_user.openai_api_key = request.headers.get("Openai-Api-Key")
brain = Brain(id=brain_id)
if not current_user.user_openai_api_key and brain_id:
if not current_user.openai_api_key and brain_id:
brain_details = get_brain_details(brain_id)
if brain_details:
current_user.user_openai_api_key = brain_details.openai_api_key
current_user.openai_api_key = brain_details.openai_api_key
if not current_user.user_openai_api_key:
if not current_user.openai_api_key:
user_identity = get_user_identity(current_user.id)
if user_identity is not None:
current_user.user_openai_api_key = user_identity.openai_api_key
current_user.openai_api_key = user_identity.openai_api_key
# Retrieve chat model (temperature, max_tokens, model)
if (
@ -202,8 +202,7 @@ async def create_question_handler(
chat_question.max_tokens = chat_question.max_tokens or brain.max_tokens or 256
try:
check_user_limit(current_user)
LLMSettings()
check_user_requests_limit(current_user)
gpt_answer_generator: HeadlessQA | OpenAIBrainPicking
if brain_id:
@ -213,14 +212,14 @@ async def create_question_handler(
max_tokens=chat_question.max_tokens,
temperature=chat_question.temperature,
brain_id=str(brain_id),
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
)
else:
gpt_answer_generator = HeadlessQA(
model=chat_question.model,
temperature=chat_question.temperature,
max_tokens=chat_question.max_tokens,
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
chat_id=str(chat_id),
)
@ -248,24 +247,24 @@ async def create_stream_question_handler(
brain_id: NullableUUID
| UUID
| None = Query(..., description="The ID of the brain"),
current_user: User = Depends(get_current_user),
current_user: UserIdentity = Depends(get_current_user),
) -> StreamingResponse:
# TODO: check if the user has access to the brain
# Retrieve user's OpenAI API key
current_user.user_openai_api_key = request.headers.get("Openai-Api-Key")
current_user.openai_api_key = request.headers.get("Openai-Api-Key")
brain = Brain(id=brain_id)
brain_details: BrainEntity | None = None
if not current_user.user_openai_api_key and brain_id:
if not current_user.openai_api_key and brain_id:
brain_details = get_brain_details(brain_id)
if brain_details:
current_user.user_openai_api_key = brain_details.openai_api_key
current_user.openai_api_key = brain_details.openai_api_key
if not current_user.user_openai_api_key:
if not current_user.openai_api_key:
user_identity = get_user_identity(current_user.id)
if user_identity is not None:
current_user.user_openai_api_key = user_identity.openai_api_key
current_user.openai_api_key = user_identity.openai_api_key
# Retrieve chat model (temperature, max_tokens, model)
if (
@ -280,36 +279,36 @@ async def create_stream_question_handler(
try:
logger.info(f"Streaming request for {chat_question.model}")
check_user_limit(current_user)
check_user_requests_limit(current_user)
gpt_answer_generator: HeadlessQA | OpenAIBrainPicking
if brain_id:
gpt_answer_generator = OpenAIBrainPicking(
chat_id=str(chat_id),
model=(brain_details or chat_question).model
if current_user.user_openai_api_key
if current_user.openai_api_key
else "gpt-3.5-turbo",
max_tokens=(brain_details or chat_question).max_tokens
if current_user.user_openai_api_key
if current_user.openai_api_key
else 0,
temperature=(brain_details or chat_question).temperature
if current_user.user_openai_api_key
if current_user.openai_api_key
else 256,
brain_id=str(brain_id),
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=True,
)
else:
gpt_answer_generator = HeadlessQA(
model=chat_question.model
if current_user.user_openai_api_key
if current_user.openai_api_key
else "gpt-3.5-turbo",
temperature=chat_question.temperature
if current_user.user_openai_api_key
if current_user.openai_api_key
else 256,
max_tokens=chat_question.max_tokens
if current_user.user_openai_api_key
if current_user.openai_api_key
else 0,
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
user_openai_api_key=current_user.openai_api_key, # pyright: ignore reportPrivateUsage=none
chat_id=str(chat_id),
streaming=True,
)

View File

@ -6,7 +6,7 @@ from uuid import UUID
from auth import AuthBearer, get_current_user
from crawl.crawler import CrawlWebsite
from fastapi import APIRouter, Depends, Query, Request, UploadFile
from models import User, Brain, File
from models import Brain, File, UserIdentity
from parsers.github import process_github
from utils.file import convert_bytes
from utils.processors import filter_file
@ -25,7 +25,7 @@ async def crawl_endpoint(
crawl_website: CrawlWebsite,
brain_id: UUID = Query(..., description="The ID of the brain"),
enable_summarization: bool = False,
current_user: User = Depends(get_current_user),
current_user: UserIdentity = Depends(get_current_user),
):
"""
Crawl a website and process the crawled data.
@ -34,6 +34,7 @@ async def crawl_endpoint(
# [TODO] check if the user is the owner/editor of the brain
brain = Brain(id=brain_id)
# [TODO] rate limiting of user for crawl
if request.headers.get("Openai-Api-Key"):
brain.max_brain_size = int(os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200))
@ -42,7 +43,7 @@ async def crawl_endpoint(
if remaining_free_space - file_size < 0:
message = {
"message": f"❌ User's brain will exceed maximum capacity with this upload. Maximum file allowed is : {convert_bytes(remaining_free_space)}",
"message": f"❌ UserIdentity's brain will exceed maximum capacity with this upload. Maximum file allowed is : {convert_bytes(remaining_free_space)}",
"type": "error",
}
else:

View File

@ -2,7 +2,7 @@ from uuid import UUID
from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, Query
from models import User, get_supabase_db, Brain
from models import Brain, UserIdentity, get_supabase_db
from routes.authorizations.brain_authorization import (
RoleEnum,
has_brain_authorization,
@ -36,7 +36,7 @@ async def explore_endpoint(
)
async def delete_endpoint(
file_name: str,
current_user: User = Depends(get_current_user),
current_user: UserIdentity = Depends(get_current_user),
brain_id: UUID = Query(..., description="The ID of the brain"),
):
"""
@ -54,7 +54,7 @@ async def delete_endpoint(
"/explore/{file_name}/", dependencies=[Depends(AuthBearer())], tags=["Explore"]
)
async def download_endpoint(
file_name: str, current_user: User = Depends(get_current_user)
file_name: str, current_user: UserIdentity = Depends(get_current_user)
):
"""
Download a specific user file by file name.

View File

@ -3,11 +3,11 @@ from uuid import UUID
from auth.auth_bearer import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, HTTPException
from models import User, BrainSubscription, Brain, PromptStatusEnum
from models import Brain, BrainSubscription, PromptStatusEnum, UserIdentity
from pydantic import BaseModel
from repository.brain import (
get_brain_by_id,
create_brain_user,
get_brain_by_id,
get_brain_details,
get_brain_for_user,
update_brain_user_rights,
@ -17,7 +17,7 @@ from repository.brain_subscription import (
resend_invitation_email,
)
from repository.prompt import delete_prompt_by_id, get_prompt_by_id
from repository.user import get_user_id_by_user_email, get_user_email_by_user_id
from repository.user import get_user_email_by_user_id, get_user_id_by_user_email
from routes.authorizations.brain_authorization import (
RoleEnum,
has_brain_authorization,
@ -44,7 +44,7 @@ def invite_users_to_brain(
brain_id: UUID,
users: List[dict],
origin: str = Depends(get_origin_header),
current_user: User = Depends(get_current_user),
current_user: UserIdentity = Depends(get_current_user),
):
"""
Invite multiple users to a brain by their emails. This function creates
@ -114,7 +114,7 @@ def get_brain_users(
"/brains/{brain_id}/subscription",
)
async def remove_user_subscription(
brain_id: UUID, current_user: User = Depends(get_current_user)
brain_id: UUID, current_user: UserIdentity = Depends(get_current_user)
):
"""
Remove a user's subscription to a brain
@ -163,13 +163,15 @@ async def remove_user_subscription(
dependencies=[Depends(AuthBearer())],
tags=["BrainSubscription"],
)
def get_user_invitation(brain_id: UUID, current_user: User = Depends(get_current_user)):
def get_user_invitation(
brain_id: UUID, current_user: UserIdentity = Depends(get_current_user)
):
"""
Get an invitation to a brain for a user. This function checks if the user
has been invited to the brain and returns the invitation status.
"""
if not current_user.email:
raise HTTPException(status_code=400, detail="User email is not defined")
raise HTTPException(status_code=400, detail="UserIdentity email is not defined")
subscription = BrainSubscription(brain_id=brain_id, email=current_user.email)
@ -197,7 +199,7 @@ def get_user_invitation(brain_id: UUID, current_user: User = Depends(get_current
tags=["Brain"],
)
async def accept_invitation(
brain_id: UUID, current_user: User = Depends(get_current_user)
brain_id: UUID, current_user: UserIdentity = Depends(get_current_user)
):
"""
Accept an invitation to a brain for a user. This function removes the
@ -205,7 +207,7 @@ async def accept_invitation(
brain users.
"""
if not current_user.email:
raise HTTPException(status_code=400, detail="User email is not defined")
raise HTTPException(status_code=400, detail="UserIdentity email is not defined")
subscription = BrainSubscription(brain_id=brain_id, email=current_user.email)
@ -240,14 +242,14 @@ async def accept_invitation(
tags=["Brain"],
)
async def decline_invitation(
brain_id: UUID, current_user: User = Depends(get_current_user)
brain_id: UUID, current_user: UserIdentity = Depends(get_current_user)
):
"""
Decline an invitation to a brain for a user. This function removes the
invitation from the subscription invitations.
"""
if not current_user.email:
raise HTTPException(status_code=400, detail="User email is not defined")
raise HTTPException(status_code=400, detail="UserIdentity email is not defined")
subscription = BrainSubscription(brain_id=brain_id, email=current_user.email)
@ -282,7 +284,7 @@ class BrainSubscriptionUpdatableProperties(BaseModel):
def update_brain_subscription(
brain_id: UUID,
subscription: BrainSubscriptionUpdatableProperties,
current_user: User = Depends(get_current_user),
current_user: UserIdentity = Depends(get_current_user),
):
user_email = subscription.email
if user_email == current_user.email:

View File

@ -3,7 +3,7 @@ from uuid import UUID
from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, Query, Request, UploadFile
from models import User, File, Brain
from models import Brain, File, UserIdentity
from repository.brain import get_brain_details
from repository.user_identity import get_user_identity
from routes.authorizations.brain_authorization import (
@ -27,7 +27,7 @@ async def upload_file(
uploadFile: UploadFile,
brain_id: UUID = Query(..., description="The ID of the brain"),
enable_summarization: bool = False,
current_user: User = Depends(get_current_user),
current_user: UserIdentity = Depends(get_current_user),
):
"""
Upload a file to the user's storage.
@ -57,7 +57,7 @@ async def upload_file(
file = File(file=uploadFile)
if remaining_free_space - file_size < 0:
message = {
"message": f"❌ User's brain will exceed maximum capacity with this upload. Maximum file allowed is : {convert_bytes(remaining_free_space)}",
"message": f"❌ UserIdentity's brain will exceed maximum capacity with this upload. Maximum file allowed is : {convert_bytes(remaining_free_space)}",
"type": "error",
}
else:

View File

@ -3,12 +3,12 @@ import time
from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, Request
from models import User, Brain, BrainRateLimiting, UserIdentity
from models import Brain, BrainRateLimiting, UserIdentity, UserUsage
from repository.brain import get_user_default_brain
from repository.user_identity import (
UserIdentityUpdatableProperties,
update_user_identity,
get_user_identity,
from repository.user_identity.get_user_identity import get_user_identity
from repository.user_identity.update_user_properties import (
UserUpdatableProperties,
update_user_properties,
)
user_router = APIRouter()
@ -18,7 +18,7 @@ MAX_BRAIN_SIZE_WITH_OWN_KEY = int(os.getenv("MAX_BRAIN_SIZE_WITH_KEY", 209715200
@user_router.get("/user", dependencies=[Depends(AuthBearer())], tags=["User"])
async def get_user_endpoint(
request: Request, current_user: User = Depends(get_current_user)
request: Request, current_user: UserIdentity = Depends(get_current_user)
):
"""
Get user information and statistics.
@ -39,7 +39,9 @@ async def get_user_endpoint(
date = time.strftime("%Y%m%d")
max_requests_number = os.getenv("MAX_REQUESTS_NUMBER")
requests_stats = current_user.get_user_request_stats()
userDailyUsage = UserUsage(id=current_user.id)
requests_stats = userDailyUsage.get_user_usage()
default_brain = get_user_default_brain(current_user.id)
if default_brain:
@ -64,13 +66,13 @@ async def get_user_endpoint(
tags=["User"],
)
def update_user_identity_route(
user_identity_updatable_properties: UserIdentityUpdatableProperties,
current_user: User = Depends(get_current_user),
user_identity_updatable_properties: UserUpdatableProperties,
current_user: UserIdentity = Depends(get_current_user),
) -> UserIdentity:
"""
Update user identity.
"""
return update_user_identity(current_user.id, user_identity_updatable_properties)
return update_user_properties(current_user.id, user_identity_updatable_properties)
@user_router.get(
@ -79,7 +81,7 @@ def update_user_identity_route(
tags=["User"],
)
def get_user_identity_route(
current_user: User = Depends(get_current_user),
current_user: UserIdentity = Depends(get_current_user),
) -> UserIdentity:
"""
Get user identity.

View File

@ -12,7 +12,7 @@ import {
type RequestStat = {
date: string;
requests_count: number;
daily_requests_count: number;
user_id: string;
};
@ -31,7 +31,7 @@ export const RequestsPerDayChart = ({
return {
date: format(date, "MM/dd/yyyy"),
requests_count: stat ? stat.requests_count : 0,
daily_requests_count: stat ? stat.daily_requests_count : 0,
};
})
.reverse();
@ -57,7 +57,7 @@ export const RequestsPerDayChart = ({
}}
/>
<VictoryAxis dependentAxis />
<VictoryLine data={data} x="date" y="requests_count" />
<VictoryLine data={data} x="date" y="daily_requests_count" />
</VictoryChart>
);
};

View File

@ -18,14 +18,14 @@ import { RequestsPerDayChart } from "./Graphs/RequestsPerDayChart";
export const UserStatistics = (userStats: UserStats): JSX.Element => {
const { email, current_brain_size, max_brain_size, date, requests_stats } =
userStats;
const { t } = useTranslation(["translation","user"]);
const { t } = useTranslation(["translation", "user"]);
return (
<>
<div className="flex flex-col sm:flex-row sm:items-center py-10 gap-5">
<div className="flex-1 flex flex-col">
<h1 className="text-4xl font-semibold">
{t("title", { user: email.split("@")[0], ns: "user"})}
{t("title", { user: email.split("@")[0], ns: "user" })}
</h1>
<p className="opacity-50">{email}</p>
<Link className="mt-2" href={"/logout"}>
@ -42,7 +42,10 @@ export const UserStatistics = (userStats: UserStats): JSX.Element => {
<div>
<h1 className="text-2xl font-semibold">
{/* The last element corresponds to today's request_count */}
{t("requestsCount",{ count: requests_stats.at(-1)?.requests_count, ns: "user"})}
{t("requestsCount", {
count: requests_stats.at(-1)?.daily_requests_count,
ns: "user",
})}
</h1>
<DateComponent date={date} />
</div>
@ -53,7 +56,9 @@ export const UserStatistics = (userStats: UserStats): JSX.Element => {
<UserStatisticsCard>
<div>
<h1 className="text-2xl font-semibold">{t("brainSize",{ns: "user"})}</h1>
<h1 className="text-2xl font-semibold">
{t("brainSize", { ns: "user" })}
</h1>
<p>
{/* How much brain space is left */}
{prettyBytes(max_brain_size - current_brain_size, {

View File

@ -1,6 +1,6 @@
export type RequestStat = {
date: string;
requests_count: number;
daily_requests_count: number;
user_id: string;
};

View File

@ -0,0 +1,29 @@
-- Create a new user_daily_usage table
create table if not exists
user_daily_usage (
user_id uuid references auth.users (id),
email text,
date text,
daily_requests_count int,
primary key (user_id, date)
);
-- Drop the old users table
drop table if exists users;
-- Update migrations table
insert into
migrations (name)
select
'202308181004030_rename_users_table'
where
not exists (
select
1
from
migrations
where
name = '202308181004030_rename_users_table'
);
commit;

View File

@ -1,11 +1,11 @@
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
-- Create users table
CREATE TABLE IF NOT EXISTS users(
CREATE TABLE IF NOT EXISTS user_daily_usage(
user_id UUID DEFAULT uuid_generate_v4() PRIMARY KEY,
email TEXT,
date TEXT,
requests_count INT
daily_requests_count INT
);
-- Create chats table

View File

@ -1,9 +1,9 @@
-- Create users table
CREATE TABLE IF NOT EXISTS users(
CREATE TABLE IF NOT EXISTS user_daily_usage(
user_id UUID REFERENCES auth.users (id),
email TEXT,
date TEXT,
requests_count INT,
daily_requests_count INT,
PRIMARY KEY (user_id, date)
);
@ -215,7 +215,7 @@ CREATE TABLE IF NOT EXISTS migrations (
);
INSERT INTO migrations (name)
SELECT '20230809154300_add_prompt_id_brain_id_to_chat_history_table'
SELECT '202308181004030_rename_users_table'
WHERE NOT EXISTS (
SELECT 1 FROM migrations WHERE name = '20230809154300_add_prompt_id_brain_id_to_chat_history_table'
SELECT 1 FROM migrations WHERE name = '202308181004030_rename_users_table'
);