feat(upload): async improved (#2544)

# Description
Hey,

Here's a breakdown of what I've done:

- Reducing the number of opened fd and memory footprint: Previously, for
each uploaded file, we were opening a temporary NamedTemporaryFile to
write existing content read from Supabase. However, due to the
dependency on `langchain` loader classes, we couldn't use memory buffers
for the loaders. Now, with the changes made, we only open a single
temporary file for each `process_file_and_notify`, cutting down on
excessive file opening, read syscalls, and memory buffer usage. This
could cause stability issues when ingesting and processing large volumes
of documents. Unfortunately, there is still reopening of temporary files
in some code paths but this can be improved further in later work.
- Removing `UploadFile` class from File: The `UploadFile` ( a FastAPI
abstraction over a SpooledTemporaryFile for multipart upload) was
redundant in our `File` setup since we already downloaded the file from
remote storage and read it into memory + wrote the file into a temp
file. By removing this abstraction, we streamline our code and eliminate
unnecessary complexity.
- `async` function Adjustments: I've removed the async labeling from
functions where it wasn't truly asynchronous. For instance, calling
`filter_file` for processing files isn't genuinely async, ass async file
reading isn't actually asynchronous—it [uses a threadpool for reading
the
file](9f16bf5c25/starlette/datastructures.py (L458))
. Given that we're already leveraging `celery` for parallelism (one
worker per core), we need to ensure that reading and processing occur in
the same thread, or at least minimize thread spawning. Additionally,
since the rest of the code isn't inherently asynchronous, our bottleneck
lies in CPU operations rather than asynchronous processing.

These changes aim to improve performance and streamline our codebase. 
Let me know if you have any questions or suggestions for further
improvements!

## Checklist before requesting a review
- [x] My code follows the style guidelines of this project
- [x] I have performed a self-review of my code
- [x] I have ideally added tests that prove my fix is effective or that
my feature works

---------

Signed-off-by: aminediro <aminediro@github.com>
Co-authored-by: aminediro <aminediro@github.com>
Co-authored-by: Stan Girard <girard.stanislas@gmail.com>
This commit is contained in:
AmineDiro 2024-06-04 15:29:27 +02:00 committed by GitHub
parent a8e20c5ee3
commit 675885c762
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 103 additions and 177 deletions

View File

@ -1,8 +1,11 @@
# celery_config.py
import os
import dotenv
from celery import Celery
dotenv.load_dotenv()
CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "")
CELERY_BROKER_QUEUE_NAME = os.getenv("CELERY_BROKER_QUEUE_NAME", "quivr")

View File

@ -1,11 +1,10 @@
import asyncio
import io
import os
from datetime import datetime, timezone
from tempfile import NamedTemporaryFile
from uuid import UUID
from celery.schedules import crontab
from celery_config import celery
from fastapi import UploadFile
from logger import get_logger
from middlewares.auth.auth_bearer import AuthBearer
from models.files import File
@ -18,7 +17,7 @@ from modules.notification.dto.inputs import NotificationUpdatableProperties
from modules.notification.entity.notification import NotificationsStatusEnum
from modules.notification.service.notification_service import NotificationService
from modules.onboarding.service.onboarding_service import OnboardingService
from packages.files.crawl.crawler import CrawlWebsite
from packages.files.crawl.crawler import CrawlWebsite, slugify
from packages.files.parsers.github import process_github
from packages.files.processors import filter_file
from packages.utils.telemetry import maybe_send_telemetry
@ -42,39 +41,36 @@ def process_file_and_notify(
):
try:
supabase_client = get_supabase_client()
tmp_file_name = "tmp-file-" + file_name
tmp_file_name = tmp_file_name.replace("/", "_")
tmp_name = file_name.replace("/", "_")
base_file_name = os.path.basename(file_name)
_, file_extension = os.path.splitext(base_file_name)
with open(tmp_file_name, "wb+") as f:
with NamedTemporaryFile(
suffix="_" + tmp_name, # pyright: ignore reportPrivateUsage=none
) as tmp_file:
res = supabase_client.storage.from_("quivr").download(file_name)
f.write(res)
f.seek(0)
file_content = f.read()
upload_file = UploadFile(
file=f, filename=file_name.split("/")[-1], size=len(file_content)
tmp_file.write(res)
tmp_file.flush()
file_instance = File(
file_name=base_file_name,
tmp_file_path=tmp_file.name,
bytes_content=res,
file_size=len(res),
file_extension=file_extension,
)
file_instance = File(file=upload_file)
loop = asyncio.get_event_loop()
brain_vector_service = BrainVectorService(brain_id)
if delete_file: # TODO fix bug
brain_vector_service.delete_file_from_brain(
file_original_name, only_vectors=True
)
message = loop.run_until_complete(
filter_file(
file=file_instance,
brain_id=brain_id,
original_file_name=file_original_name,
)
message = filter_file(
file=file_instance,
brain_id=brain_id,
original_file_name=file_original_name,
)
f.close()
os.remove(tmp_file_name)
if notification_id:
notification_service.update_notification_by_id(
notification_id,
NotificationUpdatableProperties(
@ -85,10 +81,12 @@ def process_file_and_notify(
brain_service.update_brain_last_update_time(brain_id)
return True
except TimeoutError:
logger.error("TimeoutError")
except Exception as e:
logger.exception(e)
notification_service.update_notification_by_id(
notification_id,
NotificationUpdatableProperties(
@ -96,52 +94,51 @@ def process_file_and_notify(
description=f"An error occurred while processing the file: {e}",
),
)
return False
@celery.task(name="process_crawl_and_notify")
def process_crawl_and_notify(
crawl_website_url,
brain_id,
crawl_website_url: str,
brain_id: UUID,
notification_id=None,
):
crawl_website = CrawlWebsite(url=crawl_website_url)
if not crawl_website.checkGithub():
file_path, file_name = crawl_website.process()
# Build file data
extracted_content = crawl_website.process()
extracted_content_bytes = extracted_content.encode("utf-8")
file_name = slugify(crawl_website.url) + ".txt"
with open(file_path, "rb") as f:
file_content = f.read()
# Create a file-like object in memory using BytesIO
file_object = io.BytesIO(file_content)
upload_file = UploadFile(
file=file_object, filename=file_name, size=len(file_content)
)
file_instance = File(file=upload_file)
loop = asyncio.get_event_loop()
message = loop.run_until_complete(
filter_file(
with NamedTemporaryFile(
suffix="_" + file_name, # pyright: ignore reportPrivateUsage=none
) as tmp_file:
tmp_file.write(extracted_content_bytes)
tmp_file.flush()
file_instance = File(
file_name=file_name,
tmp_file_path=tmp_file.name,
bytes_content=extracted_content_bytes,
file_size=len(extracted_content),
file_extension=".txt",
)
message = filter_file(
file=file_instance,
brain_id=brain_id,
original_file_name=crawl_website_url,
)
)
notification_service.update_notification_by_id(
notification_id,
NotificationUpdatableProperties(
status=NotificationsStatusEnum.SUCCESS,
description=f"Your URL has been properly crawled!",
),
)
else:
loop = asyncio.get_event_loop()
message = loop.run_until_complete(
process_github(
repo=crawl_website.url,
brain_id=brain_id,
notification_service.update_notification_by_id(
notification_id,
NotificationUpdatableProperties(
status=NotificationsStatusEnum.SUCCESS,
description="Your URL has been properly crawled!",
),
)
else:
message = process_github(
repo=crawl_website.url,
brain_id=brain_id,
)
if notification_id:

View File

@ -1,15 +1,9 @@
import os
if __name__ == "__main__":
# import needed here when running main.py to debug backend
# you will need to run pip install python-dotenv
from dotenv import load_dotenv # type: ignore
load_dotenv()
import logging
import os
import litellm
import sentry_sdk
from dotenv import load_dotenv # type: ignore
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from logger import get_logger
@ -35,6 +29,8 @@ from routes.subscription_routes import subscription_router
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
load_dotenv()
# Set the logging level for all loggers to WARNING
logging.basicConfig(level=logging.INFO)
logging.getLogger("httpx").setLevel(logging.WARNING)

View File

@ -1,66 +1,38 @@
import os
import tempfile
from typing import Any, Optional
from uuid import UUID
from pathlib import Path
from typing import List, Optional
from fastapi import UploadFile
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from logger import get_logger
from models.databases.supabase.supabase import SupabaseDB
from models.settings import get_supabase_db
from modules.brain.service.brain_vector_service import BrainVectorService
from packages.files.file import compute_sha1_from_file
from packages.files.file import compute_sha1_from_content
from pydantic import BaseModel
logger = get_logger(__name__)
class File(BaseModel):
id: Optional[UUID] = None
file: Optional[UploadFile] = None
file_name: Optional[str] = ""
file_size: Optional[int] = None
file_sha1: Optional[str] = ""
vectors_ids: Optional[list] = []
file_extension: Optional[str] = ""
content: Optional[Any] = None
file_name: str
tmp_file_path: Path
bytes_content: bytes
file_size: int
file_extension: str
chunk_size: int = 400
chunk_overlap: int = 100
documents: Optional[Document] = None
documents: List[Document] = []
file_sha1: Optional[str] = None
vectors_ids: Optional[list] = []
def __init__(self, **data):
super().__init__(**data)
data["file_sha1"] = compute_sha1_from_content(data["bytes_content"])
@property
def supabase_db(self) -> SupabaseDB:
return get_supabase_db()
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.file:
self.file_name = self.file.filename
self.file_size = self.file.size # pyright: ignore reportPrivateUsage=none
self.file_extension = os.path.splitext(
self.file.filename # pyright: ignore reportPrivateUsage=none
)[-1].lower()
async def compute_file_sha1(self):
"""
Compute the sha1 of the file using a temporary file
"""
with tempfile.NamedTemporaryFile(
delete=False,
suffix=self.file.filename, # pyright: ignore reportPrivateUsage=none
) as tmp_file:
await self.file.seek(0) # pyright: ignore reportPrivateUsage=none
self.content = (
await self.file.read() # pyright: ignore reportPrivateUsage=none
)
tmp_file.write(self.content)
tmp_file.flush()
self.file_sha1 = compute_sha1_from_file(tmp_file.name)
os.remove(tmp_file.name)
def compute_documents(self, loader_class):
"""
Compute the documents from the file
@ -69,18 +41,8 @@ class File(BaseModel):
loader_class (class): The class of the loader to use to load the file
"""
logger.info(f"Computing documents from file {self.file_name}")
documents = []
with tempfile.NamedTemporaryFile(
delete=False,
suffix=self.file.filename, # pyright: ignore reportPrivateUsage=none
) as tmp_file:
tmp_file.write(self.content) # pyright: ignore reportPrivateUsage=none
tmp_file.flush()
loader = loader_class(tmp_file.name)
documents = loader.load()
os.remove(tmp_file.name)
loader = loader_class(self.tmp_file_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
@ -129,7 +91,7 @@ class File(BaseModel):
"""
Check if file is empty by checking if the file pointer is at the beginning of the file
"""
return self.file.size < 1 # pyright: ignore reportPrivateUsage=none
return self.file_size < 1 # pyright: ignore reportPrivateUsage=none
def link_file_to_brain(self, brain_id):
self.set_file_vectors_ids()

View File

@ -314,7 +314,7 @@ class QuivrRAG(BaseModel):
self.brain_id
) # pyright: ignore reportPrivateUsage=none
list_files_array = [file.file_name for file in list_files_array]
list_files_array = [file.file_name or file.url for file in list_files_array]
# Max first 10 files
if len(list_files_array) > 20:
list_files_array = list_files_array[:20]

View File

@ -74,7 +74,7 @@ async def upload_file(
filename_with_brain_id = str(brain_id) + "/" + str(uploadFile.filename)
try:
file_in_storage = upload_file_storage(file_content, filename_with_brain_id)
upload_file_storage(file_content, filename_with_brain_id)
except Exception as e:
print(e)
@ -104,7 +104,7 @@ async def upload_file(
)[-1].lower(),
)
added_knowledge = knowledge_service.add_knowledge(knowledge_to_add)
knowledge_service.add_knowledge(knowledge_to_add)
process_file_and_notify.delay(
file_name=filename_with_brain_id,

View File

@ -1,6 +1,5 @@
import os
import re
import tempfile
import unicodedata
from langchain_community.document_loaders import PlaywrightURLLoader
@ -17,27 +16,21 @@ class CrawlWebsite(BaseModel):
max_pages: int = 100
max_time: int = 60
def process(self):
def process(self) -> str:
# Extract and combine content recursively
loader = PlaywrightURLLoader(
urls=[self.url], remove_selectors=["header", "footer"]
)
data = loader.load()
data = loader.load()
# Now turn the data into a string
logger.info(f"Extracted content from {len(data)} pages")
logger.info(data)
logger.debug(f"Extracted data : {data}")
extracted_content = ""
for page in data:
extracted_content += page.page_content
# Create a file
file_name = slugify(self.url) + ".txt"
temp_file_path = os.path.join(tempfile.gettempdir(), file_name)
with open(temp_file_path, "w") as temp_file:
temp_file.write(extracted_content) # type: ignore
return temp_file_path, file_name
return extracted_content
def checkGithub(self):
return "github.com" in self.url

View File

@ -1,5 +1,3 @@
import os
import tempfile
import time
import openai
@ -9,33 +7,13 @@ from models import File, get_documents_vector_store
from packages.files.file import compute_sha1_from_content
async def process_audio(
file: File, user, original_file_name, integration=None, integration_link=None
):
temp_filename = None
file_sha = ""
def process_audio(file: File, **kwargs):
dateshort = time.strftime("%Y%m%d-%H%M%S")
file_meta_name = f"audiotranscript_{dateshort}.txt"
documents_vector_store = get_documents_vector_store()
try:
upload_file = file.file
with tempfile.NamedTemporaryFile(
delete=False,
suffix=upload_file.filename, # pyright: ignore reportPrivateUsage=none
) as tmp_file:
await upload_file.seek(0) # pyright: ignore reportPrivateUsage=none
content = (
await upload_file.read() # pyright: ignore reportPrivateUsage=none
)
tmp_file.write(content)
tmp_file.flush()
tmp_file.close()
temp_filename = tmp_file.name
with open(tmp_file.name, "rb") as audio_file:
transcript = openai.Audio.transcribe("whisper-1", audio_file)
with open(file.tmp_file_path, "rb") as audio_file:
transcript = openai.Audio.transcribe("whisper-1", audio_file)
file_sha = compute_sha1_from_content(
transcript.text.encode("utf-8") # pyright: ignore reportPrivateUsage=none
@ -70,7 +48,3 @@ async def process_audio(
]
documents_vector_store.add_documents(docs_with_metadata)
finally:
if temp_filename and os.path.exists(temp_filename):
os.remove(temp_filename)

View File

@ -4,10 +4,10 @@ from models import File
from .common import process_file
async def process_python(
def process_python(
file: File, brain_id, original_file_name, integration=None, integration_link=None
):
return await process_file(
return process_file(
file=file,
loader_class=PythonLoader,
brain_id=brain_id,

View File

@ -21,7 +21,7 @@ if not isinstance(asyncio.get_event_loop(), uvloop.Loop):
logger = get_logger(__name__)
async def process_file(
def process_file(
file: File,
loader_class,
brain_id,

View File

@ -9,7 +9,7 @@ from packages.embeddings.vectors import Neurons
from packages.files.file import compute_sha1_from_content
async def process_github(
def process_github(
repo,
brain_id,
):

View File

@ -5,7 +5,11 @@ from .common import process_file
def process_pdf(
file: File, brain_id, original_file_name, integration=None, integration_link=None
file: File,
brain_id,
original_file_name,
integration=None,
integration_link=None,
):
return process_file(
file=file,

View File

@ -4,10 +4,10 @@ from models import File
from .common import process_file
async def process_txt(
def process_txt(
file: File, brain_id, original_file_name, integration=None, integration_link=None
):
return await process_file(
return process_file(
file=file,
loader_class=TextLoader,
brain_id=brain_id,

View File

@ -1,4 +1,3 @@
from fastapi import HTTPException
from modules.brain.service.brain_service import BrainService
from .parsers.audio import process_audio
@ -52,16 +51,14 @@ brain_service = BrainService()
# TODO: Move filter_file to a file service to avoid circular imports from models/files.py for File class
async def filter_file(
def filter_file(
file,
brain_id,
original_file_name=None,
):
await file.compute_file_sha1()
file_exists = file.file_already_exists()
file_exists_in_brain = file.file_already_exists_in_brain(brain_id)
using_file_name = original_file_name or file.file.filename if file.file else ""
using_file_name = file.file_name
brain = brain_service.get_brain_by_id(brain_id)
if brain is None:
@ -86,7 +83,7 @@ async def filter_file(
if file.file_extension in file_processors:
try:
result = await file_processors[file.file_extension](
result = file_processors[file.file_extension](
file=file,
brain_id=brain_id,
original_file_name=original_file_name,