diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index 544530b19..3b8f1e89f 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -95,13 +95,13 @@ async def generate_signed_url_endpoint( knowledge = await knowledge_service.get_knowledge(knowledge_id) - if len(knowledge.brains_ids) == 0: + if len(knowledge.brain_ids) == 0: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="knowledge not associated with brains yet.", ) - brain_id = knowledge.brains_ids[0] + brain_id = knowledge.brain_ids[0] validate_brain_authorization(brain_id=brain_id, user_id=current_user.id) diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index 3ef42f9bc..436e24061 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -100,6 +100,22 @@ class KnowledgeRepository(BaseRepository): return knowledge + async def get_knowledge_by_file_name_brain_id( + self, file_name: str, brain_id: UUID + ) -> KnowledgeDB: + query = ( + select(KnowledgeDB) + .where(KnowledgeDB.file_name == file_name) + .where(KnowledgeDB.brains.any(brain_id=brain_id)) # type: ignore + ) + + result = await self.session.exec(query) + knowledge = result.first() + if not knowledge: + raise NoResultFound("Knowledge not found") + + return knowledge + async def get_knowledge_by_sha1(self, sha1: str) -> KnowledgeDB: query = select(KnowledgeDB).where(KnowledgeDB.file_sha1 == sha1) result = await self.session.exec(query) diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 702f1f423..8ebcde479 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -18,6 +18,7 @@ from quivr_api.modules.sync.entity.sync_models import ( DownloadedSyncFile, SyncFile, ) +from quivr_api.modules.upload.service.upload_file import check_file_exists logger = get_logger(__name__) @@ -35,6 +36,24 @@ class KnowledgeService(BaseService[KnowledgeRepository]): km = await km.to_dto() return km + # TODO: this is temporary fix for getting knowledge path. + # KM storage path should be unrelated to brain + async def get_knowledge_storage_path( + self, file_name: str, brain_id: UUID + ) -> str | None: + try: + km = await self.repository.get_knowledge_by_file_name_brain_id( + file_name, brain_id + ) + brains = await km.awaitable_attrs.brains + return next( + f"{b.brain_id}/{file_name}" + for b in brains + if check_file_exists(b.brain_id, file_name) + ) + except NoResultFound: + raise FileNotFoundError(f"No knowledge for file_name: {file_name}") + async def get_knowledge(self, knowledge_id: UUID) -> Knowledge: inserted_knowledge_db_instance = await self.repository.get_knowledge_by_id( knowledge_id diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledges.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledges.py index afb7ba8d0..ad6b00a2c 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledges.py @@ -12,6 +12,7 @@ from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService +from quivr_api.modules.upload.service.upload_file import upload_file_storage from quivr_api.modules.user.entity.user_identity import User from quivr_api.vector.entity.vector import Vector from sqlalchemy.exc import IntegrityError, NoResultFound @@ -96,8 +97,8 @@ async def test_data(session: AsyncSession) -> TestData: ) knowledge_brain_1 = KnowledgeDB( - file_name="test_file_1", - extension="txt", + file_name="test_file_1.txt", + extension=".txt", status="UPLOADED", source="test_source", source_link="test_source_link", @@ -108,8 +109,8 @@ async def test_data(session: AsyncSession) -> TestData: ) knowledge_brain_2 = KnowledgeDB( - file_name="test_file_2", - extension="txt", + file_name="test_file_2.txt", + extension=".txt", status="UPLOADED", source="test_source", source_link="test_source_link", @@ -349,7 +350,7 @@ async def test_should_process_knowledge_link_brain( assert brain.brain_id prev = KnowledgeDB( file_name="prev", - extension="txt", + extension=".txt", status=KnowledgeStatus.UPLOADED, source="test_source", source_link="test_source_link", @@ -465,3 +466,29 @@ async def test_should_process_knowledge_prev_error( assert new.id new = await service.repository.get_knowledge_by_id(new.id) assert new.file_sha1 + + +@pytest.mark.asyncio +async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData): + brain, [knowledge, _] = test_data + assert knowledge.file_name + repository = KnowledgeRepository(session) + service = KnowledgeService(repository) + brain_2 = Brain( + name="test_brain", + description="this is a test brain", + brain_type=BrainType.integration, + ) + session.add(brain_2) + await session.commit() + await session.refresh(brain_2) + assert brain_2.brain_id + km_data = os.urandom(128) + km_path = f"{str(knowledge.brains[0].brain_id)}/{knowledge.file_name}" + await upload_file_storage(km_data, km_path) + # Link knowledge to two brains + await repository.link_to_brain(knowledge, brain_2.brain_id) + storage_path = await service.get_knowledge_storage_path( + knowledge.file_name, brain_2.brain_id + ) + assert storage_path == km_path diff --git a/backend/api/quivr_api/modules/rag_service/rag_service.py b/backend/api/quivr_api/modules/rag_service/rag_service.py index 56424ca39..a46faf253 100644 --- a/backend/api/quivr_api/modules/rag_service/rag_service.py +++ b/backend/api/quivr_api/modules/rag_service/rag_service.py @@ -263,10 +263,11 @@ class RAGService: streamed_chat_history.metadata["snippet_emoji"] = ( self.brain.snippet_emoji if self.brain else None ) - sources_urls = generate_source( - response.metadata.sources, - self.brain.brain_id, - ( + sources_urls = await generate_source( + knowledge_service=self.knowledge_service, + brain_id=self.brain.brain_id, + source_documents=response.metadata.sources, + citations=( streamed_chat_history.metadata["citations"] if streamed_chat_history.metadata else None diff --git a/backend/api/quivr_api/modules/rag_service/utils.py b/backend/api/quivr_api/modules/rag_service/utils.py index d151c9ff1..3b64bc7c9 100644 --- a/backend/api/quivr_api/modules/rag_service/utils.py +++ b/backend/api/quivr_api/modules/rag_service/utils.py @@ -3,6 +3,7 @@ from typing import Any, List from uuid import UUID from quivr_api.modules.chat.dto.chats import Sources +from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.upload.service.generate_file_signed_url import ( generate_file_signed_url, ) @@ -11,9 +12,10 @@ logger = logging.getLogger(__name__) # TODO: REFACTOR THIS, it does call the DB , so maybe in a service -def generate_source( - source_documents: List[Any] | None, +async def generate_source( + knowledge_service: KnowledgeService, brain_id: UUID, + source_documents: List[Any] | None, citations: List[int] | None = None, ) -> List[Sources]: """ @@ -62,8 +64,11 @@ def generate_source( if is_url: source_url = doc.metadata["original_file_name"] else: - file_path = f"{brain_id}/{doc.metadata['file_name']}" # Check if the URL has already been generated + file_name = doc.metadata["file_name"] + file_path = await knowledge_service.get_knowledge_storage_path( + file_name=file_name, brain_id=brain_id + ) if file_path in generated_urls: source_url = generated_urls[file_path] else: diff --git a/backend/api/quivr_api/modules/sync/tests/conftest.py b/backend/api/quivr_api/modules/sync/tests/conftest.py index 2e84900f4..4775d53ed 100644 --- a/backend/api/quivr_api/modules/sync/tests/conftest.py +++ b/backend/api/quivr_api/modules/sync/tests/conftest.py @@ -77,6 +77,16 @@ from quivr_api.modules.user.entity.user_identity import User pg_database_base_url = "postgres:postgres@localhost:54322/postgres" +@pytest.fixture(scope="module") +def event_loop(): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + yield loop + loop.close() + + @pytest.fixture(scope="function") def page_response() -> dict[str, Any]: json_path = ( @@ -182,17 +192,7 @@ def fetch_response(): @pytest.fixture(scope="session") -def event_loop(): - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - yield loop - loop.close() - - -@pytest_asyncio.fixture(scope="session") -async def sync_engine(): +def sync_engine(): engine = create_engine( "postgresql://" + pg_database_base_url, echo=True if os.getenv("ORM_DEBUG") else False, @@ -204,8 +204,8 @@ async def sync_engine(): yield engine -@pytest_asyncio.fixture() -async def sync_session(sync_engine): +@pytest.fixture +def sync_session(sync_engine): with sync_engine.connect() as conn: conn.begin() conn.begin_nested() @@ -273,7 +273,9 @@ def search_result(): ] -@pytest_asyncio.fixture(scope="session") +@pytest_asyncio.fixture( + scope="session", +) async def async_engine(): engine = create_async_engine( "postgresql+asyncpg://" + pg_database_base_url,