mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 09:32:22 +03:00
b767f19f28
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --------- Co-authored-by: Zewed <dewez.antoine2@gmail.com>
427 lines
14 KiB
Python
427 lines
14 KiB
Python
import asyncio
|
|
import os
|
|
from uuid import UUID
|
|
|
|
import torch
|
|
from celery.schedules import crontab
|
|
from celery.signals import worker_process_init
|
|
from dotenv import load_dotenv
|
|
from quivr_api.celery_config import celery
|
|
from quivr_api.logger import get_logger
|
|
from quivr_api.models.settings import settings
|
|
from quivr_api.modules.assistant.repository.tasks import TasksRepository
|
|
from quivr_api.modules.assistant.services.tasks_service import TasksService
|
|
from quivr_api.modules.brain.integrations.Notion.Notion_connector import NotionConnector
|
|
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
|
|
from quivr_api.modules.brain.service.brain_service import BrainService
|
|
from quivr_api.modules.dependencies import get_supabase_client
|
|
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
|
from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
|
|
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
|
from quivr_api.modules.notification.service.notification_service import (
|
|
NotificationService,
|
|
)
|
|
from quivr_api.modules.sync.dto.inputs import SyncsUserStatus
|
|
from quivr_api.modules.sync.repository.sync_files import SyncFilesRepository
|
|
from quivr_api.modules.sync.service.sync_notion import SyncNotionService
|
|
from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
|
|
from quivr_api.modules.vector.repository.vectors_repository import VectorRepository
|
|
from quivr_api.modules.vector.service.vector_service import VectorService
|
|
from quivr_api.utils.telemetry import maybe_send_telemetry
|
|
from sqlalchemy import Engine, create_engine
|
|
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
|
from sqlmodel import Session, text
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from quivr_worker.assistants.assistants import process_assistant
|
|
from quivr_worker.celery_monitor import is_being_executed
|
|
from quivr_worker.check_premium import check_is_premium
|
|
from quivr_worker.process.process_s3_file import process_uploaded_file
|
|
from quivr_worker.process.process_url import process_url_func
|
|
from quivr_worker.syncs.process_active_syncs import (
|
|
SyncServices,
|
|
process_all_active_syncs,
|
|
process_notion_sync,
|
|
process_sync,
|
|
)
|
|
from quivr_worker.syncs.store_notion import fetch_and_store_notion_files_async
|
|
from quivr_worker.utils.utils import _patch_json
|
|
|
|
torch.set_num_threads(1)
|
|
|
|
|
|
load_dotenv()
|
|
|
|
get_logger("quivr_core")
|
|
logger = get_logger("celery_worker")
|
|
_patch_json()
|
|
|
|
|
|
# FIXME: load at init time
|
|
# Services
|
|
supabase_client = get_supabase_client()
|
|
# document_vector_store = get_documents_vector_store()
|
|
notification_service = NotificationService()
|
|
sync_active_service = SyncService()
|
|
sync_user_service = SyncUserService()
|
|
sync_files_repo_service = SyncFilesRepository()
|
|
brain_service = BrainService()
|
|
brain_vectors = BrainsVectors()
|
|
storage = SupabaseS3Storage()
|
|
notion_service: SyncNotionService | None = None
|
|
async_engine: AsyncEngine | None = None
|
|
engine: Engine | None = None
|
|
|
|
|
|
@worker_process_init.connect
|
|
def init_worker(**kwargs):
|
|
global async_engine
|
|
global engine
|
|
if not async_engine:
|
|
async_engine = create_async_engine(
|
|
settings.pg_database_async_url,
|
|
echo=True if os.getenv("ORM_DEBUG") else False,
|
|
future=True,
|
|
# NOTE: pessimistic bound on
|
|
pool_pre_ping=True,
|
|
pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6
|
|
pool_recycle=1800,
|
|
)
|
|
|
|
if not engine:
|
|
engine = create_engine(
|
|
settings.pg_database_url,
|
|
echo=True if os.getenv("ORM_DEBUG") else False,
|
|
future=True,
|
|
# NOTE: pessimistic bound on
|
|
pool_pre_ping=True,
|
|
pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6
|
|
pool_recycle=1800,
|
|
)
|
|
|
|
|
|
@celery.task(
|
|
retries=3,
|
|
default_retry_delay=1,
|
|
name="process_assistant_task",
|
|
autoretry_for=(Exception,),
|
|
)
|
|
def process_assistant_task(
|
|
assistant_id: str,
|
|
notification_uuid: str,
|
|
task_id: int,
|
|
user_id: str,
|
|
):
|
|
logger.info(
|
|
f"process_assistant_task started for assistant_id={assistant_id}, notification_uuid={notification_uuid}, task_id={task_id}"
|
|
)
|
|
print("process_assistant_task")
|
|
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(
|
|
aprocess_assistant_task(
|
|
assistant_id,
|
|
notification_uuid,
|
|
task_id,
|
|
user_id,
|
|
)
|
|
)
|
|
|
|
|
|
async def aprocess_assistant_task(
|
|
assistant_id: str,
|
|
notification_uuid: str,
|
|
task_id: int,
|
|
user_id: str,
|
|
):
|
|
global async_engine
|
|
assert async_engine
|
|
async with AsyncSession(async_engine) as async_session:
|
|
try:
|
|
await async_session.execute(
|
|
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
|
|
)
|
|
tasks_repository = TasksRepository(async_session)
|
|
tasks_service = TasksService(tasks_repository)
|
|
|
|
await process_assistant(
|
|
assistant_id,
|
|
notification_uuid,
|
|
task_id,
|
|
tasks_service,
|
|
user_id,
|
|
)
|
|
|
|
except Exception as e:
|
|
await async_session.rollback()
|
|
raise e
|
|
finally:
|
|
await async_session.close()
|
|
|
|
|
|
@celery.task(
|
|
retries=3,
|
|
default_retry_delay=1,
|
|
name="process_file_task",
|
|
autoretry_for=(Exception,),
|
|
dont_autoretry_for=(FileExistsError,),
|
|
)
|
|
def process_file_task(
|
|
file_name: str,
|
|
file_original_name: str,
|
|
brain_id: UUID,
|
|
notification_id: UUID,
|
|
knowledge_id: UUID,
|
|
source: str | None = None,
|
|
source_link: str | None = None,
|
|
delete_file: bool = False,
|
|
):
|
|
if async_engine is None:
|
|
init_worker()
|
|
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(
|
|
aprocess_file_task(
|
|
file_name=file_name,
|
|
file_original_name=file_original_name,
|
|
brain_id=brain_id,
|
|
notification_id=notification_id,
|
|
knowledge_id=knowledge_id,
|
|
source=source,
|
|
source_link=source_link,
|
|
delete_file=delete_file,
|
|
)
|
|
)
|
|
|
|
|
|
async def aprocess_file_task(
|
|
file_name: str,
|
|
file_original_name: str,
|
|
brain_id: UUID,
|
|
notification_id: UUID,
|
|
knowledge_id: UUID,
|
|
source: str | None = None,
|
|
source_link: str | None = None,
|
|
delete_file: bool = False,
|
|
):
|
|
global engine
|
|
assert engine
|
|
async with AsyncSession(async_engine) as async_session:
|
|
try:
|
|
await async_session.execute(
|
|
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
|
|
)
|
|
with Session(engine, expire_on_commit=False, autoflush=False) as session:
|
|
session.execute(
|
|
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
|
|
)
|
|
vector_repository = VectorRepository(session)
|
|
vector_service = VectorService(
|
|
vector_repository
|
|
) # FIXME @amine: fix to need AsyncSession in vector Service
|
|
knowledge_repository = KnowledgeRepository(async_session)
|
|
knowledge_service = KnowledgeService(knowledge_repository)
|
|
await process_uploaded_file(
|
|
supabase_client=supabase_client,
|
|
brain_service=brain_service,
|
|
vector_service=vector_service,
|
|
knowledge_service=knowledge_service,
|
|
file_name=file_name,
|
|
brain_id=brain_id,
|
|
file_original_name=file_original_name,
|
|
knowledge_id=knowledge_id,
|
|
integration=source,
|
|
integration_link=source_link,
|
|
delete_file=delete_file,
|
|
)
|
|
session.commit()
|
|
await async_session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
await async_session.rollback()
|
|
raise e
|
|
finally:
|
|
session.close()
|
|
await async_session.close()
|
|
|
|
|
|
@celery.task(
|
|
retries=3,
|
|
default_retry_delay=1,
|
|
name="process_crawl_task",
|
|
autoretry_for=(Exception,),
|
|
)
|
|
def process_crawl_task(
|
|
crawl_website_url: str,
|
|
brain_id: UUID,
|
|
knowledge_id: UUID,
|
|
notification_id: UUID | None = None,
|
|
):
|
|
logger.info(
|
|
f"Task process_crawl_task started for url={crawl_website_url}, knowledge_id={knowledge_id}, brain_id={brain_id}, notification_id={notification_id}"
|
|
)
|
|
global engine
|
|
assert engine
|
|
try:
|
|
with Session(engine, expire_on_commit=False, autoflush=False) as session:
|
|
session.execute(
|
|
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
|
|
)
|
|
vector_repository = VectorRepository(session)
|
|
vector_service = VectorService(vector_repository)
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(
|
|
process_url_func(
|
|
url=crawl_website_url,
|
|
brain_id=brain_id,
|
|
knowledge_id=knowledge_id,
|
|
brain_service=brain_service,
|
|
vector_service=vector_service,
|
|
)
|
|
)
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
raise e
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
@celery.task(name="NotionConnectorLoad")
|
|
def process_integration_brain_created_initial_load(brain_id, user_id):
|
|
notion_connector = NotionConnector(brain_id=brain_id, user_id=user_id)
|
|
pages = notion_connector.load()
|
|
logger.info("Notion pages: ", len(pages))
|
|
|
|
|
|
@celery.task
|
|
def process_integration_brain_sync_user_brain(brain_id, user_id):
|
|
notion_connector = NotionConnector(brain_id=brain_id, user_id=user_id)
|
|
notion_connector.poll()
|
|
|
|
|
|
@celery.task
|
|
def ping_telemetry():
|
|
maybe_send_telemetry("ping", {"ping": "pong"})
|
|
|
|
|
|
@celery.task(name="check_is_premium_task")
|
|
def check_is_premium_task():
|
|
check_is_premium(supabase_client)
|
|
|
|
|
|
@celery.task(name="process_sync_task")
|
|
def process_sync_task(
|
|
sync_id: int, user_id: str, files_ids: list[str], folder_ids: list[str]
|
|
):
|
|
global async_engine
|
|
assert async_engine
|
|
sync = next(
|
|
filter(lambda s: s.id == sync_id, sync_active_service.get_syncs_active(user_id))
|
|
)
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(
|
|
process_sync(
|
|
sync=sync,
|
|
files_ids=files_ids,
|
|
folder_ids=folder_ids,
|
|
services=SyncServices(
|
|
async_engine=async_engine,
|
|
sync_active_service=sync_active_service,
|
|
sync_user_service=sync_user_service,
|
|
sync_files_repo_service=sync_files_repo_service,
|
|
storage=storage,
|
|
brain_vectors=brain_vectors,
|
|
notification_service=notification_service,
|
|
),
|
|
)
|
|
)
|
|
|
|
|
|
@celery.task(name="process_active_syncs_task")
|
|
def process_active_syncs_task():
|
|
sync_already_running = is_being_executed("process_sync_task")
|
|
|
|
if sync_already_running:
|
|
logger.info("Sync already running, skipping")
|
|
return
|
|
global async_engine
|
|
assert async_engine
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(
|
|
process_all_active_syncs(
|
|
SyncServices(
|
|
async_engine=async_engine,
|
|
sync_active_service=sync_active_service,
|
|
sync_user_service=sync_user_service,
|
|
sync_files_repo_service=sync_files_repo_service,
|
|
storage=storage,
|
|
brain_vectors=brain_vectors,
|
|
notification_service=notification_service,
|
|
),
|
|
)
|
|
)
|
|
|
|
|
|
@celery.task(name="process_notion_sync_task")
|
|
def process_notion_sync_task():
|
|
global async_engine
|
|
assert async_engine
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(process_notion_sync(async_engine))
|
|
|
|
|
|
@celery.task(name="fetch_and_store_notion_files_task")
|
|
def fetch_and_store_notion_files_task(
|
|
access_token: str, user_id: UUID, sync_user_id: int
|
|
):
|
|
if async_engine is None:
|
|
init_worker()
|
|
assert async_engine
|
|
try:
|
|
logger.debug("Fetching and storing Notion files")
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(
|
|
fetch_and_store_notion_files_async(
|
|
async_engine, access_token, user_id, sync_user_id
|
|
)
|
|
)
|
|
sync_user_service.update_sync_user_status(
|
|
sync_user_id=sync_user_id, status=str(SyncsUserStatus.SYNCED)
|
|
)
|
|
except Exception:
|
|
logger.error("Error fetching and storing Notion files")
|
|
sync_user_service.update_sync_user_status(
|
|
sync_user_id=sync_user_id, status=str(SyncsUserStatus.ERROR)
|
|
)
|
|
|
|
|
|
@celery.task(name="clean_notion_user_syncs")
|
|
def clean_notion_user_syncs():
|
|
logger.debug("Cleaning Notion user syncs")
|
|
sync_user_service.clean_notion_user_syncs()
|
|
|
|
|
|
celery.conf.beat_schedule = {
|
|
"ping_telemetry": {
|
|
"task": f"{__name__}.ping_telemetry",
|
|
"schedule": crontab(minute="*/30", hour="*"),
|
|
},
|
|
"process_active_syncs": {
|
|
"task": "process_active_syncs_task",
|
|
"schedule": crontab(minute="*/1", hour="*"),
|
|
},
|
|
"process_premium_users": {
|
|
"task": "check_is_premium_task",
|
|
"schedule": crontab(minute="*/1", hour="*"),
|
|
},
|
|
"process_notion_sync": {
|
|
"task": "process_notion_sync_task",
|
|
"schedule": crontab(minute="0", hour="*/6"),
|
|
},
|
|
"clean_notion_user_syncs": {
|
|
"task": "clean_notion_user_syncs",
|
|
"schedule": crontab(minute="0", hour="0"),
|
|
},
|
|
}
|