Default Boltzmann sampler construction

This commit is contained in:
Maciej Bendkowski 2022-03-24 20:17:25 +01:00
parent 024f008c2e
commit df72dc4525
3 changed files with 40 additions and 21 deletions

View File

@ -9,20 +9,14 @@ import Data.Boltzmann.Sampler.TH (mkDefWeights)
import Data.Boltzmann.System (
System (..),
)
import Data.Boltzmann.System.TH (mkBoltzmannSampler)
import Data.Boltzmann.System.TH (mkDefBoltzmannSampler)
import Data.Default (Default (def))
import System.Random.SplitMix (SMGen)
data Tree = T [Tree]
deriving (Show)
mkBoltzmannSampler
System
{ targetType = ''Tree
, meanSize = 1000
, frequencies = def
, weights = $(mkDefWeights ''Tree)
}
mkDefBoltzmannSampler ''Tree 100
randomTreeListIO :: Int -> IO [Tree]
randomTreeListIO n =
@ -31,10 +25,4 @@ randomTreeListIO n =
newtype Tree' = MkTree' Tree
deriving (Show)
mkBoltzmannSampler
System
{ targetType = ''Tree'
, meanSize = 2000
, frequencies = def
, weights = $(mkDefWeights ''Tree')
}
mkDefBoltzmannSampler ''Tree' 2000

View File

@ -12,6 +12,7 @@ module Data.Boltzmann.Sampler.TH (
mkSystemCtx,
targetTypeSynonym,
mkDefWeights,
mkDefWeights',
) where
import Data.Coerce (coerce)
@ -210,8 +211,8 @@ mkSystemCtx sys = do
, typeDeclarations = decs
}
mkDefWeights :: Name -> Q Exp
mkDefWeights targetType = do
mkDefWeights' :: Name -> Q ConstructorWeights
mkDefWeights' targetType = do
info <- reifyDatatype targetType
targetSyn <- targetTypeSynonym targetType info
@ -219,4 +220,8 @@ mkDefWeights targetType = do
let infos = Map.elems $ regTypes types
names = concatMap (map constructorName . datatypeCons) infos
Lift.lift (MkConstructorWeights $ names `zip` repeat (1 :: Int))
pure $ MkConstructorWeights $ names `zip` repeat (1 :: Int)
mkDefWeights :: Name -> Q Exp
mkDefWeights targetType =
mkDefWeights' targetType >>= Lift.lift

View File

@ -1,7 +1,10 @@
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Boltzmann.System.TH (mkBoltzmannSampler) where
module Data.Boltzmann.System.TH (
mkBoltzmannSampler,
mkDefBoltzmannSampler,
) where
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
@ -25,13 +28,25 @@ import Data.Boltzmann.Sampler.TH (
TypeDistributions (unTypeDistributions),
TypeName (MkTypeName),
WeightResolver (unWeightResolver),
mkDefWeights',
mkSystemCtx,
targetTypeSynonym,
)
import Data.Boltzmann.System (System (targetType), Types (Types), collectTypes)
import Data.Boltzmann.System (
System (
System,
frequencies,
meanSize,
targetType,
weights
),
Types (Types),
collectTypes,
)
import Data.Coerce (coerce)
import Data.Default (def)
import Data.Functor ((<&>))
import Language.Haskell.TH (Exp (LamCaseE), Q, Type (ListT))
import Language.Haskell.TH (Exp (LamCaseE), Name, Q, Type (ListT))
import Language.Haskell.TH.Datatype (
ConstructorInfo (constructorFields, constructorName),
DatatypeInfo (datatypeCons),
@ -275,3 +290,14 @@ mkBoltzmannSampler sys = do
let sys' = sys {targetType = coerce targetSyn}
runReaderT (mkBoltzmannSampler' sys') ctx
mkDefBoltzmannSampler :: Name -> Int -> Q [Dec]
mkDefBoltzmannSampler typ meanSize = do
defWeights <- mkDefWeights' typ
mkBoltzmannSampler $
System
{ targetType = typ
, meanSize = meanSize
, frequencies = def
, weights = defWeights
}