mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 01:21:48 +03:00
fix: url knowledge multiple brain (#3145)
# Description - Find knowledge path in storage
This commit is contained in:
parent
784c131441
commit
9a4ee1506b
@ -95,13 +95,13 @@ async def generate_signed_url_endpoint(
|
|||||||
|
|
||||||
knowledge = await knowledge_service.get_knowledge(knowledge_id)
|
knowledge = await knowledge_service.get_knowledge(knowledge_id)
|
||||||
|
|
||||||
if len(knowledge.brains_ids) == 0:
|
if len(knowledge.brain_ids) == 0:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.NOT_FOUND,
|
status_code=HTTPStatus.NOT_FOUND,
|
||||||
detail="knowledge not associated with brains yet.",
|
detail="knowledge not associated with brains yet.",
|
||||||
)
|
)
|
||||||
|
|
||||||
brain_id = knowledge.brains_ids[0]
|
brain_id = knowledge.brain_ids[0]
|
||||||
|
|
||||||
validate_brain_authorization(brain_id=brain_id, user_id=current_user.id)
|
validate_brain_authorization(brain_id=brain_id, user_id=current_user.id)
|
||||||
|
|
||||||
|
@ -100,6 +100,22 @@ class KnowledgeRepository(BaseRepository):
|
|||||||
|
|
||||||
return knowledge
|
return knowledge
|
||||||
|
|
||||||
|
async def get_knowledge_by_file_name_brain_id(
|
||||||
|
self, file_name: str, brain_id: UUID
|
||||||
|
) -> KnowledgeDB:
|
||||||
|
query = (
|
||||||
|
select(KnowledgeDB)
|
||||||
|
.where(KnowledgeDB.file_name == file_name)
|
||||||
|
.where(KnowledgeDB.brains.any(brain_id=brain_id)) # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self.session.exec(query)
|
||||||
|
knowledge = result.first()
|
||||||
|
if not knowledge:
|
||||||
|
raise NoResultFound("Knowledge not found")
|
||||||
|
|
||||||
|
return knowledge
|
||||||
|
|
||||||
async def get_knowledge_by_sha1(self, sha1: str) -> KnowledgeDB:
|
async def get_knowledge_by_sha1(self, sha1: str) -> KnowledgeDB:
|
||||||
query = select(KnowledgeDB).where(KnowledgeDB.file_sha1 == sha1)
|
query = select(KnowledgeDB).where(KnowledgeDB.file_sha1 == sha1)
|
||||||
result = await self.session.exec(query)
|
result = await self.session.exec(query)
|
||||||
|
@ -18,6 +18,7 @@ from quivr_api.modules.sync.entity.sync_models import (
|
|||||||
DownloadedSyncFile,
|
DownloadedSyncFile,
|
||||||
SyncFile,
|
SyncFile,
|
||||||
)
|
)
|
||||||
|
from quivr_api.modules.upload.service.upload_file import check_file_exists
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -35,6 +36,24 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
|||||||
km = await km.to_dto()
|
km = await km.to_dto()
|
||||||
return km
|
return km
|
||||||
|
|
||||||
|
# TODO: this is temporary fix for getting knowledge path.
|
||||||
|
# KM storage path should be unrelated to brain
|
||||||
|
async def get_knowledge_storage_path(
|
||||||
|
self, file_name: str, brain_id: UUID
|
||||||
|
) -> str | None:
|
||||||
|
try:
|
||||||
|
km = await self.repository.get_knowledge_by_file_name_brain_id(
|
||||||
|
file_name, brain_id
|
||||||
|
)
|
||||||
|
brains = await km.awaitable_attrs.brains
|
||||||
|
return next(
|
||||||
|
f"{b.brain_id}/{file_name}"
|
||||||
|
for b in brains
|
||||||
|
if check_file_exists(b.brain_id, file_name)
|
||||||
|
)
|
||||||
|
except NoResultFound:
|
||||||
|
raise FileNotFoundError(f"No knowledge for file_name: {file_name}")
|
||||||
|
|
||||||
async def get_knowledge(self, knowledge_id: UUID) -> Knowledge:
|
async def get_knowledge(self, knowledge_id: UUID) -> Knowledge:
|
||||||
inserted_knowledge_db_instance = await self.repository.get_knowledge_by_id(
|
inserted_knowledge_db_instance = await self.repository.get_knowledge_by_id(
|
||||||
knowledge_id
|
knowledge_id
|
||||||
|
@ -12,6 +12,7 @@ from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
|||||||
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
|
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
|
||||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||||
|
from quivr_api.modules.upload.service.upload_file import upload_file_storage
|
||||||
from quivr_api.modules.user.entity.user_identity import User
|
from quivr_api.modules.user.entity.user_identity import User
|
||||||
from quivr_api.vector.entity.vector import Vector
|
from quivr_api.vector.entity.vector import Vector
|
||||||
from sqlalchemy.exc import IntegrityError, NoResultFound
|
from sqlalchemy.exc import IntegrityError, NoResultFound
|
||||||
@ -96,8 +97,8 @@ async def test_data(session: AsyncSession) -> TestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
knowledge_brain_1 = KnowledgeDB(
|
knowledge_brain_1 = KnowledgeDB(
|
||||||
file_name="test_file_1",
|
file_name="test_file_1.txt",
|
||||||
extension="txt",
|
extension=".txt",
|
||||||
status="UPLOADED",
|
status="UPLOADED",
|
||||||
source="test_source",
|
source="test_source",
|
||||||
source_link="test_source_link",
|
source_link="test_source_link",
|
||||||
@ -108,8 +109,8 @@ async def test_data(session: AsyncSession) -> TestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
knowledge_brain_2 = KnowledgeDB(
|
knowledge_brain_2 = KnowledgeDB(
|
||||||
file_name="test_file_2",
|
file_name="test_file_2.txt",
|
||||||
extension="txt",
|
extension=".txt",
|
||||||
status="UPLOADED",
|
status="UPLOADED",
|
||||||
source="test_source",
|
source="test_source",
|
||||||
source_link="test_source_link",
|
source_link="test_source_link",
|
||||||
@ -349,7 +350,7 @@ async def test_should_process_knowledge_link_brain(
|
|||||||
assert brain.brain_id
|
assert brain.brain_id
|
||||||
prev = KnowledgeDB(
|
prev = KnowledgeDB(
|
||||||
file_name="prev",
|
file_name="prev",
|
||||||
extension="txt",
|
extension=".txt",
|
||||||
status=KnowledgeStatus.UPLOADED,
|
status=KnowledgeStatus.UPLOADED,
|
||||||
source="test_source",
|
source="test_source",
|
||||||
source_link="test_source_link",
|
source_link="test_source_link",
|
||||||
@ -465,3 +466,29 @@ async def test_should_process_knowledge_prev_error(
|
|||||||
assert new.id
|
assert new.id
|
||||||
new = await service.repository.get_knowledge_by_id(new.id)
|
new = await service.repository.get_knowledge_by_id(new.id)
|
||||||
assert new.file_sha1
|
assert new.file_sha1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData):
|
||||||
|
brain, [knowledge, _] = test_data
|
||||||
|
assert knowledge.file_name
|
||||||
|
repository = KnowledgeRepository(session)
|
||||||
|
service = KnowledgeService(repository)
|
||||||
|
brain_2 = Brain(
|
||||||
|
name="test_brain",
|
||||||
|
description="this is a test brain",
|
||||||
|
brain_type=BrainType.integration,
|
||||||
|
)
|
||||||
|
session.add(brain_2)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(brain_2)
|
||||||
|
assert brain_2.brain_id
|
||||||
|
km_data = os.urandom(128)
|
||||||
|
km_path = f"{str(knowledge.brains[0].brain_id)}/{knowledge.file_name}"
|
||||||
|
await upload_file_storage(km_data, km_path)
|
||||||
|
# Link knowledge to two brains
|
||||||
|
await repository.link_to_brain(knowledge, brain_2.brain_id)
|
||||||
|
storage_path = await service.get_knowledge_storage_path(
|
||||||
|
knowledge.file_name, brain_2.brain_id
|
||||||
|
)
|
||||||
|
assert storage_path == km_path
|
||||||
|
@ -263,10 +263,11 @@ class RAGService:
|
|||||||
streamed_chat_history.metadata["snippet_emoji"] = (
|
streamed_chat_history.metadata["snippet_emoji"] = (
|
||||||
self.brain.snippet_emoji if self.brain else None
|
self.brain.snippet_emoji if self.brain else None
|
||||||
)
|
)
|
||||||
sources_urls = generate_source(
|
sources_urls = await generate_source(
|
||||||
response.metadata.sources,
|
knowledge_service=self.knowledge_service,
|
||||||
self.brain.brain_id,
|
brain_id=self.brain.brain_id,
|
||||||
(
|
source_documents=response.metadata.sources,
|
||||||
|
citations=(
|
||||||
streamed_chat_history.metadata["citations"]
|
streamed_chat_history.metadata["citations"]
|
||||||
if streamed_chat_history.metadata
|
if streamed_chat_history.metadata
|
||||||
else None
|
else None
|
||||||
|
@ -3,6 +3,7 @@ from typing import Any, List
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from quivr_api.modules.chat.dto.chats import Sources
|
from quivr_api.modules.chat.dto.chats import Sources
|
||||||
|
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||||
from quivr_api.modules.upload.service.generate_file_signed_url import (
|
from quivr_api.modules.upload.service.generate_file_signed_url import (
|
||||||
generate_file_signed_url,
|
generate_file_signed_url,
|
||||||
)
|
)
|
||||||
@ -11,9 +12,10 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# TODO: REFACTOR THIS, it does call the DB , so maybe in a service
|
# TODO: REFACTOR THIS, it does call the DB , so maybe in a service
|
||||||
def generate_source(
|
async def generate_source(
|
||||||
source_documents: List[Any] | None,
|
knowledge_service: KnowledgeService,
|
||||||
brain_id: UUID,
|
brain_id: UUID,
|
||||||
|
source_documents: List[Any] | None,
|
||||||
citations: List[int] | None = None,
|
citations: List[int] | None = None,
|
||||||
) -> List[Sources]:
|
) -> List[Sources]:
|
||||||
"""
|
"""
|
||||||
@ -62,8 +64,11 @@ def generate_source(
|
|||||||
if is_url:
|
if is_url:
|
||||||
source_url = doc.metadata["original_file_name"]
|
source_url = doc.metadata["original_file_name"]
|
||||||
else:
|
else:
|
||||||
file_path = f"{brain_id}/{doc.metadata['file_name']}"
|
|
||||||
# Check if the URL has already been generated
|
# Check if the URL has already been generated
|
||||||
|
file_name = doc.metadata["file_name"]
|
||||||
|
file_path = await knowledge_service.get_knowledge_storage_path(
|
||||||
|
file_name=file_name, brain_id=brain_id
|
||||||
|
)
|
||||||
if file_path in generated_urls:
|
if file_path in generated_urls:
|
||||||
source_url = generated_urls[file_path]
|
source_url = generated_urls[file_path]
|
||||||
else:
|
else:
|
||||||
|
@ -77,6 +77,16 @@ from quivr_api.modules.user.entity.user_identity import User
|
|||||||
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
|
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def event_loop():
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def page_response() -> dict[str, Any]:
|
def page_response() -> dict[str, Any]:
|
||||||
json_path = (
|
json_path = (
|
||||||
@ -182,17 +192,7 @@ def fetch_response():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def event_loop():
|
def sync_engine():
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
yield loop
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
|
||||||
async def sync_engine():
|
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
"postgresql://" + pg_database_base_url,
|
"postgresql://" + pg_database_base_url,
|
||||||
echo=True if os.getenv("ORM_DEBUG") else False,
|
echo=True if os.getenv("ORM_DEBUG") else False,
|
||||||
@ -204,8 +204,8 @@ async def sync_engine():
|
|||||||
yield engine
|
yield engine
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture()
|
@pytest.fixture
|
||||||
async def sync_session(sync_engine):
|
def sync_session(sync_engine):
|
||||||
with sync_engine.connect() as conn:
|
with sync_engine.connect() as conn:
|
||||||
conn.begin()
|
conn.begin()
|
||||||
conn.begin_nested()
|
conn.begin_nested()
|
||||||
@ -273,7 +273,9 @@ def search_result():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(
|
||||||
|
scope="session",
|
||||||
|
)
|
||||||
async def async_engine():
|
async def async_engine():
|
||||||
engine = create_async_engine(
|
engine = create_async_engine(
|
||||||
"postgresql+asyncpg://" + pg_database_base_url,
|
"postgresql+asyncpg://" + pg_database_base_url,
|
||||||
|
Loading…
Reference in New Issue
Block a user