Add TH function to specify NotNull constraint for singleton values.

This commit is contained in:
Kei Hibino 2013-05-15 16:29:17 +09:00
parent 222a94a254
commit 96d14cbbaf
4 changed files with 33 additions and 30 deletions

View File

@ -1,5 +1,6 @@
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances #-}
-- |
-- Module : Database.Record.Instances
@ -14,21 +15,10 @@
module Database.Record.Instances () where
import Data.Int (Int16, Int32, Int64)
import Database.Record.Persistable
(PersistableWidth(persistableWidth), valueWidth)
import Database.Record.TH (deriveNotNullValue)
instance PersistableWidth String where
persistableWidth = valueWidth
instance PersistableWidth Int where
persistableWidth = valueWidth
instance PersistableWidth Int16 where
persistableWidth = valueWidth
instance PersistableWidth Int32 where
persistableWidth = valueWidth
instance PersistableWidth Int64 where
persistableWidth = valueWidth
$(deriveNotNullValue [t| String |])
$(deriveNotNullValue [t| Int |])
$(deriveNotNullValue [t| Int16 |])
$(deriveNotNullValue [t| Int32 |])
$(deriveNotNullValue [t| Int64 |])

View File

@ -24,7 +24,9 @@ module Database.Record.KeyConstraint (
HasKeyConstraint (keyConstraint),
derivedUniqueConstraint,
derivedNotNullConstraint
derivedNotNullConstraint,
specifyNotNullValue
) where
newtype KeyConstraint c r = KeyConstraint { index :: Int }
@ -61,3 +63,7 @@ derivedUniqueConstraint = unique keyConstraint
derivedNotNullConstraint :: HasKeyConstraint Primary r => NotNullConstraint r
derivedNotNullConstraint = notNull keyConstraint
specifyNotNullValue :: KeyConstraint NotNull a
specifyNotNullValue = specifyKeyConstraint 0

View File

@ -27,7 +27,9 @@ module Database.Record.TH (
defineRecordWithSqlTypeDefaultFromDefined,
defineRecord,
defineRecordDefault
defineRecordDefault,
deriveNotNullValue
) where
@ -54,9 +56,10 @@ import Database.Record
ToSql(recordToSql), recordToSql')
import Database.Record.KeyConstraint
(specifyKeyConstraint)
(specifyKeyConstraint, specifyNotNullValue)
import Database.Record.Persistable
(persistableRecord, persistableRecordWidth)
import qualified Database.Record.Persistable as Persistable
defineHasKeyConstraintInstance :: TypeQ -> TypeQ -> Int -> Q [Dec]
@ -264,3 +267,13 @@ defineRecordDefault sqlValueType table columns derives = do
typ <- defineRecordTypeDefault table columns derives
withSql <- defineRecordWithSqlTypeDefault sqlValueType table columns
return $ typ : withSql
deriveNotNullValue :: TypeQ -> Q [Dec]
deriveNotNullValue typeCon =
[d| instance PersistableWidth $typeCon where
persistableWidth = Persistable.valueWidth
instance HasKeyConstraint NotNull $typeCon where
keyConstraint = specifyNotNullValue
|]

View File

@ -16,9 +16,9 @@ import Language.Haskell.TH.Name.Extra (compileError)
import Data.Convertible (Convertible)
import Database.HDBC (SqlValue)
import Database.HDBC.SqlValueExtra ()
import Database.Record (PersistableWidth(persistableWidth))
import Database.Record (PersistableWidth)
import Database.Record.TH (deriveNotNullValue)
import Database.Record.Instances ()
import qualified Database.Record.Persistable as Persistable
import Database.HDBC.Record.TH (derivePersistableInstanceFromValue)
@ -80,12 +80,6 @@ persistableWidthValues = cvInfo >>= d0 where
d1 decl = unknownDeclaration $ show decl
d0 cls = unknownDeclaration $ show cls
derivePersistableWidth :: Q Type -> Q [Dec]
derivePersistableWidth typ =
[d| instance PersistableWidth $(typ) where
persistableWidth = Persistable.valueWidth
|]
mapInstanceD :: (Q Type -> Q [Dec]) -> [Type] -> Q [Dec]
mapInstanceD fD = fmap concat . mapM (fD . return)
@ -93,6 +87,6 @@ derivePersistableInstancesFromConvertibleSqlValues :: Q [Dec]
derivePersistableInstancesFromConvertibleSqlValues = do
wds <- persistableWidthValues
svs <- convertibleSqlValues
ws <- mapInstanceD derivePersistableWidth (toList $ Set.difference svs wds)
ws <- mapInstanceD deriveNotNullValue (toList $ Set.difference svs wds)
ps <- mapInstanceD derivePersistableInstanceFromValue (toList svs)
return $ ws ++ ps