2023-09-28 17:09:52 +03:00
package mysql
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/pkg/errors"
2023-10-26 04:02:50 +03:00
"github.com/usememos/memos/internal/util"
2023-09-28 17:09:52 +03:00
"github.com/usememos/memos/store"
)
2023-10-05 18:11:29 +03:00
func ( d * DB ) CreateMemo ( ctx context . Context , create * store . Memo ) ( * store . Memo , error ) {
2023-10-08 13:29:12 +03:00
fields := [ ] string { "`creator_id`" , "`content`" , "`visibility`" }
placeholder := [ ] string { "?" , "?" , "?" }
args := [ ] any { create . CreatorID , create . Content , create . Visibility }
if create . ID != 0 {
fields = append ( fields , "`id`" )
placeholder = append ( placeholder , "?" )
args = append ( args , create . ID )
}
if create . CreatedTs != 0 {
fields = append ( fields , "`created_ts`" )
placeholder = append ( placeholder , "FROM_UNIXTIME(?)" )
args = append ( args , create . CreatedTs )
}
if create . UpdatedTs != 0 {
fields = append ( fields , "`updated_ts`" )
placeholder = append ( placeholder , "FROM_UNIXTIME(?)" )
args = append ( args , create . UpdatedTs )
}
if create . RowStatus != "" {
fields = append ( fields , "`row_status`" )
placeholder = append ( placeholder , "?" )
args = append ( args , create . RowStatus )
}
stmt := "INSERT INTO memo (" + strings . Join ( fields , ", " ) + ") VALUES (" + strings . Join ( placeholder , ", " ) + ")"
result , err := d . db . ExecContext ( ctx , stmt , args ... )
2023-09-28 17:09:52 +03:00
if err != nil {
return nil , err
}
2023-09-29 04:15:54 +03:00
rawID , err := result . LastInsertId ( )
2023-09-28 17:09:52 +03:00
if err != nil {
return nil , err
}
2023-09-29 04:15:54 +03:00
id := int32 ( rawID )
memo , err := d . GetMemo ( ctx , & store . FindMemo { ID : & id } )
if err != nil {
2023-09-28 17:09:52 +03:00
return nil , err
}
2023-09-29 04:15:54 +03:00
if memo == nil {
return nil , errors . Errorf ( "failed to create memo" )
}
return memo , nil
2023-09-28 17:09:52 +03:00
}
2023-10-05 18:11:29 +03:00
func ( d * DB ) ListMemos ( ctx context . Context , find * store . FindMemo ) ( [ ] * store . Memo , error ) {
2023-09-28 17:09:52 +03:00
where , args := [ ] string { "1 = 1" } , [ ] any { }
if v := find . ID ; v != nil {
2023-10-07 17:56:12 +03:00
where , args = append ( where , "`memo`.`id` = ?" ) , append ( args , * v )
2023-09-28 17:09:52 +03:00
}
if v := find . CreatorID ; v != nil {
2023-10-07 17:56:12 +03:00
where , args = append ( where , "`memo`.`creator_id` = ?" ) , append ( args , * v )
2023-09-28 17:09:52 +03:00
}
if v := find . RowStatus ; v != nil {
2023-10-07 17:56:12 +03:00
where , args = append ( where , "`memo`.`row_status` = ?" ) , append ( args , * v )
2023-09-28 17:09:52 +03:00
}
if v := find . CreatedTsBefore ; v != nil {
2023-10-07 17:56:12 +03:00
where , args = append ( where , "UNIX_TIMESTAMP(`memo`.`created_ts`) < ?" ) , append ( args , * v )
2023-09-28 17:09:52 +03:00
}
if v := find . CreatedTsAfter ; v != nil {
2023-10-07 17:56:12 +03:00
where , args = append ( where , "UNIX_TIMESTAMP(`memo`.`created_ts`) > ?" ) , append ( args , * v )
2023-09-28 17:09:52 +03:00
}
if v := find . Pinned ; v != nil {
2023-10-07 17:56:12 +03:00
where = append ( where , "`memo_organizer`.`pinned` = 1" )
2023-09-28 17:09:52 +03:00
}
if v := find . ContentSearch ; len ( v ) != 0 {
for _ , s := range v {
2023-10-07 17:56:12 +03:00
where , args = append ( where , "`memo`.`content` LIKE ?" ) , append ( args , "%" + s + "%" )
2023-09-28 17:09:52 +03:00
}
}
if v := find . VisibilityList ; len ( v ) != 0 {
list := [ ] string { }
for _ , visibility := range v {
list = append ( list , "?" )
args = append ( args , visibility )
}
2023-10-07 17:56:12 +03:00
where = append ( where , fmt . Sprintf ( "`memo`.`visibility` in (%s)" , strings . Join ( list , "," ) ) )
2023-09-28 17:09:52 +03:00
}
2023-10-07 17:56:12 +03:00
orders := [ ] string { "`pinned` DESC" }
2023-09-28 17:09:52 +03:00
if find . OrderByUpdatedTs {
2023-10-07 17:56:12 +03:00
orders = append ( orders , "`updated_ts` DESC" )
2023-09-28 17:09:52 +03:00
} else {
2023-10-07 17:56:12 +03:00
orders = append ( orders , "`created_ts` DESC" )
2023-09-28 17:09:52 +03:00
}
2023-10-07 17:56:12 +03:00
orders = append ( orders , "`id` DESC" )
2023-09-28 17:09:52 +03:00
2023-10-07 17:56:12 +03:00
query := "SELECT `memo`.`id` AS `id`, `memo`.`creator_id` AS `creator_id`, UNIX_TIMESTAMP(`memo`.`created_ts`) AS `created_ts`, UNIX_TIMESTAMP(`memo`.`updated_ts`) AS `updated_ts`, `memo`.`row_status` AS `row_status`, `memo`.`content` AS `content`, `memo`.`visibility` AS `visibility`, MAX(CASE WHEN `memo_organizer`.`pinned` = 1 THEN 1 ELSE 0 END) AS `pinned`, GROUP_CONCAT(`resource`.`id`) AS `resource_id_list`, (SELECT GROUP_CONCAT(`memo_id`,':',`related_memo_id`,':',`type`) FROM `memo_relation` WHERE `memo_relation`.`memo_id` = `memo`.`id` OR `memo_relation`.`related_memo_id` = `memo`.`id` ) AS `relation_list` FROM `memo` LEFT JOIN `memo_organizer` ON `memo`.`id` = `memo_organizer`.`memo_id` LEFT JOIN `resource` ON `memo`.`id` = `resource`.`memo_id` WHERE " + strings . Join ( where , " AND " ) + " GROUP BY `memo`.`id` ORDER BY " + strings . Join ( orders , ", " )
2023-09-28 17:09:52 +03:00
if find . Limit != nil {
query = fmt . Sprintf ( "%s LIMIT %d" , query , * find . Limit )
if find . Offset != nil {
query = fmt . Sprintf ( "%s OFFSET %d" , query , * find . Offset )
}
}
rows , err := d . db . QueryContext ( ctx , query , args ... )
if err != nil {
return nil , err
}
defer rows . Close ( )
list := make ( [ ] * store . Memo , 0 )
for rows . Next ( ) {
var memo store . Memo
var memoResourceIDList sql . NullString
var memoRelationList sql . NullString
if err := rows . Scan (
& memo . ID ,
& memo . CreatorID ,
& memo . CreatedTs ,
& memo . UpdatedTs ,
& memo . RowStatus ,
& memo . Content ,
& memo . Visibility ,
& memo . Pinned ,
& memoResourceIDList ,
& memoRelationList ,
) ; err != nil {
return nil , err
}
if memoResourceIDList . Valid {
idStringList := strings . Split ( memoResourceIDList . String , "," )
memo . ResourceIDList = make ( [ ] int32 , 0 , len ( idStringList ) )
for _ , idString := range idStringList {
id , err := util . ConvertStringToInt32 ( idString )
if err != nil {
return nil , err
}
memo . ResourceIDList = append ( memo . ResourceIDList , id )
}
}
if memoRelationList . Valid {
memo . RelationList = make ( [ ] * store . MemoRelation , 0 )
relatedMemoTypeList := strings . Split ( memoRelationList . String , "," )
for _ , relatedMemoType := range relatedMemoTypeList {
relatedMemoTypeList := strings . Split ( relatedMemoType , ":" )
2023-10-01 16:35:17 +03:00
if len ( relatedMemoTypeList ) != 3 {
2023-09-28 17:09:52 +03:00
return nil , errors . Errorf ( "invalid relation format" )
}
2023-10-01 16:35:17 +03:00
memoID , err := util . ConvertStringToInt32 ( relatedMemoTypeList [ 0 ] )
2023-09-28 17:09:52 +03:00
if err != nil {
return nil , err
}
2023-10-01 16:35:17 +03:00
relatedMemoID , err := util . ConvertStringToInt32 ( relatedMemoTypeList [ 1 ] )
if err != nil {
return nil , err
}
relationType := store . MemoRelationType ( relatedMemoTypeList [ 2 ] )
2023-09-28 17:09:52 +03:00
memo . RelationList = append ( memo . RelationList , & store . MemoRelation {
2023-10-01 16:35:17 +03:00
MemoID : memoID ,
2023-09-28 17:09:52 +03:00
RelatedMemoID : relatedMemoID ,
2023-10-01 11:27:40 +03:00
Type : relationType ,
2023-09-28 17:09:52 +03:00
} )
2023-10-01 11:27:40 +03:00
// Set the first parent ID if relation type is comment.
2023-10-01 16:35:17 +03:00
if memo . ParentID == nil && memoID == memo . ID && relationType == store . MemoRelationComment {
2023-10-01 11:27:40 +03:00
memo . ParentID = & relatedMemoID
}
2023-09-28 17:09:52 +03:00
}
}
list = append ( list , & memo )
}
if err := rows . Err ( ) ; err != nil {
return nil , err
}
return list , nil
}
2023-10-05 18:11:29 +03:00
func ( d * DB ) GetMemo ( ctx context . Context , find * store . FindMemo ) ( * store . Memo , error ) {
2023-09-29 04:15:54 +03:00
list , err := d . ListMemos ( ctx , find )
if err != nil {
return nil , err
}
if len ( list ) == 0 {
return nil , nil
}
memo := list [ 0 ]
return memo , nil
}
2023-10-05 18:11:29 +03:00
func ( d * DB ) UpdateMemo ( ctx context . Context , update * store . UpdateMemo ) error {
2023-09-28 17:09:52 +03:00
set , args := [ ] string { } , [ ] any { }
if v := update . CreatedTs ; v != nil {
2023-10-07 17:56:12 +03:00
set , args = append ( set , "`created_ts` = FROM_UNIXTIME(?)" ) , append ( args , * v )
2023-09-28 17:09:52 +03:00
}
if v := update . UpdatedTs ; v != nil {
2023-10-07 17:56:12 +03:00
set , args = append ( set , "`updated_ts` = FROM_UNIXTIME(?)" ) , append ( args , * v )
2023-09-28 17:09:52 +03:00
}
if v := update . RowStatus ; v != nil {
2023-10-07 17:56:12 +03:00
set , args = append ( set , "`row_status` = ?" ) , append ( args , * v )
2023-09-28 17:09:52 +03:00
}
if v := update . Content ; v != nil {
2023-10-07 17:56:12 +03:00
set , args = append ( set , "`content` = ?" ) , append ( args , * v )
2023-09-28 17:09:52 +03:00
}
if v := update . Visibility ; v != nil {
2023-10-07 17:56:12 +03:00
set , args = append ( set , "`visibility` = ?" ) , append ( args , * v )
2023-09-28 17:09:52 +03:00
}
args = append ( args , update . ID )
2023-10-07 17:56:12 +03:00
stmt := "UPDATE `memo` SET " + strings . Join ( set , ", " ) + " WHERE `id` = ?"
2023-09-28 17:09:52 +03:00
if _ , err := d . db . ExecContext ( ctx , stmt , args ... ) ; err != nil {
return err
}
return nil
}
2023-10-05 18:11:29 +03:00
func ( d * DB ) DeleteMemo ( ctx context . Context , delete * store . DeleteMemo ) error {
2023-10-07 17:56:12 +03:00
where , args := [ ] string { "`id` = ?" } , [ ] any { delete . ID }
stmt := "DELETE FROM `memo` WHERE " + strings . Join ( where , " AND " )
2023-09-28 17:09:52 +03:00
result , err := d . db . ExecContext ( ctx , stmt , args ... )
if err != nil {
return err
}
if _ , err := result . RowsAffected ( ) ; err != nil {
return err
}
if err := d . Vacuum ( ctx ) ; err != nil {
// Prevent linter warning.
return err
}
return nil
}
2023-10-05 18:11:29 +03:00
func ( d * DB ) FindMemosVisibilityList ( ctx context . Context , memoIDs [ ] int32 ) ( [ ] store . Visibility , error ) {
2023-09-28 17:09:52 +03:00
args := make ( [ ] any , 0 , len ( memoIDs ) )
list := make ( [ ] string , 0 , len ( memoIDs ) )
for _ , memoID := range memoIDs {
args = append ( args , memoID )
list = append ( list , "?" )
}
2023-10-07 17:56:12 +03:00
where := fmt . Sprintf ( "`id` in (%s)" , strings . Join ( list , "," ) )
query := "SELECT DISTINCT(`visibility`) FROM `memo` WHERE " + where
2023-09-28 17:09:52 +03:00
rows , err := d . db . QueryContext ( ctx , query , args ... )
if err != nil {
return nil , err
}
defer rows . Close ( )
visibilityList := make ( [ ] store . Visibility , 0 )
for rows . Next ( ) {
var visibility store . Visibility
if err := rows . Scan ( & visibility ) ; err != nil {
return nil , err
}
visibilityList = append ( visibilityList , visibility )
}
if err := rows . Err ( ) ; err != nil {
return nil , err
}
return visibilityList , nil
}
func vacuumMemo ( ctx context . Context , tx * sql . Tx ) error {
2023-10-07 17:56:12 +03:00
stmt := "DELETE FROM `memo` WHERE `creator_id` NOT IN (SELECT `id` FROM `user`)"
2023-09-28 17:09:52 +03:00
_ , err := tx . ExecContext ( ctx , stmt )
if err != nil {
return err
}
return nil
}