diff --git a/.github/workflows/backend-tests.yml b/.github/workflows/backend-tests.yml index 013d207c3..cc77192c6 100644 --- a/.github/workflows/backend-tests.yml +++ b/.github/workflows/backend-tests.yml @@ -9,7 +9,9 @@ on: jobs: test: runs-on: ubuntu-latest - + strategy: + matrix: + project: [quivr-api, quivr-worker] steps: - name: 👀 Checkout code uses: actions/checkout@v2 @@ -65,4 +67,4 @@ jobs: supabase start rye run python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()" rye run python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')" - rye test -p quivr-api -p quivr-worker + rye test -p ${{ matrix.project }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aabdcae30..9496988b0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,13 +32,6 @@ repos: - id: mypy name: mypy additional_dependencies: ["types-aiofiles"] - - repo: https://github.com/python-poetry/poetry - rev: "1.8.0" - hooks: - - id: poetry-check - args: ["-C", "./backend/core"] - - id: poetry-lock - args: ["-C", "./backend/core"] ci: autofix_commit_msg: | [pre-commit.ci] auto fixes from pre-commit.com hooks diff --git a/backend/api/quivr_api/middlewares/auth/auth_bearer.py b/backend/api/quivr_api/middlewares/auth/auth_bearer.py index 3001b7f45..73e3867cf 100644 --- a/backend/api/quivr_api/middlewares/auth/auth_bearer.py +++ b/backend/api/quivr_api/middlewares/auth/auth_bearer.py @@ -3,6 +3,7 @@ from typing import Optional from fastapi import Depends, HTTPException, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + from quivr_api.middlewares.auth.jwt_token_handler import ( decode_access_token, verify_token, @@ -57,9 +58,13 @@ class AuthBearer(HTTPBearer): def get_test_user(self) -> UserIdentity: return UserIdentity( - email="admin@quivr.app", id="39418e3b-0258-4452-af60-7acfcc1263ff" # type: ignore + email="admin@quivr.app", + id="39418e3b-0258-4452-af60-7acfcc1263ff", # type: ignore ) # replace with test user information -def get_current_user(user: UserIdentity = Depends(AuthBearer())) -> UserIdentity: +auth_bearer = AuthBearer() + + +def get_current_user(user: UserIdentity = Depends(auth_bearer)) -> UserIdentity: return user diff --git a/backend/api/quivr_api/modules/brain/entity/brain_entity.py b/backend/api/quivr_api/modules/brain/entity/brain_entity.py index 6a722bda7..0b8e3460c 100644 --- a/backend/api/quivr_api/modules/brain/entity/brain_entity.py +++ b/backend/api/quivr_api/modules/brain/entity/brain_entity.py @@ -69,6 +69,7 @@ class Brain(AsyncAttrs, SQLModel, table=True): back_populates="brains", link_model=KnowledgeBrain ) + # TODO : add # "meaning" "public"."vector", # "tags" "public"."tags"[] diff --git a/backend/api/quivr_api/modules/brain/service/brain_vector_service.py b/backend/api/quivr_api/modules/brain/service/brain_vector_service.py index 4016b8a7e..ec514cdd7 100644 --- a/backend/api/quivr_api/modules/brain/service/brain_vector_service.py +++ b/backend/api/quivr_api/modules/brain/service/brain_vector_service.py @@ -2,7 +2,7 @@ from uuid import UUID from quivr_api.logger import get_logger from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors -from quivr_api.modules.knowledge.repository.storage import Storage +from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage logger = get_logger(__name__) @@ -11,7 +11,7 @@ class BrainVectorService: def __init__(self, brain_id: UUID): self.repository = BrainsVectors() self.brain_id = brain_id - self.storage = Storage() + self.storage = SupabaseS3Storage() def create_brain_vector(self, vector_id: str, file_sha1: str): return self.repository.create_brain_vector(self.brain_id, vector_id, file_sha1) # type: ignore @@ -26,10 +26,10 @@ class BrainVectorService: for vector_id in vector_ids: self.create_brain_vector(vector_id, file_sha1) - def delete_file_from_brain(self, file_name: str, only_vectors: bool = False): + async def delete_file_from_brain(self, file_name: str, only_vectors: bool = False): file_name_with_brain_id = f"{self.brain_id}/{file_name}" if not only_vectors: - self.storage.remove_file(file_name_with_brain_id) + await self.storage.remove_file(file_name_with_brain_id) return self.repository.delete_file_from_brain(self.brain_id, file_name) # type: ignore def delete_file_url_from_brain(self, file_name: str): diff --git a/backend/api/quivr_api/modules/conftest.py b/backend/api/quivr_api/modules/conftest.py index 721eacaeb..d9def549c 100644 --- a/backend/api/quivr_api/modules/conftest.py +++ b/backend/api/quivr_api/modules/conftest.py @@ -24,9 +24,6 @@ async_engine = create_async_engine( "postgresql+asyncpg://" + pg_database_base_url, echo=True if os.getenv("ORM_DEBUG") else False, future=True, - pool_pre_ping=True, - pool_size=10, - pool_recycle=0.1, ) diff --git a/backend/api/quivr_api/modules/dependencies.py b/backend/api/quivr_api/modules/dependencies.py index edb6f728b..fd71696cd 100644 --- a/backend/api/quivr_api/modules/dependencies.py +++ b/backend/api/quivr_api/modules/dependencies.py @@ -7,8 +7,7 @@ from langchain.embeddings.base import Embeddings from langchain_community.embeddings.ollama import OllamaEmbeddings # from langchain_community.vectorstores.supabase import SupabaseVectorStore -from langchain_openai import OpenAIEmbeddings -from langchain_openai import AzureOpenAIEmbeddings +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings # from quivr_api.modules.vector.service.vector_service import VectorService # from quivr_api.modules.vectorstore.supabase import CustomSupabaseVectorStore @@ -22,7 +21,6 @@ from quivr_api.models.databases.supabase.supabase import SupabaseDB from quivr_api.models.settings import BrainSettings from supabase.client import AsyncClient, Client, create_async_client, create_client - # Global variables to store the Supabase client and database instances _supabase_client: Optional[Client] = None _supabase_async_client: Optional[AsyncClient] = None diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index 5003eb8fb..68d01afb0 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -1,8 +1,8 @@ from http import HTTPStatus -from typing import Annotated +from typing import Annotated, List, Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile, status from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user @@ -12,6 +12,14 @@ from quivr_api.modules.brain.service.brain_authorization_service import ( validate_brain_authorization, ) from quivr_api.modules.dependencies import get_service +from quivr_api.modules.knowledge.dto.inputs import AddKnowledge +from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeUpdate +from quivr_api.modules.knowledge.service.knowledge_exceptions import ( + KnowledgeDeleteError, + KnowledgeForbiddenAccess, + KnowledgeNotFoundException, + UploadError, +) from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.upload.service.generate_file_signed_url import ( generate_file_signed_url, @@ -21,9 +29,8 @@ from quivr_api.modules.user.entity.user_identity import UserIdentity knowledge_router = APIRouter() logger = get_logger(__name__) -KnowledgeServiceDep = Annotated[ - KnowledgeService, Depends(get_service(KnowledgeService)) -] +get_km_service = get_service(KnowledgeService) +KnowledgeServiceDep = Annotated[KnowledgeService, Depends(get_km_service)] @knowledge_router.get( @@ -53,7 +60,7 @@ async def list_knowledge_in_brain_endpoint( ], tags=["Knowledge"], ) -async def delete_endpoint( +async def delete_knowledge_brain( knowledge_id: UUID, knowledge_service: KnowledgeServiceDep, current_user: UserIdentity = Depends(get_current_user), @@ -65,7 +72,7 @@ async def delete_endpoint( knowledge = await knowledge_service.get_knowledge(knowledge_id) file_name = knowledge.file_name if knowledge.file_name else knowledge.url - await knowledge_service.remove_knowledge(brain_id, knowledge_id) + await knowledge_service.remove_knowledge_brain(brain_id, knowledge_id) return { "message": f"{file_name} of brain {brain_id} has been deleted by user {current_user.email}." @@ -88,13 +95,13 @@ async def generate_signed_url_endpoint( knowledge = await knowledge_service.get_knowledge(knowledge_id) - if len(knowledge.brain_ids) == 0: + if len(knowledge.brains) == 0: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="knowledge not associated with brains yet.", ) - brain_id = knowledge.brain_ids[0] + brain_id = knowledge.brains[0]["brain_id"] validate_brain_authorization(brain_id=brain_id, user_id=current_user.id) @@ -108,3 +115,153 @@ async def generate_signed_url_endpoint( file_signed_url = generate_file_signed_url(file_path_in_storage) return file_signed_url + + +@knowledge_router.post( + "/knowledge/", + tags=["Knowledge"], + response_model=Knowledge, +) +async def create_knowledge( + knowledge_data: str = File(...), + file: Optional[UploadFile] = None, + knowledge_service: KnowledgeService = Depends(get_km_service), + current_user: UserIdentity = Depends(get_current_user), +): + knowledge = AddKnowledge.model_validate_json(knowledge_data) + if not knowledge.file_name and not knowledge.url: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Either file_name or url must be provided", + ) + try: + km = await knowledge_service.create_knowledge( + knowledge_to_add=knowledge, upload_file=file, user_id=current_user.id + ) + km_dto = await km.to_dto() + return km_dto + except ValueError: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Unprocessable knowledge ", + ) + except FileExistsError: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, detail="Existing knowledge" + ) + except UploadError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error occured uploading knowledge", + ) + except Exception: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +@knowledge_router.get( + "/knowledge/children", + response_model=List[Knowledge] | None, + tags=["Knowledge"], +) +async def list_knowledge( + parent_id: UUID | None = None, + knowledge_service: KnowledgeService = Depends(get_km_service), + current_user: UserIdentity = Depends(get_current_user), +): + try: + # TODO: Returns one level of children + children = await knowledge_service.list_knowledge(parent_id, current_user.id) + return [await c.to_dto(get_children=False) for c in children] + except KnowledgeNotFoundException as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=f"{e.message}" + ) + except KnowledgeForbiddenAccess as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" + ) + except Exception: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +@knowledge_router.get( + "/knowledge/{knowledge_id}", + response_model=Knowledge, + tags=["Knowledge"], +) +async def get_knowledge( + knowledge_id: UUID, + knowledge_service: KnowledgeService = Depends(get_km_service), + current_user: UserIdentity = Depends(get_current_user), +): + try: + km = await knowledge_service.get_knowledge(knowledge_id) + if km.user_id != current_user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to access this knowledge.", + ) + return await km.to_dto() + except KnowledgeNotFoundException as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" + ) + except Exception: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +@knowledge_router.patch( + "/knowledge/{knowledge_id}", + status_code=status.HTTP_202_ACCEPTED, + response_model=Knowledge, + tags=["Knowledge"], +) +async def update_knowledge( + knowledge_id: UUID, + payload: KnowledgeUpdate, + knowledge_service: KnowledgeService = Depends(get_km_service), + current_user: UserIdentity = Depends(get_current_user), +): + try: + km = await knowledge_service.get_knowledge(knowledge_id) + if km.user_id != current_user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to access this knowledge.", + ) + km = await knowledge_service.update_knowledge(km, payload) + return km + except KnowledgeNotFoundException as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" + ) + except Exception: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +@knowledge_router.delete( + "/knowledge/{knowledge_id}", + status_code=status.HTTP_202_ACCEPTED, + tags=["Knowledge"], +) +async def delete_knowledge( + knowledge_id: UUID, + knowledge_service: KnowledgeService = Depends(get_km_service), + current_user: UserIdentity = Depends(get_current_user), +): + try: + km = await knowledge_service.get_knowledge(knowledge_id) + + if km.user_id != current_user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to remove this knowledge.", + ) + delete_response = await knowledge_service.remove_knowledge(km) + return delete_response + except KnowledgeNotFoundException as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" + ) + except KnowledgeDeleteError: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/backend/api/quivr_api/modules/knowledge/dto/inputs.py b/backend/api/quivr_api/modules/knowledge/dto/inputs.py index a943ee6b4..85a2438e9 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/inputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/inputs.py @@ -16,8 +16,16 @@ class CreateKnowledgeProperties(BaseModel): file_size: Optional[int] = None file_sha1: Optional[str] = None metadata: Optional[Dict[str, str]] = None + is_folder: bool = False + parent_id: Optional[UUID] = None - def dict(self, *args, **kwargs): - knowledge_dict = super().dict(*args, **kwargs) - knowledge_dict["brain_id"] = str(knowledge_dict.get("brain_id")) - return knowledge_dict + +class AddKnowledge(BaseModel): + file_name: Optional[str] = None + url: Optional[str] = None + extension: str = ".txt" + source: str = "local" + source_link: Optional[str] = None + metadata: Optional[Dict[str, str]] = None + is_folder: bool = False + parent_id: Optional[UUID] = None diff --git a/backend/api/quivr_api/modules/knowledge/dto/outputs.py b/backend/api/quivr_api/modules/knowledge/dto/outputs.py index a020dbece..20218dfce 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/outputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/outputs.py @@ -4,6 +4,6 @@ from pydantic import BaseModel class DeleteKnowledgeResponse(BaseModel): - file_name: str - status: str = "delete" + file_name: str | None = None + status: str = "DELETED" knowledge_id: UUID diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py index def4e42f5..d890ee42d 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py @@ -1,5 +1,6 @@ from datetime import datetime -from typing import Dict, List, Optional +from enum import Enum +from typing import Any, Dict, List, Optional from uuid import UUID from pydantic import BaseModel @@ -12,20 +13,44 @@ from sqlmodel import Field, Relationship, SQLModel from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain +class KnowledgeSource(str, Enum): + LOCAL = "local" + WEB = "web" + GDRIVE = "google drive" + DROPBOX = "dropbox" + SHAREPOINT = "sharepoint" + + class Knowledge(BaseModel): id: UUID + file_size: int = 0 + status: KnowledgeStatus file_name: Optional[str] = None url: Optional[str] = None extension: str = ".txt" - status: str + is_folder: bool = False + updated_at: datetime + created_at: datetime source: Optional[str] = None source_link: Optional[str] = None - file_size: Optional[int] = None file_sha1: Optional[str] = None - updated_at: Optional[datetime] = None - created_at: Optional[datetime] = None metadata: Optional[Dict[str, str]] = None - brain_ids: list[UUID] + user_id: UUID + brains: List[Dict[str, Any]] + parent: Optional["Knowledge"] + children: Optional[list["Knowledge"]] + + +class KnowledgeUpdate(BaseModel): + file_name: Optional[str] = None + status: Optional[KnowledgeStatus] = None + url: Optional[str] = None + file_sha1: Optional[str] = None + extension: Optional[str] = None + parent_id: Optional[UUID] = None + source: Optional[str] = None + source_link: Optional[str] = None + metadata: Optional[Dict[str, str]] = None class KnowledgeDB(AsyncAttrs, SQLModel, table=True): @@ -49,13 +74,6 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True): file_sha1: Optional[str] = Field( max_length=40 ) # FIXME: Should not be optional @chloedia - updated_at: datetime | None = Field( - default=None, - sa_column=Column( - TIMESTAMP(timezone=False), - server_default=text("CURRENT_TIMESTAMP"), - ), - ) created_at: datetime | None = Field( default=None, sa_column=Column( @@ -63,9 +81,18 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True): server_default=text("CURRENT_TIMESTAMP"), ), ) + updated_at: datetime | None = Field( + default=None, + sa_column=Column( + TIMESTAMP(timezone=False), + server_default=text("CURRENT_TIMESTAMP"), + onupdate=datetime.utcnow, + ), + ) metadata_: Optional[Dict[str, str]] = Field( default=None, sa_column=Column("metadata", JSON) ) + is_folder: bool = Field(default=False) user_id: UUID = Field(foreign_key="users.id", nullable=False) brains: List["Brain"] = Relationship( back_populates="knowledges", @@ -73,10 +100,35 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True): sa_relationship_kwargs={"lazy": "select"}, ) - async def to_dto(self) -> Knowledge: + parent_id: UUID | None = Field( + default=None, foreign_key="knowledge.id", ondelete="CASCADE" + ) + parent: Optional["KnowledgeDB"] = Relationship( + back_populates="children", + sa_relationship_kwargs={"remote_side": "KnowledgeDB.id"}, + ) + children: list["KnowledgeDB"] = Relationship( + back_populates="parent", + sa_relationship_kwargs={ + "cascade": "all, delete-orphan", + }, + ) + + # TODO: nested folder search + async def to_dto(self, get_children: bool = True) -> Knowledge: + assert ( + self.updated_at + ), "knowledge should be inserted before transforming to dto" + assert ( + self.created_at + ), "knowledge should be inserted before transforming to dto" brains = await self.awaitable_attrs.brains - size = self.file_size if self.file_size else 0 - sha1 = self.file_sha1 if self.file_sha1 else "" + children: list[KnowledgeDB] = ( + await self.awaitable_attrs.children if get_children else [] + ) + parent = await self.awaitable_attrs.parent + parent = await parent.to_dto(get_children=False) if parent else None + return Knowledge( id=self.id, # type: ignore file_name=self.file_name, @@ -85,10 +137,14 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True): status=KnowledgeStatus(self.status), source=self.source, source_link=self.source_link, - file_size=size, - file_sha1=sha1, + is_folder=self.is_folder, + file_size=self.file_size or 0, + file_sha1=self.file_sha1, updated_at=self.updated_at, created_at=self.created_at, metadata=self.metadata_, # type: ignore - brain_ids=[brain.brain_id for brain in brains], + brains=[b.model_dump() for b in brains], + parent=parent, + children=[await c.to_dto(get_children=False) for c in children], + user_id=self.user_id, ) diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index 436e24061..427b3be06 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -1,9 +1,10 @@ -from typing import Sequence +from typing import Any, Sequence from uuid import UUID from fastapi import HTTPException from quivr_core.models import KnowledgeStatus from sqlalchemy.exc import IntegrityError, NoResultFound +from sqlalchemy.orm import joinedload from sqlmodel import select, text from sqlmodel.ext.asyncio.session import AsyncSession @@ -11,7 +12,15 @@ from quivr_api.logger import get_logger from quivr_api.modules.brain.entity.brain_entity import Brain from quivr_api.modules.dependencies import BaseRepository, get_supabase_client from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse -from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_api.modules.knowledge.entity.knowledge import ( + Knowledge, + KnowledgeDB, + KnowledgeUpdate, +) +from quivr_api.modules.knowledge.service.knowledge_exceptions import ( + KnowledgeNotFoundException, + KnowledgeUpdateError, +) logger = get_logger(__name__) @@ -22,7 +31,43 @@ class KnowledgeRepository(BaseRepository): supabase_client = get_supabase_client() self.db = supabase_client - async def insert_knowledge( + async def create_knowledge(self, knowledge: KnowledgeDB) -> KnowledgeDB: + try: + self.session.add(knowledge) + await self.session.commit() + await self.session.refresh(knowledge) + except IntegrityError: + await self.session.rollback() + raise + except Exception: + await self.session.rollback() + raise + return knowledge + + async def update_knowledge( + self, + knowledge: KnowledgeDB, + payload: Knowledge | KnowledgeUpdate | dict[str, Any], + ) -> KnowledgeDB: + try: + logger.debug(f"updating {knowledge.id} with payload {payload}") + if isinstance(payload, dict): + update_data = payload + else: + update_data = payload.model_dump(exclude_unset=True) + for field in update_data: + setattr(knowledge, field, update_data[field]) + + self.session.add(knowledge) + await self.session.commit() + await self.session.refresh(knowledge) + return knowledge + except IntegrityError as e: + await self.session.rollback() + logger.error(f"Error updating knowledge {e}") + raise KnowledgeUpdateError + + async def insert_knowledge_brain( self, knowledge: KnowledgeDB, brain_id: UUID ) -> KnowledgeDB: logger.debug(f"Inserting knowledge {knowledge}") @@ -69,6 +114,14 @@ class KnowledgeRepository(BaseRepository): await self.session.refresh(knowledge) return knowledge + async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeResponse: + assert knowledge.id + await self.session.delete(knowledge) + await self.session.commit() + return DeleteKnowledgeResponse( + status="deleted", knowledge_id=knowledge.id, file_name=knowledge.file_name + ) + async def remove_knowledge_by_id( self, knowledge_id: UUID ) -> DeleteKnowledgeResponse: @@ -126,14 +179,70 @@ class KnowledgeRepository(BaseRepository): return knowledge - async def get_knowledge_by_id(self, knowledge_id: UUID) -> KnowledgeDB: - query = select(KnowledgeDB).where(KnowledgeDB.id == knowledge_id) + async def get_all_children(self, parent_id: UUID) -> list[KnowledgeDB]: + query = text(""" + WITH RECURSIVE knowledge_tree AS ( + SELECT * + FROM knowledge + WHERE parent_id = :parent_id + UNION ALL + SELECT k.* + FROM knowledge k + JOIN knowledge_tree kt ON k.parent_id = kt.id + ) + SELECT * FROM knowledge_tree + """) + + result = await self.session.execute(query, params={"parent_id": parent_id}) + rows = result.fetchall() + knowledge_list = [] + for row in rows: + knowledge = KnowledgeDB( + id=row.id, + parent_id=row.parent_id, + file_name=row.file_name, + url=row.url, + extension=row.extension, + status=row.status, + source=row.source, + source_link=row.source_link, + file_size=row.file_size, + file_sha1=row.file_sha1, + created_at=row.created_at, + updated_at=row.updated_at, + metadata_=row.metadata, + is_folder=row.is_folder, + user_id=row.user_id, + ) + knowledge_list.append(knowledge) + + return knowledge_list + + async def get_root_knowledge_user(self, user_id: UUID) -> list[KnowledgeDB]: + query = ( + select(KnowledgeDB) + .where(KnowledgeDB.parent_id.is_(None)) # type: ignore + .where(KnowledgeDB.user_id == user_id) + .options(joinedload(KnowledgeDB.parent), joinedload(KnowledgeDB.children)) # type: ignore + ) + result = await self.session.exec(query) + kms = result.unique().all() + return list(kms) + + async def get_knowledge_by_id( + self, knowledge_id: UUID, user_id: UUID | None = None + ) -> KnowledgeDB: + query = ( + select(KnowledgeDB) + .where(KnowledgeDB.id == knowledge_id) + .options(joinedload(KnowledgeDB.parent), joinedload(KnowledgeDB.children)) # type: ignore + ) + if user_id: + query = query.where(KnowledgeDB.user_id == user_id) result = await self.session.exec(query) knowledge = result.first() - if not knowledge: - raise NoResultFound("Knowledge not found") - + raise KnowledgeNotFoundException("Knowledge not found") return knowledge async def get_brain_by_id(self, brain_id: UUID) -> Brain: diff --git a/backend/api/quivr_api/modules/knowledge/repository/storage.py b/backend/api/quivr_api/modules/knowledge/repository/storage.py index 47120ba5b..0e58e25d9 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/storage.py +++ b/backend/api/quivr_api/modules/knowledge/repository/storage.py @@ -1,29 +1,87 @@ +import mimetypes +from io import BufferedReader, FileIO + from quivr_api.logger import get_logger -from quivr_api.modules.dependencies import get_supabase_client +from quivr_api.modules.dependencies import get_supabase_async_client +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface logger = get_logger(__name__) -class Storage(StorageInterface): +class SupabaseS3Storage(StorageInterface): def __init__(self): - supabase_client = get_supabase_client() - self.db = supabase_client + self.client = None - def upload_file(self, file_name: str): - """ - Upload file to storage - """ - self.db.storage.from_("quivr").download(file_name) + async def _set_client(self): + if self.client is None: + self.client = await get_supabase_async_client() - def remove_file(self, file_name: str): + def get_storage_path( + self, + knowledge: KnowledgeDB, + ) -> str: + if knowledge.id is None: + raise ValueError("knowledge should have a valid id") + return str(knowledge.id) + + async def upload_file_storage( + self, + knowledge: KnowledgeDB, + knowledge_data: FileIO | BufferedReader | bytes, + upsert: bool = False, + ): + await self._set_client() + assert self.client + + mime_type = "application/html" + if knowledge.file_name: + guessed_mime_type, _ = mimetypes.guess_type(knowledge.file_name) + mime_type = guessed_mime_type or mime_type + + storage_path = self.get_storage_path(knowledge) + logger.info( + f"Uploading file to s3://quivr/{storage_path} using supabase. upsert={upsert}, mimetype={mime_type}" + ) + + if upsert: + _ = await self.client.storage.from_("quivr").update( + storage_path, + knowledge_data, + file_options={ + "content-type": mime_type, + "upsert": "true", + "cache-control": "3600", + }, + ) + return storage_path + else: + # check if file sha1 is already in storage + try: + _ = await self.client.storage.from_("quivr").upload( + storage_path, + knowledge_data, + file_options={ + "content-type": mime_type, + "upsert": "false", + "cache-control": "3600", + }, + ) + return storage_path + + except Exception as e: + if "The resource already exists" in str(e) and not upsert: + raise FileExistsError(f"File {storage_path} already exists") + raise e + + async def remove_file(self, storage_path: str): """ Remove file from storage """ + await self._set_client() + assert self.client try: - response = self.db.storage.from_("quivr").remove([file_name]) + response = await self.client.storage.from_("quivr").remove([storage_path]) return response except Exception as e: logger.error(e) - # raise e - diff --git a/backend/api/quivr_api/modules/knowledge/repository/storage_interface.py b/backend/api/quivr_api/modules/knowledge/repository/storage_interface.py index 228c99827..bd5a3debc 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/storage_interface.py +++ b/backend/api/quivr_api/modules/knowledge/repository/storage_interface.py @@ -1,10 +1,26 @@ from abc import ABC, abstractmethod +from io import BufferedReader, FileIO + +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB class StorageInterface(ABC): @abstractmethod - def remove_file(self, file_name: str): - """ - Remove file from storage - """ + def get_storage_path( + self, + knowledge: KnowledgeDB, + ) -> str: + pass + + @abstractmethod + async def upload_file_storage( + self, + knowledge: KnowledgeDB, + knowledge_data: FileIO | BufferedReader | bytes, + upsert: bool = False, + ): + pass + + @abstractmethod + async def remove_file(self, storage_path: str): pass diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_exceptions.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_exceptions.py new file mode 100644 index 000000000..c95cefa45 --- /dev/null +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_exceptions.py @@ -0,0 +1,34 @@ +class KnowledgeException(Exception): + def __init__(self, message="A knowledge-related error occurred"): + self.message = message + super().__init__(self.message) + + +class UploadError(KnowledgeException): + def __init__(self, message="An error occurred while uploading"): + super().__init__(message) + + +class KnowledgeCreationError(KnowledgeException): + def __init__(self, message="An error occurred while creating"): + super().__init__(message) + + +class KnowledgeUpdateError(KnowledgeException): + def __init__(self, message="An error occurred while updating"): + super().__init__(message) + + +class KnowledgeDeleteError(KnowledgeException): + def __init__(self, message="An error occurred while deleting"): + super().__init__(message) + + +class KnowledgeForbiddenAccess(KnowledgeException): + def __init__(self, message="You do not have permission to access this knowledge."): + super().__init__(message) + + +class KnowledgeNotFoundException(KnowledgeException): + def __init__(self, message="The requested knowledge was not found"): + super().__init__(message) diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 987bff7d1..cfc88884b 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -1,18 +1,33 @@ -from typing import List +import asyncio +import io +from typing import Any, List from uuid import UUID +from fastapi import UploadFile from quivr_core.models import KnowledgeStatus from sqlalchemy.exc import NoResultFound from quivr_api.logger import get_logger from quivr_api.modules.dependencies import BaseService from quivr_api.modules.knowledge.dto.inputs import ( + AddKnowledge, CreateKnowledgeProperties, ) from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse -from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeDB +from quivr_api.modules.knowledge.entity.knowledge import ( + Knowledge, + KnowledgeDB, + KnowledgeSource, + KnowledgeUpdate, +) from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository -from quivr_api.modules.knowledge.repository.storage import Storage +from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage +from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface +from quivr_api.modules.knowledge.service.knowledge_exceptions import ( + KnowledgeDeleteError, + KnowledgeForbiddenAccess, + UploadError, +) from quivr_api.modules.sync.entity.sync_models import ( DBSyncFile, DownloadedSyncFile, @@ -26,9 +41,13 @@ logger = get_logger(__name__) class KnowledgeService(BaseService[KnowledgeRepository]): repository_cls = KnowledgeRepository - def __init__(self, repository: KnowledgeRepository): + def __init__( + self, + repository: KnowledgeRepository, + storage: StorageInterface = SupabaseS3Storage(), + ): self.repository = repository - self.storage = Storage() + self.storage = storage async def get_knowledge_sync(self, sync_id: int) -> Knowledge: km = await self.repository.get_knowledge_by_sync_id(sync_id) @@ -54,19 +73,37 @@ class KnowledgeService(BaseService[KnowledgeRepository]): except NoResultFound: raise FileNotFoundError(f"No knowledge for file_name: {file_name}") - async def get_knowledge(self, knowledge_id: UUID) -> Knowledge: - inserted_knowledge_db_instance = await self.repository.get_knowledge_by_id( - knowledge_id - ) - assert inserted_knowledge_db_instance.id, "Knowledge ID not generated" - km = await inserted_knowledge_db_instance.to_dto() - return km + async def list_knowledge( + self, knowledge_id: UUID | None, user_id: UUID | None = None + ) -> list[KnowledgeDB]: + if knowledge_id is not None: + km = await self.repository.get_knowledge_by_id(knowledge_id, user_id) + return km.children + else: + if user_id is None: + raise KnowledgeForbiddenAccess( + "can't get root knowledges without user_id" + ) + return await self.repository.get_root_knowledge_user(user_id) + async def get_knowledge( + self, knowledge_id: UUID, user_id: UUID | None = None + ) -> KnowledgeDB: + return await self.repository.get_knowledge_by_id(knowledge_id, user_id) + + async def update_knowledge( + self, + knowledge: KnowledgeDB, + payload: Knowledge | KnowledgeUpdate | dict[str, Any], + ): + return await self.repository.update_knowledge(knowledge, payload) + + # TODO: Remove all of this # TODO (@aminediro): Replace with ON CONFLICT smarter query... # there is a chance of race condition but for now we let it crash in worker # the tasks will be dealt with on retry async def update_sha1_conflict( - self, knowledge: Knowledge, brain_id: UUID, file_sha1: str + self, knowledge: KnowledgeDB, brain_id: UUID, file_sha1: str ) -> bool: assert knowledge.id knowledge.file_sha1 = file_sha1 @@ -89,12 +126,12 @@ class KnowledgeService(BaseService[KnowledgeRepository]): ) else: await self.repository.link_to_brain(existing_knowledge, brain_id) - await self.remove_knowledge(brain_id, knowledge.id) + await self.remove_knowledge_brain(brain_id, knowledge.id) return False else: logger.debug(f"Removing previous errored file {existing_knowledge.id}") assert existing_knowledge.id - await self.remove_knowledge(brain_id, existing_knowledge.id) + await self.remove_knowledge_brain(brain_id, existing_knowledge.id) await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1) return True except NoResultFound: @@ -104,7 +141,47 @@ class KnowledgeService(BaseService[KnowledgeRepository]): await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1) return True - async def insert_knowledge( + async def create_knowledge( + self, + user_id: UUID, + knowledge_to_add: AddKnowledge, + upload_file: UploadFile | None = None, + ) -> KnowledgeDB: + knowledgedb = KnowledgeDB( + user_id=user_id, + file_name=knowledge_to_add.file_name, + is_folder=knowledge_to_add.is_folder, + url=knowledge_to_add.url, + extension=knowledge_to_add.extension, + source=knowledge_to_add.source, + source_link=knowledge_to_add.source_link, + file_size=upload_file.size if upload_file else 0, + metadata_=knowledge_to_add.metadata, # type: ignore + status=KnowledgeStatus.RESERVED, + parent_id=knowledge_to_add.parent_id, + ) + knowledge_db = await self.repository.create_knowledge(knowledgedb) + try: + if knowledgedb.source == KnowledgeSource.LOCAL and upload_file: + # NOTE(@aminediro): Unnecessary mem buffer because supabase doesnt accept FileIO.. + buff_reader = io.BufferedReader(upload_file.file) # type: ignore + storage_path = await self.storage.upload_file_storage( + knowledgedb, buff_reader + ) + knowledgedb.source_link = storage_path + knowledge_db = await self.repository.update_knowledge( + knowledge_db, + KnowledgeUpdate(status=KnowledgeStatus.UPLOADED), # type: ignore + ) + return knowledge_db + except Exception as e: + logger.exception( + f"Error uploading knowledge {knowledgedb.id} to storage : {e}" + ) + await self.repository.remove_knowledge(knowledge=knowledge_db) + raise UploadError() + + async def insert_knowledge_brain( self, user_id: UUID, knowledge_to_add: CreateKnowledgeProperties, # FIXME: (later) @Amine brain id should not be in CreateKnowledgeProperties but since storage is brain_id/file_name @@ -122,7 +199,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]): user_id=user_id, ) - knowledge_db = await self.repository.insert_knowledge( + knowledge_db = await self.repository.insert_knowledge_brain( knowledge, brain_id=knowledge_to_add.brain_id ) @@ -150,7 +227,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]): assert isinstance(knowledge.file_name, str), "file_name should be a string" file_name_with_brain_id = f"{brain_id}/{knowledge.file_name}" try: - self.storage.remove_file(file_name_with_brain_id) + await self.storage.remove_file(file_name_with_brain_id) except Exception as e: logger.error( f"Error while removing file {file_name_with_brain_id}: {e}" @@ -161,29 +238,52 @@ class KnowledgeService(BaseService[KnowledgeRepository]): async def update_file_sha1_knowledge(self, knowledge_id: UUID, file_sha1: str): return await self.repository.update_file_sha1_knowledge(knowledge_id, file_sha1) - async def remove_knowledge( + async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeResponse: + assert knowledge.id + + try: + # TODO: + # - Notion folders are special, they are themselves files and should be removed from storage + children = await self.repository.get_all_children(knowledge.id) + km_paths = [ + self.storage.get_storage_path(k) for k in children if not k.is_folder + ] + if not knowledge.is_folder: + km_paths.append(self.storage.get_storage_path(knowledge)) + + # recursively deletes files + deleted_km = await self.repository.remove_knowledge(knowledge) + await asyncio.gather(*[self.storage.remove_file(p) for p in km_paths]) + + return deleted_km + except Exception as e: + logger.error(f"Error while remove knowledge : {e}") + raise KnowledgeDeleteError + + async def remove_knowledge_brain( self, brain_id: UUID, knowledge_id: UUID, # FIXME: @amine when name in storage change no need for brain id ) -> DeleteKnowledgeResponse: # TODO: fix KMS # REDO ALL THIS - knowledge = await self.get_knowledge(knowledge_id) - if len(knowledge.brain_ids) > 1: + knowledge = await self.repository.get_knowledge_by_id(knowledge_id) + km_brains = await knowledge.awaitable_attrs.brains + if len(km_brains) > 1: km = await self.repository.remove_knowledge_from_brain( knowledge_id, brain_id ) + assert km.id return DeleteKnowledgeResponse(file_name=km.file_name, knowledge_id=km.id) else: message = await self.repository.remove_knowledge_by_id(knowledge_id) file_name_with_brain_id = f"{brain_id}/{message.file_name}" try: - self.storage.remove_file(file_name_with_brain_id) + await self.storage.remove_file(file_name_with_brain_id) except Exception as e: logger.error( f"Error while removing file {file_name_with_brain_id}: {e}" ) - return message async def remove_all_knowledges_from_brain(self, brain_id: UUID) -> None: @@ -210,7 +310,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]): # TODO: THIS IS A HACK!! Remove all of this if prev_sync_file: prev_knowledge = await self.get_knowledge_sync(sync_id=prev_sync_file.id) - if len(prev_knowledge.brain_ids) > 1: + if len(prev_knowledge.brains) > 1: await self.repository.remove_knowledge_from_brain( prev_knowledge.id, brain_id ) @@ -231,7 +331,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]): file_sha1=None, metadata={"sync_file_id": str(sync_id)}, ) - added_knowledge = await self.insert_knowledge( + added_knowledge = await self.insert_knowledge_brain( knowledge_to_add=knowledge_to_add, user_id=user_id ) return added_knowledge diff --git a/backend/api/quivr_api/modules/knowledge/tests/conftest.py b/backend/api/quivr_api/modules/knowledge/tests/conftest.py new file mode 100644 index 000000000..2074110f6 --- /dev/null +++ b/backend/api/quivr_api/modules/knowledge/tests/conftest.py @@ -0,0 +1,67 @@ +from io import BufferedReader, FileIO + +from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeDB +from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface + + +class ErrorStorage(StorageInterface): + async def upload_file_storage( + self, + knowledge: KnowledgeDB, + knowledge_data: FileIO | BufferedReader | bytes, + upsert: bool = False, + ): + raise SystemError + + def get_storage_path( + self, + knowledge: KnowledgeDB | Knowledge, + ) -> str: + if knowledge.id is None: + raise ValueError("knowledge should have a valid id") + return str(knowledge.id) + + async def remove_file(self, storage_path: str): + raise SystemError + + +class FakeStorage(StorageInterface): + def __init__(self): + self.storage = {} + + def get_storage_path( + self, + knowledge: KnowledgeDB | Knowledge, + ) -> str: + if knowledge.id is None: + raise ValueError("knowledge should have a valid id") + return str(knowledge.id) + + async def upload_file_storage( + self, + knowledge: KnowledgeDB, + knowledge_data: FileIO | BufferedReader | bytes, + upsert: bool = False, + ): + storage_path = f"{knowledge.id}" + if not upsert and storage_path in self.storage: + raise ValueError(f"File already exists at {storage_path}") + self.storage[storage_path] = knowledge_data + return storage_path + + async def remove_file(self, storage_path: str): + if storage_path not in self.storage: + raise FileNotFoundError(f"File not found at {storage_path}") + del self.storage[storage_path] + + # Additional helper methods for testing + def get_file(self, storage_path: str) -> FileIO | BufferedReader | bytes: + if storage_path not in self.storage: + raise FileNotFoundError(f"File not found at {storage_path}") + return self.storage[storage_path] + + def knowledge_exists(self, knowledge: KnowledgeDB | Knowledge) -> bool: + return self.get_storage_path(knowledge) in self.storage + + def clear_storage(self): + self.storage.clear() diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py new file mode 100644 index 000000000..cf6313e97 --- /dev/null +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py @@ -0,0 +1,74 @@ +import json + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from quivr_api.main import app +from quivr_api.middlewares.auth.auth_bearer import get_current_user +from quivr_api.modules.knowledge.controller.knowledge_routes import get_km_service +from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository +from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService +from quivr_api.modules.knowledge.tests.conftest import FakeStorage +from quivr_api.modules.user.entity.user_identity import User, UserIdentity + + +@pytest_asyncio.fixture(scope="function") +async def user(session: AsyncSession) -> User: + user_1 = ( + await session.exec(select(User).where(User.email == "admin@quivr.app")) + ).one() + assert user_1.id + return user_1 + + +@pytest_asyncio.fixture(scope="function") +async def test_client(session: AsyncSession, user: User): + def default_current_user() -> UserIdentity: + assert user.id + return UserIdentity(email=user.email, id=user.id) + + async def test_service(): + storage = FakeStorage() + repository = KnowledgeRepository(session) + return KnowledgeService(repository, storage) + + app.dependency_overrides[get_current_user] = default_current_user + app.dependency_overrides[get_km_service] = test_service + # app.dependency_overrides[get_async_session] = lambda: session + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: + yield ac + app.dependency_overrides = {} + + +@pytest.mark.asyncio(loop_scope="session") +async def test_post_knowledge(test_client: AsyncClient): + km_data = { + "file_name": "test_file.txt", + "source": "local", + "is_folder": False, + "parent_id": None, + } + + multipart_data = { + "knowledge_data": (None, json.dumps(km_data), "application/json"), + "file": ("test_file.txt", b"Test file content", "application/octet-stream"), + } + + response = await test_client.post( + "/knowledge/", + files=multipart_data, + ) + + assert response.status_code == 200 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_add_knowledge_invalid_input(test_client): + response = await test_client.post("/knowledge/", files={}) + assert response.status_code == 422 diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py new file mode 100644 index 000000000..7376559eb --- /dev/null +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py @@ -0,0 +1,229 @@ +from typing import List, Tuple +from uuid import uuid4 + +import pytest +import pytest_asyncio +from quivr_core.models import KnowledgeStatus +from sqlmodel import select, text +from sqlmodel.ext.asyncio.session import AsyncSession + +from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_api.modules.user.entity.user_identity import User + +TestData = Tuple[Brain, List[KnowledgeDB]] + + +@pytest_asyncio.fixture(scope="function") +async def other_user(session: AsyncSession): + sql = text( + """ + INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES + ('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL); + """ + ) + await session.execute(sql, params={"id": uuid4()}) + + other_user = ( + await session.exec(select(User).where(User.email == "other@quivr.app")) + ).one() + return other_user + + +@pytest_asyncio.fixture(scope="function") +async def user(session): + user_1 = ( + await session.exec(select(User).where(User.email == "admin@quivr.app")) + ).one() + return user_1 + + +@pytest_asyncio.fixture(scope="function") +async def brain(session): + brain_1 = Brain( + name="test_brain", + description="this is a test brain", + brain_type=BrainType.integration, + ) + session.add(brain_1) + await session.commit() + return brain_1 + + +@pytest_asyncio.fixture(scope="function") +async def folder(session, user): + folder = KnowledgeDB( + file_name="folder_1", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=4, + file_sha1=None, + brains=[], + children=[], + user_id=user.id, + is_folder=True, + ) + + session.add(folder) + await session.commit() + await session.refresh(folder) + return folder + + +@pytest.mark.asyncio(loop_scope="session") +async def test_knowledge_default_file(session, folder, user): + km = KnowledgeDB( + file_name="test_file_1.txt", + extension=".txt", + status="UPLOADED", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1="test_sha1", + brains=[], + user_id=user.id, + parent_id=folder.id, + ) + session.add(km) + await session.commit() + await session.refresh(km) + + assert not km.is_folder + + +@pytest.mark.asyncio(loop_scope="session") +async def test_knowledge_parent(session: AsyncSession, user: User): + assert user.id + + km = KnowledgeDB( + file_name="test_file_1.txt", + extension=".txt", + status="UPLOADED", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1="test_sha1", + brains=[], + user_id=user.id, + ) + + folder = KnowledgeDB( + file_name="folder_1", + extension="", + is_folder=True, + status="UPLOADED", + source="local", + source_link="local", + file_size=-1, + file_sha1=None, + brains=[], + children=[km], + user_id=user.id, + ) + + session.add(folder) + await session.commit() + await session.refresh(folder) + await session.refresh(km) + + parent = await km.awaitable_attrs.parent + assert km.parent_id == folder.id, "parent_id isn't set to folder id" + assert parent.id == folder.id, "parent_id isn't set to folder id" + assert parent.is_folder + + query = select(KnowledgeDB).where(KnowledgeDB.id == folder.id) + folder = (await session.exec(query)).first() + assert folder + + children = await folder.awaitable_attrs.children + assert len(children) > 0 + + assert children[0].id == km.id + + +@pytest.mark.asyncio(loop_scope="session") +async def test_knowledge_remove_folder_cascade( + session: AsyncSession, + folder: KnowledgeDB, + user, +): + km = KnowledgeDB( + file_name="test_file_1.txt", + extension=".txt", + status="UPLOADED", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1="test_sha1", + brains=[], + user_id=user.id, + parent_id=folder.id, + ) + session.add(km) + await session.commit() + await session.refresh(km) + + # Check all removed + await session.delete(folder) + await session.commit() + + statement = select(KnowledgeDB) + results = (await session.exec(statement)).all() + assert results == [] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_knowledge_dto(session, user, brain): + # add folder in brain + folder = KnowledgeDB( + file_name="folder_1", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=4, + file_sha1=None, + brains=[brain], + children=[], + user_id=user.id, + is_folder=True, + ) + km = KnowledgeDB( + file_name="test_file_1.txt", + extension=".txt", + status="UPLOADED", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1="test_sha1", + user_id=user.id, + brains=[brain], + parent=folder, + ) + session.add(km) + session.add(km) + await session.commit() + await session.refresh(km) + + km_dto = await km.to_dto() + + assert km_dto.file_name == km.file_name + assert km_dto.url == km.url + assert km_dto.extension == km.extension + assert km_dto.status == KnowledgeStatus(km.status) + assert km_dto.source == km.source + assert km_dto.source_link == km.source_link + assert km_dto.is_folder == km.is_folder + assert km_dto.file_size == km.file_size + assert km_dto.file_sha1 == km.file_sha1 + assert km_dto.updated_at == km.updated_at + assert km_dto.created_at == km.created_at + assert km_dto.metadata == km.metadata_ # type: ignor + assert km_dto.parent + assert km_dto.parent.id == folder.id + + folder_dto = await folder.to_dto() + assert folder_dto.brains[0] == brain.model_dump() + assert folder_dto.children == [await km.to_dto()] diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py new file mode 100644 index 000000000..169b9bef2 --- /dev/null +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py @@ -0,0 +1,1019 @@ +import os +from io import BytesIO +from typing import List, Tuple +from uuid import uuid4 + +import pytest +import pytest_asyncio +from fastapi import UploadFile +from sqlalchemy.exc import NoResultFound +from sqlmodel import select, text +from sqlmodel.ext.asyncio.session import AsyncSession + +from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType +from quivr_api.modules.knowledge.dto.inputs import AddKnowledge, KnowledgeStatus +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeUpdate +from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain +from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository +from quivr_api.modules.knowledge.service.knowledge_exceptions import ( + KnowledgeNotFoundException, + KnowledgeUpdateError, + UploadError, +) +from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService +from quivr_api.modules.knowledge.tests.conftest import ErrorStorage, FakeStorage +from quivr_api.modules.upload.service.upload_file import upload_file_storage +from quivr_api.modules.user.entity.user_identity import User +from quivr_api.modules.vector.entity.vector import Vector + +TestData = Tuple[Brain, List[KnowledgeDB]] + + +@pytest_asyncio.fixture(scope="function") +async def other_user(session: AsyncSession): + sql = text( + """ + INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES + ('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL); + """ + ) + await session.execute(sql, params={"id": uuid4()}) + + other_user = ( + await session.exec(select(User).where(User.email == "other@quivr.app")) + ).one() + return other_user + + +@pytest_asyncio.fixture(scope="function") +async def user(session: AsyncSession) -> User: + user_1 = ( + await session.exec(select(User).where(User.email == "admin@quivr.app")) + ).one() + assert user_1.id + return user_1 + + +@pytest_asyncio.fixture(scope="function") +async def test_data(session: AsyncSession) -> TestData: + user_1 = ( + await session.exec(select(User).where(User.email == "admin@quivr.app")) + ).one() + assert user_1.id + # Brain data + brain_1 = Brain( + name="test_brain", + description="this is a test brain", + brain_type=BrainType.integration, + ) + + knowledge_brain_1 = KnowledgeDB( + file_name="test_file_1.txt", + extension=".txt", + status="UPLOADED", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1="test_sha1", + brains=[brain_1], + user_id=user_1.id, + ) + + knowledge_brain_2 = KnowledgeDB( + file_name="test_file_2.txt", + extension=".txt", + status="UPLOADED", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1="test_sha2", + brains=[], + user_id=user_1.id, + ) + + session.add(brain_1) + session.add(knowledge_brain_1) + session.add(knowledge_brain_2) + await session.commit() + return brain_1, [knowledge_brain_1, knowledge_brain_2] + + +@pytest_asyncio.fixture(scope="function") +async def folder_km_nested(session: AsyncSession, user: User): + assert user.id + + nested_folder = KnowledgeDB( + file_name="folder_1", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=4, + file_sha1=None, + brains=[], + children=[], + user_id=user.id, + is_folder=True, + ) + folder = KnowledgeDB( + file_name="folder_2", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=4, + file_sha1=None, + brains=[], + children=[], + user_id=user.id, + is_folder=True, + parent=nested_folder, + ) + + knowledge_folder = KnowledgeDB( + file_name="file.txt", + extension=".txt", + status="UPLOADED", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1="test_sha2", + brains=[], + user_id=user.id, + parent=folder, + ) + + session.add(nested_folder) + session.add(folder) + session.add(knowledge_folder) + await session.commit() + await session.refresh(folder) + return nested_folder + + +@pytest_asyncio.fixture(scope="function") +async def folder_km(session: AsyncSession, user: User): + assert user.id + folder = KnowledgeDB( + file_name="folder_1", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=4, + file_sha1=None, + brains=[], + children=[], + user_id=user.id, + is_folder=True, + ) + + knowledge_folder = KnowledgeDB( + file_name="file.txt", + extension=".txt", + status="UPLOADED", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1="test_sha2", + brains=[], + user_id=user.id, + parent=folder, + ) + + session.add(folder) + session.add(knowledge_folder) + await session.commit() + await session.refresh(folder) + return folder + + +@pytest.mark.asyncio(loop_scope="session") +async def test_updates_knowledge_status(session: AsyncSession, test_data: TestData): + brain, knowledges = test_data + assert brain.brain_id + assert knowledges[0].id + repo = KnowledgeRepository(session) + await repo.update_status_knowledge(knowledges[0].id, KnowledgeStatus.ERROR) + knowledge = await repo.get_knowledge_by_id(knowledges[0].id) + assert knowledge.status == KnowledgeStatus.ERROR + + +@pytest.mark.asyncio(loop_scope="session") +async def test_updates_knowledge_status_no_knowledge( + session: AsyncSession, test_data: TestData +): + brain, knowledges = test_data + assert brain.brain_id + assert knowledges[0].id + repo = KnowledgeRepository(session) + with pytest.raises(NoResultFound): + await repo.update_status_knowledge(uuid4(), KnowledgeStatus.UPLOADED) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_update_knowledge_source_link(session: AsyncSession, test_data: TestData): + brain, knowledges = test_data + assert brain.brain_id + assert knowledges[0].id + repo = KnowledgeRepository(session) + await repo.update_source_link_knowledge(knowledges[0].id, "new_source_link") + knowledge = await repo.get_knowledge_by_id(knowledges[0].id) + assert knowledge.source_link == "new_source_link" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_remove_knowledge_from_brain(session: AsyncSession, test_data: TestData): + brain, knowledges = test_data + assert brain.brain_id + assert knowledges[0].id + repo = KnowledgeRepository(session) + knowledge = await repo.remove_knowledge_from_brain(knowledges[0].id, brain.brain_id) + assert brain.brain_id not in [ + b.brain_id for b in await knowledge.awaitable_attrs.brains + ] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_cascade_remove_knowledge_by_id( + session: AsyncSession, test_data: TestData +): + brain, knowledges = test_data + assert brain.brain_id + assert knowledges[0].id + repo = KnowledgeRepository(session) + await repo.remove_knowledge_by_id(knowledges[0].id) + with pytest.raises(KnowledgeNotFoundException): + await repo.get_knowledge_by_id(knowledges[0].id) + + query = select(KnowledgeBrain).where( + KnowledgeBrain.knowledge_id == knowledges[0].id + ) + result = await session.exec(query) + knowledge_brain = result.first() + assert knowledge_brain is None + + query = select(Vector).where(Vector.knowledge_id == knowledges[0].id) + result = await session.exec(query) + vector = result.first() + assert vector is None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_remove_all_knowledges_from_brain( + session: AsyncSession, test_data: TestData +): + brain, knowledges = test_data + assert brain.brain_id + + # supabase_client = get_supabase_client() + # db = supabase_client + # storage = db.storage.from_("quivr") + + # storage.upload(f"{brain.brain_id}/test_file_1", b"test_content") + + repo = KnowledgeRepository(session) + service = KnowledgeService(repo) + await repo.remove_all_knowledges_from_brain(brain.brain_id) + knowledges = await service.get_all_knowledge_in_brain(brain.brain_id) + assert len(knowledges) == 0 + + # response = storage.list(path=f"{brain.brain_id}") + # assert response == [] + # FIXME @aminediro &chloedia raise an error when trying to interact with storage UnboundLocalError: cannot access local variable 'response' where it is not associated with a value + + +@pytest.mark.asyncio(loop_scope="session") +async def test_duplicate_sha1_knowledge_same_user( + session: AsyncSession, test_data: TestData +): + brain, [existing_knowledge, _] = test_data + assert brain.brain_id + assert existing_knowledge.id + assert existing_knowledge.file_sha1 + repo = KnowledgeRepository(session) + knowledge = KnowledgeDB( + file_name="test_file_2", + extension="txt", + status="UPLOADED", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1=existing_knowledge.file_sha1, + brains=[brain], + user_id=existing_knowledge.user_id, + ) + + await repo.insert_knowledge_brain(knowledge, brain.brain_id) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_duplicate_sha1_knowledge_diff_user( + session: AsyncSession, test_data: TestData, other_user: User +): + brain, knowledges = test_data + assert other_user.id + assert brain.brain_id + assert knowledges[0].id + repo = KnowledgeRepository(session) + knowledge = KnowledgeDB( + file_name="test_file_2", + extension="txt", + status="UPLOADED", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1=knowledges[0].file_sha1, + brains=[brain], + user_id=other_user.id, # random user id + ) + + result = await repo.insert_knowledge_brain(knowledge, brain.brain_id) + assert result + + +@pytest.mark.asyncio(loop_scope="session") +async def test_add_knowledge_to_brain(session: AsyncSession, test_data: TestData): + brain, knowledges = test_data + assert brain.brain_id + assert knowledges[1].id + repo = KnowledgeRepository(session) + await repo.link_to_brain(knowledges[1], brain.brain_id) + knowledge = await repo.get_knowledge_by_id(knowledges[1].id) + brains_of_knowledge = [b.brain_id for b in await knowledge.awaitable_attrs.brains] + assert brain.brain_id in brains_of_knowledge + + query = select(KnowledgeBrain).where( + KnowledgeBrain.knowledge_id == knowledges[0].id + and KnowledgeBrain.brain_id == brain.brain_id + ) + result = await session.exec(query) + knowledge_brain = result.first() + assert knowledge_brain + + +# Knowledge Service +@pytest.mark.asyncio(loop_scope="session") +async def test_get_knowledge_in_brain(session: AsyncSession, test_data: TestData): + brain, knowledges = test_data + assert brain.brain_id + repo = KnowledgeRepository(session) + service = KnowledgeService(repo) + list_knowledge = await service.get_all_knowledge_in_brain(brain.brain_id) + assert len(list_knowledge) == 1 + brains_of_knowledge = [ + b.brain_id for b in await knowledges[0].awaitable_attrs.brains + ] + assert list_knowledge[0].id == knowledges[0].id + assert list_knowledge[0].file_name == knowledges[0].file_name + assert brain.brain_id in brains_of_knowledge + + +@pytest.mark.asyncio(loop_scope="session") +async def test_should_process_knowledge_exists( + session: AsyncSession, test_data: TestData +): + brain, [existing_knowledge, _] = test_data + assert brain.brain_id + new = KnowledgeDB( + file_name="new", + extension="txt", + status="PROCESSING", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1=None, + brains=[brain], + user_id=existing_knowledge.user_id, + ) + session.add(new) + await session.commit() + await session.refresh(new) + repo = KnowledgeRepository(session) + service = KnowledgeService(repo) + assert existing_knowledge.file_sha1 + with pytest.raises(FileExistsError): + await service.update_sha1_conflict( + new, brain.brain_id, file_sha1=existing_knowledge.file_sha1 + ) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_should_process_knowledge_link_brain( + session: AsyncSession, test_data: TestData +): + repo = KnowledgeRepository(session) + service = KnowledgeService(repo) + brain, [existing_knowledge, _] = test_data + user_id = existing_knowledge.user_id + assert brain.brain_id + prev = KnowledgeDB( + file_name="prev", + extension=".txt", + status=KnowledgeStatus.UPLOADED, + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1="test1", + brains=[brain], + user_id=user_id, + ) + brain_2 = Brain( + name="test_brain", + description="this is a test brain", + brain_type=BrainType.integration, + ) + session.add(brain_2) + session.add(prev) + await session.commit() + await session.refresh(prev) + await session.refresh(brain_2) + + assert prev.id + assert brain_2.brain_id + + new = KnowledgeDB( + file_name="new", + extension="txt", + status="PROCESSING", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1=None, + brains=[brain_2], + user_id=user_id, + ) + session.add(new) + await session.commit() + await session.refresh(new) + + incoming_knowledge = await new.to_dto() + assert prev.file_sha1 + + should_process = await service.update_sha1_conflict( + incoming_knowledge, brain_2.brain_id, file_sha1=prev.file_sha1 + ) + assert not should_process + + # Check prev knowledge was linked + assert incoming_knowledge.file_sha1 + prev_knowledge = await service.repository.get_knowledge_by_id(prev.id) + prev_brains = await prev_knowledge.awaitable_attrs.brains + assert {b.brain_id for b in prev_brains} == { + brain.brain_id, + brain_2.brain_id, + } + # Check new knowledge was removed + assert new.id + with pytest.raises(KnowledgeNotFoundException): + await service.repository.get_knowledge_by_id(new.id) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_should_process_knowledge_prev_error( + session: AsyncSession, test_data: TestData +): + repo = KnowledgeRepository(session) + service = KnowledgeService(repo) + brain, [existing_knowledge, _] = test_data + user_id = existing_knowledge.user_id + assert brain.brain_id + prev = KnowledgeDB( + file_name="prev", + extension="txt", + status=KnowledgeStatus.ERROR, + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1="test1", + brains=[brain], + user_id=user_id, + ) + session.add(prev) + await session.commit() + await session.refresh(prev) + + assert prev.id + + new = KnowledgeDB( + file_name="new", + extension="txt", + status="PROCESSING", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1=None, + brains=[brain], + user_id=user_id, + ) + session.add(new) + await session.commit() + await session.refresh(new) + + incoming_knowledge = await new.to_dto() + assert prev.file_sha1 + should_process = await service.update_sha1_conflict( + incoming_knowledge, brain.brain_id, file_sha1=prev.file_sha1 + ) + + # Checks we should process this file + assert should_process + # Previous errored file is cleaned up + with pytest.raises(KnowledgeNotFoundException): + await service.repository.get_knowledge_by_id(prev.id) + + assert new.id + new = await service.repository.get_knowledge_by_id(new.id) + assert new.file_sha1 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData): + _, [knowledge, _] = test_data + assert knowledge.file_name + repository = KnowledgeRepository(session) + service = KnowledgeService(repository) + brain_2 = Brain( + name="test_brain", + description="this is a test brain", + brain_type=BrainType.integration, + ) + session.add(brain_2) + await session.commit() + await session.refresh(brain_2) + assert brain_2.brain_id + km_data = os.urandom(128) + km_path = f"{str(knowledge.brains[0].brain_id)}/{knowledge.file_name}" + await upload_file_storage(km_data, km_path) + # Link knowledge to two brains + await repository.link_to_brain(knowledge, brain_2.brain_id) + storage_path = await service.get_knowledge_storage_path( + knowledge.file_name, brain_2.brain_id + ) + assert storage_path == km_path + + +@pytest.mark.asyncio(loop_scope="session") +async def test_create_knowledge_file(session: AsyncSession, user: User): + assert user.id + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + km_to_add = AddKnowledge( + file_name="test", + source="local", + is_folder=False, + parent_id=None, + ) + km_data = BytesIO(os.urandom(128)) + + km = await service.create_knowledge( + user_id=user.id, + knowledge_to_add=km_to_add, + upload_file=UploadFile(file=km_data, size=128, filename=km_to_add.file_name), + ) + + assert km.file_name == km_to_add.file_name + assert km.id + assert km.status == KnowledgeStatus.UPLOADED + assert not km.is_folder + # km in storage + storage.knowledge_exists(km) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_create_knowledge_folder(session: AsyncSession, user: User): + assert user.id + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + km_to_add = AddKnowledge( + file_name="test", + source="local", + is_folder=True, + parent_id=None, + ) + km_data = BytesIO(os.urandom(128)) + + km = await service.create_knowledge( + user_id=user.id, + knowledge_to_add=km_to_add, + upload_file=UploadFile(file=km_data, size=128, filename=km_to_add.file_name), + ) + + assert km.file_name == km_to_add.file_name + assert km.id + # Knowledge properties + assert km.file_name == km_to_add.file_name + assert km.is_folder == km_to_add.is_folder + assert km.url == km_to_add.url + assert km.extension == km_to_add.extension + assert km.source == km_to_add.source + assert km.file_size == 128 + assert km.metadata_ == km_to_add.metadata + assert km.is_folder == km_to_add.is_folder + assert km.status == KnowledgeStatus.UPLOADED + # Knowledge was saved + assert storage.knowledge_exists(km) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_create_knowledge_upload_error(session: AsyncSession, user: User): + assert user.id + storage = ErrorStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + km_to_add = AddKnowledge( + file_name="test", + source="local", + is_folder=True, + parent_id=None, + ) + km_data = BytesIO(os.urandom(128)) + + with pytest.raises(UploadError): + await service.create_knowledge( + user_id=user.id, + knowledge_to_add=km_to_add, + upload_file=UploadFile( + file=km_data, size=128, filename=km_to_add.file_name + ), + ) + # Check removed knowledge + statement = select(KnowledgeDB) + results = (await session.exec(statement)).all() + assert results == [] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_get_knowledge(session: AsyncSession, folder_km: KnowledgeDB, user: User): + assert user.id + assert folder_km.id + storage = ErrorStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + result = await service.get_knowledge(folder_km.id) + assert result.id == folder_km.id + assert result.children + assert len(result.children) > 0 + assert result.children[0] == folder_km.children[0] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_get_knowledge_nested( + session: AsyncSession, folder_km_nested: KnowledgeDB, user: User +): + assert user.id + assert folder_km_nested.id + storage = ErrorStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + result = await service.get_knowledge(folder_km_nested.id) + assert result.id == folder_km_nested.id + assert result.children + assert len(result.children) > 0 + assert result.children[0].is_folder + assert result.children[0] == folder_km_nested.children[0] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_update_knowledge_rename( + session: AsyncSession, folder_km: KnowledgeDB, user: User +): + assert user.id + assert folder_km.id + storage = ErrorStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + new_km = await service.update_knowledge( + folder_km, + KnowledgeUpdate(file_name="change_name"), # type: ignore + ) + assert new_km.file_name == "change_name" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_update_knowledge_move( + session: AsyncSession, folder_km: KnowledgeDB, user: User +): + assert user.id + assert folder_km.id + folder_2 = KnowledgeDB( + file_name="folder_2", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=4, + file_sha1=None, + brains=[], + children=[], + user_id=user.id, + is_folder=True, + ) + session.add(folder_2) + await session.commit() + await session.refresh(folder_2) + + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + new_km = await service.update_knowledge( + folder_km, + KnowledgeUpdate(parent_id=folder_2.id), # type: ignore + ) + assert new_km.parent_id == folder_2.id + + +@pytest.mark.asyncio(loop_scope="session") +async def test_update_knowledge_move_error(session: AsyncSession, user: User): + assert user.id + file_1 = KnowledgeDB( + file_name="file_1", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=4, + file_sha1=None, + brains=[], + children=[], + user_id=user.id, + is_folder=False, + ) + file_2 = KnowledgeDB( + file_name="file_2", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=4, + file_sha1=None, + brains=[], + children=[], + user_id=user.id, + is_folder=False, + ) + session.add(file_1) + session.add(file_2) + await session.commit() + await session.refresh(file_1) + await session.refresh(file_2) + + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + with pytest.raises(KnowledgeUpdateError): + await service.update_knowledge( + file_2, + KnowledgeUpdate(parent_id=file_1.id), # type: ignore + ) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_update_knowledge_multiple(session: AsyncSession, user: User): + assert user.id + file = KnowledgeDB( + file_name="file", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=None, + file_sha1=None, + user_id=user.id, + ) + folder = KnowledgeDB( + file_name="folder_2", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=4, + file_sha1=None, + brains=[], + children=[], + user_id=user.id, + is_folder=True, + ) + session.add(file) + session.add(folder) + await session.commit() + await session.refresh(folder) + + storage = ErrorStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + await service.update_knowledge( + file, + KnowledgeUpdate(parent_id=folder.id, status="UPLOADED", file_sha1="sha1"), # type: ignore + ) + + km = ( + await session.exec(select(KnowledgeDB).where(KnowledgeDB.id == file.id)) + ).first() + assert km + assert km.parent_id == folder.id + assert km.status == "UPLOADED" + assert km.file_sha1 == "sha1" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_remove_knowledge(session: AsyncSession, user: User): + assert user.id + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + km_to_add = AddKnowledge( + file_name="test", + source="local", + is_folder=False, + parent_id=None, + ) + km_data = BytesIO(os.urandom(128)) + + # Create the knowledge + km = await service.create_knowledge( + user_id=user.id, + knowledge_to_add=km_to_add, + upload_file=UploadFile(file=km_data, size=128, filename=km_to_add.file_name), + ) + + # Remove knowledge + response = await service.remove_knowledge(knowledge=km) + + assert response.knowledge_id == km.id + assert response.file_name == km.file_name + + assert not storage.knowledge_exists(km) + assert ( + await session.exec(select(KnowledgeDB).where(KnowledgeDB.id == km.id)) + ).first() is None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_remove_knowledge_folder(session: AsyncSession, user: User): + assert user.id + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + folder_add = AddKnowledge( + file_name="folder", + source="local", + is_folder=True, + parent_id=None, + ) + + # Create the knowledge + folder = await service.create_knowledge( + user_id=user.id, knowledge_to_add=folder_add, upload_file=None + ) + file_add = AddKnowledge( + file_name="file", + source="local", + is_folder=False, + parent_id=folder.id, + ) + + km_data = BytesIO(os.urandom(128)) + file = await service.create_knowledge( + user_id=user.id, + knowledge_to_add=file_add, + upload_file=UploadFile(file=km_data, size=128, filename=file_add.file_name), + ) + assert storage.knowledge_exists(file) + + # Remove knowledge + await service.remove_knowledge(knowledge=folder) + + assert not storage.knowledge_exists(folder) + assert not storage.knowledge_exists(file) + assert ( + await session.exec(select(KnowledgeDB).where(KnowledgeDB.id == folder.id)) + ).first() is None + assert ( + await session.exec(select(KnowledgeDB).where(KnowledgeDB.id == file.id)) + ).first() is None + + +@pytest.mark.asyncio(loop_scope="session") +async def test_list_knowledge_root(session: AsyncSession, user: User): + assert user.id + root_file = KnowledgeDB( + file_name="file_1", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=None, + file_sha1=None, + user_id=user.id, + ) + + root_folder = KnowledgeDB( + file_name="folder", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=4, + file_sha1=None, + brains=[], + children=[], + user_id=user.id, + is_folder=True, + ) + nested_file = KnowledgeDB( + file_name="file_2", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=10, + file_sha1=None, + user_id=user.id, + parent=root_folder, + ) + session.add(nested_file) + session.add(root_file) + session.add(root_folder) + await session.commit() + await session.refresh(root_folder) + await session.refresh(root_file) + await session.refresh(nested_file) + + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + root_kms = await service.list_knowledge(knowledge_id=None, user_id=user.id) + + assert len(root_kms) == 2 + assert {k.id for k in root_kms} == {root_folder.id, root_file.id} + + +@pytest.mark.asyncio(loop_scope="session") +async def test_list_knowledge(session: AsyncSession, user: User): + assert user.id + root_file = KnowledgeDB( + file_name="file_1", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=None, + file_sha1=None, + user_id=user.id, + ) + + root_folder = KnowledgeDB( + file_name="folder", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=4, + file_sha1=None, + brains=[], + children=[], + user_id=user.id, + is_folder=True, + ) + nested_file = KnowledgeDB( + file_name="file_2", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=10, + file_sha1=None, + user_id=user.id, + parent=root_folder, + ) + session.add(nested_file) + session.add(root_file) + session.add(root_folder) + await session.commit() + await session.refresh(root_folder) + await session.refresh(root_file) + await session.refresh(nested_file) + + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + kms = await service.list_knowledge(knowledge_id=root_folder.id, user_id=user.id) + + assert len(kms) == 1 + assert kms[0].id == nested_file.id diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledges.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledges.py deleted file mode 100644 index 749bc0199..000000000 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledges.py +++ /dev/null @@ -1,450 +0,0 @@ -import os -from typing import List, Tuple -from uuid import uuid4 - -import pytest -import pytest_asyncio -from sqlalchemy.exc import IntegrityError, NoResultFound -from sqlmodel import select, text -from sqlmodel.ext.asyncio.session import AsyncSession - -from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType -from quivr_api.modules.knowledge.dto.inputs import KnowledgeStatus -from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB -from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain -from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository -from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService -from quivr_api.modules.upload.service.upload_file import upload_file_storage -from quivr_api.modules.user.entity.user_identity import User -from quivr_api.modules.vector.entity.vector import Vector - -pg_database_base_url = "postgres:postgres@localhost:54322/postgres" - -TestData = Tuple[Brain, List[KnowledgeDB]] - - -@pytest_asyncio.fixture(scope="function") -async def other_user(session: AsyncSession): - sql = text( - """ - INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES - ('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL); - """ - ) - await session.execute(sql, params={"id": uuid4()}) - - other_user = ( - await session.exec(select(User).where(User.email == "other@quivr.app")) - ).one() - return other_user - - -@pytest_asyncio.fixture(scope="function") -async def test_data(session: AsyncSession) -> TestData: - user_1 = ( - await session.exec(select(User).where(User.email == "admin@quivr.app")) - ).one() - assert user_1.id - # Brain data - brain_1 = Brain( - name="test_brain", - description="this is a test brain", - brain_type=BrainType.integration, - ) - - knowledge_brain_1 = KnowledgeDB( - file_name="test_file_1.txt", - extension=".txt", - status="UPLOADED", - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1="test_sha1", - brains=[brain_1], - user_id=user_1.id, - ) - - knowledge_brain_2 = KnowledgeDB( - file_name="test_file_2.txt", - extension=".txt", - status="UPLOADED", - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1="test_sha2", - brains=[], - user_id=user_1.id, - ) - - session.add(brain_1) - session.add(knowledge_brain_1) - session.add(knowledge_brain_2) - await session.commit() - return brain_1, [knowledge_brain_1, knowledge_brain_2] - - -@pytest.mark.asyncio(loop_scope="session") -async def test_updates_knowledge_status(session: AsyncSession, test_data: TestData): - brain, knowledges = test_data - assert brain.brain_id - assert knowledges[0].id - repo = KnowledgeRepository(session) - await repo.update_status_knowledge(knowledges[0].id, KnowledgeStatus.ERROR) - knowledge = await repo.get_knowledge_by_id(knowledges[0].id) - assert knowledge.status == KnowledgeStatus.ERROR - - -@pytest.mark.asyncio(loop_scope="session") -async def test_updates_knowledge_status_no_knowledge( - session: AsyncSession, test_data: TestData -): - brain, knowledges = test_data - assert brain.brain_id - assert knowledges[0].id - repo = KnowledgeRepository(session) - with pytest.raises(NoResultFound): - await repo.update_status_knowledge(uuid4(), KnowledgeStatus.UPLOADED) - - -@pytest.mark.asyncio(loop_scope="session") -async def test_update_knowledge_source_link(session: AsyncSession, test_data: TestData): - brain, knowledges = test_data - assert brain.brain_id - assert knowledges[0].id - repo = KnowledgeRepository(session) - await repo.update_source_link_knowledge(knowledges[0].id, "new_source_link") - knowledge = await repo.get_knowledge_by_id(knowledges[0].id) - assert knowledge.source_link == "new_source_link" - - -@pytest.mark.asyncio(loop_scope="session") -async def test_remove_knowledge_from_brain(session: AsyncSession, test_data: TestData): - brain, knowledges = test_data - assert brain.brain_id - assert knowledges[0].id - repo = KnowledgeRepository(session) - knowledge = await repo.remove_knowledge_from_brain(knowledges[0].id, brain.brain_id) - assert brain.brain_id not in [ - b.brain_id for b in await knowledge.awaitable_attrs.brains - ] - - -@pytest.mark.asyncio(loop_scope="session") -async def test_cascade_remove_knowledge_by_id( - session: AsyncSession, test_data: TestData -): - brain, knowledges = test_data - assert brain.brain_id - assert knowledges[0].id - repo = KnowledgeRepository(session) - await repo.remove_knowledge_by_id(knowledges[0].id) - with pytest.raises(NoResultFound): - await repo.get_knowledge_by_id(knowledges[0].id) - - query = select(KnowledgeBrain).where( - KnowledgeBrain.knowledge_id == knowledges[0].id - ) - result = await session.exec(query) - knowledge_brain = result.first() - assert knowledge_brain is None - - query = select(Vector).where(Vector.knowledge_id == knowledges[0].id) - result = await session.exec(query) - vector = result.first() - assert vector is None - - -@pytest.mark.asyncio(loop_scope="session") -async def test_remove_all_knowledges_from_brain( - session: AsyncSession, test_data: TestData -): - brain, knowledges = test_data - assert brain.brain_id - - # supabase_client = get_supabase_client() - # db = supabase_client - # storage = db.storage.from_("quivr") - - # storage.upload(f"{brain.brain_id}/test_file_1", b"test_content") - - repo = KnowledgeRepository(session) - service = KnowledgeService(repo) - await repo.remove_all_knowledges_from_brain(brain.brain_id) - knowledges = await service.get_all_knowledge_in_brain(brain.brain_id) - assert len(knowledges) == 0 - - # response = storage.list(path=f"{brain.brain_id}") - # assert response == [] - # FIXME @aminediro &chloedia raise an error when trying to interact with storage UnboundLocalError: cannot access local variable 'response' where it is not associated with a value - - -@pytest.mark.asyncio(loop_scope="session") -async def test_duplicate_sha1_knowledge_same_user( - session: AsyncSession, test_data: TestData -): - brain, [existing_knowledge, _] = test_data - assert brain.brain_id - assert existing_knowledge.id - assert existing_knowledge.file_sha1 - repo = KnowledgeRepository(session) - knowledge = KnowledgeDB( - file_name="test_file_2", - extension="txt", - status="UPLOADED", - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1=existing_knowledge.file_sha1, - brains=[brain], - user_id=existing_knowledge.user_id, - ) - - with pytest.raises(IntegrityError): # FIXME: Should raise IntegrityError - await repo.insert_knowledge(knowledge, brain.brain_id) - - -@pytest.mark.asyncio(loop_scope="session") -async def test_duplicate_sha1_knowledge_diff_user( - session: AsyncSession, test_data: TestData, other_user: User -): - brain, knowledges = test_data - assert other_user.id - assert brain.brain_id - assert knowledges[0].id - repo = KnowledgeRepository(session) - knowledge = KnowledgeDB( - file_name="test_file_2", - extension="txt", - status="UPLOADED", - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1=knowledges[0].file_sha1, - brains=[brain], - user_id=other_user.id, # random user id - ) - - result = await repo.insert_knowledge(knowledge, brain.brain_id) - assert result - - -@pytest.mark.asyncio(loop_scope="session") -async def test_add_knowledge_to_brain(session: AsyncSession, test_data: TestData): - brain, knowledges = test_data - assert brain.brain_id - assert knowledges[1].id - repo = KnowledgeRepository(session) - await repo.link_to_brain(knowledges[1], brain.brain_id) - knowledge = await repo.get_knowledge_by_id(knowledges[1].id) - brains_of_knowledge = [b.brain_id for b in await knowledge.awaitable_attrs.brains] - assert brain.brain_id in brains_of_knowledge - - query = select(KnowledgeBrain).where( - KnowledgeBrain.knowledge_id == knowledges[0].id - and KnowledgeBrain.brain_id == brain.brain_id - ) - result = await session.exec(query) - knowledge_brain = result.first() - assert knowledge_brain - - -# Knowledge Service -@pytest.mark.asyncio(loop_scope="session") -async def test_get_knowledge_in_brain(session: AsyncSession, test_data: TestData): - brain, knowledges = test_data - assert brain.brain_id - repo = KnowledgeRepository(session) - service = KnowledgeService(repo) - list_knowledge = await service.get_all_knowledge_in_brain(brain.brain_id) - assert len(list_knowledge) == 1 - brains_of_knowledge = [ - b.brain_id for b in await knowledges[0].awaitable_attrs.brains - ] - assert list_knowledge[0].id == knowledges[0].id - assert list_knowledge[0].file_name == knowledges[0].file_name - assert brain.brain_id in brains_of_knowledge - - -@pytest.mark.asyncio(loop_scope="session") -async def test_should_process_knowledge_exists( - session: AsyncSession, test_data: TestData -): - brain, [existing_knowledge, _] = test_data - assert brain.brain_id - new = KnowledgeDB( - file_name="new", - extension="txt", - status="PROCESSING", - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1=None, - brains=[brain], - user_id=existing_knowledge.user_id, - ) - session.add(new) - await session.commit() - await session.refresh(new) - incoming_knowledge = await new.to_dto() - repo = KnowledgeRepository(session) - service = KnowledgeService(repo) - assert existing_knowledge.file_sha1 - with pytest.raises(FileExistsError): - await service.update_sha1_conflict( - incoming_knowledge, brain.brain_id, file_sha1=existing_knowledge.file_sha1 - ) - - -@pytest.mark.asyncio(loop_scope="session") -async def test_should_process_knowledge_link_brain( - session: AsyncSession, test_data: TestData -): - repo = KnowledgeRepository(session) - service = KnowledgeService(repo) - brain, [existing_knowledge, _] = test_data - user_id = existing_knowledge.user_id - assert brain.brain_id - prev = KnowledgeDB( - file_name="prev", - extension=".txt", - status=KnowledgeStatus.UPLOADED, - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1="test1", - brains=[brain], - user_id=user_id, - ) - brain_2 = Brain( - name="test_brain", - description="this is a test brain", - brain_type=BrainType.integration, - ) - session.add(brain_2) - session.add(prev) - await session.commit() - await session.refresh(prev) - await session.refresh(brain_2) - - assert prev.id - assert brain_2.brain_id - - new = KnowledgeDB( - file_name="new", - extension="txt", - status="PROCESSING", - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1=None, - brains=[brain_2], - user_id=user_id, - ) - session.add(new) - await session.commit() - await session.refresh(new) - - incoming_knowledge = await new.to_dto() - assert prev.file_sha1 - - should_process = await service.update_sha1_conflict( - incoming_knowledge, brain_2.brain_id, file_sha1=prev.file_sha1 - ) - assert not should_process - - # Check prev knowledge was linked - assert incoming_knowledge.file_sha1 - prev_knowledge = await service.repository.get_knowledge_by_id(prev.id) - prev_brains = await prev_knowledge.awaitable_attrs.brains - assert {b.brain_id for b in prev_brains} == { - brain.brain_id, - brain_2.brain_id, - } - # Check new knowledge was removed - assert new.id - with pytest.raises(NoResultFound): - await service.repository.get_knowledge_by_id(new.id) - - -@pytest.mark.asyncio(loop_scope="session") -async def test_should_process_knowledge_prev_error( - session: AsyncSession, test_data: TestData -): - repo = KnowledgeRepository(session) - service = KnowledgeService(repo) - brain, [existing_knowledge, _] = test_data - user_id = existing_knowledge.user_id - assert brain.brain_id - prev = KnowledgeDB( - file_name="prev", - extension="txt", - status=KnowledgeStatus.ERROR, - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1="test1", - brains=[brain], - user_id=user_id, - ) - session.add(prev) - await session.commit() - await session.refresh(prev) - - assert prev.id - - new = KnowledgeDB( - file_name="new", - extension="txt", - status="PROCESSING", - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1=None, - brains=[brain], - user_id=user_id, - ) - session.add(new) - await session.commit() - await session.refresh(new) - - incoming_knowledge = await new.to_dto() - assert prev.file_sha1 - should_process = await service.update_sha1_conflict( - incoming_knowledge, brain.brain_id, file_sha1=prev.file_sha1 - ) - - # Checks we should process this file - assert should_process - # Previous errored file is cleaned up - with pytest.raises(NoResultFound): - await service.repository.get_knowledge_by_id(prev.id) - - assert new.id - new = await service.repository.get_knowledge_by_id(new.id) - assert new.file_sha1 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData): - brain, [knowledge, _] = test_data - assert knowledge.file_name - repository = KnowledgeRepository(session) - service = KnowledgeService(repository) - brain_2 = Brain( - name="test_brain", - description="this is a test brain", - brain_type=BrainType.integration, - ) - session.add(brain_2) - await session.commit() - await session.refresh(brain_2) - assert brain_2.brain_id - km_data = os.urandom(128) - km_path = f"{str(knowledge.brains[0].brain_id)}/{knowledge.file_name}" - await upload_file_storage(km_data, km_path) - # Link knowledge to two brains - await repository.link_to_brain(knowledge, brain_2.brain_id) - storage_path = await service.get_knowledge_storage_path( - knowledge.file_name, brain_2.brain_id - ) - assert storage_path == km_path diff --git a/backend/api/quivr_api/modules/sync/tests/test_syncutils.py b/backend/api/quivr_api/modules/sync/tests/test_syncutils.py index 0c16ad09d..63b212128 100644 --- a/backend/api/quivr_api/modules/sync/tests/test_syncutils.py +++ b/backend/api/quivr_api/modules/sync/tests/test_syncutils.py @@ -317,10 +317,11 @@ async def test_process_sync_file_noprev( created_km = all_km[0] assert created_km.file_name == sync_file.name assert created_km.extension == ".txt" - assert created_km.file_sha1 is not None + assert created_km.file_sha1 is None assert created_km.created_at is not None assert created_km.metadata == {"sync_file_id": "1"} - assert created_km.brain_ids == [brain_1.brain_id] + assert len(created_km.brains)> 0 + assert created_km.brains[0]["brain_id"]== brain_1.brain_id # Assert celery task in correct assert task["args"] == ("process_file_task",) @@ -409,12 +410,12 @@ async def test_process_sync_file_with_prev( created_km = all_km[0] assert created_km.file_name == sync_file.name assert created_km.extension == ".txt" - assert created_km.file_sha1 is not None + assert created_km.file_sha1 is None assert created_km.updated_at assert created_km.created_at assert created_km.updated_at == created_km.created_at # new line assert created_km.metadata == {"sync_file_id": str(dbfiles[0].id)} - assert created_km.brain_ids == [brain_1.brain_id] + assert created_km.brains[0]["brain_id"]== brain_1.brain_id # Check file content changed assert check_file_exists(str(brain_1.brain_id), sync_file.name) diff --git a/backend/api/quivr_api/modules/upload/controller/upload_routes.py b/backend/api/quivr_api/modules/upload/controller/upload_routes.py index 0f614c2af..0bf6e952d 100644 --- a/backend/api/quivr_api/modules/upload/controller/upload_routes.py +++ b/backend/api/quivr_api/modules/upload/controller/upload_routes.py @@ -53,12 +53,10 @@ AsyncClientDep = Annotated[AsyncClient, Depends(get_supabase_async_client)] @upload_router.post("/upload", dependencies=[Depends(AuthBearer())], tags=["Upload"]) async def upload_file( uploadFile: UploadFile, - client: AsyncClientDep, - background_tasks: BackgroundTasks, knowledge_service: KnowledgeServiceDep, + background_tasks: BackgroundTasks, bulk_id: Optional[UUID] = Query(None, description="The ID of the bulk upload"), brain_id: UUID = Query(..., description="The ID of the brain"), - chat_id: Optional[UUID] = Query(None, description="The ID of the chat"), current_user: UserIdentity = Depends(get_current_user), integration: Optional[str] = None, integration_link: Optional[str] = None, @@ -121,7 +119,7 @@ async def upload_file( file_size=uploadFile.size, file_sha1=None, ) - knowledge = await knowledge_service.insert_knowledge( + knowledge = await knowledge_service.insert_knowledge_brain( user_id=current_user.id, knowledge_to_add=knowledge_to_add ) # type: ignore diff --git a/backend/api/quivr_api/routes/crawl_routes.py b/backend/api/quivr_api/routes/crawl_routes.py index e4d06d61a..804c379af 100644 --- a/backend/api/quivr_api/routes/crawl_routes.py +++ b/backend/api/quivr_api/routes/crawl_routes.py @@ -87,7 +87,7 @@ async def crawl_endpoint( source_link=crawl_website.url, ) - added_knowledge = await knowledge_service.insert_knowledge( + added_knowledge = await knowledge_service.insert_knowledge_brain( knowledge_to_add=knowledge_to_add, user_id=current_user.id ) logger.info(f"Knowledge {added_knowledge} added successfully") diff --git a/backend/api/quivr_api/utils/partial.py b/backend/api/quivr_api/utils/partial.py new file mode 100644 index 000000000..138a36c71 --- /dev/null +++ b/backend/api/quivr_api/utils/partial.py @@ -0,0 +1,50 @@ +from copy import deepcopy +from typing import Any, Callable, Optional, Type, TypeVar +from uuid import UUID + +from pydantic import BaseModel, create_model +from pydantic.fields import FieldInfo + +Model = TypeVar("Model", bound=Type[BaseModel]) + + +def all_optional(without_fields: list[str] | None = None) -> Callable[[Model], Model]: + if without_fields is None: + without_fields = [] + + def wrapper(model: Type[Model]) -> Type[Model]: + base_model: Type[Model] = model + + def make_field_optional( + field: FieldInfo, default: Any = None + ) -> tuple[Any, FieldInfo]: + new = deepcopy(field) + new.default = default + new.annotation = Optional[field.annotation] + return new.annotation, new + + if without_fields: + base_model = BaseModel + + return create_model( + model.__name__, + __base__=base_model, + __module__=model.__module__, + **{ + field_name: make_field_optional(field_info) + for field_name, field_info in model.model_fields.items() + if field_name not in without_fields + }, + ) + + return wrapper + + +class Test(BaseModel): + id: UUID + name: Optional[str] = None + + +@all_optional() +class TestUpdate(Test): + pass diff --git a/backend/core/quivr_core/models.py b/backend/core/quivr_core/models.py index 053642c21..8ebf2bbe2 100644 --- a/backend/core/quivr_core/models.py +++ b/backend/core/quivr_core/models.py @@ -42,6 +42,7 @@ class KnowledgeStatus(str, Enum): PROCESSING = "PROCESSING" UPLOADED = "UPLOADED" ERROR = "ERROR" + RESERVED = "RESERVED" class Source(BaseModel): diff --git a/backend/supabase/migrations/20240905153004_knowledge-folders.sql b/backend/supabase/migrations/20240905153004_knowledge-folders.sql new file mode 100644 index 000000000..5b2ac3165 --- /dev/null +++ b/backend/supabase/migrations/20240905153004_knowledge-folders.sql @@ -0,0 +1,31 @@ +ALTER USER postgres +SET idle_session_timeout = '3min'; +ALTER USER postgres +SET idle_in_transaction_session_timeout = '3min'; +-- Drop previous contraint +alter table "public"."knowledge" drop constraint "unique_file_sha1_user_id"; +alter table "public"."knowledge" +add column "is_folder" boolean default false; +-- Update the knowledge to backfill knowledge to is_folder = false +UPDATE "public"."knowledge" +SET is_folder = false; +-- Add parent_id -> folder +alter table "public"."knowledge" +add column "parent_id" uuid; +alter table "public"."knowledge" +add constraint "public_knowledge_parent_id_fkey" FOREIGN KEY (parent_id) REFERENCES knowledge(id) ON DELETE CASCADE; +-- Add constraint must be folder for parent_id +CREATE FUNCTION is_parent_folder(folder_id uuid) RETURNS boolean AS $$ BEGIN RETURN ( + SELECT k.is_folder + FROM public.knowledge k + WHERE k.id = folder_id +); +END; +$$ LANGUAGE plpgsql; +ALTER TABLE public.knowledge +ADD CONSTRAINT check_parent_is_folder CHECK ( + parent_id IS NULL + OR is_parent_folder(parent_id) + ); +-- Index on parent_id +CREATE INDEX knowledge_parent_id_idx ON public.knowledge USING btree (parent_id); diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py index c438c742d..ceb1632c8 100644 --- a/backend/worker/quivr_worker/celery_worker.py +++ b/backend/worker/quivr_worker/celery_worker.py @@ -13,7 +13,7 @@ from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors from quivr_api.modules.brain.service.brain_service import BrainService from quivr_api.modules.dependencies import get_supabase_client from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository -from quivr_api.modules.knowledge.repository.storage import Storage +from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.notification.service.notification_service import ( NotificationService, @@ -58,7 +58,7 @@ sync_user_service = SyncUserService() sync_files_repo_service = SyncFilesRepository() brain_service = BrainService() brain_vectors = BrainsVectors() -storage = Storage() +storage = SupabaseS3Storage() notion_service: SyncNotionService | None = None async_engine: AsyncEngine | None = None engine: Engine | None = None @@ -170,6 +170,8 @@ async def aprocess_file_task( integration_link=source_link, delete_file=delete_file, ) + session.commit() + await async_session.commit() except Exception as e: session.rollback() await async_session.rollback() @@ -196,19 +198,29 @@ def process_crawl_task( ) global engine assert engine - with Session(engine, expire_on_commit=False, autoflush=False) as session: - vector_repository = VectorRepository(session) - vector_service = VectorService(vector_repository) - loop = asyncio.get_event_loop() - loop.run_until_complete( - process_url_func( - url=crawl_website_url, - brain_id=brain_id, - knowledge_id=knowledge_id, - brain_service=brain_service, - vector_service=vector_service, + try: + with Session(engine, expire_on_commit=False, autoflush=False) as session: + session.execute( + text("SET SESSION idle_in_transaction_session_timeout = '5min';") ) - ) + vector_repository = VectorRepository(session) + vector_service = VectorService(vector_repository) + loop = asyncio.get_event_loop() + loop.run_until_complete( + process_url_func( + url=crawl_website_url, + brain_id=brain_id, + knowledge_id=knowledge_id, + brain_service=brain_service, + vector_service=vector_service, + ) + ) + session.commit() + except Exception as e: + session.rollback() + raise e + finally: + session.close() @celery.task(name="NotionConnectorLoad") diff --git a/backend/worker/quivr_worker/process/process_s3_file.py b/backend/worker/quivr_worker/process/process_s3_file.py index a85465794..99bc4e736 100644 --- a/backend/worker/quivr_worker/process/process_s3_file.py +++ b/backend/worker/quivr_worker/process/process_s3_file.py @@ -2,6 +2,7 @@ from uuid import UUID from quivr_api.logger import get_logger from quivr_api.modules.brain.service.brain_service import BrainService +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeUpdate from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.vector.service.vector_service import VectorService @@ -41,17 +42,15 @@ async def process_uploaded_file( # If we have some knowledge with error with build_file(file_data, knowledge_id, file_name) as file_instance: knowledge = await knowledge_service.get_knowledge(knowledge_id=knowledge_id) - should_process = await knowledge_service.update_sha1_conflict( - knowledge=knowledge, - brain_id=brain.brain_id, - file_sha1=file_instance.file_sha1, + await knowledge_service.update_knowledge( + knowledge, + KnowledgeUpdate(file_sha1=file_instance.file_sha1), # type: ignore + ) + await process_file( + file_instance=file_instance, + brain=brain, + brain_service=brain_service, + vector_service=vector_service, + integration=integration, + integration_link=integration_link, ) - if should_process: - await process_file( - file_instance=file_instance, - brain=brain, - brain_service=brain_service, - vector_service=vector_service, - integration=integration, - integration_link=integration_link, - ) diff --git a/backend/worker/quivr_worker/syncs/process_active_syncs.py b/backend/worker/quivr_worker/syncs/process_active_syncs.py index 299d9f2b0..d190c2191 100644 --- a/backend/worker/quivr_worker/syncs/process_active_syncs.py +++ b/backend/worker/quivr_worker/syncs/process_active_syncs.py @@ -141,7 +141,7 @@ async def process_notion_sync( UUID(user_id), notion_client, # type: ignore ) - + await session.commit() except Exception as e: await session.rollback() raise e diff --git a/backend/worker/quivr_worker/syncs/store_notion.py b/backend/worker/quivr_worker/syncs/store_notion.py index 82925e773..821de8874 100644 --- a/backend/worker/quivr_worker/syncs/store_notion.py +++ b/backend/worker/quivr_worker/syncs/store_notion.py @@ -40,6 +40,8 @@ async def fetch_and_store_notion_files_async( else: logger.warn("No notion page fetched") + # Commit all before exiting + await session.commit() except Exception as e: await session.rollback() raise e diff --git a/backend/worker/quivr_worker/syncs/utils.py b/backend/worker/quivr_worker/syncs/utils.py index e1523b29b..bbc3c75f8 100644 --- a/backend/worker/quivr_worker/syncs/utils.py +++ b/backend/worker/quivr_worker/syncs/utils.py @@ -6,7 +6,7 @@ from quivr_api.celery_config import celery from quivr_api.logger import get_logger from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository -from quivr_api.modules.knowledge.repository.storage import Storage +from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.notification.service.notification_service import ( NotificationService, @@ -42,7 +42,7 @@ class SyncServices: sync_files_repo_service: SyncFilesRepository notification_service: NotificationService brain_vectors: BrainsVectors - storage: Storage + storage: SupabaseS3Storage @asynccontextmanager @@ -56,7 +56,6 @@ async def build_syncs_utils( await session.execute( text("SET SESSION idle_in_transaction_session_timeout = '5min';") ) - # TODO pass services from celery_worker notion_repository = NotionRepository(session) notion_service = SyncNotionService(notion_repository) knowledge_service = KnowledgeService(KnowledgeRepository(session)) @@ -84,7 +83,7 @@ async def build_syncs_utils( mapping_sync_utils[provider_name] = provider_sync_util yield mapping_sync_utils - + await session.commit() except Exception as e: await session.rollback() raise e