From 47c6e24bf1a7f7d96d436eb76b41d3180d2bdedf Mon Sep 17 00:00:00 2001 From: Stan Girard Date: Sat, 8 Jun 2024 11:55:11 +0200 Subject: [PATCH] fix: integrations (#2642) This pull request adds support for recursive folder retrieval in the `get_files_folder_user_sync` method. Previously, the method only retrieved files from the specified folder, but now it can also retrieve files from all subfolders recursively. This enhancement improves the functionality and flexibility of the method, allowing for more comprehensive file retrieval in sync operations. --- backend/modules/sync/dto/inputs.py | 4 +- backend/modules/sync/entity/sync.py | 10 + backend/modules/sync/repository/sync.py | 5 +- backend/modules/sync/repository/sync_files.py | 6 +- .../sync/repository/sync_interfaces.py | 6 +- backend/modules/sync/repository/sync_user.py | 10 +- backend/modules/sync/service/sync_service.py | 10 +- backend/modules/sync/utils/googleutils.py | 60 ++++- backend/modules/sync/utils/list_files.py | 221 ++++++++++++++---- backend/modules/sync/utils/sharepointutils.py | 77 ++++-- .../20240608095352_supported-sync.sql | 2 + backend/tach.yml | 124 ++++++++++ 12 files changed, 457 insertions(+), 78 deletions(-) create mode 100644 backend/supabase/migrations/20240608095352_supported-sync.sql create mode 100644 backend/tach.yml diff --git a/backend/modules/sync/dto/inputs.py b/backend/modules/sync/dto/inputs.py index b5e363271..1d6a0f9af 100644 --- a/backend/modules/sync/dto/inputs.py +++ b/backend/modules/sync/dto/inputs.py @@ -93,6 +93,7 @@ class SyncFileInput(BaseModel): syncs_active_id: int last_modified: str brain_id: str + supported: Optional[bool] = True class SyncFileUpdateInput(BaseModel): @@ -103,4 +104,5 @@ class SyncFileUpdateInput(BaseModel): last_modified (datetime.datetime): The updated last modified date and time. """ - last_modified: str + last_modified: Optional[str] = None + supported: Optional[bool] = None diff --git a/backend/modules/sync/entity/sync.py b/backend/modules/sync/entity/sync.py index 12e1b2efe..f0f320376 100644 --- a/backend/modules/sync/entity/sync.py +++ b/backend/modules/sync/entity/sync.py @@ -1,8 +1,16 @@ from datetime import datetime +from typing import Optional from pydantic import BaseModel +class SyncsUser(BaseModel): + id: int + user_id: str + name: str + provider: str + + class SyncsActive(BaseModel): id: int name: str @@ -12,6 +20,7 @@ class SyncsActive(BaseModel): last_synced: datetime sync_interval_minutes: int brain_id: str + syncs_user: Optional[SyncsUser] class SyncsFiles(BaseModel): @@ -20,3 +29,4 @@ class SyncsFiles(BaseModel): syncs_active_id: int last_modified: str brain_id: str + supported: bool diff --git a/backend/modules/sync/repository/sync.py b/backend/modules/sync/repository/sync.py index 3b10e9285..8fcab62f5 100644 --- a/backend/modules/sync/repository/sync.py +++ b/backend/modules/sync/repository/sync.py @@ -64,7 +64,10 @@ class Sync(SyncInterface): """ logger.info("Retrieving active syncs for user_id: %s", user_id) response = ( - self.db.from_("syncs_active").select("*").eq("user_id", user_id).execute() + self.db.from_("syncs_active") + .select("*, syncs_user(*)") + .eq("user_id", user_id) + .execute() ) if response.data: logger.info("Active syncs retrieved successfully: %s", response.data) diff --git a/backend/modules/sync/repository/sync_files.py b/backend/modules/sync/repository/sync_files.py index a2a73d5cc..d7d166970 100644 --- a/backend/modules/sync/repository/sync_files.py +++ b/backend/modules/sync/repository/sync_files.py @@ -81,9 +81,9 @@ class SyncFiles(SyncFileInterface): sync_file_id, sync_file_input, ) - self.db.from_("syncs_files").update(sync_file_input.model_dump()).eq( - "id", sync_file_id - ).execute() + self.db.from_("syncs_files").update( + sync_file_input.model_dump(exclude_unset=True) + ).eq("id", sync_file_id).execute() logger.info("Sync file updated successfully") def delete_sync_file(self, sync_file_id: int): diff --git a/backend/modules/sync/repository/sync_interfaces.py b/backend/modules/sync/repository/sync_interfaces.py index b46c53ed2..e0ef1ba72 100644 --- a/backend/modules/sync/repository/sync_interfaces.py +++ b/backend/modules/sync/repository/sync_interfaces.py @@ -45,7 +45,11 @@ class SyncUserInterface(ABC): @abstractmethod def get_files_folder_user_sync( - self, sync_active_id: int, user_id: str, folder_id: int = None + self, + sync_active_id: int, + user_id: str, + folder_id: int = None, + recursive: bool = False, ): pass diff --git a/backend/modules/sync/repository/sync_user.py b/backend/modules/sync/repository/sync_user.py index 861c28ed0..387b2a194 100644 --- a/backend/modules/sync/repository/sync_user.py +++ b/backend/modules/sync/repository/sync_user.py @@ -159,7 +159,7 @@ class SyncUser(SyncUserInterface): logger.info("Sync user updated successfully") def get_files_folder_user_sync( - self, sync_active_id: int, user_id: str, folder_id: str = None + self, sync_active_id: int, user_id: str, folder_id: str = None, recursive: bool = False ): """ Retrieve files from a user's sync folder, either from Google Drive or Azure. @@ -195,10 +195,12 @@ class SyncUser(SyncUserInterface): provider = sync_user["provider"].lower() if provider == "google": logger.info("Getting files for Google sync") - return get_google_drive_files(sync_user["credentials"], folder_id) + return { + "files": get_google_drive_files(sync_user["credentials"], folder_id) + } elif provider == "azure": logger.info("Getting files for Azure sync") - return list_azure_files(sync_user["credentials"], folder_id) + return {"files": list_azure_files(sync_user["credentials"], folder_id, recursive)} else: - logger.warning("No sync found for provider: %s", sync_user["provider"]) + logger.warning("No sync found for provider: %s", sync_user["provider"], recursive) return "No sync found" diff --git a/backend/modules/sync/service/sync_service.py b/backend/modules/sync/service/sync_service.py index 6ee25aff9..6e2ad6f91 100644 --- a/backend/modules/sync/service/sync_service.py +++ b/backend/modules/sync/service/sync_service.py @@ -36,7 +36,7 @@ class SyncUserService: def get_sync_user_by_state(self, state: dict): return self.repository.get_sync_user_by_state(state) - + def get_sync_user_by_id(self, sync_id: int): return self.repository.get_sync_user_by_id(sync_id) @@ -46,10 +46,14 @@ class SyncUserService: return self.repository.update_sync_user(sync_user_id, state, sync_user_input) def get_files_folder_user_sync( - self, sync_active_id: int, user_id: str, folder_id: str = None + self, + sync_active_id: int, + user_id: str, + folder_id: str = None, + recursive: bool = False, ): return self.repository.get_files_folder_user_sync( - sync_active_id, user_id, folder_id + sync_active_id, user_id, folder_id, recursive ) diff --git a/backend/modules/sync/utils/googleutils.py b/backend/modules/sync/utils/googleutils.py index c41d77e3b..3a6cc64a3 100644 --- a/backend/modules/sync/utils/googleutils.py +++ b/backend/modules/sync/utils/googleutils.py @@ -16,7 +16,10 @@ from modules.sync.dto.inputs import ( ) from modules.sync.repository.sync_files import SyncFiles from modules.sync.service.sync_service import SyncService, SyncUserService -from modules.sync.utils.list_files import get_google_drive_files +from modules.sync.utils.list_files import ( + get_google_drive_files, + get_google_drive_files_by_id, +) from modules.sync.utils.upload import upload_file from modules.upload.service.upload_file import check_file_exists from pydantic import BaseModel, ConfigDict @@ -131,13 +134,15 @@ class GoogleSyncUtils(BaseModel): filename=file_name, ) - await upload_file(to_upload_file, brain_id, current_user) # type: ignore - # Check if the file already exists in the database existing_files = self.sync_files_repo.get_sync_files(sync_active_id) existing_file = next( (f for f in existing_files if f.path == file_name), None ) + supported = False + if (existing_file and existing_file.supported) or not existing_file: + supported = True + await upload_file(to_upload_file, brain_id, current_user) # type: ignore if existing_file: # Update the existing file record @@ -145,6 +150,7 @@ class GoogleSyncUtils(BaseModel): existing_file.id, SyncFileUpdateInput( last_modified=modified_time, + supported=supported, ), ) else: @@ -155,6 +161,7 @@ class GoogleSyncUtils(BaseModel): syncs_active_id=sync_active_id, last_modified=modified_time, brain_id=brain_id, + supported=supported, ) ) @@ -164,6 +171,30 @@ class GoogleSyncUtils(BaseModel): "An error occurred while downloading Google Drive files: %s", error, ) + # Check if the file already exists in the database + existing_files = self.sync_files_repo.get_sync_files(sync_active_id) + existing_file = next( + (f for f in existing_files if f.path == file["name"]), None + ) + # Update the existing file record + if existing_file: + self.sync_files_repo.update_sync_file( + existing_file.id, + SyncFileUpdateInput( + supported=False, + ), + ) + else: + # Create a new file record + self.sync_files_repo.create_sync_file( + SyncFileInput( + path=file["name"], + syncs_active_id=sync_active_id, + last_modified=file["last_modified"], + brain_id=brain_id, + supported=False, + ) + ) return {"downloaded_files": downloaded_files} async def sync(self, sync_active_id: int, user_id: str): @@ -231,12 +262,25 @@ class GoogleSyncUtils(BaseModel): sync_active_id, ) - # Get the folder id from the settings from sync_active settings = sync_active.get("settings", {}) folders = settings.get("folders", []) - files = get_google_drive_files( - sync_user["credentials"], folder_id=folders[0] if folders else None - ) + files_to_download = settings.get("files", []) + files = [] + if len(folders) > 0: + files = [] + for folder in folders: + files.extend( + get_google_drive_files( + sync_user["credentials"], + folder_id=folder, + recursive=True, + ) + ) + if len(files_to_download) > 0: + files_metadata = get_google_drive_files_by_id( + sync_user["credentials"], files_to_download + ) + files = files + files_metadata # type: ignore if "error" in files: logger.error( "Failed to download files from Google Drive for sync_active_id: %s", @@ -249,7 +293,7 @@ class GoogleSyncUtils(BaseModel): files_to_download = [ file - for file in files.get("files", []) + for file in files if not file["is_folder"] and ( ( diff --git a/backend/modules/sync/utils/list_files.py b/backend/modules/sync/utils/list_files.py index d52f1fb6e..984a842f0 100644 --- a/backend/modules/sync/utils/list_files.py +++ b/backend/modules/sync/utils/list_files.py @@ -1,4 +1,5 @@ import os +from typing import List import msal import requests @@ -12,13 +13,62 @@ from requests import HTTPError logger = get_logger(__name__) -def get_google_drive_files(credentials: dict, folder_id: str = None): +def get_google_drive_files_by_id(credentials: dict, file_ids: List[str]): + """ + Retrieve files from Google Drive by their IDs. + + Args: + credentials (dict): The credentials for accessing Google Drive. + file_ids (list): The list of file IDs to retrieve. + + Returns: + list: A list of dictionaries containing the metadata of each file or an error message. + """ + logger.info("Retrieving Google Drive files with file_ids: %s", file_ids) + creds = Credentials.from_authorized_user_info(credentials) + if creds.expired and creds.refresh_token: + creds.refresh(GoogleRequest()) + logger.info("Google Drive credentials refreshed") + + try: + service = build("drive", "v3", credentials=creds) + files = [] + + for file_id in file_ids: + result = ( + service.files() + .get(fileId=file_id, fields="id, name, mimeType, modifiedTime") + .execute() + ) + + files.append( + { + "name": result["name"], + "id": result["id"], + "is_folder": result["mimeType"] + == "application/vnd.google-apps.folder", + "last_modified": result["modifiedTime"], + "mime_type": result["mimeType"], + } + ) + + logger.info("Google Drive files retrieved successfully: %s", len(files)) + return files + except HTTPError as error: + logger.error("An error occurred while retrieving Google Drive files: %s", error) + return {"error": f"An error occurred: {error}"} + + +def get_google_drive_files( + credentials: dict, folder_id: str = None, recursive: bool = False +): """ Retrieve files from Google Drive. Args: credentials (dict): The credentials for accessing Google Drive. folder_id (str, optional): The folder ID to filter files. Defaults to None. + recursive (bool, optional): If True, fetch files from all subfolders. Defaults to False. Returns: dict: A dictionary containing the list of files or an error message. @@ -32,34 +82,63 @@ def get_google_drive_files(credentials: dict, folder_id: str = None): try: service = build("drive", "v3", credentials=creds) - query = f"'{folder_id}' in parents" if folder_id else None - results = ( - service.files() - .list( - q=query, - pageSize=10, - fields="nextPageToken, files(id, name, mimeType, modifiedTime)", + if folder_id: + query = f"'{folder_id}' in parents" + else: + query = "'root' in parents or sharedWithMe" + page_token = None + files = [] + + while True: + results = ( + service.files() + .list( + q=query, + pageSize=100, + fields="nextPageToken, files(id, name, mimeType, modifiedTime)", + pageToken=page_token, + ) + .execute() ) - .execute() - ) - items = results.get("files", []) + items = results.get("files", []) - if not items: - logger.info("No files found in Google Drive") - return {"files": "No files found."} + if not items: + logger.info("No files found in Google Drive") + break - files = [ - { - "name": item["name"], - "id": item["id"], - "is_folder": item["mimeType"] == "application/vnd.google-apps.folder", - "last_modified": item["modifiedTime"], - "mime_type": item["mimeType"], - } - for item in items - ] - logger.info("Google Drive files retrieved successfully: %s", files) - return {"files": files} + for item in items: + files.append( + { + "name": item["name"], + "id": item["id"], + "is_folder": item["mimeType"] + == "application/vnd.google-apps.folder", + "last_modified": item["modifiedTime"], + "mime_type": item["mimeType"], + } + ) + + # If recursive is True and the item is a folder, get files from the folder + if item["name"] == "Monotype": + logger.warning(item) + if ( + recursive + and item["mimeType"] == "application/vnd.google-apps.folder" + ): + logger.warning( + "Calling Recursive for folder: %s", + item["name"], + ) + files.extend( + get_google_drive_files(credentials, item["id"], recursive) + ) + + page_token = results.get("nextPageToken", None) + if page_token is None: + break + + logger.info("Google Drive files retrieved successfully: %s", len(files)) + return files except HTTPError as error: logger.error("An error occurred while retrieving Google Drive files: %s", error) return {"error": f"An error occurred: {error}"} @@ -103,7 +182,17 @@ def get_azure_headers(token_data): } -def list_azure_files(credentials, folder_id=None): +def list_azure_files(credentials, folder_id=None, recursive=False): + def fetch_files(endpoint, headers): + response = requests.get(endpoint, headers=headers) + if response.status_code == 401: + token_data = refresh_azure_token(credentials) + headers = get_azure_headers(token_data) + response = requests.get(endpoint, headers=headers) + if response.status_code != 200: + return {"error": response.text} + return response.json().get("value", []) + token_data = get_azure_token_data(credentials) headers = get_azure_headers(token_data) endpoint = f"https://graph.microsoft.com/v1.0/me/drive/root/children" @@ -111,28 +200,76 @@ def list_azure_files(credentials, folder_id=None): endpoint = ( f"https://graph.microsoft.com/v1.0/me/drive/items/{folder_id}/children" ) - response = requests.get(endpoint, headers=headers) - if response.status_code == 401: - token_data = refresh_azure_token(credentials) - headers = get_azure_headers(token_data) - response = requests.get(endpoint, headers=headers) - if response.status_code != 200: - return {"error": response.text} - items = response.json().get("value", []) + + items = fetch_files(endpoint, headers) if not items: logger.info("No files found in Azure Drive") - return {"files": "No files found."} + return [] - files = [ - { + files = [] + for item in items: + file_data = { "name": item["name"], "id": item["id"], "is_folder": "folder" in item, "last_modified": item["lastModifiedDateTime"], "mime_type": item.get("file", {}).get("mimeType", "folder"), } - for item in items - ] - logger.info("Azure Drive files retrieved successfully: %s", files) - return {"files": files} + files.append(file_data) + + # If recursive option is enabled and the item is a folder, fetch files from it + if recursive and file_data["is_folder"]: + folder_files = list_azure_files( + credentials, folder_id=file_data["id"], recursive=True + ) + + files.extend(folder_files) + + logger.info("Azure Drive files retrieved successfully: %s", len(files)) + return files + + +def get_azure_files_by_id(credentials: dict, file_ids: List[str]): + """ + Retrieve files from Azure Drive by their IDs. + + Args: + credentials (dict): The credentials for accessing Azure Drive. + file_ids (list): The list of file IDs to retrieve. + + Returns: + list: A list of dictionaries containing the metadata of each file or an error message. + """ + logger.info("Retrieving Azure Drive files with file_ids: %s", file_ids) + token_data = get_azure_token_data(credentials) + headers = get_azure_headers(token_data) + files = [] + + for file_id in file_ids: + endpoint = f"https://graph.microsoft.com/v1.0/me/drive/items/{file_id}" + response = requests.get(endpoint, headers=headers) + if response.status_code == 401: + token_data = refresh_azure_token(credentials) + headers = get_azure_headers(token_data) + response = requests.get(endpoint, headers=headers) + if response.status_code != 200: + logger.error( + "An error occurred while retrieving Azure Drive files: %s", + response.text, + ) + return {"error": response.text} + + result = response.json() + files.append( + { + "name": result["name"], + "id": result["id"], + "is_folder": "folder" in result, + "last_modified": result["lastModifiedDateTime"], + "mime_type": result.get("file", {}).get("mimeType", "folder"), + } + ) + + logger.info("Azure Drive files retrieved successfully: %s", len(files)) + return files diff --git a/backend/modules/sync/utils/sharepointutils.py b/backend/modules/sync/utils/sharepointutils.py index 64ea9f226..3adb2b375 100644 --- a/backend/modules/sync/utils/sharepointutils.py +++ b/backend/modules/sync/utils/sharepointutils.py @@ -15,7 +15,7 @@ from modules.sync.dto.inputs import ( ) from modules.sync.repository.sync_files import SyncFiles from modules.sync.service.sync_service import SyncService, SyncUserService -from modules.sync.utils.list_files import list_azure_files +from modules.sync.utils.list_files import get_azure_files_by_id, list_azure_files from modules.sync.utils.upload import upload_file from modules.upload.service.upload_file import check_file_exists from pydantic import BaseModel, ConfigDict @@ -75,9 +75,9 @@ class AzureSyncUtils(BaseModel): logger.info("Downloading Azure files with metadata: %s", files) headers = self.get_headers(token_data) - try: - downloaded_files = [] - for file in files: + downloaded_files = [] + for file in files: + try: file_id = file["id"] file_name = file["name"] modified_time = file["last_modified"] @@ -127,20 +127,24 @@ class AzureSyncUtils(BaseModel): filename=file_name, ) - await upload_file(to_upload_file, brain_id, current_user) - # Check if the file already exists in the database existing_files = self.sync_files_repo.get_sync_files(sync_active_id) existing_file = next( (f for f in existing_files if f.path == file_name), None ) + supported = False + if (existing_file and existing_file.supported) or not existing_file: + supported = True + await upload_file(to_upload_file, brain_id, current_user) + if existing_file: # Update the existing file record self.sync_files_repo.update_sync_file( existing_file.id, SyncFileUpdateInput( last_modified=modified_time, + supported=supported, ), ) else: @@ -151,14 +155,40 @@ class AzureSyncUtils(BaseModel): syncs_active_id=sync_active_id, last_modified=modified_time, brain_id=brain_id, + supported=supported, ) ) downloaded_files.append(file_name) - return {"downloaded_files": downloaded_files} - except Exception as error: - logger.error("An error occurred while downloading Azure files: %s", error) - return {"error": f"An error occurred: {error}"} + except Exception as error: + logger.error( + "An error occurred while downloading Azure files: %s", error + ) + # Check if the file already exists in the database + existing_files = self.sync_files_repo.get_sync_files(sync_active_id) + existing_file = next( + (f for f in existing_files if f.path == file["name"]), None + ) + # Update the existing file record + if existing_file: + self.sync_files_repo.update_sync_file( + existing_file.id, + SyncFileUpdateInput( + supported=False, + ), + ) + else: + # Create a new file record + self.sync_files_repo.create_sync_file( + SyncFileInput( + path=file["name"], + syncs_active_id=sync_active_id, + last_modified=file["last_modified"], + brain_id=brain_id, + supported=False, + ) + ) + return {"downloaded_files": downloaded_files} async def sync(self, sync_active_id: int, user_id: str): """ @@ -228,9 +258,25 @@ class AzureSyncUtils(BaseModel): # Get the folder id from the settings from sync_active settings = sync_active.get("settings", {}) folders = settings.get("folders", []) - files = list_azure_files( - sync_user["credentials"], folder_id=folders[0] if folders else None - ) + files_to_download = settings.get("files", []) + files = [] + if len(folders) > 0: + files = [] + for folder in folders: + files.extend( + list_azure_files( + sync_user["credentials"], + folder_id=folder, + recursive=True, + ) + ) + if len(files_to_download) > 0: + files_metadata = get_azure_files_by_id( + sync_user["credentials"], + files_to_download, + ) + files = files + files_metadata # type: ignore + if "error" in files: logger.error( "Failed to download files from Azure for sync_active_id: %s", @@ -244,10 +290,11 @@ class AzureSyncUtils(BaseModel): if last_synced else None ) - logger.info("Files retrieved from Azure: %s", files.get("files", [])) + logger.info("Files retrieved from Azure: %s", len(files)) + logger.info("Files retrieved from Azure: %s", files) files_to_download = [ file - for file in files.get("files", []) + for file in files if not file["is_folder"] and ( ( diff --git a/backend/supabase/migrations/20240608095352_supported-sync.sql b/backend/supabase/migrations/20240608095352_supported-sync.sql new file mode 100644 index 000000000..70625dbc5 --- /dev/null +++ b/backend/supabase/migrations/20240608095352_supported-sync.sql @@ -0,0 +1,2 @@ +alter table "public"."syncs_files" add column "supported" boolean not null default true; + diff --git a/backend/tach.yml b/backend/tach.yml new file mode 100644 index 000000000..01d486808 --- /dev/null +++ b/backend/tach.yml @@ -0,0 +1,124 @@ +modules: +- path: + depends_on: + - modules.analytics + - modules.api_key + - modules.assistant + - modules.brain + - modules.chat + - modules.contact_support + - modules.knowledge + - modules.misc + - modules.notification + - modules.onboarding + - modules.prompt + - modules.sync + - modules.upload + - modules.user + strict: false +- path: modules.analytics + depends_on: + - + - modules.brain + strict: false +- path: modules.api_key + depends_on: + - + - modules.user + strict: false +- path: modules.assistant + depends_on: + - + - modules.chat + - modules.contact_support + # - modules.upload + - modules.user + strict: false +- path: modules.authorization + depends_on: [] + strict: false +- path: modules.brain + depends_on: + - + - modules.chat + - modules.knowledge + - modules.prompt + - modules.tools + - modules.upload + - modules.user + strict: false +- path: modules.chat + depends_on: + - + - modules.brain + - modules.notification + - modules.prompt + - modules.user + strict: false +- path: modules.contact_support + depends_on: + - + strict: false +- path: modules.ingestion + depends_on: [] + strict: false +- path: modules.knowledge + depends_on: + - + - modules.brain + - modules.upload + - modules.user + strict: false +- path: modules.message + depends_on: [] + strict: false +- path: modules.misc + depends_on: [] + strict: false +- path: modules.notification + depends_on: + - + strict: false +- path: modules.onboarding + depends_on: + - + - modules.user + strict: false +- path: modules.prompt + depends_on: + - + - modules.brain + strict: false +- path: modules.sync + depends_on: + - + - modules.brain + - modules.knowledge + - modules.notification + - modules.upload + - modules.user + strict: false +- path: modules.tools + depends_on: + - + - modules.contact_support + strict: false +- path: modules.upload + depends_on: + - + - modules.brain + - modules.knowledge + - modules.notification + - modules.user + strict: false +- path: modules.user + depends_on: + - + - modules.brain + strict: false +exclude: +- docs +- tests +exact: false +disable_logging: false +ignore_type_checking_imports: false