quivr/backend/models/files.py

143 lines
4.5 KiB
Python
Raw Normal View History

2023-06-28 20:39:27 +03:00
import os
import tempfile
from typing import Any, Optional
from uuid import UUID
from fastapi import UploadFile
from langchain.text_splitter import RecursiveCharacterTextSplitter
from logger import get_logger
from models.databases.supabase.supabase import SupabaseDB
from models.settings import get_supabase_db
from modules.brain.service.brain_vector_service import BrainVectorService
from packages.files.file import compute_sha1_from_file
from pydantic import BaseModel
2023-06-28 20:39:27 +03:00
logger = get_logger(__name__)
2023-06-29 19:26:03 +03:00
2023-06-28 20:39:27 +03:00
class File(BaseModel):
id: Optional[UUID] = None
file: Optional[UploadFile]
file_name: Optional[str] = ""
2023-07-10 20:28:38 +03:00
file_size: Optional[int] = None
2023-06-28 20:39:27 +03:00
file_sha1: Optional[str] = ""
2023-07-10 20:28:38 +03:00
vectors_ids: Optional[list] = []
2023-06-28 20:39:27 +03:00
file_extension: Optional[str] = ""
2023-06-29 19:26:03 +03:00
content: Optional[Any] = None
2023-06-28 20:39:27 +03:00
chunk_size: int = 500
2023-06-29 19:26:03 +03:00
chunk_overlap: int = 0
documents: Optional[Any] = None
2023-06-28 20:39:27 +03:00
@property
def supabase_db(self) -> SupabaseDB:
return get_supabase_db()
2023-06-28 20:39:27 +03:00
def __init__(self, **kwargs):
super().__init__(**kwargs)
2023-06-28 20:39:27 +03:00
if self.file:
self.file_name = self.file.filename
self.file_size = self.file.size # pyright: ignore reportPrivateUsage=none
self.file_extension = os.path.splitext(
self.file.filename # pyright: ignore reportPrivateUsage=none
)[-1].lower()
2023-06-28 20:39:27 +03:00
async def compute_file_sha1(self):
2023-07-10 20:28:38 +03:00
"""
Compute the sha1 of the file using a temporary file
"""
with tempfile.NamedTemporaryFile(
delete=False,
suffix=self.file.filename, # pyright: ignore reportPrivateUsage=none
) as tmp_file:
await self.file.seek(0) # pyright: ignore reportPrivateUsage=none
self.content = (
await self.file.read() # pyright: ignore reportPrivateUsage=none
)
2023-06-28 20:39:27 +03:00
tmp_file.write(self.content)
tmp_file.flush()
self.file_sha1 = compute_sha1_from_file(tmp_file.name)
os.remove(tmp_file.name)
def compute_documents(self, loader_class):
2023-07-10 20:28:38 +03:00
"""
Compute the documents from the file
Args:
loader_class (class): The class of the loader to use to load the file
"""
2023-06-28 20:39:27 +03:00
logger.info(f"Computing documents from file {self.file_name}")
2023-06-28 20:39:27 +03:00
documents = []
with tempfile.NamedTemporaryFile(
delete=False,
suffix=self.file.filename, # pyright: ignore reportPrivateUsage=none
) as tmp_file:
tmp_file.write(self.content) # pyright: ignore reportPrivateUsage=none
2023-06-28 20:39:27 +03:00
tmp_file.flush()
loader = loader_class(tmp_file.name)
documents = loader.load()
2023-06-28 20:39:27 +03:00
os.remove(tmp_file.name)
2023-06-28 20:39:27 +03:00
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
self.documents = text_splitter.split_documents(documents)
def set_file_vectors_ids(self):
2023-06-29 19:26:03 +03:00
"""
Set the vectors_ids property with the ids of the vectors
2023-06-29 19:26:03 +03:00
that are associated with the file in the vectors table
"""
self.vectors_ids = self.supabase_db.get_vectors_by_file_sha1(
self.file_sha1
).data
2023-06-29 19:26:03 +03:00
def file_already_exists(self):
"""
Check if file already exists in vectors table
"""
2023-06-28 20:39:27 +03:00
self.set_file_vectors_ids()
2023-06-29 19:26:03 +03:00
# if the file does not exist in vectors then no need to go check in brains_vectors
if len(self.vectors_ids) == 0: # pyright: ignore reportPrivateUsage=none
2023-06-28 20:39:27 +03:00
return False
2023-06-28 20:39:27 +03:00
return True
2023-06-29 19:26:03 +03:00
def file_already_exists_in_brain(self, brain_id):
2023-07-10 20:28:38 +03:00
"""
Check if file already exists in a brain
Args:
brain_id (str): Brain id
"""
response = self.supabase_db.get_brain_vectors_by_brain_id_and_file_sha1(
brain_id, self.file_sha1 # type: ignore
)
2023-06-29 19:26:03 +03:00
if len(response.data) == 0:
return False
2023-06-29 19:26:03 +03:00
return True
2023-06-28 20:39:27 +03:00
def file_is_empty(self):
2023-07-10 20:28:38 +03:00
"""
Check if file is empty by checking if the file pointer is at the beginning of the file
"""
return self.file.size < 1 # pyright: ignore reportPrivateUsage=none
def link_file_to_brain(self, brain_id):
2023-06-29 19:26:03 +03:00
self.set_file_vectors_ids()
2023-07-10 20:28:38 +03:00
if self.vectors_ids is None:
return
brain_vector_service = BrainVectorService(brain_id)
for vector_id in self.vectors_ids: # pyright: ignore reportPrivateUsage=none
brain_vector_service.create_brain_vector(vector_id["id"], self.file_sha1)