fix(RBAC): use dependencies (#629)

This commit is contained in:
Mamadou DICKO 2023-07-13 17:54:23 +02:00 committed by GitHub
parent 83fe9430d0
commit f65044e152
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 35 deletions

View File

@ -1,41 +1,40 @@
from functools import wraps
from typing import Optional
from uuid import UUID
from fastapi import HTTPException, status
from auth.auth_bearer import get_current_user
from fastapi import Depends, HTTPException, status
from models.brains import Brain
from models.users import User
def has_brain_authorization(required_role: Optional[str] = "Owner"):
def decorator(func):
@wraps(func)
async def wrapper(current_user: User, *args, **kwargs):
brain_id: Optional[UUID] = kwargs.get("brain_id")
user_id = current_user.id
if brain_id is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing brain ID",
)
"""
Decorator to check if the user has the required role for the brain
param: required_role: The role required to access the brain
return: A wrapper function that checks the authorization
"""
async def wrapper(brain_id: UUID, current_user: User = Depends(get_current_user)):
validate_brain_authorization(
brain_id, user_id=user_id, required_role=required_role
brain_id=brain_id, user_id=current_user.id, required_role=required_role
)
return await func(*args, **kwargs)
return wrapper
return decorator
def validate_brain_authorization(
brain_id: UUID,
user_id: UUID,
required_role: Optional[str] = "Owner",
):
"""
Function to check if the user has the required role for the brain
param: brain_id: The id of the brain
param: user_id: The id of the user
param: required_role: The role required to access the brain
return: None
"""
if required_role is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -10,7 +10,7 @@ from models.users import User
from pydantic import BaseModel
from routes.authorizations.brain_authorization import (
validate_brain_authorization,
has_brain_authorization,
)
logger = get_logger(__name__)
@ -78,14 +78,11 @@ async def get_default_brain_endpoint(current_user: User = Depends(get_current_us
# get one brain
@brain_router.get(
"/brains/{brain_id}/",
dependencies=[
Depends(AuthBearer()),
],
dependencies=[Depends(AuthBearer()), Depends(has_brain_authorization())],
tags=["Brain"],
)
async def get_brain_endpoint(
brain_id: UUID,
current_user: User = Depends(get_current_user),
):
"""
Retrieve details of a specific brain by brain ID.
@ -96,7 +93,6 @@ async def get_brain_endpoint(
This endpoint retrieves the details of a specific brain identified by the provided brain ID. It returns the brain ID and its
history, which includes the brain messages exchanged in the brain.
"""
validate_brain_authorization(brain_id, current_user.id)
brain = Brain(id=brain_id)
brains = brain.get_brain_details()
if len(brains) > 0:
@ -112,9 +108,7 @@ async def get_brain_endpoint(
# delete one brain
@brain_router.delete(
"/brains/{brain_id}/",
dependencies=[
Depends(AuthBearer()),
],
dependencies=[Depends(AuthBearer()), Depends(has_brain_authorization())],
tags=["Brain"],
)
async def delete_brain_endpoint(
@ -124,8 +118,6 @@ async def delete_brain_endpoint(
"""
Delete a specific brain by brain ID.
"""
# [TODO] check if the user is the owner of the brain
validate_brain_authorization(brain_id, current_user.id)
brain = Brain(id=brain_id)
brain.delete_brain(current_user.id)
@ -188,13 +180,13 @@ async def create_brain_endpoint(
Depends(
AuthBearer(),
),
Depends(has_brain_authorization()),
],
tags=["Brain"],
)
async def update_brain_endpoint(
brain_id: UUID,
input_brain: Brain,
current_user: User = Depends(get_current_user),
):
"""
Update an existing brain with new brain parameters/files.
@ -204,7 +196,6 @@ async def update_brain_endpoint(
name, status, model, max_tokens, temperature
Return modified brain ? No need -> do an optimistic update
"""
validate_brain_authorization(brain_id, current_user.id)
commons = common_dependencies()
brain = Brain(id=brain_id)

View File

@ -7,6 +7,7 @@ from models.settings import common_dependencies
from models.users import User
from routes.authorizations.brain_authorization import (
has_brain_authorization,
validate_brain_authorization,
)
@ -31,6 +32,7 @@ async def explore_endpoint(
"/explore/{file_name}/",
dependencies=[
Depends(AuthBearer()),
Depends(has_brain_authorization()),
],
tags=["Explore"],
)
@ -42,7 +44,6 @@ async def delete_endpoint(
"""
Delete a specific user file by file name.
"""
validate_brain_authorization(brain_id, current_user.id)
brain = Brain(id=brain_id)
brain.delete_file_from_brain(file_name)