From 024f008c2eaec15105b0e3eb760f8e8f9ee7c6d3 Mon Sep 17 00:00:00 2001 From: Maciej Bendkowski Date: Wed, 23 Mar 2022 19:50:02 +0100 Subject: [PATCH] Use template Haskell to create default weights --- profile/Tree/Tree.hs | 12 +++--------- src/Data/Boltzmann/Sampler/TH.hs | 29 ++++++++++++++++++++++------- src/Data/Boltzmann/System.hs | 18 +++++++++++++++--- src/Data/Boltzmann/System/TH.hs | 2 +- 4 files changed, 41 insertions(+), 20 deletions(-) diff --git a/profile/Tree/Tree.hs b/profile/Tree/Tree.hs index aba7f5e..f85ff99 100644 --- a/profile/Tree/Tree.hs +++ b/profile/Tree/Tree.hs @@ -5,8 +5,8 @@ module Tree (Tree (..), randomTreeListIO) where import Control.Monad (replicateM) import Data.Boltzmann.BitOracle (evalIO) import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler') +import Data.Boltzmann.Sampler.TH (mkDefWeights) import Data.Boltzmann.System ( - ConstructorWeights (MkConstructorWeights), System (..), ) import Data.Boltzmann.System.TH (mkBoltzmannSampler) @@ -21,10 +21,7 @@ mkBoltzmannSampler { targetType = ''Tree , meanSize = 1000 , frequencies = def - , weights = - MkConstructorWeights - [ ('T, 1) - ] + , weights = $(mkDefWeights ''Tree) } randomTreeListIO :: Int -> IO [Tree] @@ -39,8 +36,5 @@ mkBoltzmannSampler { targetType = ''Tree' , meanSize = 2000 , frequencies = def - , weights = - MkConstructorWeights - [ ('T, 1) - ] + , weights = $(mkDefWeights ''Tree') } diff --git a/src/Data/Boltzmann/Sampler/TH.hs b/src/Data/Boltzmann/Sampler/TH.hs index 2e7aaf2..a4b8621 100644 --- a/src/Data/Boltzmann/Sampler/TH.hs +++ b/src/Data/Boltzmann/Sampler/TH.hs @@ -11,6 +11,7 @@ module Data.Boltzmann.Sampler.TH ( SamplerCtx (..), mkSystemCtx, targetTypeSynonym, + mkDefWeights, ) where import Data.Coerce (coerce) @@ -21,22 +22,25 @@ import qualified Data.Map.Strict as Map import Control.Monad (forM) import Data.Boltzmann.Distribution (Distribution) import Data.Boltzmann.System ( - ConstructorWeights (unConstructorWeights), + ConstructorWeights (MkConstructorWeights, unConstructorWeights), Distributions (Distributions, listTypeDdgs, regTypeDdgs), System (targetType, weights), Types (Types, regTypes), collectTypes, + collectTypes', hasProperConstructors, hasProperFrequencies, paganiniSpecIO, ) -import Language.Haskell.TH (Q, runIO) + +import Language.Haskell.TH (Exp, Q, runIO) import Language.Haskell.TH.Datatype ( - ConstructorInfo (constructorFields), + ConstructorInfo (constructorFields, constructorName), DatatypeInfo (datatypeCons, datatypeVariant), DatatypeVariant (Datatype, Newtype), reifyDatatype, ) +import qualified Language.Haskell.TH.Lift as Lift import Language.Haskell.TH.Syntax ( Bang (Bang), Con (NormalC), @@ -119,10 +123,10 @@ data SamplerCtx a = SamplerCtx , typeDeclarations :: [Dec] } -targetTypeSynonym :: System -> DatatypeInfo -> Q Synonym -targetTypeSynonym sys info = do +targetTypeSynonym :: Name -> DatatypeInfo -> Q Synonym +targetTypeSynonym targetType info = do case datatypeVariant info of - Datatype -> pure $ MkSynonym (targetType sys) + Datatype -> pure $ MkSynonym targetType Newtype -> case datatypeCons info of [consInfo] -> @@ -174,7 +178,7 @@ mkSystemCtx sys = do let target = targetType sys info <- reifyDatatype target - targetSyn <- targetTypeSynonym sys info + targetSyn <- targetTypeSynonym target info let sys' = sys {targetType = coerce targetSyn} hasProperConstructors sys' @@ -205,3 +209,14 @@ mkSystemCtx sys = do , constructorWeight = mkWeightResolver sys , typeDeclarations = decs } + +mkDefWeights :: Name -> Q Exp +mkDefWeights targetType = do + info <- reifyDatatype targetType + targetSyn <- targetTypeSynonym targetType info + + types <- collectTypes' (coerce targetSyn) + let infos = Map.elems $ regTypes types + names = concatMap (map constructorName . datatypeCons) infos + + Lift.lift (MkConstructorWeights $ names `zip` repeat (1 :: Int)) diff --git a/src/Data/Boltzmann/System.hs b/src/Data/Boltzmann/System.hs index 6d5f59f..2c49ee9 100644 --- a/src/Data/Boltzmann/System.hs +++ b/src/Data/Boltzmann/System.hs @@ -1,9 +1,12 @@ +{-# LANGUAGE TemplateHaskell #-} + module Data.Boltzmann.System ( Types (..), Distributions (..), ConstructorWeights (..), ConstructorFrequencies (..), collectTypes, + collectTypes', System (..), getWeight, paganiniSpecIO, @@ -11,7 +14,10 @@ module Data.Boltzmann.System ( hasProperFrequencies, ) where -import Language.Haskell.TH.Syntax (Name, Type (AppT, ConT, ListT)) +import Language.Haskell.TH.Syntax ( + Name, + Type (AppT, ConT, ListT), + ) import Control.Monad (foldM, forM, replicateM, unless) import Data.Boltzmann.Distribution (Distribution (Distribution)) @@ -42,6 +48,7 @@ import Language.Haskell.TH.Datatype ( DatatypeInfo (datatypeCons, datatypeName), reifyDatatype, ) +import qualified Language.Haskell.TH.Lift as Lift import Data.Coerce (coerce) import Data.Default (Default (def)) @@ -51,6 +58,8 @@ newtype ConstructorWeights = MkConstructorWeights {unConstructorWeights :: [(Name, Int)]} deriving (Show) via [(Name, Int)] +Lift.deriveLift ''ConstructorWeights + newtype ConstructorFrequencies = MkConstructorFrequencies {unConstructorFrequencies :: [(Name, Int)]} deriving (Show) via [(Name, Int)] @@ -80,8 +89,11 @@ initTypes :: Types initTypes = Types Map.empty Set.empty collectTypes :: System -> Q Types -collectTypes sys = do - info <- reifyDatatype $ targetType sys +collectTypes = collectTypes' . targetType + +collectTypes' :: Name -> Q Types +collectTypes' targetType = do + info <- reifyDatatype targetType collectFromDataTypeInfo initTypes info collectFromDataTypeInfo :: diff --git a/src/Data/Boltzmann/System/TH.hs b/src/Data/Boltzmann/System/TH.hs index 47bc861..03f0319 100644 --- a/src/Data/Boltzmann/System/TH.hs +++ b/src/Data/Boltzmann/System/TH.hs @@ -271,7 +271,7 @@ mkBoltzmannSampler sys = do let target = targetType sys info <- reifyDatatype target - targetSyn <- targetTypeSynonym sys info + targetSyn <- targetTypeSynonym target info let sys' = sys {targetType = coerce targetSyn} runReaderT (mkBoltzmannSampler' sys') ctx