mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-11-10 04:55:33 +03:00
refactor: Prompt module (#1688)
# Description Prompt module with Service ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
This commit is contained in:
parent
18421ca272
commit
1bf67e3640
@ -13,7 +13,7 @@ from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from logger import get_logger
|
||||
from models.chats import ChatQuestion
|
||||
from models.databases.supabase.chats import CreateChatHistory
|
||||
from models.prompt import Prompt
|
||||
from modules.prompt.entity.prompt import Prompt
|
||||
from pydantic import BaseModel
|
||||
from repository.chat import (
|
||||
GetChatHistoryOutput,
|
||||
|
@ -2,8 +2,10 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
|
||||
from models.prompt import Prompt
|
||||
from repository.prompt import get_prompt_by_id
|
||||
from modules.prompt.entity.prompt import Prompt
|
||||
from modules.prompt.service import PromptService
|
||||
|
||||
promptService = PromptService()
|
||||
|
||||
|
||||
def get_prompt_to_use(
|
||||
@ -13,4 +15,4 @@ def get_prompt_to_use(
|
||||
if prompt_to_use_id is None:
|
||||
return None
|
||||
|
||||
return get_prompt_by_id(prompt_to_use_id)
|
||||
return promptService.get_prompt_by_id(prompt_to_use_id)
|
||||
|
@ -13,6 +13,7 @@ from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from logger import get_logger
|
||||
from middlewares.cors import add_cors_middleware
|
||||
from modules.prompt.controller.prompt_routes import prompt_router
|
||||
from modules.user.controller.user_controller import user_router
|
||||
from routes.api_key_routes import api_key_router
|
||||
from routes.brain_routes import brain_router
|
||||
@ -24,7 +25,6 @@ from routes.knowledge_routes import knowledge_router
|
||||
from routes.misc_routes import misc_router
|
||||
from routes.notification_routes import notification_router
|
||||
from routes.onboarding_routes import onboarding_router
|
||||
from routes.prompt_routes import prompt_router
|
||||
from routes.subscription_routes import subscription_router
|
||||
from routes.upload_routes import upload_router
|
||||
|
||||
|
@ -4,11 +4,9 @@ from .brains_subscription_invitations import BrainSubscription
|
||||
from .chat import Chat, ChatHistory
|
||||
from .chats import ChatMessage, ChatQuestion
|
||||
from .files import File
|
||||
from .prompt import Prompt, PromptStatusEnum
|
||||
from .settings import (BrainRateLimiting, BrainSettings, ContactsSettings,
|
||||
ResendSettings, get_embeddings,
|
||||
get_documents_vector_store, get_embeddings,
|
||||
get_supabase_client, get_supabase_db)
|
||||
ResendSettings, get_documents_vector_store,
|
||||
get_embeddings, get_supabase_client, get_supabase_db)
|
||||
from .user_usage import UserUsage
|
||||
|
||||
# TODO uncomment the below import when start using SQLalchemy
|
||||
|
@ -204,26 +204,6 @@ class Repository(ABC):
|
||||
def get_vectors_by_file_sha1(self, file_sha1):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_prompt(self, new_prompt):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_prompt_by_id(self, prompt_id: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_prompt_by_id(self, prompt_id: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_prompt_by_id(self, prompt_id: UUID, updates):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_public_prompts(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_notification(self, notification):
|
||||
pass
|
||||
|
@ -1,13 +1,12 @@
|
||||
from models.databases.supabase.api_brain_definition import ApiBrainDefinitions
|
||||
from models.databases.supabase.api_key_handler import ApiKeyHandler
|
||||
from models.databases.supabase.brains import Brain
|
||||
from models.databases.supabase.brains_subscription_invitations import BrainSubscription
|
||||
from models.databases.supabase.brains_subscription_invitations import \
|
||||
BrainSubscription
|
||||
from models.databases.supabase.chats import Chats
|
||||
from models.databases.supabase.files import File
|
||||
from models.databases.supabase.knowledge import Knowledges
|
||||
from models.databases.supabase.notifications import Notifications
|
||||
from models.databases.supabase.onboarding import Onboarding
|
||||
from models.databases.supabase.prompts import Prompts
|
||||
from models.databases.supabase.user_usage import UserUsage
|
||||
from models.databases.supabase.vectors import Vector
|
||||
|
||||
from models.databases.supabase.api_brain_definition import ApiBrainDefinitions
|
||||
|
@ -9,7 +9,6 @@ from models.databases.supabase import (
|
||||
Knowledges,
|
||||
Notifications,
|
||||
Onboarding,
|
||||
Prompts,
|
||||
UserUsage,
|
||||
Vector,
|
||||
)
|
||||
@ -26,7 +25,6 @@ class SupabaseDB(
|
||||
Chats,
|
||||
Vector,
|
||||
Onboarding,
|
||||
Prompts,
|
||||
Notifications,
|
||||
Knowledges,
|
||||
ApiBrainDefinitions,
|
||||
@ -40,7 +38,6 @@ class SupabaseDB(
|
||||
ApiKeyHandler.__init__(self, supabase_client)
|
||||
Chats.__init__(self, supabase_client)
|
||||
Vector.__init__(self, supabase_client)
|
||||
Prompts.__init__(self, supabase_client)
|
||||
Notifications.__init__(self, supabase_client)
|
||||
Knowledges.__init__(self, supabase_client)
|
||||
Onboarding.__init__(self, supabase_client)
|
||||
|
@ -1,16 +0,0 @@
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PromptStatusEnum(str, Enum):
|
||||
private = "private"
|
||||
public = "public"
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
status: PromptStatusEnum = PromptStatusEnum.private
|
||||
id: UUID
|
0
backend/modules/prompt/__init__.py
Normal file
0
backend/modules/prompt/__init__.py
Normal file
0
backend/modules/prompt/controller/__init__.py
Normal file
0
backend/modules/prompt/controller/__init__.py
Normal file
@ -2,28 +2,24 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from middlewares.auth import AuthBearer
|
||||
from models import Prompt
|
||||
from models.databases.supabase.prompts import (
|
||||
from modules.prompt.entity.prompt import (
|
||||
CreatePromptProperties,
|
||||
Prompt,
|
||||
PromptUpdatableProperties,
|
||||
)
|
||||
from repository.prompt import (
|
||||
create_prompt,
|
||||
get_prompt_by_id,
|
||||
get_public_prompts,
|
||||
update_prompt_by_id,
|
||||
)
|
||||
from modules.prompt.service import PromptService
|
||||
|
||||
prompt_router = APIRouter()
|
||||
|
||||
promptService = PromptService()
|
||||
|
||||
|
||||
@prompt_router.get("/prompts", dependencies=[Depends(AuthBearer())], tags=["Prompt"])
|
||||
async def get_prompts() -> list[Prompt]:
|
||||
"""
|
||||
Retrieve all public prompt
|
||||
"""
|
||||
|
||||
return get_public_prompts()
|
||||
return promptService.get_public_prompts()
|
||||
|
||||
|
||||
@prompt_router.get(
|
||||
@ -34,7 +30,7 @@ async def get_prompt(prompt_id: UUID) -> Prompt | None:
|
||||
Retrieve a prompt by its id
|
||||
"""
|
||||
|
||||
return get_prompt_by_id(prompt_id)
|
||||
return promptService.get_prompt_by_id(prompt_id)
|
||||
|
||||
|
||||
@prompt_router.put(
|
||||
@ -47,7 +43,7 @@ async def update_prompt(
|
||||
Update a prompt by its id
|
||||
"""
|
||||
|
||||
return update_prompt_by_id(prompt_id, prompt)
|
||||
return promptService.update_prompt_by_id(prompt_id, prompt)
|
||||
|
||||
|
||||
@prompt_router.post("/prompts", dependencies=[Depends(AuthBearer())], tags=["Prompt"])
|
||||
@ -56,4 +52,4 @@ async def create_prompt_route(prompt: CreatePromptProperties) -> Prompt | None:
|
||||
Create a prompt by its id
|
||||
"""
|
||||
|
||||
return create_prompt(prompt)
|
||||
return promptService.create_prompt(prompt)
|
1
backend/modules/prompt/entity/__init__.py
Normal file
1
backend/modules/prompt/entity/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .prompt import Prompt, PromptStatusEnum, CreatePromptProperties, PromptUpdatableProperties, DeletePromptResponse
|
40
backend/modules/prompt/entity/prompt.py
Normal file
40
backend/modules/prompt/entity/prompt.py
Normal file
@ -0,0 +1,40 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PromptStatusEnum(str, Enum):
|
||||
private = "private"
|
||||
public = "public"
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
status: PromptStatusEnum = PromptStatusEnum.private
|
||||
id: UUID
|
||||
|
||||
|
||||
class CreatePromptProperties(BaseModel):
|
||||
"""Properties that can be received on prompt creation"""
|
||||
|
||||
title: str
|
||||
content: str
|
||||
status: PromptStatusEnum = PromptStatusEnum.private
|
||||
|
||||
|
||||
class PromptUpdatableProperties(BaseModel):
|
||||
"""Properties that can be received on prompt update"""
|
||||
|
||||
title: Optional[str]
|
||||
content: Optional[str]
|
||||
status: Optional[PromptStatusEnum]
|
||||
|
||||
|
||||
class DeletePromptResponse(BaseModel):
|
||||
"""Response when deleting a prompt"""
|
||||
|
||||
status: str = "delete"
|
||||
prompt_id: UUID
|
0
backend/modules/prompt/repository/__init__.py
Normal file
0
backend/modules/prompt/repository/__init__.py
Normal file
@ -1,40 +1,16 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from models.databases.repository import Repository
|
||||
from models.prompt import Prompt, PromptStatusEnum
|
||||
from pydantic import BaseModel
|
||||
from modules.prompt.entity.prompt import Prompt
|
||||
from modules.prompt.repository.prompts_interface import (
|
||||
DeletePromptResponse,
|
||||
PromptsInterface,
|
||||
)
|
||||
|
||||
|
||||
class CreatePromptProperties(BaseModel):
|
||||
"""Properties that can be received on prompt creation"""
|
||||
|
||||
title: str
|
||||
content: str
|
||||
status: PromptStatusEnum = PromptStatusEnum.private
|
||||
|
||||
|
||||
class PromptUpdatableProperties(BaseModel):
|
||||
"""Properties that can be received on prompt update"""
|
||||
|
||||
title: Optional[str]
|
||||
content: Optional[str]
|
||||
status: Optional[PromptStatusEnum]
|
||||
|
||||
|
||||
class DeletePromptResponse(BaseModel):
|
||||
"""Response when deleting a prompt"""
|
||||
|
||||
status: str = "delete"
|
||||
prompt_id: UUID
|
||||
|
||||
|
||||
class Prompts(Repository):
|
||||
class Prompts(PromptsInterface):
|
||||
def __init__(self, supabase_client):
|
||||
self.db = supabase_client
|
||||
|
||||
def create_prompt(self, prompt: CreatePromptProperties) -> Prompt:
|
||||
def create_prompt(self, prompt):
|
||||
"""
|
||||
Create a prompt
|
||||
"""
|
||||
@ -43,7 +19,7 @@ class Prompts(Repository):
|
||||
|
||||
return Prompt(**response[0])
|
||||
|
||||
def delete_prompt_by_id(self, prompt_id: UUID) -> DeletePromptResponse:
|
||||
def delete_prompt_by_id(self, prompt_id):
|
||||
"""
|
||||
Delete a prompt by id
|
||||
Args:
|
||||
@ -65,7 +41,7 @@ class Prompts(Repository):
|
||||
|
||||
return DeletePromptResponse(status="deleted", prompt_id=prompt_id)
|
||||
|
||||
def get_prompt_by_id(self, prompt_id: UUID) -> Prompt | None:
|
||||
def get_prompt_by_id(self, prompt_id):
|
||||
"""
|
||||
Get a prompt by its id
|
||||
|
||||
@ -84,7 +60,7 @@ class Prompts(Repository):
|
||||
return None
|
||||
return Prompt(**response[0])
|
||||
|
||||
def get_public_prompts(self) -> list[Prompt]:
|
||||
def get_public_prompts(self):
|
||||
"""
|
||||
List all public prompts
|
||||
"""
|
||||
@ -96,9 +72,7 @@ class Prompts(Repository):
|
||||
.execute()
|
||||
).data
|
||||
|
||||
def update_prompt_by_id(
|
||||
self, prompt_id: UUID, prompt: PromptUpdatableProperties
|
||||
) -> Prompt:
|
||||
def update_prompt_by_id(self, prompt_id, prompt):
|
||||
"""Update a prompt by id"""
|
||||
|
||||
response = (
|
57
backend/modules/prompt/repository/prompts_interface.py
Normal file
57
backend/modules/prompt/repository/prompts_interface.py
Normal file
@ -0,0 +1,57 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from modules.prompt.entity import (
|
||||
CreatePromptProperties,
|
||||
DeletePromptResponse,
|
||||
Prompt,
|
||||
PromptUpdatableProperties,
|
||||
)
|
||||
|
||||
|
||||
class PromptsInterface(ABC):
|
||||
@abstractmethod
|
||||
def create_prompt(self, prompt: CreatePromptProperties) -> Prompt:
|
||||
"""
|
||||
Create a prompt
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_prompt_by_id(self, prompt_id: UUID) -> DeletePromptResponse:
|
||||
"""
|
||||
Delete a prompt by id
|
||||
Args:
|
||||
prompt_id (UUID): The id of the prompt
|
||||
|
||||
Returns:
|
||||
A dictionary containing the status of the delete and prompt_id of the deleted prompt
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_prompt_by_id(self, prompt_id: UUID) -> Prompt | None:
|
||||
"""
|
||||
Get a prompt by its id
|
||||
|
||||
Args:
|
||||
prompt_id (UUID): The id of the prompt
|
||||
|
||||
Returns:
|
||||
Prompt: The prompt
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_public_prompts(self) -> list[Prompt]:
|
||||
"""
|
||||
List all public prompts
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_prompt_by_id(
|
||||
self, prompt_id: UUID, prompt: PromptUpdatableProperties
|
||||
) -> Prompt:
|
||||
"""Update a prompt by id"""
|
||||
pass
|
1
backend/modules/prompt/service/__init__.py
Normal file
1
backend/modules/prompt/service/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .prompt_service import PromptService
|
59
backend/modules/prompt/service/prompt_service.py
Normal file
59
backend/modules/prompt/service/prompt_service.py
Normal file
@ -0,0 +1,59 @@
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from models.settings import get_supabase_client
|
||||
from modules.prompt.entity.prompt import (
|
||||
CreatePromptProperties,
|
||||
DeletePromptResponse,
|
||||
Prompt,
|
||||
PromptUpdatableProperties,
|
||||
)
|
||||
from modules.prompt.repository.prompts import Prompts
|
||||
|
||||
|
||||
class PromptService:
|
||||
repository: Prompts
|
||||
|
||||
def __init__(self):
|
||||
supabase_client = get_supabase_client()
|
||||
self.repository = Prompts(supabase_client)
|
||||
|
||||
def create_prompt(self, prompt: CreatePromptProperties) -> Prompt:
|
||||
return self.repository.create_prompt(prompt)
|
||||
|
||||
def delete_prompt_by_id(self, prompt_id: UUID) -> DeletePromptResponse:
|
||||
"""
|
||||
Delete a prompt by id
|
||||
Args:
|
||||
prompt_id (UUID): The id of the prompt
|
||||
|
||||
Returns:
|
||||
Prompt: The prompt
|
||||
"""
|
||||
return self.repository.delete_prompt_by_id(prompt_id)
|
||||
|
||||
def get_prompt_by_id(self, prompt_id: UUID) -> Prompt | None:
|
||||
"""
|
||||
Get a prompt by its id
|
||||
|
||||
Args:
|
||||
prompt_id (UUID): The id of the prompt
|
||||
|
||||
Returns:
|
||||
Prompt: The prompt
|
||||
"""
|
||||
return self.repository.get_prompt_by_id(prompt_id)
|
||||
|
||||
def get_public_prompts(self) -> List[Prompt]:
|
||||
"""
|
||||
List all public prompts
|
||||
"""
|
||||
|
||||
return self.repository.get_public_prompts()
|
||||
|
||||
def update_prompt_by_id(
|
||||
self, prompt_id: UUID, prompt: PromptUpdatableProperties
|
||||
) -> Prompt:
|
||||
"""Update a prompt by id"""
|
||||
|
||||
return self.repository.update_prompt_by_id(prompt_id, prompt)
|
@ -2,10 +2,11 @@ from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from models import ChatHistory, get_supabase_db
|
||||
from modules.prompt.service.prompt_service import PromptService
|
||||
from pydantic import BaseModel
|
||||
|
||||
from repository.brain import get_brain_by_id
|
||||
from repository.prompt import get_prompt_by_id
|
||||
|
||||
prompt_service = PromptService()
|
||||
|
||||
|
||||
class GetChatHistoryOutput(BaseModel):
|
||||
@ -40,7 +41,7 @@ def get_chat_history(chat_id: str) -> List[GetChatHistoryOutput]:
|
||||
|
||||
prompt = None
|
||||
if message.prompt_id:
|
||||
prompt = get_prompt_by_id(message.prompt_id)
|
||||
prompt = prompt_service.get_prompt_by_id(message.prompt_id)
|
||||
|
||||
enriched_history.append(
|
||||
GetChatHistoryOutput(
|
||||
|
@ -1,5 +0,0 @@
|
||||
from .create_prompt import create_prompt
|
||||
from .get_prompt_by_id import get_prompt_by_id
|
||||
from .get_public_prompts import get_public_prompts
|
||||
from .update_prompt_by_id import update_prompt_by_id
|
||||
from .delete_prompt_py_id import delete_prompt_by_id
|
@ -1,8 +0,0 @@
|
||||
from models.databases.supabase.prompts import CreatePromptProperties
|
||||
from models import Prompt, get_supabase_db
|
||||
|
||||
|
||||
def create_prompt(prompt: CreatePromptProperties) -> Prompt:
|
||||
supabase_db = get_supabase_db()
|
||||
|
||||
return supabase_db.create_prompt(prompt)
|
@ -1,17 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from models.databases.supabase.prompts import DeletePromptResponse
|
||||
from models import get_supabase_db
|
||||
|
||||
|
||||
def delete_prompt_by_id(prompt_id: UUID) -> DeletePromptResponse:
|
||||
"""
|
||||
Delete a prompt by id
|
||||
Args:
|
||||
prompt_id (UUID): The id of the prompt
|
||||
|
||||
Returns:
|
||||
Prompt: The prompt
|
||||
"""
|
||||
supabase_db = get_supabase_db()
|
||||
return supabase_db.delete_prompt_by_id(prompt_id)
|
@ -1,17 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from models import Prompt, get_supabase_db
|
||||
|
||||
|
||||
def get_prompt_by_id(prompt_id: UUID) -> Prompt | None:
|
||||
"""
|
||||
Get a prompt by its id
|
||||
|
||||
Args:
|
||||
prompt_id (UUID): The id of the prompt
|
||||
|
||||
Returns:
|
||||
Prompt: The prompt
|
||||
"""
|
||||
supabase_db = get_supabase_db()
|
||||
return supabase_db.get_prompt_by_id(prompt_id)
|
@ -1,9 +0,0 @@
|
||||
from models import Prompt, get_supabase_db
|
||||
|
||||
|
||||
def get_public_prompts() -> list[Prompt]:
|
||||
"""
|
||||
List all public prompts
|
||||
"""
|
||||
supabase_db = get_supabase_db()
|
||||
return supabase_db.get_public_prompts()
|
@ -1,11 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from models.databases.supabase.prompts import PromptUpdatableProperties
|
||||
from models import Prompt, get_supabase_db
|
||||
|
||||
|
||||
def update_prompt_by_id(prompt_id: UUID, prompt: PromptUpdatableProperties) -> Prompt:
|
||||
"""Update a prompt by id"""
|
||||
supabase_db = get_supabase_db()
|
||||
|
||||
return supabase_db.update_prompt_by_id(prompt_id, prompt)
|
@ -11,6 +11,7 @@ from models.databases.supabase.brains import (
|
||||
BrainUpdatableProperties,
|
||||
CreateBrainProperties,
|
||||
)
|
||||
from modules.prompt.service.prompt_service import PromptService
|
||||
from modules.user.entity.user_identity import UserIdentity
|
||||
from repository.brain import (
|
||||
create_brain,
|
||||
@ -26,17 +27,15 @@ from repository.brain import (
|
||||
update_brain_by_id,
|
||||
)
|
||||
from repository.brain.get_brain_for_user import get_brain_for_user
|
||||
from repository.external_api_secret.update_secret_value import (
|
||||
update_secret_value,
|
||||
)
|
||||
from repository.prompt import delete_prompt_by_id, get_prompt_by_id
|
||||
|
||||
from repository.external_api_secret.update_secret_value import update_secret_value
|
||||
from routes.authorizations.brain_authorization import has_brain_authorization
|
||||
from routes.authorizations.types import RoleEnum
|
||||
|
||||
logger = get_logger(__name__)
|
||||
brain_router = APIRouter()
|
||||
|
||||
prompt_service = PromptService()
|
||||
|
||||
|
||||
@brain_router.get("/brains/", dependencies=[Depends(AuthBearer())], tags=["Brain"])
|
||||
async def retrieve_all_brains_for_user(
|
||||
@ -147,9 +146,9 @@ async def update_existing_brain(
|
||||
update_brain_by_id(brain_id, brain_update_data)
|
||||
|
||||
if brain_update_data.prompt_id is None and existing_brain.prompt_id:
|
||||
prompt = get_prompt_by_id(existing_brain.prompt_id)
|
||||
prompt = prompt_service.get_prompt_by_id(existing_brain.prompt_id)
|
||||
if prompt and prompt.status == "private":
|
||||
delete_prompt_by_id(existing_brain.prompt_id)
|
||||
prompt_service.delete_prompt_by_id(existing_brain.prompt_id)
|
||||
|
||||
if brain_update_data.status == "private" and existing_brain.status == "public":
|
||||
delete_brain_users(brain_id)
|
||||
|
@ -3,7 +3,9 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from middlewares.auth.auth_bearer import AuthBearer, get_current_user
|
||||
from models import BrainSubscription, PromptStatusEnum
|
||||
from models import BrainSubscription
|
||||
from modules.prompt.entity.prompt import PromptStatusEnum
|
||||
from modules.prompt.service.prompt_service import PromptService
|
||||
from modules.user.entity.user_identity import UserIdentity
|
||||
from modules.user.service import get_user_email_by_user_id
|
||||
from modules.user.service.get_user_id_by_email import get_user_id_by_email
|
||||
@ -26,8 +28,6 @@ from repository.brain_subscription import (
|
||||
resend_invitation_email,
|
||||
)
|
||||
from repository.external_api_secret.create_secret import create_secret
|
||||
from repository.prompt import delete_prompt_by_id, get_prompt_by_id
|
||||
|
||||
from routes.authorizations.brain_authorization import (
|
||||
RoleEnum,
|
||||
has_brain_authorization,
|
||||
@ -38,6 +38,8 @@ from routes.headers.get_origin_header import get_origin_header
|
||||
subscription_router = APIRouter()
|
||||
subscription_service = SubscriptionInvitationService()
|
||||
|
||||
prompt_service = PromptService()
|
||||
|
||||
|
||||
@subscription_router.post(
|
||||
"/brains/{brain_id}/subscription",
|
||||
@ -168,11 +170,13 @@ async def remove_user_subscription(
|
||||
brain_id=brain_id,
|
||||
)
|
||||
if targeted_brain.prompt_id:
|
||||
brain_to_delete_prompt = get_prompt_by_id(targeted_brain.prompt_id)
|
||||
brain_to_delete_prompt = prompt_service.get_prompt_by_id(
|
||||
targeted_brain.prompt_id
|
||||
)
|
||||
if brain_to_delete_prompt is not None and (
|
||||
brain_to_delete_prompt.status == PromptStatusEnum.private
|
||||
):
|
||||
delete_prompt_by_id(targeted_brain.prompt_id)
|
||||
prompt_service.delete_prompt_by_id(targeted_brain.prompt_id)
|
||||
|
||||
else:
|
||||
delete_brain_user(
|
||||
|
Loading…
Reference in New Issue
Block a user