mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 09:32:22 +03:00
feat(sync): retrieve user email used for the connection (#2628)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
a89db0cd5a
commit
043bcd17ce
@ -1,5 +1,6 @@
|
||||
import os
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from logger import get_logger
|
||||
from middlewares.auth import AuthBearer, get_current_user
|
||||
@ -101,13 +102,27 @@ def oauth2callback_azure(request: Request):
|
||||
logger.error("Failed to acquire token")
|
||||
raise HTTPException(status_code=400, detail="Failed to acquire token")
|
||||
|
||||
access_token = result["access_token"]
|
||||
|
||||
creds = result
|
||||
logger.info(f"Fetched OAuth2 token for user: {current_user}")
|
||||
|
||||
# Fetch user email from Microsoft Graph API
|
||||
graph_url = "https://graph.microsoft.com/v1.0/me"
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
response = requests.get(graph_url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
logger.error("Failed to fetch user profile from Microsoft Graph API")
|
||||
raise HTTPException(status_code=400, detail="Failed to fetch user profile")
|
||||
|
||||
user_info = response.json()
|
||||
user_email = user_info.get("mail") or user_info.get("userPrincipalName")
|
||||
logger.info(f"Retrieved email for user: {current_user} - {user_email}")
|
||||
|
||||
sync_user_input = SyncUserUpdateInput(
|
||||
credentials=creds,
|
||||
state={},
|
||||
credentials=result, state={}, email=user_email
|
||||
)
|
||||
|
||||
sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)
|
||||
logger.info(f"Azure sync created successfully for user: {current_user}")
|
||||
return {"message": "Azure sync created successfully"}
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from googleapiclient.discovery import build
|
||||
from logger import get_logger
|
||||
from middlewares.auth import AuthBearer, get_current_user
|
||||
from modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput
|
||||
@ -26,6 +27,8 @@ google_sync_router = APIRouter()
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"openid",
|
||||
]
|
||||
BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:5050")
|
||||
BASE_REDIRECT_URI = f"{BACKEND_URL}/sync/google/oauth2callback"
|
||||
@ -104,7 +107,7 @@ def oauth2callback_google(request: Request):
|
||||
sync_user_state = sync_user_service.get_sync_user_by_state(state_dict)
|
||||
logger.info(f"Retrieved sync user state: {sync_user_state}")
|
||||
|
||||
if state_dict != sync_user_state["state"]:
|
||||
if not sync_user_state or state_dict != sync_user_state.get("state"):
|
||||
logger.error("Invalid state parameter")
|
||||
raise HTTPException(status_code=400, detail="Invalid state parameter")
|
||||
if sync_user_state.get("user_id") != current_user:
|
||||
@ -122,9 +125,16 @@ def oauth2callback_google(request: Request):
|
||||
creds = flow.credentials
|
||||
logger.info(f"Fetched OAuth2 token for user: {current_user}")
|
||||
|
||||
# Use the credentials to get the user's email
|
||||
service = build("oauth2", "v2", credentials=creds)
|
||||
user_info = service.userinfo().get().execute()
|
||||
user_email = user_info.get("email")
|
||||
logger.info(f"Retrieved email for user: {current_user} - {user_email}")
|
||||
|
||||
sync_user_input = SyncUserUpdateInput(
|
||||
credentials=json.loads(creds.to_json()),
|
||||
state={},
|
||||
email=user_email,
|
||||
)
|
||||
sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)
|
||||
logger.info(f"Google Drive sync created successfully for user: {current_user}")
|
||||
|
@ -83,6 +83,29 @@ async def get_user_syncs(current_user: UserIdentity = Depends(get_current_user))
|
||||
logger.debug(f"Fetching user syncs for user: {current_user.id}")
|
||||
return sync_user_service.get_syncs_user(str(current_user.id))
|
||||
|
||||
@sync_router.delete(
|
||||
"/sync/{sync_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Depends(AuthBearer())],
|
||||
tags=["Sync"],
|
||||
)
|
||||
async def delete_user_sync(
|
||||
sync_id: int, current_user: UserIdentity = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Delete a sync for the current user.
|
||||
|
||||
Args:
|
||||
sync_id (int): The ID of the sync to delete.
|
||||
current_user (UserIdentity): The current authenticated user.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
logger.debug(f"Deleting user sync for user: {current_user.id} with sync ID: {sync_id}")
|
||||
sync_user_service.delete_sync_user(sync_id, str(current_user.id))
|
||||
return None
|
||||
|
||||
|
||||
@sync_router.post(
|
||||
"/sync/active",
|
||||
|
@ -33,6 +33,7 @@ class SyncUserUpdateInput(BaseModel):
|
||||
|
||||
credentials: dict
|
||||
state: dict
|
||||
email: str
|
||||
|
||||
|
||||
class SyncActiveSettings(BaseModel):
|
||||
|
@ -118,7 +118,7 @@ class SyncUser(SyncUserInterface):
|
||||
logger.warning("No sync user found for state: %s", state)
|
||||
return []
|
||||
|
||||
def delete_sync_user(self, provider: str, user_id: str):
|
||||
def delete_sync_user(self, sync_id: str, user_id: str):
|
||||
"""
|
||||
Delete a sync user from the database.
|
||||
|
||||
@ -127,9 +127,9 @@ class SyncUser(SyncUserInterface):
|
||||
user_id (str): The user ID of the sync user.
|
||||
"""
|
||||
logger.info(
|
||||
"Deleting sync user with provider: %s, user_id: %s", provider, user_id
|
||||
"Deleting sync user with sync_id: %s, user_id: %s", sync_id, user_id
|
||||
)
|
||||
self.db.from_("syncs_user").delete().eq("provider", provider).eq(
|
||||
self.db.from_("syncs_user").delete().eq("id", sync_id).eq(
|
||||
"user_id", user_id
|
||||
).execute()
|
||||
logger.info("Sync user deleted successfully")
|
||||
|
@ -31,8 +31,8 @@ class SyncUserService:
|
||||
def create_sync_user(self, sync_user_input: SyncsUserInput):
|
||||
return self.repository.create_sync_user(sync_user_input)
|
||||
|
||||
def delete_sync_user(self, provider: str, user_id: str):
|
||||
return self.repository.delete_sync_user(provider, user_id)
|
||||
def delete_sync_user(self, sync_id: str, user_id: str):
|
||||
return self.repository.delete_sync_user(sync_id, user_id)
|
||||
|
||||
def get_sync_user_by_state(self, state: dict):
|
||||
return self.repository.get_sync_user_by_state(state)
|
||||
|
@ -0,0 +1,3 @@
|
||||
alter table "public"."syncs_user" add column "email" text;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user