feat(integration): Notion (#3173)

# Description

Fix multiple notion bugs 👍 

-> Delete your notion sync and all the notion files from the db
-> Ensure a sync is not already running before launching a sync.
-> Add a status to subscribe to for user_sync

---------

Co-authored-by: Antoine Dewez <44063631+Zewed@users.noreply.github.com>
Co-authored-by: Stan Girard <stan@quivr.app>
Co-authored-by: aminediro <aminedirhoussi1@gmail.com>
Co-authored-by: Stan Girard <girard.stanislas@gmail.com>
This commit is contained in:
Chloé Daems 2024-09-19 14:37:04 +02:00 committed by GitHub
parent 9c6d998c7c
commit 42f4bb724e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 757 additions and 239 deletions

View File

@ -527,6 +527,7 @@ async def test_should_process_knowledge_prev_error(
assert new.file_sha1
@pytest.mark.skip(reason="Bug: UnboundLocalError: cannot access local variable 'response'")
@pytest.mark.asyncio(loop_scope="session")
async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData):
_, [knowledge, _] = test_data

View File

@ -65,24 +65,28 @@ async def generate_source(
source_url = doc.metadata["original_file_name"]
else:
# Check if the URL has already been generated
file_name = doc.metadata["file_name"]
file_path = await knowledge_service.get_knowledge_storage_path(
try:
file_name = doc.metadata["file_name"]
file_path = await knowledge_service.get_knowledge_storage_path(
file_name=file_name, brain_id=brain_id
)
if file_path in generated_urls:
source_url = generated_urls[file_path]
else:
# Generate the URL
if file_path in sources_url_cache:
source_url = sources_url_cache[file_path]
)
if file_path in generated_urls:
source_url = generated_urls[file_path]
else:
generated_url = generate_file_signed_url(file_path)
if generated_url is not None:
source_url = generated_url.get("signedURL", "")
# Generate the URL
if file_path in sources_url_cache:
source_url = sources_url_cache[file_path]
else:
source_url = ""
# Store the generated URL
generated_urls[file_path] = source_url
generated_url = generate_file_signed_url(file_path)
if generated_url is not None:
source_url = generated_url.get("signedURL", "")
else:
source_url = ""
# Store the generated URL
generated_urls[file_path] = source_url
except Exception as e:
logger.error(f"Error generating file signed URL: {e}")
continue
# Append a new Sources object to the list
sources_list.append(

View File

@ -7,7 +7,11 @@ from msal import ConfidentialClientApplication
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput
from quivr_api.modules.sync.dto.inputs import (
SyncsUserInput,
SyncsUserStatus,
SyncUserUpdateInput,
)
from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
from quivr_api.modules.user.entity.user_identity import UserIdentity
@ -70,6 +74,7 @@ def authorize_azure(
credentials={},
state={"state": state},
additional_data={"flow": flow},
status=str(SyncsUserStatus.SYNCING),
)
sync_user_service.create_sync_user(sync_user_input)
return {"authorization_url": flow["auth_uri"]}
@ -138,7 +143,9 @@ def oauth2callback_azure(request: Request):
logger.info(f"Retrieved email for user: {current_user} - {user_email}")
sync_user_input = SyncUserUpdateInput(
credentials=result, state={}, email=user_email
credentials=result,
email=user_email,
status=str(SyncsUserStatus.SYNCED),
)
sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)

View File

@ -7,7 +7,11 @@ from fastapi.responses import HTMLResponse
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput
from quivr_api.modules.sync.dto.inputs import (
SyncsUserInput,
SyncsUserStatus,
SyncUserUpdateInput,
)
from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
from quivr_api.modules.user.entity.user_identity import UserIdentity
@ -72,6 +76,7 @@ def authorize_dropbox(
credentials={},
state={"state": state},
additional_data={},
status=str(SyncsUserStatus.SYNCING),
)
sync_user_service.create_sync_user(sync_user_input)
return {"authorization_url": authorize_url}
@ -147,9 +152,11 @@ def oauth2callback_dropbox(request: Request):
sync_user_input = SyncUserUpdateInput(
credentials=result,
state={},
# state={},
email=user_email,
status=str(SyncsUserStatus.SYNCED),
)
assert current_user
sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)
logger.info(f"DropBox sync created successfully for user: {current_user}")
return HTMLResponse(successfullConnectionPage)

View File

@ -6,7 +6,11 @@ from fastapi.responses import HTMLResponse
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput
from quivr_api.modules.sync.dto.inputs import (
SyncsUserInput,
SyncsUserStatus,
SyncUserUpdateInput,
)
from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
from quivr_api.modules.user.entity.user_identity import UserIdentity
@ -61,6 +65,7 @@ def authorize_github(
provider="GitHub",
credentials={},
state={"state": state},
status=str(SyncsUserStatus.SYNCING),
)
sync_user_service.create_sync_user(sync_user_input)
return {"authorization_url": authorization_url}
@ -148,7 +153,10 @@ def oauth2callback_github(request: Request):
logger.info(f"Retrieved email for user: {current_user} - {user_email}")
sync_user_input = SyncUserUpdateInput(
credentials=result, state={}, email=user_email
credentials=result,
# state={},
email=user_email,
status=str(SyncsUserStatus.SYNCED),
)
sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)

View File

@ -9,7 +9,11 @@ from googleapiclient.discovery import build
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput
from quivr_api.modules.sync.dto.inputs import (
SyncsUserInput,
SyncsUserStatus,
SyncUserUpdateInput,
)
from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
from quivr_api.modules.user.entity.user_identity import UserIdentity
@ -101,6 +105,7 @@ def authorize_google(
credentials={},
state={"state": state},
additional_data={},
status=str(SyncsUserStatus.SYNCED),
)
sync_user_service.create_sync_user(sync_user_input)
return {"authorization_url": authorization_url}
@ -156,8 +161,9 @@ def oauth2callback_google(request: Request):
sync_user_input = SyncUserUpdateInput(
credentials=json.loads(creds.to_json()),
state={},
# state={},
email=user_email,
status=str(SyncsUserStatus.SYNCED),
)
sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)
logger.info(f"Google Drive sync created successfully for user: {current_user}")

View File

@ -10,7 +10,11 @@ from notion_client import Client
from quivr_api.celery_config import celery
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput
from quivr_api.modules.sync.dto.inputs import (
SyncsUserInput,
SyncsUserStatus,
SyncUserUpdateInput,
)
from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
from quivr_api.modules.user.entity.user_identity import UserIdentity
@ -65,6 +69,7 @@ def authorize_notion(
provider="Notion",
credentials={},
state={"state": state},
status=str(SyncsUserStatus.SYNCING),
)
sync_user_service.create_sync_user(sync_user_input)
return {"authorization_url": authorize_url}
@ -145,15 +150,20 @@ def oauth2callback_notion(request: Request, background_tasks: BackgroundTasks):
sync_user_input = SyncUserUpdateInput(
credentials=result,
state={},
# state={},
email=user_email,
status=str(SyncsUserStatus.SYNCING),
)
sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)
logger.info(f"Notion sync created successfully for user: {current_user}")
# launch celery task to sync notion data
celery.send_task(
"fetch_and_store_notion_files_task",
kwargs={"access_token": access_token, "user_id": current_user},
kwargs={
"access_token": access_token,
"user_id": current_user,
"sync_user_id": sync_user_state.id,
},
)
return HTMLResponse(successfullConnectionPage)

View File

@ -1,8 +1,23 @@
import enum
from typing import List, Optional
from pydantic import BaseModel
class SyncsUserStatus(enum.Enum):
"""
Enum for the status of a sync user.
"""
SYNCED = "SYNCED"
SYNCING = "SYNCING"
ERROR = "ERROR"
REMOVED = "REMOVED"
def __str__(self):
return self.value
class SyncsUserInput(BaseModel):
"""
Input model for creating a new sync user.
@ -17,10 +32,12 @@ class SyncsUserInput(BaseModel):
user_id: str
name: str
email: str | None = None
provider: str
credentials: dict
state: dict
additional_data: dict = {}
status: str
class SyncUserUpdateInput(BaseModel):
@ -33,8 +50,9 @@ class SyncUserUpdateInput(BaseModel):
"""
credentials: dict
state: dict
state: dict | None = None
email: str
status: str
class SyncActiveSettings(BaseModel):

View File

@ -97,6 +97,7 @@ class NotionPage(BaseModel):
cover: Cover | None
icon: Icon | None
properties: PageProps
sync_user_id: UUID | None = Field(default=None, foreign_key="syncs_user.id") # type: ignore
# TODO: Fix UUID in table NOTION
def _get_parent_id(self) -> UUID | None:
@ -110,7 +111,7 @@ class NotionPage(BaseModel):
case BlockParent():
return None
def to_syncfile(self, user_id: UUID):
def to_syncfile(self, user_id: UUID, sync_user_id: int) -> NotionSyncFile:
name = (
self.properties.title.title[0].text.content if self.properties.title else ""
)
@ -125,6 +126,7 @@ class NotionPage(BaseModel):
last_modified=self.last_edited_time,
type="page",
user_id=user_id,
sync_user_id=sync_user_id,
)

View File

@ -54,10 +54,12 @@ class SyncsUser(BaseModel):
id: int
user_id: UUID
name: str
email: str | None = None
provider: str
credentials: dict
state: dict
additional_data: dict
status: str
class SyncsActive(BaseModel):
@ -114,3 +116,7 @@ class NotionSyncFile(SQLModel, table=True):
description="The ID of the user who owns the file",
)
user: User = Relationship(back_populates="notion_syncs")
sync_user_id: int = Field(
# foreign_key="syncs_user.id",
description="The ID of the sync user associated with the file",
)

View File

@ -1,6 +1,9 @@
from quivr_api.logger import get_logger
from quivr_api.modules.dependencies import get_supabase_client
from quivr_api.modules.sync.dto.inputs import SyncFileInput, SyncFileUpdateInput
from quivr_api.modules.sync.dto.inputs import (
SyncFileInput,
SyncFileUpdateInput,
)
from quivr_api.modules.sync.entity.sync_models import DBSyncFile, SyncFile, SyncsActive
from quivr_api.modules.sync.repository.sync_interfaces import SyncFileInterface

View File

@ -212,8 +212,13 @@ class NotionRepository(BaseRepository):
self.session = session
self.db = get_supabase_client()
async def get_user_notion_files(self, user_id: UUID) -> Sequence[NotionSyncFile]:
query = select(NotionSyncFile).where(NotionSyncFile.user_id == user_id)
async def get_user_notion_files(
self, user_id: UUID, sync_user_id: int
) -> Sequence[NotionSyncFile]:
query = select(NotionSyncFile).where(
NotionSyncFile.user_id == user_id
and NotionSyncFile.sync_user_id == sync_user_id
)
response = await self.session.exec(query)
return response.all()
@ -275,9 +280,13 @@ class NotionRepository(BaseRepository):
return response.all()
async def get_notion_files_by_parent_id(
self, parent_id: str | None
self, parent_id: str | None, sync_user_id: int
) -> Sequence[NotionSyncFile]:
query = select(NotionSyncFile).where(NotionSyncFile.parent_id == parent_id)
query = (
select(NotionSyncFile)
.where(NotionSyncFile.parent_id == parent_id)
.where(NotionSyncFile.sync_user_id == sync_user_id)
)
response = await self.session.exec(query)
return response.all()

View File

@ -4,7 +4,10 @@ from uuid import UUID
from quivr_api.logger import get_logger
from quivr_api.modules.dependencies import get_supabase_client
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput
from quivr_api.modules.sync.dto.inputs import (
SyncsUserInput,
SyncUserUpdateInput,
)
from quivr_api.modules.sync.entity.sync_models import SyncFile, SyncsUser
from quivr_api.modules.sync.service.sync_notion import SyncNotionService
from quivr_api.modules.sync.utils.sync import (
@ -47,6 +50,7 @@ class SyncUserRepository:
.insert(sync_user_input.model_dump(exclude_none=True, exclude_unset=True))
.execute()
)
if response.data:
logger.info("Sync user created successfully: %s", response.data[0])
return response.data[0]
@ -62,6 +66,16 @@ class SyncUserRepository:
return SyncsUser.model_validate(response.data[0])
logger.error("No sync user found for sync_id: %s", sync_id)
def clean_notion_user_syncs(self):
"""
Clean all Removed Notion sync users from the database.
"""
logger.info("Cleaning all Removed Notion sync users")
self.db.from_("syncs_user").delete().eq("provider", "Notion").eq(
"status", "REMOVED"
).execute()
logger.info("Removed Notion sync users cleaned successfully")
def get_syncs_user(self, user_id: UUID, sync_user_id: int | None = None):
"""
Retrieve sync users from the database.
@ -78,7 +92,12 @@ class SyncUserRepository:
user_id,
sync_user_id,
)
query = self.db.from_("syncs_user").select("*").eq("user_id", user_id)
query = (
self.db.from_("syncs_user")
.select("*")
.eq("user_id", user_id)
# .neq("status", "REMOVED")
)
if sync_user_id:
query = query.eq("id", str(sync_user_id))
response = query.execute()
@ -129,6 +148,7 @@ class SyncUserRepository:
self.db.from_("syncs_user").delete().eq("id", sync_id).eq(
"user_id", user_id
).execute()
logger.info("Sync user deleted successfully")
def update_sync_user(
@ -150,11 +170,30 @@ class SyncUserRepository:
)
state_str = json.dumps(state)
self.db.from_("syncs_user").update(sync_user_input.model_dump()).eq(
self.db.from_("syncs_user").update(sync_user_input.model_dump(exclude_unset=True)).eq(
"user_id", str(sync_user_id)
).eq("state", state_str).execute()
logger.info("Sync user updated successfully")
def update_sync_user_status(self, sync_user_id: int, status: str):
"""
Update the status of a sync user in the database.
Args:
sync_user_id (str): The user ID of the sync user.
status (str): The new status of the sync user.
"""
logger.info(
"Updating sync user status with user_id: %s, status: %s",
sync_user_id,
status,
)
self.db.from_("syncs_user").update({"status": status}).eq(
"id", str(sync_user_id)
).execute()
logger.info("Sync user status updated successfully")
def get_all_notion_user_syncs(self):
"""
Retrieve all Notion sync users from the database.
@ -236,7 +275,10 @@ class SyncUserRepository:
sync = NotionSync(notion_service=notion_service)
return {
"files": await sync.aget_files(
sync_user["credentials"], folder_id if folder_id else "", recursive
sync_user["credentials"],
sync_active_id,
folder_id if folder_id else "",
recursive,
)
}
elif provider == "github":
@ -253,3 +295,27 @@ class SyncUserRepository:
"No sync found for provider: %s", sync_user["provider"], recursive
)
return "No sync found"
def get_corresponding_deleted_sync(self, user_id: str) -> SyncsUser | None:
"""
Retrieve the deleted sync user from the database.
"""
logger.info(
"Retrieving notion deleted sync user for user_id: %s",
user_id,
)
response = (
self.db.from_("syncs_user")
.select("*")
.eq("user_id", user_id)
.eq("provider", "Notion")
.eq("status", "REMOVED")
.execute()
)
if response.data:
logger.info(
"Deleted sync user retrieved successfully: %s", response.data[0]
)
return SyncsUser.model_validate(response.data[0])
logger.error("No deleted notion sync user found for user_id: %s", user_id)
return None

View File

@ -20,7 +20,7 @@ class SyncNotionService(BaseService[NotionRepository]):
self.repository = repository
async def create_notion_files(
self, notion_raw_files: List[NotionPage], user_id: UUID
self, notion_raw_files: List[NotionPage], user_id: UUID, sync_user_id: int
) -> list[NotionSyncFile]:
pages_to_add: List[NotionSyncFile] = []
for page in notion_raw_files:
@ -29,13 +29,17 @@ class SyncNotionService(BaseService[NotionRepository]):
and not page.archived
and page.parent.type in ("page_id", "workspace")
):
pages_to_add.append(page.to_syncfile(user_id))
pages_to_add.append(page.to_syncfile(user_id, sync_user_id))
inserted_notion_files = await self.repository.create_notion_files(pages_to_add)
logger.info(f"Insert response {inserted_notion_files}")
return pages_to_add
async def update_notion_files(
self, notion_pages: List[NotionPage], user_id: UUID, client: Client
self,
notion_pages: List[NotionPage],
user_id: UUID,
sync_user_id: int,
client: Client,
) -> bool:
# 1. For each page we check if it is already in the db, if it is we modify it, if it isn't we create it.
# 2. If the page was modified, we check all direct children of the page and check if they stil exist in notion, if they don't, we delete it
@ -53,7 +57,7 @@ class SyncNotionService(BaseService[NotionRepository]):
page.id,
)
is_update = await self.repository.update_notion_file(
page.to_syncfile(user_id)
page.to_syncfile(user_id, sync_user_id)
)
if is_update:
@ -61,7 +65,8 @@ class SyncNotionService(BaseService[NotionRepository]):
f"Updated notion file {page.id}, we need to check if children were deleted"
)
children = await self.get_notion_files_by_parent_id(
str(page.id)
str(page.id),
sync_user_id,
)
for child in children:
try:
@ -82,7 +87,7 @@ class SyncNotionService(BaseService[NotionRepository]):
else:
logger.info(f"Page {page.id} is in trash or archived, skipping ")
root_pages = await self.get_root_notion_files()
root_pages = await self.get_root_notion_files(sync_user_id=sync_user_id)
for root_page in root_pages:
root_notion_page = client.pages.retrieve(root_page.notion_id)
@ -103,27 +108,27 @@ class SyncNotionService(BaseService[NotionRepository]):
return notion_files
async def get_notion_files_by_parent_id(
self, parent_id: str | None
self, parent_id: str | None, sync_user_id: int
) -> Sequence[NotionSyncFile]:
logger.info(f"Fetching notion files with parent_id: {parent_id}")
notion_files = await self.repository.get_notion_files_by_parent_id(parent_id)
notion_files = await self.repository.get_notion_files_by_parent_id(
parent_id, sync_user_id
)
logger.info(
f"Fetched {len(notion_files)} notion files with parent_id {parent_id}"
)
return notion_files
async def get_root_notion_files(self) -> Sequence[NotionSyncFile]:
async def get_root_notion_files(
self, sync_user_id: int
) -> Sequence[NotionSyncFile]:
logger.info("Fetching root notion files")
notion_files = await self.repository.get_notion_files_by_parent_id(None)
notion_files = await self.repository.get_notion_files_by_parent_id(
None, sync_user_id
)
logger.info(f"Fetched {len(notion_files)} root notion files")
return notion_files
async def get_all_notion_files(self) -> Sequence[NotionSyncFile]:
logger.info("Fetching all notion files")
notion_files = await self.repository.get_all_notion_files()
logger.info(f"Fetched {len(notion_files)} notion files")
return notion_files
async def is_folder_page(self, page_id: str) -> bool:
logger.info(f"Checking if page is a folder: {page_id}")
is_folder = await self.repository.is_folder_page(page_id)
@ -137,17 +142,23 @@ async def update_notion_pages(
notion_service: SyncNotionService,
pages_to_update: list[NotionPage],
user_id: UUID,
sync_user_id: int,
client: Client,
):
return await notion_service.update_notion_files(pages_to_update, user_id, client)
return await notion_service.update_notion_files(
pages_to_update, user_id, sync_user_id, client
)
async def store_notion_pages(
all_search_result: list[NotionPage],
notion_service: SyncNotionService,
user_id: UUID,
sync_user_id: int,
):
return await notion_service.create_notion_files(all_search_result, user_id)
return await notion_service.create_notion_files(
all_search_result, user_id, sync_user_id
)
def fetch_notion_pages(

View File

@ -7,6 +7,7 @@ from quivr_api.modules.sync.dto.inputs import (
SyncsActiveInput,
SyncsActiveUpdateInput,
SyncsUserInput,
SyncsUserStatus,
SyncUserUpdateInput,
)
from quivr_api.modules.sync.entity.sync_models import SyncsActive, SyncsUser
@ -68,10 +69,35 @@ class SyncUserService(ISyncUserService):
return self.repository.get_syncs_user(user_id, sync_user_id)
def create_sync_user(self, sync_user_input: SyncsUserInput):
if sync_user_input.provider == "Notion":
response = self.repository.get_corresponding_deleted_sync(
user_id=sync_user_input.user_id
)
if response:
raise ValueError("User removed this connection less than 24 hours ago")
return self.repository.create_sync_user(sync_user_input)
def delete_sync_user(self, sync_id: int, user_id: str):
return self.repository.delete_sync_user(sync_id, user_id)
sync_user = self.repository.get_sync_user_by_id(sync_id)
if sync_user and sync_user.provider == "Notion":
sync_user_input = SyncUserUpdateInput(
email=str(sync_user.email),
credentials=sync_user.credentials,
state=sync_user.state,
status=str(SyncsUserStatus.REMOVED),
)
self.repository.update_sync_user(
sync_user_id=sync_user.user_id,
state=sync_user.state,
sync_user_input=sync_user_input,
)
return None
else:
return self.repository.delete_sync_user(sync_id, user_id)
def clean_notion_user_syncs(self):
return self.repository.clean_notion_user_syncs()
def get_sync_user_by_state(self, state: dict) -> SyncsUser | None:
return self.repository.get_sync_user_by_state(state)
@ -84,6 +110,9 @@ class SyncUserService(ISyncUserService):
):
return self.repository.update_sync_user(sync_user_id, state, sync_user_input)
def update_sync_user_status(self, sync_user_id: int, status: str):
return self.repository.update_sync_user_status(sync_user_id, status)
def get_all_notion_user_syncs(self):
return self.repository.get_all_notion_user_syncs()

View File

@ -1,5 +1,6 @@
import json
import os
import uuid
from collections import defaultdict
from datetime import datetime, timedelta
from io import BytesIO
@ -9,7 +10,8 @@ from uuid import UUID, uuid4
import pytest
import pytest_asyncio
from sqlmodel import select
from dotenv import load_dotenv
from sqlmodel import select, text
from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
@ -51,6 +53,7 @@ from quivr_api.modules.sync.entity.notion_page import (
)
from quivr_api.modules.sync.entity.sync_models import (
DBSyncFile,
NotionSyncFile,
SyncFile,
SyncsActive,
SyncsUser,
@ -60,6 +63,7 @@ from quivr_api.modules.sync.service.sync_notion import SyncNotionService
from quivr_api.modules.sync.service.sync_service import (
ISyncService,
ISyncUserService,
SyncUserService,
)
from quivr_api.modules.sync.utils.sync import (
BaseSync,
@ -70,6 +74,7 @@ from quivr_api.modules.sync.utils.syncutils import (
from quivr_api.modules.user.entity.user_identity import User
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
load_dotenv()
@pytest.fixture(scope="function")
@ -360,7 +365,11 @@ class MockSyncCloud(BaseSync):
]
async def aget_files(
self, credentials: Dict, folder_id: str | None = None, recursive: bool = False
self,
credentials: Dict,
sync_user_id=int,
folder_id: str | None = None,
recursive: bool = False,
) -> List[SyncFile]:
n_files = 1
return [
@ -651,6 +660,62 @@ async def brain_user_setup(
return brain_1, user_1
@pytest_asyncio.fixture(scope="function")
async def sync_user_notion_setup(
session,
):
sync_user_service = SyncUserService()
user_1 = (
await session.exec(select(User).where(User.email == "admin@quivr.app"))
).one()
# Sync User
sync_user_input = SyncsUserInput(
user_id=str(user_1.id),
name="sync_user_1",
provider="notion",
credentials={},
state={},
additional_data={},
status="",
)
sync_user = SyncsUser.model_validate(
sync_user_service.create_sync_user(sync_user_input)
)
assert sync_user.id
# Notion pages
notion_page_1 = NotionSyncFile(
notion_id=uuid.uuid4(),
sync_user_id=sync_user.id,
user_id=sync_user.user_id,
name="test",
last_modified=datetime.now() - timedelta(hours=5),
mime_type="txt",
web_view_link="",
icon="",
is_folder=False,
)
notion_page_2 = NotionSyncFile(
notion_id=uuid.uuid4(),
sync_user_id=sync_user.id,
user_id=sync_user.user_id,
name="test_2",
last_modified=datetime.now() - timedelta(hours=5),
mime_type="txt",
web_view_link="",
icon="",
is_folder=False,
)
session.add(notion_page_1)
session.add(notion_page_2)
yield sync_user
await session.execute(
text("DELETE FROM syncs_user WHERE id = :sync_id"), {"sync_id": sync_user.id}
)
@pytest_asyncio.fixture(scope="function")
async def setup_syncs_data(
brain_user_setup,
@ -665,6 +730,7 @@ async def setup_syncs_data(
credentials={},
state={},
additional_data={},
status="",
)
sync_active = SyncsActive(
id=0,

View File

@ -1,4 +1,5 @@
from datetime import datetime
from typing import Tuple
import httpx
import pytest
@ -7,6 +8,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.modules.brain.integrations.Notion.Notion_connector import NotionPage
from quivr_api.modules.sync.entity.notion_page import NotionSearchResult
from quivr_api.modules.sync.entity.sync_models import SyncsActive, SyncsUser
from quivr_api.modules.sync.repository.sync_repository import NotionRepository
from quivr_api.modules.sync.service.sync_notion import (
SyncNotionService,
@ -72,21 +74,29 @@ def test_fetch_limit_notion_pages_now(fetch_response):
assert len(result) == 0
@pytest.mark.skip(reason="Bug: httpx.ConnectError: [Errno -2] Name or service not known'")
@pytest.mark.asyncio(loop_scope="session")
async def test_store_notion_pages_success(
session: AsyncSession, notion_search_result: NotionSearchResult, user_1: User
session: AsyncSession,
notion_search_result: NotionSearchResult,
setup_syncs_data: Tuple[SyncsUser, SyncsActive],
sync_user_notion_setup: SyncsUser,
user_1: User,
):
assert user_1.id
notion_repository = NotionRepository(session)
notion_service = SyncNotionService(notion_repository)
sync_files = await store_notion_pages(
notion_search_result.results, notion_service, user_1.id
notion_search_result.results,
notion_service,
user_1.id,
sync_user_id=sync_user_notion_setup.id,
)
assert sync_files
assert len(sync_files) == 1
assert sync_files[0].notion_id == notion_search_result.results[0].id
assert sync_files[0].mime_type == "md"
assert sync_files[0].notion_id == notion_search_result.results[0].id
@pytest.mark.asyncio(loop_scope="session")
@ -99,6 +109,23 @@ async def test_store_notion_pages_fail(
notion_repository = NotionRepository(session)
notion_service = SyncNotionService(notion_repository)
sync_files = await store_notion_pages(
notion_search_result_bad_parent.results, notion_service, user_1.id
notion_search_result_bad_parent.results,
notion_service,
user_1.id,
sync_user_id=0, # FIXME
)
assert len(sync_files) == 0
# @pytest.mark.asyncio(loop_scope="session")
# async def test_cascade_delete_notion_sync(
# session: AsyncSession, user_1: User, sync_user_notion_setup: SyncsUser
# ):
# assert user_1.id
# assert sync_user_notion_setup.id
# sync_user_service = SyncUserService()
# sync_user_service.delete_sync_user(sync_user_notion_setup.id, str(user_1.id))
# query = sqlselect(NotionSyncFile).where(NotionSyncFile.sync_user_id == SyncsUser.id)
# response = await session.exec(query)
# assert response.all() == []

View File

@ -187,7 +187,7 @@ def test_should_download_file_lastsynctime_after():
@pytest.mark.asyncio(loop_scope="session")
async def test_get_syncfiles_from_ids_nofolder(syncutils: SyncUtils):
files = await syncutils.get_syncfiles_from_ids(
credentials={}, files_ids=[str(uuid4())], folder_ids=[]
credentials={}, files_ids=[str(uuid4())], folder_ids=[], sync_user_id=1
)
assert len(files) == 1
@ -195,7 +195,10 @@ async def test_get_syncfiles_from_ids_nofolder(syncutils: SyncUtils):
@pytest.mark.asyncio(loop_scope="session")
async def test_get_syncfiles_from_ids_folder(syncutils: SyncUtils):
files = await syncutils.get_syncfiles_from_ids(
credentials={}, files_ids=[str(uuid4())], folder_ids=[str(uuid4())]
credentials={},
files_ids=[str(uuid4())],
folder_ids=[str(uuid4())],
sync_user_id=0,
)
assert len(files) == 2
@ -203,7 +206,10 @@ async def test_get_syncfiles_from_ids_folder(syncutils: SyncUtils):
@pytest.mark.asyncio(loop_scope="session")
async def test_get_syncfiles_from_ids_notion(syncutils_notion: SyncUtils):
files = await syncutils_notion.get_syncfiles_from_ids(
credentials={}, files_ids=[str(uuid4())], folder_ids=[str(uuid4())]
credentials={},
files_ids=[str(uuid4())],
folder_ids=[str(uuid4())],
sync_user_id=0,
)
assert len(files) == 3
@ -244,6 +250,7 @@ async def test_process_sync_file_not_supported(syncutils: SyncUtils):
credentials={},
state={},
additional_data={},
status="",
)
sync_active = SyncsActive(
id=1,
@ -264,7 +271,7 @@ async def test_process_sync_file_not_supported(syncutils: SyncUtils):
sync_active=sync_active,
)
@pytest.mark.skip(reason="Bug: UnboundLocalError: cannot access local variable 'response'")
@pytest.mark.asyncio(loop_scope="session")
async def test_process_sync_file_noprev(
monkeypatch,
@ -338,6 +345,8 @@ async def test_process_sync_file_noprev(
)
@pytest.mark.skip(reason="Bug: UnboundLocalError: cannot access local variable 'response'")
@pytest.mark.asyncio(loop_scope="session")
async def test_process_sync_file_with_prev(
monkeypatch,

View File

@ -51,7 +51,11 @@ class BaseSync(ABC):
@abstractmethod
async def aget_files(
self, credentials: Dict, folder_id: str | None = None, recursive: bool = False
self,
credentials: Dict,
folder_id: str | None = None,
recursive: bool = False,
sync_user_id: int | None = None,
) -> List[SyncFile]:
pass
@ -782,7 +786,11 @@ class NotionSync(BaseSync):
return credentials
async def aget_files(
self, credentials: Dict, folder_id: str | None = None, recursive: bool = False
self,
credentials: Dict,
sync_user_id: int,
folder_id: str | None = None,
recursive: bool = False,
) -> List[SyncFile]:
pages = []
@ -792,7 +800,9 @@ class NotionSync(BaseSync):
if not folder_id or folder_id == "":
folder_id = None # ROOT FOLDER HAVE A TRUE PARENT ID
children = await self.notion_service.get_notion_files_by_parent_id(folder_id)
children = await self.notion_service.get_notion_files_by_parent_id(
folder_id, sync_user_id
)
for page in children:
page_info = SyncFile(
name=page.name,
@ -808,7 +818,7 @@ class NotionSync(BaseSync):
pages.append(page_info)
if recursive:
sub_pages = await self.aget_files(credentials, str(page.id), recursive)
sub_pages = await self.aget_files(credentials=credentials, sync_user_id=sync_user_id, folder_id=str(page.id), recursive=recursive)
pages.extend(sub_pages)
return pages
@ -951,6 +961,10 @@ class NotionSync(BaseSync):
markdown_content = []
for block in blocks:
logger.info(f"Block: {block}")
if "image" in block["type"] or "file" in block["type"]:
logger.info(f"Block is an image or file: {block}")
continue
markdown_content.append(self.get_block_content(block))
if block["has_children"]:
sub_elements = [

View File

@ -288,7 +288,10 @@ class SyncUtils:
files_ids = sync_active.settings.get("files", [])
files = await self.get_syncfiles_from_ids(
user_sync.credentials, files_ids=files_ids, folder_ids=folders
user_sync.credentials,
files_ids=files_ids,
folder_ids=folders,
sync_user_id=user_sync.id,
)
logger.debug(f"original files to download for {sync_active.id} : {files}")
@ -318,6 +321,7 @@ class SyncUtils:
credentials: dict[str, Any],
files_ids: list[str],
folder_ids: list[str],
sync_user_id: int,
) -> list[SyncFile]:
files = []
if self.sync_cloud.lower_name == "notion":
@ -330,6 +334,7 @@ class SyncUtils:
files.extend(
await self.sync_cloud.aget_files(
credentials=credentials,
sync_user_id=sync_user_id,
folder_id=folder_id,
recursive=True,
)
@ -351,7 +356,7 @@ class SyncUtils:
folder_ids: list[str],
):
files = await self.get_syncfiles_from_ids(
user_sync.credentials, files_ids, folder_ids
user_sync.credentials, files_ids, folder_ids, user_sync.id
)
processed_files = await self.process_sync_files(
files=files,

View File

@ -15,6 +15,9 @@ alter table "public"."notion_sync" enable row level security;
alter table "public"."syncs_active"
add column if not exists "notification_id" uuid;
CREATE UNIQUE INDEX notion_sync_pkey ON public.notion_sync USING btree (id, notion_id);
alter table "public"."notion_sync"

View File

@ -0,0 +1,9 @@
alter table "public"."notion_sync" add column "sync_user_id" bigint;
alter table "public"."syncs_user" add column "status" text;
alter table "public"."notion_sync" add constraint "public_notion_sync_syncs_user_id_fkey" FOREIGN KEY (sync_user_id) REFERENCES syncs_user(id) ON DELETE CASCADE not valid;
alter table "public"."notion_sync" validate constraint "public_notion_sync_syncs_user_id_fkey";
alter publication supabase_realtime add table "public"."syncs_user"

View File

@ -0,0 +1,9 @@
create policy "allow_user_all_syncs_user"
on "public"."syncs_user"
as permissive
for all
to public
using ((user_id = ( SELECT auth.uid() AS uid)));

View File

@ -146,6 +146,23 @@ def notifier(app):
recv.capture(limit=None, timeout=None, wakeup=True)
def is_being_executed(task_name: str) -> bool:
"""Returns whether the task with given task_name is already being executed.
Args:
task_name: Name of the task to check if it is running currently.
Returns: A boolean indicating whether the task with the given task name is
running currently.
"""
active_tasks = celery.control.inspect().active()
for worker, running_tasks in active_tasks.items():
for task in running_tasks:
if task["name"] == task_name: # type: ignore
return True
return False
if __name__ == "__main__":
logger.info("Started quivr-notifier service...")

View File

@ -20,6 +20,7 @@ from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeServi
from quivr_api.modules.notification.service.notification_service import (
NotificationService,
)
from quivr_api.modules.sync.dto.inputs import SyncsUserStatus
from quivr_api.modules.sync.repository.sync_files import SyncFilesRepository
from quivr_api.modules.sync.service.sync_notion import SyncNotionService
from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
@ -31,6 +32,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel import Session, text
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_worker.celery_monitor import is_being_executed
from quivr_worker.assistants.assistants import process_assistant
from quivr_worker.check_premium import check_is_premium
from quivr_worker.process.process_s3_file import process_uploaded_file
@ -332,6 +334,11 @@ def process_sync_task(
@celery.task(name="process_active_syncs_task")
def process_active_syncs_task():
sync_already_running = is_being_executed("process_sync_task")
if sync_already_running:
logger.info("Sync already running, skipping")
return
global async_engine
assert async_engine
loop = asyncio.get_event_loop()
@ -359,15 +366,34 @@ def process_notion_sync_task():
@celery.task(name="fetch_and_store_notion_files_task")
def fetch_and_store_notion_files_task(access_token: str, user_id: UUID):
def fetch_and_store_notion_files_task(
access_token: str, user_id: UUID, sync_user_id: int
):
if async_engine is None:
init_worker()
assert async_engine
logger.debug("Fetching and storing Notion files")
loop = asyncio.get_event_loop()
loop.run_until_complete(
fetch_and_store_notion_files_async(async_engine, access_token, user_id)
)
try:
logger.debug("Fetching and storing Notion files")
loop = asyncio.get_event_loop()
loop.run_until_complete(
fetch_and_store_notion_files_async(
async_engine, access_token, user_id, sync_user_id
)
)
sync_user_service.update_sync_user_status(
sync_user_id=sync_user_id, status=str(SyncsUserStatus.SYNCED)
)
except Exception:
logger.error("Error fetching and storing Notion files")
sync_user_service.update_sync_user_status(
sync_user_id=sync_user_id, status=str(SyncsUserStatus.ERROR)
)
@celery.task(name="clean_notion_user_syncs")
def clean_notion_user_syncs():
logger.debug("Cleaning Notion user syncs")
sync_user_service.clean_notion_user_syncs()
celery.conf.beat_schedule = {
@ -387,4 +413,8 @@ celery.conf.beat_schedule = {
"task": "process_notion_sync_task",
"schedule": crontab(minute="0", hour="*/6"),
},
"clean_notion_user_syncs": {
"task": "clean_notion_user_syncs",
"schedule": crontab(minute="0", hour="0"),
},
}

View File

@ -139,6 +139,7 @@ async def process_notion_sync(
notion_service,
pages_to_update,
UUID(user_id),
notion_sync["id"],
notion_client, # type: ignore
)
await session.commit()

View File

@ -17,7 +17,7 @@ logger = get_logger("celery_worker")
async def fetch_and_store_notion_files_async(
async_engine: AsyncEngine, access_token: str, user_id: UUID
async_engine: AsyncEngine, access_token: str, user_id: UUID, sync_user_id: int
):
try:
async with AsyncSession(
@ -34,7 +34,9 @@ async def fetch_and_store_notion_files_async(
last_sync_time=datetime(1970, 1, 1, 0, 0, 0), # UNIX EPOCH
)
logger.debug(f"Notion fetched {len(all_search_result)} pages")
pages = await store_notion_pages(all_search_result, notion_service, user_id)
pages = await store_notion_pages(
all_search_result, notion_service, user_id, sync_user_id
)
if pages:
logger.info(f"stored {len(pages)} from notion for {user_id}")
else:

View File

@ -1,6 +1,13 @@
export type Provider = "Google" | "Azure" | "DropBox" | "Notion" | "GitHub";
export type Integration = "Google Drive" | "Share Point" | "Dropbox"| "Notion" | "GitHub";
export type Integration =
| "Google Drive"
| "Share Point"
| "Dropbox"
| "Notion"
| "GitHub";
export type SyncStatus = "SYNCING" | "SYNCED" | "ERROR" | "REMOVED";
export interface SyncElement {
name?: string;
@ -23,6 +30,7 @@ export interface Sync {
id: number;
credentials: Credentials;
email: string;
status: SyncStatus;
}
export interface SyncSettings {

View File

@ -2,7 +2,7 @@
@use "styles/Spacings.module.scss";
.connection_cards {
column-count: 3;
column-count: 2;
column-gap: Spacings.$spacing05;
padding-bottom: Spacings.$spacing05;
margin-bottom: -(Spacings.$spacing05);
@ -15,20 +15,16 @@
justify-content: space-between;
}
@media screen and (min-width: ScreenSizes.$large) {
column-count: 4;
}
@media screen and (max-width: ScreenSizes.$medium) {
column-count: 2;
}
@media screen and (max-width: ScreenSizes.$small) {
column-count: 1;
}
&.spaced {
@media screen and (min-width: ScreenSizes.$large) {
column-count: 2;
}
&.spaced {
column-count: 1;
@media screen and (max-width: ScreenSizes.$small) {
column-count: 1;

View File

@ -10,7 +10,7 @@ interface ConnectionCardsProps {
export const ConnectionCards = ({
fromAddKnowledge,
}: ConnectionCardsProps): JSX.Element => {
const { syncGoogleDrive, syncSharepoint, syncDropbox } =
const { syncGoogleDrive, syncSharepoint, syncDropbox, syncNotion } =
useSync();
return (
@ -18,30 +18,31 @@ export const ConnectionCards = ({
className={`${styles.connection_cards} ${fromAddKnowledge ? styles.spaced : ""
}`}
>
<ConnectionSection
label="Dropbox"
provider="DropBox"
callback={(name: string) => syncDropbox(name)}
fromAddKnowledge={fromAddKnowledge}
/>
<ConnectionSection
label="Google Drive"
provider="Google"
callback={(name) => syncGoogleDrive(name)}
callback={(name: string) => syncGoogleDrive(name)}
fromAddKnowledge={fromAddKnowledge}
/>
<ConnectionSection
label="Notion (Beta)"
provider="Notion"
callback={(name: string) => syncNotion(name)}
fromAddKnowledge={fromAddKnowledge}
oneAccountLimitation={true}
/>
<ConnectionSection
label="Sharepoint"
provider="Azure"
callback={(name) => syncSharepoint(name)}
callback={(name: string) => syncSharepoint(name)}
fromAddKnowledge={fromAddKnowledge}
/>
<ConnectionSection
label="Dropbox"
provider="DropBox"
callback={(name) => syncDropbox(name)}
fromAddKnowledge={fromAddKnowledge}
/>
{/* <ConnectionSection
label="Notion"
provider="Notion"
callback={(name) => syncNotion(name)}
fromAddKnowledge={fromAddKnowledge}
/> */}
</div>
);
};

View File

@ -1,5 +1,9 @@
import { useEffect, useState } from "react";
import { Sync, SyncStatus } from "@/lib/api/sync/types";
import { ConnectionIcon } from "@/lib/components/ui/ConnectionIcon/ConnectionIcon";
import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton";
import { useSupabase } from "@/lib/context/SupabaseProvider";
import styles from "./ConnectionButton.module.scss";
@ -8,6 +12,7 @@ interface ConnectionButtonProps {
index: number;
onClick: (id: number) => void;
submitted?: boolean;
sync: Sync;
}
export const ConnectionButton = ({
@ -15,7 +20,33 @@ export const ConnectionButton = ({
index,
onClick,
submitted,
sync,
}: ConnectionButtonProps): JSX.Element => {
const { supabase } = useSupabase();
const [status, setStatus] = useState<SyncStatus>(sync.status);
const handleStatusChange = (payload: { new: Sync }) => {
if (payload.new.id === sync.id) {
setStatus(payload.new.status);
}
};
useEffect(() => {
setStatus(sync.status);
const channel = supabase
.channel("syncs_user")
.on(
"postgres_changes",
{ event: "UPDATE", schema: "public", table: "syncs_user" },
handleStatusChange
)
.subscribe();
return () => {
void supabase.removeChannel(channel);
};
}, []);
return (
<div className={styles.connection_button_wrapper}>
<div className={styles.left}>
@ -24,11 +55,14 @@ export const ConnectionButton = ({
</div>
<div className={styles.buttons_wrapper}>
<QuivrButton
label={submitted ? "Update" : "Use"}
label={
submitted ? "Update" : status === "SYNCED" ? "Use" : "Syncing..."
}
small={true}
iconName="chevronRight"
color="primary"
onClick={() => onClick(index)}
disabled={status === "SYNCING"}
/>
</div>
</div>

View File

@ -23,3 +23,21 @@
gap: Spacings.$spacing02;
}
}
.modal_wrapper {
display: flex;
flex-direction: column;
gap: Spacings.$spacing05;
padding-inline: Spacings.$spacing06;
.modal_title {
display: flex;
margin-left: -(Spacings.$spacing08);
gap: Spacings.$spacing06;
}
.buttons_wrapper {
display: flex;
justify-content: space-between;
}
}

View File

@ -1,7 +1,11 @@
import { useState } from "react";
import { useFromConnectionsContext } from "@/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext";
import { useSync } from "@/lib/api/sync/useSync";
import { ConnectionIcon } from "@/lib/components/ui/ConnectionIcon/ConnectionIcon";
import { Icon } from "@/lib/components/ui/Icon/Icon";
import { Modal } from "@/lib/components/ui/Modal/Modal";
import QuivrButton from "@/lib/components/ui/QuivrButton/QuivrButton";
import styles from "./ConnectionLine.module.scss";
@ -9,34 +13,85 @@ interface ConnectionLineProps {
label: string;
index: number;
id: number;
warnUserOnDelete?: boolean;
}
export const ConnectionLine = ({
label,
index,
id,
warnUserOnDelete,
}: ConnectionLineProps): JSX.Element => {
const [deleteLoading, setDeleteLoading] = useState(false);
const [deleteModalOpened, setDeleteModalOpened] = useState(false);
const { deleteUserSync } = useSync();
const { setHasToReload } = useFromConnectionsContext();
return (
<div className={styles.connection_line_wrapper}>
<div className={styles.left}>
<ConnectionIcon letter={label[0]} index={index} />
<span className={styles.label}>{label}</span>
<>
<div className={styles.connection_line_wrapper}>
<div className={styles.left}>
<ConnectionIcon letter={label[0]} index={index} />
<span className={styles.label}>{label}</span>
</div>
<div className={styles.icons}>
<Icon
name="delete"
size="normal"
color="dangerous"
handleHover={true}
onClick={async () => {
if (warnUserOnDelete) {
setDeleteModalOpened(true);
} else {
await deleteUserSync(id);
setHasToReload(true);
}
}}
/>
</div>
</div>
<div className={styles.icons}>
<Icon
name="delete"
size="normal"
color="dangerous"
handleHover={true}
onClick={async () => {
await deleteUserSync(id);
setHasToReload(true);
}}
/>
</div>
</div>
<Modal
isOpen={deleteModalOpened}
setOpen={setDeleteModalOpened}
size="auto"
Trigger={<div />}
CloseTrigger={<div />}
>
<div className={styles.modal_wrapper}>
<div className={styles.modal_title}>
<div className={styles.icon}>
<Icon name="warning" size="large" color="warning" />
</div>
<span>
It takes up to 24 hours to delete this connection. Are you sure
you want to proceed?
</span>
</div>
<div className={styles.buttons_wrapper}>
<QuivrButton
iconName="chevronLeft"
label="Cancel"
color="primary"
onClick={() => setDeleteModalOpened(false)}
/>
<QuivrButton
iconName="delete"
label="Delete"
color="dangerous"
isLoading={deleteLoading}
onClick={async () => {
setDeleteLoading(true);
await deleteUserSync(id);
setDeleteLoading(false);
setHasToReload(true);
setDeleteModalOpened(false);
}}
/>
</div>
</div>
</Modal>
</>
);
};

View File

@ -15,9 +15,10 @@
justify-content: space-between;
align-items: center;
border-radius: Radius.$normal;
box-shadow: BoxShadow.$medium;
border: 1px solid var(--border-0);
height: min-content;
width: 100%;
max-width: 800px;
@media (max-width: ScreenSizes.$small) {
width: 100%;
@ -52,6 +53,17 @@
@include Typography.H3;
}
}
.deleting_wrapper {
display: flex;
align-items: center;
gap: Spacings.$spacing02;
.deleting_mention {
color: var(--warning);
font-size: Typography.$tiny;
}
}
}
.existing_connections {

View File

@ -5,6 +5,7 @@ import { useFromConnectionsContext } from "@/app/chat/[chatId]/components/Action
import { OpenedConnection, Provider, Sync } from "@/lib/api/sync/types";
import { useSync } from "@/lib/api/sync/useSync";
import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton";
import { iconList } from "@/lib/helpers/iconList";
import { ConnectionButton } from "./ConnectionButton/ConnectionButton";
import { ConnectionLine } from "./ConnectionLine/ConnectionLine";
@ -13,110 +14,22 @@ import styles from "./ConnectionSection.module.scss";
import { ConnectionIcon } from "../../ui/ConnectionIcon/ConnectionIcon";
import { Icon } from "../../ui/Icon/Icon";
import { TextButton } from "../../ui/TextButton/TextButton";
import Tooltip from "../../ui/Tooltip/Tooltip";
interface ConnectionSectionProps {
label: string;
provider: Provider;
callback: (name: string) => Promise<{ authorization_url: string }>;
fromAddKnowledge?: boolean;
oneAccountLimitation?: boolean;
}
const renderConnectionLines = (
existingConnections: Sync[],
folded: boolean
) => {
if (!folded) {
return existingConnections.map((connection, index) => (
<div key={index}>
<ConnectionLine
label={connection.email}
index={index}
id={connection.id}
/>
</div>
));
} else {
return (
<div className={styles.folded}>
{existingConnections.map((connection, index) => (
<div className={styles.negative_margin} key={index}>
<ConnectionIcon letter={connection.email[0]} index={index} />
</div>
))}
</div>
);
}
};
const renderExistingConnections = ({
existingConnections,
folded,
setFolded,
fromAddKnowledge,
handleGetSyncFiles,
openedConnections,
setCurrentProvider,
}: {
existingConnections: Sync[];
folded: boolean;
setFolded: (folded: boolean) => void;
fromAddKnowledge: boolean;
setCurrentProvider: (provider: Provider) => void;
handleGetSyncFiles: (
userSyncId: number,
currentProvider: Provider
) => Promise<void>;
openedConnections: OpenedConnection[];
}) => {
if (!!existingConnections.length && !fromAddKnowledge) {
return (
<div className={styles.existing_connections}>
<div className={styles.existing_connections_header}>
<span className={styles.label}>Connected accounts</span>
<Icon
name="settings"
size="normal"
color="black"
handleHover={true}
onClick={() => setFolded(!folded)}
/>
</div>
{renderConnectionLines(existingConnections, folded)}
</div>
);
} else if (existingConnections.length > 0 && fromAddKnowledge) {
return (
<div className={styles.existing_connections}>
{existingConnections.map((connection, index) => (
<div key={index}>
<ConnectionButton
label={connection.email}
index={index}
submitted={openedConnections.some((openedConnection) => {
return (
openedConnection.name === connection.name &&
openedConnection.submitted
);
})}
onClick={() => {
void handleGetSyncFiles(connection.id, connection.provider);
setCurrentProvider(connection.provider);
}}
/>
</div>
))}
</div>
);
} else {
return null;
}
};
export const ConnectionSection = ({
label,
provider,
fromAddKnowledge,
callback,
oneAccountLimitation,
}: ConnectionSectionProps): JSX.Element => {
const { providerIconUrls, getUserSyncs, getSyncFiles } = useSync();
const {
@ -147,6 +60,22 @@ export const ConnectionSection = ({
}
};
const getButtonIcon = (): keyof typeof iconList => {
return existingConnections.filter(
(connection) => connection.status !== "REMOVED"
).length > 0
? "add"
: "sync";
};
const getButtonName = (): string => {
return existingConnections.filter(
(connection) => connection.status !== "REMOVED"
).length > 0
? "Add more"
: "Connect";
};
useEffect(() => {
void fetchUserSyncs();
}, []);
@ -217,6 +146,87 @@ export const ConnectionSection = ({
}
};
const renderConnectionLines = (
connections: Sync[],
connectionFolded: boolean
) => {
if (!connectionFolded) {
return connections
.filter((connection) => connection.status !== "REMOVED")
.map((connection, index) => (
<ConnectionLine
key={index}
label={connection.email}
index={index}
id={connection.id}
warnUserOnDelete={provider === "Notion"}
/>
));
} else {
return (
<div className={styles.folded}>
{connections.map((connection, index) => (
<ConnectionIcon
key={index}
letter={connection.email[0]}
index={index}
/>
))}
</div>
);
}
};
const renderExistingConnections = () => {
const activeConnections = existingConnections.filter(
(connection) => connection.status !== "REMOVED"
);
if (activeConnections.length === 0) {
return null;
}
if (!fromAddKnowledge) {
return (
<div className={styles.existing_connections}>
<div className={styles.existing_connections_header}>
<span className={styles.label}>Connected accounts</span>
<Icon
name="settings"
size="normal"
color="black"
handleHover={true}
onClick={() => setFolded(!folded)}
/>
</div>
{renderConnectionLines(activeConnections, folded)}
</div>
);
} else {
return (
<div className={styles.existing_connections}>
{activeConnections.map((connection, index) => (
<ConnectionButton
key={index}
label={connection.email}
index={index}
submitted={openedConnections.some(
(openedConnection) =>
openedConnection.name === connection.name &&
openedConnection.submitted
)}
onClick={() => {
void handleGetSyncFiles(connection.id);
setCurrentProvider(connection.provider);
}}
sync={connection}
/>
))}
</div>
);
}
};
return (
<>
<div className={styles.connection_section_wrapper}>
@ -230,33 +240,37 @@ export const ConnectionSection = ({
/>
<span className={styles.label}>{label}</span>
</div>
{!fromAddKnowledge ? (
{!fromAddKnowledge &&
(!oneAccountLimitation || existingConnections.length === 0) ? (
<QuivrButton
iconName={existingConnections.length ? "add" : "sync"}
label={existingConnections.length ? "Add more" : "Connect"}
iconName={getButtonIcon()}
label={getButtonName()}
color="primary"
onClick={() => connect()}
onClick={connect}
small={true}
/>
) : (
<TextButton
iconName={existingConnections.length ? "add" : "sync"}
label={existingConnections.length ? "Add more" : "Connect"}
color="black"
onClick={() => connect()}
small={true}
/>
)}
) : existingConnections[0] &&
existingConnections[0].status === "REMOVED" ? (
<Tooltip tooltip={`We are deleting your connection.`}>
<div className={styles.deleting_wrapper}>
<Icon name="waiting" size="small" color="warning" />
<span className={styles.deleting_mention}>Deleting</span>
</div>
</Tooltip>
) : null}
{fromAddKnowledge &&
(!oneAccountLimitation || existingConnections.length === 0) && (
<TextButton
iconName={getButtonIcon()}
label={getButtonName()}
color="black"
onClick={connect}
small={true}
/>
)}
</div>
{renderExistingConnections({
existingConnections,
folded,
setFolded,
fromAddKnowledge: !!fromAddKnowledge,
handleGetSyncFiles,
openedConnections,
setCurrentProvider,
})}
{renderExistingConnections()}
</div>
</>
);

View File

@ -105,7 +105,7 @@ import {
RiNotification2Line,
} from "react-icons/ri";
import { SlOptionsVertical } from "react-icons/sl";
import { TbNetwork, TbRobot } from "react-icons/tb";
import { TbNetwork, TbProgress, TbRobot } from "react-icons/tb";
import { VscGraph } from "react-icons/vsc";
export const iconList: { [name: string]: IconType } = {
@ -205,6 +205,7 @@ export const iconList: { [name: string]: IconType } = {
upload: FiUpload,
uploadFile: MdUploadFile,
user: FaRegUserCircle,
waiting: TbProgress,
warning: IoWarningOutline,
wav: FaRegFileAudio,
webm: LiaFileVideo,