fix: remove knowledge and idle conn (#3165)

# Description

- Add set session idle conn
- Remove brain vectore in knowledge routees ( knowledge cascades to
vectors now), should speed up th
- Add session rollback/commit to worker

TODO: Should refactor session in worker . The design async+ celery is
starting to be really bad
This commit is contained in:
AmineDiro 2024-09-06 15:56:07 +02:00 committed by GitHub
parent 50a940c67e
commit 89d9e2fcaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 165 additions and 121 deletions

View File

@ -12,7 +12,7 @@ from langchain_openai import OpenAIEmbeddings
# from quivr_api.modules.vectorstore.supabase import CustomSupabaseVectorStore
from sqlalchemy import Engine, create_engine
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import Session
from sqlmodel import Session, text
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.logger import get_logger
@ -77,16 +77,17 @@ async_engine = create_async_engine(
def get_sync_session() -> Generator[Session, None, None]:
with Session(sync_engine, expire_on_commit=False, autoflush=False) as session:
yield session
# def get_documents_vector_store(vector_service: VectorService) -> SupabaseVectorStore:
# embeddings = get_embedding_client()
# supabase_client: Client = get_supabase_client()
# documents_vector_store = CustomSupabaseVectorStore( # Modified by @chloe Check
# supabase_client, embeddings, table_name="vectors", vector_service=vector_service
# )
# return documents_vector_store
try:
session.execute(
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
)
yield session
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
@ -94,6 +95,9 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async_engine,
) as session:
try:
await session.execute(
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
)
yield session
await session.commit()
except Exception as e:

View File

@ -11,7 +11,6 @@ from quivr_api.modules.brain.service.brain_authorization_service import (
has_brain_authorization,
validate_brain_authorization,
)
from quivr_api.modules.brain.service.brain_vector_service import BrainVectorService
from quivr_api.modules.dependencies import get_service
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.upload.service.generate_file_signed_url import (
@ -68,12 +67,6 @@ async def delete_endpoint(
file_name = knowledge.file_name if knowledge.file_name else knowledge.url
await knowledge_service.remove_knowledge(brain_id, knowledge_id)
brain_vector_service = BrainVectorService(brain_id)
if knowledge.file_name:
brain_vector_service.delete_file_from_brain(knowledge.file_name)
elif knowledge.url:
brain_vector_service.delete_file_url_from_brain(knowledge.url)
return {
"message": f"{file_name} of brain {brain_id} has been deleted by user {current_user.email}."
}

View File

@ -49,7 +49,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
return next(
f"{b.brain_id}/{file_name}"
for b in brains
if check_file_exists(b.brain_id, file_name)
if check_file_exists(str(b.brain_id), file_name)
)
except NoResultFound:
raise FileNotFoundError(f"No knowledge for file_name: {file_name}")

View File

@ -26,7 +26,7 @@ from quivr_api.modules.vector.service.vector_service import VectorService
from quivr_api.utils.telemetry import maybe_send_telemetry
from sqlalchemy import Engine, create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel import Session
from sqlmodel import Session, text
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_worker.check_premium import check_is_premium
@ -143,26 +143,40 @@ async def aprocess_file_task(
global engine
assert engine
async with AsyncSession(async_engine) as async_session:
with Session(engine, expire_on_commit=False, autoflush=False) as session:
vector_repository = VectorRepository(session)
vector_service = VectorService(
vector_repository
) # FIXME @amine: fix to need AsyncSession in vector Service
knowledge_repository = KnowledgeRepository(async_session)
knowledge_service = KnowledgeService(knowledge_repository)
await process_uploaded_file(
supabase_client=supabase_client,
brain_service=brain_service,
vector_service=vector_service,
knowledge_service=knowledge_service,
file_name=file_name,
brain_id=brain_id,
file_original_name=file_original_name,
knowledge_id=knowledge_id,
integration=source,
integration_link=source_link,
delete_file=delete_file,
try:
await async_session.execute(
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
)
with Session(engine, expire_on_commit=False, autoflush=False) as session:
session.execute(
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
)
vector_repository = VectorRepository(session)
vector_service = VectorService(
vector_repository
) # FIXME @amine: fix to need AsyncSession in vector Service
knowledge_repository = KnowledgeRepository(async_session)
knowledge_service = KnowledgeService(knowledge_repository)
await process_uploaded_file(
supabase_client=supabase_client,
brain_service=brain_service,
vector_service=vector_service,
knowledge_service=knowledge_service,
file_name=file_name,
brain_id=brain_id,
file_original_name=file_original_name,
knowledge_id=knowledge_id,
integration=source,
integration_link=source_link,
delete_file=delete_file,
)
except Exception as e:
session.rollback()
await async_session.rollback()
raise e
finally:
session.close()
await async_session.close()
@celery.task(

View File

@ -17,6 +17,7 @@ from quivr_api.modules.sync.service.sync_notion import (
from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
from quivr_api.modules.sync.utils.syncutils import SyncUtils
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlmodel import text
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_worker.syncs.utils import SyncServices, build_syncs_utils
@ -96,43 +97,53 @@ async def _process_all_active_syncs(
async def process_notion_sync(
async_engine: AsyncEngine,
):
async with AsyncSession(
async_engine, expire_on_commit=False, autoflush=False
) as session:
sync_user_service = SyncUserService()
notion_repository = NotionRepository(session)
notion_service = SyncNotionService(notion_repository)
# TODO: Add state in sync_user to check if the same fetching is running
# Get active tasks for all workers
active_tasks = celery_inspector.active()
is_uploading_task_running = any(
"fetch_and_store_notion_files" in task
for worker_tasks in active_tasks.values()
for task in worker_tasks
)
if is_uploading_task_running:
return None
# Get all notion syncs
notion_syncs = sync_user_service.get_all_notion_user_syncs()
for notion_sync in notion_syncs:
user_id = notion_sync["user_id"]
notion_client = Client(auth=notion_sync["credentials"]["access_token"])
# TODO: fetch last_sync_time from table
pages_to_update = fetch_limit_notion_pages(
notion_client,
datetime.now() - timedelta(hours=6),
try:
async with AsyncSession(
async_engine, expire_on_commit=False, autoflush=False
) as session:
await session.execute(
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
)
logger.debug("Number of pages to update: %s", len(pages_to_update))
if not pages_to_update:
logger.info("No pages to update")
continue
sync_user_service = SyncUserService()
notion_repository = NotionRepository(session)
notion_service = SyncNotionService(notion_repository)
await update_notion_pages(
notion_service,
pages_to_update,
UUID(user_id),
notion_client, # type: ignore
# TODO: Add state in sync_user to check if the same fetching is running
# Get active tasks for all workers
active_tasks = celery_inspector.active()
is_uploading_task_running = any(
"fetch_and_store_notion_files" in task
for worker_tasks in active_tasks.values()
for task in worker_tasks
)
if is_uploading_task_running:
return None
# Get all notion syncs
notion_syncs = sync_user_service.get_all_notion_user_syncs()
for notion_sync in notion_syncs:
user_id = notion_sync["user_id"]
notion_client = Client(auth=notion_sync["credentials"]["access_token"])
# TODO: fetch last_sync_time from table
pages_to_update = fetch_limit_notion_pages(
notion_client,
datetime.now() - timedelta(hours=6),
)
logger.debug("Number of pages to update: %s", len(pages_to_update))
if not pages_to_update:
logger.info("No pages to update")
continue
await update_notion_pages(
notion_service,
pages_to_update,
UUID(user_id),
notion_client, # type: ignore
)
except Exception as e:
await session.rollback()
raise e
finally:
await session.close()

View File

@ -10,6 +10,7 @@ from quivr_api.modules.sync.service.sync_notion import (
store_notion_pages,
)
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlmodel import text
from sqlmodel.ext.asyncio.session import AsyncSession
logger = get_logger("celery_worker")
@ -18,19 +19,29 @@ logger = get_logger("celery_worker")
async def fetch_and_store_notion_files_async(
async_engine: AsyncEngine, access_token: str, user_id: UUID
):
async with AsyncSession(
async_engine, expire_on_commit=False, autoflush=False
) as session:
notion_repository = NotionRepository(session)
notion_service = SyncNotionService(notion_repository)
notion_client = Client(auth=access_token)
all_search_result = fetch_limit_notion_pages(
notion_client,
last_sync_time=datetime(1970, 1, 1, 0, 0, 0), # UNIX EPOCH
)
logger.debug(f"Notion fetched {len(all_search_result)} pages")
pages = await store_notion_pages(all_search_result, notion_service, user_id)
if pages:
logger.info(f"stored {len(pages)} from notion for {user_id}")
else:
logger.warn("No notion page fetched")
try:
async with AsyncSession(
async_engine, expire_on_commit=False, autoflush=False
) as session:
await session.execute(
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
)
notion_repository = NotionRepository(session)
notion_service = SyncNotionService(notion_repository)
notion_client = Client(auth=access_token)
all_search_result = fetch_limit_notion_pages(
notion_client,
last_sync_time=datetime(1970, 1, 1, 0, 0, 0), # UNIX EPOCH
)
logger.debug(f"Notion fetched {len(all_search_result)} pages")
pages = await store_notion_pages(all_search_result, notion_service, user_id)
if pages:
logger.info(f"stored {len(pages)} from notion for {user_id}")
else:
logger.warn("No notion page fetched")
except Exception as e:
await session.rollback()
raise e
finally:
await session.close()

View File

@ -26,6 +26,7 @@ from quivr_api.modules.sync.utils.sync import (
)
from quivr_api.modules.sync.utils.syncutils import SyncUtils
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlmodel import text
from sqlmodel.ext.asyncio.session import AsyncSession
celery_inspector = celery.control.inspect()
@ -48,34 +49,44 @@ class SyncServices:
async def build_syncs_utils(
deps: SyncServices,
) -> AsyncGenerator[dict[str, SyncUtils], None]:
async with AsyncSession(
deps.async_engine, expire_on_commit=False, autoflush=False
) as session:
# TODO pass services from celery_worker
notion_repository = NotionRepository(session)
notion_service = SyncNotionService(notion_repository)
knowledge_service = KnowledgeService(KnowledgeRepository(session))
mapping_sync_utils = {}
for provider_name, sync_cloud in [
("google", GoogleDriveSync()),
("azure", AzureDriveSync()),
("dropbox", DropboxSync()),
("github", GitHubSync()),
(
"notion",
NotionSync(notion_service=notion_service),
), # Fixed duplicate "github" key
]:
provider_sync_util = SyncUtils(
sync_user_service=deps.sync_user_service,
sync_active_service=deps.sync_active_service,
sync_files_repo=deps.sync_files_repo_service,
sync_cloud=sync_cloud,
notification_service=deps.notification_service,
brain_vectors=deps.brain_vectors,
knowledge_service=knowledge_service,
try:
async with AsyncSession(
deps.async_engine, expire_on_commit=False, autoflush=False
) as session:
await session.execute(
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
)
mapping_sync_utils[provider_name] = provider_sync_util
# TODO pass services from celery_worker
notion_repository = NotionRepository(session)
notion_service = SyncNotionService(notion_repository)
knowledge_service = KnowledgeService(KnowledgeRepository(session))
yield mapping_sync_utils
mapping_sync_utils = {}
for provider_name, sync_cloud in [
("google", GoogleDriveSync()),
("azure", AzureDriveSync()),
("dropbox", DropboxSync()),
("github", GitHubSync()),
(
"notion",
NotionSync(notion_service=notion_service),
), # Fixed duplicate "github" key
]:
provider_sync_util = SyncUtils(
sync_user_service=deps.sync_user_service,
sync_active_service=deps.sync_active_service,
sync_files_repo=deps.sync_files_repo_service,
sync_cloud=sync_cloud,
notification_service=deps.notification_service,
brain_vectors=deps.brain_vectors,
knowledge_service=knowledge_service,
)
mapping_sync_utils[provider_name] = provider_sync_util
yield mapping_sync_utils
except Exception as e:
await session.rollback()
raise e
finally:
await session.close()