mirror of
https://github.com/QuivrHQ/quivr.git
synced 2025-01-05 23:03:53 +03:00
c643157b75
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): Co-authored-by: Stan Girard <stan@quivr.app>
284 lines
9.2 KiB
Python
284 lines
9.2 KiB
Python
import os
|
|
from datetime import datetime, timedelta
|
|
from tempfile import NamedTemporaryFile
|
|
from uuid import UUID
|
|
|
|
from celery.schedules import crontab
|
|
from pytz import timezone
|
|
from quivr_api.celery_config import celery
|
|
from quivr_api.logger import get_logger
|
|
from quivr_api.middlewares.auth.auth_bearer import AuthBearer
|
|
from quivr_api.models.files import File
|
|
from quivr_api.models.settings import get_supabase_client, get_supabase_db
|
|
from quivr_api.modules.brain.integrations.Notion.Notion_connector import NotionConnector
|
|
from quivr_api.modules.brain.service.brain_service import BrainService
|
|
from quivr_api.modules.brain.service.brain_vector_service import BrainVectorService
|
|
from quivr_api.modules.notification.service.notification_service import (
|
|
NotificationService,
|
|
)
|
|
from quivr_api.modules.onboarding.service.onboarding_service import OnboardingService
|
|
from quivr_api.packages.files.crawl.crawler import CrawlWebsite, slugify
|
|
from quivr_api.packages.files.processors import filter_file
|
|
from quivr_api.packages.utils.telemetry import maybe_send_telemetry
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
onboardingService = OnboardingService()
|
|
notification_service = NotificationService()
|
|
brain_service = BrainService()
|
|
auth_bearer = AuthBearer()
|
|
|
|
|
|
@celery.task(
|
|
retries=3,
|
|
default_retry_delay=1,
|
|
name="process_file_and_notify",
|
|
autoretry_for=(Exception,),
|
|
)
|
|
def process_file_and_notify(
|
|
file_name: str,
|
|
file_original_name: str,
|
|
brain_id,
|
|
notification_id: UUID,
|
|
knowledge_id: UUID,
|
|
integration=None,
|
|
delete_file=False,
|
|
):
|
|
logger.debug(
|
|
f"process_file file_name={file_name}, knowledge_id={knowledge_id}, brain_id={brain_id}, notification_id={notification_id}"
|
|
)
|
|
supabase_client = get_supabase_client()
|
|
tmp_name = file_name.replace("/", "_")
|
|
base_file_name = os.path.basename(file_name)
|
|
_, file_extension = os.path.splitext(base_file_name)
|
|
|
|
with NamedTemporaryFile(
|
|
suffix="_" + tmp_name, # pyright: ignore reportPrivateUsage=none
|
|
) as tmp_file:
|
|
res = supabase_client.storage.from_("quivr").download(file_name)
|
|
tmp_file.write(res)
|
|
tmp_file.flush()
|
|
file_instance = File(
|
|
file_name=base_file_name,
|
|
tmp_file_path=tmp_file.name,
|
|
bytes_content=res,
|
|
file_size=len(res),
|
|
file_extension=file_extension,
|
|
)
|
|
brain_vector_service = BrainVectorService(brain_id)
|
|
if delete_file: # TODO fix bug
|
|
brain_vector_service.delete_file_from_brain(
|
|
file_original_name, only_vectors=True
|
|
)
|
|
|
|
filter_file(
|
|
file=file_instance,
|
|
brain_id=brain_id,
|
|
original_file_name=file_original_name,
|
|
)
|
|
|
|
brain_service.update_brain_last_update_time(brain_id)
|
|
|
|
|
|
@celery.task(
|
|
retries=3,
|
|
default_retry_delay=1,
|
|
name="process_crawl_and_notify",
|
|
autoretry_for=(Exception,),
|
|
)
|
|
def process_crawl_and_notify(
|
|
crawl_website_url: str,
|
|
brain_id: UUID,
|
|
knowledge_id: UUID,
|
|
notification_id=None,
|
|
):
|
|
crawl_website = CrawlWebsite(url=crawl_website_url)
|
|
# Build file data
|
|
extracted_content = crawl_website.process()
|
|
extracted_content_bytes = extracted_content.encode("utf-8")
|
|
file_name = slugify(crawl_website.url) + ".txt"
|
|
|
|
with NamedTemporaryFile(
|
|
suffix="_" + file_name, # pyright: ignore reportPrivateUsage=none
|
|
) as tmp_file:
|
|
tmp_file.write(extracted_content_bytes)
|
|
tmp_file.flush()
|
|
file_instance = File(
|
|
file_name=file_name,
|
|
tmp_file_path=tmp_file.name,
|
|
bytes_content=extracted_content_bytes,
|
|
file_size=len(extracted_content),
|
|
file_extension=".txt",
|
|
)
|
|
filter_file(
|
|
file=file_instance,
|
|
brain_id=brain_id,
|
|
original_file_name=crawl_website_url,
|
|
)
|
|
|
|
|
|
@celery.task
|
|
def remove_onboarding_more_than_x_days_task():
|
|
onboardingService.remove_onboarding_more_than_x_days(7)
|
|
|
|
|
|
@celery.task(name="NotionConnectorLoad")
|
|
def process_integration_brain_created_initial_load(brain_id, user_id):
|
|
notion_connector = NotionConnector(brain_id=brain_id, user_id=user_id)
|
|
|
|
pages = notion_connector.load()
|
|
|
|
print("pages: ", len(pages))
|
|
|
|
|
|
@celery.task
|
|
def process_integration_brain_sync_user_brain(brain_id, user_id):
|
|
notion_connector = NotionConnector(brain_id=brain_id, user_id=user_id)
|
|
|
|
notion_connector.poll()
|
|
|
|
|
|
@celery.task
|
|
def ping_telemetry():
|
|
maybe_send_telemetry("ping", {"ping": "pong"})
|
|
|
|
|
|
@celery.task(name="check_if_is_premium_user")
|
|
def check_if_is_premium_user():
|
|
supabase = get_supabase_db()
|
|
supabase_db = supabase.db
|
|
|
|
paris_tz = timezone("Europe/Paris")
|
|
current_time = datetime.now(paris_tz)
|
|
current_time_str = current_time.strftime("%Y-%m-%d %H:%M:%S.%f")
|
|
logger.debug(f"Current time: {current_time_str}")
|
|
|
|
# Define the memoization period (e.g., 1 hour)
|
|
memoization_period = timedelta(hours=1)
|
|
memoization_cutoff = current_time - memoization_period
|
|
|
|
# Fetch all necessary data in bulk
|
|
subscriptions = (
|
|
supabase_db.table("subscriptions")
|
|
.select("*")
|
|
.filter("current_period_end", "gt", current_time_str)
|
|
.execute()
|
|
).data
|
|
|
|
customers = (supabase_db.table("customers").select("*").execute()).data
|
|
|
|
customer_emails = [customer["email"] for customer in customers]
|
|
|
|
# Split customer emails into batches of 50
|
|
email_batches = [
|
|
customer_emails[i : i + 20] for i in range(0, len(customer_emails), 20)
|
|
]
|
|
|
|
users = []
|
|
for email_batch in email_batches:
|
|
batch_users = (
|
|
supabase_db.table("users")
|
|
.select("id, email")
|
|
.in_("email", email_batch)
|
|
.execute()
|
|
).data
|
|
users.extend(batch_users)
|
|
|
|
product_features = (
|
|
supabase_db.table("product_to_features").select("*").execute()
|
|
).data
|
|
|
|
user_settings = (supabase_db.table("user_settings").select("*").execute()).data
|
|
|
|
# Create lookup dictionaries for faster access
|
|
user_dict = {user["email"]: user["id"] for user in users}
|
|
customer_dict = {customer["id"]: customer for customer in customers}
|
|
product_dict = {
|
|
product["stripe_product_id"]: product for product in product_features
|
|
}
|
|
settings_dict = {setting["user_id"]: setting for setting in user_settings}
|
|
|
|
# Process subscriptions and update user settings
|
|
premium_user_ids = set()
|
|
settings_to_upsert = {}
|
|
for sub in subscriptions:
|
|
if sub["attrs"]["status"] != "active":
|
|
continue
|
|
|
|
customer = customer_dict.get(sub["customer"])
|
|
if not customer:
|
|
continue
|
|
|
|
user_id = user_dict.get(customer["email"])
|
|
if not user_id:
|
|
continue
|
|
|
|
current_settings = settings_dict.get(user_id, {})
|
|
last_check = current_settings.get("last_stripe_check")
|
|
|
|
# Skip if the user was checked recently
|
|
if last_check and datetime.fromisoformat(last_check) > memoization_cutoff:
|
|
premium_user_ids.add(user_id)
|
|
continue
|
|
|
|
user_id = str(user_id) # Ensure user_id is a string
|
|
premium_user_ids.add(user_id)
|
|
|
|
product_id = sub["attrs"]["items"]["data"][0]["plan"]["product"]
|
|
product = product_dict.get(product_id)
|
|
if not product:
|
|
logger.warning(f"No matching product found for subscription: {sub['id']}")
|
|
continue
|
|
|
|
settings_to_upsert[user_id] = {
|
|
"user_id": user_id,
|
|
"max_brains": product["max_brains"],
|
|
"max_brain_size": product["max_brain_size"],
|
|
"monthly_chat_credit": product["monthly_chat_credit"],
|
|
"api_access": product["api_access"],
|
|
"models": product["models"],
|
|
"is_premium": True,
|
|
"last_stripe_check": current_time_str,
|
|
}
|
|
|
|
# Bulk upsert premium user settings in batches of 10
|
|
settings_list = list(settings_to_upsert.values())
|
|
for i in range(0, len(settings_list), 10):
|
|
batch = settings_list[i : i + 10]
|
|
supabase_db.table("user_settings").upsert(batch).execute()
|
|
|
|
# Delete settings for non-premium users in batches of 10
|
|
settings_to_delete = [
|
|
setting["user_id"]
|
|
for setting in user_settings
|
|
if setting["user_id"] not in premium_user_ids and setting.get("is_premium")
|
|
]
|
|
for i in range(0, len(settings_to_delete), 10):
|
|
batch = settings_to_delete[i : i + 10]
|
|
supabase_db.table("user_settings").delete().in_("user_id", batch).execute()
|
|
|
|
logger.info(
|
|
f"Updated {len(settings_to_upsert)} premium users, deleted settings for {len(settings_to_delete)} non-premium users"
|
|
)
|
|
return True
|
|
|
|
|
|
celery.conf.beat_schedule = {
|
|
"remove_onboarding_more_than_x_days_task": {
|
|
"task": f"{__name__}.remove_onboarding_more_than_x_days_task",
|
|
"schedule": crontab(minute="0", hour="0"),
|
|
},
|
|
"ping_telemetry": {
|
|
"task": f"{__name__}.ping_telemetry",
|
|
"schedule": crontab(minute="*/30", hour="*"),
|
|
},
|
|
"process_sync_active": {
|
|
"task": "process_sync_active",
|
|
"schedule": crontab(minute="*/1", hour="*"),
|
|
},
|
|
"process_premium_users": {
|
|
"task": "check_if_is_premium_user",
|
|
"schedule": crontab(minute="*/1", hour="*"),
|
|
},
|
|
}
|