mirror of
https://github.com/StanGirard/quivr.git
synced 2024-11-23 12:26:03 +03:00
feat(prompt): add prompt table, entity and repository (#823)
* feat: add prompts table * feat: add Prompt entity * feat: add prompt router * refactor(promptRepository): use common reposority
This commit is contained in:
parent
23f50ec3a3
commit
e3b6114248
@ -13,6 +13,7 @@ from routes.chat_routes import chat_router
|
||||
from routes.crawl_routes import crawl_router
|
||||
from routes.explore_routes import explore_router
|
||||
from routes.misc_routes import misc_router
|
||||
from routes.prompt_routes import prompt_router
|
||||
from routes.subscription_routes import subscription_router
|
||||
from routes.upload_routes import upload_router
|
||||
from routes.user_routes import user_router
|
||||
@ -46,6 +47,7 @@ app.include_router(upload_router)
|
||||
app.include_router(user_router)
|
||||
app.include_router(api_key_router)
|
||||
app.include_router(subscription_router)
|
||||
app.include_router(prompt_router)
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
|
@ -199,3 +199,23 @@ class Repository(ABC):
|
||||
@abstractmethod
|
||||
def get_vectors_by_file_sha1(self, file_sha1):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_prompt(self, new_prompt):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_prompt_by_id(self, prompt_id: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_prompt_by_id(self, prompt_id: UUID):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_prompt_by_id(self, prompt_id: UUID, updates):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_public_prompts(self):
|
||||
pass
|
||||
|
@ -1,7 +1,8 @@
|
||||
from models.databases.supabase.brains import Brain
|
||||
from models.databases.supabase.users import User
|
||||
from models.databases.supabase.files import File
|
||||
from models.databases.supabase.brains_subscription_invitations import BrainSubscription
|
||||
from models.databases.supabase.api_key_handler import ApiKeyHandler
|
||||
from models.databases.supabase.brains import Brain
|
||||
from models.databases.supabase.brains_subscription_invitations import BrainSubscription
|
||||
from models.databases.supabase.chats import Chats
|
||||
from models.databases.supabase.files import File
|
||||
from models.databases.supabase.prompts import Prompts
|
||||
from models.databases.supabase.users import User
|
||||
from models.databases.supabase.vectors import Vector
|
||||
|
103
backend/core/models/databases/supabase/prompts.py
Normal file
103
backend/core/models/databases/supabase/prompts.py
Normal file
@ -0,0 +1,103 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from models.databases.repository import Repository
|
||||
from models.prompt import Prompt
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CreatePromptProperties(BaseModel):
|
||||
"""Properties that can be received on prompt creation"""
|
||||
|
||||
title: str
|
||||
content: str
|
||||
status: str = "private"
|
||||
|
||||
|
||||
class PromptUpdatableProperties(BaseModel):
|
||||
"""Properties that can be received on prompt update"""
|
||||
|
||||
title: Optional[str]
|
||||
content: Optional[str]
|
||||
status: Optional[str]
|
||||
|
||||
|
||||
class Prompts(Repository):
|
||||
def __init__(self, supabase_client):
|
||||
self.db = supabase_client
|
||||
|
||||
def create_prompt(self, prompt: CreatePromptProperties) -> Prompt:
|
||||
"""Create a prompt by id"""
|
||||
|
||||
response = (self.db.from_("prompts").insert(prompt.dict()).execute()).data
|
||||
|
||||
return Prompt(**response[0])
|
||||
|
||||
def delete_prompt_by_id(self, prompt_id: UUID) -> Prompt | None:
|
||||
"""
|
||||
Delete a prompt by id
|
||||
Args:
|
||||
prompt_id (UUID): The id of the prompt
|
||||
|
||||
Returns:
|
||||
Prompt: The prompt
|
||||
"""
|
||||
response = (
|
||||
self.db.from_("prompts")
|
||||
.delete()
|
||||
.filter("id", "eq", prompt_id)
|
||||
.execute()
|
||||
.data
|
||||
)
|
||||
if response == []:
|
||||
raise HTTPException(404, "Prompt not found")
|
||||
return Prompt(**response[0])
|
||||
|
||||
def get_prompt_by_id(self, prompt_id: UUID) -> Prompt | None:
|
||||
"""
|
||||
Get a prompt by its id
|
||||
|
||||
Args:
|
||||
prompt_id (UUID): The id of the prompt
|
||||
|
||||
Returns:
|
||||
Prompt: The prompt
|
||||
"""
|
||||
|
||||
response = (
|
||||
self.db.from_("prompts").select("*").filter("id", "eq", prompt_id).execute()
|
||||
).data
|
||||
|
||||
if response == []:
|
||||
return None
|
||||
return Prompt(**response[0])
|
||||
|
||||
def get_public_prompts(self) -> list[Prompt]:
|
||||
"""
|
||||
List all public prompts
|
||||
"""
|
||||
|
||||
return (
|
||||
self.db.from_("prompts")
|
||||
.select("*")
|
||||
.filter("status", "eq", "public")
|
||||
.execute()
|
||||
).data
|
||||
|
||||
def update_prompt_by_id(
|
||||
self, prompt_id: UUID, prompt: PromptUpdatableProperties
|
||||
) -> Prompt:
|
||||
"""Update a prompt by id"""
|
||||
|
||||
response = (
|
||||
self.db.from_("prompts")
|
||||
.update(prompt.dict(exclude_unset=True))
|
||||
.filter("id", "eq", prompt_id)
|
||||
.execute()
|
||||
).data
|
||||
|
||||
if response == []:
|
||||
raise HTTPException(404, "Prompt not found")
|
||||
|
||||
return Prompt(**response[0])
|
@ -1,19 +1,21 @@
|
||||
from logger import get_logger
|
||||
from models.databases.supabase import (
|
||||
Brain,
|
||||
User,
|
||||
File,
|
||||
BrainSubscription,
|
||||
ApiKeyHandler,
|
||||
Brain,
|
||||
BrainSubscription,
|
||||
Chats,
|
||||
File,
|
||||
Prompts,
|
||||
User,
|
||||
Vector,
|
||||
)
|
||||
from logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SupabaseDB(Brain, User, File, BrainSubscription, ApiKeyHandler, Chats, Vector):
|
||||
class SupabaseDB(
|
||||
Brain, User, File, BrainSubscription, ApiKeyHandler, Chats, Vector, Prompts
|
||||
):
|
||||
def __init__(self, supabase_client):
|
||||
self.db = supabase_client
|
||||
Brain.__init__(self, supabase_client)
|
||||
@ -23,3 +25,4 @@ class SupabaseDB(Brain, User, File, BrainSubscription, ApiKeyHandler, Chats, Vec
|
||||
ApiKeyHandler.__init__(self, supabase_client)
|
||||
Chats.__init__(self, supabase_client)
|
||||
Vector.__init__(self, supabase_client)
|
||||
Prompts.__init__(self, supabase_client)
|
||||
|
10
backend/core/models/prompt.py
Normal file
10
backend/core/models/prompt.py
Normal file
@ -0,0 +1,10 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
status: str = "private"
|
||||
id: UUID
|
10
backend/core/repository/prompt/create_prompt.py
Normal file
10
backend/core/repository/prompt/create_prompt.py
Normal file
@ -0,0 +1,10 @@
|
||||
from models.databases.supabase.prompts import CreatePromptProperties
|
||||
from models.prompt import Prompt
|
||||
from models.settings import common_dependencies
|
||||
|
||||
|
||||
def create_prompt(prompt: CreatePromptProperties) -> Prompt:
|
||||
"""Create a prompt by id"""
|
||||
commons = common_dependencies()
|
||||
|
||||
return commons["db"].create_prompt(prompt)
|
17
backend/core/repository/prompt/delete_prompt_py_id.py
Normal file
17
backend/core/repository/prompt/delete_prompt_py_id.py
Normal file
@ -0,0 +1,17 @@
|
||||
from uuid import UUID
|
||||
|
||||
from models.prompt import Prompt
|
||||
from models.settings import common_dependencies
|
||||
|
||||
|
||||
def delete_prompt_by_id(prompt_id: UUID) -> Prompt | None:
|
||||
"""
|
||||
Delete a prompt by id
|
||||
Args:
|
||||
prompt_id (UUID): The id of the prompt
|
||||
|
||||
Returns:
|
||||
Prompt: The prompt
|
||||
"""
|
||||
commons = common_dependencies()
|
||||
return commons["db"].delete_prompt_by_id(prompt_id)
|
18
backend/core/repository/prompt/get_prompt_by_id.py
Normal file
18
backend/core/repository/prompt/get_prompt_by_id.py
Normal file
@ -0,0 +1,18 @@
|
||||
from uuid import UUID
|
||||
|
||||
from models.prompt import Prompt
|
||||
from models.settings import common_dependencies
|
||||
|
||||
|
||||
def get_prompt_by_id(prompt_id: UUID) -> Prompt | None:
|
||||
"""
|
||||
Get a prompt by its id
|
||||
|
||||
Args:
|
||||
prompt_id (UUID): The id of the prompt
|
||||
|
||||
Returns:
|
||||
Prompt: The prompt
|
||||
"""
|
||||
commons = common_dependencies()
|
||||
return commons["db"].get_prompt_by_id(prompt_id)
|
10
backend/core/repository/prompt/get_public_prompts.py
Normal file
10
backend/core/repository/prompt/get_public_prompts.py
Normal file
@ -0,0 +1,10 @@
|
||||
from models.prompt import Prompt
|
||||
from models.settings import common_dependencies
|
||||
|
||||
|
||||
def get_public_prompts() -> list[Prompt]:
|
||||
"""
|
||||
List all public prompts
|
||||
"""
|
||||
commons = common_dependencies()
|
||||
return commons["db"].get_public_prompts()
|
12
backend/core/repository/prompt/update_prompt_by_id.py
Normal file
12
backend/core/repository/prompt/update_prompt_by_id.py
Normal file
@ -0,0 +1,12 @@
|
||||
from uuid import UUID
|
||||
|
||||
from models.databases.supabase.prompts import PromptUpdatableProperties
|
||||
from models.prompt import Prompt
|
||||
from models.settings import common_dependencies
|
||||
|
||||
|
||||
def update_prompt_by_id(prompt_id: UUID, prompt: PromptUpdatableProperties) -> Prompt:
|
||||
"""Update a prompt by id"""
|
||||
commons = common_dependencies()
|
||||
|
||||
return commons["db"].update_prompt_by_id(prompt_id, prompt)
|
15
backend/core/routes/prompt_routes.py
Normal file
15
backend/core/routes/prompt_routes.py
Normal file
@ -0,0 +1,15 @@
|
||||
from auth import AuthBearer
|
||||
from fastapi import APIRouter, Depends
|
||||
from models.prompt import Prompt
|
||||
from repository.prompt.get_public_prompts import get_public_prompts
|
||||
|
||||
prompt_router = APIRouter()
|
||||
|
||||
|
||||
@prompt_router.get("/prompts", dependencies=[Depends(AuthBearer())], tags=["Prompt"])
|
||||
async def get_prompts() -> list[Prompt]:
|
||||
"""
|
||||
Retrieve all public prompt
|
||||
"""
|
||||
|
||||
return get_public_prompts()
|
19
scripts/20230701180101_add_prompts_table.sql
Normal file
19
scripts/20230701180101_add_prompts_table.sql
Normal file
@ -0,0 +1,19 @@
|
||||
BEGIN;
|
||||
|
||||
-- Create user_identity table if it doesn't exist
|
||||
CREATE TABLE IF NOT EXISTS prompts (
|
||||
id UUID DEFAULT uuid_generate_v4() PRIMARY KEY,
|
||||
title VARCHAR(255),
|
||||
content TEXT,
|
||||
status VARCHAR(255) DEFAULT 'private'
|
||||
);
|
||||
|
||||
|
||||
-- Insert migration record if it doesn't exist
|
||||
INSERT INTO migrations (name)
|
||||
SELECT '20230701180101_add_prompts_table'
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM migrations WHERE name = '20230701180101_add_prompts_table'
|
||||
);
|
||||
|
||||
COMMIT;
|
@ -173,6 +173,16 @@ CREATE TABLE IF NOT EXISTS user_identity (
|
||||
openai_api_key VARCHAR(255)
|
||||
);
|
||||
|
||||
|
||||
--- Create prompts table
|
||||
CREATE TABLE IF NOT EXISTS prompts (
|
||||
id UUID DEFAULT uuid_generate_v4() PRIMARY KEY,
|
||||
title VARCHAR(255),
|
||||
content TEXT,
|
||||
status VARCHAR(255) DEFAULT 'private'
|
||||
);
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION public.get_user_email_by_user_id(user_id uuid)
|
||||
RETURNS TABLE (email text)
|
||||
SECURITY definer
|
||||
@ -200,7 +210,7 @@ CREATE TABLE IF NOT EXISTS migrations (
|
||||
);
|
||||
|
||||
INSERT INTO migrations (name)
|
||||
SELECT '20230731172400_add_user_identity_table'
|
||||
SELECT '20230701180101_add_prompts_table'
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM migrations WHERE name = '20230731172400_add_user_identity_table'
|
||||
SELECT 1 FROM migrations WHERE name = '20230701180101_add_prompts_table'
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user