refactor: Prompt module (#1688)

# Description

Prompt module with Service

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
This commit is contained in:
Zineb El Bachiri 2023-11-23 14:13:21 +01:00 committed by GitHub
parent 18421ca272
commit 1bf67e3640
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 209 additions and 184 deletions

View File

@ -13,7 +13,7 @@ from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
from logger import get_logger
from models.chats import ChatQuestion
from models.databases.supabase.chats import CreateChatHistory
from models.prompt import Prompt
from modules.prompt.entity.prompt import Prompt
from pydantic import BaseModel
from repository.chat import (
GetChatHistoryOutput,

View File

@ -2,8 +2,10 @@ from typing import Optional
from uuid import UUID
from llm.utils.get_prompt_to_use_id import get_prompt_to_use_id
from models.prompt import Prompt
from repository.prompt import get_prompt_by_id
from modules.prompt.entity.prompt import Prompt
from modules.prompt.service import PromptService
promptService = PromptService()
def get_prompt_to_use(
@ -13,4 +15,4 @@ def get_prompt_to_use(
if prompt_to_use_id is None:
return None
return get_prompt_by_id(prompt_to_use_id)
return promptService.get_prompt_by_id(prompt_to_use_id)

View File

@ -13,6 +13,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from logger import get_logger
from middlewares.cors import add_cors_middleware
from modules.prompt.controller.prompt_routes import prompt_router
from modules.user.controller.user_controller import user_router
from routes.api_key_routes import api_key_router
from routes.brain_routes import brain_router
@ -24,7 +25,6 @@ from routes.knowledge_routes import knowledge_router
from routes.misc_routes import misc_router
from routes.notification_routes import notification_router
from routes.onboarding_routes import onboarding_router
from routes.prompt_routes import prompt_router
from routes.subscription_routes import subscription_router
from routes.upload_routes import upload_router

View File

@ -4,11 +4,9 @@ from .brains_subscription_invitations import BrainSubscription
from .chat import Chat, ChatHistory
from .chats import ChatMessage, ChatQuestion
from .files import File
from .prompt import Prompt, PromptStatusEnum
from .settings import (BrainRateLimiting, BrainSettings, ContactsSettings,
ResendSettings, get_embeddings,
get_documents_vector_store, get_embeddings,
get_supabase_client, get_supabase_db)
ResendSettings, get_documents_vector_store,
get_embeddings, get_supabase_client, get_supabase_db)
from .user_usage import UserUsage
# TODO uncomment the below import when start using SQLalchemy

View File

@ -204,26 +204,6 @@ class Repository(ABC):
def get_vectors_by_file_sha1(self, file_sha1):
pass
@abstractmethod
def create_prompt(self, new_prompt):
pass
@abstractmethod
def get_prompt_by_id(self, prompt_id: UUID):
pass
@abstractmethod
def delete_prompt_by_id(self, prompt_id: UUID):
pass
@abstractmethod
def update_prompt_by_id(self, prompt_id: UUID, updates):
pass
@abstractmethod
def get_public_prompts(self):
pass
@abstractmethod
def add_notification(self, notification):
pass

View File

@ -1,13 +1,12 @@
from models.databases.supabase.api_brain_definition import ApiBrainDefinitions
from models.databases.supabase.api_key_handler import ApiKeyHandler
from models.databases.supabase.brains import Brain
from models.databases.supabase.brains_subscription_invitations import BrainSubscription
from models.databases.supabase.brains_subscription_invitations import \
BrainSubscription
from models.databases.supabase.chats import Chats
from models.databases.supabase.files import File
from models.databases.supabase.knowledge import Knowledges
from models.databases.supabase.notifications import Notifications
from models.databases.supabase.onboarding import Onboarding
from models.databases.supabase.prompts import Prompts
from models.databases.supabase.user_usage import UserUsage
from models.databases.supabase.vectors import Vector
from models.databases.supabase.api_brain_definition import ApiBrainDefinitions

View File

@ -9,7 +9,6 @@ from models.databases.supabase import (
Knowledges,
Notifications,
Onboarding,
Prompts,
UserUsage,
Vector,
)
@ -26,7 +25,6 @@ class SupabaseDB(
Chats,
Vector,
Onboarding,
Prompts,
Notifications,
Knowledges,
ApiBrainDefinitions,
@ -40,7 +38,6 @@ class SupabaseDB(
ApiKeyHandler.__init__(self, supabase_client)
Chats.__init__(self, supabase_client)
Vector.__init__(self, supabase_client)
Prompts.__init__(self, supabase_client)
Notifications.__init__(self, supabase_client)
Knowledges.__init__(self, supabase_client)
Onboarding.__init__(self, supabase_client)

View File

@ -1,16 +0,0 @@
from enum import Enum
from uuid import UUID
from pydantic import BaseModel
class PromptStatusEnum(str, Enum):
private = "private"
public = "public"
class Prompt(BaseModel):
title: str
content: str
status: PromptStatusEnum = PromptStatusEnum.private
id: UUID

View File

View File

@ -2,28 +2,24 @@ from uuid import UUID
from fastapi import APIRouter, Depends
from middlewares.auth import AuthBearer
from models import Prompt
from models.databases.supabase.prompts import (
from modules.prompt.entity.prompt import (
CreatePromptProperties,
Prompt,
PromptUpdatableProperties,
)
from repository.prompt import (
create_prompt,
get_prompt_by_id,
get_public_prompts,
update_prompt_by_id,
)
from modules.prompt.service import PromptService
prompt_router = APIRouter()
promptService = PromptService()
@prompt_router.get("/prompts", dependencies=[Depends(AuthBearer())], tags=["Prompt"])
async def get_prompts() -> list[Prompt]:
"""
Retrieve all public prompt
"""
return get_public_prompts()
return promptService.get_public_prompts()
@prompt_router.get(
@ -34,7 +30,7 @@ async def get_prompt(prompt_id: UUID) -> Prompt | None:
Retrieve a prompt by its id
"""
return get_prompt_by_id(prompt_id)
return promptService.get_prompt_by_id(prompt_id)
@prompt_router.put(
@ -47,7 +43,7 @@ async def update_prompt(
Update a prompt by its id
"""
return update_prompt_by_id(prompt_id, prompt)
return promptService.update_prompt_by_id(prompt_id, prompt)
@prompt_router.post("/prompts", dependencies=[Depends(AuthBearer())], tags=["Prompt"])
@ -56,4 +52,4 @@ async def create_prompt_route(prompt: CreatePromptProperties) -> Prompt | None:
Create a prompt by its id
"""
return create_prompt(prompt)
return promptService.create_prompt(prompt)

View File

@ -0,0 +1 @@
from .prompt import Prompt, PromptStatusEnum, CreatePromptProperties, PromptUpdatableProperties, DeletePromptResponse

View File

@ -0,0 +1,40 @@
from enum import Enum
from typing import Optional
from uuid import UUID
from pydantic import BaseModel
class PromptStatusEnum(str, Enum):
private = "private"
public = "public"
class Prompt(BaseModel):
title: str
content: str
status: PromptStatusEnum = PromptStatusEnum.private
id: UUID
class CreatePromptProperties(BaseModel):
"""Properties that can be received on prompt creation"""
title: str
content: str
status: PromptStatusEnum = PromptStatusEnum.private
class PromptUpdatableProperties(BaseModel):
"""Properties that can be received on prompt update"""
title: Optional[str]
content: Optional[str]
status: Optional[PromptStatusEnum]
class DeletePromptResponse(BaseModel):
"""Response when deleting a prompt"""
status: str = "delete"
prompt_id: UUID

View File

@ -1,40 +1,16 @@
from typing import Optional
from uuid import UUID
from fastapi import HTTPException
from models.databases.repository import Repository
from models.prompt import Prompt, PromptStatusEnum
from pydantic import BaseModel
from modules.prompt.entity.prompt import Prompt
from modules.prompt.repository.prompts_interface import (
DeletePromptResponse,
PromptsInterface,
)
class CreatePromptProperties(BaseModel):
"""Properties that can be received on prompt creation"""
title: str
content: str
status: PromptStatusEnum = PromptStatusEnum.private
class PromptUpdatableProperties(BaseModel):
"""Properties that can be received on prompt update"""
title: Optional[str]
content: Optional[str]
status: Optional[PromptStatusEnum]
class DeletePromptResponse(BaseModel):
"""Response when deleting a prompt"""
status: str = "delete"
prompt_id: UUID
class Prompts(Repository):
class Prompts(PromptsInterface):
def __init__(self, supabase_client):
self.db = supabase_client
def create_prompt(self, prompt: CreatePromptProperties) -> Prompt:
def create_prompt(self, prompt):
"""
Create a prompt
"""
@ -43,7 +19,7 @@ class Prompts(Repository):
return Prompt(**response[0])
def delete_prompt_by_id(self, prompt_id: UUID) -> DeletePromptResponse:
def delete_prompt_by_id(self, prompt_id):
"""
Delete a prompt by id
Args:
@ -65,7 +41,7 @@ class Prompts(Repository):
return DeletePromptResponse(status="deleted", prompt_id=prompt_id)
def get_prompt_by_id(self, prompt_id: UUID) -> Prompt | None:
def get_prompt_by_id(self, prompt_id):
"""
Get a prompt by its id
@ -84,7 +60,7 @@ class Prompts(Repository):
return None
return Prompt(**response[0])
def get_public_prompts(self) -> list[Prompt]:
def get_public_prompts(self):
"""
List all public prompts
"""
@ -96,9 +72,7 @@ class Prompts(Repository):
.execute()
).data
def update_prompt_by_id(
self, prompt_id: UUID, prompt: PromptUpdatableProperties
) -> Prompt:
def update_prompt_by_id(self, prompt_id, prompt):
"""Update a prompt by id"""
response = (

View File

@ -0,0 +1,57 @@
from abc import ABC, abstractmethod
from uuid import UUID
from modules.prompt.entity import (
CreatePromptProperties,
DeletePromptResponse,
Prompt,
PromptUpdatableProperties,
)
class PromptsInterface(ABC):
@abstractmethod
def create_prompt(self, prompt: CreatePromptProperties) -> Prompt:
"""
Create a prompt
"""
pass
@abstractmethod
def delete_prompt_by_id(self, prompt_id: UUID) -> DeletePromptResponse:
"""
Delete a prompt by id
Args:
prompt_id (UUID): The id of the prompt
Returns:
A dictionary containing the status of the delete and prompt_id of the deleted prompt
"""
pass
@abstractmethod
def get_prompt_by_id(self, prompt_id: UUID) -> Prompt | None:
"""
Get a prompt by its id
Args:
prompt_id (UUID): The id of the prompt
Returns:
Prompt: The prompt
"""
pass
@abstractmethod
def get_public_prompts(self) -> list[Prompt]:
"""
List all public prompts
"""
pass
@abstractmethod
def update_prompt_by_id(
self, prompt_id: UUID, prompt: PromptUpdatableProperties
) -> Prompt:
"""Update a prompt by id"""
pass

View File

@ -0,0 +1 @@
from .prompt_service import PromptService

View File

@ -0,0 +1,59 @@
from typing import List
from uuid import UUID
from models.settings import get_supabase_client
from modules.prompt.entity.prompt import (
CreatePromptProperties,
DeletePromptResponse,
Prompt,
PromptUpdatableProperties,
)
from modules.prompt.repository.prompts import Prompts
class PromptService:
repository: Prompts
def __init__(self):
supabase_client = get_supabase_client()
self.repository = Prompts(supabase_client)
def create_prompt(self, prompt: CreatePromptProperties) -> Prompt:
return self.repository.create_prompt(prompt)
def delete_prompt_by_id(self, prompt_id: UUID) -> DeletePromptResponse:
"""
Delete a prompt by id
Args:
prompt_id (UUID): The id of the prompt
Returns:
Prompt: The prompt
"""
return self.repository.delete_prompt_by_id(prompt_id)
def get_prompt_by_id(self, prompt_id: UUID) -> Prompt | None:
"""
Get a prompt by its id
Args:
prompt_id (UUID): The id of the prompt
Returns:
Prompt: The prompt
"""
return self.repository.get_prompt_by_id(prompt_id)
def get_public_prompts(self) -> List[Prompt]:
"""
List all public prompts
"""
return self.repository.get_public_prompts()
def update_prompt_by_id(
self, prompt_id: UUID, prompt: PromptUpdatableProperties
) -> Prompt:
"""Update a prompt by id"""
return self.repository.update_prompt_by_id(prompt_id, prompt)

View File

@ -2,10 +2,11 @@ from typing import List, Optional
from uuid import UUID
from models import ChatHistory, get_supabase_db
from modules.prompt.service.prompt_service import PromptService
from pydantic import BaseModel
from repository.brain import get_brain_by_id
from repository.prompt import get_prompt_by_id
prompt_service = PromptService()
class GetChatHistoryOutput(BaseModel):
@ -40,7 +41,7 @@ def get_chat_history(chat_id: str) -> List[GetChatHistoryOutput]:
prompt = None
if message.prompt_id:
prompt = get_prompt_by_id(message.prompt_id)
prompt = prompt_service.get_prompt_by_id(message.prompt_id)
enriched_history.append(
GetChatHistoryOutput(

View File

@ -1,5 +0,0 @@
from .create_prompt import create_prompt
from .get_prompt_by_id import get_prompt_by_id
from .get_public_prompts import get_public_prompts
from .update_prompt_by_id import update_prompt_by_id
from .delete_prompt_py_id import delete_prompt_by_id

View File

@ -1,8 +0,0 @@
from models.databases.supabase.prompts import CreatePromptProperties
from models import Prompt, get_supabase_db
def create_prompt(prompt: CreatePromptProperties) -> Prompt:
supabase_db = get_supabase_db()
return supabase_db.create_prompt(prompt)

View File

@ -1,17 +0,0 @@
from uuid import UUID
from models.databases.supabase.prompts import DeletePromptResponse
from models import get_supabase_db
def delete_prompt_by_id(prompt_id: UUID) -> DeletePromptResponse:
"""
Delete a prompt by id
Args:
prompt_id (UUID): The id of the prompt
Returns:
Prompt: The prompt
"""
supabase_db = get_supabase_db()
return supabase_db.delete_prompt_by_id(prompt_id)

View File

@ -1,17 +0,0 @@
from uuid import UUID
from models import Prompt, get_supabase_db
def get_prompt_by_id(prompt_id: UUID) -> Prompt | None:
"""
Get a prompt by its id
Args:
prompt_id (UUID): The id of the prompt
Returns:
Prompt: The prompt
"""
supabase_db = get_supabase_db()
return supabase_db.get_prompt_by_id(prompt_id)

View File

@ -1,9 +0,0 @@
from models import Prompt, get_supabase_db
def get_public_prompts() -> list[Prompt]:
"""
List all public prompts
"""
supabase_db = get_supabase_db()
return supabase_db.get_public_prompts()

View File

@ -1,11 +0,0 @@
from uuid import UUID
from models.databases.supabase.prompts import PromptUpdatableProperties
from models import Prompt, get_supabase_db
def update_prompt_by_id(prompt_id: UUID, prompt: PromptUpdatableProperties) -> Prompt:
"""Update a prompt by id"""
supabase_db = get_supabase_db()
return supabase_db.update_prompt_by_id(prompt_id, prompt)

View File

@ -11,6 +11,7 @@ from models.databases.supabase.brains import (
BrainUpdatableProperties,
CreateBrainProperties,
)
from modules.prompt.service.prompt_service import PromptService
from modules.user.entity.user_identity import UserIdentity
from repository.brain import (
create_brain,
@ -26,17 +27,15 @@ from repository.brain import (
update_brain_by_id,
)
from repository.brain.get_brain_for_user import get_brain_for_user
from repository.external_api_secret.update_secret_value import (
update_secret_value,
)
from repository.prompt import delete_prompt_by_id, get_prompt_by_id
from repository.external_api_secret.update_secret_value import update_secret_value
from routes.authorizations.brain_authorization import has_brain_authorization
from routes.authorizations.types import RoleEnum
logger = get_logger(__name__)
brain_router = APIRouter()
prompt_service = PromptService()
@brain_router.get("/brains/", dependencies=[Depends(AuthBearer())], tags=["Brain"])
async def retrieve_all_brains_for_user(
@ -147,9 +146,9 @@ async def update_existing_brain(
update_brain_by_id(brain_id, brain_update_data)
if brain_update_data.prompt_id is None and existing_brain.prompt_id:
prompt = get_prompt_by_id(existing_brain.prompt_id)
prompt = prompt_service.get_prompt_by_id(existing_brain.prompt_id)
if prompt and prompt.status == "private":
delete_prompt_by_id(existing_brain.prompt_id)
prompt_service.delete_prompt_by_id(existing_brain.prompt_id)
if brain_update_data.status == "private" and existing_brain.status == "public":
delete_brain_users(brain_id)

View File

@ -3,7 +3,9 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from middlewares.auth.auth_bearer import AuthBearer, get_current_user
from models import BrainSubscription, PromptStatusEnum
from models import BrainSubscription
from modules.prompt.entity.prompt import PromptStatusEnum
from modules.prompt.service.prompt_service import PromptService
from modules.user.entity.user_identity import UserIdentity
from modules.user.service import get_user_email_by_user_id
from modules.user.service.get_user_id_by_email import get_user_id_by_email
@ -26,8 +28,6 @@ from repository.brain_subscription import (
resend_invitation_email,
)
from repository.external_api_secret.create_secret import create_secret
from repository.prompt import delete_prompt_by_id, get_prompt_by_id
from routes.authorizations.brain_authorization import (
RoleEnum,
has_brain_authorization,
@ -38,6 +38,8 @@ from routes.headers.get_origin_header import get_origin_header
subscription_router = APIRouter()
subscription_service = SubscriptionInvitationService()
prompt_service = PromptService()
@subscription_router.post(
"/brains/{brain_id}/subscription",
@ -168,11 +170,13 @@ async def remove_user_subscription(
brain_id=brain_id,
)
if targeted_brain.prompt_id:
brain_to_delete_prompt = get_prompt_by_id(targeted_brain.prompt_id)
brain_to_delete_prompt = prompt_service.get_prompt_by_id(
targeted_brain.prompt_id
)
if brain_to_delete_prompt is not None and (
brain_to_delete_prompt.status == PromptStatusEnum.private
):
delete_prompt_by_id(targeted_brain.prompt_id)
prompt_service.delete_prompt_by_id(targeted_brain.prompt_id)
else:
delete_brain_user(