fix: url knowledge multiple brain (#3145)

# Description

- Find knowledge path in storage
This commit is contained in:
AmineDiro 2024-09-05 10:15:51 +02:00 committed by GitHub
parent 784c131441
commit 9a4ee1506b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 98 additions and 28 deletions

View File

@ -95,13 +95,13 @@ async def generate_signed_url_endpoint(
knowledge = await knowledge_service.get_knowledge(knowledge_id) knowledge = await knowledge_service.get_knowledge(knowledge_id)
if len(knowledge.brains_ids) == 0: if len(knowledge.brain_ids) == 0:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, status_code=HTTPStatus.NOT_FOUND,
detail="knowledge not associated with brains yet.", detail="knowledge not associated with brains yet.",
) )
brain_id = knowledge.brains_ids[0] brain_id = knowledge.brain_ids[0]
validate_brain_authorization(brain_id=brain_id, user_id=current_user.id) validate_brain_authorization(brain_id=brain_id, user_id=current_user.id)

View File

@ -100,6 +100,22 @@ class KnowledgeRepository(BaseRepository):
return knowledge return knowledge
async def get_knowledge_by_file_name_brain_id(
self, file_name: str, brain_id: UUID
) -> KnowledgeDB:
query = (
select(KnowledgeDB)
.where(KnowledgeDB.file_name == file_name)
.where(KnowledgeDB.brains.any(brain_id=brain_id)) # type: ignore
)
result = await self.session.exec(query)
knowledge = result.first()
if not knowledge:
raise NoResultFound("Knowledge not found")
return knowledge
async def get_knowledge_by_sha1(self, sha1: str) -> KnowledgeDB: async def get_knowledge_by_sha1(self, sha1: str) -> KnowledgeDB:
query = select(KnowledgeDB).where(KnowledgeDB.file_sha1 == sha1) query = select(KnowledgeDB).where(KnowledgeDB.file_sha1 == sha1)
result = await self.session.exec(query) result = await self.session.exec(query)

View File

@ -18,6 +18,7 @@ from quivr_api.modules.sync.entity.sync_models import (
DownloadedSyncFile, DownloadedSyncFile,
SyncFile, SyncFile,
) )
from quivr_api.modules.upload.service.upload_file import check_file_exists
logger = get_logger(__name__) logger = get_logger(__name__)
@ -35,6 +36,24 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
km = await km.to_dto() km = await km.to_dto()
return km return km
# TODO: this is temporary fix for getting knowledge path.
# KM storage path should be unrelated to brain
async def get_knowledge_storage_path(
self, file_name: str, brain_id: UUID
) -> str | None:
try:
km = await self.repository.get_knowledge_by_file_name_brain_id(
file_name, brain_id
)
brains = await km.awaitable_attrs.brains
return next(
f"{b.brain_id}/{file_name}"
for b in brains
if check_file_exists(b.brain_id, file_name)
)
except NoResultFound:
raise FileNotFoundError(f"No knowledge for file_name: {file_name}")
async def get_knowledge(self, knowledge_id: UUID) -> Knowledge: async def get_knowledge(self, knowledge_id: UUID) -> Knowledge:
inserted_knowledge_db_instance = await self.repository.get_knowledge_by_id( inserted_knowledge_db_instance = await self.repository.get_knowledge_by_id(
knowledge_id knowledge_id

View File

@ -12,6 +12,7 @@ from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.upload.service.upload_file import upload_file_storage
from quivr_api.modules.user.entity.user_identity import User from quivr_api.modules.user.entity.user_identity import User
from quivr_api.vector.entity.vector import Vector from quivr_api.vector.entity.vector import Vector
from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.exc import IntegrityError, NoResultFound
@ -96,8 +97,8 @@ async def test_data(session: AsyncSession) -> TestData:
) )
knowledge_brain_1 = KnowledgeDB( knowledge_brain_1 = KnowledgeDB(
file_name="test_file_1", file_name="test_file_1.txt",
extension="txt", extension=".txt",
status="UPLOADED", status="UPLOADED",
source="test_source", source="test_source",
source_link="test_source_link", source_link="test_source_link",
@ -108,8 +109,8 @@ async def test_data(session: AsyncSession) -> TestData:
) )
knowledge_brain_2 = KnowledgeDB( knowledge_brain_2 = KnowledgeDB(
file_name="test_file_2", file_name="test_file_2.txt",
extension="txt", extension=".txt",
status="UPLOADED", status="UPLOADED",
source="test_source", source="test_source",
source_link="test_source_link", source_link="test_source_link",
@ -349,7 +350,7 @@ async def test_should_process_knowledge_link_brain(
assert brain.brain_id assert brain.brain_id
prev = KnowledgeDB( prev = KnowledgeDB(
file_name="prev", file_name="prev",
extension="txt", extension=".txt",
status=KnowledgeStatus.UPLOADED, status=KnowledgeStatus.UPLOADED,
source="test_source", source="test_source",
source_link="test_source_link", source_link="test_source_link",
@ -465,3 +466,29 @@ async def test_should_process_knowledge_prev_error(
assert new.id assert new.id
new = await service.repository.get_knowledge_by_id(new.id) new = await service.repository.get_knowledge_by_id(new.id)
assert new.file_sha1 assert new.file_sha1
@pytest.mark.asyncio
async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData):
brain, [knowledge, _] = test_data
assert knowledge.file_name
repository = KnowledgeRepository(session)
service = KnowledgeService(repository)
brain_2 = Brain(
name="test_brain",
description="this is a test brain",
brain_type=BrainType.integration,
)
session.add(brain_2)
await session.commit()
await session.refresh(brain_2)
assert brain_2.brain_id
km_data = os.urandom(128)
km_path = f"{str(knowledge.brains[0].brain_id)}/{knowledge.file_name}"
await upload_file_storage(km_data, km_path)
# Link knowledge to two brains
await repository.link_to_brain(knowledge, brain_2.brain_id)
storage_path = await service.get_knowledge_storage_path(
knowledge.file_name, brain_2.brain_id
)
assert storage_path == km_path

View File

@ -263,10 +263,11 @@ class RAGService:
streamed_chat_history.metadata["snippet_emoji"] = ( streamed_chat_history.metadata["snippet_emoji"] = (
self.brain.snippet_emoji if self.brain else None self.brain.snippet_emoji if self.brain else None
) )
sources_urls = generate_source( sources_urls = await generate_source(
response.metadata.sources, knowledge_service=self.knowledge_service,
self.brain.brain_id, brain_id=self.brain.brain_id,
( source_documents=response.metadata.sources,
citations=(
streamed_chat_history.metadata["citations"] streamed_chat_history.metadata["citations"]
if streamed_chat_history.metadata if streamed_chat_history.metadata
else None else None

View File

@ -3,6 +3,7 @@ from typing import Any, List
from uuid import UUID from uuid import UUID
from quivr_api.modules.chat.dto.chats import Sources from quivr_api.modules.chat.dto.chats import Sources
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.upload.service.generate_file_signed_url import ( from quivr_api.modules.upload.service.generate_file_signed_url import (
generate_file_signed_url, generate_file_signed_url,
) )
@ -11,9 +12,10 @@ logger = logging.getLogger(__name__)
# TODO: REFACTOR THIS, it does call the DB , so maybe in a service # TODO: REFACTOR THIS, it does call the DB , so maybe in a service
def generate_source( async def generate_source(
source_documents: List[Any] | None, knowledge_service: KnowledgeService,
brain_id: UUID, brain_id: UUID,
source_documents: List[Any] | None,
citations: List[int] | None = None, citations: List[int] | None = None,
) -> List[Sources]: ) -> List[Sources]:
""" """
@ -62,8 +64,11 @@ def generate_source(
if is_url: if is_url:
source_url = doc.metadata["original_file_name"] source_url = doc.metadata["original_file_name"]
else: else:
file_path = f"{brain_id}/{doc.metadata['file_name']}"
# Check if the URL has already been generated # Check if the URL has already been generated
file_name = doc.metadata["file_name"]
file_path = await knowledge_service.get_knowledge_storage_path(
file_name=file_name, brain_id=brain_id
)
if file_path in generated_urls: if file_path in generated_urls:
source_url = generated_urls[file_path] source_url = generated_urls[file_path]
else: else:

View File

@ -77,6 +77,16 @@ from quivr_api.modules.user.entity.user_identity import User
pg_database_base_url = "postgres:postgres@localhost:54322/postgres" pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
@pytest.fixture(scope="module")
def event_loop():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def page_response() -> dict[str, Any]: def page_response() -> dict[str, Any]:
json_path = ( json_path = (
@ -182,17 +192,7 @@ def fetch_response():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def event_loop(): def sync_engine():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="session")
async def sync_engine():
engine = create_engine( engine = create_engine(
"postgresql://" + pg_database_base_url, "postgresql://" + pg_database_base_url,
echo=True if os.getenv("ORM_DEBUG") else False, echo=True if os.getenv("ORM_DEBUG") else False,
@ -204,8 +204,8 @@ async def sync_engine():
yield engine yield engine
@pytest_asyncio.fixture() @pytest.fixture
async def sync_session(sync_engine): def sync_session(sync_engine):
with sync_engine.connect() as conn: with sync_engine.connect() as conn:
conn.begin() conn.begin()
conn.begin_nested() conn.begin_nested()
@ -273,7 +273,9 @@ def search_result():
] ]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(
scope="session",
)
async def async_engine(): async def async_engine():
engine = create_async_engine( engine = create_async_engine(
"postgresql+asyncpg://" + pg_database_base_url, "postgresql+asyncpg://" + pg_database_base_url,