fix(RBAC): use dependencies (#629)

This commit is contained in:
Mamadou DICKO 2023-07-13 17:54:23 +02:00 committed by GitHub
parent 83fe9430d0
commit f65044e152
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 35 deletions

View File

@ -1,41 +1,40 @@
from functools import wraps
from typing import Optional from typing import Optional
from uuid import UUID from uuid import UUID
from fastapi import HTTPException, status from auth.auth_bearer import get_current_user
from fastapi import Depends, HTTPException, status
from models.brains import Brain from models.brains import Brain
from models.users import User from models.users import User
def has_brain_authorization(required_role: Optional[str] = "Owner"): def has_brain_authorization(required_role: Optional[str] = "Owner"):
def decorator(func): """
@wraps(func) Decorator to check if the user has the required role for the brain
async def wrapper(current_user: User, *args, **kwargs): param: required_role: The role required to access the brain
brain_id: Optional[UUID] = kwargs.get("brain_id") return: A wrapper function that checks the authorization
user_id = current_user.id """
if brain_id is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing brain ID",
)
async def wrapper(brain_id: UUID, current_user: User = Depends(get_current_user)):
validate_brain_authorization( validate_brain_authorization(
brain_id, user_id=user_id, required_role=required_role brain_id=brain_id, user_id=current_user.id, required_role=required_role
) )
return await func(*args, **kwargs)
return wrapper return wrapper
return decorator
def validate_brain_authorization( def validate_brain_authorization(
brain_id: UUID, brain_id: UUID,
user_id: UUID, user_id: UUID,
required_role: Optional[str] = "Owner", required_role: Optional[str] = "Owner",
): ):
"""
Function to check if the user has the required role for the brain
param: brain_id: The id of the brain
param: user_id: The id of the user
param: required_role: The role required to access the brain
return: None
"""
if required_role is None: if required_role is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -10,7 +10,7 @@ from models.users import User
from pydantic import BaseModel from pydantic import BaseModel
from routes.authorizations.brain_authorization import ( from routes.authorizations.brain_authorization import (
validate_brain_authorization, has_brain_authorization,
) )
logger = get_logger(__name__) logger = get_logger(__name__)
@ -78,14 +78,11 @@ async def get_default_brain_endpoint(current_user: User = Depends(get_current_us
# get one brain # get one brain
@brain_router.get( @brain_router.get(
"/brains/{brain_id}/", "/brains/{brain_id}/",
dependencies=[ dependencies=[Depends(AuthBearer()), Depends(has_brain_authorization())],
Depends(AuthBearer()),
],
tags=["Brain"], tags=["Brain"],
) )
async def get_brain_endpoint( async def get_brain_endpoint(
brain_id: UUID, brain_id: UUID,
current_user: User = Depends(get_current_user),
): ):
""" """
Retrieve details of a specific brain by brain ID. Retrieve details of a specific brain by brain ID.
@ -96,7 +93,6 @@ async def get_brain_endpoint(
This endpoint retrieves the details of a specific brain identified by the provided brain ID. It returns the brain ID and its This endpoint retrieves the details of a specific brain identified by the provided brain ID. It returns the brain ID and its
history, which includes the brain messages exchanged in the brain. history, which includes the brain messages exchanged in the brain.
""" """
validate_brain_authorization(brain_id, current_user.id)
brain = Brain(id=brain_id) brain = Brain(id=brain_id)
brains = brain.get_brain_details() brains = brain.get_brain_details()
if len(brains) > 0: if len(brains) > 0:
@ -112,9 +108,7 @@ async def get_brain_endpoint(
# delete one brain # delete one brain
@brain_router.delete( @brain_router.delete(
"/brains/{brain_id}/", "/brains/{brain_id}/",
dependencies=[ dependencies=[Depends(AuthBearer()), Depends(has_brain_authorization())],
Depends(AuthBearer()),
],
tags=["Brain"], tags=["Brain"],
) )
async def delete_brain_endpoint( async def delete_brain_endpoint(
@ -124,8 +118,6 @@ async def delete_brain_endpoint(
""" """
Delete a specific brain by brain ID. Delete a specific brain by brain ID.
""" """
# [TODO] check if the user is the owner of the brain
validate_brain_authorization(brain_id, current_user.id)
brain = Brain(id=brain_id) brain = Brain(id=brain_id)
brain.delete_brain(current_user.id) brain.delete_brain(current_user.id)
@ -188,13 +180,13 @@ async def create_brain_endpoint(
Depends( Depends(
AuthBearer(), AuthBearer(),
), ),
Depends(has_brain_authorization()),
], ],
tags=["Brain"], tags=["Brain"],
) )
async def update_brain_endpoint( async def update_brain_endpoint(
brain_id: UUID, brain_id: UUID,
input_brain: Brain, input_brain: Brain,
current_user: User = Depends(get_current_user),
): ):
""" """
Update an existing brain with new brain parameters/files. Update an existing brain with new brain parameters/files.
@ -204,7 +196,6 @@ async def update_brain_endpoint(
name, status, model, max_tokens, temperature name, status, model, max_tokens, temperature
Return modified brain ? No need -> do an optimistic update Return modified brain ? No need -> do an optimistic update
""" """
validate_brain_authorization(brain_id, current_user.id)
commons = common_dependencies() commons = common_dependencies()
brain = Brain(id=brain_id) brain = Brain(id=brain_id)

View File

@ -7,6 +7,7 @@ from models.settings import common_dependencies
from models.users import User from models.users import User
from routes.authorizations.brain_authorization import ( from routes.authorizations.brain_authorization import (
has_brain_authorization,
validate_brain_authorization, validate_brain_authorization,
) )
@ -31,6 +32,7 @@ async def explore_endpoint(
"/explore/{file_name}/", "/explore/{file_name}/",
dependencies=[ dependencies=[
Depends(AuthBearer()), Depends(AuthBearer()),
Depends(has_brain_authorization()),
], ],
tags=["Explore"], tags=["Explore"],
) )
@ -42,7 +44,6 @@ async def delete_endpoint(
""" """
Delete a specific user file by file name. Delete a specific user file by file name.
""" """
validate_brain_authorization(brain_id, current_user.id)
brain = Brain(id=brain_id) brain = Brain(id=brain_id)
brain.delete_file_from_brain(file_name) brain.delete_file_from_brain(file_name)