import asyncio import os from uuid import UUID from celery.schedules import crontab from celery.signals import worker_process_init from dotenv import load_dotenv from quivr_api.celery_config import celery from quivr_api.logger import get_logger from quivr_api.models.settings import settings from quivr_api.modules.brain.integrations.Notion.Notion_connector import NotionConnector from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors from quivr_api.modules.brain.service.brain_service import BrainService from quivr_api.modules.dependencies import get_supabase_client from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.repository.storage import Storage from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.notification.service.notification_service import ( NotificationService, ) from quivr_api.modules.sync.repository.sync_files import SyncFilesRepository from quivr_api.modules.sync.service.sync_notion import SyncNotionService from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService from quivr_api.modules.vector.repository.vectors_repository import VectorRepository from quivr_api.modules.vector.service.vector_service import VectorService from quivr_api.utils.telemetry import maybe_send_telemetry from sqlalchemy import Engine, create_engine from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from sqlmodel import Session from sqlmodel.ext.asyncio.session import AsyncSession from quivr_worker.check_premium import check_is_premium from quivr_worker.process.process_s3_file import process_uploaded_file from quivr_worker.process.process_url import process_url_func from quivr_worker.syncs.process_active_syncs import ( SyncServices, process_all_active_syncs, process_notion_sync, process_sync, ) from quivr_worker.syncs.store_notion import fetch_and_store_notion_files_async from quivr_worker.utils import _patch_json load_dotenv() get_logger("quivr_core") logger = get_logger("celery_worker") _patch_json() # FIXME: load at init time # Services supabase_client = get_supabase_client() # document_vector_store = get_documents_vector_store() notification_service = NotificationService() sync_active_service = SyncService() sync_user_service = SyncUserService() sync_files_repo_service = SyncFilesRepository() brain_service = BrainService() brain_vectors = BrainsVectors() storage = Storage() notion_service: SyncNotionService | None = None async_engine: AsyncEngine | None = None engine: Engine | None = None @worker_process_init.connect def init_worker(**kwargs): global async_engine global engine if not async_engine: async_engine = create_async_engine( settings.pg_database_async_url, echo=True if os.getenv("ORM_DEBUG") else False, future=True, # NOTE: pessimistic bound on pool_pre_ping=True, pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6 pool_recycle=1800, ) if not engine: engine = create_engine( settings.pg_database_url, echo=True if os.getenv("ORM_DEBUG") else False, future=True, # NOTE: pessimistic bound on pool_pre_ping=True, pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6 pool_recycle=1800, ) @celery.task( retries=3, default_retry_delay=1, name="process_file_task", autoretry_for=(Exception,), dont_autoretry_for=(FileExistsError,), ) def process_file_task( file_name: str, file_original_name: str, brain_id: UUID, notification_id: UUID, knowledge_id: UUID, source: str | None = None, source_link: str | None = None, delete_file: bool = False, ): if async_engine is None: init_worker() logger.info( f"Task process_file started for file_name={file_name}, knowledge_id={knowledge_id}, brain_id={brain_id}, notification_id={notification_id}" ) loop = asyncio.get_event_loop() loop.run_until_complete( aprocess_file_task( file_name=file_name, file_original_name=file_original_name, brain_id=brain_id, notification_id=notification_id, knowledge_id=knowledge_id, source=source, source_link=source_link, delete_file=delete_file, ) ) async def aprocess_file_task( file_name: str, file_original_name: str, brain_id: UUID, notification_id: UUID, knowledge_id: UUID, source: str | None = None, source_link: str | None = None, delete_file: bool = False, ): global engine assert engine async with AsyncSession(async_engine) as async_session: with Session(engine, expire_on_commit=False, autoflush=False) as session: vector_repository = VectorRepository(session) vector_service = VectorService( vector_repository ) # FIXME @amine: fix to need AsyncSession in vector Service knowledge_repository = KnowledgeRepository(async_session) knowledge_service = KnowledgeService(knowledge_repository) await process_uploaded_file( supabase_client=supabase_client, brain_service=brain_service, vector_service=vector_service, knowledge_service=knowledge_service, file_name=file_name, brain_id=brain_id, file_original_name=file_original_name, knowledge_id=knowledge_id, integration=source, integration_link=source_link, delete_file=delete_file, ) @celery.task( retries=3, default_retry_delay=1, name="process_crawl_task", autoretry_for=(Exception,), ) def process_crawl_task( crawl_website_url: str, brain_id: UUID, knowledge_id: UUID, notification_id: UUID | None = None, ): logger.info( f"Task process_crawl_task started for url={crawl_website_url}, knowledge_id={knowledge_id}, brain_id={brain_id}, notification_id={notification_id}" ) global engine assert engine with Session(engine, expire_on_commit=False, autoflush=False) as session: vector_repository = VectorRepository(session) vector_service = VectorService(vector_repository) loop = asyncio.get_event_loop() loop.run_until_complete( process_url_func( url=crawl_website_url, brain_id=brain_id, knowledge_id=knowledge_id, brain_service=brain_service, vector_service=vector_service, ) ) @celery.task(name="NotionConnectorLoad") def process_integration_brain_created_initial_load(brain_id, user_id): notion_connector = NotionConnector(brain_id=brain_id, user_id=user_id) pages = notion_connector.load() logger.info("Notion pages: ", len(pages)) @celery.task def process_integration_brain_sync_user_brain(brain_id, user_id): notion_connector = NotionConnector(brain_id=brain_id, user_id=user_id) notion_connector.poll() @celery.task def ping_telemetry(): maybe_send_telemetry("ping", {"ping": "pong"}) @celery.task(name="check_is_premium_task") def check_is_premium_task(): check_is_premium(supabase_client) @celery.task(name="process_sync_task") def process_sync_task( sync_id: int, user_id: str, files_ids: list[str], folder_ids: list[str] ): global async_engine assert async_engine sync = next( filter(lambda s: s.id == sync_id, sync_active_service.get_syncs_active(user_id)) ) loop = asyncio.get_event_loop() loop.run_until_complete( process_sync( sync=sync, files_ids=files_ids, folder_ids=folder_ids, services=SyncServices( async_engine=async_engine, sync_active_service=sync_active_service, sync_user_service=sync_user_service, sync_files_repo_service=sync_files_repo_service, storage=storage, brain_vectors=brain_vectors, notification_service=notification_service, ), ) ) @celery.task(name="process_active_syncs_task") def process_active_syncs_task(): global async_engine assert async_engine loop = asyncio.get_event_loop() loop.run_until_complete( process_all_active_syncs( SyncServices( async_engine=async_engine, sync_active_service=sync_active_service, sync_user_service=sync_user_service, sync_files_repo_service=sync_files_repo_service, storage=storage, brain_vectors=brain_vectors, notification_service=notification_service, ), ) ) @celery.task(name="process_notion_sync_task") def process_notion_sync_task(): global async_engine assert async_engine loop = asyncio.get_event_loop() loop.run_until_complete(process_notion_sync(async_engine)) @celery.task(name="fetch_and_store_notion_files_task") def fetch_and_store_notion_files_task(access_token: str, user_id: UUID): if async_engine is None: init_worker() assert async_engine logger.debug("Fetching and storing Notion files") loop = asyncio.get_event_loop() loop.run_until_complete( fetch_and_store_notion_files_async(async_engine, access_token, user_id) ) celery.conf.beat_schedule = { "ping_telemetry": { "task": f"{__name__}.ping_telemetry", "schedule": crontab(minute="*/30", hour="*"), }, "process_active_syncs": { "task": "process_active_syncs_task", "schedule": crontab(minute="*/1", hour="*"), }, "process_premium_users": { "task": "check_is_premium_task", "schedule": crontab(minute="*/1", hour="*"), }, "process_notion_sync": { "task": "process_notion_sync_task", "schedule": crontab(minute="0", hour="*/6"), }, }