feat: CRUD KMS (no syncs) (#3162)

# Description

closes #3056.
closes #3198 


- Create knowledge route
- Get knowledge route
- List knowledge route : accepts knowledge_id | None. None to list root
knowledge for use
- Update (patch) knowledge to rename and move knowledge
- Remove knowledge: Cascade if parent_id in knowledge and cleanup
storage
- Link storage upload to knowledge_service
- Relax sha1 file constraint
- Tests to all repository / service

---------

Co-authored-by: Stan Girard <girard.stanislas@gmail.com>
This commit is contained in:
AmineDiro 2024-09-16 13:31:09 +02:00 committed by GitHub
parent edc4118ba1
commit 71edca572f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 2162 additions and 596 deletions

View File

@ -9,7 +9,9 @@ on:
jobs: jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy:
matrix:
project: [quivr-api, quivr-worker]
steps: steps:
- name: 👀 Checkout code - name: 👀 Checkout code
uses: actions/checkout@v2 uses: actions/checkout@v2
@ -65,4 +67,4 @@ jobs:
supabase start supabase start
rye run python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()" rye run python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()"
rye run python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')" rye run python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
rye test -p quivr-api -p quivr-worker rye test -p ${{ matrix.project }}

View File

@ -32,13 +32,6 @@ repos:
- id: mypy - id: mypy
name: mypy name: mypy
additional_dependencies: ["types-aiofiles"] additional_dependencies: ["types-aiofiles"]
- repo: https://github.com/python-poetry/poetry
rev: "1.8.0"
hooks:
- id: poetry-check
args: ["-C", "./backend/core"]
- id: poetry-lock
args: ["-C", "./backend/core"]
ci: ci:
autofix_commit_msg: | autofix_commit_msg: |
[pre-commit.ci] auto fixes from pre-commit.com hooks [pre-commit.ci] auto fixes from pre-commit.com hooks

View File

@ -3,6 +3,7 @@ from typing import Optional
from fastapi import Depends, HTTPException, Request from fastapi import Depends, HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from quivr_api.middlewares.auth.jwt_token_handler import ( from quivr_api.middlewares.auth.jwt_token_handler import (
decode_access_token, decode_access_token,
verify_token, verify_token,
@ -57,9 +58,13 @@ class AuthBearer(HTTPBearer):
def get_test_user(self) -> UserIdentity: def get_test_user(self) -> UserIdentity:
return UserIdentity( return UserIdentity(
email="admin@quivr.app", id="39418e3b-0258-4452-af60-7acfcc1263ff" # type: ignore email="admin@quivr.app",
id="39418e3b-0258-4452-af60-7acfcc1263ff", # type: ignore
) # replace with test user information ) # replace with test user information
def get_current_user(user: UserIdentity = Depends(AuthBearer())) -> UserIdentity: auth_bearer = AuthBearer()
def get_current_user(user: UserIdentity = Depends(auth_bearer)) -> UserIdentity:
return user return user

View File

@ -69,6 +69,7 @@ class Brain(AsyncAttrs, SQLModel, table=True):
back_populates="brains", link_model=KnowledgeBrain back_populates="brains", link_model=KnowledgeBrain
) )
# TODO : add # TODO : add
# "meaning" "public"."vector", # "meaning" "public"."vector",
# "tags" "public"."tags"[] # "tags" "public"."tags"[]

View File

@ -2,7 +2,7 @@ from uuid import UUID
from quivr_api.logger import get_logger from quivr_api.logger import get_logger
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
from quivr_api.modules.knowledge.repository.storage import Storage from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
logger = get_logger(__name__) logger = get_logger(__name__)
@ -11,7 +11,7 @@ class BrainVectorService:
def __init__(self, brain_id: UUID): def __init__(self, brain_id: UUID):
self.repository = BrainsVectors() self.repository = BrainsVectors()
self.brain_id = brain_id self.brain_id = brain_id
self.storage = Storage() self.storage = SupabaseS3Storage()
def create_brain_vector(self, vector_id: str, file_sha1: str): def create_brain_vector(self, vector_id: str, file_sha1: str):
return self.repository.create_brain_vector(self.brain_id, vector_id, file_sha1) # type: ignore return self.repository.create_brain_vector(self.brain_id, vector_id, file_sha1) # type: ignore
@ -26,10 +26,10 @@ class BrainVectorService:
for vector_id in vector_ids: for vector_id in vector_ids:
self.create_brain_vector(vector_id, file_sha1) self.create_brain_vector(vector_id, file_sha1)
def delete_file_from_brain(self, file_name: str, only_vectors: bool = False): async def delete_file_from_brain(self, file_name: str, only_vectors: bool = False):
file_name_with_brain_id = f"{self.brain_id}/{file_name}" file_name_with_brain_id = f"{self.brain_id}/{file_name}"
if not only_vectors: if not only_vectors:
self.storage.remove_file(file_name_with_brain_id) await self.storage.remove_file(file_name_with_brain_id)
return self.repository.delete_file_from_brain(self.brain_id, file_name) # type: ignore return self.repository.delete_file_from_brain(self.brain_id, file_name) # type: ignore
def delete_file_url_from_brain(self, file_name: str): def delete_file_url_from_brain(self, file_name: str):

View File

@ -24,9 +24,6 @@ async_engine = create_async_engine(
"postgresql+asyncpg://" + pg_database_base_url, "postgresql+asyncpg://" + pg_database_base_url,
echo=True if os.getenv("ORM_DEBUG") else False, echo=True if os.getenv("ORM_DEBUG") else False,
future=True, future=True,
pool_pre_ping=True,
pool_size=10,
pool_recycle=0.1,
) )

View File

@ -7,8 +7,7 @@ from langchain.embeddings.base import Embeddings
from langchain_community.embeddings.ollama import OllamaEmbeddings from langchain_community.embeddings.ollama import OllamaEmbeddings
# from langchain_community.vectorstores.supabase import SupabaseVectorStore # from langchain_community.vectorstores.supabase import SupabaseVectorStore
from langchain_openai import OpenAIEmbeddings from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_openai import AzureOpenAIEmbeddings
# from quivr_api.modules.vector.service.vector_service import VectorService # from quivr_api.modules.vector.service.vector_service import VectorService
# from quivr_api.modules.vectorstore.supabase import CustomSupabaseVectorStore # from quivr_api.modules.vectorstore.supabase import CustomSupabaseVectorStore
@ -22,7 +21,6 @@ from quivr_api.models.databases.supabase.supabase import SupabaseDB
from quivr_api.models.settings import BrainSettings from quivr_api.models.settings import BrainSettings
from supabase.client import AsyncClient, Client, create_async_client, create_client from supabase.client import AsyncClient, Client, create_async_client, create_client
# Global variables to store the Supabase client and database instances # Global variables to store the Supabase client and database instances
_supabase_client: Optional[Client] = None _supabase_client: Optional[Client] = None
_supabase_async_client: Optional[AsyncClient] = None _supabase_async_client: Optional[AsyncClient] = None

View File

@ -1,8 +1,8 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Annotated from typing import Annotated, List, Optional
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile, status
from quivr_api.logger import get_logger from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.middlewares.auth import AuthBearer, get_current_user
@ -12,6 +12,14 @@ from quivr_api.modules.brain.service.brain_authorization_service import (
validate_brain_authorization, validate_brain_authorization,
) )
from quivr_api.modules.dependencies import get_service from quivr_api.modules.dependencies import get_service
from quivr_api.modules.knowledge.dto.inputs import AddKnowledge
from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeUpdate
from quivr_api.modules.knowledge.service.knowledge_exceptions import (
KnowledgeDeleteError,
KnowledgeForbiddenAccess,
KnowledgeNotFoundException,
UploadError,
)
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.upload.service.generate_file_signed_url import ( from quivr_api.modules.upload.service.generate_file_signed_url import (
generate_file_signed_url, generate_file_signed_url,
@ -21,9 +29,8 @@ from quivr_api.modules.user.entity.user_identity import UserIdentity
knowledge_router = APIRouter() knowledge_router = APIRouter()
logger = get_logger(__name__) logger = get_logger(__name__)
KnowledgeServiceDep = Annotated[ get_km_service = get_service(KnowledgeService)
KnowledgeService, Depends(get_service(KnowledgeService)) KnowledgeServiceDep = Annotated[KnowledgeService, Depends(get_km_service)]
]
@knowledge_router.get( @knowledge_router.get(
@ -53,7 +60,7 @@ async def list_knowledge_in_brain_endpoint(
], ],
tags=["Knowledge"], tags=["Knowledge"],
) )
async def delete_endpoint( async def delete_knowledge_brain(
knowledge_id: UUID, knowledge_id: UUID,
knowledge_service: KnowledgeServiceDep, knowledge_service: KnowledgeServiceDep,
current_user: UserIdentity = Depends(get_current_user), current_user: UserIdentity = Depends(get_current_user),
@ -65,7 +72,7 @@ async def delete_endpoint(
knowledge = await knowledge_service.get_knowledge(knowledge_id) knowledge = await knowledge_service.get_knowledge(knowledge_id)
file_name = knowledge.file_name if knowledge.file_name else knowledge.url file_name = knowledge.file_name if knowledge.file_name else knowledge.url
await knowledge_service.remove_knowledge(brain_id, knowledge_id) await knowledge_service.remove_knowledge_brain(brain_id, knowledge_id)
return { return {
"message": f"{file_name} of brain {brain_id} has been deleted by user {current_user.email}." "message": f"{file_name} of brain {brain_id} has been deleted by user {current_user.email}."
@ -88,13 +95,13 @@ async def generate_signed_url_endpoint(
knowledge = await knowledge_service.get_knowledge(knowledge_id) knowledge = await knowledge_service.get_knowledge(knowledge_id)
if len(knowledge.brain_ids) == 0: if len(knowledge.brains) == 0:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, status_code=HTTPStatus.NOT_FOUND,
detail="knowledge not associated with brains yet.", detail="knowledge not associated with brains yet.",
) )
brain_id = knowledge.brain_ids[0] brain_id = knowledge.brains[0]["brain_id"]
validate_brain_authorization(brain_id=brain_id, user_id=current_user.id) validate_brain_authorization(brain_id=brain_id, user_id=current_user.id)
@ -108,3 +115,153 @@ async def generate_signed_url_endpoint(
file_signed_url = generate_file_signed_url(file_path_in_storage) file_signed_url = generate_file_signed_url(file_path_in_storage)
return file_signed_url return file_signed_url
@knowledge_router.post(
"/knowledge/",
tags=["Knowledge"],
response_model=Knowledge,
)
async def create_knowledge(
knowledge_data: str = File(...),
file: Optional[UploadFile] = None,
knowledge_service: KnowledgeService = Depends(get_km_service),
current_user: UserIdentity = Depends(get_current_user),
):
knowledge = AddKnowledge.model_validate_json(knowledge_data)
if not knowledge.file_name and not knowledge.url:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Either file_name or url must be provided",
)
try:
km = await knowledge_service.create_knowledge(
knowledge_to_add=knowledge, upload_file=file, user_id=current_user.id
)
km_dto = await km.to_dto()
return km_dto
except ValueError:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Unprocessable knowledge ",
)
except FileExistsError:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail="Existing knowledge"
)
except UploadError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Error occured uploading knowledge",
)
except Exception:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
@knowledge_router.get(
"/knowledge/children",
response_model=List[Knowledge] | None,
tags=["Knowledge"],
)
async def list_knowledge(
parent_id: UUID | None = None,
knowledge_service: KnowledgeService = Depends(get_km_service),
current_user: UserIdentity = Depends(get_current_user),
):
try:
# TODO: Returns one level of children
children = await knowledge_service.list_knowledge(parent_id, current_user.id)
return [await c.to_dto(get_children=False) for c in children]
except KnowledgeNotFoundException as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=f"{e.message}"
)
except KnowledgeForbiddenAccess as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}"
)
except Exception:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
@knowledge_router.get(
"/knowledge/{knowledge_id}",
response_model=Knowledge,
tags=["Knowledge"],
)
async def get_knowledge(
knowledge_id: UUID,
knowledge_service: KnowledgeService = Depends(get_km_service),
current_user: UserIdentity = Depends(get_current_user),
):
try:
km = await knowledge_service.get_knowledge(knowledge_id)
if km.user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have permission to access this knowledge.",
)
return await km.to_dto()
except KnowledgeNotFoundException as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}"
)
except Exception:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
@knowledge_router.patch(
"/knowledge/{knowledge_id}",
status_code=status.HTTP_202_ACCEPTED,
response_model=Knowledge,
tags=["Knowledge"],
)
async def update_knowledge(
knowledge_id: UUID,
payload: KnowledgeUpdate,
knowledge_service: KnowledgeService = Depends(get_km_service),
current_user: UserIdentity = Depends(get_current_user),
):
try:
km = await knowledge_service.get_knowledge(knowledge_id)
if km.user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have permission to access this knowledge.",
)
km = await knowledge_service.update_knowledge(km, payload)
return km
except KnowledgeNotFoundException as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}"
)
except Exception:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
@knowledge_router.delete(
"/knowledge/{knowledge_id}",
status_code=status.HTTP_202_ACCEPTED,
tags=["Knowledge"],
)
async def delete_knowledge(
knowledge_id: UUID,
knowledge_service: KnowledgeService = Depends(get_km_service),
current_user: UserIdentity = Depends(get_current_user),
):
try:
km = await knowledge_service.get_knowledge(knowledge_id)
if km.user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You do not have permission to remove this knowledge.",
)
delete_response = await knowledge_service.remove_knowledge(km)
return delete_response
except KnowledgeNotFoundException as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}"
)
except KnowledgeDeleteError:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)

View File

@ -16,8 +16,16 @@ class CreateKnowledgeProperties(BaseModel):
file_size: Optional[int] = None file_size: Optional[int] = None
file_sha1: Optional[str] = None file_sha1: Optional[str] = None
metadata: Optional[Dict[str, str]] = None metadata: Optional[Dict[str, str]] = None
is_folder: bool = False
parent_id: Optional[UUID] = None
def dict(self, *args, **kwargs):
knowledge_dict = super().dict(*args, **kwargs) class AddKnowledge(BaseModel):
knowledge_dict["brain_id"] = str(knowledge_dict.get("brain_id")) file_name: Optional[str] = None
return knowledge_dict url: Optional[str] = None
extension: str = ".txt"
source: str = "local"
source_link: Optional[str] = None
metadata: Optional[Dict[str, str]] = None
is_folder: bool = False
parent_id: Optional[UUID] = None

View File

@ -4,6 +4,6 @@ from pydantic import BaseModel
class DeleteKnowledgeResponse(BaseModel): class DeleteKnowledgeResponse(BaseModel):
file_name: str file_name: str | None = None
status: str = "delete" status: str = "DELETED"
knowledge_id: UUID knowledge_id: UUID

View File

@ -1,5 +1,6 @@
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional from enum import Enum
from typing import Any, Dict, List, Optional
from uuid import UUID from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
@ -12,20 +13,44 @@ from sqlmodel import Field, Relationship, SQLModel
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
class KnowledgeSource(str, Enum):
LOCAL = "local"
WEB = "web"
GDRIVE = "google drive"
DROPBOX = "dropbox"
SHAREPOINT = "sharepoint"
class Knowledge(BaseModel): class Knowledge(BaseModel):
id: UUID id: UUID
file_size: int = 0
status: KnowledgeStatus
file_name: Optional[str] = None file_name: Optional[str] = None
url: Optional[str] = None url: Optional[str] = None
extension: str = ".txt" extension: str = ".txt"
status: str is_folder: bool = False
updated_at: datetime
created_at: datetime
source: Optional[str] = None source: Optional[str] = None
source_link: Optional[str] = None source_link: Optional[str] = None
file_size: Optional[int] = None
file_sha1: Optional[str] = None file_sha1: Optional[str] = None
updated_at: Optional[datetime] = None
created_at: Optional[datetime] = None
metadata: Optional[Dict[str, str]] = None metadata: Optional[Dict[str, str]] = None
brain_ids: list[UUID] user_id: UUID
brains: List[Dict[str, Any]]
parent: Optional["Knowledge"]
children: Optional[list["Knowledge"]]
class KnowledgeUpdate(BaseModel):
file_name: Optional[str] = None
status: Optional[KnowledgeStatus] = None
url: Optional[str] = None
file_sha1: Optional[str] = None
extension: Optional[str] = None
parent_id: Optional[UUID] = None
source: Optional[str] = None
source_link: Optional[str] = None
metadata: Optional[Dict[str, str]] = None
class KnowledgeDB(AsyncAttrs, SQLModel, table=True): class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
@ -49,13 +74,6 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
file_sha1: Optional[str] = Field( file_sha1: Optional[str] = Field(
max_length=40 max_length=40
) # FIXME: Should not be optional @chloedia ) # FIXME: Should not be optional @chloedia
updated_at: datetime | None = Field(
default=None,
sa_column=Column(
TIMESTAMP(timezone=False),
server_default=text("CURRENT_TIMESTAMP"),
),
)
created_at: datetime | None = Field( created_at: datetime | None = Field(
default=None, default=None,
sa_column=Column( sa_column=Column(
@ -63,9 +81,18 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
server_default=text("CURRENT_TIMESTAMP"), server_default=text("CURRENT_TIMESTAMP"),
), ),
) )
updated_at: datetime | None = Field(
default=None,
sa_column=Column(
TIMESTAMP(timezone=False),
server_default=text("CURRENT_TIMESTAMP"),
onupdate=datetime.utcnow,
),
)
metadata_: Optional[Dict[str, str]] = Field( metadata_: Optional[Dict[str, str]] = Field(
default=None, sa_column=Column("metadata", JSON) default=None, sa_column=Column("metadata", JSON)
) )
is_folder: bool = Field(default=False)
user_id: UUID = Field(foreign_key="users.id", nullable=False) user_id: UUID = Field(foreign_key="users.id", nullable=False)
brains: List["Brain"] = Relationship( brains: List["Brain"] = Relationship(
back_populates="knowledges", back_populates="knowledges",
@ -73,10 +100,35 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
sa_relationship_kwargs={"lazy": "select"}, sa_relationship_kwargs={"lazy": "select"},
) )
async def to_dto(self) -> Knowledge: parent_id: UUID | None = Field(
default=None, foreign_key="knowledge.id", ondelete="CASCADE"
)
parent: Optional["KnowledgeDB"] = Relationship(
back_populates="children",
sa_relationship_kwargs={"remote_side": "KnowledgeDB.id"},
)
children: list["KnowledgeDB"] = Relationship(
back_populates="parent",
sa_relationship_kwargs={
"cascade": "all, delete-orphan",
},
)
# TODO: nested folder search
async def to_dto(self, get_children: bool = True) -> Knowledge:
assert (
self.updated_at
), "knowledge should be inserted before transforming to dto"
assert (
self.created_at
), "knowledge should be inserted before transforming to dto"
brains = await self.awaitable_attrs.brains brains = await self.awaitable_attrs.brains
size = self.file_size if self.file_size else 0 children: list[KnowledgeDB] = (
sha1 = self.file_sha1 if self.file_sha1 else "" await self.awaitable_attrs.children if get_children else []
)
parent = await self.awaitable_attrs.parent
parent = await parent.to_dto(get_children=False) if parent else None
return Knowledge( return Knowledge(
id=self.id, # type: ignore id=self.id, # type: ignore
file_name=self.file_name, file_name=self.file_name,
@ -85,10 +137,14 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
status=KnowledgeStatus(self.status), status=KnowledgeStatus(self.status),
source=self.source, source=self.source,
source_link=self.source_link, source_link=self.source_link,
file_size=size, is_folder=self.is_folder,
file_sha1=sha1, file_size=self.file_size or 0,
file_sha1=self.file_sha1,
updated_at=self.updated_at, updated_at=self.updated_at,
created_at=self.created_at, created_at=self.created_at,
metadata=self.metadata_, # type: ignore metadata=self.metadata_, # type: ignore
brain_ids=[brain.brain_id for brain in brains], brains=[b.model_dump() for b in brains],
parent=parent,
children=[await c.to_dto(get_children=False) for c in children],
user_id=self.user_id,
) )

View File

@ -1,9 +1,10 @@
from typing import Sequence from typing import Any, Sequence
from uuid import UUID from uuid import UUID
from fastapi import HTTPException from fastapi import HTTPException
from quivr_core.models import KnowledgeStatus from quivr_core.models import KnowledgeStatus
from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy.orm import joinedload
from sqlmodel import select, text from sqlmodel import select, text
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@ -11,7 +12,15 @@ from quivr_api.logger import get_logger
from quivr_api.modules.brain.entity.brain_entity import Brain from quivr_api.modules.brain.entity.brain_entity import Brain
from quivr_api.modules.dependencies import BaseRepository, get_supabase_client from quivr_api.modules.dependencies import BaseRepository, get_supabase_client
from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_api.modules.knowledge.entity.knowledge import (
Knowledge,
KnowledgeDB,
KnowledgeUpdate,
)
from quivr_api.modules.knowledge.service.knowledge_exceptions import (
KnowledgeNotFoundException,
KnowledgeUpdateError,
)
logger = get_logger(__name__) logger = get_logger(__name__)
@ -22,7 +31,43 @@ class KnowledgeRepository(BaseRepository):
supabase_client = get_supabase_client() supabase_client = get_supabase_client()
self.db = supabase_client self.db = supabase_client
async def insert_knowledge( async def create_knowledge(self, knowledge: KnowledgeDB) -> KnowledgeDB:
try:
self.session.add(knowledge)
await self.session.commit()
await self.session.refresh(knowledge)
except IntegrityError:
await self.session.rollback()
raise
except Exception:
await self.session.rollback()
raise
return knowledge
async def update_knowledge(
self,
knowledge: KnowledgeDB,
payload: Knowledge | KnowledgeUpdate | dict[str, Any],
) -> KnowledgeDB:
try:
logger.debug(f"updating {knowledge.id} with payload {payload}")
if isinstance(payload, dict):
update_data = payload
else:
update_data = payload.model_dump(exclude_unset=True)
for field in update_data:
setattr(knowledge, field, update_data[field])
self.session.add(knowledge)
await self.session.commit()
await self.session.refresh(knowledge)
return knowledge
except IntegrityError as e:
await self.session.rollback()
logger.error(f"Error updating knowledge {e}")
raise KnowledgeUpdateError
async def insert_knowledge_brain(
self, knowledge: KnowledgeDB, brain_id: UUID self, knowledge: KnowledgeDB, brain_id: UUID
) -> KnowledgeDB: ) -> KnowledgeDB:
logger.debug(f"Inserting knowledge {knowledge}") logger.debug(f"Inserting knowledge {knowledge}")
@ -69,6 +114,14 @@ class KnowledgeRepository(BaseRepository):
await self.session.refresh(knowledge) await self.session.refresh(knowledge)
return knowledge return knowledge
async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeResponse:
assert knowledge.id
await self.session.delete(knowledge)
await self.session.commit()
return DeleteKnowledgeResponse(
status="deleted", knowledge_id=knowledge.id, file_name=knowledge.file_name
)
async def remove_knowledge_by_id( async def remove_knowledge_by_id(
self, knowledge_id: UUID self, knowledge_id: UUID
) -> DeleteKnowledgeResponse: ) -> DeleteKnowledgeResponse:
@ -126,14 +179,70 @@ class KnowledgeRepository(BaseRepository):
return knowledge return knowledge
async def get_knowledge_by_id(self, knowledge_id: UUID) -> KnowledgeDB: async def get_all_children(self, parent_id: UUID) -> list[KnowledgeDB]:
query = select(KnowledgeDB).where(KnowledgeDB.id == knowledge_id) query = text("""
WITH RECURSIVE knowledge_tree AS (
SELECT *
FROM knowledge
WHERE parent_id = :parent_id
UNION ALL
SELECT k.*
FROM knowledge k
JOIN knowledge_tree kt ON k.parent_id = kt.id
)
SELECT * FROM knowledge_tree
""")
result = await self.session.execute(query, params={"parent_id": parent_id})
rows = result.fetchall()
knowledge_list = []
for row in rows:
knowledge = KnowledgeDB(
id=row.id,
parent_id=row.parent_id,
file_name=row.file_name,
url=row.url,
extension=row.extension,
status=row.status,
source=row.source,
source_link=row.source_link,
file_size=row.file_size,
file_sha1=row.file_sha1,
created_at=row.created_at,
updated_at=row.updated_at,
metadata_=row.metadata,
is_folder=row.is_folder,
user_id=row.user_id,
)
knowledge_list.append(knowledge)
return knowledge_list
async def get_root_knowledge_user(self, user_id: UUID) -> list[KnowledgeDB]:
query = (
select(KnowledgeDB)
.where(KnowledgeDB.parent_id.is_(None)) # type: ignore
.where(KnowledgeDB.user_id == user_id)
.options(joinedload(KnowledgeDB.parent), joinedload(KnowledgeDB.children)) # type: ignore
)
result = await self.session.exec(query)
kms = result.unique().all()
return list(kms)
async def get_knowledge_by_id(
self, knowledge_id: UUID, user_id: UUID | None = None
) -> KnowledgeDB:
query = (
select(KnowledgeDB)
.where(KnowledgeDB.id == knowledge_id)
.options(joinedload(KnowledgeDB.parent), joinedload(KnowledgeDB.children)) # type: ignore
)
if user_id:
query = query.where(KnowledgeDB.user_id == user_id)
result = await self.session.exec(query) result = await self.session.exec(query)
knowledge = result.first() knowledge = result.first()
if not knowledge: if not knowledge:
raise NoResultFound("Knowledge not found") raise KnowledgeNotFoundException("Knowledge not found")
return knowledge return knowledge
async def get_brain_by_id(self, brain_id: UUID) -> Brain: async def get_brain_by_id(self, brain_id: UUID) -> Brain:

View File

@ -1,29 +1,87 @@
import mimetypes
from io import BufferedReader, FileIO
from quivr_api.logger import get_logger from quivr_api.logger import get_logger
from quivr_api.modules.dependencies import get_supabase_client from quivr_api.modules.dependencies import get_supabase_async_client
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface
logger = get_logger(__name__) logger = get_logger(__name__)
class Storage(StorageInterface): class SupabaseS3Storage(StorageInterface):
def __init__(self): def __init__(self):
supabase_client = get_supabase_client() self.client = None
self.db = supabase_client
def upload_file(self, file_name: str): async def _set_client(self):
""" if self.client is None:
Upload file to storage self.client = await get_supabase_async_client()
"""
self.db.storage.from_("quivr").download(file_name)
def remove_file(self, file_name: str): def get_storage_path(
self,
knowledge: KnowledgeDB,
) -> str:
if knowledge.id is None:
raise ValueError("knowledge should have a valid id")
return str(knowledge.id)
async def upload_file_storage(
self,
knowledge: KnowledgeDB,
knowledge_data: FileIO | BufferedReader | bytes,
upsert: bool = False,
):
await self._set_client()
assert self.client
mime_type = "application/html"
if knowledge.file_name:
guessed_mime_type, _ = mimetypes.guess_type(knowledge.file_name)
mime_type = guessed_mime_type or mime_type
storage_path = self.get_storage_path(knowledge)
logger.info(
f"Uploading file to s3://quivr/{storage_path} using supabase. upsert={upsert}, mimetype={mime_type}"
)
if upsert:
_ = await self.client.storage.from_("quivr").update(
storage_path,
knowledge_data,
file_options={
"content-type": mime_type,
"upsert": "true",
"cache-control": "3600",
},
)
return storage_path
else:
# check if file sha1 is already in storage
try:
_ = await self.client.storage.from_("quivr").upload(
storage_path,
knowledge_data,
file_options={
"content-type": mime_type,
"upsert": "false",
"cache-control": "3600",
},
)
return storage_path
except Exception as e:
if "The resource already exists" in str(e) and not upsert:
raise FileExistsError(f"File {storage_path} already exists")
raise e
async def remove_file(self, storage_path: str):
""" """
Remove file from storage Remove file from storage
""" """
await self._set_client()
assert self.client
try: try:
response = self.db.storage.from_("quivr").remove([file_name]) response = await self.client.storage.from_("quivr").remove([storage_path])
return response return response
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
# raise e

View File

@ -1,10 +1,26 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from io import BufferedReader, FileIO
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
class StorageInterface(ABC): class StorageInterface(ABC):
@abstractmethod @abstractmethod
def remove_file(self, file_name: str): def get_storage_path(
""" self,
Remove file from storage knowledge: KnowledgeDB,
""" ) -> str:
pass
@abstractmethod
async def upload_file_storage(
self,
knowledge: KnowledgeDB,
knowledge_data: FileIO | BufferedReader | bytes,
upsert: bool = False,
):
pass
@abstractmethod
async def remove_file(self, storage_path: str):
pass pass

View File

@ -0,0 +1,34 @@
class KnowledgeException(Exception):
def __init__(self, message="A knowledge-related error occurred"):
self.message = message
super().__init__(self.message)
class UploadError(KnowledgeException):
def __init__(self, message="An error occurred while uploading"):
super().__init__(message)
class KnowledgeCreationError(KnowledgeException):
def __init__(self, message="An error occurred while creating"):
super().__init__(message)
class KnowledgeUpdateError(KnowledgeException):
def __init__(self, message="An error occurred while updating"):
super().__init__(message)
class KnowledgeDeleteError(KnowledgeException):
def __init__(self, message="An error occurred while deleting"):
super().__init__(message)
class KnowledgeForbiddenAccess(KnowledgeException):
def __init__(self, message="You do not have permission to access this knowledge."):
super().__init__(message)
class KnowledgeNotFoundException(KnowledgeException):
def __init__(self, message="The requested knowledge was not found"):
super().__init__(message)

View File

@ -1,18 +1,33 @@
from typing import List import asyncio
import io
from typing import Any, List
from uuid import UUID from uuid import UUID
from fastapi import UploadFile
from quivr_core.models import KnowledgeStatus from quivr_core.models import KnowledgeStatus
from sqlalchemy.exc import NoResultFound from sqlalchemy.exc import NoResultFound
from quivr_api.logger import get_logger from quivr_api.logger import get_logger
from quivr_api.modules.dependencies import BaseService from quivr_api.modules.dependencies import BaseService
from quivr_api.modules.knowledge.dto.inputs import ( from quivr_api.modules.knowledge.dto.inputs import (
AddKnowledge,
CreateKnowledgeProperties, CreateKnowledgeProperties,
) )
from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse
from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeDB from quivr_api.modules.knowledge.entity.knowledge import (
Knowledge,
KnowledgeDB,
KnowledgeSource,
KnowledgeUpdate,
)
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
from quivr_api.modules.knowledge.repository.storage import Storage from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface
from quivr_api.modules.knowledge.service.knowledge_exceptions import (
KnowledgeDeleteError,
KnowledgeForbiddenAccess,
UploadError,
)
from quivr_api.modules.sync.entity.sync_models import ( from quivr_api.modules.sync.entity.sync_models import (
DBSyncFile, DBSyncFile,
DownloadedSyncFile, DownloadedSyncFile,
@ -26,9 +41,13 @@ logger = get_logger(__name__)
class KnowledgeService(BaseService[KnowledgeRepository]): class KnowledgeService(BaseService[KnowledgeRepository]):
repository_cls = KnowledgeRepository repository_cls = KnowledgeRepository
def __init__(self, repository: KnowledgeRepository): def __init__(
self,
repository: KnowledgeRepository,
storage: StorageInterface = SupabaseS3Storage(),
):
self.repository = repository self.repository = repository
self.storage = Storage() self.storage = storage
async def get_knowledge_sync(self, sync_id: int) -> Knowledge: async def get_knowledge_sync(self, sync_id: int) -> Knowledge:
km = await self.repository.get_knowledge_by_sync_id(sync_id) km = await self.repository.get_knowledge_by_sync_id(sync_id)
@ -54,19 +73,37 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
except NoResultFound: except NoResultFound:
raise FileNotFoundError(f"No knowledge for file_name: {file_name}") raise FileNotFoundError(f"No knowledge for file_name: {file_name}")
async def get_knowledge(self, knowledge_id: UUID) -> Knowledge: async def list_knowledge(
inserted_knowledge_db_instance = await self.repository.get_knowledge_by_id( self, knowledge_id: UUID | None, user_id: UUID | None = None
knowledge_id ) -> list[KnowledgeDB]:
if knowledge_id is not None:
km = await self.repository.get_knowledge_by_id(knowledge_id, user_id)
return km.children
else:
if user_id is None:
raise KnowledgeForbiddenAccess(
"can't get root knowledges without user_id"
) )
assert inserted_knowledge_db_instance.id, "Knowledge ID not generated" return await self.repository.get_root_knowledge_user(user_id)
km = await inserted_knowledge_db_instance.to_dto()
return km
async def get_knowledge(
self, knowledge_id: UUID, user_id: UUID | None = None
) -> KnowledgeDB:
return await self.repository.get_knowledge_by_id(knowledge_id, user_id)
async def update_knowledge(
self,
knowledge: KnowledgeDB,
payload: Knowledge | KnowledgeUpdate | dict[str, Any],
):
return await self.repository.update_knowledge(knowledge, payload)
# TODO: Remove all of this
# TODO (@aminediro): Replace with ON CONFLICT smarter query... # TODO (@aminediro): Replace with ON CONFLICT smarter query...
# there is a chance of race condition but for now we let it crash in worker # there is a chance of race condition but for now we let it crash in worker
# the tasks will be dealt with on retry # the tasks will be dealt with on retry
async def update_sha1_conflict( async def update_sha1_conflict(
self, knowledge: Knowledge, brain_id: UUID, file_sha1: str self, knowledge: KnowledgeDB, brain_id: UUID, file_sha1: str
) -> bool: ) -> bool:
assert knowledge.id assert knowledge.id
knowledge.file_sha1 = file_sha1 knowledge.file_sha1 = file_sha1
@ -89,12 +126,12 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
) )
else: else:
await self.repository.link_to_brain(existing_knowledge, brain_id) await self.repository.link_to_brain(existing_knowledge, brain_id)
await self.remove_knowledge(brain_id, knowledge.id) await self.remove_knowledge_brain(brain_id, knowledge.id)
return False return False
else: else:
logger.debug(f"Removing previous errored file {existing_knowledge.id}") logger.debug(f"Removing previous errored file {existing_knowledge.id}")
assert existing_knowledge.id assert existing_knowledge.id
await self.remove_knowledge(brain_id, existing_knowledge.id) await self.remove_knowledge_brain(brain_id, existing_knowledge.id)
await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1) await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1)
return True return True
except NoResultFound: except NoResultFound:
@ -104,7 +141,47 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1) await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1)
return True return True
async def insert_knowledge( async def create_knowledge(
self,
user_id: UUID,
knowledge_to_add: AddKnowledge,
upload_file: UploadFile | None = None,
) -> KnowledgeDB:
knowledgedb = KnowledgeDB(
user_id=user_id,
file_name=knowledge_to_add.file_name,
is_folder=knowledge_to_add.is_folder,
url=knowledge_to_add.url,
extension=knowledge_to_add.extension,
source=knowledge_to_add.source,
source_link=knowledge_to_add.source_link,
file_size=upload_file.size if upload_file else 0,
metadata_=knowledge_to_add.metadata, # type: ignore
status=KnowledgeStatus.RESERVED,
parent_id=knowledge_to_add.parent_id,
)
knowledge_db = await self.repository.create_knowledge(knowledgedb)
try:
if knowledgedb.source == KnowledgeSource.LOCAL and upload_file:
# NOTE(@aminediro): Unnecessary mem buffer because supabase doesnt accept FileIO..
buff_reader = io.BufferedReader(upload_file.file) # type: ignore
storage_path = await self.storage.upload_file_storage(
knowledgedb, buff_reader
)
knowledgedb.source_link = storage_path
knowledge_db = await self.repository.update_knowledge(
knowledge_db,
KnowledgeUpdate(status=KnowledgeStatus.UPLOADED), # type: ignore
)
return knowledge_db
except Exception as e:
logger.exception(
f"Error uploading knowledge {knowledgedb.id} to storage : {e}"
)
await self.repository.remove_knowledge(knowledge=knowledge_db)
raise UploadError()
async def insert_knowledge_brain(
self, self,
user_id: UUID, user_id: UUID,
knowledge_to_add: CreateKnowledgeProperties, # FIXME: (later) @Amine brain id should not be in CreateKnowledgeProperties but since storage is brain_id/file_name knowledge_to_add: CreateKnowledgeProperties, # FIXME: (later) @Amine brain id should not be in CreateKnowledgeProperties but since storage is brain_id/file_name
@ -122,7 +199,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
user_id=user_id, user_id=user_id,
) )
knowledge_db = await self.repository.insert_knowledge( knowledge_db = await self.repository.insert_knowledge_brain(
knowledge, brain_id=knowledge_to_add.brain_id knowledge, brain_id=knowledge_to_add.brain_id
) )
@ -150,7 +227,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
assert isinstance(knowledge.file_name, str), "file_name should be a string" assert isinstance(knowledge.file_name, str), "file_name should be a string"
file_name_with_brain_id = f"{brain_id}/{knowledge.file_name}" file_name_with_brain_id = f"{brain_id}/{knowledge.file_name}"
try: try:
self.storage.remove_file(file_name_with_brain_id) await self.storage.remove_file(file_name_with_brain_id)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error while removing file {file_name_with_brain_id}: {e}" f"Error while removing file {file_name_with_brain_id}: {e}"
@ -161,29 +238,52 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
async def update_file_sha1_knowledge(self, knowledge_id: UUID, file_sha1: str): async def update_file_sha1_knowledge(self, knowledge_id: UUID, file_sha1: str):
return await self.repository.update_file_sha1_knowledge(knowledge_id, file_sha1) return await self.repository.update_file_sha1_knowledge(knowledge_id, file_sha1)
async def remove_knowledge( async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeResponse:
assert knowledge.id
try:
# TODO:
# - Notion folders are special, they are themselves files and should be removed from storage
children = await self.repository.get_all_children(knowledge.id)
km_paths = [
self.storage.get_storage_path(k) for k in children if not k.is_folder
]
if not knowledge.is_folder:
km_paths.append(self.storage.get_storage_path(knowledge))
# recursively deletes files
deleted_km = await self.repository.remove_knowledge(knowledge)
await asyncio.gather(*[self.storage.remove_file(p) for p in km_paths])
return deleted_km
except Exception as e:
logger.error(f"Error while remove knowledge : {e}")
raise KnowledgeDeleteError
async def remove_knowledge_brain(
self, self,
brain_id: UUID, brain_id: UUID,
knowledge_id: UUID, # FIXME: @amine when name in storage change no need for brain id knowledge_id: UUID, # FIXME: @amine when name in storage change no need for brain id
) -> DeleteKnowledgeResponse: ) -> DeleteKnowledgeResponse:
# TODO: fix KMS # TODO: fix KMS
# REDO ALL THIS # REDO ALL THIS
knowledge = await self.get_knowledge(knowledge_id) knowledge = await self.repository.get_knowledge_by_id(knowledge_id)
if len(knowledge.brain_ids) > 1: km_brains = await knowledge.awaitable_attrs.brains
if len(km_brains) > 1:
km = await self.repository.remove_knowledge_from_brain( km = await self.repository.remove_knowledge_from_brain(
knowledge_id, brain_id knowledge_id, brain_id
) )
assert km.id
return DeleteKnowledgeResponse(file_name=km.file_name, knowledge_id=km.id) return DeleteKnowledgeResponse(file_name=km.file_name, knowledge_id=km.id)
else: else:
message = await self.repository.remove_knowledge_by_id(knowledge_id) message = await self.repository.remove_knowledge_by_id(knowledge_id)
file_name_with_brain_id = f"{brain_id}/{message.file_name}" file_name_with_brain_id = f"{brain_id}/{message.file_name}"
try: try:
self.storage.remove_file(file_name_with_brain_id) await self.storage.remove_file(file_name_with_brain_id)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error while removing file {file_name_with_brain_id}: {e}" f"Error while removing file {file_name_with_brain_id}: {e}"
) )
return message return message
async def remove_all_knowledges_from_brain(self, brain_id: UUID) -> None: async def remove_all_knowledges_from_brain(self, brain_id: UUID) -> None:
@ -210,7 +310,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
# TODO: THIS IS A HACK!! Remove all of this # TODO: THIS IS A HACK!! Remove all of this
if prev_sync_file: if prev_sync_file:
prev_knowledge = await self.get_knowledge_sync(sync_id=prev_sync_file.id) prev_knowledge = await self.get_knowledge_sync(sync_id=prev_sync_file.id)
if len(prev_knowledge.brain_ids) > 1: if len(prev_knowledge.brains) > 1:
await self.repository.remove_knowledge_from_brain( await self.repository.remove_knowledge_from_brain(
prev_knowledge.id, brain_id prev_knowledge.id, brain_id
) )
@ -231,7 +331,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
file_sha1=None, file_sha1=None,
metadata={"sync_file_id": str(sync_id)}, metadata={"sync_file_id": str(sync_id)},
) )
added_knowledge = await self.insert_knowledge( added_knowledge = await self.insert_knowledge_brain(
knowledge_to_add=knowledge_to_add, user_id=user_id knowledge_to_add=knowledge_to_add, user_id=user_id
) )
return added_knowledge return added_knowledge

View File

@ -0,0 +1,67 @@
from io import BufferedReader, FileIO
from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeDB
from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface
class ErrorStorage(StorageInterface):
async def upload_file_storage(
self,
knowledge: KnowledgeDB,
knowledge_data: FileIO | BufferedReader | bytes,
upsert: bool = False,
):
raise SystemError
def get_storage_path(
self,
knowledge: KnowledgeDB | Knowledge,
) -> str:
if knowledge.id is None:
raise ValueError("knowledge should have a valid id")
return str(knowledge.id)
async def remove_file(self, storage_path: str):
raise SystemError
class FakeStorage(StorageInterface):
def __init__(self):
self.storage = {}
def get_storage_path(
self,
knowledge: KnowledgeDB | Knowledge,
) -> str:
if knowledge.id is None:
raise ValueError("knowledge should have a valid id")
return str(knowledge.id)
async def upload_file_storage(
self,
knowledge: KnowledgeDB,
knowledge_data: FileIO | BufferedReader | bytes,
upsert: bool = False,
):
storage_path = f"{knowledge.id}"
if not upsert and storage_path in self.storage:
raise ValueError(f"File already exists at {storage_path}")
self.storage[storage_path] = knowledge_data
return storage_path
async def remove_file(self, storage_path: str):
if storage_path not in self.storage:
raise FileNotFoundError(f"File not found at {storage_path}")
del self.storage[storage_path]
# Additional helper methods for testing
def get_file(self, storage_path: str) -> FileIO | BufferedReader | bytes:
if storage_path not in self.storage:
raise FileNotFoundError(f"File not found at {storage_path}")
return self.storage[storage_path]
def knowledge_exists(self, knowledge: KnowledgeDB | Knowledge) -> bool:
return self.get_storage_path(knowledge) in self.storage
def clear_storage(self):
self.storage.clear()

View File

@ -0,0 +1,74 @@
import json
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.main import app
from quivr_api.middlewares.auth.auth_bearer import get_current_user
from quivr_api.modules.knowledge.controller.knowledge_routes import get_km_service
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.knowledge.tests.conftest import FakeStorage
from quivr_api.modules.user.entity.user_identity import User, UserIdentity
@pytest_asyncio.fixture(scope="function")
async def user(session: AsyncSession) -> User:
user_1 = (
await session.exec(select(User).where(User.email == "admin@quivr.app"))
).one()
assert user_1.id
return user_1
@pytest_asyncio.fixture(scope="function")
async def test_client(session: AsyncSession, user: User):
def default_current_user() -> UserIdentity:
assert user.id
return UserIdentity(email=user.email, id=user.id)
async def test_service():
storage = FakeStorage()
repository = KnowledgeRepository(session)
return KnowledgeService(repository, storage)
app.dependency_overrides[get_current_user] = default_current_user
app.dependency_overrides[get_km_service] = test_service
# app.dependency_overrides[get_async_session] = lambda: session
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as ac:
yield ac
app.dependency_overrides = {}
@pytest.mark.asyncio(loop_scope="session")
async def test_post_knowledge(test_client: AsyncClient):
km_data = {
"file_name": "test_file.txt",
"source": "local",
"is_folder": False,
"parent_id": None,
}
multipart_data = {
"knowledge_data": (None, json.dumps(km_data), "application/json"),
"file": ("test_file.txt", b"Test file content", "application/octet-stream"),
}
response = await test_client.post(
"/knowledge/",
files=multipart_data,
)
assert response.status_code == 200
@pytest.mark.asyncio(loop_scope="session")
async def test_add_knowledge_invalid_input(test_client):
response = await test_client.post("/knowledge/", files={})
assert response.status_code == 422

View File

@ -0,0 +1,229 @@
from typing import List, Tuple
from uuid import uuid4
import pytest
import pytest_asyncio
from quivr_core.models import KnowledgeStatus
from sqlmodel import select, text
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
from quivr_api.modules.user.entity.user_identity import User
TestData = Tuple[Brain, List[KnowledgeDB]]
@pytest_asyncio.fixture(scope="function")
async def other_user(session: AsyncSession):
sql = text(
"""
INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES
('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL);
"""
)
await session.execute(sql, params={"id": uuid4()})
other_user = (
await session.exec(select(User).where(User.email == "other@quivr.app"))
).one()
return other_user
@pytest_asyncio.fixture(scope="function")
async def user(session):
user_1 = (
await session.exec(select(User).where(User.email == "admin@quivr.app"))
).one()
return user_1
@pytest_asyncio.fixture(scope="function")
async def brain(session):
brain_1 = Brain(
name="test_brain",
description="this is a test brain",
brain_type=BrainType.integration,
)
session.add(brain_1)
await session.commit()
return brain_1
@pytest_asyncio.fixture(scope="function")
async def folder(session, user):
folder = KnowledgeDB(
file_name="folder_1",
extension="",
status="UPLOADED",
source="local",
source_link="local",
file_size=4,
file_sha1=None,
brains=[],
children=[],
user_id=user.id,
is_folder=True,
)
session.add(folder)
await session.commit()
await session.refresh(folder)
return folder
@pytest.mark.asyncio(loop_scope="session")
async def test_knowledge_default_file(session, folder, user):
km = KnowledgeDB(
file_name="test_file_1.txt",
extension=".txt",
status="UPLOADED",
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1="test_sha1",
brains=[],
user_id=user.id,
parent_id=folder.id,
)
session.add(km)
await session.commit()
await session.refresh(km)
assert not km.is_folder
@pytest.mark.asyncio(loop_scope="session")
async def test_knowledge_parent(session: AsyncSession, user: User):
assert user.id
km = KnowledgeDB(
file_name="test_file_1.txt",
extension=".txt",
status="UPLOADED",
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1="test_sha1",
brains=[],
user_id=user.id,
)
folder = KnowledgeDB(
file_name="folder_1",
extension="",
is_folder=True,
status="UPLOADED",
source="local",
source_link="local",
file_size=-1,
file_sha1=None,
brains=[],
children=[km],
user_id=user.id,
)
session.add(folder)
await session.commit()
await session.refresh(folder)
await session.refresh(km)
parent = await km.awaitable_attrs.parent
assert km.parent_id == folder.id, "parent_id isn't set to folder id"
assert parent.id == folder.id, "parent_id isn't set to folder id"
assert parent.is_folder
query = select(KnowledgeDB).where(KnowledgeDB.id == folder.id)
folder = (await session.exec(query)).first()
assert folder
children = await folder.awaitable_attrs.children
assert len(children) > 0
assert children[0].id == km.id
@pytest.mark.asyncio(loop_scope="session")
async def test_knowledge_remove_folder_cascade(
session: AsyncSession,
folder: KnowledgeDB,
user,
):
km = KnowledgeDB(
file_name="test_file_1.txt",
extension=".txt",
status="UPLOADED",
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1="test_sha1",
brains=[],
user_id=user.id,
parent_id=folder.id,
)
session.add(km)
await session.commit()
await session.refresh(km)
# Check all removed
await session.delete(folder)
await session.commit()
statement = select(KnowledgeDB)
results = (await session.exec(statement)).all()
assert results == []
@pytest.mark.asyncio(loop_scope="session")
async def test_knowledge_dto(session, user, brain):
# add folder in brain
folder = KnowledgeDB(
file_name="folder_1",
extension="",
status="UPLOADED",
source="local",
source_link="local",
file_size=4,
file_sha1=None,
brains=[brain],
children=[],
user_id=user.id,
is_folder=True,
)
km = KnowledgeDB(
file_name="test_file_1.txt",
extension=".txt",
status="UPLOADED",
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1="test_sha1",
user_id=user.id,
brains=[brain],
parent=folder,
)
session.add(km)
session.add(km)
await session.commit()
await session.refresh(km)
km_dto = await km.to_dto()
assert km_dto.file_name == km.file_name
assert km_dto.url == km.url
assert km_dto.extension == km.extension
assert km_dto.status == KnowledgeStatus(km.status)
assert km_dto.source == km.source
assert km_dto.source_link == km.source_link
assert km_dto.is_folder == km.is_folder
assert km_dto.file_size == km.file_size
assert km_dto.file_sha1 == km.file_sha1
assert km_dto.updated_at == km.updated_at
assert km_dto.created_at == km.created_at
assert km_dto.metadata == km.metadata_ # type: ignor
assert km_dto.parent
assert km_dto.parent.id == folder.id
folder_dto = await folder.to_dto()
assert folder_dto.brains[0] == brain.model_dump()
assert folder_dto.children == [await km.to_dto()]

File diff suppressed because it is too large Load Diff

View File

@ -1,450 +0,0 @@
import os
from typing import List, Tuple
from uuid import uuid4
import pytest
import pytest_asyncio
from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlmodel import select, text
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
from quivr_api.modules.knowledge.dto.inputs import KnowledgeStatus
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.upload.service.upload_file import upload_file_storage
from quivr_api.modules.user.entity.user_identity import User
from quivr_api.modules.vector.entity.vector import Vector
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
TestData = Tuple[Brain, List[KnowledgeDB]]
@pytest_asyncio.fixture(scope="function")
async def other_user(session: AsyncSession):
sql = text(
"""
INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES
('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL);
"""
)
await session.execute(sql, params={"id": uuid4()})
other_user = (
await session.exec(select(User).where(User.email == "other@quivr.app"))
).one()
return other_user
@pytest_asyncio.fixture(scope="function")
async def test_data(session: AsyncSession) -> TestData:
user_1 = (
await session.exec(select(User).where(User.email == "admin@quivr.app"))
).one()
assert user_1.id
# Brain data
brain_1 = Brain(
name="test_brain",
description="this is a test brain",
brain_type=BrainType.integration,
)
knowledge_brain_1 = KnowledgeDB(
file_name="test_file_1.txt",
extension=".txt",
status="UPLOADED",
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1="test_sha1",
brains=[brain_1],
user_id=user_1.id,
)
knowledge_brain_2 = KnowledgeDB(
file_name="test_file_2.txt",
extension=".txt",
status="UPLOADED",
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1="test_sha2",
brains=[],
user_id=user_1.id,
)
session.add(brain_1)
session.add(knowledge_brain_1)
session.add(knowledge_brain_2)
await session.commit()
return brain_1, [knowledge_brain_1, knowledge_brain_2]
@pytest.mark.asyncio(loop_scope="session")
async def test_updates_knowledge_status(session: AsyncSession, test_data: TestData):
brain, knowledges = test_data
assert brain.brain_id
assert knowledges[0].id
repo = KnowledgeRepository(session)
await repo.update_status_knowledge(knowledges[0].id, KnowledgeStatus.ERROR)
knowledge = await repo.get_knowledge_by_id(knowledges[0].id)
assert knowledge.status == KnowledgeStatus.ERROR
@pytest.mark.asyncio(loop_scope="session")
async def test_updates_knowledge_status_no_knowledge(
session: AsyncSession, test_data: TestData
):
brain, knowledges = test_data
assert brain.brain_id
assert knowledges[0].id
repo = KnowledgeRepository(session)
with pytest.raises(NoResultFound):
await repo.update_status_knowledge(uuid4(), KnowledgeStatus.UPLOADED)
@pytest.mark.asyncio(loop_scope="session")
async def test_update_knowledge_source_link(session: AsyncSession, test_data: TestData):
brain, knowledges = test_data
assert brain.brain_id
assert knowledges[0].id
repo = KnowledgeRepository(session)
await repo.update_source_link_knowledge(knowledges[0].id, "new_source_link")
knowledge = await repo.get_knowledge_by_id(knowledges[0].id)
assert knowledge.source_link == "new_source_link"
@pytest.mark.asyncio(loop_scope="session")
async def test_remove_knowledge_from_brain(session: AsyncSession, test_data: TestData):
brain, knowledges = test_data
assert brain.brain_id
assert knowledges[0].id
repo = KnowledgeRepository(session)
knowledge = await repo.remove_knowledge_from_brain(knowledges[0].id, brain.brain_id)
assert brain.brain_id not in [
b.brain_id for b in await knowledge.awaitable_attrs.brains
]
@pytest.mark.asyncio(loop_scope="session")
async def test_cascade_remove_knowledge_by_id(
session: AsyncSession, test_data: TestData
):
brain, knowledges = test_data
assert brain.brain_id
assert knowledges[0].id
repo = KnowledgeRepository(session)
await repo.remove_knowledge_by_id(knowledges[0].id)
with pytest.raises(NoResultFound):
await repo.get_knowledge_by_id(knowledges[0].id)
query = select(KnowledgeBrain).where(
KnowledgeBrain.knowledge_id == knowledges[0].id
)
result = await session.exec(query)
knowledge_brain = result.first()
assert knowledge_brain is None
query = select(Vector).where(Vector.knowledge_id == knowledges[0].id)
result = await session.exec(query)
vector = result.first()
assert vector is None
@pytest.mark.asyncio(loop_scope="session")
async def test_remove_all_knowledges_from_brain(
session: AsyncSession, test_data: TestData
):
brain, knowledges = test_data
assert brain.brain_id
# supabase_client = get_supabase_client()
# db = supabase_client
# storage = db.storage.from_("quivr")
# storage.upload(f"{brain.brain_id}/test_file_1", b"test_content")
repo = KnowledgeRepository(session)
service = KnowledgeService(repo)
await repo.remove_all_knowledges_from_brain(brain.brain_id)
knowledges = await service.get_all_knowledge_in_brain(brain.brain_id)
assert len(knowledges) == 0
# response = storage.list(path=f"{brain.brain_id}")
# assert response == []
# FIXME @aminediro &chloedia raise an error when trying to interact with storage UnboundLocalError: cannot access local variable 'response' where it is not associated with a value
@pytest.mark.asyncio(loop_scope="session")
async def test_duplicate_sha1_knowledge_same_user(
session: AsyncSession, test_data: TestData
):
brain, [existing_knowledge, _] = test_data
assert brain.brain_id
assert existing_knowledge.id
assert existing_knowledge.file_sha1
repo = KnowledgeRepository(session)
knowledge = KnowledgeDB(
file_name="test_file_2",
extension="txt",
status="UPLOADED",
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1=existing_knowledge.file_sha1,
brains=[brain],
user_id=existing_knowledge.user_id,
)
with pytest.raises(IntegrityError): # FIXME: Should raise IntegrityError
await repo.insert_knowledge(knowledge, brain.brain_id)
@pytest.mark.asyncio(loop_scope="session")
async def test_duplicate_sha1_knowledge_diff_user(
session: AsyncSession, test_data: TestData, other_user: User
):
brain, knowledges = test_data
assert other_user.id
assert brain.brain_id
assert knowledges[0].id
repo = KnowledgeRepository(session)
knowledge = KnowledgeDB(
file_name="test_file_2",
extension="txt",
status="UPLOADED",
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1=knowledges[0].file_sha1,
brains=[brain],
user_id=other_user.id, # random user id
)
result = await repo.insert_knowledge(knowledge, brain.brain_id)
assert result
@pytest.mark.asyncio(loop_scope="session")
async def test_add_knowledge_to_brain(session: AsyncSession, test_data: TestData):
brain, knowledges = test_data
assert brain.brain_id
assert knowledges[1].id
repo = KnowledgeRepository(session)
await repo.link_to_brain(knowledges[1], brain.brain_id)
knowledge = await repo.get_knowledge_by_id(knowledges[1].id)
brains_of_knowledge = [b.brain_id for b in await knowledge.awaitable_attrs.brains]
assert brain.brain_id in brains_of_knowledge
query = select(KnowledgeBrain).where(
KnowledgeBrain.knowledge_id == knowledges[0].id
and KnowledgeBrain.brain_id == brain.brain_id
)
result = await session.exec(query)
knowledge_brain = result.first()
assert knowledge_brain
# Knowledge Service
@pytest.mark.asyncio(loop_scope="session")
async def test_get_knowledge_in_brain(session: AsyncSession, test_data: TestData):
brain, knowledges = test_data
assert brain.brain_id
repo = KnowledgeRepository(session)
service = KnowledgeService(repo)
list_knowledge = await service.get_all_knowledge_in_brain(brain.brain_id)
assert len(list_knowledge) == 1
brains_of_knowledge = [
b.brain_id for b in await knowledges[0].awaitable_attrs.brains
]
assert list_knowledge[0].id == knowledges[0].id
assert list_knowledge[0].file_name == knowledges[0].file_name
assert brain.brain_id in brains_of_knowledge
@pytest.mark.asyncio(loop_scope="session")
async def test_should_process_knowledge_exists(
session: AsyncSession, test_data: TestData
):
brain, [existing_knowledge, _] = test_data
assert brain.brain_id
new = KnowledgeDB(
file_name="new",
extension="txt",
status="PROCESSING",
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1=None,
brains=[brain],
user_id=existing_knowledge.user_id,
)
session.add(new)
await session.commit()
await session.refresh(new)
incoming_knowledge = await new.to_dto()
repo = KnowledgeRepository(session)
service = KnowledgeService(repo)
assert existing_knowledge.file_sha1
with pytest.raises(FileExistsError):
await service.update_sha1_conflict(
incoming_knowledge, brain.brain_id, file_sha1=existing_knowledge.file_sha1
)
@pytest.mark.asyncio(loop_scope="session")
async def test_should_process_knowledge_link_brain(
session: AsyncSession, test_data: TestData
):
repo = KnowledgeRepository(session)
service = KnowledgeService(repo)
brain, [existing_knowledge, _] = test_data
user_id = existing_knowledge.user_id
assert brain.brain_id
prev = KnowledgeDB(
file_name="prev",
extension=".txt",
status=KnowledgeStatus.UPLOADED,
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1="test1",
brains=[brain],
user_id=user_id,
)
brain_2 = Brain(
name="test_brain",
description="this is a test brain",
brain_type=BrainType.integration,
)
session.add(brain_2)
session.add(prev)
await session.commit()
await session.refresh(prev)
await session.refresh(brain_2)
assert prev.id
assert brain_2.brain_id
new = KnowledgeDB(
file_name="new",
extension="txt",
status="PROCESSING",
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1=None,
brains=[brain_2],
user_id=user_id,
)
session.add(new)
await session.commit()
await session.refresh(new)
incoming_knowledge = await new.to_dto()
assert prev.file_sha1
should_process = await service.update_sha1_conflict(
incoming_knowledge, brain_2.brain_id, file_sha1=prev.file_sha1
)
assert not should_process
# Check prev knowledge was linked
assert incoming_knowledge.file_sha1
prev_knowledge = await service.repository.get_knowledge_by_id(prev.id)
prev_brains = await prev_knowledge.awaitable_attrs.brains
assert {b.brain_id for b in prev_brains} == {
brain.brain_id,
brain_2.brain_id,
}
# Check new knowledge was removed
assert new.id
with pytest.raises(NoResultFound):
await service.repository.get_knowledge_by_id(new.id)
@pytest.mark.asyncio(loop_scope="session")
async def test_should_process_knowledge_prev_error(
session: AsyncSession, test_data: TestData
):
repo = KnowledgeRepository(session)
service = KnowledgeService(repo)
brain, [existing_knowledge, _] = test_data
user_id = existing_knowledge.user_id
assert brain.brain_id
prev = KnowledgeDB(
file_name="prev",
extension="txt",
status=KnowledgeStatus.ERROR,
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1="test1",
brains=[brain],
user_id=user_id,
)
session.add(prev)
await session.commit()
await session.refresh(prev)
assert prev.id
new = KnowledgeDB(
file_name="new",
extension="txt",
status="PROCESSING",
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1=None,
brains=[brain],
user_id=user_id,
)
session.add(new)
await session.commit()
await session.refresh(new)
incoming_knowledge = await new.to_dto()
assert prev.file_sha1
should_process = await service.update_sha1_conflict(
incoming_knowledge, brain.brain_id, file_sha1=prev.file_sha1
)
# Checks we should process this file
assert should_process
# Previous errored file is cleaned up
with pytest.raises(NoResultFound):
await service.repository.get_knowledge_by_id(prev.id)
assert new.id
new = await service.repository.get_knowledge_by_id(new.id)
assert new.file_sha1
@pytest.mark.asyncio(loop_scope="session")
async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData):
brain, [knowledge, _] = test_data
assert knowledge.file_name
repository = KnowledgeRepository(session)
service = KnowledgeService(repository)
brain_2 = Brain(
name="test_brain",
description="this is a test brain",
brain_type=BrainType.integration,
)
session.add(brain_2)
await session.commit()
await session.refresh(brain_2)
assert brain_2.brain_id
km_data = os.urandom(128)
km_path = f"{str(knowledge.brains[0].brain_id)}/{knowledge.file_name}"
await upload_file_storage(km_data, km_path)
# Link knowledge to two brains
await repository.link_to_brain(knowledge, brain_2.brain_id)
storage_path = await service.get_knowledge_storage_path(
knowledge.file_name, brain_2.brain_id
)
assert storage_path == km_path

View File

@ -317,10 +317,11 @@ async def test_process_sync_file_noprev(
created_km = all_km[0] created_km = all_km[0]
assert created_km.file_name == sync_file.name assert created_km.file_name == sync_file.name
assert created_km.extension == ".txt" assert created_km.extension == ".txt"
assert created_km.file_sha1 is not None assert created_km.file_sha1 is None
assert created_km.created_at is not None assert created_km.created_at is not None
assert created_km.metadata == {"sync_file_id": "1"} assert created_km.metadata == {"sync_file_id": "1"}
assert created_km.brain_ids == [brain_1.brain_id] assert len(created_km.brains)> 0
assert created_km.brains[0]["brain_id"]== brain_1.brain_id
# Assert celery task in correct # Assert celery task in correct
assert task["args"] == ("process_file_task",) assert task["args"] == ("process_file_task",)
@ -409,12 +410,12 @@ async def test_process_sync_file_with_prev(
created_km = all_km[0] created_km = all_km[0]
assert created_km.file_name == sync_file.name assert created_km.file_name == sync_file.name
assert created_km.extension == ".txt" assert created_km.extension == ".txt"
assert created_km.file_sha1 is not None assert created_km.file_sha1 is None
assert created_km.updated_at assert created_km.updated_at
assert created_km.created_at assert created_km.created_at
assert created_km.updated_at == created_km.created_at # new line assert created_km.updated_at == created_km.created_at # new line
assert created_km.metadata == {"sync_file_id": str(dbfiles[0].id)} assert created_km.metadata == {"sync_file_id": str(dbfiles[0].id)}
assert created_km.brain_ids == [brain_1.brain_id] assert created_km.brains[0]["brain_id"]== brain_1.brain_id
# Check file content changed # Check file content changed
assert check_file_exists(str(brain_1.brain_id), sync_file.name) assert check_file_exists(str(brain_1.brain_id), sync_file.name)

View File

@ -53,12 +53,10 @@ AsyncClientDep = Annotated[AsyncClient, Depends(get_supabase_async_client)]
@upload_router.post("/upload", dependencies=[Depends(AuthBearer())], tags=["Upload"]) @upload_router.post("/upload", dependencies=[Depends(AuthBearer())], tags=["Upload"])
async def upload_file( async def upload_file(
uploadFile: UploadFile, uploadFile: UploadFile,
client: AsyncClientDep,
background_tasks: BackgroundTasks,
knowledge_service: KnowledgeServiceDep, knowledge_service: KnowledgeServiceDep,
background_tasks: BackgroundTasks,
bulk_id: Optional[UUID] = Query(None, description="The ID of the bulk upload"), bulk_id: Optional[UUID] = Query(None, description="The ID of the bulk upload"),
brain_id: UUID = Query(..., description="The ID of the brain"), brain_id: UUID = Query(..., description="The ID of the brain"),
chat_id: Optional[UUID] = Query(None, description="The ID of the chat"),
current_user: UserIdentity = Depends(get_current_user), current_user: UserIdentity = Depends(get_current_user),
integration: Optional[str] = None, integration: Optional[str] = None,
integration_link: Optional[str] = None, integration_link: Optional[str] = None,
@ -121,7 +119,7 @@ async def upload_file(
file_size=uploadFile.size, file_size=uploadFile.size,
file_sha1=None, file_sha1=None,
) )
knowledge = await knowledge_service.insert_knowledge( knowledge = await knowledge_service.insert_knowledge_brain(
user_id=current_user.id, knowledge_to_add=knowledge_to_add user_id=current_user.id, knowledge_to_add=knowledge_to_add
) # type: ignore ) # type: ignore

View File

@ -87,7 +87,7 @@ async def crawl_endpoint(
source_link=crawl_website.url, source_link=crawl_website.url,
) )
added_knowledge = await knowledge_service.insert_knowledge( added_knowledge = await knowledge_service.insert_knowledge_brain(
knowledge_to_add=knowledge_to_add, user_id=current_user.id knowledge_to_add=knowledge_to_add, user_id=current_user.id
) )
logger.info(f"Knowledge {added_knowledge} added successfully") logger.info(f"Knowledge {added_knowledge} added successfully")

View File

@ -0,0 +1,50 @@
from copy import deepcopy
from typing import Any, Callable, Optional, Type, TypeVar
from uuid import UUID
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo
Model = TypeVar("Model", bound=Type[BaseModel])
def all_optional(without_fields: list[str] | None = None) -> Callable[[Model], Model]:
if without_fields is None:
without_fields = []
def wrapper(model: Type[Model]) -> Type[Model]:
base_model: Type[Model] = model
def make_field_optional(
field: FieldInfo, default: Any = None
) -> tuple[Any, FieldInfo]:
new = deepcopy(field)
new.default = default
new.annotation = Optional[field.annotation]
return new.annotation, new
if without_fields:
base_model = BaseModel
return create_model(
model.__name__,
__base__=base_model,
__module__=model.__module__,
**{
field_name: make_field_optional(field_info)
for field_name, field_info in model.model_fields.items()
if field_name not in without_fields
},
)
return wrapper
class Test(BaseModel):
id: UUID
name: Optional[str] = None
@all_optional()
class TestUpdate(Test):
pass

View File

@ -42,6 +42,7 @@ class KnowledgeStatus(str, Enum):
PROCESSING = "PROCESSING" PROCESSING = "PROCESSING"
UPLOADED = "UPLOADED" UPLOADED = "UPLOADED"
ERROR = "ERROR" ERROR = "ERROR"
RESERVED = "RESERVED"
class Source(BaseModel): class Source(BaseModel):

View File

@ -0,0 +1,31 @@
ALTER USER postgres
SET idle_session_timeout = '3min';
ALTER USER postgres
SET idle_in_transaction_session_timeout = '3min';
-- Drop previous contraint
alter table "public"."knowledge" drop constraint "unique_file_sha1_user_id";
alter table "public"."knowledge"
add column "is_folder" boolean default false;
-- Update the knowledge to backfill knowledge to is_folder = false
UPDATE "public"."knowledge"
SET is_folder = false;
-- Add parent_id -> folder
alter table "public"."knowledge"
add column "parent_id" uuid;
alter table "public"."knowledge"
add constraint "public_knowledge_parent_id_fkey" FOREIGN KEY (parent_id) REFERENCES knowledge(id) ON DELETE CASCADE;
-- Add constraint must be folder for parent_id
CREATE FUNCTION is_parent_folder(folder_id uuid) RETURNS boolean AS $$ BEGIN RETURN (
SELECT k.is_folder
FROM public.knowledge k
WHERE k.id = folder_id
);
END;
$$ LANGUAGE plpgsql;
ALTER TABLE public.knowledge
ADD CONSTRAINT check_parent_is_folder CHECK (
parent_id IS NULL
OR is_parent_folder(parent_id)
);
-- Index on parent_id
CREATE INDEX knowledge_parent_id_idx ON public.knowledge USING btree (parent_id);

View File

@ -13,7 +13,7 @@ from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
from quivr_api.modules.brain.service.brain_service import BrainService from quivr_api.modules.brain.service.brain_service import BrainService
from quivr_api.modules.dependencies import get_supabase_client from quivr_api.modules.dependencies import get_supabase_client
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
from quivr_api.modules.knowledge.repository.storage import Storage from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.notification.service.notification_service import ( from quivr_api.modules.notification.service.notification_service import (
NotificationService, NotificationService,
@ -58,7 +58,7 @@ sync_user_service = SyncUserService()
sync_files_repo_service = SyncFilesRepository() sync_files_repo_service = SyncFilesRepository()
brain_service = BrainService() brain_service = BrainService()
brain_vectors = BrainsVectors() brain_vectors = BrainsVectors()
storage = Storage() storage = SupabaseS3Storage()
notion_service: SyncNotionService | None = None notion_service: SyncNotionService | None = None
async_engine: AsyncEngine | None = None async_engine: AsyncEngine | None = None
engine: Engine | None = None engine: Engine | None = None
@ -170,6 +170,8 @@ async def aprocess_file_task(
integration_link=source_link, integration_link=source_link,
delete_file=delete_file, delete_file=delete_file,
) )
session.commit()
await async_session.commit()
except Exception as e: except Exception as e:
session.rollback() session.rollback()
await async_session.rollback() await async_session.rollback()
@ -196,7 +198,11 @@ def process_crawl_task(
) )
global engine global engine
assert engine assert engine
try:
with Session(engine, expire_on_commit=False, autoflush=False) as session: with Session(engine, expire_on_commit=False, autoflush=False) as session:
session.execute(
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
)
vector_repository = VectorRepository(session) vector_repository = VectorRepository(session)
vector_service = VectorService(vector_repository) vector_service = VectorService(vector_repository)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -209,6 +215,12 @@ def process_crawl_task(
vector_service=vector_service, vector_service=vector_service,
) )
) )
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
@celery.task(name="NotionConnectorLoad") @celery.task(name="NotionConnectorLoad")

View File

@ -2,6 +2,7 @@ from uuid import UUID
from quivr_api.logger import get_logger from quivr_api.logger import get_logger
from quivr_api.modules.brain.service.brain_service import BrainService from quivr_api.modules.brain.service.brain_service import BrainService
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeUpdate
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.vector.service.vector_service import VectorService from quivr_api.modules.vector.service.vector_service import VectorService
@ -41,12 +42,10 @@ async def process_uploaded_file(
# If we have some knowledge with error # If we have some knowledge with error
with build_file(file_data, knowledge_id, file_name) as file_instance: with build_file(file_data, knowledge_id, file_name) as file_instance:
knowledge = await knowledge_service.get_knowledge(knowledge_id=knowledge_id) knowledge = await knowledge_service.get_knowledge(knowledge_id=knowledge_id)
should_process = await knowledge_service.update_sha1_conflict( await knowledge_service.update_knowledge(
knowledge=knowledge, knowledge,
brain_id=brain.brain_id, KnowledgeUpdate(file_sha1=file_instance.file_sha1), # type: ignore
file_sha1=file_instance.file_sha1,
) )
if should_process:
await process_file( await process_file(
file_instance=file_instance, file_instance=file_instance,
brain=brain, brain=brain,

View File

@ -141,7 +141,7 @@ async def process_notion_sync(
UUID(user_id), UUID(user_id),
notion_client, # type: ignore notion_client, # type: ignore
) )
await session.commit()
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise e raise e

View File

@ -40,6 +40,8 @@ async def fetch_and_store_notion_files_async(
else: else:
logger.warn("No notion page fetched") logger.warn("No notion page fetched")
# Commit all before exiting
await session.commit()
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise e raise e

View File

@ -6,7 +6,7 @@ from quivr_api.celery_config import celery
from quivr_api.logger import get_logger from quivr_api.logger import get_logger
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
from quivr_api.modules.knowledge.repository.storage import Storage from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.notification.service.notification_service import ( from quivr_api.modules.notification.service.notification_service import (
NotificationService, NotificationService,
@ -42,7 +42,7 @@ class SyncServices:
sync_files_repo_service: SyncFilesRepository sync_files_repo_service: SyncFilesRepository
notification_service: NotificationService notification_service: NotificationService
brain_vectors: BrainsVectors brain_vectors: BrainsVectors
storage: Storage storage: SupabaseS3Storage
@asynccontextmanager @asynccontextmanager
@ -56,7 +56,6 @@ async def build_syncs_utils(
await session.execute( await session.execute(
text("SET SESSION idle_in_transaction_session_timeout = '5min';") text("SET SESSION idle_in_transaction_session_timeout = '5min';")
) )
# TODO pass services from celery_worker
notion_repository = NotionRepository(session) notion_repository = NotionRepository(session)
notion_service = SyncNotionService(notion_repository) notion_service = SyncNotionService(notion_repository)
knowledge_service = KnowledgeService(KnowledgeRepository(session)) knowledge_service = KnowledgeService(KnowledgeRepository(session))
@ -84,7 +83,7 @@ async def build_syncs_utils(
mapping_sync_utils[provider_name] = provider_sync_util mapping_sync_utils[provider_name] = provider_sync_util
yield mapping_sync_utils yield mapping_sync_utils
await session.commit()
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise e raise e