Constructor weights and frequencies

This commit is contained in:
Maciej Bendkowski 2022-03-22 21:29:12 +01:00
parent 901693b740
commit 72c8ec1508
5 changed files with 72 additions and 44 deletions

View File

@ -3,10 +3,14 @@
module BinTree (BinTree (..), randomBinTreeListIO) where
import Control.Monad (replicateM)
import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler')
import Data.Boltzmann.System (System (..))
import Data.Boltzmann.System.TH (mkBoltzmannSampler)
import Data.Boltzmann.BitOracle (evalIO)
import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler')
import Data.Boltzmann.System (
ConstructorFrequencies (MkConstructorFrequencies),
ConstructorWeights (MkConstructorWeights),
System (..),
)
import Data.Boltzmann.System.TH (mkBoltzmannSampler)
import System.Random.SplitMix (SMGen)
data BinTree
@ -18,11 +22,12 @@ mkBoltzmannSampler
System
{ targetType = ''BinTree
, meanSize = 1000
, frequencies = []
, frequencies = MkConstructorFrequencies []
, weights =
[ ('Leaf, 0)
, ('Node, 1)
]
MkConstructorWeights
[ ('Leaf, 0)
, ('Node, 1)
]
}
randomBinTreeListIO :: Int -> IO [BinTree]

View File

@ -3,10 +3,13 @@
module Lambda (Lambda (..), randomLambdaListIO) where
import Control.Monad (replicateM)
import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler')
import Data.Boltzmann.System (System (..))
import Data.Boltzmann.BitOracle (evalIO)
import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler')
import Data.Boltzmann.System (
ConstructorFrequencies (MkConstructorFrequencies),
ConstructorWeights (MkConstructorWeights),
System (..),
)
import Data.Boltzmann.System.TH (mkBoltzmannSampler)
import System.Random.SplitMix (SMGen)
@ -25,16 +28,17 @@ mkBoltzmannSampler
System
{ targetType = ''Lambda
, meanSize = 10_000
, frequencies = []
, frequencies = MkConstructorFrequencies []
, weights =
[ -- De Bruijn
('S, 1)
, ('Z, 1)
, -- Lambda
('Index, 0)
, ('App, 1)
, ('Abs, 1)
]
MkConstructorWeights
[ -- De Bruijn
('S, 1)
, ('Z, 1)
, -- Lambda
('Index, 0)
, ('App, 1)
, ('Abs, 1)
]
}
newtype BinLambda = MkBinLambda Lambda
@ -44,16 +48,17 @@ mkBoltzmannSampler
System
{ targetType = ''BinLambda
, meanSize = 10_000
, frequencies = [('Abs, 4500)]
, frequencies = MkConstructorFrequencies [('Abs, 4500)]
, weights =
[ -- De Bruijn
('S, 1)
, ('Z, 1)
, -- Lambda
('Index, 0)
, ('App, 2)
, ('Abs, 2)
]
MkConstructorWeights
[ -- De Bruijn
('S, 1)
, ('Z, 1)
, -- Lambda
('Index, 0)
, ('App, 2)
, ('Abs, 2)
]
}
randomLambdaListIO :: Int -> IO [BinLambda]

View File

@ -3,10 +3,14 @@
module Tree (Tree (..), randomTreeListIO) where
import Control.Monad (replicateM)
import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler')
import Data.Boltzmann.System (System (..))
import Data.Boltzmann.System.TH (mkBoltzmannSampler)
import Data.Boltzmann.BitOracle (evalIO)
import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler')
import Data.Boltzmann.System (
ConstructorFrequencies (MkConstructorFrequencies),
ConstructorWeights (MkConstructorWeights),
System (..),
)
import Data.Boltzmann.System.TH (mkBoltzmannSampler)
import System.Random.SplitMix (SMGen)
data Tree = T [Tree]
@ -16,10 +20,11 @@ mkBoltzmannSampler
System
{ targetType = ''Tree
, meanSize = 1000
, frequencies = []
, frequencies = MkConstructorFrequencies []
, weights =
[ ('T, 1)
]
MkConstructorWeights
[ ('T, 1)
]
}
randomTreeListIO :: Int -> IO [Tree]
@ -33,8 +38,9 @@ mkBoltzmannSampler
System
{ targetType = ''Tree'
, meanSize = 2000
, frequencies = []
, frequencies = MkConstructorFrequencies []
, weights =
[ ('T, 1)
]
MkConstructorWeights
[ ('T, 1)
]
}

View File

@ -21,6 +21,7 @@ import qualified Data.Map.Strict as Map
import Control.Monad (forM)
import Data.Boltzmann.Distribution (Distribution)
import Data.Boltzmann.System (
ConstructorWeights (unConstructorWeights),
Distributions (Distributions, listTypeDdgs, regTypeDdgs),
System (targetType, weights),
Types (Types, regTypes),
@ -106,7 +107,7 @@ newtype WeightResolver = MkWeightResolver
mkWeightResolver :: System -> WeightResolver
mkWeightResolver sys =
MkWeightResolver $ \n ->
case coerce n `lookup` weights sys of
case coerce n `lookup` unConstructorWeights (weights sys) of
Just w -> pure w
Nothing -> fail $ "Missing constructor weight for " ++ show n

View File

@ -1,6 +1,8 @@
module Data.Boltzmann.System (
Types (..),
Distributions (..),
ConstructorWeights (..),
ConstructorFrequencies (..),
collectTypes,
System (..),
getWeight,
@ -41,19 +43,28 @@ import Language.Haskell.TH.Datatype (
reifyDatatype,
)
import Data.Coerce (coerce)
import Prelude hiding (seq)
newtype ConstructorWeights = MkConstructorWeights
{unConstructorWeights :: [(Name, Int)]}
deriving (Show) via [(Name, Int)]
newtype ConstructorFrequencies = MkConstructorFrequencies
{unConstructorFrequencies :: [(Name, Int)]}
deriving (Show) via [(Name, Int)]
data System = System
{ targetType :: Name
, meanSize :: Int
, weights :: [(Name, Int)]
, frequencies :: [(Name, Int)]
, weights :: ConstructorWeights
, frequencies :: ConstructorFrequencies
}
deriving (Show)
getWeight :: System -> Name -> Int
getWeight sys name =
fromMaybe 1 $ lookup name (weights sys)
fromMaybe 1 $ lookup name (coerce $ weights sys)
data Types = Types
{ regTypes :: Map Name DatatypeInfo
@ -121,7 +132,7 @@ constructors sys = do
hasProperConstructors :: System -> Q ()
hasProperConstructors sys = do
sysConstrs <- constructors sys
let weightConstrs = map fst (weights sys)
let weightConstrs = map fst (unConstructorWeights $ weights sys)
missingConstrs = sysConstrs `Set.difference` Set.fromList weightConstrs
additionalConstrs = Set.fromList weightConstrs `Set.difference` sysConstrs
@ -138,7 +149,7 @@ hasProperConstructors sys = do
hasProperFrequencies :: System -> Q ()
hasProperFrequencies sys = do
sysConstrs <- constructors sys
let freqConstrs = map fst (frequencies sys)
let freqConstrs = map fst (unConstructorFrequencies $ frequencies sys)
additionalConstrs = Set.fromList freqConstrs `Set.difference` sysConstrs
unless (Set.null additionalConstrs) $ do
@ -170,7 +181,7 @@ mkMarkingVariables sys = do
x <- variable' $ fromIntegral freq
pure (cons, x)
)
(frequencies sys)
(unConstructorFrequencies $ frequencies sys)
pure $ Map.fromList xs