mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-09-11 14:36:35 +03:00
chore: rename and refactor AuthBearer
This commit is contained in:
parent
0bc4221b2c
commit
c0b877c94e
36
backend/auth/api_key_handler.py
Normal file
36
backend/auth/api_key_handler.py
Normal file
@ -0,0 +1,36 @@
|
||||
|
||||
from datetime import datetime
|
||||
from fastapi import HTTPException
|
||||
|
||||
from pydantic import DateError
|
||||
from utils.vectors import CommonsDep
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str, commons: CommonsDep):
|
||||
try:
|
||||
# Use UTC time to avoid timezone issues
|
||||
current_date = datetime.utcnow().date()
|
||||
result = commons['supabase'].table('api_keys').select('api_key', 'creation_time').filter('api_key', 'eq', api_key).filter('is_active', 'eq', True).execute()
|
||||
if result.data is not None and len(result.data) > 0:
|
||||
api_key_creation_date = datetime.strptime(result.data[0]['creation_time'], "%Y-%m-%dT%H:%M:%S").date()
|
||||
|
||||
# Check if the API key was created today: Todo remove this check and use deleted_time instead.
|
||||
if api_key_creation_date == current_date:
|
||||
return True
|
||||
return False
|
||||
except DateError:
|
||||
return False
|
||||
|
||||
async def get_user_from_api_key(api_key: str, commons: CommonsDep):
|
||||
# Lookup the user_id from the api_keys table
|
||||
user_id_data = commons['supabase'].table('api_keys').select('user_id').filter('api_key', 'eq', api_key).execute()
|
||||
|
||||
if not user_id_data.data:
|
||||
raise HTTPException(status_code=400, detail="Invalid API key.")
|
||||
|
||||
user_id = user_id_data.data[0]['user_id']
|
||||
|
||||
# Lookup the email from the users table. Todo: remove and use user_id for credentials
|
||||
user_email_data = commons['supabase'].table('users').select('email').filter('user_id', 'eq', user_id).execute()
|
||||
|
||||
return {'email': user_email_data.data[0]['email']} if user_email_data.data else {'email': None}
|
@ -1,14 +1,15 @@
|
||||
import os
|
||||
|
||||
from auth.api_key_handler import verify_api_key, get_user_from_api_key
|
||||
from auth.jwt_token_handler import decode_access_token, verify_token
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import Depends, Request, HTTPException
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
import os
|
||||
from models.users import User
|
||||
from utils.vectors import CommonsDep
|
||||
from asyncpg.exceptions import DataError
|
||||
from auth.auth_handler import decode_access_token
|
||||
from datetime import datetime
|
||||
|
||||
class JWTBearer(HTTPBearer):
|
||||
class AuthBearer(HTTPBearer):
|
||||
def __init__(self, auto_error: bool = True):
|
||||
super().__init__(auto_error=auto_error)
|
||||
|
||||
@ -27,54 +28,15 @@ class JWTBearer(HTTPBearer):
|
||||
async def authenticate(self, token, commons):
|
||||
if os.environ.get("AUTHENTICATE") == "false":
|
||||
return self.get_test_user()
|
||||
elif self.verify_jwt(token):
|
||||
return self.decode_jwt(token)
|
||||
elif await self.verify_api_key(token, commons):
|
||||
return await self.get_user_from_api_key(token, commons)
|
||||
elif verify_token(token):
|
||||
return decode_access_token(token)
|
||||
elif await verify_api_key(token, commons):
|
||||
return await get_user_from_api_key(token, commons)
|
||||
else:
|
||||
raise HTTPException(status_code=402, detail="Invalid token or expired token.")
|
||||
|
||||
def get_test_user(self):
|
||||
return {'email': 'test@example.com'} # replace with test user information
|
||||
|
||||
def verify_jwt(self, jwtoken: str):
|
||||
payload = decode_access_token(jwtoken)
|
||||
return payload is not None
|
||||
|
||||
def decode_jwt(self, jwtoken: str):
|
||||
return decode_access_token(jwtoken)
|
||||
|
||||
async def verify_api_key(self, api_key: str, commons: CommonsDep):
|
||||
try:
|
||||
# Use UTC time to avoid timezone issues
|
||||
current_date = datetime.utcnow().date()
|
||||
result = commons['supabase'].table('api_keys').select('api_key', 'creation_time').filter('api_key', 'eq', api_key).filter('is_active', 'eq', True).execute()
|
||||
if result.data is not None and len(result.data) > 0:
|
||||
api_key_creation_date = datetime.strptime(result.data[0]['creation_time'], "%Y-%m-%dT%H:%M:%S").date()
|
||||
|
||||
# Check if the API key was created today: Todo remove this check and use deleted_time instead.
|
||||
if api_key_creation_date == current_date:
|
||||
return True
|
||||
return False
|
||||
except DataError:
|
||||
return False
|
||||
|
||||
|
||||
async def get_user_from_api_key(self, api_key: str, commons: CommonsDep):
|
||||
# Lookup the user_id from the api_keys table
|
||||
user_id_data = commons['supabase'].table('api_keys').select('user_id').filter('api_key', 'eq', api_key).execute()
|
||||
|
||||
if not user_id_data.data:
|
||||
raise HTTPException(status_code=400, detail="Invalid API key.")
|
||||
|
||||
user_id = user_id_data.data[0]['user_id']
|
||||
|
||||
# Lookup the email from the users table. Todo: remove and use user_id for credentials
|
||||
user_email_data = commons['supabase'].table('users').select('email').filter('user_id', 'eq', user_id).execute()
|
||||
|
||||
return {'email': user_email_data.data[0]['email']} if user_email_data.data else {'email': None}
|
||||
|
||||
|
||||
|
||||
def get_current_user(credentials: dict = Depends(JWTBearer())) -> User:
|
||||
def get_current_user(credentials: dict = Depends(AuthBearer())) -> User:
|
||||
return User(email=credentials.get('email', 'none'))
|
||||
|
@ -24,6 +24,10 @@ def decode_access_token(token: str):
|
||||
return payload
|
||||
except JWTError as e:
|
||||
return None
|
||||
|
||||
def verify_token(token: str):
|
||||
payload = decode_access_token(token)
|
||||
return payload is not None
|
||||
|
||||
def get_user_email_from_token(token: str):
|
||||
payload = decode_access_token(token)
|
@ -2,7 +2,7 @@ from datetime import datetime
|
||||
import time
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
from auth.auth_bearer import JWTBearer, get_current_user
|
||||
from auth.auth_bearer import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends
|
||||
from utils.vectors import fetch_user_id_from_credentials
|
||||
from models.users import User
|
||||
@ -26,7 +26,7 @@ class ApiKey(BaseModel):
|
||||
api_key_router = APIRouter()
|
||||
|
||||
|
||||
@api_key_router.post("/api-key", response_model=ApiKey, dependencies=[Depends(JWTBearer())])
|
||||
@api_key_router.post("/api-key", response_model=ApiKey, dependencies=[Depends(AuthBearer())])
|
||||
async def create_api_key(commons: CommonsDep, current_user: User = Depends(get_current_user)):
|
||||
|
||||
date = time.strftime("%Y%m%d")
|
||||
@ -60,7 +60,7 @@ async def create_api_key(commons: CommonsDep, current_user: User = Depends(get_c
|
||||
return {"api_key": new_api_key}
|
||||
|
||||
|
||||
@api_key_router.delete("/api-key/{key_id}", dependencies=[Depends(JWTBearer())])
|
||||
@api_key_router.delete("/api-key/{key_id}", dependencies=[Depends(AuthBearer())])
|
||||
async def delete_api_key(key_id: str, commons: CommonsDep, current_user: User = Depends(get_current_user)):
|
||||
"""Delete (deactivate) an API key for current user."""
|
||||
|
||||
@ -72,7 +72,7 @@ async def delete_api_key(key_id: str, commons: CommonsDep, current_user: User =
|
||||
return {"message": "API key deleted."}
|
||||
|
||||
|
||||
@api_key_router.get("/api-keys", response_model=List[ApiKeyInfo], dependencies=[Depends(JWTBearer())])
|
||||
@api_key_router.get("/api-keys", response_model=List[ApiKeyInfo], dependencies=[Depends(AuthBearer())])
|
||||
async def get_api_keys(commons: CommonsDep, current_user: User = Depends(get_current_user)):
|
||||
"""Get all active API keys for current user."""
|
||||
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
from auth.auth_bearer import JWTBearer, get_current_user
|
||||
from auth.auth_bearer import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from models.chats import ChatMessage
|
||||
from models.users import User
|
||||
@ -30,7 +30,7 @@ def fetch_user_stats(commons, user, date):
|
||||
return userItem
|
||||
|
||||
# get all chats
|
||||
@chat_router.get("/chat", dependencies=[Depends(JWTBearer())])
|
||||
@chat_router.get("/chat", dependencies=[Depends(AuthBearer())])
|
||||
async def get_chats(commons: CommonsDep, current_user: User = Depends(get_current_user)):
|
||||
date = time.strftime("%Y%m%d")
|
||||
user_id = fetch_user_id_from_credentials(commons, date, {"email": current_user.email})
|
||||
@ -38,7 +38,7 @@ async def get_chats(commons: CommonsDep, current_user: User = Depends(get_curren
|
||||
return {"chats": chats}
|
||||
|
||||
# get one chat
|
||||
@chat_router.get("/chat/{chat_id}", dependencies=[Depends(JWTBearer())])
|
||||
@chat_router.get("/chat/{chat_id}", dependencies=[Depends(AuthBearer())])
|
||||
async def get_chats(commons: CommonsDep, chat_id: UUID):
|
||||
chats = get_chat_details(commons, chat_id)
|
||||
if len(chats) > 0:
|
||||
@ -47,7 +47,7 @@ async def get_chats(commons: CommonsDep, chat_id: UUID):
|
||||
return {"error": "Chat not found"}
|
||||
|
||||
# delete one chat
|
||||
@chat_router.delete("/chat/{chat_id}", dependencies=[Depends(JWTBearer())])
|
||||
@chat_router.delete("/chat/{chat_id}", dependencies=[Depends(AuthBearer())])
|
||||
async def delete_chat(commons: CommonsDep, chat_id: UUID):
|
||||
delete_chat_from_db(commons, chat_id)
|
||||
return {"message": f"{chat_id} has been deleted."}
|
||||
@ -91,11 +91,11 @@ def chat_handler(request, commons, chat_id, chat_message, email, is_new_chat=Fal
|
||||
|
||||
|
||||
# update existing chat
|
||||
@chat_router.put("/chat/{chat_id}", dependencies=[Depends(JWTBearer())])
|
||||
@chat_router.put("/chat/{chat_id}", dependencies=[Depends(AuthBearer())])
|
||||
async def chat_endpoint(request: Request, commons: CommonsDep, chat_id: UUID, chat_message: ChatMessage, current_user: User = Depends(get_current_user)):
|
||||
return chat_handler(request, commons, chat_id, chat_message, current_user.email)
|
||||
|
||||
# create new chat
|
||||
@chat_router.post("/chat", dependencies=[Depends(JWTBearer())])
|
||||
@chat_router.post("/chat", dependencies=[Depends(AuthBearer())])
|
||||
async def chat_endpoint(request: Request, commons: CommonsDep, chat_message: ChatMessage, current_user: User = Depends(get_current_user)):
|
||||
return chat_handler(request, commons, None, chat_message, current_user.email, is_new_chat=True)
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
import shutil
|
||||
from tempfile import SpooledTemporaryFile
|
||||
|
||||
from auth.auth_bearer import JWTBearer, get_current_user
|
||||
from auth.auth_bearer import AuthBearer, get_current_user
|
||||
from crawl.crawler import CrawlWebsite
|
||||
from fastapi import APIRouter, Depends, Request, UploadFile
|
||||
from models.users import User
|
||||
@ -23,7 +23,7 @@ def get_unique_user_data(commons, user):
|
||||
user_unique_vectors = [dict(t) for t in set(tuple(d.items()) for d in documents)]
|
||||
return user_unique_vectors
|
||||
|
||||
@crawl_router.post("/crawl/", dependencies=[Depends(JWTBearer())])
|
||||
@crawl_router.post("/crawl/", dependencies=[Depends(AuthBearer())])
|
||||
async def crawl_endpoint(request: Request,commons: CommonsDep, crawl_website: CrawlWebsite, enable_summarization: bool = False, current_user: User = Depends(get_current_user)):
|
||||
max_brain_size = os.getenv("MAX_BRAIN_SIZE")
|
||||
if request.headers.get('Openai-Api-Key'):
|
||||
|
@ -1,4 +1,4 @@
|
||||
from auth.auth_bearer import JWTBearer, get_current_user
|
||||
from auth.auth_bearer import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends
|
||||
from models.users import User
|
||||
from utils.vectors import CommonsDep
|
||||
@ -13,15 +13,15 @@ def get_unique_user_data(commons, user):
|
||||
unique_data = [dict(t) for t in set(tuple(d.items()) for d in documents)]
|
||||
return unique_data
|
||||
|
||||
@explore_router.get("/explore", dependencies=[Depends(JWTBearer())])
|
||||
@explore_router.get("/explore", dependencies=[Depends(AuthBearer())])
|
||||
async def explore_endpoint(commons: CommonsDep, current_user: User = Depends(get_current_user)):
|
||||
unique_data = get_unique_user_data(commons, current_user)
|
||||
unique_data.sort(key=lambda x: int(x['size']), reverse=True)
|
||||
return {"documents": unique_data}
|
||||
|
||||
|
||||
@explore_router.delete("/explore/{file_name}", dependencies=[Depends(JWTBearer())])
|
||||
async def delete_endpoint(commons: CommonsDep, file_name: str, credentials: dict = Depends(JWTBearer())):
|
||||
@explore_router.delete("/explore/{file_name}", dependencies=[Depends(AuthBearer())])
|
||||
async def delete_endpoint(commons: CommonsDep, file_name: str, credentials: dict = Depends(AuthBearer())):
|
||||
user = User(email=credentials.get('email', 'none'))
|
||||
# Cascade delete the summary from the database first, because it has a foreign key constraint
|
||||
commons['supabase'].table("summaries").delete().match(
|
||||
@ -30,7 +30,7 @@ async def delete_endpoint(commons: CommonsDep, file_name: str, credentials: dict
|
||||
{"metadata->>file_name": file_name, "user_id": user.email}).execute()
|
||||
return {"message": f"{file_name} of user {user.email} has been deleted."}
|
||||
|
||||
@explore_router.get("/explore/{file_name}", dependencies=[Depends(JWTBearer())])
|
||||
@explore_router.get("/explore/{file_name}", dependencies=[Depends(AuthBearer())])
|
||||
async def download_endpoint(commons: CommonsDep, file_name: str, current_user: User = Depends(get_current_user)):
|
||||
response = commons['supabase'].table("vectors").select(
|
||||
"metadata->>file_name, metadata->>file_size, metadata->>file_extension, metadata->>file_url", "content").match({"metadata->>file_name": file_name, "user_id": current_user.email}).execute()
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
|
||||
from auth.auth_bearer import JWTBearer, get_current_user
|
||||
from auth.auth_bearer import AuthBearer, get_current_user
|
||||
from fastapi import APIRouter, Depends, Request, UploadFile
|
||||
from models.users import User
|
||||
from utils.file import convert_bytes, get_file_size
|
||||
@ -23,7 +23,7 @@ def calculate_remaining_space(request, max_brain_size, max_brain_size_with_own_k
|
||||
remaining_free_space = float(max_brain_size_with_own_key) - current_brain_size if request.headers.get('Openai-Api-Key') else float(max_brain_size) - current_brain_size
|
||||
return remaining_free_space
|
||||
|
||||
@upload_router.post("/upload", dependencies=[Depends(JWTBearer())])
|
||||
@upload_router.post("/upload", dependencies=[Depends(AuthBearer())])
|
||||
async def upload_file(request: Request, commons: CommonsDep, file: UploadFile, enable_summarization: bool = False, current_user: User = Depends(get_current_user)):
|
||||
max_brain_size = os.getenv("MAX_BRAIN_SIZE")
|
||||
max_brain_size_with_own_key = os.getenv("MAX_BRAIN_SIZE_WITH_KEY",209715200)
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from auth.auth_bearer import JWTBearer, get_current_user
|
||||
from auth.auth_bearer import AuthBearer, get_current_user
|
||||
from models.users import User
|
||||
from utils.vectors import CommonsDep
|
||||
|
||||
@ -27,7 +27,7 @@ def get_user_request_stats(commons, email):
|
||||
'*').filter("email", "eq", email).execute()
|
||||
return requests_stats.data
|
||||
|
||||
@user_router.get("/user", dependencies=[Depends(JWTBearer())])
|
||||
@user_router.get("/user", dependencies=[Depends(AuthBearer())])
|
||||
async def get_user_endpoint(request: Request, commons: CommonsDep, current_user: User = Depends(get_current_user)):
|
||||
|
||||
user_vectors = get_user_vectors(commons, current_user.email)
|
||||
|
Loading…
Reference in New Issue
Block a user