mirror of
https://github.com/QuivrHQ/quivr.git
synced 2025-01-05 23:03:53 +03:00
feat: CRUD KMS (no syncs) (#3162)
# Description closes #3056. closes #3198 - Create knowledge route - Get knowledge route - List knowledge route : accepts knowledge_id | None. None to list root knowledge for use - Update (patch) knowledge to rename and move knowledge - Remove knowledge: Cascade if parent_id in knowledge and cleanup storage - Link storage upload to knowledge_service - Relax sha1 file constraint - Tests to all repository / service --------- Co-authored-by: Stan Girard <girard.stanislas@gmail.com>
This commit is contained in:
parent
edc4118ba1
commit
71edca572f
6
.github/workflows/backend-tests.yml
vendored
6
.github/workflows/backend-tests.yml
vendored
@ -9,7 +9,9 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
project: [quivr-api, quivr-worker]
|
||||||
steps:
|
steps:
|
||||||
- name: 👀 Checkout code
|
- name: 👀 Checkout code
|
||||||
uses: actions/checkout@v2
|
uses: actions/checkout@v2
|
||||||
@ -65,4 +67,4 @@ jobs:
|
|||||||
supabase start
|
supabase start
|
||||||
rye run python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()"
|
rye run python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()"
|
||||||
rye run python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
rye run python -c "import nltk;nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')"
|
||||||
rye test -p quivr-api -p quivr-worker
|
rye test -p ${{ matrix.project }}
|
||||||
|
@ -32,13 +32,6 @@ repos:
|
|||||||
- id: mypy
|
- id: mypy
|
||||||
name: mypy
|
name: mypy
|
||||||
additional_dependencies: ["types-aiofiles"]
|
additional_dependencies: ["types-aiofiles"]
|
||||||
- repo: https://github.com/python-poetry/poetry
|
|
||||||
rev: "1.8.0"
|
|
||||||
hooks:
|
|
||||||
- id: poetry-check
|
|
||||||
args: ["-C", "./backend/core"]
|
|
||||||
- id: poetry-lock
|
|
||||||
args: ["-C", "./backend/core"]
|
|
||||||
ci:
|
ci:
|
||||||
autofix_commit_msg: |
|
autofix_commit_msg: |
|
||||||
[pre-commit.ci] auto fixes from pre-commit.com hooks
|
[pre-commit.ci] auto fixes from pre-commit.com hooks
|
||||||
|
@ -3,6 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from fastapi import Depends, HTTPException, Request
|
from fastapi import Depends, HTTPException, Request
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
|
||||||
from quivr_api.middlewares.auth.jwt_token_handler import (
|
from quivr_api.middlewares.auth.jwt_token_handler import (
|
||||||
decode_access_token,
|
decode_access_token,
|
||||||
verify_token,
|
verify_token,
|
||||||
@ -57,9 +58,13 @@ class AuthBearer(HTTPBearer):
|
|||||||
|
|
||||||
def get_test_user(self) -> UserIdentity:
|
def get_test_user(self) -> UserIdentity:
|
||||||
return UserIdentity(
|
return UserIdentity(
|
||||||
email="admin@quivr.app", id="39418e3b-0258-4452-af60-7acfcc1263ff" # type: ignore
|
email="admin@quivr.app",
|
||||||
|
id="39418e3b-0258-4452-af60-7acfcc1263ff", # type: ignore
|
||||||
) # replace with test user information
|
) # replace with test user information
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(user: UserIdentity = Depends(AuthBearer())) -> UserIdentity:
|
auth_bearer = AuthBearer()
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user(user: UserIdentity = Depends(auth_bearer)) -> UserIdentity:
|
||||||
return user
|
return user
|
||||||
|
@ -69,6 +69,7 @@ class Brain(AsyncAttrs, SQLModel, table=True):
|
|||||||
back_populates="brains", link_model=KnowledgeBrain
|
back_populates="brains", link_model=KnowledgeBrain
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO : add
|
# TODO : add
|
||||||
# "meaning" "public"."vector",
|
# "meaning" "public"."vector",
|
||||||
# "tags" "public"."tags"[]
|
# "tags" "public"."tags"[]
|
||||||
|
@ -2,7 +2,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from quivr_api.logger import get_logger
|
from quivr_api.logger import get_logger
|
||||||
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
|
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
|
||||||
from quivr_api.modules.knowledge.repository.storage import Storage
|
from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -11,7 +11,7 @@ class BrainVectorService:
|
|||||||
def __init__(self, brain_id: UUID):
|
def __init__(self, brain_id: UUID):
|
||||||
self.repository = BrainsVectors()
|
self.repository = BrainsVectors()
|
||||||
self.brain_id = brain_id
|
self.brain_id = brain_id
|
||||||
self.storage = Storage()
|
self.storage = SupabaseS3Storage()
|
||||||
|
|
||||||
def create_brain_vector(self, vector_id: str, file_sha1: str):
|
def create_brain_vector(self, vector_id: str, file_sha1: str):
|
||||||
return self.repository.create_brain_vector(self.brain_id, vector_id, file_sha1) # type: ignore
|
return self.repository.create_brain_vector(self.brain_id, vector_id, file_sha1) # type: ignore
|
||||||
@ -26,10 +26,10 @@ class BrainVectorService:
|
|||||||
for vector_id in vector_ids:
|
for vector_id in vector_ids:
|
||||||
self.create_brain_vector(vector_id, file_sha1)
|
self.create_brain_vector(vector_id, file_sha1)
|
||||||
|
|
||||||
def delete_file_from_brain(self, file_name: str, only_vectors: bool = False):
|
async def delete_file_from_brain(self, file_name: str, only_vectors: bool = False):
|
||||||
file_name_with_brain_id = f"{self.brain_id}/{file_name}"
|
file_name_with_brain_id = f"{self.brain_id}/{file_name}"
|
||||||
if not only_vectors:
|
if not only_vectors:
|
||||||
self.storage.remove_file(file_name_with_brain_id)
|
await self.storage.remove_file(file_name_with_brain_id)
|
||||||
return self.repository.delete_file_from_brain(self.brain_id, file_name) # type: ignore
|
return self.repository.delete_file_from_brain(self.brain_id, file_name) # type: ignore
|
||||||
|
|
||||||
def delete_file_url_from_brain(self, file_name: str):
|
def delete_file_url_from_brain(self, file_name: str):
|
||||||
|
@ -24,9 +24,6 @@ async_engine = create_async_engine(
|
|||||||
"postgresql+asyncpg://" + pg_database_base_url,
|
"postgresql+asyncpg://" + pg_database_base_url,
|
||||||
echo=True if os.getenv("ORM_DEBUG") else False,
|
echo=True if os.getenv("ORM_DEBUG") else False,
|
||||||
future=True,
|
future=True,
|
||||||
pool_pre_ping=True,
|
|
||||||
pool_size=10,
|
|
||||||
pool_recycle=0.1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,8 +7,7 @@ from langchain.embeddings.base import Embeddings
|
|||||||
from langchain_community.embeddings.ollama import OllamaEmbeddings
|
from langchain_community.embeddings.ollama import OllamaEmbeddings
|
||||||
|
|
||||||
# from langchain_community.vectorstores.supabase import SupabaseVectorStore
|
# from langchain_community.vectorstores.supabase import SupabaseVectorStore
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||||
from langchain_openai import AzureOpenAIEmbeddings
|
|
||||||
|
|
||||||
# from quivr_api.modules.vector.service.vector_service import VectorService
|
# from quivr_api.modules.vector.service.vector_service import VectorService
|
||||||
# from quivr_api.modules.vectorstore.supabase import CustomSupabaseVectorStore
|
# from quivr_api.modules.vectorstore.supabase import CustomSupabaseVectorStore
|
||||||
@ -22,7 +21,6 @@ from quivr_api.models.databases.supabase.supabase import SupabaseDB
|
|||||||
from quivr_api.models.settings import BrainSettings
|
from quivr_api.models.settings import BrainSettings
|
||||||
from supabase.client import AsyncClient, Client, create_async_client, create_client
|
from supabase.client import AsyncClient, Client, create_async_client, create_client
|
||||||
|
|
||||||
|
|
||||||
# Global variables to store the Supabase client and database instances
|
# Global variables to store the Supabase client and database instances
|
||||||
_supabase_client: Optional[Client] = None
|
_supabase_client: Optional[Client] = None
|
||||||
_supabase_async_client: Optional[AsyncClient] = None
|
_supabase_async_client: Optional[AsyncClient] = None
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Annotated
|
from typing import Annotated, List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile, status
|
||||||
|
|
||||||
from quivr_api.logger import get_logger
|
from quivr_api.logger import get_logger
|
||||||
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
||||||
@ -12,6 +12,14 @@ from quivr_api.modules.brain.service.brain_authorization_service import (
|
|||||||
validate_brain_authorization,
|
validate_brain_authorization,
|
||||||
)
|
)
|
||||||
from quivr_api.modules.dependencies import get_service
|
from quivr_api.modules.dependencies import get_service
|
||||||
|
from quivr_api.modules.knowledge.dto.inputs import AddKnowledge
|
||||||
|
from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeUpdate
|
||||||
|
from quivr_api.modules.knowledge.service.knowledge_exceptions import (
|
||||||
|
KnowledgeDeleteError,
|
||||||
|
KnowledgeForbiddenAccess,
|
||||||
|
KnowledgeNotFoundException,
|
||||||
|
UploadError,
|
||||||
|
)
|
||||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||||
from quivr_api.modules.upload.service.generate_file_signed_url import (
|
from quivr_api.modules.upload.service.generate_file_signed_url import (
|
||||||
generate_file_signed_url,
|
generate_file_signed_url,
|
||||||
@ -21,9 +29,8 @@ from quivr_api.modules.user.entity.user_identity import UserIdentity
|
|||||||
knowledge_router = APIRouter()
|
knowledge_router = APIRouter()
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
KnowledgeServiceDep = Annotated[
|
get_km_service = get_service(KnowledgeService)
|
||||||
KnowledgeService, Depends(get_service(KnowledgeService))
|
KnowledgeServiceDep = Annotated[KnowledgeService, Depends(get_km_service)]
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@knowledge_router.get(
|
@knowledge_router.get(
|
||||||
@ -53,7 +60,7 @@ async def list_knowledge_in_brain_endpoint(
|
|||||||
],
|
],
|
||||||
tags=["Knowledge"],
|
tags=["Knowledge"],
|
||||||
)
|
)
|
||||||
async def delete_endpoint(
|
async def delete_knowledge_brain(
|
||||||
knowledge_id: UUID,
|
knowledge_id: UUID,
|
||||||
knowledge_service: KnowledgeServiceDep,
|
knowledge_service: KnowledgeServiceDep,
|
||||||
current_user: UserIdentity = Depends(get_current_user),
|
current_user: UserIdentity = Depends(get_current_user),
|
||||||
@ -65,7 +72,7 @@ async def delete_endpoint(
|
|||||||
|
|
||||||
knowledge = await knowledge_service.get_knowledge(knowledge_id)
|
knowledge = await knowledge_service.get_knowledge(knowledge_id)
|
||||||
file_name = knowledge.file_name if knowledge.file_name else knowledge.url
|
file_name = knowledge.file_name if knowledge.file_name else knowledge.url
|
||||||
await knowledge_service.remove_knowledge(brain_id, knowledge_id)
|
await knowledge_service.remove_knowledge_brain(brain_id, knowledge_id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"message": f"{file_name} of brain {brain_id} has been deleted by user {current_user.email}."
|
"message": f"{file_name} of brain {brain_id} has been deleted by user {current_user.email}."
|
||||||
@ -88,13 +95,13 @@ async def generate_signed_url_endpoint(
|
|||||||
|
|
||||||
knowledge = await knowledge_service.get_knowledge(knowledge_id)
|
knowledge = await knowledge_service.get_knowledge(knowledge_id)
|
||||||
|
|
||||||
if len(knowledge.brain_ids) == 0:
|
if len(knowledge.brains) == 0:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.NOT_FOUND,
|
status_code=HTTPStatus.NOT_FOUND,
|
||||||
detail="knowledge not associated with brains yet.",
|
detail="knowledge not associated with brains yet.",
|
||||||
)
|
)
|
||||||
|
|
||||||
brain_id = knowledge.brain_ids[0]
|
brain_id = knowledge.brains[0]["brain_id"]
|
||||||
|
|
||||||
validate_brain_authorization(brain_id=brain_id, user_id=current_user.id)
|
validate_brain_authorization(brain_id=brain_id, user_id=current_user.id)
|
||||||
|
|
||||||
@ -108,3 +115,153 @@ async def generate_signed_url_endpoint(
|
|||||||
file_signed_url = generate_file_signed_url(file_path_in_storage)
|
file_signed_url = generate_file_signed_url(file_path_in_storage)
|
||||||
|
|
||||||
return file_signed_url
|
return file_signed_url
|
||||||
|
|
||||||
|
|
||||||
|
@knowledge_router.post(
|
||||||
|
"/knowledge/",
|
||||||
|
tags=["Knowledge"],
|
||||||
|
response_model=Knowledge,
|
||||||
|
)
|
||||||
|
async def create_knowledge(
|
||||||
|
knowledge_data: str = File(...),
|
||||||
|
file: Optional[UploadFile] = None,
|
||||||
|
knowledge_service: KnowledgeService = Depends(get_km_service),
|
||||||
|
current_user: UserIdentity = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
knowledge = AddKnowledge.model_validate_json(knowledge_data)
|
||||||
|
if not knowledge.file_name and not knowledge.url:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Either file_name or url must be provided",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
km = await knowledge_service.create_knowledge(
|
||||||
|
knowledge_to_add=knowledge, upload_file=file, user_id=current_user.id
|
||||||
|
)
|
||||||
|
km_dto = await km.to_dto()
|
||||||
|
return km_dto
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail="Unprocessable knowledge ",
|
||||||
|
)
|
||||||
|
except FileExistsError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT, detail="Existing knowledge"
|
||||||
|
)
|
||||||
|
except UploadError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Error occured uploading knowledge",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
@knowledge_router.get(
|
||||||
|
"/knowledge/children",
|
||||||
|
response_model=List[Knowledge] | None,
|
||||||
|
tags=["Knowledge"],
|
||||||
|
)
|
||||||
|
async def list_knowledge(
|
||||||
|
parent_id: UUID | None = None,
|
||||||
|
knowledge_service: KnowledgeService = Depends(get_km_service),
|
||||||
|
current_user: UserIdentity = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# TODO: Returns one level of children
|
||||||
|
children = await knowledge_service.list_knowledge(parent_id, current_user.id)
|
||||||
|
return [await c.to_dto(get_children=False) for c in children]
|
||||||
|
except KnowledgeNotFoundException as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN, detail=f"{e.message}"
|
||||||
|
)
|
||||||
|
except KnowledgeForbiddenAccess as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
@knowledge_router.get(
|
||||||
|
"/knowledge/{knowledge_id}",
|
||||||
|
response_model=Knowledge,
|
||||||
|
tags=["Knowledge"],
|
||||||
|
)
|
||||||
|
async def get_knowledge(
|
||||||
|
knowledge_id: UUID,
|
||||||
|
knowledge_service: KnowledgeService = Depends(get_km_service),
|
||||||
|
current_user: UserIdentity = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
km = await knowledge_service.get_knowledge(knowledge_id)
|
||||||
|
if km.user_id != current_user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="You do not have permission to access this knowledge.",
|
||||||
|
)
|
||||||
|
return await km.to_dto()
|
||||||
|
except KnowledgeNotFoundException as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
@knowledge_router.patch(
|
||||||
|
"/knowledge/{knowledge_id}",
|
||||||
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
|
response_model=Knowledge,
|
||||||
|
tags=["Knowledge"],
|
||||||
|
)
|
||||||
|
async def update_knowledge(
|
||||||
|
knowledge_id: UUID,
|
||||||
|
payload: KnowledgeUpdate,
|
||||||
|
knowledge_service: KnowledgeService = Depends(get_km_service),
|
||||||
|
current_user: UserIdentity = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
km = await knowledge_service.get_knowledge(knowledge_id)
|
||||||
|
if km.user_id != current_user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="You do not have permission to access this knowledge.",
|
||||||
|
)
|
||||||
|
km = await knowledge_service.update_knowledge(km, payload)
|
||||||
|
return km
|
||||||
|
except KnowledgeNotFoundException as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
@knowledge_router.delete(
|
||||||
|
"/knowledge/{knowledge_id}",
|
||||||
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
|
tags=["Knowledge"],
|
||||||
|
)
|
||||||
|
async def delete_knowledge(
|
||||||
|
knowledge_id: UUID,
|
||||||
|
knowledge_service: KnowledgeService = Depends(get_km_service),
|
||||||
|
current_user: UserIdentity = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
km = await knowledge_service.get_knowledge(knowledge_id)
|
||||||
|
|
||||||
|
if km.user_id != current_user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="You do not have permission to remove this knowledge.",
|
||||||
|
)
|
||||||
|
delete_response = await knowledge_service.remove_knowledge(km)
|
||||||
|
return delete_response
|
||||||
|
except KnowledgeNotFoundException as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}"
|
||||||
|
)
|
||||||
|
except KnowledgeDeleteError:
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
@ -16,8 +16,16 @@ class CreateKnowledgeProperties(BaseModel):
|
|||||||
file_size: Optional[int] = None
|
file_size: Optional[int] = None
|
||||||
file_sha1: Optional[str] = None
|
file_sha1: Optional[str] = None
|
||||||
metadata: Optional[Dict[str, str]] = None
|
metadata: Optional[Dict[str, str]] = None
|
||||||
|
is_folder: bool = False
|
||||||
|
parent_id: Optional[UUID] = None
|
||||||
|
|
||||||
def dict(self, *args, **kwargs):
|
|
||||||
knowledge_dict = super().dict(*args, **kwargs)
|
class AddKnowledge(BaseModel):
|
||||||
knowledge_dict["brain_id"] = str(knowledge_dict.get("brain_id"))
|
file_name: Optional[str] = None
|
||||||
return knowledge_dict
|
url: Optional[str] = None
|
||||||
|
extension: str = ".txt"
|
||||||
|
source: str = "local"
|
||||||
|
source_link: Optional[str] = None
|
||||||
|
metadata: Optional[Dict[str, str]] = None
|
||||||
|
is_folder: bool = False
|
||||||
|
parent_id: Optional[UUID] = None
|
||||||
|
@ -4,6 +4,6 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
|
|
||||||
class DeleteKnowledgeResponse(BaseModel):
|
class DeleteKnowledgeResponse(BaseModel):
|
||||||
file_name: str
|
file_name: str | None = None
|
||||||
status: str = "delete"
|
status: str = "DELETED"
|
||||||
knowledge_id: UUID
|
knowledge_id: UUID
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -12,20 +13,44 @@ from sqlmodel import Field, Relationship, SQLModel
|
|||||||
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
|
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeSource(str, Enum):
|
||||||
|
LOCAL = "local"
|
||||||
|
WEB = "web"
|
||||||
|
GDRIVE = "google drive"
|
||||||
|
DROPBOX = "dropbox"
|
||||||
|
SHAREPOINT = "sharepoint"
|
||||||
|
|
||||||
|
|
||||||
class Knowledge(BaseModel):
|
class Knowledge(BaseModel):
|
||||||
id: UUID
|
id: UUID
|
||||||
|
file_size: int = 0
|
||||||
|
status: KnowledgeStatus
|
||||||
file_name: Optional[str] = None
|
file_name: Optional[str] = None
|
||||||
url: Optional[str] = None
|
url: Optional[str] = None
|
||||||
extension: str = ".txt"
|
extension: str = ".txt"
|
||||||
status: str
|
is_folder: bool = False
|
||||||
|
updated_at: datetime
|
||||||
|
created_at: datetime
|
||||||
source: Optional[str] = None
|
source: Optional[str] = None
|
||||||
source_link: Optional[str] = None
|
source_link: Optional[str] = None
|
||||||
file_size: Optional[int] = None
|
|
||||||
file_sha1: Optional[str] = None
|
file_sha1: Optional[str] = None
|
||||||
updated_at: Optional[datetime] = None
|
|
||||||
created_at: Optional[datetime] = None
|
|
||||||
metadata: Optional[Dict[str, str]] = None
|
metadata: Optional[Dict[str, str]] = None
|
||||||
brain_ids: list[UUID]
|
user_id: UUID
|
||||||
|
brains: List[Dict[str, Any]]
|
||||||
|
parent: Optional["Knowledge"]
|
||||||
|
children: Optional[list["Knowledge"]]
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeUpdate(BaseModel):
|
||||||
|
file_name: Optional[str] = None
|
||||||
|
status: Optional[KnowledgeStatus] = None
|
||||||
|
url: Optional[str] = None
|
||||||
|
file_sha1: Optional[str] = None
|
||||||
|
extension: Optional[str] = None
|
||||||
|
parent_id: Optional[UUID] = None
|
||||||
|
source: Optional[str] = None
|
||||||
|
source_link: Optional[str] = None
|
||||||
|
metadata: Optional[Dict[str, str]] = None
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
|
class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
|
||||||
@ -49,13 +74,6 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
|
|||||||
file_sha1: Optional[str] = Field(
|
file_sha1: Optional[str] = Field(
|
||||||
max_length=40
|
max_length=40
|
||||||
) # FIXME: Should not be optional @chloedia
|
) # FIXME: Should not be optional @chloedia
|
||||||
updated_at: datetime | None = Field(
|
|
||||||
default=None,
|
|
||||||
sa_column=Column(
|
|
||||||
TIMESTAMP(timezone=False),
|
|
||||||
server_default=text("CURRENT_TIMESTAMP"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
created_at: datetime | None = Field(
|
created_at: datetime | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
sa_column=Column(
|
sa_column=Column(
|
||||||
@ -63,9 +81,18 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
|
|||||||
server_default=text("CURRENT_TIMESTAMP"),
|
server_default=text("CURRENT_TIMESTAMP"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
updated_at: datetime | None = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column=Column(
|
||||||
|
TIMESTAMP(timezone=False),
|
||||||
|
server_default=text("CURRENT_TIMESTAMP"),
|
||||||
|
onupdate=datetime.utcnow,
|
||||||
|
),
|
||||||
|
)
|
||||||
metadata_: Optional[Dict[str, str]] = Field(
|
metadata_: Optional[Dict[str, str]] = Field(
|
||||||
default=None, sa_column=Column("metadata", JSON)
|
default=None, sa_column=Column("metadata", JSON)
|
||||||
)
|
)
|
||||||
|
is_folder: bool = Field(default=False)
|
||||||
user_id: UUID = Field(foreign_key="users.id", nullable=False)
|
user_id: UUID = Field(foreign_key="users.id", nullable=False)
|
||||||
brains: List["Brain"] = Relationship(
|
brains: List["Brain"] = Relationship(
|
||||||
back_populates="knowledges",
|
back_populates="knowledges",
|
||||||
@ -73,10 +100,35 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
|
|||||||
sa_relationship_kwargs={"lazy": "select"},
|
sa_relationship_kwargs={"lazy": "select"},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def to_dto(self) -> Knowledge:
|
parent_id: UUID | None = Field(
|
||||||
|
default=None, foreign_key="knowledge.id", ondelete="CASCADE"
|
||||||
|
)
|
||||||
|
parent: Optional["KnowledgeDB"] = Relationship(
|
||||||
|
back_populates="children",
|
||||||
|
sa_relationship_kwargs={"remote_side": "KnowledgeDB.id"},
|
||||||
|
)
|
||||||
|
children: list["KnowledgeDB"] = Relationship(
|
||||||
|
back_populates="parent",
|
||||||
|
sa_relationship_kwargs={
|
||||||
|
"cascade": "all, delete-orphan",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: nested folder search
|
||||||
|
async def to_dto(self, get_children: bool = True) -> Knowledge:
|
||||||
|
assert (
|
||||||
|
self.updated_at
|
||||||
|
), "knowledge should be inserted before transforming to dto"
|
||||||
|
assert (
|
||||||
|
self.created_at
|
||||||
|
), "knowledge should be inserted before transforming to dto"
|
||||||
brains = await self.awaitable_attrs.brains
|
brains = await self.awaitable_attrs.brains
|
||||||
size = self.file_size if self.file_size else 0
|
children: list[KnowledgeDB] = (
|
||||||
sha1 = self.file_sha1 if self.file_sha1 else ""
|
await self.awaitable_attrs.children if get_children else []
|
||||||
|
)
|
||||||
|
parent = await self.awaitable_attrs.parent
|
||||||
|
parent = await parent.to_dto(get_children=False) if parent else None
|
||||||
|
|
||||||
return Knowledge(
|
return Knowledge(
|
||||||
id=self.id, # type: ignore
|
id=self.id, # type: ignore
|
||||||
file_name=self.file_name,
|
file_name=self.file_name,
|
||||||
@ -85,10 +137,14 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
|
|||||||
status=KnowledgeStatus(self.status),
|
status=KnowledgeStatus(self.status),
|
||||||
source=self.source,
|
source=self.source,
|
||||||
source_link=self.source_link,
|
source_link=self.source_link,
|
||||||
file_size=size,
|
is_folder=self.is_folder,
|
||||||
file_sha1=sha1,
|
file_size=self.file_size or 0,
|
||||||
|
file_sha1=self.file_sha1,
|
||||||
updated_at=self.updated_at,
|
updated_at=self.updated_at,
|
||||||
created_at=self.created_at,
|
created_at=self.created_at,
|
||||||
metadata=self.metadata_, # type: ignore
|
metadata=self.metadata_, # type: ignore
|
||||||
brain_ids=[brain.brain_id for brain in brains],
|
brains=[b.model_dump() for b in brains],
|
||||||
|
parent=parent,
|
||||||
|
children=[await c.to_dto(get_children=False) for c in children],
|
||||||
|
user_id=self.user_id,
|
||||||
)
|
)
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
from typing import Sequence
|
from typing import Any, Sequence
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from quivr_core.models import KnowledgeStatus
|
from quivr_core.models import KnowledgeStatus
|
||||||
from sqlalchemy.exc import IntegrityError, NoResultFound
|
from sqlalchemy.exc import IntegrityError, NoResultFound
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
from sqlmodel import select, text
|
from sqlmodel import select, text
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@ -11,7 +12,15 @@ from quivr_api.logger import get_logger
|
|||||||
from quivr_api.modules.brain.entity.brain_entity import Brain
|
from quivr_api.modules.brain.entity.brain_entity import Brain
|
||||||
from quivr_api.modules.dependencies import BaseRepository, get_supabase_client
|
from quivr_api.modules.dependencies import BaseRepository, get_supabase_client
|
||||||
from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse
|
from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse
|
||||||
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
from quivr_api.modules.knowledge.entity.knowledge import (
|
||||||
|
Knowledge,
|
||||||
|
KnowledgeDB,
|
||||||
|
KnowledgeUpdate,
|
||||||
|
)
|
||||||
|
from quivr_api.modules.knowledge.service.knowledge_exceptions import (
|
||||||
|
KnowledgeNotFoundException,
|
||||||
|
KnowledgeUpdateError,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -22,7 +31,43 @@ class KnowledgeRepository(BaseRepository):
|
|||||||
supabase_client = get_supabase_client()
|
supabase_client = get_supabase_client()
|
||||||
self.db = supabase_client
|
self.db = supabase_client
|
||||||
|
|
||||||
async def insert_knowledge(
|
async def create_knowledge(self, knowledge: KnowledgeDB) -> KnowledgeDB:
|
||||||
|
try:
|
||||||
|
self.session.add(knowledge)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(knowledge)
|
||||||
|
except IntegrityError:
|
||||||
|
await self.session.rollback()
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
await self.session.rollback()
|
||||||
|
raise
|
||||||
|
return knowledge
|
||||||
|
|
||||||
|
async def update_knowledge(
|
||||||
|
self,
|
||||||
|
knowledge: KnowledgeDB,
|
||||||
|
payload: Knowledge | KnowledgeUpdate | dict[str, Any],
|
||||||
|
) -> KnowledgeDB:
|
||||||
|
try:
|
||||||
|
logger.debug(f"updating {knowledge.id} with payload {payload}")
|
||||||
|
if isinstance(payload, dict):
|
||||||
|
update_data = payload
|
||||||
|
else:
|
||||||
|
update_data = payload.model_dump(exclude_unset=True)
|
||||||
|
for field in update_data:
|
||||||
|
setattr(knowledge, field, update_data[field])
|
||||||
|
|
||||||
|
self.session.add(knowledge)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(knowledge)
|
||||||
|
return knowledge
|
||||||
|
except IntegrityError as e:
|
||||||
|
await self.session.rollback()
|
||||||
|
logger.error(f"Error updating knowledge {e}")
|
||||||
|
raise KnowledgeUpdateError
|
||||||
|
|
||||||
|
async def insert_knowledge_brain(
|
||||||
self, knowledge: KnowledgeDB, brain_id: UUID
|
self, knowledge: KnowledgeDB, brain_id: UUID
|
||||||
) -> KnowledgeDB:
|
) -> KnowledgeDB:
|
||||||
logger.debug(f"Inserting knowledge {knowledge}")
|
logger.debug(f"Inserting knowledge {knowledge}")
|
||||||
@ -69,6 +114,14 @@ class KnowledgeRepository(BaseRepository):
|
|||||||
await self.session.refresh(knowledge)
|
await self.session.refresh(knowledge)
|
||||||
return knowledge
|
return knowledge
|
||||||
|
|
||||||
|
async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeResponse:
|
||||||
|
assert knowledge.id
|
||||||
|
await self.session.delete(knowledge)
|
||||||
|
await self.session.commit()
|
||||||
|
return DeleteKnowledgeResponse(
|
||||||
|
status="deleted", knowledge_id=knowledge.id, file_name=knowledge.file_name
|
||||||
|
)
|
||||||
|
|
||||||
async def remove_knowledge_by_id(
|
async def remove_knowledge_by_id(
|
||||||
self, knowledge_id: UUID
|
self, knowledge_id: UUID
|
||||||
) -> DeleteKnowledgeResponse:
|
) -> DeleteKnowledgeResponse:
|
||||||
@ -126,14 +179,70 @@ class KnowledgeRepository(BaseRepository):
|
|||||||
|
|
||||||
return knowledge
|
return knowledge
|
||||||
|
|
||||||
async def get_knowledge_by_id(self, knowledge_id: UUID) -> KnowledgeDB:
|
async def get_all_children(self, parent_id: UUID) -> list[KnowledgeDB]:
|
||||||
query = select(KnowledgeDB).where(KnowledgeDB.id == knowledge_id)
|
query = text("""
|
||||||
|
WITH RECURSIVE knowledge_tree AS (
|
||||||
|
SELECT *
|
||||||
|
FROM knowledge
|
||||||
|
WHERE parent_id = :parent_id
|
||||||
|
UNION ALL
|
||||||
|
SELECT k.*
|
||||||
|
FROM knowledge k
|
||||||
|
JOIN knowledge_tree kt ON k.parent_id = kt.id
|
||||||
|
)
|
||||||
|
SELECT * FROM knowledge_tree
|
||||||
|
""")
|
||||||
|
|
||||||
|
result = await self.session.execute(query, params={"parent_id": parent_id})
|
||||||
|
rows = result.fetchall()
|
||||||
|
knowledge_list = []
|
||||||
|
for row in rows:
|
||||||
|
knowledge = KnowledgeDB(
|
||||||
|
id=row.id,
|
||||||
|
parent_id=row.parent_id,
|
||||||
|
file_name=row.file_name,
|
||||||
|
url=row.url,
|
||||||
|
extension=row.extension,
|
||||||
|
status=row.status,
|
||||||
|
source=row.source,
|
||||||
|
source_link=row.source_link,
|
||||||
|
file_size=row.file_size,
|
||||||
|
file_sha1=row.file_sha1,
|
||||||
|
created_at=row.created_at,
|
||||||
|
updated_at=row.updated_at,
|
||||||
|
metadata_=row.metadata,
|
||||||
|
is_folder=row.is_folder,
|
||||||
|
user_id=row.user_id,
|
||||||
|
)
|
||||||
|
knowledge_list.append(knowledge)
|
||||||
|
|
||||||
|
return knowledge_list
|
||||||
|
|
||||||
|
async def get_root_knowledge_user(self, user_id: UUID) -> list[KnowledgeDB]:
|
||||||
|
query = (
|
||||||
|
select(KnowledgeDB)
|
||||||
|
.where(KnowledgeDB.parent_id.is_(None)) # type: ignore
|
||||||
|
.where(KnowledgeDB.user_id == user_id)
|
||||||
|
.options(joinedload(KnowledgeDB.parent), joinedload(KnowledgeDB.children)) # type: ignore
|
||||||
|
)
|
||||||
|
result = await self.session.exec(query)
|
||||||
|
kms = result.unique().all()
|
||||||
|
return list(kms)
|
||||||
|
|
||||||
|
async def get_knowledge_by_id(
|
||||||
|
self, knowledge_id: UUID, user_id: UUID | None = None
|
||||||
|
) -> KnowledgeDB:
|
||||||
|
query = (
|
||||||
|
select(KnowledgeDB)
|
||||||
|
.where(KnowledgeDB.id == knowledge_id)
|
||||||
|
.options(joinedload(KnowledgeDB.parent), joinedload(KnowledgeDB.children)) # type: ignore
|
||||||
|
)
|
||||||
|
if user_id:
|
||||||
|
query = query.where(KnowledgeDB.user_id == user_id)
|
||||||
result = await self.session.exec(query)
|
result = await self.session.exec(query)
|
||||||
knowledge = result.first()
|
knowledge = result.first()
|
||||||
|
|
||||||
if not knowledge:
|
if not knowledge:
|
||||||
raise NoResultFound("Knowledge not found")
|
raise KnowledgeNotFoundException("Knowledge not found")
|
||||||
|
|
||||||
return knowledge
|
return knowledge
|
||||||
|
|
||||||
async def get_brain_by_id(self, brain_id: UUID) -> Brain:
|
async def get_brain_by_id(self, brain_id: UUID) -> Brain:
|
||||||
|
@ -1,29 +1,87 @@
|
|||||||
|
import mimetypes
|
||||||
|
from io import BufferedReader, FileIO
|
||||||
|
|
||||||
from quivr_api.logger import get_logger
|
from quivr_api.logger import get_logger
|
||||||
from quivr_api.modules.dependencies import get_supabase_client
|
from quivr_api.modules.dependencies import get_supabase_async_client
|
||||||
|
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
||||||
from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface
|
from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Storage(StorageInterface):
|
class SupabaseS3Storage(StorageInterface):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
supabase_client = get_supabase_client()
|
self.client = None
|
||||||
self.db = supabase_client
|
|
||||||
|
|
||||||
def upload_file(self, file_name: str):
|
async def _set_client(self):
|
||||||
"""
|
if self.client is None:
|
||||||
Upload file to storage
|
self.client = await get_supabase_async_client()
|
||||||
"""
|
|
||||||
self.db.storage.from_("quivr").download(file_name)
|
|
||||||
|
|
||||||
def remove_file(self, file_name: str):
|
def get_storage_path(
|
||||||
|
self,
|
||||||
|
knowledge: KnowledgeDB,
|
||||||
|
) -> str:
|
||||||
|
if knowledge.id is None:
|
||||||
|
raise ValueError("knowledge should have a valid id")
|
||||||
|
return str(knowledge.id)
|
||||||
|
|
||||||
|
async def upload_file_storage(
|
||||||
|
self,
|
||||||
|
knowledge: KnowledgeDB,
|
||||||
|
knowledge_data: FileIO | BufferedReader | bytes,
|
||||||
|
upsert: bool = False,
|
||||||
|
):
|
||||||
|
await self._set_client()
|
||||||
|
assert self.client
|
||||||
|
|
||||||
|
mime_type = "application/html"
|
||||||
|
if knowledge.file_name:
|
||||||
|
guessed_mime_type, _ = mimetypes.guess_type(knowledge.file_name)
|
||||||
|
mime_type = guessed_mime_type or mime_type
|
||||||
|
|
||||||
|
storage_path = self.get_storage_path(knowledge)
|
||||||
|
logger.info(
|
||||||
|
f"Uploading file to s3://quivr/{storage_path} using supabase. upsert={upsert}, mimetype={mime_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if upsert:
|
||||||
|
_ = await self.client.storage.from_("quivr").update(
|
||||||
|
storage_path,
|
||||||
|
knowledge_data,
|
||||||
|
file_options={
|
||||||
|
"content-type": mime_type,
|
||||||
|
"upsert": "true",
|
||||||
|
"cache-control": "3600",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return storage_path
|
||||||
|
else:
|
||||||
|
# check if file sha1 is already in storage
|
||||||
|
try:
|
||||||
|
_ = await self.client.storage.from_("quivr").upload(
|
||||||
|
storage_path,
|
||||||
|
knowledge_data,
|
||||||
|
file_options={
|
||||||
|
"content-type": mime_type,
|
||||||
|
"upsert": "false",
|
||||||
|
"cache-control": "3600",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return storage_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if "The resource already exists" in str(e) and not upsert:
|
||||||
|
raise FileExistsError(f"File {storage_path} already exists")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def remove_file(self, storage_path: str):
|
||||||
"""
|
"""
|
||||||
Remove file from storage
|
Remove file from storage
|
||||||
"""
|
"""
|
||||||
|
await self._set_client()
|
||||||
|
assert self.client
|
||||||
try:
|
try:
|
||||||
response = self.db.storage.from_("quivr").remove([file_name])
|
response = await self.client.storage.from_("quivr").remove([storage_path])
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
# raise e
|
|
||||||
|
|
||||||
|
@ -1,10 +1,26 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from io import BufferedReader, FileIO
|
||||||
|
|
||||||
|
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
||||||
|
|
||||||
|
|
||||||
class StorageInterface(ABC):
|
class StorageInterface(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remove_file(self, file_name: str):
|
def get_storage_path(
|
||||||
"""
|
self,
|
||||||
Remove file from storage
|
knowledge: KnowledgeDB,
|
||||||
"""
|
) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def upload_file_storage(
|
||||||
|
self,
|
||||||
|
knowledge: KnowledgeDB,
|
||||||
|
knowledge_data: FileIO | BufferedReader | bytes,
|
||||||
|
upsert: bool = False,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def remove_file(self, storage_path: str):
|
||||||
pass
|
pass
|
||||||
|
@ -0,0 +1,34 @@
|
|||||||
|
class KnowledgeException(Exception):
|
||||||
|
def __init__(self, message="A knowledge-related error occurred"):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class UploadError(KnowledgeException):
|
||||||
|
def __init__(self, message="An error occurred while uploading"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeCreationError(KnowledgeException):
|
||||||
|
def __init__(self, message="An error occurred while creating"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeUpdateError(KnowledgeException):
|
||||||
|
def __init__(self, message="An error occurred while updating"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeDeleteError(KnowledgeException):
|
||||||
|
def __init__(self, message="An error occurred while deleting"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeForbiddenAccess(KnowledgeException):
|
||||||
|
def __init__(self, message="You do not have permission to access this knowledge."):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeNotFoundException(KnowledgeException):
|
||||||
|
def __init__(self, message="The requested knowledge was not found"):
|
||||||
|
super().__init__(message)
|
@ -1,18 +1,33 @@
|
|||||||
from typing import List
|
import asyncio
|
||||||
|
import io
|
||||||
|
from typing import Any, List
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import UploadFile
|
||||||
from quivr_core.models import KnowledgeStatus
|
from quivr_core.models import KnowledgeStatus
|
||||||
from sqlalchemy.exc import NoResultFound
|
from sqlalchemy.exc import NoResultFound
|
||||||
|
|
||||||
from quivr_api.logger import get_logger
|
from quivr_api.logger import get_logger
|
||||||
from quivr_api.modules.dependencies import BaseService
|
from quivr_api.modules.dependencies import BaseService
|
||||||
from quivr_api.modules.knowledge.dto.inputs import (
|
from quivr_api.modules.knowledge.dto.inputs import (
|
||||||
|
AddKnowledge,
|
||||||
CreateKnowledgeProperties,
|
CreateKnowledgeProperties,
|
||||||
)
|
)
|
||||||
from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse
|
from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse
|
||||||
from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeDB
|
from quivr_api.modules.knowledge.entity.knowledge import (
|
||||||
|
Knowledge,
|
||||||
|
KnowledgeDB,
|
||||||
|
KnowledgeSource,
|
||||||
|
KnowledgeUpdate,
|
||||||
|
)
|
||||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||||
from quivr_api.modules.knowledge.repository.storage import Storage
|
from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
|
||||||
|
from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface
|
||||||
|
from quivr_api.modules.knowledge.service.knowledge_exceptions import (
|
||||||
|
KnowledgeDeleteError,
|
||||||
|
KnowledgeForbiddenAccess,
|
||||||
|
UploadError,
|
||||||
|
)
|
||||||
from quivr_api.modules.sync.entity.sync_models import (
|
from quivr_api.modules.sync.entity.sync_models import (
|
||||||
DBSyncFile,
|
DBSyncFile,
|
||||||
DownloadedSyncFile,
|
DownloadedSyncFile,
|
||||||
@ -26,9 +41,13 @@ logger = get_logger(__name__)
|
|||||||
class KnowledgeService(BaseService[KnowledgeRepository]):
|
class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||||
repository_cls = KnowledgeRepository
|
repository_cls = KnowledgeRepository
|
||||||
|
|
||||||
def __init__(self, repository: KnowledgeRepository):
|
def __init__(
|
||||||
|
self,
|
||||||
|
repository: KnowledgeRepository,
|
||||||
|
storage: StorageInterface = SupabaseS3Storage(),
|
||||||
|
):
|
||||||
self.repository = repository
|
self.repository = repository
|
||||||
self.storage = Storage()
|
self.storage = storage
|
||||||
|
|
||||||
async def get_knowledge_sync(self, sync_id: int) -> Knowledge:
|
async def get_knowledge_sync(self, sync_id: int) -> Knowledge:
|
||||||
km = await self.repository.get_knowledge_by_sync_id(sync_id)
|
km = await self.repository.get_knowledge_by_sync_id(sync_id)
|
||||||
@ -54,19 +73,37 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
|||||||
except NoResultFound:
|
except NoResultFound:
|
||||||
raise FileNotFoundError(f"No knowledge for file_name: {file_name}")
|
raise FileNotFoundError(f"No knowledge for file_name: {file_name}")
|
||||||
|
|
||||||
async def get_knowledge(self, knowledge_id: UUID) -> Knowledge:
|
async def list_knowledge(
|
||||||
inserted_knowledge_db_instance = await self.repository.get_knowledge_by_id(
|
self, knowledge_id: UUID | None, user_id: UUID | None = None
|
||||||
knowledge_id
|
) -> list[KnowledgeDB]:
|
||||||
|
if knowledge_id is not None:
|
||||||
|
km = await self.repository.get_knowledge_by_id(knowledge_id, user_id)
|
||||||
|
return km.children
|
||||||
|
else:
|
||||||
|
if user_id is None:
|
||||||
|
raise KnowledgeForbiddenAccess(
|
||||||
|
"can't get root knowledges without user_id"
|
||||||
)
|
)
|
||||||
assert inserted_knowledge_db_instance.id, "Knowledge ID not generated"
|
return await self.repository.get_root_knowledge_user(user_id)
|
||||||
km = await inserted_knowledge_db_instance.to_dto()
|
|
||||||
return km
|
|
||||||
|
|
||||||
|
async def get_knowledge(
|
||||||
|
self, knowledge_id: UUID, user_id: UUID | None = None
|
||||||
|
) -> KnowledgeDB:
|
||||||
|
return await self.repository.get_knowledge_by_id(knowledge_id, user_id)
|
||||||
|
|
||||||
|
async def update_knowledge(
|
||||||
|
self,
|
||||||
|
knowledge: KnowledgeDB,
|
||||||
|
payload: Knowledge | KnowledgeUpdate | dict[str, Any],
|
||||||
|
):
|
||||||
|
return await self.repository.update_knowledge(knowledge, payload)
|
||||||
|
|
||||||
|
# TODO: Remove all of this
|
||||||
# TODO (@aminediro): Replace with ON CONFLICT smarter query...
|
# TODO (@aminediro): Replace with ON CONFLICT smarter query...
|
||||||
# there is a chance of race condition but for now we let it crash in worker
|
# there is a chance of race condition but for now we let it crash in worker
|
||||||
# the tasks will be dealt with on retry
|
# the tasks will be dealt with on retry
|
||||||
async def update_sha1_conflict(
|
async def update_sha1_conflict(
|
||||||
self, knowledge: Knowledge, brain_id: UUID, file_sha1: str
|
self, knowledge: KnowledgeDB, brain_id: UUID, file_sha1: str
|
||||||
) -> bool:
|
) -> bool:
|
||||||
assert knowledge.id
|
assert knowledge.id
|
||||||
knowledge.file_sha1 = file_sha1
|
knowledge.file_sha1 = file_sha1
|
||||||
@ -89,12 +126,12 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self.repository.link_to_brain(existing_knowledge, brain_id)
|
await self.repository.link_to_brain(existing_knowledge, brain_id)
|
||||||
await self.remove_knowledge(brain_id, knowledge.id)
|
await self.remove_knowledge_brain(brain_id, knowledge.id)
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Removing previous errored file {existing_knowledge.id}")
|
logger.debug(f"Removing previous errored file {existing_knowledge.id}")
|
||||||
assert existing_knowledge.id
|
assert existing_knowledge.id
|
||||||
await self.remove_knowledge(brain_id, existing_knowledge.id)
|
await self.remove_knowledge_brain(brain_id, existing_knowledge.id)
|
||||||
await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1)
|
await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1)
|
||||||
return True
|
return True
|
||||||
except NoResultFound:
|
except NoResultFound:
|
||||||
@ -104,7 +141,47 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
|||||||
await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1)
|
await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def insert_knowledge(
|
async def create_knowledge(
|
||||||
|
self,
|
||||||
|
user_id: UUID,
|
||||||
|
knowledge_to_add: AddKnowledge,
|
||||||
|
upload_file: UploadFile | None = None,
|
||||||
|
) -> KnowledgeDB:
|
||||||
|
knowledgedb = KnowledgeDB(
|
||||||
|
user_id=user_id,
|
||||||
|
file_name=knowledge_to_add.file_name,
|
||||||
|
is_folder=knowledge_to_add.is_folder,
|
||||||
|
url=knowledge_to_add.url,
|
||||||
|
extension=knowledge_to_add.extension,
|
||||||
|
source=knowledge_to_add.source,
|
||||||
|
source_link=knowledge_to_add.source_link,
|
||||||
|
file_size=upload_file.size if upload_file else 0,
|
||||||
|
metadata_=knowledge_to_add.metadata, # type: ignore
|
||||||
|
status=KnowledgeStatus.RESERVED,
|
||||||
|
parent_id=knowledge_to_add.parent_id,
|
||||||
|
)
|
||||||
|
knowledge_db = await self.repository.create_knowledge(knowledgedb)
|
||||||
|
try:
|
||||||
|
if knowledgedb.source == KnowledgeSource.LOCAL and upload_file:
|
||||||
|
# NOTE(@aminediro): Unnecessary mem buffer because supabase doesnt accept FileIO..
|
||||||
|
buff_reader = io.BufferedReader(upload_file.file) # type: ignore
|
||||||
|
storage_path = await self.storage.upload_file_storage(
|
||||||
|
knowledgedb, buff_reader
|
||||||
|
)
|
||||||
|
knowledgedb.source_link = storage_path
|
||||||
|
knowledge_db = await self.repository.update_knowledge(
|
||||||
|
knowledge_db,
|
||||||
|
KnowledgeUpdate(status=KnowledgeStatus.UPLOADED), # type: ignore
|
||||||
|
)
|
||||||
|
return knowledge_db
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Error uploading knowledge {knowledgedb.id} to storage : {e}"
|
||||||
|
)
|
||||||
|
await self.repository.remove_knowledge(knowledge=knowledge_db)
|
||||||
|
raise UploadError()
|
||||||
|
|
||||||
|
async def insert_knowledge_brain(
|
||||||
self,
|
self,
|
||||||
user_id: UUID,
|
user_id: UUID,
|
||||||
knowledge_to_add: CreateKnowledgeProperties, # FIXME: (later) @Amine brain id should not be in CreateKnowledgeProperties but since storage is brain_id/file_name
|
knowledge_to_add: CreateKnowledgeProperties, # FIXME: (later) @Amine brain id should not be in CreateKnowledgeProperties but since storage is brain_id/file_name
|
||||||
@ -122,7 +199,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
knowledge_db = await self.repository.insert_knowledge(
|
knowledge_db = await self.repository.insert_knowledge_brain(
|
||||||
knowledge, brain_id=knowledge_to_add.brain_id
|
knowledge, brain_id=knowledge_to_add.brain_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -150,7 +227,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
|||||||
assert isinstance(knowledge.file_name, str), "file_name should be a string"
|
assert isinstance(knowledge.file_name, str), "file_name should be a string"
|
||||||
file_name_with_brain_id = f"{brain_id}/{knowledge.file_name}"
|
file_name_with_brain_id = f"{brain_id}/{knowledge.file_name}"
|
||||||
try:
|
try:
|
||||||
self.storage.remove_file(file_name_with_brain_id)
|
await self.storage.remove_file(file_name_with_brain_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error while removing file {file_name_with_brain_id}: {e}"
|
f"Error while removing file {file_name_with_brain_id}: {e}"
|
||||||
@ -161,29 +238,52 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
|||||||
async def update_file_sha1_knowledge(self, knowledge_id: UUID, file_sha1: str):
|
async def update_file_sha1_knowledge(self, knowledge_id: UUID, file_sha1: str):
|
||||||
return await self.repository.update_file_sha1_knowledge(knowledge_id, file_sha1)
|
return await self.repository.update_file_sha1_knowledge(knowledge_id, file_sha1)
|
||||||
|
|
||||||
async def remove_knowledge(
|
async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeResponse:
|
||||||
|
assert knowledge.id
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO:
|
||||||
|
# - Notion folders are special, they are themselves files and should be removed from storage
|
||||||
|
children = await self.repository.get_all_children(knowledge.id)
|
||||||
|
km_paths = [
|
||||||
|
self.storage.get_storage_path(k) for k in children if not k.is_folder
|
||||||
|
]
|
||||||
|
if not knowledge.is_folder:
|
||||||
|
km_paths.append(self.storage.get_storage_path(knowledge))
|
||||||
|
|
||||||
|
# recursively deletes files
|
||||||
|
deleted_km = await self.repository.remove_knowledge(knowledge)
|
||||||
|
await asyncio.gather(*[self.storage.remove_file(p) for p in km_paths])
|
||||||
|
|
||||||
|
return deleted_km
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error while remove knowledge : {e}")
|
||||||
|
raise KnowledgeDeleteError
|
||||||
|
|
||||||
|
async def remove_knowledge_brain(
|
||||||
self,
|
self,
|
||||||
brain_id: UUID,
|
brain_id: UUID,
|
||||||
knowledge_id: UUID, # FIXME: @amine when name in storage change no need for brain id
|
knowledge_id: UUID, # FIXME: @amine when name in storage change no need for brain id
|
||||||
) -> DeleteKnowledgeResponse:
|
) -> DeleteKnowledgeResponse:
|
||||||
# TODO: fix KMS
|
# TODO: fix KMS
|
||||||
# REDO ALL THIS
|
# REDO ALL THIS
|
||||||
knowledge = await self.get_knowledge(knowledge_id)
|
knowledge = await self.repository.get_knowledge_by_id(knowledge_id)
|
||||||
if len(knowledge.brain_ids) > 1:
|
km_brains = await knowledge.awaitable_attrs.brains
|
||||||
|
if len(km_brains) > 1:
|
||||||
km = await self.repository.remove_knowledge_from_brain(
|
km = await self.repository.remove_knowledge_from_brain(
|
||||||
knowledge_id, brain_id
|
knowledge_id, brain_id
|
||||||
)
|
)
|
||||||
|
assert km.id
|
||||||
return DeleteKnowledgeResponse(file_name=km.file_name, knowledge_id=km.id)
|
return DeleteKnowledgeResponse(file_name=km.file_name, knowledge_id=km.id)
|
||||||
else:
|
else:
|
||||||
message = await self.repository.remove_knowledge_by_id(knowledge_id)
|
message = await self.repository.remove_knowledge_by_id(knowledge_id)
|
||||||
file_name_with_brain_id = f"{brain_id}/{message.file_name}"
|
file_name_with_brain_id = f"{brain_id}/{message.file_name}"
|
||||||
try:
|
try:
|
||||||
self.storage.remove_file(file_name_with_brain_id)
|
await self.storage.remove_file(file_name_with_brain_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error while removing file {file_name_with_brain_id}: {e}"
|
f"Error while removing file {file_name_with_brain_id}: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
async def remove_all_knowledges_from_brain(self, brain_id: UUID) -> None:
|
async def remove_all_knowledges_from_brain(self, brain_id: UUID) -> None:
|
||||||
@ -210,7 +310,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
|||||||
# TODO: THIS IS A HACK!! Remove all of this
|
# TODO: THIS IS A HACK!! Remove all of this
|
||||||
if prev_sync_file:
|
if prev_sync_file:
|
||||||
prev_knowledge = await self.get_knowledge_sync(sync_id=prev_sync_file.id)
|
prev_knowledge = await self.get_knowledge_sync(sync_id=prev_sync_file.id)
|
||||||
if len(prev_knowledge.brain_ids) > 1:
|
if len(prev_knowledge.brains) > 1:
|
||||||
await self.repository.remove_knowledge_from_brain(
|
await self.repository.remove_knowledge_from_brain(
|
||||||
prev_knowledge.id, brain_id
|
prev_knowledge.id, brain_id
|
||||||
)
|
)
|
||||||
@ -231,7 +331,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
|||||||
file_sha1=None,
|
file_sha1=None,
|
||||||
metadata={"sync_file_id": str(sync_id)},
|
metadata={"sync_file_id": str(sync_id)},
|
||||||
)
|
)
|
||||||
added_knowledge = await self.insert_knowledge(
|
added_knowledge = await self.insert_knowledge_brain(
|
||||||
knowledge_to_add=knowledge_to_add, user_id=user_id
|
knowledge_to_add=knowledge_to_add, user_id=user_id
|
||||||
)
|
)
|
||||||
return added_knowledge
|
return added_knowledge
|
||||||
|
67
backend/api/quivr_api/modules/knowledge/tests/conftest.py
Normal file
67
backend/api/quivr_api/modules/knowledge/tests/conftest.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from io import BufferedReader, FileIO
|
||||||
|
|
||||||
|
from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeDB
|
||||||
|
from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorStorage(StorageInterface):
|
||||||
|
async def upload_file_storage(
|
||||||
|
self,
|
||||||
|
knowledge: KnowledgeDB,
|
||||||
|
knowledge_data: FileIO | BufferedReader | bytes,
|
||||||
|
upsert: bool = False,
|
||||||
|
):
|
||||||
|
raise SystemError
|
||||||
|
|
||||||
|
def get_storage_path(
|
||||||
|
self,
|
||||||
|
knowledge: KnowledgeDB | Knowledge,
|
||||||
|
) -> str:
|
||||||
|
if knowledge.id is None:
|
||||||
|
raise ValueError("knowledge should have a valid id")
|
||||||
|
return str(knowledge.id)
|
||||||
|
|
||||||
|
async def remove_file(self, storage_path: str):
|
||||||
|
raise SystemError
|
||||||
|
|
||||||
|
|
||||||
|
class FakeStorage(StorageInterface):
|
||||||
|
def __init__(self):
|
||||||
|
self.storage = {}
|
||||||
|
|
||||||
|
def get_storage_path(
|
||||||
|
self,
|
||||||
|
knowledge: KnowledgeDB | Knowledge,
|
||||||
|
) -> str:
|
||||||
|
if knowledge.id is None:
|
||||||
|
raise ValueError("knowledge should have a valid id")
|
||||||
|
return str(knowledge.id)
|
||||||
|
|
||||||
|
async def upload_file_storage(
|
||||||
|
self,
|
||||||
|
knowledge: KnowledgeDB,
|
||||||
|
knowledge_data: FileIO | BufferedReader | bytes,
|
||||||
|
upsert: bool = False,
|
||||||
|
):
|
||||||
|
storage_path = f"{knowledge.id}"
|
||||||
|
if not upsert and storage_path in self.storage:
|
||||||
|
raise ValueError(f"File already exists at {storage_path}")
|
||||||
|
self.storage[storage_path] = knowledge_data
|
||||||
|
return storage_path
|
||||||
|
|
||||||
|
async def remove_file(self, storage_path: str):
|
||||||
|
if storage_path not in self.storage:
|
||||||
|
raise FileNotFoundError(f"File not found at {storage_path}")
|
||||||
|
del self.storage[storage_path]
|
||||||
|
|
||||||
|
# Additional helper methods for testing
|
||||||
|
def get_file(self, storage_path: str) -> FileIO | BufferedReader | bytes:
|
||||||
|
if storage_path not in self.storage:
|
||||||
|
raise FileNotFoundError(f"File not found at {storage_path}")
|
||||||
|
return self.storage[storage_path]
|
||||||
|
|
||||||
|
def knowledge_exists(self, knowledge: KnowledgeDB | Knowledge) -> bool:
|
||||||
|
return self.get_storage_path(knowledge) in self.storage
|
||||||
|
|
||||||
|
def clear_storage(self):
|
||||||
|
self.storage.clear()
|
@ -0,0 +1,74 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from quivr_api.main import app
|
||||||
|
from quivr_api.middlewares.auth.auth_bearer import get_current_user
|
||||||
|
from quivr_api.modules.knowledge.controller.knowledge_routes import get_km_service
|
||||||
|
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||||
|
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||||
|
from quivr_api.modules.knowledge.tests.conftest import FakeStorage
|
||||||
|
from quivr_api.modules.user.entity.user_identity import User, UserIdentity
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def user(session: AsyncSession) -> User:
|
||||||
|
user_1 = (
|
||||||
|
await session.exec(select(User).where(User.email == "admin@quivr.app"))
|
||||||
|
).one()
|
||||||
|
assert user_1.id
|
||||||
|
return user_1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def test_client(session: AsyncSession, user: User):
|
||||||
|
def default_current_user() -> UserIdentity:
|
||||||
|
assert user.id
|
||||||
|
return UserIdentity(email=user.email, id=user.id)
|
||||||
|
|
||||||
|
async def test_service():
|
||||||
|
storage = FakeStorage()
|
||||||
|
repository = KnowledgeRepository(session)
|
||||||
|
return KnowledgeService(repository, storage)
|
||||||
|
|
||||||
|
app.dependency_overrides[get_current_user] = default_current_user
|
||||||
|
app.dependency_overrides[get_km_service] = test_service
|
||||||
|
# app.dependency_overrides[get_async_session] = lambda: session
|
||||||
|
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=app), base_url="http://test"
|
||||||
|
) as ac:
|
||||||
|
yield ac
|
||||||
|
app.dependency_overrides = {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_post_knowledge(test_client: AsyncClient):
|
||||||
|
km_data = {
|
||||||
|
"file_name": "test_file.txt",
|
||||||
|
"source": "local",
|
||||||
|
"is_folder": False,
|
||||||
|
"parent_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
multipart_data = {
|
||||||
|
"knowledge_data": (None, json.dumps(km_data), "application/json"),
|
||||||
|
"file": ("test_file.txt", b"Test file content", "application/octet-stream"),
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await test_client.post(
|
||||||
|
"/knowledge/",
|
||||||
|
files=multipart_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_add_knowledge_invalid_input(test_client):
|
||||||
|
response = await test_client.post("/knowledge/", files={})
|
||||||
|
assert response.status_code == 422
|
@ -0,0 +1,229 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from quivr_core.models import KnowledgeStatus
|
||||||
|
from sqlmodel import select, text
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
|
||||||
|
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
||||||
|
from quivr_api.modules.user.entity.user_identity import User
|
||||||
|
|
||||||
|
TestData = Tuple[Brain, List[KnowledgeDB]]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def other_user(session: AsyncSession):
|
||||||
|
sql = text(
|
||||||
|
"""
|
||||||
|
INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES
|
||||||
|
('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await session.execute(sql, params={"id": uuid4()})
|
||||||
|
|
||||||
|
other_user = (
|
||||||
|
await session.exec(select(User).where(User.email == "other@quivr.app"))
|
||||||
|
).one()
|
||||||
|
return other_user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def user(session):
|
||||||
|
user_1 = (
|
||||||
|
await session.exec(select(User).where(User.email == "admin@quivr.app"))
|
||||||
|
).one()
|
||||||
|
return user_1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def brain(session):
|
||||||
|
brain_1 = Brain(
|
||||||
|
name="test_brain",
|
||||||
|
description="this is a test brain",
|
||||||
|
brain_type=BrainType.integration,
|
||||||
|
)
|
||||||
|
session.add(brain_1)
|
||||||
|
await session.commit()
|
||||||
|
return brain_1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def folder(session, user):
|
||||||
|
folder = KnowledgeDB(
|
||||||
|
file_name="folder_1",
|
||||||
|
extension="",
|
||||||
|
status="UPLOADED",
|
||||||
|
source="local",
|
||||||
|
source_link="local",
|
||||||
|
file_size=4,
|
||||||
|
file_sha1=None,
|
||||||
|
brains=[],
|
||||||
|
children=[],
|
||||||
|
user_id=user.id,
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(folder)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(folder)
|
||||||
|
return folder
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_knowledge_default_file(session, folder, user):
|
||||||
|
km = KnowledgeDB(
|
||||||
|
file_name="test_file_1.txt",
|
||||||
|
extension=".txt",
|
||||||
|
status="UPLOADED",
|
||||||
|
source="test_source",
|
||||||
|
source_link="test_source_link",
|
||||||
|
file_size=100,
|
||||||
|
file_sha1="test_sha1",
|
||||||
|
brains=[],
|
||||||
|
user_id=user.id,
|
||||||
|
parent_id=folder.id,
|
||||||
|
)
|
||||||
|
session.add(km)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(km)
|
||||||
|
|
||||||
|
assert not km.is_folder
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_knowledge_parent(session: AsyncSession, user: User):
|
||||||
|
assert user.id
|
||||||
|
|
||||||
|
km = KnowledgeDB(
|
||||||
|
file_name="test_file_1.txt",
|
||||||
|
extension=".txt",
|
||||||
|
status="UPLOADED",
|
||||||
|
source="test_source",
|
||||||
|
source_link="test_source_link",
|
||||||
|
file_size=100,
|
||||||
|
file_sha1="test_sha1",
|
||||||
|
brains=[],
|
||||||
|
user_id=user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
folder = KnowledgeDB(
|
||||||
|
file_name="folder_1",
|
||||||
|
extension="",
|
||||||
|
is_folder=True,
|
||||||
|
status="UPLOADED",
|
||||||
|
source="local",
|
||||||
|
source_link="local",
|
||||||
|
file_size=-1,
|
||||||
|
file_sha1=None,
|
||||||
|
brains=[],
|
||||||
|
children=[km],
|
||||||
|
user_id=user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(folder)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(folder)
|
||||||
|
await session.refresh(km)
|
||||||
|
|
||||||
|
parent = await km.awaitable_attrs.parent
|
||||||
|
assert km.parent_id == folder.id, "parent_id isn't set to folder id"
|
||||||
|
assert parent.id == folder.id, "parent_id isn't set to folder id"
|
||||||
|
assert parent.is_folder
|
||||||
|
|
||||||
|
query = select(KnowledgeDB).where(KnowledgeDB.id == folder.id)
|
||||||
|
folder = (await session.exec(query)).first()
|
||||||
|
assert folder
|
||||||
|
|
||||||
|
children = await folder.awaitable_attrs.children
|
||||||
|
assert len(children) > 0
|
||||||
|
|
||||||
|
assert children[0].id == km.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_knowledge_remove_folder_cascade(
|
||||||
|
session: AsyncSession,
|
||||||
|
folder: KnowledgeDB,
|
||||||
|
user,
|
||||||
|
):
|
||||||
|
km = KnowledgeDB(
|
||||||
|
file_name="test_file_1.txt",
|
||||||
|
extension=".txt",
|
||||||
|
status="UPLOADED",
|
||||||
|
source="test_source",
|
||||||
|
source_link="test_source_link",
|
||||||
|
file_size=100,
|
||||||
|
file_sha1="test_sha1",
|
||||||
|
brains=[],
|
||||||
|
user_id=user.id,
|
||||||
|
parent_id=folder.id,
|
||||||
|
)
|
||||||
|
session.add(km)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(km)
|
||||||
|
|
||||||
|
# Check all removed
|
||||||
|
await session.delete(folder)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
statement = select(KnowledgeDB)
|
||||||
|
results = (await session.exec(statement)).all()
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(loop_scope="session")
|
||||||
|
async def test_knowledge_dto(session, user, brain):
|
||||||
|
# add folder in brain
|
||||||
|
folder = KnowledgeDB(
|
||||||
|
file_name="folder_1",
|
||||||
|
extension="",
|
||||||
|
status="UPLOADED",
|
||||||
|
source="local",
|
||||||
|
source_link="local",
|
||||||
|
file_size=4,
|
||||||
|
file_sha1=None,
|
||||||
|
brains=[brain],
|
||||||
|
children=[],
|
||||||
|
user_id=user.id,
|
||||||
|
is_folder=True,
|
||||||
|
)
|
||||||
|
km = KnowledgeDB(
|
||||||
|
file_name="test_file_1.txt",
|
||||||
|
extension=".txt",
|
||||||
|
status="UPLOADED",
|
||||||
|
source="test_source",
|
||||||
|
source_link="test_source_link",
|
||||||
|
file_size=100,
|
||||||
|
file_sha1="test_sha1",
|
||||||
|
user_id=user.id,
|
||||||
|
brains=[brain],
|
||||||
|
parent=folder,
|
||||||
|
)
|
||||||
|
session.add(km)
|
||||||
|
session.add(km)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(km)
|
||||||
|
|
||||||
|
km_dto = await km.to_dto()
|
||||||
|
|
||||||
|
assert km_dto.file_name == km.file_name
|
||||||
|
assert km_dto.url == km.url
|
||||||
|
assert km_dto.extension == km.extension
|
||||||
|
assert km_dto.status == KnowledgeStatus(km.status)
|
||||||
|
assert km_dto.source == km.source
|
||||||
|
assert km_dto.source_link == km.source_link
|
||||||
|
assert km_dto.is_folder == km.is_folder
|
||||||
|
assert km_dto.file_size == km.file_size
|
||||||
|
assert km_dto.file_sha1 == km.file_sha1
|
||||||
|
assert km_dto.updated_at == km.updated_at
|
||||||
|
assert km_dto.created_at == km.created_at
|
||||||
|
assert km_dto.metadata == km.metadata_ # type: ignor
|
||||||
|
assert km_dto.parent
|
||||||
|
assert km_dto.parent.id == folder.id
|
||||||
|
|
||||||
|
folder_dto = await folder.to_dto()
|
||||||
|
assert folder_dto.brains[0] == brain.model_dump()
|
||||||
|
assert folder_dto.children == [await km.to_dto()]
|
File diff suppressed because it is too large
Load Diff
@ -1,450 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import List, Tuple
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
from sqlalchemy.exc import IntegrityError, NoResultFound
|
|
||||||
from sqlmodel import select, text
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
||||||
|
|
||||||
from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
|
|
||||||
from quivr_api.modules.knowledge.dto.inputs import KnowledgeStatus
|
|
||||||
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
|
||||||
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
|
|
||||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
|
||||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
|
||||||
from quivr_api.modules.upload.service.upload_file import upload_file_storage
|
|
||||||
from quivr_api.modules.user.entity.user_identity import User
|
|
||||||
from quivr_api.modules.vector.entity.vector import Vector
|
|
||||||
|
|
||||||
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
|
|
||||||
|
|
||||||
TestData = Tuple[Brain, List[KnowledgeDB]]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function")
|
|
||||||
async def other_user(session: AsyncSession):
|
|
||||||
sql = text(
|
|
||||||
"""
|
|
||||||
INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES
|
|
||||||
('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL);
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
await session.execute(sql, params={"id": uuid4()})
|
|
||||||
|
|
||||||
other_user = (
|
|
||||||
await session.exec(select(User).where(User.email == "other@quivr.app"))
|
|
||||||
).one()
|
|
||||||
return other_user
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function")
|
|
||||||
async def test_data(session: AsyncSession) -> TestData:
|
|
||||||
user_1 = (
|
|
||||||
await session.exec(select(User).where(User.email == "admin@quivr.app"))
|
|
||||||
).one()
|
|
||||||
assert user_1.id
|
|
||||||
# Brain data
|
|
||||||
brain_1 = Brain(
|
|
||||||
name="test_brain",
|
|
||||||
description="this is a test brain",
|
|
||||||
brain_type=BrainType.integration,
|
|
||||||
)
|
|
||||||
|
|
||||||
knowledge_brain_1 = KnowledgeDB(
|
|
||||||
file_name="test_file_1.txt",
|
|
||||||
extension=".txt",
|
|
||||||
status="UPLOADED",
|
|
||||||
source="test_source",
|
|
||||||
source_link="test_source_link",
|
|
||||||
file_size=100,
|
|
||||||
file_sha1="test_sha1",
|
|
||||||
brains=[brain_1],
|
|
||||||
user_id=user_1.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
knowledge_brain_2 = KnowledgeDB(
|
|
||||||
file_name="test_file_2.txt",
|
|
||||||
extension=".txt",
|
|
||||||
status="UPLOADED",
|
|
||||||
source="test_source",
|
|
||||||
source_link="test_source_link",
|
|
||||||
file_size=100,
|
|
||||||
file_sha1="test_sha2",
|
|
||||||
brains=[],
|
|
||||||
user_id=user_1.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
session.add(brain_1)
|
|
||||||
session.add(knowledge_brain_1)
|
|
||||||
session.add(knowledge_brain_2)
|
|
||||||
await session.commit()
|
|
||||||
return brain_1, [knowledge_brain_1, knowledge_brain_2]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_updates_knowledge_status(session: AsyncSession, test_data: TestData):
|
|
||||||
brain, knowledges = test_data
|
|
||||||
assert brain.brain_id
|
|
||||||
assert knowledges[0].id
|
|
||||||
repo = KnowledgeRepository(session)
|
|
||||||
await repo.update_status_knowledge(knowledges[0].id, KnowledgeStatus.ERROR)
|
|
||||||
knowledge = await repo.get_knowledge_by_id(knowledges[0].id)
|
|
||||||
assert knowledge.status == KnowledgeStatus.ERROR
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_updates_knowledge_status_no_knowledge(
|
|
||||||
session: AsyncSession, test_data: TestData
|
|
||||||
):
|
|
||||||
brain, knowledges = test_data
|
|
||||||
assert brain.brain_id
|
|
||||||
assert knowledges[0].id
|
|
||||||
repo = KnowledgeRepository(session)
|
|
||||||
with pytest.raises(NoResultFound):
|
|
||||||
await repo.update_status_knowledge(uuid4(), KnowledgeStatus.UPLOADED)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_update_knowledge_source_link(session: AsyncSession, test_data: TestData):
|
|
||||||
brain, knowledges = test_data
|
|
||||||
assert brain.brain_id
|
|
||||||
assert knowledges[0].id
|
|
||||||
repo = KnowledgeRepository(session)
|
|
||||||
await repo.update_source_link_knowledge(knowledges[0].id, "new_source_link")
|
|
||||||
knowledge = await repo.get_knowledge_by_id(knowledges[0].id)
|
|
||||||
assert knowledge.source_link == "new_source_link"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_remove_knowledge_from_brain(session: AsyncSession, test_data: TestData):
|
|
||||||
brain, knowledges = test_data
|
|
||||||
assert brain.brain_id
|
|
||||||
assert knowledges[0].id
|
|
||||||
repo = KnowledgeRepository(session)
|
|
||||||
knowledge = await repo.remove_knowledge_from_brain(knowledges[0].id, brain.brain_id)
|
|
||||||
assert brain.brain_id not in [
|
|
||||||
b.brain_id for b in await knowledge.awaitable_attrs.brains
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_cascade_remove_knowledge_by_id(
|
|
||||||
session: AsyncSession, test_data: TestData
|
|
||||||
):
|
|
||||||
brain, knowledges = test_data
|
|
||||||
assert brain.brain_id
|
|
||||||
assert knowledges[0].id
|
|
||||||
repo = KnowledgeRepository(session)
|
|
||||||
await repo.remove_knowledge_by_id(knowledges[0].id)
|
|
||||||
with pytest.raises(NoResultFound):
|
|
||||||
await repo.get_knowledge_by_id(knowledges[0].id)
|
|
||||||
|
|
||||||
query = select(KnowledgeBrain).where(
|
|
||||||
KnowledgeBrain.knowledge_id == knowledges[0].id
|
|
||||||
)
|
|
||||||
result = await session.exec(query)
|
|
||||||
knowledge_brain = result.first()
|
|
||||||
assert knowledge_brain is None
|
|
||||||
|
|
||||||
query = select(Vector).where(Vector.knowledge_id == knowledges[0].id)
|
|
||||||
result = await session.exec(query)
|
|
||||||
vector = result.first()
|
|
||||||
assert vector is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_remove_all_knowledges_from_brain(
|
|
||||||
session: AsyncSession, test_data: TestData
|
|
||||||
):
|
|
||||||
brain, knowledges = test_data
|
|
||||||
assert brain.brain_id
|
|
||||||
|
|
||||||
# supabase_client = get_supabase_client()
|
|
||||||
# db = supabase_client
|
|
||||||
# storage = db.storage.from_("quivr")
|
|
||||||
|
|
||||||
# storage.upload(f"{brain.brain_id}/test_file_1", b"test_content")
|
|
||||||
|
|
||||||
repo = KnowledgeRepository(session)
|
|
||||||
service = KnowledgeService(repo)
|
|
||||||
await repo.remove_all_knowledges_from_brain(brain.brain_id)
|
|
||||||
knowledges = await service.get_all_knowledge_in_brain(brain.brain_id)
|
|
||||||
assert len(knowledges) == 0
|
|
||||||
|
|
||||||
# response = storage.list(path=f"{brain.brain_id}")
|
|
||||||
# assert response == []
|
|
||||||
# FIXME @aminediro &chloedia raise an error when trying to interact with storage UnboundLocalError: cannot access local variable 'response' where it is not associated with a value
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_duplicate_sha1_knowledge_same_user(
|
|
||||||
session: AsyncSession, test_data: TestData
|
|
||||||
):
|
|
||||||
brain, [existing_knowledge, _] = test_data
|
|
||||||
assert brain.brain_id
|
|
||||||
assert existing_knowledge.id
|
|
||||||
assert existing_knowledge.file_sha1
|
|
||||||
repo = KnowledgeRepository(session)
|
|
||||||
knowledge = KnowledgeDB(
|
|
||||||
file_name="test_file_2",
|
|
||||||
extension="txt",
|
|
||||||
status="UPLOADED",
|
|
||||||
source="test_source",
|
|
||||||
source_link="test_source_link",
|
|
||||||
file_size=100,
|
|
||||||
file_sha1=existing_knowledge.file_sha1,
|
|
||||||
brains=[brain],
|
|
||||||
user_id=existing_knowledge.user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(IntegrityError): # FIXME: Should raise IntegrityError
|
|
||||||
await repo.insert_knowledge(knowledge, brain.brain_id)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_duplicate_sha1_knowledge_diff_user(
|
|
||||||
session: AsyncSession, test_data: TestData, other_user: User
|
|
||||||
):
|
|
||||||
brain, knowledges = test_data
|
|
||||||
assert other_user.id
|
|
||||||
assert brain.brain_id
|
|
||||||
assert knowledges[0].id
|
|
||||||
repo = KnowledgeRepository(session)
|
|
||||||
knowledge = KnowledgeDB(
|
|
||||||
file_name="test_file_2",
|
|
||||||
extension="txt",
|
|
||||||
status="UPLOADED",
|
|
||||||
source="test_source",
|
|
||||||
source_link="test_source_link",
|
|
||||||
file_size=100,
|
|
||||||
file_sha1=knowledges[0].file_sha1,
|
|
||||||
brains=[brain],
|
|
||||||
user_id=other_user.id, # random user id
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await repo.insert_knowledge(knowledge, brain.brain_id)
|
|
||||||
assert result
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_add_knowledge_to_brain(session: AsyncSession, test_data: TestData):
|
|
||||||
brain, knowledges = test_data
|
|
||||||
assert brain.brain_id
|
|
||||||
assert knowledges[1].id
|
|
||||||
repo = KnowledgeRepository(session)
|
|
||||||
await repo.link_to_brain(knowledges[1], brain.brain_id)
|
|
||||||
knowledge = await repo.get_knowledge_by_id(knowledges[1].id)
|
|
||||||
brains_of_knowledge = [b.brain_id for b in await knowledge.awaitable_attrs.brains]
|
|
||||||
assert brain.brain_id in brains_of_knowledge
|
|
||||||
|
|
||||||
query = select(KnowledgeBrain).where(
|
|
||||||
KnowledgeBrain.knowledge_id == knowledges[0].id
|
|
||||||
and KnowledgeBrain.brain_id == brain.brain_id
|
|
||||||
)
|
|
||||||
result = await session.exec(query)
|
|
||||||
knowledge_brain = result.first()
|
|
||||||
assert knowledge_brain
|
|
||||||
|
|
||||||
|
|
||||||
# Knowledge Service
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_get_knowledge_in_brain(session: AsyncSession, test_data: TestData):
|
|
||||||
brain, knowledges = test_data
|
|
||||||
assert brain.brain_id
|
|
||||||
repo = KnowledgeRepository(session)
|
|
||||||
service = KnowledgeService(repo)
|
|
||||||
list_knowledge = await service.get_all_knowledge_in_brain(brain.brain_id)
|
|
||||||
assert len(list_knowledge) == 1
|
|
||||||
brains_of_knowledge = [
|
|
||||||
b.brain_id for b in await knowledges[0].awaitable_attrs.brains
|
|
||||||
]
|
|
||||||
assert list_knowledge[0].id == knowledges[0].id
|
|
||||||
assert list_knowledge[0].file_name == knowledges[0].file_name
|
|
||||||
assert brain.brain_id in brains_of_knowledge
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_should_process_knowledge_exists(
|
|
||||||
session: AsyncSession, test_data: TestData
|
|
||||||
):
|
|
||||||
brain, [existing_knowledge, _] = test_data
|
|
||||||
assert brain.brain_id
|
|
||||||
new = KnowledgeDB(
|
|
||||||
file_name="new",
|
|
||||||
extension="txt",
|
|
||||||
status="PROCESSING",
|
|
||||||
source="test_source",
|
|
||||||
source_link="test_source_link",
|
|
||||||
file_size=100,
|
|
||||||
file_sha1=None,
|
|
||||||
brains=[brain],
|
|
||||||
user_id=existing_knowledge.user_id,
|
|
||||||
)
|
|
||||||
session.add(new)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(new)
|
|
||||||
incoming_knowledge = await new.to_dto()
|
|
||||||
repo = KnowledgeRepository(session)
|
|
||||||
service = KnowledgeService(repo)
|
|
||||||
assert existing_knowledge.file_sha1
|
|
||||||
with pytest.raises(FileExistsError):
|
|
||||||
await service.update_sha1_conflict(
|
|
||||||
incoming_knowledge, brain.brain_id, file_sha1=existing_knowledge.file_sha1
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_should_process_knowledge_link_brain(
|
|
||||||
session: AsyncSession, test_data: TestData
|
|
||||||
):
|
|
||||||
repo = KnowledgeRepository(session)
|
|
||||||
service = KnowledgeService(repo)
|
|
||||||
brain, [existing_knowledge, _] = test_data
|
|
||||||
user_id = existing_knowledge.user_id
|
|
||||||
assert brain.brain_id
|
|
||||||
prev = KnowledgeDB(
|
|
||||||
file_name="prev",
|
|
||||||
extension=".txt",
|
|
||||||
status=KnowledgeStatus.UPLOADED,
|
|
||||||
source="test_source",
|
|
||||||
source_link="test_source_link",
|
|
||||||
file_size=100,
|
|
||||||
file_sha1="test1",
|
|
||||||
brains=[brain],
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
brain_2 = Brain(
|
|
||||||
name="test_brain",
|
|
||||||
description="this is a test brain",
|
|
||||||
brain_type=BrainType.integration,
|
|
||||||
)
|
|
||||||
session.add(brain_2)
|
|
||||||
session.add(prev)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(prev)
|
|
||||||
await session.refresh(brain_2)
|
|
||||||
|
|
||||||
assert prev.id
|
|
||||||
assert brain_2.brain_id
|
|
||||||
|
|
||||||
new = KnowledgeDB(
|
|
||||||
file_name="new",
|
|
||||||
extension="txt",
|
|
||||||
status="PROCESSING",
|
|
||||||
source="test_source",
|
|
||||||
source_link="test_source_link",
|
|
||||||
file_size=100,
|
|
||||||
file_sha1=None,
|
|
||||||
brains=[brain_2],
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
session.add(new)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(new)
|
|
||||||
|
|
||||||
incoming_knowledge = await new.to_dto()
|
|
||||||
assert prev.file_sha1
|
|
||||||
|
|
||||||
should_process = await service.update_sha1_conflict(
|
|
||||||
incoming_knowledge, brain_2.brain_id, file_sha1=prev.file_sha1
|
|
||||||
)
|
|
||||||
assert not should_process
|
|
||||||
|
|
||||||
# Check prev knowledge was linked
|
|
||||||
assert incoming_knowledge.file_sha1
|
|
||||||
prev_knowledge = await service.repository.get_knowledge_by_id(prev.id)
|
|
||||||
prev_brains = await prev_knowledge.awaitable_attrs.brains
|
|
||||||
assert {b.brain_id for b in prev_brains} == {
|
|
||||||
brain.brain_id,
|
|
||||||
brain_2.brain_id,
|
|
||||||
}
|
|
||||||
# Check new knowledge was removed
|
|
||||||
assert new.id
|
|
||||||
with pytest.raises(NoResultFound):
|
|
||||||
await service.repository.get_knowledge_by_id(new.id)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_should_process_knowledge_prev_error(
|
|
||||||
session: AsyncSession, test_data: TestData
|
|
||||||
):
|
|
||||||
repo = KnowledgeRepository(session)
|
|
||||||
service = KnowledgeService(repo)
|
|
||||||
brain, [existing_knowledge, _] = test_data
|
|
||||||
user_id = existing_knowledge.user_id
|
|
||||||
assert brain.brain_id
|
|
||||||
prev = KnowledgeDB(
|
|
||||||
file_name="prev",
|
|
||||||
extension="txt",
|
|
||||||
status=KnowledgeStatus.ERROR,
|
|
||||||
source="test_source",
|
|
||||||
source_link="test_source_link",
|
|
||||||
file_size=100,
|
|
||||||
file_sha1="test1",
|
|
||||||
brains=[brain],
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
session.add(prev)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(prev)
|
|
||||||
|
|
||||||
assert prev.id
|
|
||||||
|
|
||||||
new = KnowledgeDB(
|
|
||||||
file_name="new",
|
|
||||||
extension="txt",
|
|
||||||
status="PROCESSING",
|
|
||||||
source="test_source",
|
|
||||||
source_link="test_source_link",
|
|
||||||
file_size=100,
|
|
||||||
file_sha1=None,
|
|
||||||
brains=[brain],
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
session.add(new)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(new)
|
|
||||||
|
|
||||||
incoming_knowledge = await new.to_dto()
|
|
||||||
assert prev.file_sha1
|
|
||||||
should_process = await service.update_sha1_conflict(
|
|
||||||
incoming_knowledge, brain.brain_id, file_sha1=prev.file_sha1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Checks we should process this file
|
|
||||||
assert should_process
|
|
||||||
# Previous errored file is cleaned up
|
|
||||||
with pytest.raises(NoResultFound):
|
|
||||||
await service.repository.get_knowledge_by_id(prev.id)
|
|
||||||
|
|
||||||
assert new.id
|
|
||||||
new = await service.repository.get_knowledge_by_id(new.id)
|
|
||||||
assert new.file_sha1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData):
|
|
||||||
brain, [knowledge, _] = test_data
|
|
||||||
assert knowledge.file_name
|
|
||||||
repository = KnowledgeRepository(session)
|
|
||||||
service = KnowledgeService(repository)
|
|
||||||
brain_2 = Brain(
|
|
||||||
name="test_brain",
|
|
||||||
description="this is a test brain",
|
|
||||||
brain_type=BrainType.integration,
|
|
||||||
)
|
|
||||||
session.add(brain_2)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(brain_2)
|
|
||||||
assert brain_2.brain_id
|
|
||||||
km_data = os.urandom(128)
|
|
||||||
km_path = f"{str(knowledge.brains[0].brain_id)}/{knowledge.file_name}"
|
|
||||||
await upload_file_storage(km_data, km_path)
|
|
||||||
# Link knowledge to two brains
|
|
||||||
await repository.link_to_brain(knowledge, brain_2.brain_id)
|
|
||||||
storage_path = await service.get_knowledge_storage_path(
|
|
||||||
knowledge.file_name, brain_2.brain_id
|
|
||||||
)
|
|
||||||
assert storage_path == km_path
|
|
@ -317,10 +317,11 @@ async def test_process_sync_file_noprev(
|
|||||||
created_km = all_km[0]
|
created_km = all_km[0]
|
||||||
assert created_km.file_name == sync_file.name
|
assert created_km.file_name == sync_file.name
|
||||||
assert created_km.extension == ".txt"
|
assert created_km.extension == ".txt"
|
||||||
assert created_km.file_sha1 is not None
|
assert created_km.file_sha1 is None
|
||||||
assert created_km.created_at is not None
|
assert created_km.created_at is not None
|
||||||
assert created_km.metadata == {"sync_file_id": "1"}
|
assert created_km.metadata == {"sync_file_id": "1"}
|
||||||
assert created_km.brain_ids == [brain_1.brain_id]
|
assert len(created_km.brains)> 0
|
||||||
|
assert created_km.brains[0]["brain_id"]== brain_1.brain_id
|
||||||
|
|
||||||
# Assert celery task in correct
|
# Assert celery task in correct
|
||||||
assert task["args"] == ("process_file_task",)
|
assert task["args"] == ("process_file_task",)
|
||||||
@ -409,12 +410,12 @@ async def test_process_sync_file_with_prev(
|
|||||||
created_km = all_km[0]
|
created_km = all_km[0]
|
||||||
assert created_km.file_name == sync_file.name
|
assert created_km.file_name == sync_file.name
|
||||||
assert created_km.extension == ".txt"
|
assert created_km.extension == ".txt"
|
||||||
assert created_km.file_sha1 is not None
|
assert created_km.file_sha1 is None
|
||||||
assert created_km.updated_at
|
assert created_km.updated_at
|
||||||
assert created_km.created_at
|
assert created_km.created_at
|
||||||
assert created_km.updated_at == created_km.created_at # new line
|
assert created_km.updated_at == created_km.created_at # new line
|
||||||
assert created_km.metadata == {"sync_file_id": str(dbfiles[0].id)}
|
assert created_km.metadata == {"sync_file_id": str(dbfiles[0].id)}
|
||||||
assert created_km.brain_ids == [brain_1.brain_id]
|
assert created_km.brains[0]["brain_id"]== brain_1.brain_id
|
||||||
|
|
||||||
# Check file content changed
|
# Check file content changed
|
||||||
assert check_file_exists(str(brain_1.brain_id), sync_file.name)
|
assert check_file_exists(str(brain_1.brain_id), sync_file.name)
|
||||||
|
@ -53,12 +53,10 @@ AsyncClientDep = Annotated[AsyncClient, Depends(get_supabase_async_client)]
|
|||||||
@upload_router.post("/upload", dependencies=[Depends(AuthBearer())], tags=["Upload"])
|
@upload_router.post("/upload", dependencies=[Depends(AuthBearer())], tags=["Upload"])
|
||||||
async def upload_file(
|
async def upload_file(
|
||||||
uploadFile: UploadFile,
|
uploadFile: UploadFile,
|
||||||
client: AsyncClientDep,
|
|
||||||
background_tasks: BackgroundTasks,
|
|
||||||
knowledge_service: KnowledgeServiceDep,
|
knowledge_service: KnowledgeServiceDep,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
bulk_id: Optional[UUID] = Query(None, description="The ID of the bulk upload"),
|
bulk_id: Optional[UUID] = Query(None, description="The ID of the bulk upload"),
|
||||||
brain_id: UUID = Query(..., description="The ID of the brain"),
|
brain_id: UUID = Query(..., description="The ID of the brain"),
|
||||||
chat_id: Optional[UUID] = Query(None, description="The ID of the chat"),
|
|
||||||
current_user: UserIdentity = Depends(get_current_user),
|
current_user: UserIdentity = Depends(get_current_user),
|
||||||
integration: Optional[str] = None,
|
integration: Optional[str] = None,
|
||||||
integration_link: Optional[str] = None,
|
integration_link: Optional[str] = None,
|
||||||
@ -121,7 +119,7 @@ async def upload_file(
|
|||||||
file_size=uploadFile.size,
|
file_size=uploadFile.size,
|
||||||
file_sha1=None,
|
file_sha1=None,
|
||||||
)
|
)
|
||||||
knowledge = await knowledge_service.insert_knowledge(
|
knowledge = await knowledge_service.insert_knowledge_brain(
|
||||||
user_id=current_user.id, knowledge_to_add=knowledge_to_add
|
user_id=current_user.id, knowledge_to_add=knowledge_to_add
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ async def crawl_endpoint(
|
|||||||
source_link=crawl_website.url,
|
source_link=crawl_website.url,
|
||||||
)
|
)
|
||||||
|
|
||||||
added_knowledge = await knowledge_service.insert_knowledge(
|
added_knowledge = await knowledge_service.insert_knowledge_brain(
|
||||||
knowledge_to_add=knowledge_to_add, user_id=current_user.id
|
knowledge_to_add=knowledge_to_add, user_id=current_user.id
|
||||||
)
|
)
|
||||||
logger.info(f"Knowledge {added_knowledge} added successfully")
|
logger.info(f"Knowledge {added_knowledge} added successfully")
|
||||||
|
50
backend/api/quivr_api/utils/partial.py
Normal file
50
backend/api/quivr_api/utils/partial.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Callable, Optional, Type, TypeVar
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, create_model
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
|
Model = TypeVar("Model", bound=Type[BaseModel])
|
||||||
|
|
||||||
|
|
||||||
|
def all_optional(without_fields: list[str] | None = None) -> Callable[[Model], Model]:
|
||||||
|
if without_fields is None:
|
||||||
|
without_fields = []
|
||||||
|
|
||||||
|
def wrapper(model: Type[Model]) -> Type[Model]:
|
||||||
|
base_model: Type[Model] = model
|
||||||
|
|
||||||
|
def make_field_optional(
|
||||||
|
field: FieldInfo, default: Any = None
|
||||||
|
) -> tuple[Any, FieldInfo]:
|
||||||
|
new = deepcopy(field)
|
||||||
|
new.default = default
|
||||||
|
new.annotation = Optional[field.annotation]
|
||||||
|
return new.annotation, new
|
||||||
|
|
||||||
|
if without_fields:
|
||||||
|
base_model = BaseModel
|
||||||
|
|
||||||
|
return create_model(
|
||||||
|
model.__name__,
|
||||||
|
__base__=base_model,
|
||||||
|
__module__=model.__module__,
|
||||||
|
**{
|
||||||
|
field_name: make_field_optional(field_info)
|
||||||
|
for field_name, field_info in model.model_fields.items()
|
||||||
|
if field_name not in without_fields
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class Test(BaseModel):
|
||||||
|
id: UUID
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@all_optional()
|
||||||
|
class TestUpdate(Test):
|
||||||
|
pass
|
@ -42,6 +42,7 @@ class KnowledgeStatus(str, Enum):
|
|||||||
PROCESSING = "PROCESSING"
|
PROCESSING = "PROCESSING"
|
||||||
UPLOADED = "UPLOADED"
|
UPLOADED = "UPLOADED"
|
||||||
ERROR = "ERROR"
|
ERROR = "ERROR"
|
||||||
|
RESERVED = "RESERVED"
|
||||||
|
|
||||||
|
|
||||||
class Source(BaseModel):
|
class Source(BaseModel):
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
ALTER USER postgres
|
||||||
|
SET idle_session_timeout = '3min';
|
||||||
|
ALTER USER postgres
|
||||||
|
SET idle_in_transaction_session_timeout = '3min';
|
||||||
|
-- Drop previous contraint
|
||||||
|
alter table "public"."knowledge" drop constraint "unique_file_sha1_user_id";
|
||||||
|
alter table "public"."knowledge"
|
||||||
|
add column "is_folder" boolean default false;
|
||||||
|
-- Update the knowledge to backfill knowledge to is_folder = false
|
||||||
|
UPDATE "public"."knowledge"
|
||||||
|
SET is_folder = false;
|
||||||
|
-- Add parent_id -> folder
|
||||||
|
alter table "public"."knowledge"
|
||||||
|
add column "parent_id" uuid;
|
||||||
|
alter table "public"."knowledge"
|
||||||
|
add constraint "public_knowledge_parent_id_fkey" FOREIGN KEY (parent_id) REFERENCES knowledge(id) ON DELETE CASCADE;
|
||||||
|
-- Add constraint must be folder for parent_id
|
||||||
|
CREATE FUNCTION is_parent_folder(folder_id uuid) RETURNS boolean AS $$ BEGIN RETURN (
|
||||||
|
SELECT k.is_folder
|
||||||
|
FROM public.knowledge k
|
||||||
|
WHERE k.id = folder_id
|
||||||
|
);
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql;
|
||||||
|
ALTER TABLE public.knowledge
|
||||||
|
ADD CONSTRAINT check_parent_is_folder CHECK (
|
||||||
|
parent_id IS NULL
|
||||||
|
OR is_parent_folder(parent_id)
|
||||||
|
);
|
||||||
|
-- Index on parent_id
|
||||||
|
CREATE INDEX knowledge_parent_id_idx ON public.knowledge USING btree (parent_id);
|
@ -13,7 +13,7 @@ from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
|
|||||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||||
from quivr_api.modules.dependencies import get_supabase_client
|
from quivr_api.modules.dependencies import get_supabase_client
|
||||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||||
from quivr_api.modules.knowledge.repository.storage import Storage
|
from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
|
||||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||||
from quivr_api.modules.notification.service.notification_service import (
|
from quivr_api.modules.notification.service.notification_service import (
|
||||||
NotificationService,
|
NotificationService,
|
||||||
@ -58,7 +58,7 @@ sync_user_service = SyncUserService()
|
|||||||
sync_files_repo_service = SyncFilesRepository()
|
sync_files_repo_service = SyncFilesRepository()
|
||||||
brain_service = BrainService()
|
brain_service = BrainService()
|
||||||
brain_vectors = BrainsVectors()
|
brain_vectors = BrainsVectors()
|
||||||
storage = Storage()
|
storage = SupabaseS3Storage()
|
||||||
notion_service: SyncNotionService | None = None
|
notion_service: SyncNotionService | None = None
|
||||||
async_engine: AsyncEngine | None = None
|
async_engine: AsyncEngine | None = None
|
||||||
engine: Engine | None = None
|
engine: Engine | None = None
|
||||||
@ -170,6 +170,8 @@ async def aprocess_file_task(
|
|||||||
integration_link=source_link,
|
integration_link=source_link,
|
||||||
delete_file=delete_file,
|
delete_file=delete_file,
|
||||||
)
|
)
|
||||||
|
session.commit()
|
||||||
|
await async_session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
session.rollback()
|
session.rollback()
|
||||||
await async_session.rollback()
|
await async_session.rollback()
|
||||||
@ -196,7 +198,11 @@ def process_crawl_task(
|
|||||||
)
|
)
|
||||||
global engine
|
global engine
|
||||||
assert engine
|
assert engine
|
||||||
|
try:
|
||||||
with Session(engine, expire_on_commit=False, autoflush=False) as session:
|
with Session(engine, expire_on_commit=False, autoflush=False) as session:
|
||||||
|
session.execute(
|
||||||
|
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
|
||||||
|
)
|
||||||
vector_repository = VectorRepository(session)
|
vector_repository = VectorRepository(session)
|
||||||
vector_service = VectorService(vector_repository)
|
vector_service = VectorService(vector_repository)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@ -209,6 +215,12 @@ def process_crawl_task(
|
|||||||
vector_service=vector_service,
|
vector_service=vector_service,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
@celery.task(name="NotionConnectorLoad")
|
@celery.task(name="NotionConnectorLoad")
|
||||||
|
@ -2,6 +2,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from quivr_api.logger import get_logger
|
from quivr_api.logger import get_logger
|
||||||
from quivr_api.modules.brain.service.brain_service import BrainService
|
from quivr_api.modules.brain.service.brain_service import BrainService
|
||||||
|
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeUpdate
|
||||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||||
from quivr_api.modules.vector.service.vector_service import VectorService
|
from quivr_api.modules.vector.service.vector_service import VectorService
|
||||||
|
|
||||||
@ -41,12 +42,10 @@ async def process_uploaded_file(
|
|||||||
# If we have some knowledge with error
|
# If we have some knowledge with error
|
||||||
with build_file(file_data, knowledge_id, file_name) as file_instance:
|
with build_file(file_data, knowledge_id, file_name) as file_instance:
|
||||||
knowledge = await knowledge_service.get_knowledge(knowledge_id=knowledge_id)
|
knowledge = await knowledge_service.get_knowledge(knowledge_id=knowledge_id)
|
||||||
should_process = await knowledge_service.update_sha1_conflict(
|
await knowledge_service.update_knowledge(
|
||||||
knowledge=knowledge,
|
knowledge,
|
||||||
brain_id=brain.brain_id,
|
KnowledgeUpdate(file_sha1=file_instance.file_sha1), # type: ignore
|
||||||
file_sha1=file_instance.file_sha1,
|
|
||||||
)
|
)
|
||||||
if should_process:
|
|
||||||
await process_file(
|
await process_file(
|
||||||
file_instance=file_instance,
|
file_instance=file_instance,
|
||||||
brain=brain,
|
brain=brain,
|
||||||
|
@ -141,7 +141,7 @@ async def process_notion_sync(
|
|||||||
UUID(user_id),
|
UUID(user_id),
|
||||||
notion_client, # type: ignore
|
notion_client, # type: ignore
|
||||||
)
|
)
|
||||||
|
await session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
@ -40,6 +40,8 @@ async def fetch_and_store_notion_files_async(
|
|||||||
else:
|
else:
|
||||||
logger.warn("No notion page fetched")
|
logger.warn("No notion page fetched")
|
||||||
|
|
||||||
|
# Commit all before exiting
|
||||||
|
await session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
@ -6,7 +6,7 @@ from quivr_api.celery_config import celery
|
|||||||
from quivr_api.logger import get_logger
|
from quivr_api.logger import get_logger
|
||||||
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
|
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
|
||||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||||
from quivr_api.modules.knowledge.repository.storage import Storage
|
from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
|
||||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||||
from quivr_api.modules.notification.service.notification_service import (
|
from quivr_api.modules.notification.service.notification_service import (
|
||||||
NotificationService,
|
NotificationService,
|
||||||
@ -42,7 +42,7 @@ class SyncServices:
|
|||||||
sync_files_repo_service: SyncFilesRepository
|
sync_files_repo_service: SyncFilesRepository
|
||||||
notification_service: NotificationService
|
notification_service: NotificationService
|
||||||
brain_vectors: BrainsVectors
|
brain_vectors: BrainsVectors
|
||||||
storage: Storage
|
storage: SupabaseS3Storage
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@ -56,7 +56,6 @@ async def build_syncs_utils(
|
|||||||
await session.execute(
|
await session.execute(
|
||||||
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
|
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
|
||||||
)
|
)
|
||||||
# TODO pass services from celery_worker
|
|
||||||
notion_repository = NotionRepository(session)
|
notion_repository = NotionRepository(session)
|
||||||
notion_service = SyncNotionService(notion_repository)
|
notion_service = SyncNotionService(notion_repository)
|
||||||
knowledge_service = KnowledgeService(KnowledgeRepository(session))
|
knowledge_service = KnowledgeService(KnowledgeRepository(session))
|
||||||
@ -84,7 +83,7 @@ async def build_syncs_utils(
|
|||||||
mapping_sync_utils[provider_name] = provider_sync_util
|
mapping_sync_utils[provider_name] = provider_sync_util
|
||||||
|
|
||||||
yield mapping_sync_utils
|
yield mapping_sync_utils
|
||||||
|
await session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise e
|
raise e
|
||||||
|
Loading…
Reference in New Issue
Block a user