Use template Haskell to create default weights

This commit is contained in:
Maciej Bendkowski 2022-03-23 19:50:02 +01:00
parent 11dcdc4920
commit 024f008c2e
4 changed files with 41 additions and 20 deletions

View File

@ -5,8 +5,8 @@ module Tree (Tree (..), randomTreeListIO) where
import Control.Monad (replicateM)
import Data.Boltzmann.BitOracle (evalIO)
import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler')
import Data.Boltzmann.Sampler.TH (mkDefWeights)
import Data.Boltzmann.System (
ConstructorWeights (MkConstructorWeights),
System (..),
)
import Data.Boltzmann.System.TH (mkBoltzmannSampler)
@ -21,10 +21,7 @@ mkBoltzmannSampler
{ targetType = ''Tree
, meanSize = 1000
, frequencies = def
, weights =
MkConstructorWeights
[ ('T, 1)
]
, weights = $(mkDefWeights ''Tree)
}
randomTreeListIO :: Int -> IO [Tree]
@ -39,8 +36,5 @@ mkBoltzmannSampler
{ targetType = ''Tree'
, meanSize = 2000
, frequencies = def
, weights =
MkConstructorWeights
[ ('T, 1)
]
, weights = $(mkDefWeights ''Tree')
}

View File

@ -11,6 +11,7 @@ module Data.Boltzmann.Sampler.TH (
SamplerCtx (..),
mkSystemCtx,
targetTypeSynonym,
mkDefWeights,
) where
import Data.Coerce (coerce)
@ -21,22 +22,25 @@ import qualified Data.Map.Strict as Map
import Control.Monad (forM)
import Data.Boltzmann.Distribution (Distribution)
import Data.Boltzmann.System (
ConstructorWeights (unConstructorWeights),
ConstructorWeights (MkConstructorWeights, unConstructorWeights),
Distributions (Distributions, listTypeDdgs, regTypeDdgs),
System (targetType, weights),
Types (Types, regTypes),
collectTypes,
collectTypes',
hasProperConstructors,
hasProperFrequencies,
paganiniSpecIO,
)
import Language.Haskell.TH (Q, runIO)
import Language.Haskell.TH (Exp, Q, runIO)
import Language.Haskell.TH.Datatype (
ConstructorInfo (constructorFields),
ConstructorInfo (constructorFields, constructorName),
DatatypeInfo (datatypeCons, datatypeVariant),
DatatypeVariant (Datatype, Newtype),
reifyDatatype,
)
import qualified Language.Haskell.TH.Lift as Lift
import Language.Haskell.TH.Syntax (
Bang (Bang),
Con (NormalC),
@ -119,10 +123,10 @@ data SamplerCtx a = SamplerCtx
, typeDeclarations :: [Dec]
}
targetTypeSynonym :: System -> DatatypeInfo -> Q Synonym
targetTypeSynonym sys info = do
targetTypeSynonym :: Name -> DatatypeInfo -> Q Synonym
targetTypeSynonym targetType info = do
case datatypeVariant info of
Datatype -> pure $ MkSynonym (targetType sys)
Datatype -> pure $ MkSynonym targetType
Newtype ->
case datatypeCons info of
[consInfo] ->
@ -174,7 +178,7 @@ mkSystemCtx sys = do
let target = targetType sys
info <- reifyDatatype target
targetSyn <- targetTypeSynonym sys info
targetSyn <- targetTypeSynonym target info
let sys' = sys {targetType = coerce targetSyn}
hasProperConstructors sys'
@ -205,3 +209,14 @@ mkSystemCtx sys = do
, constructorWeight = mkWeightResolver sys
, typeDeclarations = decs
}
mkDefWeights :: Name -> Q Exp
mkDefWeights targetType = do
info <- reifyDatatype targetType
targetSyn <- targetTypeSynonym targetType info
types <- collectTypes' (coerce targetSyn)
let infos = Map.elems $ regTypes types
names = concatMap (map constructorName . datatypeCons) infos
Lift.lift (MkConstructorWeights $ names `zip` repeat (1 :: Int))

View File

@ -1,9 +1,12 @@
{-# LANGUAGE TemplateHaskell #-}
module Data.Boltzmann.System (
Types (..),
Distributions (..),
ConstructorWeights (..),
ConstructorFrequencies (..),
collectTypes,
collectTypes',
System (..),
getWeight,
paganiniSpecIO,
@ -11,7 +14,10 @@ module Data.Boltzmann.System (
hasProperFrequencies,
) where
import Language.Haskell.TH.Syntax (Name, Type (AppT, ConT, ListT))
import Language.Haskell.TH.Syntax (
Name,
Type (AppT, ConT, ListT),
)
import Control.Monad (foldM, forM, replicateM, unless)
import Data.Boltzmann.Distribution (Distribution (Distribution))
@ -42,6 +48,7 @@ import Language.Haskell.TH.Datatype (
DatatypeInfo (datatypeCons, datatypeName),
reifyDatatype,
)
import qualified Language.Haskell.TH.Lift as Lift
import Data.Coerce (coerce)
import Data.Default (Default (def))
@ -51,6 +58,8 @@ newtype ConstructorWeights = MkConstructorWeights
{unConstructorWeights :: [(Name, Int)]}
deriving (Show) via [(Name, Int)]
Lift.deriveLift ''ConstructorWeights
newtype ConstructorFrequencies = MkConstructorFrequencies
{unConstructorFrequencies :: [(Name, Int)]}
deriving (Show) via [(Name, Int)]
@ -80,8 +89,11 @@ initTypes :: Types
initTypes = Types Map.empty Set.empty
collectTypes :: System -> Q Types
collectTypes sys = do
info <- reifyDatatype $ targetType sys
collectTypes = collectTypes' . targetType
collectTypes' :: Name -> Q Types
collectTypes' targetType = do
info <- reifyDatatype targetType
collectFromDataTypeInfo initTypes info
collectFromDataTypeInfo ::

View File

@ -271,7 +271,7 @@ mkBoltzmannSampler sys = do
let target = targetType sys
info <- reifyDatatype target
targetSyn <- targetTypeSynonym sys info
targetSyn <- targetTypeSynonym target info
let sys' = sys {targetType = coerce targetSyn}
runReaderT (mkBoltzmannSampler' sys') ctx