fix: knowledge user (#3124)

# Description

- Add `user_id` foreign to knowledge table
- Updates knowledges service 
- Updates sha1 contraint to be on (`user_id`,`file_sha1`) on knowledge
- Fix event_loop fixture
- Adds knowledge tests

---------

Co-authored-by: chloedia <chloedaems0@gmail.com>
Co-authored-by: Stan Girard <stan@quivr.app>
Co-authored-by: aminediro <aminedirhoussi@gmail.com>
Co-authored-by: Antoine Dewez <44063631+Zewed@users.noreply.github.com>
Co-authored-by: Stan Girard <girard.stanislas@gmail.com>
Co-authored-by: Chloé Daems <73901882+chloedia@users.noreply.github.com>
Co-authored-by: porter-deployment-app[bot] <87230664+porter-deployment-app[bot]@users.noreply.github.com>
Co-authored-by: Zewed <dewez.antoine2@gmail.com>
This commit is contained in:
AmineDiro 2024-09-02 15:07:30 +02:00 committed by GitHub
parent 077d470602
commit b59fc4c0c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 102 additions and 18 deletions

View File

@ -25,8 +25,11 @@ n_seed_chats_history = 3
@pytest.fixture(scope="session")
def event_loop(request: pytest.FixtureRequest):
loop = asyncio.get_event_loop_policy().new_event_loop()
def event_loop():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
yield loop
loop.close()

View File

@ -66,6 +66,7 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
metadata_: Optional[Dict[str, str]] = Field(
default=None, sa_column=Column("metadata", JSON)
)
user_id: UUID = Field(foreign_key="users.id", nullable=False)
brains: List["Brain"] = Relationship(
back_populates="knowledges",
link_model=KnowledgeBrain,

View File

@ -45,6 +45,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
async def insert_knowledge(
self,
user_id: UUID,
knowledge_to_add: CreateKnowledgeProperties, # FIXME: (later) @Amine brain id should not be in CreateKnowledgeProperties but since storage is brain_id/file_name
) -> Knowledge:
knowledge = KnowledgeDB(
@ -57,6 +58,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
file_size=knowledge_to_add.file_size,
file_sha1=knowledge_to_add.file_sha1,
metadata_=knowledge_to_add.metadata, # type: ignore
user_id=user_id,
)
knowledge_db = await self.repository.insert_knowledge(
@ -133,6 +135,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
async def update_or_create_knowledge_sync(
self,
brain_id: UUID,
user_id: UUID,
file: SyncFile,
new_sync_file: DBSyncFile | None,
prev_sync_file: DBSyncFile | None,
@ -160,5 +163,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
# FIXME (@aminediro): This is a temporary fix, redo in KMS
metadata={"sync_file_id": str(sync_id)},
)
added_knowledge = await self.insert_knowledge(knowledge_to_add)
added_knowledge = await self.insert_knowledge(
knowledge_to_add=knowledge_to_add, user_id=user_id
)
return added_knowledge

View File

@ -1,6 +1,7 @@
import asyncio
import os
from typing import List, Tuple
from uuid import uuid4
import pytest
import pytest_asyncio
@ -11,10 +12,11 @@ from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.user.entity.user_identity import User
from quivr_api.vector.entity.vector import Vector
from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import select
from sqlmodel import select, text
from sqlmodel.ext.asyncio.session import AsyncSession
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
@ -23,8 +25,11 @@ TestData = Tuple[Brain, List[KnowledgeDB]]
@pytest.fixture(scope="session")
def event_loop(request: pytest.FixtureRequest):
loop = asyncio.get_event_loop_policy().new_event_loop()
def event_loop():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
yield loop
loop.close()
@ -61,8 +66,28 @@ async def session(async_engine):
yield async_session
@pytest_asyncio.fixture
async def other_user(session: AsyncSession):
sql = text(
"""
INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES
('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL);
"""
)
await session.execute(sql, params={"id": uuid4()})
other_user = (
await session.exec(select(User).where(User.email == "other@quivr.app"))
).one()
return other_user
@pytest_asyncio.fixture()
async def test_data(session: AsyncSession) -> TestData:
user_1 = (
await session.exec(select(User).where(User.email == "admin@quivr.app"))
).one()
assert user_1.id
# Brain data
brain_1 = Brain(
name="test_brain",
@ -79,6 +104,7 @@ async def test_data(session: AsyncSession) -> TestData:
file_size=100,
file_sha1="test_sha1",
brains=[brain_1],
user_id=user_1.id,
)
knowledge_brain_2 = KnowledgeDB(
@ -90,6 +116,7 @@ async def test_data(session: AsyncSession) -> TestData:
file_size=100,
file_sha1="test_sha2",
brains=[],
user_id=user_1.id,
)
session.add(brain_1)
@ -183,7 +210,9 @@ async def test_remove_all_knowledges_from_brain(
@pytest.mark.asyncio
async def test_duplicate_sha1_knowledge(session: AsyncSession, test_data: TestData):
async def test_duplicate_sha1_knowledge_same_user(
session: AsyncSession, test_data: TestData
):
brain, knowledges = test_data
assert brain.brain_id
assert knowledges[0].id
@ -197,12 +226,38 @@ async def test_duplicate_sha1_knowledge(session: AsyncSession, test_data: TestDa
file_size=100,
file_sha1="test_sha1",
brains=[brain],
user_id=knowledges[0].user_id,
)
with pytest.raises(IntegrityError): # FIXME: Should raise IntegrityError
await repo.insert_knowledge(knowledge, brain.brain_id)
@pytest.mark.asyncio
async def test_duplicate_sha1_knowledge_diff_user(
session: AsyncSession, test_data: TestData, other_user: User
):
brain, knowledges = test_data
assert other_user.id
assert brain.brain_id
assert knowledges[0].id
repo = KnowledgeRepository(session)
knowledge = KnowledgeDB(
file_name="test_file_2",
extension="txt",
status="UPLOADED",
source="test_source",
source_link="test_source_link",
file_size=100,
file_sha1=knowledges[0].file_sha1,
brains=[brain],
user_id=other_user.id, # random user id
)
result = await repo.insert_knowledge(knowledge, brain.brain_id)
assert result
@pytest.mark.asyncio
async def test_add_knowledge_to_brain(session: AsyncSession, test_data: TestData):
brain, knowledges = test_data

View File

@ -5,21 +5,23 @@ from typing import Tuple
import pytest
import pytest_asyncio
import sqlalchemy
from quivr_api.modules.models.entity.model import Model
from quivr_api.modules.user.entity.user_identity import User
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.modules.models.entity.model import Model
from quivr_api.modules.user.entity.user_identity import User
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
TestData = Tuple[Model, Model, User]
@pytest.fixture(scope="session")
def event_loop(request: pytest.FixtureRequest):
loop = asyncio.get_event_loop_policy().new_event_loop()
@pytest.fixture(scope="session", autouse=True)
def event_loop():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
yield loop
loop.close()

View File

@ -182,8 +182,11 @@ def fetch_response():
@pytest.fixture(scope="session")
def event_loop(request: pytest.FixtureRequest):
loop = asyncio.get_event_loop_policy().new_event_loop()
def event_loop():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
yield loop
loop.close()

View File

@ -196,6 +196,7 @@ class SyncUtils:
downloaded_file=downloaded_file,
source=source,
source_link=source_link,
user_id=current_user.user_id,
)
# Send file for processing

View File

@ -121,7 +121,9 @@ async def upload_file(
file_size=uploadFile.size,
file_sha1=None,
)
knowledge = await knowledge_service.insert_knowledge(knowledge_to_add) # type: ignore
knowledge = await knowledge_service.insert_knowledge(
user_id=current_user.id, knowledge_to_add=knowledge_to_add
) # type: ignore
celery.send_task(
"process_file_task",

View File

@ -87,7 +87,9 @@ async def crawl_endpoint(
source_link=crawl_website.url,
)
added_knowledge = await knowledge_service.insert_knowledge(knowledge_to_add)
added_knowledge = await knowledge_service.insert_knowledge(
knowledge_to_add=knowledge_to_add, user_id=current_user.id
)
logger.info(f"Knowledge {added_knowledge} added successfully")
celery.send_task(

View File

@ -64,7 +64,7 @@ def test_data(session: Session, embedder) -> TestData:
)
knowledge_1 = KnowledgeDB(
file_name="test_file_1",
mime_type="txt",
extension=".txt",
status="UPLOADED",
source="test_source",
source_link="test_source_link",

View File

@ -0,0 +1,10 @@
alter table "public"."knowledge" add column "user_id" uuid;
alter table "public"."knowledge" add constraint "public_knowledge_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id) ON UPDATE CASCADE ON DELETE CASCADE not valid;
alter table "public"."knowledge" validate constraint "public_knowledge_user_id_fkey";
-- alter table
ALTER TABLE "public"."knowledge"
DROP CONSTRAINT "unique_file_sha1";
ALTER TABLE "public"."knowledge"
ADD CONSTRAINT "unique_file_sha1_user_id" UNIQUE ("file_sha1", "user_id");