Use Template Haskell to define default weights.

This commit is contained in:
Maciej Bendkowski 2022-03-25 19:36:39 +01:00
parent df72dc4525
commit b0c2e4a7af
4 changed files with 49 additions and 32 deletions

View File

@ -5,8 +5,9 @@ module BinTree (BinTree (..), randomBinTreeListIO) where
import Control.Monad (replicateM)
import Data.Boltzmann.BitOracle (evalIO)
import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler')
import Data.Boltzmann.Sampler.TH (mkDefWeights)
import Data.Boltzmann.System (
ConstructorWeights (MkConstructorWeights),
Constructable ((<:>)),
System (..),
)
import Data.Boltzmann.System.TH (mkBoltzmannSampler)
@ -24,10 +25,8 @@ mkBoltzmannSampler
, meanSize = 1000
, frequencies = def
, weights =
MkConstructorWeights
[ ('Leaf, 0)
, ('Node, 1)
]
('Leaf, 0)
<:> $(mkDefWeights ''BinTree)
}
randomBinTreeListIO :: Int -> IO [BinTree]

View File

@ -5,9 +5,9 @@ module Lambda (Lambda (..), randomLambdaListIO) where
import Control.Monad (replicateM)
import Data.Boltzmann.BitOracle (evalIO)
import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler')
import Data.Boltzmann.Sampler.TH (mkDefWeights)
import Data.Boltzmann.System (
ConstructorFrequencies (MkConstructorFrequencies),
ConstructorWeights (MkConstructorWeights),
Constructable ((<:>)),
System (..),
)
import Data.Boltzmann.System.TH (mkBoltzmannSampler)
@ -31,15 +31,8 @@ mkBoltzmannSampler
, meanSize = 10_000
, frequencies = def
, weights =
MkConstructorWeights
[ -- De Bruijn
('S, 1)
, ('Z, 1)
, -- Lambda
('Index, 0)
, ('App, 1)
, ('Abs, 1)
]
('Index, 0)
<:> $(mkDefWeights ''Lambda)
}
newtype BinLambda = MkBinLambda Lambda
@ -49,18 +42,16 @@ mkBoltzmannSampler
System
{ targetType = ''BinLambda
, meanSize = 10_000
, frequencies = MkConstructorFrequencies [('Abs, 4500)]
, frequencies =
('Abs, 4500) <:> def
, weights =
MkConstructorWeights
[ -- De Bruijn
('S, 1)
, ('Z, 1)
, -- Lambda
('Index, 0)
, ('App, 2)
, ('Abs, 2)
]
('Index, 0)
<:> ('App, 2)
<:> ('Abs, 2)
<:> $(mkDefWeights ''Lambda)
}
randomLambdaListIO :: Int -> IO [BinLambda]
randomLambdaListIO n = evalIO $ replicateM n (rejectionSampler' @SMGen 10_000 0.2)
randomLambdaListIO n =
evalIO $
replicateM n (rejectionSampler' @SMGen 10_000 0.2)

View File

@ -5,12 +5,7 @@ module Tree (Tree (..), randomTreeListIO) where
import Control.Monad (replicateM)
import Data.Boltzmann.BitOracle (evalIO)
import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler')
import Data.Boltzmann.Sampler.TH (mkDefWeights)
import Data.Boltzmann.System (
System (..),
)
import Data.Boltzmann.System.TH (mkDefBoltzmannSampler)
import Data.Default (Default (def))
import System.Random.SplitMix (SMGen)
data Tree = T [Tree]

View File

@ -12,6 +12,7 @@ module Data.Boltzmann.System (
paganiniSpecIO,
hasProperConstructors,
hasProperFrequencies,
Constructable (..),
) where
import Language.Haskell.TH.Syntax (
@ -54,10 +55,28 @@ import Data.Coerce (coerce)
import Data.Default (Default (def))
import Prelude hiding (seq)
class Constructable a where
(<:>) :: (Name, Int) -> a -> a
infixr 6 <:>
newtype ConstructorWeights = MkConstructorWeights
{unConstructorWeights :: [(Name, Int)]}
deriving (Show) via [(Name, Int)]
instance Semigroup ConstructorWeights where
-- left-biased union
xs <> ys = MkConstructorWeights (Map.toList $ xs' <> ys')
where
xs' = Map.fromList (unConstructorWeights xs)
ys' = Map.fromList (unConstructorWeights ys)
instance Monoid ConstructorWeights where
mempty = MkConstructorWeights []
instance Constructable ConstructorWeights where
x <:> xs = MkConstructorWeights [x] <> xs
Lift.deriveLift ''ConstructorWeights
newtype ConstructorFrequencies = MkConstructorFrequencies
@ -67,6 +86,19 @@ newtype ConstructorFrequencies = MkConstructorFrequencies
instance Default ConstructorFrequencies where
def = MkConstructorFrequencies []
instance Semigroup ConstructorFrequencies where
-- left-biased union
xs <> ys = MkConstructorFrequencies (Map.toList $ xs' <> ys')
where
xs' = Map.fromList (unConstructorFrequencies xs)
ys' = Map.fromList (unConstructorFrequencies ys)
instance Monoid ConstructorFrequencies where
mempty = def
instance Constructable ConstructorFrequencies where
x <:> xs = MkConstructorFrequencies [x] <> xs
data System = System
{ targetType :: Name
, meanSize :: Int