fix: integrations (#2642)

This pull request adds support for recursive folder retrieval in the
`get_files_folder_user_sync` method. Previously, the method only
retrieved files from the specified folder, but now it can also retrieve
files from all subfolders recursively. This enhancement improves the
functionality and flexibility of the method, allowing for more
comprehensive file retrieval in sync operations.
This commit is contained in:
Stan Girard 2024-06-08 11:55:11 +02:00 committed by GitHub
parent ca6341372d
commit 47c6e24bf1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 457 additions and 78 deletions

View File

@ -93,6 +93,7 @@ class SyncFileInput(BaseModel):
syncs_active_id: int
last_modified: str
brain_id: str
supported: Optional[bool] = True
class SyncFileUpdateInput(BaseModel):
@ -103,4 +104,5 @@ class SyncFileUpdateInput(BaseModel):
last_modified (datetime.datetime): The updated last modified date and time.
"""
last_modified: str
last_modified: Optional[str] = None
supported: Optional[bool] = None

View File

@ -1,8 +1,16 @@
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
class SyncsUser(BaseModel):
id: int
user_id: str
name: str
provider: str
class SyncsActive(BaseModel):
id: int
name: str
@ -12,6 +20,7 @@ class SyncsActive(BaseModel):
last_synced: datetime
sync_interval_minutes: int
brain_id: str
syncs_user: Optional[SyncsUser]
class SyncsFiles(BaseModel):
@ -20,3 +29,4 @@ class SyncsFiles(BaseModel):
syncs_active_id: int
last_modified: str
brain_id: str
supported: bool

View File

@ -64,7 +64,10 @@ class Sync(SyncInterface):
"""
logger.info("Retrieving active syncs for user_id: %s", user_id)
response = (
self.db.from_("syncs_active").select("*").eq("user_id", user_id).execute()
self.db.from_("syncs_active")
.select("*, syncs_user(*)")
.eq("user_id", user_id)
.execute()
)
if response.data:
logger.info("Active syncs retrieved successfully: %s", response.data)

View File

@ -81,9 +81,9 @@ class SyncFiles(SyncFileInterface):
sync_file_id,
sync_file_input,
)
self.db.from_("syncs_files").update(sync_file_input.model_dump()).eq(
"id", sync_file_id
).execute()
self.db.from_("syncs_files").update(
sync_file_input.model_dump(exclude_unset=True)
).eq("id", sync_file_id).execute()
logger.info("Sync file updated successfully")
def delete_sync_file(self, sync_file_id: int):

View File

@ -45,7 +45,11 @@ class SyncUserInterface(ABC):
@abstractmethod
def get_files_folder_user_sync(
self, sync_active_id: int, user_id: str, folder_id: int = None
self,
sync_active_id: int,
user_id: str,
folder_id: int = None,
recursive: bool = False,
):
pass

View File

@ -159,7 +159,7 @@ class SyncUser(SyncUserInterface):
logger.info("Sync user updated successfully")
def get_files_folder_user_sync(
self, sync_active_id: int, user_id: str, folder_id: str = None
self, sync_active_id: int, user_id: str, folder_id: str = None, recursive: bool = False
):
"""
Retrieve files from a user's sync folder, either from Google Drive or Azure.
@ -195,10 +195,12 @@ class SyncUser(SyncUserInterface):
provider = sync_user["provider"].lower()
if provider == "google":
logger.info("Getting files for Google sync")
return get_google_drive_files(sync_user["credentials"], folder_id)
return {
"files": get_google_drive_files(sync_user["credentials"], folder_id)
}
elif provider == "azure":
logger.info("Getting files for Azure sync")
return list_azure_files(sync_user["credentials"], folder_id)
return {"files": list_azure_files(sync_user["credentials"], folder_id, recursive)}
else:
logger.warning("No sync found for provider: %s", sync_user["provider"])
logger.warning("No sync found for provider: %s", sync_user["provider"], recursive)
return "No sync found"

View File

@ -46,10 +46,14 @@ class SyncUserService:
return self.repository.update_sync_user(sync_user_id, state, sync_user_input)
def get_files_folder_user_sync(
self, sync_active_id: int, user_id: str, folder_id: str = None
self,
sync_active_id: int,
user_id: str,
folder_id: str = None,
recursive: bool = False,
):
return self.repository.get_files_folder_user_sync(
sync_active_id, user_id, folder_id
sync_active_id, user_id, folder_id, recursive
)

View File

@ -16,7 +16,10 @@ from modules.sync.dto.inputs import (
)
from modules.sync.repository.sync_files import SyncFiles
from modules.sync.service.sync_service import SyncService, SyncUserService
from modules.sync.utils.list_files import get_google_drive_files
from modules.sync.utils.list_files import (
get_google_drive_files,
get_google_drive_files_by_id,
)
from modules.sync.utils.upload import upload_file
from modules.upload.service.upload_file import check_file_exists
from pydantic import BaseModel, ConfigDict
@ -131,13 +134,15 @@ class GoogleSyncUtils(BaseModel):
filename=file_name,
)
await upload_file(to_upload_file, brain_id, current_user) # type: ignore
# Check if the file already exists in the database
existing_files = self.sync_files_repo.get_sync_files(sync_active_id)
existing_file = next(
(f for f in existing_files if f.path == file_name), None
)
supported = False
if (existing_file and existing_file.supported) or not existing_file:
supported = True
await upload_file(to_upload_file, brain_id, current_user) # type: ignore
if existing_file:
# Update the existing file record
@ -145,6 +150,7 @@ class GoogleSyncUtils(BaseModel):
existing_file.id,
SyncFileUpdateInput(
last_modified=modified_time,
supported=supported,
),
)
else:
@ -155,6 +161,7 @@ class GoogleSyncUtils(BaseModel):
syncs_active_id=sync_active_id,
last_modified=modified_time,
brain_id=brain_id,
supported=supported,
)
)
@ -164,6 +171,30 @@ class GoogleSyncUtils(BaseModel):
"An error occurred while downloading Google Drive files: %s",
error,
)
# Check if the file already exists in the database
existing_files = self.sync_files_repo.get_sync_files(sync_active_id)
existing_file = next(
(f for f in existing_files if f.path == file["name"]), None
)
# Update the existing file record
if existing_file:
self.sync_files_repo.update_sync_file(
existing_file.id,
SyncFileUpdateInput(
supported=False,
),
)
else:
# Create a new file record
self.sync_files_repo.create_sync_file(
SyncFileInput(
path=file["name"],
syncs_active_id=sync_active_id,
last_modified=file["last_modified"],
brain_id=brain_id,
supported=False,
)
)
return {"downloaded_files": downloaded_files}
async def sync(self, sync_active_id: int, user_id: str):
@ -231,12 +262,25 @@ class GoogleSyncUtils(BaseModel):
sync_active_id,
)
# Get the folder id from the settings from sync_active
settings = sync_active.get("settings", {})
folders = settings.get("folders", [])
files = get_google_drive_files(
sync_user["credentials"], folder_id=folders[0] if folders else None
files_to_download = settings.get("files", [])
files = []
if len(folders) > 0:
files = []
for folder in folders:
files.extend(
get_google_drive_files(
sync_user["credentials"],
folder_id=folder,
recursive=True,
)
)
if len(files_to_download) > 0:
files_metadata = get_google_drive_files_by_id(
sync_user["credentials"], files_to_download
)
files = files + files_metadata # type: ignore
if "error" in files:
logger.error(
"Failed to download files from Google Drive for sync_active_id: %s",
@ -249,7 +293,7 @@ class GoogleSyncUtils(BaseModel):
files_to_download = [
file
for file in files.get("files", [])
for file in files
if not file["is_folder"]
and (
(

View File

@ -1,4 +1,5 @@
import os
from typing import List
import msal
import requests
@ -12,13 +13,62 @@ from requests import HTTPError
logger = get_logger(__name__)
def get_google_drive_files(credentials: dict, folder_id: str = None):
def get_google_drive_files_by_id(credentials: dict, file_ids: List[str]):
"""
Retrieve files from Google Drive by their IDs.
Args:
credentials (dict): The credentials for accessing Google Drive.
file_ids (list): The list of file IDs to retrieve.
Returns:
list: A list of dictionaries containing the metadata of each file or an error message.
"""
logger.info("Retrieving Google Drive files with file_ids: %s", file_ids)
creds = Credentials.from_authorized_user_info(credentials)
if creds.expired and creds.refresh_token:
creds.refresh(GoogleRequest())
logger.info("Google Drive credentials refreshed")
try:
service = build("drive", "v3", credentials=creds)
files = []
for file_id in file_ids:
result = (
service.files()
.get(fileId=file_id, fields="id, name, mimeType, modifiedTime")
.execute()
)
files.append(
{
"name": result["name"],
"id": result["id"],
"is_folder": result["mimeType"]
== "application/vnd.google-apps.folder",
"last_modified": result["modifiedTime"],
"mime_type": result["mimeType"],
}
)
logger.info("Google Drive files retrieved successfully: %s", len(files))
return files
except HTTPError as error:
logger.error("An error occurred while retrieving Google Drive files: %s", error)
return {"error": f"An error occurred: {error}"}
def get_google_drive_files(
credentials: dict, folder_id: str = None, recursive: bool = False
):
"""
Retrieve files from Google Drive.
Args:
credentials (dict): The credentials for accessing Google Drive.
folder_id (str, optional): The folder ID to filter files. Defaults to None.
recursive (bool, optional): If True, fetch files from all subfolders. Defaults to False.
Returns:
dict: A dictionary containing the list of files or an error message.
@ -32,13 +82,21 @@ def get_google_drive_files(credentials: dict, folder_id: str = None):
try:
service = build("drive", "v3", credentials=creds)
query = f"'{folder_id}' in parents" if folder_id else None
if folder_id:
query = f"'{folder_id}' in parents"
else:
query = "'root' in parents or sharedWithMe"
page_token = None
files = []
while True:
results = (
service.files()
.list(
q=query,
pageSize=10,
pageSize=100,
fields="nextPageToken, files(id, name, mimeType, modifiedTime)",
pageToken=page_token,
)
.execute()
)
@ -46,20 +104,41 @@ def get_google_drive_files(credentials: dict, folder_id: str = None):
if not items:
logger.info("No files found in Google Drive")
return {"files": "No files found."}
break
files = [
for item in items:
files.append(
{
"name": item["name"],
"id": item["id"],
"is_folder": item["mimeType"] == "application/vnd.google-apps.folder",
"is_folder": item["mimeType"]
== "application/vnd.google-apps.folder",
"last_modified": item["modifiedTime"],
"mime_type": item["mimeType"],
}
for item in items
]
logger.info("Google Drive files retrieved successfully: %s", files)
return {"files": files}
)
# If recursive is True and the item is a folder, get files from the folder
if item["name"] == "Monotype":
logger.warning(item)
if (
recursive
and item["mimeType"] == "application/vnd.google-apps.folder"
):
logger.warning(
"Calling Recursive for folder: %s",
item["name"],
)
files.extend(
get_google_drive_files(credentials, item["id"], recursive)
)
page_token = results.get("nextPageToken", None)
if page_token is None:
break
logger.info("Google Drive files retrieved successfully: %s", len(files))
return files
except HTTPError as error:
logger.error("An error occurred while retrieving Google Drive files: %s", error)
return {"error": f"An error occurred: {error}"}
@ -103,14 +182,8 @@ def get_azure_headers(token_data):
}
def list_azure_files(credentials, folder_id=None):
token_data = get_azure_token_data(credentials)
headers = get_azure_headers(token_data)
endpoint = f"https://graph.microsoft.com/v1.0/me/drive/root/children"
if folder_id:
endpoint = (
f"https://graph.microsoft.com/v1.0/me/drive/items/{folder_id}/children"
)
def list_azure_files(credentials, folder_id=None, recursive=False):
def fetch_files(endpoint, headers):
response = requests.get(endpoint, headers=headers)
if response.status_code == 401:
token_data = refresh_azure_token(credentials)
@ -118,21 +191,85 @@ def list_azure_files(credentials, folder_id=None):
response = requests.get(endpoint, headers=headers)
if response.status_code != 200:
return {"error": response.text}
items = response.json().get("value", [])
return response.json().get("value", [])
token_data = get_azure_token_data(credentials)
headers = get_azure_headers(token_data)
endpoint = f"https://graph.microsoft.com/v1.0/me/drive/root/children"
if folder_id:
endpoint = (
f"https://graph.microsoft.com/v1.0/me/drive/items/{folder_id}/children"
)
items = fetch_files(endpoint, headers)
if not items:
logger.info("No files found in Azure Drive")
return {"files": "No files found."}
return []
files = [
{
files = []
for item in items:
file_data = {
"name": item["name"],
"id": item["id"],
"is_folder": "folder" in item,
"last_modified": item["lastModifiedDateTime"],
"mime_type": item.get("file", {}).get("mimeType", "folder"),
}
for item in items
]
logger.info("Azure Drive files retrieved successfully: %s", files)
return {"files": files}
files.append(file_data)
# If recursive option is enabled and the item is a folder, fetch files from it
if recursive and file_data["is_folder"]:
folder_files = list_azure_files(
credentials, folder_id=file_data["id"], recursive=True
)
files.extend(folder_files)
logger.info("Azure Drive files retrieved successfully: %s", len(files))
return files
def get_azure_files_by_id(credentials: dict, file_ids: List[str]):
"""
Retrieve files from Azure Drive by their IDs.
Args:
credentials (dict): The credentials for accessing Azure Drive.
file_ids (list): The list of file IDs to retrieve.
Returns:
list: A list of dictionaries containing the metadata of each file or an error message.
"""
logger.info("Retrieving Azure Drive files with file_ids: %s", file_ids)
token_data = get_azure_token_data(credentials)
headers = get_azure_headers(token_data)
files = []
for file_id in file_ids:
endpoint = f"https://graph.microsoft.com/v1.0/me/drive/items/{file_id}"
response = requests.get(endpoint, headers=headers)
if response.status_code == 401:
token_data = refresh_azure_token(credentials)
headers = get_azure_headers(token_data)
response = requests.get(endpoint, headers=headers)
if response.status_code != 200:
logger.error(
"An error occurred while retrieving Azure Drive files: %s",
response.text,
)
return {"error": response.text}
result = response.json()
files.append(
{
"name": result["name"],
"id": result["id"],
"is_folder": "folder" in result,
"last_modified": result["lastModifiedDateTime"],
"mime_type": result.get("file", {}).get("mimeType", "folder"),
}
)
logger.info("Azure Drive files retrieved successfully: %s", len(files))
return files

View File

@ -15,7 +15,7 @@ from modules.sync.dto.inputs import (
)
from modules.sync.repository.sync_files import SyncFiles
from modules.sync.service.sync_service import SyncService, SyncUserService
from modules.sync.utils.list_files import list_azure_files
from modules.sync.utils.list_files import get_azure_files_by_id, list_azure_files
from modules.sync.utils.upload import upload_file
from modules.upload.service.upload_file import check_file_exists
from pydantic import BaseModel, ConfigDict
@ -75,9 +75,9 @@ class AzureSyncUtils(BaseModel):
logger.info("Downloading Azure files with metadata: %s", files)
headers = self.get_headers(token_data)
try:
downloaded_files = []
for file in files:
try:
file_id = file["id"]
file_name = file["name"]
modified_time = file["last_modified"]
@ -127,20 +127,24 @@ class AzureSyncUtils(BaseModel):
filename=file_name,
)
await upload_file(to_upload_file, brain_id, current_user)
# Check if the file already exists in the database
existing_files = self.sync_files_repo.get_sync_files(sync_active_id)
existing_file = next(
(f for f in existing_files if f.path == file_name), None
)
supported = False
if (existing_file and existing_file.supported) or not existing_file:
supported = True
await upload_file(to_upload_file, brain_id, current_user)
if existing_file:
# Update the existing file record
self.sync_files_repo.update_sync_file(
existing_file.id,
SyncFileUpdateInput(
last_modified=modified_time,
supported=supported,
),
)
else:
@ -151,14 +155,40 @@ class AzureSyncUtils(BaseModel):
syncs_active_id=sync_active_id,
last_modified=modified_time,
brain_id=brain_id,
supported=supported,
)
)
downloaded_files.append(file_name)
return {"downloaded_files": downloaded_files}
except Exception as error:
logger.error("An error occurred while downloading Azure files: %s", error)
return {"error": f"An error occurred: {error}"}
logger.error(
"An error occurred while downloading Azure files: %s", error
)
# Check if the file already exists in the database
existing_files = self.sync_files_repo.get_sync_files(sync_active_id)
existing_file = next(
(f for f in existing_files if f.path == file["name"]), None
)
# Update the existing file record
if existing_file:
self.sync_files_repo.update_sync_file(
existing_file.id,
SyncFileUpdateInput(
supported=False,
),
)
else:
# Create a new file record
self.sync_files_repo.create_sync_file(
SyncFileInput(
path=file["name"],
syncs_active_id=sync_active_id,
last_modified=file["last_modified"],
brain_id=brain_id,
supported=False,
)
)
return {"downloaded_files": downloaded_files}
async def sync(self, sync_active_id: int, user_id: str):
"""
@ -228,9 +258,25 @@ class AzureSyncUtils(BaseModel):
# Get the folder id from the settings from sync_active
settings = sync_active.get("settings", {})
folders = settings.get("folders", [])
files = list_azure_files(
sync_user["credentials"], folder_id=folders[0] if folders else None
files_to_download = settings.get("files", [])
files = []
if len(folders) > 0:
files = []
for folder in folders:
files.extend(
list_azure_files(
sync_user["credentials"],
folder_id=folder,
recursive=True,
)
)
if len(files_to_download) > 0:
files_metadata = get_azure_files_by_id(
sync_user["credentials"],
files_to_download,
)
files = files + files_metadata # type: ignore
if "error" in files:
logger.error(
"Failed to download files from Azure for sync_active_id: %s",
@ -244,10 +290,11 @@ class AzureSyncUtils(BaseModel):
if last_synced
else None
)
logger.info("Files retrieved from Azure: %s", files.get("files", []))
logger.info("Files retrieved from Azure: %s", len(files))
logger.info("Files retrieved from Azure: %s", files)
files_to_download = [
file
for file in files.get("files", [])
for file in files
if not file["is_folder"]
and (
(

View File

@ -0,0 +1,2 @@
alter table "public"."syncs_files" add column "supported" boolean not null default true;

124
backend/tach.yml Normal file
View File

@ -0,0 +1,124 @@
modules:
- path: <root>
depends_on:
- modules.analytics
- modules.api_key
- modules.assistant
- modules.brain
- modules.chat
- modules.contact_support
- modules.knowledge
- modules.misc
- modules.notification
- modules.onboarding
- modules.prompt
- modules.sync
- modules.upload
- modules.user
strict: false
- path: modules.analytics
depends_on:
- <root>
- modules.brain
strict: false
- path: modules.api_key
depends_on:
- <root>
- modules.user
strict: false
- path: modules.assistant
depends_on:
- <root>
- modules.chat
- modules.contact_support
# - modules.upload
- modules.user
strict: false
- path: modules.authorization
depends_on: []
strict: false
- path: modules.brain
depends_on:
- <root>
- modules.chat
- modules.knowledge
- modules.prompt
- modules.tools
- modules.upload
- modules.user
strict: false
- path: modules.chat
depends_on:
- <root>
- modules.brain
- modules.notification
- modules.prompt
- modules.user
strict: false
- path: modules.contact_support
depends_on:
- <root>
strict: false
- path: modules.ingestion
depends_on: []
strict: false
- path: modules.knowledge
depends_on:
- <root>
- modules.brain
- modules.upload
- modules.user
strict: false
- path: modules.message
depends_on: []
strict: false
- path: modules.misc
depends_on: []
strict: false
- path: modules.notification
depends_on:
- <root>
strict: false
- path: modules.onboarding
depends_on:
- <root>
- modules.user
strict: false
- path: modules.prompt
depends_on:
- <root>
- modules.brain
strict: false
- path: modules.sync
depends_on:
- <root>
- modules.brain
- modules.knowledge
- modules.notification
- modules.upload
- modules.user
strict: false
- path: modules.tools
depends_on:
- <root>
- modules.contact_support
strict: false
- path: modules.upload
depends_on:
- <root>
- modules.brain
- modules.knowledge
- modules.notification
- modules.user
strict: false
- path: modules.user
depends_on:
- <root>
- modules.brain
strict: false
exclude:
- docs
- tests
exact: false
disable_logging: false
ignore_type_checking_imports: false