Add simple TConSet implementation.

This commit is contained in:
Kei Hibino 2013-05-15 14:50:16 +09:00
parent 0681a2ccbf
commit 83651702ac

View File

@ -6,7 +6,8 @@ module Database.HDBC.Record.InternalTH (
) where ) where
import Data.Maybe (catMaybes) import Data.Maybe (catMaybes)
import Data.List (intersect, find) import Data.Set (Set)
import qualified Data.Set as Set
import Language.Haskell.TH import Language.Haskell.TH
(Q, Dec (InstanceD), Type(AppT, ConT), (Q, Dec (InstanceD), Type(AppT, ConT),
@ -22,6 +23,24 @@ import qualified Database.Record.Persistable as Persistable
import Database.HDBC.Record.TH (derivePersistableInstanceFromValue) import Database.HDBC.Record.TH (derivePersistableInstanceFromValue)
newtype TypeCon = TypeCon { unTypeCon :: Type } deriving Eq
instance Ord TypeCon where
TypeCon (ConT an) `compare` TypeCon (ConT bn) = an `compare` bn
TypeCon (ConT _) `compare` TypeCon _ = LT
TypeCon _ `compare` TypeCon (ConT _) = GT
a `compare` b | a == b = EQ
| otherwise = EQ
type TConSet = Set TypeCon
fromList :: [Type] -> TConSet
fromList = Set.fromList . map TypeCon
toList :: TConSet -> [Type]
toList = map unTypeCon . Set.toList
sqlValueType :: Q Type sqlValueType :: Q Type
sqlValueType = [t| SqlValue |] sqlValueType = [t| SqlValue |]
@ -43,20 +62,20 @@ convertibleSqlValues' = cvInfo >>= d0 where
= unknownDeclaration $ show decl = unknownDeclaration $ show decl
d0 cls = unknownDeclaration $ show cls d0 cls = unknownDeclaration $ show cls
convertibleSqlValues :: Q [Type] convertibleSqlValues :: Q TConSet
convertibleSqlValues = do convertibleSqlValues = do
qvt <- sqlValueType qvt <- sqlValueType
vs <- convertibleSqlValues' vs <- convertibleSqlValues'
let from = map snd . filter ((== qvt) . fst) $ vs let from = fromList . map snd . filter ((== qvt) . fst) $ vs
to = map fst . filter ((== qvt) . snd) $ vs to = fromList . map fst . filter ((== qvt) . snd) $ vs
return $ intersect from to return $ Set.intersection from to
persistableWidthValues :: Q [Type] persistableWidthValues :: Q TConSet
persistableWidthValues = cvInfo >>= d0 where persistableWidthValues = cvInfo >>= d0 where
cvInfo = reify ''PersistableWidth cvInfo = reify ''PersistableWidth
unknownDeclaration = compileError unknownDeclaration = compileError
. ("persistableWidthValues: Unknown declaration pattern: " ++) . ("persistableWidthValues: Unknown declaration pattern: " ++)
d0 (ClassI _ is) = sequence . map d1 $ is where d0 (ClassI _ is) = fmap fromList . sequence . map d1 $ is where
d1 (InstanceD _cxt (AppT (ConT _n) a) _ds) = return a d1 (InstanceD _cxt (AppT (ConT _n) a) _ds) = return a
d1 decl = unknownDeclaration $ show decl d1 decl = unknownDeclaration $ show decl
d0 cls = unknownDeclaration $ show cls d0 cls = unknownDeclaration $ show cls
@ -72,13 +91,8 @@ mapInstanceD fD = fmap concat . mapM (fD . return)
derivePersistableInstancesFromConvertibleSqlValues :: Q [Dec] derivePersistableInstancesFromConvertibleSqlValues :: Q [Dec]
derivePersistableInstancesFromConvertibleSqlValues = do derivePersistableInstancesFromConvertibleSqlValues = do
ds <- persistableWidthValues wds <- persistableWidthValues
ts <- convertibleSqlValues svs <- convertibleSqlValues
let defineNotDefined qt = do ws <- mapInstanceD derivePersistableWidth (toList $ Set.difference svs wds)
t <- qt ps <- mapInstanceD derivePersistableInstanceFromValue (toList svs)
case find (== t) ds of
Nothing -> derivePersistableWidth qt
Just _ -> return []
ws <- mapInstanceD defineNotDefined ts
ps <- mapInstanceD derivePersistableInstanceFromValue ts
return $ ws ++ ps return $ ws ++ ps