diff --git a/backend/core/models/user_identity.py b/backend/core/models/user_identity.py new file mode 100644 index 000000000..aac40b8cd --- /dev/null +++ b/backend/core/models/user_identity.py @@ -0,0 +1,9 @@ +from typing import Optional +from uuid import UUID + +from pydantic import BaseModel + + +class UserIdentity(BaseModel): + user_id: UUID + openai_api_key: Optional[str] = None diff --git a/backend/core/models/users.py b/backend/core/models/users.py index 0ebc05e5b..5294f1ac7 100644 --- a/backend/core/models/users.py +++ b/backend/core/models/users.py @@ -2,19 +2,20 @@ from typing import Optional from uuid import UUID from logger import get_logger -from models.settings import common_dependencies from pydantic import BaseModel +from models.settings import common_dependencies + logger = get_logger(__name__) +# [TODO] Rename the user table and its references to 'user_usage' class User(BaseModel): id: UUID email: Optional[str] user_openai_api_key: Optional[str] = None requests_count: int = 0 - # [TODO] Rename the user table and its references to 'user_usage' def create_user(self, date): """ Create a new user entry in the database diff --git a/backend/core/repository/user_identity/create_user_identity.py b/backend/core/repository/user_identity/create_user_identity.py new file mode 100644 index 000000000..26d95e8f6 --- /dev/null +++ b/backend/core/repository/user_identity/create_user_identity.py @@ -0,0 +1,13 @@ +from models.settings import common_dependencies +from models.user_identity import UserIdentity + + +def create_user_identity(user_identity: UserIdentity) -> UserIdentity: + commons = common_dependencies() + user_identity_dict = user_identity.dict() + user_identity_dict["user_id"] = str(user_identity.user_id) + response = ( + commons["supabase"].from_("user_identity").insert(user_identity_dict).execute() + ) + + return UserIdentity(**response.data[0]) diff --git a/backend/core/repository/user_identity/get_user_identity.py b/backend/core/repository/user_identity/get_user_identity.py new file mode 100644 index 000000000..0e3c2a236 --- /dev/null +++ b/backend/core/repository/user_identity/get_user_identity.py @@ -0,0 +1,21 @@ +from uuid import UUID + +from models.settings import common_dependencies +from models.user_identity import UserIdentity +from repository.user_identity.create_user_identity import create_user_identity + + +def get_user_identity(user_id: UUID) -> UserIdentity: + commons = common_dependencies() + response = ( + commons["supabase"] + .from_("user_identity") + .select("*") + .filter("user_id", "eq", user_id) + .execute() + ) + + if len(response.data) == 0: + return create_user_identity(UserIdentity(user_id=user_id)) + + return UserIdentity(**response.data[0]) diff --git a/backend/core/repository/user_identity/update_user_identity.py b/backend/core/repository/user_identity/update_user_identity.py new file mode 100644 index 000000000..a30136e2b --- /dev/null +++ b/backend/core/repository/user_identity/update_user_identity.py @@ -0,0 +1,36 @@ +from typing import Optional +from uuid import UUID + +from models.settings import common_dependencies +from models.user_identity import UserIdentity +from pydantic import BaseModel +from repository.user_identity.create_user_identity import ( + create_user_identity, +) + + +class UserIdentityUpdatableProperties(BaseModel): + openai_api_key: Optional[str] + + +def update_user_identity( + user_id: UUID, + user_identity_updatable_properties: UserIdentityUpdatableProperties, +) -> UserIdentity: + commons = common_dependencies() + response = ( + commons["supabase"] + .from_("user_identity") + .update(user_identity_updatable_properties.__dict__) + .filter("user_id", "eq", user_id) + .execute() + ) + + if len(response.data) == 0: + user_identity = UserIdentity( + user_id=user_id, + openai_api_key=user_identity_updatable_properties.openai_api_key, + ) + return create_user_identity(user_identity) + + return UserIdentity(**response.data[0]) diff --git a/backend/core/routes/user_routes.py b/backend/core/routes/user_routes.py index ae0b68c8f..c54fc455f 100644 --- a/backend/core/routes/user_routes.py +++ b/backend/core/routes/user_routes.py @@ -5,7 +5,13 @@ from auth import AuthBearer, get_current_user from fastapi import APIRouter, Depends, Request from models.brains import Brain, get_default_user_brain from models.settings import BrainRateLimiting +from models.user_identity import UserIdentity from models.users import User +from repository.user_identity.get_user_identity import get_user_identity +from repository.user_identity.update_user_identity import ( + UserIdentityUpdatableProperties, + update_user_identity, +) user_router = APIRouter() @@ -56,3 +62,32 @@ async def get_user_endpoint( "requests_stats": requests_stats, "date": date, } + + +@user_router.put( + "/user/identity", + dependencies=[Depends(AuthBearer())], + tags=["User"], +) +def update_user_identity_route( + user_identity_updatable_properties: UserIdentityUpdatableProperties, + current_user: User = Depends(get_current_user), +) -> UserIdentity: + """ + Update user identity. + """ + return update_user_identity(current_user.id, user_identity_updatable_properties) + + +@user_router.get( + "/user/identity", + dependencies=[Depends(AuthBearer())], + tags=["User"], +) +def get_user_identity_route( + current_user: User = Depends(get_current_user), +) -> UserIdentity: + """ + Get user identity. + """ + return get_user_identity(current_user.id) diff --git a/scripts/20230731172400_add_user_identity_table.sql b/scripts/20230731172400_add_user_identity_table.sql new file mode 100644 index 000000000..ae5195eb0 --- /dev/null +++ b/scripts/20230731172400_add_user_identity_table.sql @@ -0,0 +1,16 @@ +BEGIN; + +-- Create user_identity table if it doesn't exist +CREATE TABLE IF NOT EXISTS user_identity ( + user_id UUID PRIMARY KEY, + openai_api_key VARCHAR(255) +); + +-- Insert migration record if it doesn't exist +INSERT INTO migrations (name) +SELECT '20230731172400_add_user_identity_table' +WHERE NOT EXISTS ( + SELECT 1 FROM migrations WHERE name = '20230731172400_add_user_identity_table' +); + +COMMIT; diff --git a/scripts/tables.sql b/scripts/tables.sql index 2f0816baa..12c84fb41 100644 --- a/scripts/tables.sql +++ b/scripts/tables.sql @@ -167,6 +167,12 @@ CREATE TABLE IF NOT EXISTS brain_subscription_invitations ( FOREIGN KEY (brain_id) REFERENCES brains (brain_id) ); +--- Create user_identity table +CREATE TABLE IF NOT EXISTS user_identity ( + user_id UUID PRIMARY KEY, + openai_api_key VARCHAR(255) +); + CREATE OR REPLACE FUNCTION public.get_user_email_by_user_id(user_id uuid) RETURNS TABLE (email text) SECURITY definer @@ -194,7 +200,7 @@ CREATE TABLE IF NOT EXISTS migrations ( ); INSERT INTO migrations (name) -SELECT '202307241530031_add_fields_to_brain' +SELECT '20230731172400_add_user_identity_table' WHERE NOT EXISTS ( - SELECT 1 FROM migrations WHERE name = '202307241530031_add_fields_to_brain' + SELECT 1 FROM migrations WHERE name = '20230731172400_add_user_identity_table' );