mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-14 07:59:00 +03:00
fix: url knowledge multiple brain (#3145)
# Description - Find knowledge path in storage
This commit is contained in:
parent
784c131441
commit
9a4ee1506b
@ -95,13 +95,13 @@ async def generate_signed_url_endpoint(
|
||||
|
||||
knowledge = await knowledge_service.get_knowledge(knowledge_id)
|
||||
|
||||
if len(knowledge.brains_ids) == 0:
|
||||
if len(knowledge.brain_ids) == 0:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="knowledge not associated with brains yet.",
|
||||
)
|
||||
|
||||
brain_id = knowledge.brains_ids[0]
|
||||
brain_id = knowledge.brain_ids[0]
|
||||
|
||||
validate_brain_authorization(brain_id=brain_id, user_id=current_user.id)
|
||||
|
||||
|
@ -100,6 +100,22 @@ class KnowledgeRepository(BaseRepository):
|
||||
|
||||
return knowledge
|
||||
|
||||
async def get_knowledge_by_file_name_brain_id(
|
||||
self, file_name: str, brain_id: UUID
|
||||
) -> KnowledgeDB:
|
||||
query = (
|
||||
select(KnowledgeDB)
|
||||
.where(KnowledgeDB.file_name == file_name)
|
||||
.where(KnowledgeDB.brains.any(brain_id=brain_id)) # type: ignore
|
||||
)
|
||||
|
||||
result = await self.session.exec(query)
|
||||
knowledge = result.first()
|
||||
if not knowledge:
|
||||
raise NoResultFound("Knowledge not found")
|
||||
|
||||
return knowledge
|
||||
|
||||
async def get_knowledge_by_sha1(self, sha1: str) -> KnowledgeDB:
|
||||
query = select(KnowledgeDB).where(KnowledgeDB.file_sha1 == sha1)
|
||||
result = await self.session.exec(query)
|
||||
|
@ -18,6 +18,7 @@ from quivr_api.modules.sync.entity.sync_models import (
|
||||
DownloadedSyncFile,
|
||||
SyncFile,
|
||||
)
|
||||
from quivr_api.modules.upload.service.upload_file import check_file_exists
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -35,6 +36,24 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
|
||||
km = await km.to_dto()
|
||||
return km
|
||||
|
||||
# TODO: this is temporary fix for getting knowledge path.
|
||||
# KM storage path should be unrelated to brain
|
||||
async def get_knowledge_storage_path(
|
||||
self, file_name: str, brain_id: UUID
|
||||
) -> str | None:
|
||||
try:
|
||||
km = await self.repository.get_knowledge_by_file_name_brain_id(
|
||||
file_name, brain_id
|
||||
)
|
||||
brains = await km.awaitable_attrs.brains
|
||||
return next(
|
||||
f"{b.brain_id}/{file_name}"
|
||||
for b in brains
|
||||
if check_file_exists(b.brain_id, file_name)
|
||||
)
|
||||
except NoResultFound:
|
||||
raise FileNotFoundError(f"No knowledge for file_name: {file_name}")
|
||||
|
||||
async def get_knowledge(self, knowledge_id: UUID) -> Knowledge:
|
||||
inserted_knowledge_db_instance = await self.repository.get_knowledge_by_id(
|
||||
knowledge_id
|
||||
|
@ -12,6 +12,7 @@ from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
|
||||
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
|
||||
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
|
||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||
from quivr_api.modules.upload.service.upload_file import upload_file_storage
|
||||
from quivr_api.modules.user.entity.user_identity import User
|
||||
from quivr_api.vector.entity.vector import Vector
|
||||
from sqlalchemy.exc import IntegrityError, NoResultFound
|
||||
@ -96,8 +97,8 @@ async def test_data(session: AsyncSession) -> TestData:
|
||||
)
|
||||
|
||||
knowledge_brain_1 = KnowledgeDB(
|
||||
file_name="test_file_1",
|
||||
extension="txt",
|
||||
file_name="test_file_1.txt",
|
||||
extension=".txt",
|
||||
status="UPLOADED",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
@ -108,8 +109,8 @@ async def test_data(session: AsyncSession) -> TestData:
|
||||
)
|
||||
|
||||
knowledge_brain_2 = KnowledgeDB(
|
||||
file_name="test_file_2",
|
||||
extension="txt",
|
||||
file_name="test_file_2.txt",
|
||||
extension=".txt",
|
||||
status="UPLOADED",
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
@ -349,7 +350,7 @@ async def test_should_process_knowledge_link_brain(
|
||||
assert brain.brain_id
|
||||
prev = KnowledgeDB(
|
||||
file_name="prev",
|
||||
extension="txt",
|
||||
extension=".txt",
|
||||
status=KnowledgeStatus.UPLOADED,
|
||||
source="test_source",
|
||||
source_link="test_source_link",
|
||||
@ -465,3 +466,29 @@ async def test_should_process_knowledge_prev_error(
|
||||
assert new.id
|
||||
new = await service.repository.get_knowledge_by_id(new.id)
|
||||
assert new.file_sha1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData):
|
||||
brain, [knowledge, _] = test_data
|
||||
assert knowledge.file_name
|
||||
repository = KnowledgeRepository(session)
|
||||
service = KnowledgeService(repository)
|
||||
brain_2 = Brain(
|
||||
name="test_brain",
|
||||
description="this is a test brain",
|
||||
brain_type=BrainType.integration,
|
||||
)
|
||||
session.add(brain_2)
|
||||
await session.commit()
|
||||
await session.refresh(brain_2)
|
||||
assert brain_2.brain_id
|
||||
km_data = os.urandom(128)
|
||||
km_path = f"{str(knowledge.brains[0].brain_id)}/{knowledge.file_name}"
|
||||
await upload_file_storage(km_data, km_path)
|
||||
# Link knowledge to two brains
|
||||
await repository.link_to_brain(knowledge, brain_2.brain_id)
|
||||
storage_path = await service.get_knowledge_storage_path(
|
||||
knowledge.file_name, brain_2.brain_id
|
||||
)
|
||||
assert storage_path == km_path
|
||||
|
@ -263,10 +263,11 @@ class RAGService:
|
||||
streamed_chat_history.metadata["snippet_emoji"] = (
|
||||
self.brain.snippet_emoji if self.brain else None
|
||||
)
|
||||
sources_urls = generate_source(
|
||||
response.metadata.sources,
|
||||
self.brain.brain_id,
|
||||
(
|
||||
sources_urls = await generate_source(
|
||||
knowledge_service=self.knowledge_service,
|
||||
brain_id=self.brain.brain_id,
|
||||
source_documents=response.metadata.sources,
|
||||
citations=(
|
||||
streamed_chat_history.metadata["citations"]
|
||||
if streamed_chat_history.metadata
|
||||
else None
|
||||
|
@ -3,6 +3,7 @@ from typing import Any, List
|
||||
from uuid import UUID
|
||||
|
||||
from quivr_api.modules.chat.dto.chats import Sources
|
||||
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
|
||||
from quivr_api.modules.upload.service.generate_file_signed_url import (
|
||||
generate_file_signed_url,
|
||||
)
|
||||
@ -11,9 +12,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: REFACTOR THIS, it does call the DB , so maybe in a service
|
||||
def generate_source(
|
||||
source_documents: List[Any] | None,
|
||||
async def generate_source(
|
||||
knowledge_service: KnowledgeService,
|
||||
brain_id: UUID,
|
||||
source_documents: List[Any] | None,
|
||||
citations: List[int] | None = None,
|
||||
) -> List[Sources]:
|
||||
"""
|
||||
@ -62,8 +64,11 @@ def generate_source(
|
||||
if is_url:
|
||||
source_url = doc.metadata["original_file_name"]
|
||||
else:
|
||||
file_path = f"{brain_id}/{doc.metadata['file_name']}"
|
||||
# Check if the URL has already been generated
|
||||
file_name = doc.metadata["file_name"]
|
||||
file_path = await knowledge_service.get_knowledge_storage_path(
|
||||
file_name=file_name, brain_id=brain_id
|
||||
)
|
||||
if file_path in generated_urls:
|
||||
source_url = generated_urls[file_path]
|
||||
else:
|
||||
|
@ -77,6 +77,16 @@ from quivr_api.modules.user.entity.user_identity import User
|
||||
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def page_response() -> dict[str, Any]:
|
||||
json_path = (
|
||||
@ -182,17 +192,7 @@ def fetch_response():
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def sync_engine():
|
||||
def sync_engine():
|
||||
engine = create_engine(
|
||||
"postgresql://" + pg_database_base_url,
|
||||
echo=True if os.getenv("ORM_DEBUG") else False,
|
||||
@ -204,8 +204,8 @@ async def sync_engine():
|
||||
yield engine
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def sync_session(sync_engine):
|
||||
@pytest.fixture
|
||||
def sync_session(sync_engine):
|
||||
with sync_engine.connect() as conn:
|
||||
conn.begin()
|
||||
conn.begin_nested()
|
||||
@ -273,7 +273,9 @@ def search_result():
|
||||
]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
@pytest_asyncio.fixture(
|
||||
scope="session",
|
||||
)
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
"postgresql+asyncpg://" + pg_database_base_url,
|
||||
|
Loading…
Reference in New Issue
Block a user