Add simple TConSet implementation.

This commit is contained in:
Kei Hibino 2013-05-15 14:50:16 +09:00
parent 0681a2ccbf
commit 83651702ac

View File

@ -6,7 +6,8 @@ module Database.HDBC.Record.InternalTH (
) where
import Data.Maybe (catMaybes)
import Data.List (intersect, find)
import Data.Set (Set)
import qualified Data.Set as Set
import Language.Haskell.TH
(Q, Dec (InstanceD), Type(AppT, ConT),
@ -22,6 +23,24 @@ import qualified Database.Record.Persistable as Persistable
import Database.HDBC.Record.TH (derivePersistableInstanceFromValue)
newtype TypeCon = TypeCon { unTypeCon :: Type } deriving Eq
instance Ord TypeCon where
TypeCon (ConT an) `compare` TypeCon (ConT bn) = an `compare` bn
TypeCon (ConT _) `compare` TypeCon _ = LT
TypeCon _ `compare` TypeCon (ConT _) = GT
a `compare` b | a == b = EQ
| otherwise = EQ
type TConSet = Set TypeCon
fromList :: [Type] -> TConSet
fromList = Set.fromList . map TypeCon
toList :: TConSet -> [Type]
toList = map unTypeCon . Set.toList
sqlValueType :: Q Type
sqlValueType = [t| SqlValue |]
@ -43,20 +62,20 @@ convertibleSqlValues' = cvInfo >>= d0 where
= unknownDeclaration $ show decl
d0 cls = unknownDeclaration $ show cls
convertibleSqlValues :: Q [Type]
convertibleSqlValues :: Q TConSet
convertibleSqlValues = do
qvt <- sqlValueType
vs <- convertibleSqlValues'
let from = map snd . filter ((== qvt) . fst) $ vs
to = map fst . filter ((== qvt) . snd) $ vs
return $ intersect from to
let from = fromList . map snd . filter ((== qvt) . fst) $ vs
to = fromList . map fst . filter ((== qvt) . snd) $ vs
return $ Set.intersection from to
persistableWidthValues :: Q [Type]
persistableWidthValues :: Q TConSet
persistableWidthValues = cvInfo >>= d0 where
cvInfo = reify ''PersistableWidth
unknownDeclaration = compileError
. ("persistableWidthValues: Unknown declaration pattern: " ++)
d0 (ClassI _ is) = sequence . map d1 $ is where
d0 (ClassI _ is) = fmap fromList . sequence . map d1 $ is where
d1 (InstanceD _cxt (AppT (ConT _n) a) _ds) = return a
d1 decl = unknownDeclaration $ show decl
d0 cls = unknownDeclaration $ show cls
@ -72,13 +91,8 @@ mapInstanceD fD = fmap concat . mapM (fD . return)
derivePersistableInstancesFromConvertibleSqlValues :: Q [Dec]
derivePersistableInstancesFromConvertibleSqlValues = do
ds <- persistableWidthValues
ts <- convertibleSqlValues
let defineNotDefined qt = do
t <- qt
case find (== t) ds of
Nothing -> derivePersistableWidth qt
Just _ -> return []
ws <- mapInstanceD defineNotDefined ts
ps <- mapInstanceD derivePersistableInstanceFromValue ts
wds <- persistableWidthValues
svs <- convertibleSqlValues
ws <- mapInstanceD derivePersistableWidth (toList $ Set.difference svs wds)
ps <- mapInstanceD derivePersistableInstanceFromValue (toList svs)
return $ ws ++ ps