mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 07:59:00 +03:00
feat: CRUD KMS (no syncs) (#3162)
# Description closes #3056. closes #3198 - Create knowledge route - Get knowledge route - List knowledge route : accepts knowledge_id | None. None to list root knowledge for use - Update (patch) knowledge to rename and move knowledge - Remove knowledge: Cascade if parent_id in knowledge and cleanup storage - Link storage upload to knowledge_service - Relax sha1 file constraint - Tests to all repository / service --------- Co-authored-by: Stan Girard <girard.stanislas@gmail.com>
This commit is contained in:
parent
edc4118ba1
commit
71edca572f
6
.github/workflows/backend-tests.yml
vendored
6
.github/workflows/backend-tests.yml
vendored
@ -9,7 +9,9 @@ on:
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
project: [quivr-api, quivr-worker]
|
||||
steps:
|
||||
- name: 👀 Checkout code
|
||||
uses: actions/checkout@v2
|
||||
@ -65,4 +67,4 @@ jobs:
|
||||
supabase start
|
||||
rye run python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()"
|
||||
rye run python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
||||
rye test -p quivr-api -p quivr-worker
|
||||
rye test -p ${{ matrix.project }}
|
||||
|
@ -32,13 +32,6 @@ repos:
|
||||
- id: mypy
|
||||
name: mypy
|
||||
additional_dependencies: ["types-aiofiles"]
|
||||
- repo: https://github.com/python-poetry/poetry
|
||||
rev: "1.8.0"
|
||||
hooks:
|
||||
- id: poetry-check
|
||||
args: ["-C", "./backend/core"]
|
||||
- id: poetry-lock
|
||||
args: ["-C", "./backend/core"]
|
||||
ci:
|
||||
autofix_commit_msg: |
|
||||
[pre-commit.ci] auto fixes from pre-commit.com hooks
|
||||
|
@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from quivr_api.middlewares.auth.jwt_token_handler import (
|
||||
decode_access_token,
|
||||
verify_token,
|
||||
@ -57,9 +58,13 @@ class AuthBearer(HTTPBearer):
|
||||
|
||||
def get_test_user(self) -> UserIdentity:
|
||||
return UserIdentity(
|
||||
email="admin@quivr.app", id="39418e3b-0258-4452-af60-7acfcc1263ff" # type: ignore
|
||||
email="admin@quivr.app",
|
||||
id="39418e3b-0258-4452-af60-7acfcc1263ff", # type: ignore
|
||||
) # replace with test user information
|
||||
|
||||
|
||||
def get_current_user(user: UserIdentity = Depends(AuthBearer())) -> UserIdentity:
|
||||
auth_bearer = AuthBearer()
|
||||
|
||||
|
||||
def get_current_user(user: UserIdentity = Depends(auth_bearer)) -> UserIdentity:
|
||||
return user
|
||||
|
@ -69,6 +69,7 @@ class Brain(AsyncAttrs, SQLModel, table=True):
|
||||
back_populates="brains", link_model=KnowledgeBrain
|
||||
)
|
||||
|
||||
|
||||
# TODO : add
|
||||
# "meaning" "public"."vector",
|
||||
# "tags" "public"."tags"[]
|
||||
|
@ -2,7 +2,7 @@ from uuid import UUID
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
|
||||
from quivr_api.modules.knowledge.repository.storage import Storage
|
||||
from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -11,7 +11,7 @@ class BrainVectorService:
|
||||
def __init__(self, brain_id: UUID):
|
||||
self.repository = BrainsVectors()
|
||||
self.brain_id = brain_id
|
||||
self.storage = Storage()
|
||||
self.storage = SupabaseS3Storage()
|
||||
|
||||
def create_brain_vector(self, vector_id: str, file_sha1: str):
|
||||
return self.repository.create_brain_vector(self.brain_id, vector_id, file_sha1) # type: ignore
|
||||
@ -26,10 +26,10 @@ class BrainVectorService:
|
||||
for vector_id in vector_ids:
|
||||
self.create_brain_vector(vector_id, file_sha1)
|
||||
|
||||
def delete_file_from_brain(self, file_name: str, only_vectors: bool = False):
|
||||
async def delete_file_from_brain(self, file_name: str, only_vectors: bool = False):
|
||||
file_name_with_brain_id = f"{self.brain_id}/{file_name}"
|
||||
if not only_vectors:
|
||||
self.storage.remove_file(file_name_with_brain_id)
|
||||
await self.storage.remove_file(file_name_with_brain_id)
|
||||
return self.repository.delete_file_from_brain(self.brain_id, file_name) # type: ignore
|
||||
|
||||
def delete_file_url_from_brain(self, file_name: str):
|
||||
|
@ -24,9 +24,6 @@ async_engine = create_async_engine(
|
||||
"postgresql+asyncpg://" + pg_database_base_url,
|
||||
echo=True if os.getenv("ORM_DEBUG") else False,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=10,
|
||||
pool_recycle=0.1,
|
||||
)
|
||||
|
||||
|
||||
|
@ -7,8 +7,7 @@ from langchain.embeddings.base import Embeddings
|
||||
from langchain_community.embeddings.ollama import OllamaEmbeddings
|
||||
|
||||
# from langchain_community.vectorstores.supabase import SupabaseVectorStore
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_openai import AzureOpenAIEmbeddings
|
||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
|
||||
# from quivr_api.modules.vector.service.vector_service import VectorService
|
||||
# from quivr_api.modules.vectorstore.supabase import CustomSupabaseVectorStore
|
||||
@ -22,7 +21,6 @@ from quivr_api.models.databases.supabase.supabase import SupabaseDB
|
||||
from quivr_api.models.settings import BrainSettings
|
||||
from supabase.client import AsyncClient, Client, create_async_client, create_client
|
||||
|
||||
|
||||
# Global variables to store the Supabase client and database instances
|
||||
_supabase_client: Optional[Client] = None
|
||||
_supabase_async_client: Optional[AsyncClient] = None
|
||||
|
@ -1,8 +1,8 @@
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated
|
||||
from typing import Annotated, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile, status
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
||||
@ -12,6 +12,14 @@ from quivr_api.modules.brain.service.brain_authorization_service import (
|
||||
validate_brain_authorization,
|
||||
)
|
||||
from quivr_api.modules.dependencies import get_service
|
||||
from quivr_api.modules.knowledge.dto.inputs import AddKnowledge
|
||||
from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeUpdate
|
||||
from quivr_api.modules.knowledge.service.knowledge_exceptions import (
|
||||
KnowledgeDeleteError,
|
||||
KnowledgeForbiddenAccess,
|
||||
KnowledgeNotFoundException,
|
||||
UploadError,
|
||||
)
|
||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||
from quivr_api.modules.upload.service.generate_file_signed_url import (
|
||||
generate_file_signed_url,
|
||||
@ -21,9 +29,8 @@ from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||
knowledge_router = APIRouter()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
KnowledgeServiceDep = Annotated[
|
||||
KnowledgeService, Depends(get_service(KnowledgeService))
|
||||
]
|
||||
get_km_service = get_service(KnowledgeService)
|
||||
KnowledgeServiceDep = Annotated[KnowledgeService, Depends(get_km_service)]
|
||||
|
||||
|
||||
@knowledge_router.get(
|
||||
@ -53,7 +60,7 @@ async def list_knowledge_in_brain_endpoint(
|
||||
],
|
||||
tags=["Knowledge"],
|
||||
)
|
||||
async def delete_endpoint(
|
||||
async def delete_knowledge_brain(
|
||||
knowledge_id: UUID,
|
||||
knowledge_service: KnowledgeServiceDep,
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
@ -65,7 +72,7 @@ async def delete_endpoint(
|
||||
|
||||
knowledge = await knowledge_service.get_knowledge(knowledge_id)
|
||||
file_name = knowledge.file_name if knowledge.file_name else knowledge.url
|
||||
await knowledge_service.remove_knowledge(brain_id, knowledge_id)
|
||||
await knowledge_service.remove_knowledge_brain(brain_id, knowledge_id)
|
||||
|
||||
return {
|
||||
"message": f"{file_name} of brain {brain_id} has been deleted by user {current_user.email}."
|
||||
@ -88,13 +95,13 @@ async def generate_signed_url_endpoint(
|
||||
|
||||
knowledge = await knowledge_service.get_knowledge(knowledge_id)
|
||||
|
||||
if len(knowledge.brain_ids) == 0:
|
||||
if len(knowledge.brains) == 0:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="knowledge not associated with brains yet.",
|
||||
)
|
||||
|
||||
brain_id = knowledge.brain_ids[0]
|
||||
brain_id = knowledge.brains[0]["brain_id"]
|
||||
|
||||
validate_brain_authorization(brain_id=brain_id, user_id=current_user.id)
|
||||
|
||||
@ -108,3 +115,153 @@ async def generate_signed_url_endpoint(
|
||||
file_signed_url = generate_file_signed_url(file_path_in_storage)
|
||||
|
||||
return file_signed_url
|
||||
|
||||
|
||||
@knowledge_router.post(
|
||||
"/knowledge/",
|
||||
tags=["Knowledge"],
|
||||
response_model=Knowledge,
|
||||
)
|
||||
async def create_knowledge(
|
||||
knowledge_data: str = File(...),
|
||||
file: Optional[UploadFile] = None,
|
||||
knowledge_service: KnowledgeService = Depends(get_km_service),
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
knowledge = AddKnowledge.model_validate_json(knowledge_data)
|
||||
if not knowledge.file_name and not knowledge.url:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Either file_name or url must be provided",
|
||||
)
|
||||
try:
|
||||
km = await knowledge_service.create_knowledge(
|
||||
knowledge_to_add=knowledge, upload_file=file, user_id=current_user.id
|
||||
)
|
||||
km_dto = await km.to_dto()
|
||||
return km_dto
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Unprocessable knowledge ",
|
||||
)
|
||||
except FileExistsError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT, detail="Existing knowledge"
|
||||
)
|
||||
except UploadError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Error occured uploading knowledge",
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
@knowledge_router.get(
|
||||
"/knowledge/children",
|
||||
response_model=List[Knowledge] | None,
|
||||
tags=["Knowledge"],
|
||||
)
|
||||
async def list_knowledge(
|
||||
parent_id: UUID | None = None,
|
||||
knowledge_service: KnowledgeService = Depends(get_km_service),
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
# TODO: Returns one level of children
|
||||
children = await knowledge_service.list_knowledge(parent_id, current_user.id)
|
||||
return [await c.to_dto(get_children=False) for c in children]
|
||||
except KnowledgeNotFoundException as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=f"{e.message}"
|
||||
)
|
||||
except KnowledgeForbiddenAccess as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}"
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
@knowledge_router.get(
|
||||
"/knowledge/{knowledge_id}",
|
||||
response_model=Knowledge,
|
||||
tags=["Knowledge"],
|
||||
)
|
||||
async def get_knowledge(
|
||||
knowledge_id: UUID,
|
||||
knowledge_service: KnowledgeService = Depends(get_km_service),
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
km = await knowledge_service.get_knowledge(knowledge_id)
|
||||
if km.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You do not have permission to access this knowledge.",
|
||||
)
|
||||
return await km.to_dto()
|
||||
except KnowledgeNotFoundException as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}"
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
@knowledge_router.patch(
|
||||
"/knowledge/{knowledge_id}",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
response_model=Knowledge,
|
||||
tags=["Knowledge"],
|
||||
)
|
||||
async def update_knowledge(
|
||||
knowledge_id: UUID,
|
||||
payload: KnowledgeUpdate,
|
||||
knowledge_service: KnowledgeService = Depends(get_km_service),
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
km = await knowledge_service.get_knowledge(knowledge_id)
|
||||
if km.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You do not have permission to access this knowledge.",
|
||||
)
|
||||
km = await knowledge_service.update_knowledge(km, payload)
|
||||
return km
|
||||
except KnowledgeNotFoundException as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}"
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
@knowledge_router.delete(
|
||||
"/knowledge/{knowledge_id}",
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
tags=["Knowledge"],
|
||||
)
|
||||
async def delete_knowledge(
|
||||
knowledge_id: UUID,
|
||||
knowledge_service: KnowledgeService = Depends(get_km_service),
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
km = await knowledge_service.get_knowledge(knowledge_id)
|
||||
|
||||
if km.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You do not have permission to remove this knowledge.",
|
||||
)
|
||||
delete_response = await knowledge_service.remove_knowledge(km)
|
||||
return delete_response
|
||||
except KnowledgeNotFoundException as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}"
|
||||
)
|
||||
except KnowledgeDeleteError:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
@ -16,8 +16,16 @@ class CreateKnowledgeProperties(BaseModel):
|
||||
file_size: Optional[int] = None
|
||||
file_sha1: Optional[str] = None
|
||||
metadata: Optional[Dict[str, str]] = None
|
||||
is_folder: bool = False
|
||||
parent_id: Optional[UUID] = None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
knowledge_dict = super().dict(*args, **kwargs)
|
||||
knowledge_dict["brain_id"] = str(knowledge_dict.get("brain_id"))
|
||||
return knowledge_dict
|
||||
|
||||
class AddKnowledge(BaseModel):
|
||||
file_name: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
extension: str = ".txt"
|
||||
source: str = "local"
|
||||
source_link: Optional[str] = None
|
||||
metadata: Optional[Dict[str, str]] = None
|
||||
is_folder: bool = False
|
||||
parent_id: Optional[UUID] = None
|
||||
|
@ -4,6 +4,6 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class DeleteKnowledgeResponse(BaseModel):
|
||||
file_name: str
|
||||
status: str = "delete"
|
||||
file_name: str | None = None
|
||||
status: str = "DELETED"
|
||||
knowledge_id: UUID
|
||||
|
@ -1,5 +1,6 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -12,20 +13,44 @@ from sqlmodel import Field, Relationship, SQLModel
|
||||
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
|
||||
|
||||
|
||||
class KnowledgeSource(str, Enum):
|
||||
LOCAL = "local"
|
||||
WEB = "web"
|
||||
GDRIVE = "google drive"
|
||||
DROPBOX = "dropbox"
|
||||
SHAREPOINT = "sharepoint"
|
||||
|
||||
|
||||
class Knowledge(BaseModel):
|
||||
id: UUID
|
||||
file_size: int = 0
|
||||
status: KnowledgeStatus
|
||||
file_name: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
extension: str = ".txt"
|
||||
status: str
|
||||
is_folder: bool = False
|
||||
updated_at: datetime
|
||||
created_at: datetime
|
||||
source: Optional[str] = None
|
||||
source_link: Optional[str] = None
|
||||
file_size: Optional[int] = None
|
||||
file_sha1: Optional[str] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
created_at: Optional[datetime] = None
|
||||
metadata: Optional[Dict[str, str]] = None
|
||||
brain_ids: list[UUID]
|
||||
user_id: UUID
|
||||
brains: List[Dict[str, Any]]
|
||||
parent: Optional["Knowledge"]
|
||||
children: Optional[list["Knowledge"]]
|
||||
|
||||
|
||||
class KnowledgeUpdate(BaseModel):
|
||||
file_name: Optional[str] = None
|
||||
status: Optional[KnowledgeStatus] = None
|
||||
url: Optional[str] = None
|
||||
file_sha1: Optional[str] = None
|
||||
extension: Optional[str] = None
|
||||
parent_id: Optional[UUID] = None
|
||||
source: Optional[str] = None
|
||||
source_link: Optional[str] = None
|
||||
metadata: Optional[Dict[str, str]] = None
|
||||
|
||||
|
||||
class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
|
||||
@ -49,13 +74,6 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
|
||||
file_sha1: Optional[str] = Field(
|
||||
max_length=40
|
||||
) # FIXME: Should not be optional @chloedia
|
||||
updated_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
TIMESTAMP(timezone=False),
|
||||
server_default=text("CURRENT_TIMESTAMP"),
|
||||
),
|
||||
)
|
||||
created_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
@ -63,9 +81,18 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
|
||||
server_default=text("CURRENT_TIMESTAMP"),
|
||||
),
|
||||
)
|
||||
updated_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
TIMESTAMP(timezone=False),
|
||||
server_default=text("CURRENT_TIMESTAMP"),
|
||||
onupdate=datetime.utcnow,
|
||||
),
|
||||
)
|
||||
metadata_: Optional[Dict[str, str]] = Field(
|
||||
default=None, sa_column=Column("metadata", JSON)
|
||||
)
|
||||
is_folder: bool = Field(default=False)
|
||||
user_id: UUID = Field(foreign_key="users.id", nullable=False)
|
||||
brains: List["Brain"] = Relationship(
|
||||
back_populates="knowledges",
|
||||
@ -73,10 +100,35 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
|
||||
sa_relationship_kwargs={"lazy": "select"},
|
||||
)
|
||||
|
||||
async def to_dto(self) -> Knowledge:
|
||||
parent_id: UUID | None = Field(
|
||||
default=None, foreign_key="knowledge.id", ondelete="CASCADE"
|
||||
)
|
||||
parent: Optional["KnowledgeDB"] = Relationship(
|
||||
back_populates="children",
|
||||
sa_relationship_kwargs={"remote_side": "KnowledgeDB.id"},
|
||||
)
|
||||
children: list["KnowledgeDB"] = Relationship(
|
||||
back_populates="parent",
|
||||
sa_relationship_kwargs={
|
||||
"cascade": "all, delete-orphan",
|
||||
},
|
||||
)
|
||||
|
||||
# TODO: nested folder search
|
||||
async def to_dto(self, get_children: bool = True) -> Knowledge:
|
||||
assert (
|
||||
self.updated_at
|
||||
), "knowledge should be inserted before transforming to dto"
|
||||
assert (
|
||||
self.created_at
|
||||
), "knowledge should be inserted before transforming to dto"
|
||||
brains = await self.awaitable_attrs.brains
|
||||
size = self.file_size if self.file_size else 0
|
||||
sha1 = self.file_sha1 if self.file_sha1 else ""
|
||||
children: list[KnowledgeDB] = (
|
||||
await self.awaitable_attrs.children if get_children else []
|
||||
)
|
||||
parent = await self.awaitable_attrs.parent
|
||||
parent = await parent.to_dto(get_children=False) if parent else None
|
||||
|
||||
return Knowledge(
|
||||
id=self.id, # type: ignore
|
||||
file_name=self.file_name,
|
||||
@ -85,10 +137,14 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
|
||||
status=KnowledgeStatus(self.status),
|
||||
source=self.source,
|
||||
source_link=self.source_link,
|
||||
file_size=size,
|
||||
file_sha1=sha1,
|
||||
is_folder=self.is_folder,
|
||||
file_size=self.file_size or 0,
|
||||
file_sha1=self.file_sha1,
|
||||
updated_at=self.updated_at,
|
||||
created_at=self.created_at,
|
||||
metadata=self.metadata_, # type: ignore
|
||||
brain_ids=[brain.brain_id for brain in brains],
|
||||
brains=[b.model_dump() for b in brains],
|
||||
parent=parent,
|
||||
children=[await c.to_dto(get_children=False) for c in children],
|
||||
user_id=self.user_id,
|
||||
)
|
||||
|
@ -1,9 +1,10 @@
|
||||
from typing import Sequence
|
||||
from typing import Any, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from quivr_core.models import KnowledgeStatus
|
||||
from sqlalchemy.exc import IntegrityError, NoResultFound
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlmodel import select, text
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@ -11,7 +12,15 @@ from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.entity.brain_entity import Brain
|
||||
from quivr_api.modules.dependencies import BaseRepository, get_supabase_client
|
||||
from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse
|
||||
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
||||
from quivr_api.modules.knowledge.entity.knowledge import (
|
||||
Knowledge,
|
||||
KnowledgeDB,
|
||||
KnowledgeUpdate,
|
||||
)
|
||||
from quivr_api.modules.knowledge.service.knowledge_exceptions import (
|
||||
KnowledgeNotFoundException,
|
||||
KnowledgeUpdateError,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -22,7 +31,43 @@ class KnowledgeRepository(BaseRepository):
|
||||
supabase_client = get_supabase_client()
|
||||
self.db = supabase_client
|
||||
|
||||
async def insert_knowledge(
|
||||
async def create_knowledge(self, knowledge: KnowledgeDB) -> KnowledgeDB:
|
||||
try:
|
||||
self.session.add(knowledge)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(knowledge)
|
||||
except IntegrityError:
|
||||
await self.session.rollback()
|
||||
raise
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
raise
|
||||
return knowledge
|
||||
|
||||
async def update_knowledge(
|
||||
self,
|
||||
knowledge: KnowledgeDB,
|
||||
payload: Knowledge | KnowledgeUpdate | dict[str, Any],
|
||||
) -> KnowledgeDB:
|
||||
try:
|
||||
logger.debug(f"updating {knowledge.id} with payload {payload}")
|
||||
if isinstance(payload, dict):
|
||||
update_data = payload
|
||||
else:
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for field in update_data:
|
||||
setattr(knowledge, field, update_data[field])
|
||||
|
||||
self.session.add(knowledge)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(knowledge)
|
||||
return knowledge
|
||||
except IntegrityError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Error updating knowledge {e}")
|
||||
raise KnowledgeUpdateError
|
||||
|
||||
async def insert_knowledge_brain(
|
||||
self, knowledge: KnowledgeDB, brain_id: UUID
|
||||
) -> KnowledgeDB:
|
||||
logger.debug(f"Inserting knowledge {knowledge}")
|
||||
@ -69,6 +114,14 @@ class KnowledgeRepository(BaseRepository):
|
||||
await self.session.refresh(knowledge)
|
||||
return knowledge
|
||||
|
||||
async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeResponse:
|
||||
assert knowledge.id
|
||||
await self.session.delete(knowledge)
|
||||
await self.session.commit()
|
||||
return DeleteKnowledgeResponse(
|
||||
status="deleted", knowledge_id=knowledge.id, file_name=knowledge.file_name
|
||||
)
|
||||
|
||||
async def remove_knowledge_by_id(
|
||||
self, knowledge_id: UUID
|
||||
) -> DeleteKnowledgeResponse:
|
||||
@ -126,14 +179,70 @@ class KnowledgeRepository(BaseRepository):
|
||||
|
||||
return knowledge
|
||||
|
||||
async def get_knowledge_by_id(self, knowledge_id: UUID) -> KnowledgeDB:
|
||||
query = select(KnowledgeDB).where(KnowledgeDB.id == knowledge_id)
|
||||
async def get_all_children(self, parent_id: UUID) -> list[KnowledgeDB]:
|
||||
query = text("""
|
||||
WITH RECURSIVE knowledge_tree AS (
|
||||
SELECT *
|
||||
FROM knowledge
|
||||
WHERE parent_id = :parent_id
|
||||
UNION ALL
|
||||
SELECT k.*
|
||||
FROM knowledge k
|
||||
JOIN knowledge_tree kt ON k.parent_id = kt.id
|
||||
)
|
||||
SELECT * FROM knowledge_tree
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, params={"parent_id": parent_id})
|
||||
rows = result.fetchall()
|
||||
knowledge_list = []
|
||||
for row in rows:
|
||||
knowledge = KnowledgeDB(
|
||||
id=row.id,
|
||||
parent_id=row.parent_id,
|
||||
file_name=row.file_name,
|
||||
url=row.url,
|
||||
extension=row.extension,
|
||||
status=row.status,
|
||||
source=row.source,
|
||||
source_link=row.source_link,
|
||||
file_size=row.file_size,
|
||||
file_sha1=row.file_sha1,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
metadata_=row.metadata,
|
||||
is_folder=row.is_folder,
|
||||
user_id=row.user_id,
|
||||
)
|
||||
knowledge_list.append(knowledge)
|
||||
|
||||
return knowledge_list
|
||||
|
||||
async def get_root_knowledge_user(self, user_id: UUID) -> list[KnowledgeDB]:
|
||||
query = (
|
||||
select(KnowledgeDB)
|
||||
.where(KnowledgeDB.parent_id.is_(None)) # type: ignore
|
||||
.where(KnowledgeDB.user_id == user_id)
|
||||
.options(joinedload(KnowledgeDB.parent), joinedload(KnowledgeDB.children)) # type: ignore
|
||||
)
|
||||
result = await self.session.exec(query)
|
||||
kms = result.unique().all()
|
||||
return list(kms)
|
||||
|
||||
async def get_knowledge_by_id(
|
||||
self, knowledge_id: UUID, user_id: UUID | None = None
|
||||
) -> KnowledgeDB:
|
||||
query = (
|
||||
select(KnowledgeDB)
|
||||
.where(KnowledgeDB.id == knowledge_id)
|
||||
.options(joinedload(KnowledgeDB.parent), joinedload(KnowledgeDB.children)) # type: ignore
|
||||
)
|
||||
if user_id:
|
||||
query = query.where(KnowledgeDB.user_id == user_id)
|
||||
result = await self.session.exec(query)
|
||||
knowledge = result.first()
|
||||
|
||||
if not knowledge:
|
||||
raise NoResultFound("Knowledge not found")
|
||||
|
||||
raise KnowledgeNotFoundException("Knowledge not found")
|
||||
return knowledge
|
||||
|
||||
async def get_brain_by_id(self, brain_id: UUID) -> Brain:
|
||||
|
@ -1,29 +1,87 @@
|
||||
import mimetypes
|
||||
from io import BufferedReader, FileIO
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.dependencies import get_supabase_client
|
||||
from quivr_api.modules.dependencies import get_supabase_async_client
|
||||
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
||||
from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Storage(StorageInterface):
|
||||
class SupabaseS3Storage(StorageInterface):
|
||||
def __init__(self):
|
||||
supabase_client = get_supabase_client()
|
||||
self.db = supabase_client
|
||||
self.client = None
|
||||
|
||||
def upload_file(self, file_name: str):
|
||||
"""
|
||||
Upload file to storage
|
||||
"""
|
||||
self.db.storage.from_("quivr").download(file_name)
|
||||
async def _set_client(self):
|
||||
if self.client is None:
|
||||
self.client = await get_supabase_async_client()
|
||||
|
||||
def remove_file(self, file_name: str):
|
||||
def get_storage_path(
|
||||
self,
|
||||
knowledge: KnowledgeDB,
|
||||
) -> str:
|
||||
if knowledge.id is None:
|
||||
raise ValueError("knowledge should have a valid id")
|
||||
return str(knowledge.id)
|
||||
|
||||
async def upload_file_storage(
|
||||
self,
|
||||
knowledge: KnowledgeDB,
|
||||
knowledge_data: FileIO | BufferedReader | bytes,
|
||||
upsert: bool = False,
|
||||
):
|
||||
await self._set_client()
|
||||
assert self.client
|
||||
|
||||
mime_type = "application/html"
|
||||
if knowledge.file_name:
|
||||
guessed_mime_type, _ = mimetypes.guess_type(knowledge.file_name)
|
||||
mime_type = guessed_mime_type or mime_type
|
||||
|
||||
storage_path = self.get_storage_path(knowledge)
|
||||
logger.info(
|
||||
f"Uploading file to s3://quivr/{storage_path} using supabase. upsert={upsert}, mimetype={mime_type}"
|
||||
)
|
||||
|
||||
if upsert:
|
||||
_ = await self.client.storage.from_("quivr").update(
|
||||
storage_path,
|
||||
knowledge_data,
|
||||
file_options={
|
||||
"content-type": mime_type,
|
||||
"upsert": "true",
|
||||
"cache-control": "3600",
|
||||
},
|
||||
)
|
||||
return storage_path
|
||||
else:
|
||||
# check if file sha1 is already in storage
|
||||
try:
|
||||
_ = await self.client.storage.from_("quivr").upload(
|
||||
storage_path,
|
||||
knowledge_data,
|
||||
file_options={
|
||||
"content-type": mime_type,
|
||||
"upsert": "false",
|
||||
"cache-control": "3600",
|
||||
},
|
||||
)
|
||||
return storage_path
|
||||
|
||||
except Exception as e:
|
||||
if "The resource already exists" in str(e) and not upsert:
|
||||
raise FileExistsError(f"File {storage_path} already exists")
|
||||
raise e
|
||||
|
||||
async def remove_file(self, storage_path: str):
|
||||
"""
|
||||
Remove file from storage
|
||||
"""
|
||||
await self._set_client()
|
||||
assert self.client
|
||||
try:
|
||||
response = self.db.storage.from_("quivr").remove([file_name])
|
||||
response = await self.client.storage.from_("quivr").remove([storage_path])
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
# raise e
|
||||
|
||||
|
@ -1,10 +1,26 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from io import BufferedReader, FileIO
|
||||
|
||||
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
||||
|
||||
|
||||
class StorageInterface(ABC):
|
||||
@abstractmethod
|
||||
def remove_file(self, file_name: str):
|
||||
"""
|
||||
Remove file from storage
|
||||
"""
|
||||
def get_storage_path(
|
||||
self,
|
||||
knowledge: KnowledgeDB,
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def upload_file_storage(
|
||||
self,
|
||||
knowledge: KnowledgeDB,
|
||||
knowledge_data: FileIO | BufferedReader | bytes,
|
||||
upsert: bool = False,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def remove_file(self, storage_path: str):
|
||||
pass
|
||||
|
@ -0,0 +1,34 @@
|
||||
class KnowledgeException(Exception):
|
||||
def __init__(self, message="A knowledge-related error occurred"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class UploadError(KnowledgeException):
|
||||
def __init__(self, message="An error occurred while uploading"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class KnowledgeCreationError(KnowledgeException):
|
||||
def __init__(self, message="An error occurred while creating"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class KnowledgeUpdateError(KnowledgeException):
|
||||
def __init__(self, message="An error occurred while updating"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class KnowledgeDeleteError(KnowledgeException):
|
||||
def __init__(self, message="An error occurred while deleting"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class KnowledgeForbiddenAccess(KnowledgeException):
|
||||
def __init__(self, message="You do not have permission to access this knowledge."):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class KnowledgeNotFoundException(KnowledgeException):
|
||||
def __init__(self, message="The requested knowledge was not found"):
|
||||
super().__init__(message)
|
@ -1,18 +1,33 @@
|
||||
from typing import List
|
||||
import asyncio
|
||||
import io
|
||||
from typing import Any, List
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import UploadFile
|
||||
from quivr_core.models import KnowledgeStatus
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.dependencies import BaseService
|
||||
from quivr_api.modules.knowledge.dto.inputs import (
|
||||
AddKnowledge,
|
||||
CreateKnowledgeProperties,
|
||||
)
|
||||
from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse
|
||||
from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeDB
|
||||
from quivr_api.modules.knowledge.entity.knowledge import (
|
||||
Knowledge,
|
||||
KnowledgeDB,
|
||||
KnowledgeSource,
|
||||
KnowledgeUpdate,
|
||||
)
|
||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||
from quivr_api.modules.knowledge.repository.storage import Storage
|
||||
from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
|
||||
from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface
|
||||
from quivr_api.modules.knowledge.service.knowledge_exceptions import (
|
||||
KnowledgeDeleteError,
|
||||
KnowledgeForbiddenAccess,
|
||||
UploadError,
|
||||
)
|
||||
from quivr_api.modules.sync.entity.sync_models import (
|
||||
DBSyncFile,
|
||||
DownloadedSyncFile,
|
||||
@ -26,9 +41,13 @@ logger = get_logger(__name__)
|
||||
class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
repository_cls = KnowledgeRepository
|
||||
|
||||
def __init__(self, repository: KnowledgeRepository):
|
||||
def __init__(
|
||||
self,
|
||||
repository: KnowledgeRepository,
|
||||
storage: StorageInterface = SupabaseS3Storage(),
|
||||
):
|
||||
self.repository = repository
|
||||
self.storage = Storage()
|
||||
self.storage = storage
|
||||
|
||||
async def get_knowledge_sync(self, sync_id: int) -> Knowledge:
|
||||
km = await self.repository.get_knowledge_by_sync_id(sync_id)
|
||||
@ -54,19 +73,37 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
except NoResultFound:
|
||||
raise FileNotFoundError(f"No knowledge for file_name: {file_name}")
|
||||
|
||||
async def get_knowledge(self, knowledge_id: UUID) -> Knowledge:
|
||||
inserted_knowledge_db_instance = await self.repository.get_knowledge_by_id(
|
||||
knowledge_id
|
||||
async def list_knowledge(
|
||||
self, knowledge_id: UUID | None, user_id: UUID | None = None
|
||||
) -> list[KnowledgeDB]:
|
||||
if knowledge_id is not None:
|
||||
km = await self.repository.get_knowledge_by_id(knowledge_id, user_id)
|
||||
return km.children
|
||||
else:
|
||||
if user_id is None:
|
||||
raise KnowledgeForbiddenAccess(
|
||||
"can't get root knowledges without user_id"
|
||||
)
|
||||
assert inserted_knowledge_db_instance.id, "Knowledge ID not generated"
|
||||
km = await inserted_knowledge_db_instance.to_dto()
|
||||
return km
|
||||
return await self.repository.get_root_knowledge_user(user_id)
|
||||
|
||||
async def get_knowledge(
|
||||
self, knowledge_id: UUID, user_id: UUID | None = None
|
||||
) -> KnowledgeDB:
|
||||
return await self.repository.get_knowledge_by_id(knowledge_id, user_id)
|
||||
|
||||
async def update_knowledge(
|
||||
self,
|
||||
knowledge: KnowledgeDB,
|
||||
payload: Knowledge | KnowledgeUpdate | dict[str, Any],
|
||||
):
|
||||
return await self.repository.update_knowledge(knowledge, payload)
|
||||
|
||||
# TODO: Remove all of this
|
||||
# TODO (@aminediro): Replace with ON CONFLICT smarter query...
|
||||
# there is a chance of race condition but for now we let it crash in worker
|
||||
# the tasks will be dealt with on retry
|
||||
async def update_sha1_conflict(
|
||||
self, knowledge: Knowledge, brain_id: UUID, file_sha1: str
|
||||
self, knowledge: KnowledgeDB, brain_id: UUID, file_sha1: str
|
||||
) -> bool:
|
||||
assert knowledge.id
|
||||
knowledge.file_sha1 = file_sha1
|
||||
@ -89,12 +126,12 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
)
|
||||
else:
|
||||
await self.repository.link_to_brain(existing_knowledge, brain_id)
|
||||
await self.remove_knowledge(brain_id, knowledge.id)
|
||||
await self.remove_knowledge_brain(brain_id, knowledge.id)
|
||||
return False
|
||||
else:
|
||||
logger.debug(f"Removing previous errored file {existing_knowledge.id}")
|
||||
assert existing_knowledge.id
|
||||
await self.remove_knowledge(brain_id, existing_knowledge.id)
|
||||
await self.remove_knowledge_brain(brain_id, existing_knowledge.id)
|
||||
await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1)
|
||||
return True
|
||||
except NoResultFound:
|
||||
@ -104,7 +141,47 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1)
|
||||
return True
|
||||
|
||||
async def insert_knowledge(
|
||||
async def create_knowledge(
|
||||
self,
|
||||
user_id: UUID,
|
||||
knowledge_to_add: AddKnowledge,
|
||||
upload_file: UploadFile | None = None,
|
||||
) -> KnowledgeDB:
|
||||
knowledgedb = KnowledgeDB(
|
||||
user_id=user_id,
|
||||
file_name=knowledge_to_add.file_name,
|
||||
is_folder=knowledge_to_add.is_folder,
|
||||
url=knowledge_to_add.url,
|
||||
extension=knowledge_to_add.extension,
|
||||
source=knowledge_to_add.source,
|
||||
source_link=knowledge_to_add.source_link,
|
||||
file_size=upload_file.size if upload_file else 0,
|
||||
metadata_=knowledge_to_add.metadata, # type: ignore
|
||||
status=KnowledgeStatus.RESERVED,
|
||||
parent_id=knowledge_to_add.parent_id,
|
||||
)
|
||||
knowledge_db = await self.repository.create_knowledge(knowledgedb)
|
||||
try:
|
||||
if knowledgedb.source == KnowledgeSource.LOCAL and upload_file:
|
||||
# NOTE(@aminediro): Unnecessary mem buffer because supabase doesnt accept FileIO..
|
||||
buff_reader = io.BufferedReader(upload_file.file) # type: ignore
|
||||
storage_path = await self.storage.upload_file_storage(
|
||||
knowledgedb, buff_reader
|
||||
)
|
||||
knowledgedb.source_link = storage_path
|
||||
knowledge_db = await self.repository.update_knowledge(
|
||||
knowledge_db,
|
||||
KnowledgeUpdate(status=KnowledgeStatus.UPLOADED), # type: ignore
|
||||
)
|
||||
return knowledge_db
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error uploading knowledge {knowledgedb.id} to storage : {e}"
|
||||
)
|
||||
await self.repository.remove_knowledge(knowledge=knowledge_db)
|
||||
raise UploadError()
|
||||
|
||||
async def insert_knowledge_brain(
|
||||
self,
|
||||
user_id: UUID,
|
||||
knowledge_to_add: CreateKnowledgeProperties, # FIXME: (later) @Amine brain id should not be in CreateKnowledgeProperties but since storage is brain_id/file_name
|
||||
@ -122,7 +199,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
knowledge_db = await self.repository.insert_knowledge(
|
||||
knowledge_db = await self.repository.insert_knowledge_brain(
|
||||
knowledge, brain_id=knowledge_to_add.brain_id
|
||||
)
|
||||
|
||||
@ -150,7 +227,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
assert isinstance(knowledge.file_name, str), "file_name should be a string"
|
||||
file_name_with_brain_id = f"{brain_id}/{knowledge.file_name}"
|
||||
try:
|
||||
self.storage.remove_file(file_name_with_brain_id)
|
||||
await self.storage.remove_file(file_name_with_brain_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while removing file {file_name_with_brain_id}: {e}"
|
||||
@ -161,29 +238,52 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
async def update_file_sha1_knowledge(self, knowledge_id: UUID, file_sha1: str):
|
||||
return await self.repository.update_file_sha1_knowledge(knowledge_id, file_sha1)
|
||||
|
||||
async def remove_knowledge(
|
||||
async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeResponse:
|
||||
assert knowledge.id
|
||||
|
||||
try:
|
||||
# TODO:
|
||||
# - Notion folders are special, they are themselves files and should be removed from storage
|
||||
children = await self.repository.get_all_children(knowledge.id)
|
||||
km_paths = [
|
||||
self.storage.get_storage_path(k) for k in children if not k.is_folder
|
||||
]
|
||||
if not knowledge.is_folder:
|
||||
km_paths.append(self.storage.get_storage_path(knowledge))
|
||||
|
||||
# recursively deletes files
|
||||
deleted_km = await self.repository.remove_knowledge(knowledge)
|
||||
await asyncio.gather(*[self.storage.remove_file(p) for p in km_paths])
|
||||
|
||||
return deleted_km
|
||||
except Exception as e:
|
||||
logger.error(f"Error while remove knowledge : {e}")
|
||||
raise KnowledgeDeleteError
|
||||
|
||||
async def remove_knowledge_brain(
|
||||
self,
|
||||
brain_id: UUID,
|
||||
knowledge_id: UUID, # FIXME: @amine when name in storage change no need for brain id
|
||||
) -> DeleteKnowledgeResponse:
|
||||
# TODO: fix KMS
|
||||
# REDO ALL THIS
|
||||
knowledge = await self.get_knowledge(knowledge_id)
|
||||
if len(knowledge.brain_ids) > 1:
|
||||
knowledge = await self.repository.get_knowledge_by_id(knowledge_id)
|
||||
km_brains = await knowledge.awaitable_attrs.brains
|
||||
if len(km_brains) > 1:
|
||||
km = await self.repository.remove_knowledge_from_brain(
|
||||
knowledge_id, brain_id
|
||||
)
|
||||
assert km.id
|
||||
return DeleteKnowledgeResponse(file_name=km.file_name, knowledge_id=km.id)
|
||||
else:
|
||||
message = await self.repository.remove_knowledge_by_id(knowledge_id)
|
||||
file_name_with_brain_id = f"{brain_id}/{message.file_name}"
|
||||
try:
|
||||
self.storage.remove_file(file_name_with_brain_id)
|
||||
await self.storage.remove_file(file_name_with_brain_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while removing file {file_name_with_brain_id}: {e}"
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
async def remove_all_knowledges_from_brain(self, brain_id: UUID) -> None:
|
||||
@ -210,7 +310,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
# TODO: THIS IS A HACK!! Remove all of this
|
||||
if prev_sync_file:
|
||||
prev_knowledge = await self.get_knowledge_sync(sync_id=prev_sync_file.id)
|
||||
if len(prev_knowledge.brain_ids) > 1:
|
||||
if len(prev_knowledge.brains) > 1:
|
||||
await self.repository.remove_knowledge_from_brain(
|
||||
prev_knowledge.id, brain_id
|
||||
)
|
||||
@ -231,7 +331,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
file_sha1=None,
|
||||
metadata={"sync_file_id": str(sync_id)},
|
||||
)
|
||||
added_knowledge = await self.insert_knowledge(
|
||||
added_knowledge = await self.insert_knowledge_brain(
|
||||
knowledge_to_add=knowledge_to_add, user_id=user_id
|
||||
)
|
||||
return added_knowledge
|
||||
|
67
backend/api/quivr_api/modules/knowledge/tests/conftest.py
Normal file
67
backend/api/quivr_api/modules/knowledge/tests/conftest.py
Normal file
@ -0,0 +1,67 @@
|
||||
from io import BufferedReader, FileIO
|
||||
|
||||
from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeDB
|
||||
from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface
|
||||
|
||||
|
||||
class ErrorStorage(StorageInterface):
|
||||
async def upload_file_storage(
|
||||
self,
|
||||
knowledge: KnowledgeDB,
|
||||
knowledge_data: FileIO | BufferedReader | bytes,
|
||||
upsert: bool = False,
|
||||
):
|
||||
raise SystemError
|
||||
|
||||
def get_storage_path(
|
||||
self,
|
||||
knowledge: KnowledgeDB | Knowledge,
|
||||
) -> str:
|
||||
if knowledge.id is None:
|
||||
raise ValueError("knowledge should have a valid id")
|
||||
return str(knowledge.id)
|
||||
|
||||
async def remove_file(self, storage_path: str):
|
||||
raise SystemError
|
||||
|
||||
|
||||
class FakeStorage(StorageInterface):
|
||||
def __init__(self):
|
||||
self.storage = {}
|
||||
|
||||
def get_storage_path(
|
||||
self,
|
||||
knowledge: KnowledgeDB | Knowledge,
|
||||
) -> str:
|
||||
if knowledge.id is None:
|
||||
raise ValueError("knowledge should have a valid id")
|
||||
return str(knowledge.id)
|
||||
|
||||
async def upload_file_storage(
|
||||
self,
|
||||
knowledge: KnowledgeDB,
|
||||
knowledge_data: FileIO | BufferedReader | bytes,
|
||||
upsert: bool = False,
|
||||
):
|
||||
storage_path = f"{knowledge.id}"
|
||||
if not upsert and storage_path in self.storage:
|
||||
raise ValueError(f"File already exists at {storage_path}")
|
||||
self.storage[storage_path] = knowledge_data
|
||||
return storage_path
|
||||
|
||||
async def remove_file(self, storage_path: str):
|
||||
if storage_path not in self.storage:
|
||||
raise FileNotFoundError(f"File not found at {storage_path}")
|
||||
del self.storage[storage_path]
|
||||
|
||||
# Additional helper methods for testing
|
||||
def get_file(self, storage_path: str) -> FileIO | BufferedReader | bytes:
|
||||
if storage_path not in self.storage:
|
||||
raise FileNotFoundError(f"File not found at {storage_path}")
|
||||
return self.storage[storage_path]
|
||||
|
||||
def knowledge_exists(self, knowledge: KnowledgeDB | Knowledge) -> bool:
|
||||
return self.get_storage_path(knowledge) in self.storage
|
||||
|
||||
def clear_storage(self):
|
||||
self.storage.clear()
|
@ -0,0 +1,74 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from quivr_api.main import app
|
||||
from quivr_api.middlewares.auth.auth_bearer import get_current_user
|
||||
from quivr_api.modules.knowledge.controller.knowledge_routes import get_km_service
|
||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||
from quivr_api.modules.knowledge.tests.conftest import FakeStorage
|
||||
from quivr_api.modules.user.entity.user_identity import User, UserIdentity
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def user(session: AsyncSession) -> User:
|
||||
user_1 = (
|
||||
await session.exec(select(User).where(User.email == "admin@quivr.app"))
|
||||
).one()
|
||||
assert user_1.id
|
||||
return user_1
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_client(session: AsyncSession, user: User):
|
||||
def default_current_user() -> UserIdentity:
|
||||
assert user.id
|
||||
return UserIdentity(email=user.email, id=user.id)
|
||||
|
||||
async def test_service():
|
||||
storage = FakeStorage()
|
||||
repository = KnowledgeRepository(session)
|
||||
return KnowledgeService(repository, storage)
|
||||
|
||||
app.dependency_overrides[get_current_user] = default_current_user
|
||||
app.dependency_overrides[get_km_service] = test_service
|
||||
# app.dependency_overrides[get_async_session] = lambda: session
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as ac:
|
||||
yield ac
|
||||
app.dependency_overrides = {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_post_knowledge(test_client: AsyncClient):
|
||||
km_data = {
|
||||
"file_name": "test_file.txt",
|
||||
"source": "local",
|
||||
"is_folder": False,
|
||||
"parent_id": None,
|
||||
}
|
||||
|
||||
multipart_data = {
|
||||
"knowledge_data": (None, json.dumps(km_data), "application/json"),
|
||||
"file": ("test_file.txt", b"Test file content", "application/octet-stream"),
|
||||
}
|
||||
|
||||
response = await test_client.post(
|
||||
"/knowledge/",
|
||||
files=multipart_data,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_knowledge_invalid_input(test_client):
|
||||
response = await test_client.post("/knowledge/", files={})
|
||||
assert response.status_code == 422
|
@ -0,0 +1,229 @@
|
||||
from typing import List, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from quivr_core.models import KnowledgeStatus
|
||||
from sqlmodel import select, text
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
|
||||
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
||||
from quivr_api.modules.user.entity.user_identity import User
|
||||
|
||||
TestData = Tuple[Brain, List[KnowledgeDB]]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def other_user(session: AsyncSession):
|
||||
sql = text(
|
||||
"""
|
||||
INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES
|
||||
('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL);
|
||||
"""
|
||||
)
|
||||
await session.execute(sql, params={"id": uuid4()})
|
||||
|
||||
other_user = (
|
||||
await session.exec(select(User).where(User.email == "other@quivr.app"))
|
||||
).one()
|
||||
return other_user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def user(session):
|
||||
user_1 = (
|
||||
await session.exec(select(User).where(User.email == "admin@quivr.app"))
|
||||
).one()
|
||||
return user_1
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def brain(session):
|
||||
brain_1 = Brain(
|
||||
name="test_brain",
|
||||
description="this is a test brain",
|
||||
brain_type=BrainType.integration,
|
||||
)
|
||||
session.add(brain_1)
|
||||
await session.commit()
|
||||
return brain_1
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def folder(session, user):
|
||||
folder = KnowledgeDB(
|
||||
file_name="folder_1",
|
||||
extension="",
|
||||
status="UPLOADED",
|
||||
source="local",
|
||||
source_link="local",
|
||||
file_size=4,
|
||||
file_sha1=None,
|
||||
brains=[],
|
||||
children=[],
|
||||
user_id=user.id,
|
||||
is_folder=True,
|
||||
)
|
||||
|
||||
session.add(folder)
|
||||
await session.commit()
|
||||
await session.refresh(folder)
|
||||
return folder
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_knowledge_default_file(session, folder, user):
|
||||
km = KnowledgeDB(
|
||||
file_name="test_file_1.txt",
|
||||
extension=".txt",
|
||||
status="UPLOADED",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1="test_sha1",
|
||||
brains=[],
|
||||
user_id=user.id,
|
||||
parent_id=folder.id,
|
||||
)
|
||||
session.add(km)
|
||||
await session.commit()
|
||||
await session.refresh(km)
|
||||
|
||||
assert not km.is_folder
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_knowledge_parent(session: AsyncSession, user: User):
|
||||
assert user.id
|
||||
|
||||
km = KnowledgeDB(
|
||||
file_name="test_file_1.txt",
|
||||
extension=".txt",
|
||||
status="UPLOADED",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1="test_sha1",
|
||||
brains=[],
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
folder = KnowledgeDB(
|
||||
file_name="folder_1",
|
||||
extension="",
|
||||
is_folder=True,
|
||||
status="UPLOADED",
|
||||
source="local",
|
||||
source_link="local",
|
||||
file_size=-1,
|
||||
file_sha1=None,
|
||||
brains=[],
|
||||
children=[km],
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
session.add(folder)
|
||||
await session.commit()
|
||||
await session.refresh(folder)
|
||||
await session.refresh(km)
|
||||
|
||||
parent = await km.awaitable_attrs.parent
|
||||
assert km.parent_id == folder.id, "parent_id isn't set to folder id"
|
||||
assert parent.id == folder.id, "parent_id isn't set to folder id"
|
||||
assert parent.is_folder
|
||||
|
||||
query = select(KnowledgeDB).where(KnowledgeDB.id == folder.id)
|
||||
folder = (await session.exec(query)).first()
|
||||
assert folder
|
||||
|
||||
children = await folder.awaitable_attrs.children
|
||||
assert len(children) > 0
|
||||
|
||||
assert children[0].id == km.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_knowledge_remove_folder_cascade(
|
||||
session: AsyncSession,
|
||||
folder: KnowledgeDB,
|
||||
user,
|
||||
):
|
||||
km = KnowledgeDB(
|
||||
file_name="test_file_1.txt",
|
||||
extension=".txt",
|
||||
status="UPLOADED",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1="test_sha1",
|
||||
brains=[],
|
||||
user_id=user.id,
|
||||
parent_id=folder.id,
|
||||
)
|
||||
session.add(km)
|
||||
await session.commit()
|
||||
await session.refresh(km)
|
||||
|
||||
# Check all removed
|
||||
await session.delete(folder)
|
||||
await session.commit()
|
||||
|
||||
statement = select(KnowledgeDB)
|
||||
results = (await session.exec(statement)).all()
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_knowledge_dto(session, user, brain):
|
||||
# add folder in brain
|
||||
folder = KnowledgeDB(
|
||||
file_name="folder_1",
|
||||
extension="",
|
||||
status="UPLOADED",
|
||||
source="local",
|
||||
source_link="local",
|
||||
file_size=4,
|
||||
file_sha1=None,
|
||||
brains=[brain],
|
||||
children=[],
|
||||
user_id=user.id,
|
||||
is_folder=True,
|
||||
)
|
||||
km = KnowledgeDB(
|
||||
file_name="test_file_1.txt",
|
||||
extension=".txt",
|
||||
status="UPLOADED",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1="test_sha1",
|
||||
user_id=user.id,
|
||||
brains=[brain],
|
||||
parent=folder,
|
||||
)
|
||||
session.add(km)
|
||||
session.add(km)
|
||||
await session.commit()
|
||||
await session.refresh(km)
|
||||
|
||||
km_dto = await km.to_dto()
|
||||
|
||||
assert km_dto.file_name == km.file_name
|
||||
assert km_dto.url == km.url
|
||||
assert km_dto.extension == km.extension
|
||||
assert km_dto.status == KnowledgeStatus(km.status)
|
||||
assert km_dto.source == km.source
|
||||
assert km_dto.source_link == km.source_link
|
||||
assert km_dto.is_folder == km.is_folder
|
||||
assert km_dto.file_size == km.file_size
|
||||
assert km_dto.file_sha1 == km.file_sha1
|
||||
assert km_dto.updated_at == km.updated_at
|
||||
assert km_dto.created_at == km.created_at
|
||||
assert km_dto.metadata == km.metadata_ # type: ignor
|
||||
assert km_dto.parent
|
||||
assert km_dto.parent.id == folder.id
|
||||
|
||||
folder_dto = await folder.to_dto()
|
||||
assert folder_dto.brains[0] == brain.model_dump()
|
||||
assert folder_dto.children == [await km.to_dto()]
|
File diff suppressed because it is too large
Load Diff
@ -1,450 +0,0 @@
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.exc import IntegrityError, NoResultFound
|
||||
from sqlmodel import select, text
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
|
||||
from quivr_api.modules.knowledge.dto.inputs import KnowledgeStatus
|
||||
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
||||
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
|
||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||
from quivr_api.modules.upload.service.upload_file import upload_file_storage
|
||||
from quivr_api.modules.user.entity.user_identity import User
|
||||
from quivr_api.modules.vector.entity.vector import Vector
|
||||
|
||||
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
|
||||
|
||||
TestData = Tuple[Brain, List[KnowledgeDB]]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def other_user(session: AsyncSession):
|
||||
sql = text(
|
||||
"""
|
||||
INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES
|
||||
('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL);
|
||||
"""
|
||||
)
|
||||
await session.execute(sql, params={"id": uuid4()})
|
||||
|
||||
other_user = (
|
||||
await session.exec(select(User).where(User.email == "other@quivr.app"))
|
||||
).one()
|
||||
return other_user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_data(session: AsyncSession) -> TestData:
|
||||
user_1 = (
|
||||
await session.exec(select(User).where(User.email == "admin@quivr.app"))
|
||||
).one()
|
||||
assert user_1.id
|
||||
# Brain data
|
||||
brain_1 = Brain(
|
||||
name="test_brain",
|
||||
description="this is a test brain",
|
||||
brain_type=BrainType.integration,
|
||||
)
|
||||
|
||||
knowledge_brain_1 = KnowledgeDB(
|
||||
file_name="test_file_1.txt",
|
||||
extension=".txt",
|
||||
status="UPLOADED",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1="test_sha1",
|
||||
brains=[brain_1],
|
||||
user_id=user_1.id,
|
||||
)
|
||||
|
||||
knowledge_brain_2 = KnowledgeDB(
|
||||
file_name="test_file_2.txt",
|
||||
extension=".txt",
|
||||
status="UPLOADED",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1="test_sha2",
|
||||
brains=[],
|
||||
user_id=user_1.id,
|
||||
)
|
||||
|
||||
session.add(brain_1)
|
||||
session.add(knowledge_brain_1)
|
||||
session.add(knowledge_brain_2)
|
||||
await session.commit()
|
||||
return brain_1, [knowledge_brain_1, knowledge_brain_2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_updates_knowledge_status(session: AsyncSession, test_data: TestData):
|
||||
brain, knowledges = test_data
|
||||
assert brain.brain_id
|
||||
assert knowledges[0].id
|
||||
repo = KnowledgeRepository(session)
|
||||
await repo.update_status_knowledge(knowledges[0].id, KnowledgeStatus.ERROR)
|
||||
knowledge = await repo.get_knowledge_by_id(knowledges[0].id)
|
||||
assert knowledge.status == KnowledgeStatus.ERROR
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_updates_knowledge_status_no_knowledge(
|
||||
session: AsyncSession, test_data: TestData
|
||||
):
|
||||
brain, knowledges = test_data
|
||||
assert brain.brain_id
|
||||
assert knowledges[0].id
|
||||
repo = KnowledgeRepository(session)
|
||||
with pytest.raises(NoResultFound):
|
||||
await repo.update_status_knowledge(uuid4(), KnowledgeStatus.UPLOADED)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_update_knowledge_source_link(session: AsyncSession, test_data: TestData):
|
||||
brain, knowledges = test_data
|
||||
assert brain.brain_id
|
||||
assert knowledges[0].id
|
||||
repo = KnowledgeRepository(session)
|
||||
await repo.update_source_link_knowledge(knowledges[0].id, "new_source_link")
|
||||
knowledge = await repo.get_knowledge_by_id(knowledges[0].id)
|
||||
assert knowledge.source_link == "new_source_link"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_remove_knowledge_from_brain(session: AsyncSession, test_data: TestData):
|
||||
brain, knowledges = test_data
|
||||
assert brain.brain_id
|
||||
assert knowledges[0].id
|
||||
repo = KnowledgeRepository(session)
|
||||
knowledge = await repo.remove_knowledge_from_brain(knowledges[0].id, brain.brain_id)
|
||||
assert brain.brain_id not in [
|
||||
b.brain_id for b in await knowledge.awaitable_attrs.brains
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_cascade_remove_knowledge_by_id(
|
||||
session: AsyncSession, test_data: TestData
|
||||
):
|
||||
brain, knowledges = test_data
|
||||
assert brain.brain_id
|
||||
assert knowledges[0].id
|
||||
repo = KnowledgeRepository(session)
|
||||
await repo.remove_knowledge_by_id(knowledges[0].id)
|
||||
with pytest.raises(NoResultFound):
|
||||
await repo.get_knowledge_by_id(knowledges[0].id)
|
||||
|
||||
query = select(KnowledgeBrain).where(
|
||||
KnowledgeBrain.knowledge_id == knowledges[0].id
|
||||
)
|
||||
result = await session.exec(query)
|
||||
knowledge_brain = result.first()
|
||||
assert knowledge_brain is None
|
||||
|
||||
query = select(Vector).where(Vector.knowledge_id == knowledges[0].id)
|
||||
result = await session.exec(query)
|
||||
vector = result.first()
|
||||
assert vector is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_remove_all_knowledges_from_brain(
|
||||
session: AsyncSession, test_data: TestData
|
||||
):
|
||||
brain, knowledges = test_data
|
||||
assert brain.brain_id
|
||||
|
||||
# supabase_client = get_supabase_client()
|
||||
# db = supabase_client
|
||||
# storage = db.storage.from_("quivr")
|
||||
|
||||
# storage.upload(f"{brain.brain_id}/test_file_1", b"test_content")
|
||||
|
||||
repo = KnowledgeRepository(session)
|
||||
service = KnowledgeService(repo)
|
||||
await repo.remove_all_knowledges_from_brain(brain.brain_id)
|
||||
knowledges = await service.get_all_knowledge_in_brain(brain.brain_id)
|
||||
assert len(knowledges) == 0
|
||||
|
||||
# response = storage.list(path=f"{brain.brain_id}")
|
||||
# assert response == []
|
||||
# FIXME @aminediro &chloedia raise an error when trying to interact with storage UnboundLocalError: cannot access local variable 'response' where it is not associated with a value
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_duplicate_sha1_knowledge_same_user(
|
||||
session: AsyncSession, test_data: TestData
|
||||
):
|
||||
brain, [existing_knowledge, _] = test_data
|
||||
assert brain.brain_id
|
||||
assert existing_knowledge.id
|
||||
assert existing_knowledge.file_sha1
|
||||
repo = KnowledgeRepository(session)
|
||||
knowledge = KnowledgeDB(
|
||||
file_name="test_file_2",
|
||||
extension="txt",
|
||||
status="UPLOADED",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1=existing_knowledge.file_sha1,
|
||||
brains=[brain],
|
||||
user_id=existing_knowledge.user_id,
|
||||
)
|
||||
|
||||
with pytest.raises(IntegrityError): # FIXME: Should raise IntegrityError
|
||||
await repo.insert_knowledge(knowledge, brain.brain_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_duplicate_sha1_knowledge_diff_user(
|
||||
session: AsyncSession, test_data: TestData, other_user: User
|
||||
):
|
||||
brain, knowledges = test_data
|
||||
assert other_user.id
|
||||
assert brain.brain_id
|
||||
assert knowledges[0].id
|
||||
repo = KnowledgeRepository(session)
|
||||
knowledge = KnowledgeDB(
|
||||
file_name="test_file_2",
|
||||
extension="txt",
|
||||
status="UPLOADED",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1=knowledges[0].file_sha1,
|
||||
brains=[brain],
|
||||
user_id=other_user.id, # random user id
|
||||
)
|
||||
|
||||
result = await repo.insert_knowledge(knowledge, brain.brain_id)
|
||||
assert result
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_add_knowledge_to_brain(session: AsyncSession, test_data: TestData):
|
||||
brain, knowledges = test_data
|
||||
assert brain.brain_id
|
||||
assert knowledges[1].id
|
||||
repo = KnowledgeRepository(session)
|
||||
await repo.link_to_brain(knowledges[1], brain.brain_id)
|
||||
knowledge = await repo.get_knowledge_by_id(knowledges[1].id)
|
||||
brains_of_knowledge = [b.brain_id for b in await knowledge.awaitable_attrs.brains]
|
||||
assert brain.brain_id in brains_of_knowledge
|
||||
|
||||
query = select(KnowledgeBrain).where(
|
||||
KnowledgeBrain.knowledge_id == knowledges[0].id
|
||||
and KnowledgeBrain.brain_id == brain.brain_id
|
||||
)
|
||||
result = await session.exec(query)
|
||||
knowledge_brain = result.first()
|
||||
assert knowledge_brain
|
||||
|
||||
|
||||
# Knowledge Service
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_knowledge_in_brain(session: AsyncSession, test_data: TestData):
|
||||
brain, knowledges = test_data
|
||||
assert brain.brain_id
|
||||
repo = KnowledgeRepository(session)
|
||||
service = KnowledgeService(repo)
|
||||
list_knowledge = await service.get_all_knowledge_in_brain(brain.brain_id)
|
||||
assert len(list_knowledge) == 1
|
||||
brains_of_knowledge = [
|
||||
b.brain_id for b in await knowledges[0].awaitable_attrs.brains
|
||||
]
|
||||
assert list_knowledge[0].id == knowledges[0].id
|
||||
assert list_knowledge[0].file_name == knowledges[0].file_name
|
||||
assert brain.brain_id in brains_of_knowledge
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_should_process_knowledge_exists(
|
||||
session: AsyncSession, test_data: TestData
|
||||
):
|
||||
brain, [existing_knowledge, _] = test_data
|
||||
assert brain.brain_id
|
||||
new = KnowledgeDB(
|
||||
file_name="new",
|
||||
extension="txt",
|
||||
status="PROCESSING",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1=None,
|
||||
brains=[brain],
|
||||
user_id=existing_knowledge.user_id,
|
||||
)
|
||||
session.add(new)
|
||||
await session.commit()
|
||||
await session.refresh(new)
|
||||
incoming_knowledge = await new.to_dto()
|
||||
repo = KnowledgeRepository(session)
|
||||
service = KnowledgeService(repo)
|
||||
assert existing_knowledge.file_sha1
|
||||
with pytest.raises(FileExistsError):
|
||||
await service.update_sha1_conflict(
|
||||
incoming_knowledge, brain.brain_id, file_sha1=existing_knowledge.file_sha1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_should_process_knowledge_link_brain(
|
||||
session: AsyncSession, test_data: TestData
|
||||
):
|
||||
repo = KnowledgeRepository(session)
|
||||
service = KnowledgeService(repo)
|
||||
brain, [existing_knowledge, _] = test_data
|
||||
user_id = existing_knowledge.user_id
|
||||
assert brain.brain_id
|
||||
prev = KnowledgeDB(
|
||||
file_name="prev",
|
||||
extension=".txt",
|
||||
status=KnowledgeStatus.UPLOADED,
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1="test1",
|
||||
brains=[brain],
|
||||
user_id=user_id,
|
||||
)
|
||||
brain_2 = Brain(
|
||||
name="test_brain",
|
||||
description="this is a test brain",
|
||||
brain_type=BrainType.integration,
|
||||
)
|
||||
session.add(brain_2)
|
||||
session.add(prev)
|
||||
await session.commit()
|
||||
await session.refresh(prev)
|
||||
await session.refresh(brain_2)
|
||||
|
||||
assert prev.id
|
||||
assert brain_2.brain_id
|
||||
|
||||
new = KnowledgeDB(
|
||||
file_name="new",
|
||||
extension="txt",
|
||||
status="PROCESSING",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1=None,
|
||||
brains=[brain_2],
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(new)
|
||||
await session.commit()
|
||||
await session.refresh(new)
|
||||
|
||||
incoming_knowledge = await new.to_dto()
|
||||
assert prev.file_sha1
|
||||
|
||||
should_process = await service.update_sha1_conflict(
|
||||
incoming_knowledge, brain_2.brain_id, file_sha1=prev.file_sha1
|
||||
)
|
||||
assert not should_process
|
||||
|
||||
# Check prev knowledge was linked
|
||||
assert incoming_knowledge.file_sha1
|
||||
prev_knowledge = await service.repository.get_knowledge_by_id(prev.id)
|
||||
prev_brains = await prev_knowledge.awaitable_attrs.brains
|
||||
assert {b.brain_id for b in prev_brains} == {
|
||||
brain.brain_id,
|
||||
brain_2.brain_id,
|
||||
}
|
||||
# Check new knowledge was removed
|
||||
assert new.id
|
||||
with pytest.raises(NoResultFound):
|
||||
await service.repository.get_knowledge_by_id(new.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_should_process_knowledge_prev_error(
|
||||
session: AsyncSession, test_data: TestData
|
||||
):
|
||||
repo = KnowledgeRepository(session)
|
||||
service = KnowledgeService(repo)
|
||||
brain, [existing_knowledge, _] = test_data
|
||||
user_id = existing_knowledge.user_id
|
||||
assert brain.brain_id
|
||||
prev = KnowledgeDB(
|
||||
file_name="prev",
|
||||
extension="txt",
|
||||
status=KnowledgeStatus.ERROR,
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1="test1",
|
||||
brains=[brain],
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(prev)
|
||||
await session.commit()
|
||||
await session.refresh(prev)
|
||||
|
||||
assert prev.id
|
||||
|
||||
new = KnowledgeDB(
|
||||
file_name="new",
|
||||
extension="txt",
|
||||
status="PROCESSING",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1=None,
|
||||
brains=[brain],
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(new)
|
||||
await session.commit()
|
||||
await session.refresh(new)
|
||||
|
||||
incoming_knowledge = await new.to_dto()
|
||||
assert prev.file_sha1
|
||||
should_process = await service.update_sha1_conflict(
|
||||
incoming_knowledge, brain.brain_id, file_sha1=prev.file_sha1
|
||||
)
|
||||
|
||||
# Checks we should process this file
|
||||
assert should_process
|
||||
# Previous errored file is cleaned up
|
||||
with pytest.raises(NoResultFound):
|
||||
await service.repository.get_knowledge_by_id(prev.id)
|
||||
|
||||
assert new.id
|
||||
new = await service.repository.get_knowledge_by_id(new.id)
|
||||
assert new.file_sha1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData):
|
||||
brain, [knowledge, _] = test_data
|
||||
assert knowledge.file_name
|
||||
repository = KnowledgeRepository(session)
|
||||
service = KnowledgeService(repository)
|
||||
brain_2 = Brain(
|
||||
name="test_brain",
|
||||
description="this is a test brain",
|
||||
brain_type=BrainType.integration,
|
||||
)
|
||||
session.add(brain_2)
|
||||
await session.commit()
|
||||
await session.refresh(brain_2)
|
||||
assert brain_2.brain_id
|
||||
km_data = os.urandom(128)
|
||||
km_path = f"{str(knowledge.brains[0].brain_id)}/{knowledge.file_name}"
|
||||
await upload_file_storage(km_data, km_path)
|
||||
# Link knowledge to two brains
|
||||
await repository.link_to_brain(knowledge, brain_2.brain_id)
|
||||
storage_path = await service.get_knowledge_storage_path(
|
||||
knowledge.file_name, brain_2.brain_id
|
||||
)
|
||||
assert storage_path == km_path
|
@ -317,10 +317,11 @@ async def test_process_sync_file_noprev(
|
||||
created_km = all_km[0]
|
||||
assert created_km.file_name == sync_file.name
|
||||
assert created_km.extension == ".txt"
|
||||
assert created_km.file_sha1 is not None
|
||||
assert created_km.file_sha1 is None
|
||||
assert created_km.created_at is not None
|
||||
assert created_km.metadata == {"sync_file_id": "1"}
|
||||
assert created_km.brain_ids == [brain_1.brain_id]
|
||||
assert len(created_km.brains)> 0
|
||||
assert created_km.brains[0]["brain_id"]== brain_1.brain_id
|
||||
|
||||
# Assert celery task in correct
|
||||
assert task["args"] == ("process_file_task",)
|
||||
@ -409,12 +410,12 @@ async def test_process_sync_file_with_prev(
|
||||
created_km = all_km[0]
|
||||
assert created_km.file_name == sync_file.name
|
||||
assert created_km.extension == ".txt"
|
||||
assert created_km.file_sha1 is not None
|
||||
assert created_km.file_sha1 is None
|
||||
assert created_km.updated_at
|
||||
assert created_km.created_at
|
||||
assert created_km.updated_at == created_km.created_at # new line
|
||||
assert created_km.metadata == {"sync_file_id": str(dbfiles[0].id)}
|
||||
assert created_km.brain_ids == [brain_1.brain_id]
|
||||
assert created_km.brains[0]["brain_id"]== brain_1.brain_id
|
||||
|
||||
# Check file content changed
|
||||
assert check_file_exists(str(brain_1.brain_id), sync_file.name)
|
||||
|
@ -53,12 +53,10 @@ AsyncClientDep = Annotated[AsyncClient, Depends(get_supabase_async_client)]
|
||||
@upload_router.post("/upload", dependencies=[Depends(AuthBearer())], tags=["Upload"])
|
||||
async def upload_file(
|
||||
uploadFile: UploadFile,
|
||||
client: AsyncClientDep,
|
||||
background_tasks: BackgroundTasks,
|
||||
knowledge_service: KnowledgeServiceDep,
|
||||
background_tasks: BackgroundTasks,
|
||||
bulk_id: Optional[UUID] = Query(None, description="The ID of the bulk upload"),
|
||||
brain_id: UUID = Query(..., description="The ID of the brain"),
|
||||
chat_id: Optional[UUID] = Query(None, description="The ID of the chat"),
|
||||
current_user: UserIdentity = Depends(get_current_user),
|
||||
integration: Optional[str] = None,
|
||||
integration_link: Optional[str] = None,
|
||||
@ -121,7 +119,7 @@ async def upload_file(
|
||||
file_size=uploadFile.size,
|
||||
file_sha1=None,
|
||||
)
|
||||
knowledge = await knowledge_service.insert_knowledge(
|
||||
knowledge = await knowledge_service.insert_knowledge_brain(
|
||||
user_id=current_user.id, knowledge_to_add=knowledge_to_add
|
||||
) # type: ignore
|
||||
|
||||
|
@ -87,7 +87,7 @@ async def crawl_endpoint(
|
||||
source_link=crawl_website.url,
|
||||
)
|
||||
|
||||
added_knowledge = await knowledge_service.insert_knowledge(
|
||||
added_knowledge = await knowledge_service.insert_knowledge_brain(
|
||||
knowledge_to_add=knowledge_to_add, user_id=current_user.id
|
||||
)
|
||||
logger.info(f"Knowledge {added_knowledge} added successfully")
|
||||
|
50
backend/api/quivr_api/utils/partial.py
Normal file
50
backend/api/quivr_api/utils/partial.py
Normal file
@ -0,0 +1,50 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Optional, Type, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
Model = TypeVar("Model", bound=Type[BaseModel])
|
||||
|
||||
|
||||
def all_optional(without_fields: list[str] | None = None) -> Callable[[Model], Model]:
|
||||
if without_fields is None:
|
||||
without_fields = []
|
||||
|
||||
def wrapper(model: Type[Model]) -> Type[Model]:
|
||||
base_model: Type[Model] = model
|
||||
|
||||
def make_field_optional(
|
||||
field: FieldInfo, default: Any = None
|
||||
) -> tuple[Any, FieldInfo]:
|
||||
new = deepcopy(field)
|
||||
new.default = default
|
||||
new.annotation = Optional[field.annotation]
|
||||
return new.annotation, new
|
||||
|
||||
if without_fields:
|
||||
base_model = BaseModel
|
||||
|
||||
return create_model(
|
||||
model.__name__,
|
||||
__base__=base_model,
|
||||
__module__=model.__module__,
|
||||
**{
|
||||
field_name: make_field_optional(field_info)
|
||||
for field_name, field_info in model.model_fields.items()
|
||||
if field_name not in without_fields
|
||||
},
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Test(BaseModel):
|
||||
id: UUID
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
@all_optional()
|
||||
class TestUpdate(Test):
|
||||
pass
|
@ -42,6 +42,7 @@ class KnowledgeStatus(str, Enum):
|
||||
PROCESSING = "PROCESSING"
|
||||
UPLOADED = "UPLOADED"
|
||||
ERROR = "ERROR"
|
||||
RESERVED = "RESERVED"
|
||||
|
||||
|
||||
class Source(BaseModel):
|
||||
|
@ -0,0 +1,31 @@
|
||||
ALTER USER postgres
|
||||
SET idle_session_timeout = '3min';
|
||||
ALTER USER postgres
|
||||
SET idle_in_transaction_session_timeout = '3min';
|
||||
-- Drop previous contraint
|
||||
alter table "public"."knowledge" drop constraint "unique_file_sha1_user_id";
|
||||
alter table "public"."knowledge"
|
||||
add column "is_folder" boolean default false;
|
||||
-- Update the knowledge to backfill knowledge to is_folder = false
|
||||
UPDATE "public"."knowledge"
|
||||
SET is_folder = false;
|
||||
-- Add parent_id -> folder
|
||||
alter table "public"."knowledge"
|
||||
add column "parent_id" uuid;
|
||||
alter table "public"."knowledge"
|
||||
add constraint "public_knowledge_parent_id_fkey" FOREIGN KEY (parent_id) REFERENCES knowledge(id) ON DELETE CASCADE;
|
||||
-- Add constraint must be folder for parent_id
|
||||
CREATE FUNCTION is_parent_folder(folder_id uuid) RETURNS boolean AS $$ BEGIN RETURN (
|
||||
SELECT k.is_folder
|
||||
FROM public.knowledge k
|
||||
WHERE k.id = folder_id
|
||||
);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
ALTER TABLE public.knowledge
|
||||
ADD CONSTRAINT check_parent_is_folder CHECK (
|
||||
parent_id IS NULL
|
||||
OR is_parent_folder(parent_id)
|
||||
);
|
||||
-- Index on parent_id
|
||||
CREATE INDEX knowledge_parent_id_idx ON public.knowledge USING btree (parent_id);
|
@ -13,7 +13,7 @@ from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
from quivr_api.modules.dependencies import get_supabase_client
|
||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||
from quivr_api.modules.knowledge.repository.storage import Storage
|
||||
from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
|
||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||
from quivr_api.modules.notification.service.notification_service import (
|
||||
NotificationService,
|
||||
@ -58,7 +58,7 @@ sync_user_service = SyncUserService()
|
||||
sync_files_repo_service = SyncFilesRepository()
|
||||
brain_service = BrainService()
|
||||
brain_vectors = BrainsVectors()
|
||||
storage = Storage()
|
||||
storage = SupabaseS3Storage()
|
||||
notion_service: SyncNotionService | None = None
|
||||
async_engine: AsyncEngine | None = None
|
||||
engine: Engine | None = None
|
||||
@ -170,6 +170,8 @@ async def aprocess_file_task(
|
||||
integration_link=source_link,
|
||||
delete_file=delete_file,
|
||||
)
|
||||
session.commit()
|
||||
await async_session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
await async_session.rollback()
|
||||
@ -196,7 +198,11 @@ def process_crawl_task(
|
||||
)
|
||||
global engine
|
||||
assert engine
|
||||
try:
|
||||
with Session(engine, expire_on_commit=False, autoflush=False) as session:
|
||||
session.execute(
|
||||
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
|
||||
)
|
||||
vector_repository = VectorRepository(session)
|
||||
vector_service = VectorService(vector_repository)
|
||||
loop = asyncio.get_event_loop()
|
||||
@ -209,6 +215,12 @@ def process_crawl_task(
|
||||
vector_service=vector_service,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise e
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
@celery.task(name="NotionConnectorLoad")
|
||||
|
@ -2,6 +2,7 @@ from uuid import UUID
|
||||
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeUpdate
|
||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||
from quivr_api.modules.vector.service.vector_service import VectorService
|
||||
|
||||
@ -41,12 +42,10 @@ async def process_uploaded_file(
|
||||
# If we have some knowledge with error
|
||||
with build_file(file_data, knowledge_id, file_name) as file_instance:
|
||||
knowledge = await knowledge_service.get_knowledge(knowledge_id=knowledge_id)
|
||||
should_process = await knowledge_service.update_sha1_conflict(
|
||||
knowledge=knowledge,
|
||||
brain_id=brain.brain_id,
|
||||
file_sha1=file_instance.file_sha1,
|
||||
await knowledge_service.update_knowledge(
|
||||
knowledge,
|
||||
KnowledgeUpdate(file_sha1=file_instance.file_sha1), # type: ignore
|
||||
)
|
||||
if should_process:
|
||||
await process_file(
|
||||
file_instance=file_instance,
|
||||
brain=brain,
|
||||
|
@ -141,7 +141,7 @@ async def process_notion_sync(
|
||||
UUID(user_id),
|
||||
notion_client, # type: ignore
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
@ -40,6 +40,8 @@ async def fetch_and_store_notion_files_async(
|
||||
else:
|
||||
logger.warn("No notion page fetched")
|
||||
|
||||
# Commit all before exiting
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
@ -6,7 +6,7 @@ from quivr_api.celery_config import celery
|
||||
from quivr_api.logger import get_logger
|
||||
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
|
||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||
from quivr_api.modules.knowledge.repository.storage import Storage
|
||||
from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
|
||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||
from quivr_api.modules.notification.service.notification_service import (
|
||||
NotificationService,
|
||||
@ -42,7 +42,7 @@ class SyncServices:
|
||||
sync_files_repo_service: SyncFilesRepository
|
||||
notification_service: NotificationService
|
||||
brain_vectors: BrainsVectors
|
||||
storage: Storage
|
||||
storage: SupabaseS3Storage
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@ -56,7 +56,6 @@ async def build_syncs_utils(
|
||||
await session.execute(
|
||||
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
|
||||
)
|
||||
# TODO pass services from celery_worker
|
||||
notion_repository = NotionRepository(session)
|
||||
notion_service = SyncNotionService(notion_repository)
|
||||
knowledge_service = KnowledgeService(KnowledgeRepository(session))
|
||||
@ -84,7 +83,7 @@ async def build_syncs_utils(
|
||||
mapping_sync_utils[provider_name] = provider_sync_util
|
||||
|
||||
yield mapping_sync_utils
|
||||
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
Loading…
Reference in New Issue
Block a user