chore: rename and refactor AuthBearer

This commit is contained in:
Matt 2023-06-14 17:36:01 +01:00
parent 0bc4221b2c
commit c0b877c94e
9 changed files with 72 additions and 70 deletions

View File

@ -0,0 +1,36 @@
from datetime import datetime
from fastapi import HTTPException
from pydantic import DateError
from utils.vectors import CommonsDep
async def verify_api_key(api_key: str, commons: CommonsDep):
try:
# Use UTC time to avoid timezone issues
current_date = datetime.utcnow().date()
result = commons['supabase'].table('api_keys').select('api_key', 'creation_time').filter('api_key', 'eq', api_key).filter('is_active', 'eq', True).execute()
if result.data is not None and len(result.data) > 0:
api_key_creation_date = datetime.strptime(result.data[0]['creation_time'], "%Y-%m-%dT%H:%M:%S").date()
# Check if the API key was created today: Todo remove this check and use deleted_time instead.
if api_key_creation_date == current_date:
return True
return False
except DateError:
return False
async def get_user_from_api_key(api_key: str, commons: CommonsDep):
# Lookup the user_id from the api_keys table
user_id_data = commons['supabase'].table('api_keys').select('user_id').filter('api_key', 'eq', api_key).execute()
if not user_id_data.data:
raise HTTPException(status_code=400, detail="Invalid API key.")
user_id = user_id_data.data[0]['user_id']
# Lookup the email from the users table. Todo: remove and use user_id for credentials
user_email_data = commons['supabase'].table('users').select('email').filter('user_id', 'eq', user_id).execute()
return {'email': user_email_data.data[0]['email']} if user_email_data.data else {'email': None}

View File

@ -1,14 +1,15 @@
import os
from auth.api_key_handler import verify_api_key, get_user_from_api_key
from auth.jwt_token_handler import decode_access_token, verify_token
from typing import Optional
from fastapi import Depends, Request, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import os
from models.users import User
from utils.vectors import CommonsDep
from asyncpg.exceptions import DataError
from auth.auth_handler import decode_access_token
from datetime import datetime
class JWTBearer(HTTPBearer):
class AuthBearer(HTTPBearer):
def __init__(self, auto_error: bool = True):
super().__init__(auto_error=auto_error)
@ -27,54 +28,15 @@ class JWTBearer(HTTPBearer):
async def authenticate(self, token, commons):
if os.environ.get("AUTHENTICATE") == "false":
return self.get_test_user()
elif self.verify_jwt(token):
return self.decode_jwt(token)
elif await self.verify_api_key(token, commons):
return await self.get_user_from_api_key(token, commons)
elif verify_token(token):
return decode_access_token(token)
elif await verify_api_key(token, commons):
return await get_user_from_api_key(token, commons)
else:
raise HTTPException(status_code=402, detail="Invalid token or expired token.")
def get_test_user(self):
return {'email': 'test@example.com'} # replace with test user information
def verify_jwt(self, jwtoken: str):
payload = decode_access_token(jwtoken)
return payload is not None
def decode_jwt(self, jwtoken: str):
return decode_access_token(jwtoken)
async def verify_api_key(self, api_key: str, commons: CommonsDep):
try:
# Use UTC time to avoid timezone issues
current_date = datetime.utcnow().date()
result = commons['supabase'].table('api_keys').select('api_key', 'creation_time').filter('api_key', 'eq', api_key).filter('is_active', 'eq', True).execute()
if result.data is not None and len(result.data) > 0:
api_key_creation_date = datetime.strptime(result.data[0]['creation_time'], "%Y-%m-%dT%H:%M:%S").date()
# Check if the API key was created today: Todo remove this check and use deleted_time instead.
if api_key_creation_date == current_date:
return True
return False
except DataError:
return False
async def get_user_from_api_key(self, api_key: str, commons: CommonsDep):
# Lookup the user_id from the api_keys table
user_id_data = commons['supabase'].table('api_keys').select('user_id').filter('api_key', 'eq', api_key).execute()
if not user_id_data.data:
raise HTTPException(status_code=400, detail="Invalid API key.")
user_id = user_id_data.data[0]['user_id']
# Lookup the email from the users table. Todo: remove and use user_id for credentials
user_email_data = commons['supabase'].table('users').select('email').filter('user_id', 'eq', user_id).execute()
return {'email': user_email_data.data[0]['email']} if user_email_data.data else {'email': None}
def get_current_user(credentials: dict = Depends(JWTBearer())) -> User:
def get_current_user(credentials: dict = Depends(AuthBearer())) -> User:
return User(email=credentials.get('email', 'none'))

View File

@ -24,6 +24,10 @@ def decode_access_token(token: str):
return payload
except JWTError as e:
return None
def verify_token(token: str):
payload = decode_access_token(token)
return payload is not None
def get_user_email_from_token(token: str):
payload = decode_access_token(token)

View File

@ -2,7 +2,7 @@ from datetime import datetime
import time
from typing import List
from pydantic import BaseModel
from auth.auth_bearer import JWTBearer, get_current_user
from auth.auth_bearer import AuthBearer, get_current_user
from fastapi import APIRouter, Depends
from utils.vectors import fetch_user_id_from_credentials
from models.users import User
@ -26,7 +26,7 @@ class ApiKey(BaseModel):
api_key_router = APIRouter()
@api_key_router.post("/api-key", response_model=ApiKey, dependencies=[Depends(JWTBearer())])
@api_key_router.post("/api-key", response_model=ApiKey, dependencies=[Depends(AuthBearer())])
async def create_api_key(commons: CommonsDep, current_user: User = Depends(get_current_user)):
date = time.strftime("%Y%m%d")
@ -60,7 +60,7 @@ async def create_api_key(commons: CommonsDep, current_user: User = Depends(get_c
return {"api_key": new_api_key}
@api_key_router.delete("/api-key/{key_id}", dependencies=[Depends(JWTBearer())])
@api_key_router.delete("/api-key/{key_id}", dependencies=[Depends(AuthBearer())])
async def delete_api_key(key_id: str, commons: CommonsDep, current_user: User = Depends(get_current_user)):
"""Delete (deactivate) an API key for current user."""
@ -72,7 +72,7 @@ async def delete_api_key(key_id: str, commons: CommonsDep, current_user: User =
return {"message": "API key deleted."}
@api_key_router.get("/api-keys", response_model=List[ApiKeyInfo], dependencies=[Depends(JWTBearer())])
@api_key_router.get("/api-keys", response_model=List[ApiKeyInfo], dependencies=[Depends(AuthBearer())])
async def get_api_keys(commons: CommonsDep, current_user: User = Depends(get_current_user)):
"""Get all active API keys for current user."""

View File

@ -2,7 +2,7 @@ import os
import time
from uuid import UUID
from auth.auth_bearer import JWTBearer, get_current_user
from auth.auth_bearer import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, Request
from models.chats import ChatMessage
from models.users import User
@ -30,7 +30,7 @@ def fetch_user_stats(commons, user, date):
return userItem
# get all chats
@chat_router.get("/chat", dependencies=[Depends(JWTBearer())])
@chat_router.get("/chat", dependencies=[Depends(AuthBearer())])
async def get_chats(commons: CommonsDep, current_user: User = Depends(get_current_user)):
date = time.strftime("%Y%m%d")
user_id = fetch_user_id_from_credentials(commons, date, {"email": current_user.email})
@ -38,7 +38,7 @@ async def get_chats(commons: CommonsDep, current_user: User = Depends(get_curren
return {"chats": chats}
# get one chat
@chat_router.get("/chat/{chat_id}", dependencies=[Depends(JWTBearer())])
@chat_router.get("/chat/{chat_id}", dependencies=[Depends(AuthBearer())])
async def get_chats(commons: CommonsDep, chat_id: UUID):
chats = get_chat_details(commons, chat_id)
if len(chats) > 0:
@ -47,7 +47,7 @@ async def get_chats(commons: CommonsDep, chat_id: UUID):
return {"error": "Chat not found"}
# delete one chat
@chat_router.delete("/chat/{chat_id}", dependencies=[Depends(JWTBearer())])
@chat_router.delete("/chat/{chat_id}", dependencies=[Depends(AuthBearer())])
async def delete_chat(commons: CommonsDep, chat_id: UUID):
delete_chat_from_db(commons, chat_id)
return {"message": f"{chat_id} has been deleted."}
@ -91,11 +91,11 @@ def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=Fal
# update existing chat
@chat_router.put("/chat/{chat_id}", dependencies=[Depends(JWTBearer())])
@chat_router.put("/chat/{chat_id}", dependencies=[Depends(AuthBearer())])
async def chat_endpoint(request: Request, commons: CommonsDep, chat_id: UUID, chat_message: ChatMessage, current_user: User = Depends(get_current_user)):
return chat_handler(request, commons, chat_id, chat_message, current_user.email)
# create new chat
@chat_router.post("/chat", dependencies=[Depends(JWTBearer())])
@chat_router.post("/chat", dependencies=[Depends(AuthBearer())])
async def chat_endpoint(request: Request, commons: CommonsDep, chat_message: ChatMessage, current_user: User = Depends(get_current_user)):
return chat_handler(request, commons, None, chat_message, current_user.email, is_new_chat=True)

View File

@ -2,7 +2,7 @@ import os
import shutil
from tempfile import SpooledTemporaryFile
from auth.auth_bearer import JWTBearer, get_current_user
from auth.auth_bearer import AuthBearer, get_current_user
from crawl.crawler import CrawlWebsite
from fastapi import APIRouter, Depends, Request, UploadFile
from models.users import User
@ -23,7 +23,7 @@ def get_unique_user_data(commons, user):
user_unique_vectors = [dict(t) for t in set(tuple(d.items()) for d in documents)]
return user_unique_vectors
@crawl_router.post("/crawl/", dependencies=[Depends(JWTBearer())])
@crawl_router.post("/crawl/", dependencies=[Depends(AuthBearer())])
async def crawl_endpoint(request: Request,commons: CommonsDep, crawl_website: CrawlWebsite, enable_summarization: bool = False, current_user: User = Depends(get_current_user)):
max_brain_size = os.getenv("MAX_BRAIN_SIZE")
if request.headers.get('Openai-Api-Key'):

View File

@ -1,4 +1,4 @@
from auth.auth_bearer import JWTBearer, get_current_user
from auth.auth_bearer import AuthBearer, get_current_user
from fastapi import APIRouter, Depends
from models.users import User
from utils.vectors import CommonsDep
@ -13,15 +13,15 @@ def get_unique_user_data(commons, user):
unique_data = [dict(t) for t in set(tuple(d.items()) for d in documents)]
return unique_data
@explore_router.get("/explore", dependencies=[Depends(JWTBearer())])
@explore_router.get("/explore", dependencies=[Depends(AuthBearer())])
async def explore_endpoint(commons: CommonsDep, current_user: User = Depends(get_current_user)):
unique_data = get_unique_user_data(commons, current_user)
unique_data.sort(key=lambda x: int(x['size']), reverse=True)
return {"documents": unique_data}
@explore_router.delete("/explore/{file_name}", dependencies=[Depends(JWTBearer())])
async def delete_endpoint(commons: CommonsDep, file_name: str, credentials: dict = Depends(JWTBearer())):
@explore_router.delete("/explore/{file_name}", dependencies=[Depends(AuthBearer())])
async def delete_endpoint(commons: CommonsDep, file_name: str, credentials: dict = Depends(AuthBearer())):
user = User(email=credentials.get('email', 'none'))
# Cascade delete the summary from the database first, because it has a foreign key constraint
commons['supabase'].table("summaries").delete().match(
@ -30,7 +30,7 @@ async def delete_endpoint(commons: CommonsDep, file_name: str, credentials: dict
{"metadata->>file_name": file_name, "user_id": user.email}).execute()
return {"message": f"{file_name} of user {user.email} has been deleted."}
@explore_router.get("/explore/{file_name}", dependencies=[Depends(JWTBearer())])
@explore_router.get("/explore/{file_name}", dependencies=[Depends(AuthBearer())])
async def download_endpoint(commons: CommonsDep, file_name: str, current_user: User = Depends(get_current_user)):
response = commons['supabase'].table("vectors").select(
"metadata->>file_name, metadata->>file_size, metadata->>file_extension, metadata->>file_url", "content").match({"metadata->>file_name": file_name, "user_id": current_user.email}).execute()

View File

@ -1,6 +1,6 @@
import os
from auth.auth_bearer import JWTBearer, get_current_user
from auth.auth_bearer import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, Request, UploadFile
from models.users import User
from utils.file import convert_bytes, get_file_size
@ -23,7 +23,7 @@ def calculate_remaining_space(request, max_brain_size, max_brain_size_with_own_k
remaining_free_space = float(max_brain_size_with_own_key) - current_brain_size if request.headers.get('Openai-Api-Key') else float(max_brain_size) - current_brain_size
return remaining_free_space
@upload_router.post("/upload", dependencies=[Depends(JWTBearer())])
@upload_router.post("/upload", dependencies=[Depends(AuthBearer())])
async def upload_file(request: Request, commons: CommonsDep, file: UploadFile, enable_summarization: bool = False, current_user: User = Depends(get_current_user)):
max_brain_size = os.getenv("MAX_BRAIN_SIZE")
max_brain_size_with_own_key = os.getenv("MAX_BRAIN_SIZE_WITH_KEY",209715200)

View File

@ -2,7 +2,7 @@ import os
import time
from fastapi import APIRouter, Depends, Request
from auth.auth_bearer import JWTBearer, get_current_user
from auth.auth_bearer import AuthBearer, get_current_user
from models.users import User
from utils.vectors import CommonsDep
@ -27,7 +27,7 @@ def get_user_request_stats(commons, email):
'*').filter("email", "eq", email).execute()
return requests_stats.data
@user_router.get("/user", dependencies=[Depends(JWTBearer())])
@user_router.get("/user", dependencies=[Depends(AuthBearer())])
async def get_user_endpoint(request: Request, commons: CommonsDep, current_user: User = Depends(get_current_user)):
user_vectors = get_user_vectors(commons, current_user.email)