mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 01:21:48 +03:00
fix: knowledge user (#3124)
# Description - Add `user_id` foreign to knowledge table - Updates knowledges service - Updates sha1 contraint to be on (`user_id`,`file_sha1`) on knowledge - Fix event_loop fixture - Adds knowledge tests --------- Co-authored-by: chloedia <chloedaems0@gmail.com> Co-authored-by: Stan Girard <stan@quivr.app> Co-authored-by: aminediro <aminedirhoussi@gmail.com> Co-authored-by: Antoine Dewez <44063631+Zewed@users.noreply.github.com> Co-authored-by: Stan Girard <girard.stanislas@gmail.com> Co-authored-by: Chloé Daems <73901882+chloedia@users.noreply.github.com> Co-authored-by: porter-deployment-app[bot] <87230664+porter-deployment-app[bot]@users.noreply.github.com> Co-authored-by: Zewed <dewez.antoine2@gmail.com>
This commit is contained in:
parent
077d470602
commit
b59fc4c0c9
@ -25,8 +25,11 @@ n_seed_chats_history = 3
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop(request: pytest.FixtureRequest):
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
def event_loop():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
@ -66,6 +66,7 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
|
||||
metadata_: Optional[Dict[str, str]] = Field(
|
||||
default=None, sa_column=Column("metadata", JSON)
|
||||
)
|
||||
user_id: UUID = Field(foreign_key="users.id", nullable=False)
|
||||
brains: List["Brain"] = Relationship(
|
||||
back_populates="knowledges",
|
||||
link_model=KnowledgeBrain,
|
||||
|
@ -45,6 +45,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
|
||||
async def insert_knowledge(
|
||||
self,
|
||||
user_id: UUID,
|
||||
knowledge_to_add: CreateKnowledgeProperties, # FIXME: (later) @Amine brain id should not be in CreateKnowledgeProperties but since storage is brain_id/file_name
|
||||
) -> Knowledge:
|
||||
knowledge = KnowledgeDB(
|
||||
@ -57,6 +58,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
file_size=knowledge_to_add.file_size,
|
||||
file_sha1=knowledge_to_add.file_sha1,
|
||||
metadata_=knowledge_to_add.metadata, # type: ignore
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
knowledge_db = await self.repository.insert_knowledge(
|
||||
@ -133,6 +135,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
async def update_or_create_knowledge_sync(
|
||||
self,
|
||||
brain_id: UUID,
|
||||
user_id: UUID,
|
||||
file: SyncFile,
|
||||
new_sync_file: DBSyncFile | None,
|
||||
prev_sync_file: DBSyncFile | None,
|
||||
@ -160,5 +163,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
# FIXME (@aminediro): This is a temporary fix, redo in KMS
|
||||
metadata={"sync_file_id": str(sync_id)},
|
||||
)
|
||||
added_knowledge = await self.insert_knowledge(knowledge_to_add)
|
||||
added_knowledge = await self.insert_knowledge(
|
||||
knowledge_to_add=knowledge_to_add, user_id=user_id
|
||||
)
|
||||
return added_knowledge
|
||||
|
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
@ -11,10 +12,11 @@ from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
||||
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
|
||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||
from quivr_api.modules.user.entity.user_identity import User
|
||||
from quivr_api.vector.entity.vector import Vector
|
||||
from sqlalchemy.exc import IntegrityError, NoResultFound
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import select
|
||||
from sqlmodel import select, text
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
|
||||
@ -23,8 +25,11 @@ TestData = Tuple[Brain, List[KnowledgeDB]]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop(request: pytest.FixtureRequest):
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
def event_loop():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
@ -61,8 +66,28 @@ async def session(async_engine):
|
||||
yield async_session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def other_user(session: AsyncSession):
|
||||
sql = text(
|
||||
"""
|
||||
INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES
|
||||
('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL);
|
||||
"""
|
||||
)
|
||||
await session.execute(sql, params={"id": uuid4()})
|
||||
|
||||
other_user = (
|
||||
await session.exec(select(User).where(User.email == "other@quivr.app"))
|
||||
).one()
|
||||
return other_user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def test_data(session: AsyncSession) -> TestData:
|
||||
user_1 = (
|
||||
await session.exec(select(User).where(User.email == "admin@quivr.app"))
|
||||
).one()
|
||||
assert user_1.id
|
||||
# Brain data
|
||||
brain_1 = Brain(
|
||||
name="test_brain",
|
||||
@ -79,6 +104,7 @@ async def test_data(session: AsyncSession) -> TestData:
|
||||
file_size=100,
|
||||
file_sha1="test_sha1",
|
||||
brains=[brain_1],
|
||||
user_id=user_1.id,
|
||||
)
|
||||
|
||||
knowledge_brain_2 = KnowledgeDB(
|
||||
@ -90,6 +116,7 @@ async def test_data(session: AsyncSession) -> TestData:
|
||||
file_size=100,
|
||||
file_sha1="test_sha2",
|
||||
brains=[],
|
||||
user_id=user_1.id,
|
||||
)
|
||||
|
||||
session.add(brain_1)
|
||||
@ -183,7 +210,9 @@ async def test_remove_all_knowledges_from_brain(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_sha1_knowledge(session: AsyncSession, test_data: TestData):
|
||||
async def test_duplicate_sha1_knowledge_same_user(
|
||||
session: AsyncSession, test_data: TestData
|
||||
):
|
||||
brain, knowledges = test_data
|
||||
assert brain.brain_id
|
||||
assert knowledges[0].id
|
||||
@ -197,12 +226,38 @@ async def test_duplicate_sha1_knowledge(session: AsyncSession, test_data: TestDa
|
||||
file_size=100,
|
||||
file_sha1="test_sha1",
|
||||
brains=[brain],
|
||||
user_id=knowledges[0].user_id,
|
||||
)
|
||||
|
||||
with pytest.raises(IntegrityError): # FIXME: Should raise IntegrityError
|
||||
await repo.insert_knowledge(knowledge, brain.brain_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_sha1_knowledge_diff_user(
|
||||
session: AsyncSession, test_data: TestData, other_user: User
|
||||
):
|
||||
brain, knowledges = test_data
|
||||
assert other_user.id
|
||||
assert brain.brain_id
|
||||
assert knowledges[0].id
|
||||
repo = KnowledgeRepository(session)
|
||||
knowledge = KnowledgeDB(
|
||||
file_name="test_file_2",
|
||||
extension="txt",
|
||||
status="UPLOADED",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
file_size=100,
|
||||
file_sha1=knowledges[0].file_sha1,
|
||||
brains=[brain],
|
||||
user_id=other_user.id, # random user id
|
||||
)
|
||||
|
||||
result = await repo.insert_knowledge(knowledge, brain.brain_id)
|
||||
assert result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_knowledge_to_brain(session: AsyncSession, test_data: TestData):
|
||||
brain, knowledges = test_data
|
||||
|
@ -5,21 +5,23 @@ from typing import Tuple
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import sqlalchemy
|
||||
from quivr_api.modules.models.entity.model import Model
|
||||
from quivr_api.modules.user.entity.user_identity import User
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from quivr_api.modules.models.entity.model import Model
|
||||
from quivr_api.modules.user.entity.user_identity import User
|
||||
|
||||
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
|
||||
|
||||
TestData = Tuple[Model, Model, User]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop(request: pytest.FixtureRequest):
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def event_loop():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
@ -182,8 +182,11 @@ def fetch_response():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop(request: pytest.FixtureRequest):
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
def event_loop():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
@ -196,6 +196,7 @@ class SyncUtils:
|
||||
downloaded_file=downloaded_file,
|
||||
source=source,
|
||||
source_link=source_link,
|
||||
user_id=current_user.user_id,
|
||||
)
|
||||
|
||||
# Send file for processing
|
||||
|
@ -121,7 +121,9 @@ async def upload_file(
|
||||
file_size=uploadFile.size,
|
||||
file_sha1=None,
|
||||
)
|
||||
knowledge = await knowledge_service.insert_knowledge(knowledge_to_add) # type: ignore
|
||||
knowledge = await knowledge_service.insert_knowledge(
|
||||
user_id=current_user.id, knowledge_to_add=knowledge_to_add
|
||||
) # type: ignore
|
||||
|
||||
celery.send_task(
|
||||
"process_file_task",
|
||||
|
@ -87,7 +87,9 @@ async def crawl_endpoint(
|
||||
source_link=crawl_website.url,
|
||||
)
|
||||
|
||||
added_knowledge = await knowledge_service.insert_knowledge(knowledge_to_add)
|
||||
added_knowledge = await knowledge_service.insert_knowledge(
|
||||
knowledge_to_add=knowledge_to_add, user_id=current_user.id
|
||||
)
|
||||
logger.info(f"Knowledge {added_knowledge} added successfully")
|
||||
|
||||
celery.send_task(
|
||||
|
@ -64,7 +64,7 @@ def test_data(session: Session, embedder) -> TestData:
|
||||
)
|
||||
knowledge_1 = KnowledgeDB(
|
||||
file_name="test_file_1",
|
||||
mime_type="txt",
|
||||
extension=".txt",
|
||||
status="UPLOADED",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
|
@ -0,0 +1,10 @@
|
||||
alter table "public"."knowledge" add column "user_id" uuid;
|
||||
alter table "public"."knowledge" add constraint "public_knowledge_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id) ON UPDATE CASCADE ON DELETE CASCADE not valid;
|
||||
alter table "public"."knowledge" validate constraint "public_knowledge_user_id_fkey";
|
||||
|
||||
-- alter table
|
||||
ALTER TABLE "public"."knowledge"
|
||||
DROP CONSTRAINT "unique_file_sha1";
|
||||
|
||||
ALTER TABLE "public"."knowledge"
|
||||
ADD CONSTRAINT "unique_file_sha1_user_id" UNIQUE ("file_sha1", "user_id");
|
Loading…
Reference in New Issue
Block a user