Check that systems have proper weights and frequencies

This commit is contained in:
Maciej Bendkowski 2022-03-22 21:09:09 +01:00
parent 7adf382942
commit 901693b740
3 changed files with 53 additions and 2 deletions

View File

@ -6,8 +6,8 @@ import Control.Monad (replicateM)
import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler')
import Data.Boltzmann.System (System (..))
import Data.Boltzmann.System.TH (mkBoltzmannSampler)
import Data.Boltzmann.BitOracle (evalIO)
import Data.Boltzmann.System.TH (mkBoltzmannSampler)
import System.Random.SplitMix (SMGen)
data DeBruijn

View File

@ -25,6 +25,8 @@ import Data.Boltzmann.System (
System (targetType, weights),
Types (Types, regTypes),
collectTypes,
hasProperConstructors,
hasProperFrequencies,
paganiniSpecIO,
)
import Language.Haskell.TH (Q, runIO)
@ -174,6 +176,9 @@ mkSystemCtx sys = do
targetSyn <- targetTypeSynonym sys info
let sys' = sys {targetType = coerce targetSyn}
hasProperConstructors sys'
hasProperFrequencies sys'
types <- collectTypes sys'
distributions <- runIO $ do
spec <- paganiniSpecIO sys' types

View File

@ -5,11 +5,13 @@ module Data.Boltzmann.System (
System (..),
getWeight,
paganiniSpecIO,
hasProperConstructors,
hasProperFrequencies,
) where
import Language.Haskell.TH.Syntax (Name, Type (AppT, ConT, ListT))
import Control.Monad (foldM, forM, replicateM)
import Control.Monad (foldM, forM, replicateM, unless)
import Data.Boltzmann.Distribution (Distribution (Distribution))
import qualified Data.Map as Map
import Data.Map.Strict (Map)
@ -100,6 +102,50 @@ collectFromType types typ =
collectFromDataTypeInfo types' info
_ -> fail $ "Unsupported type " ++ show typ
format :: Show a => Set a -> String
format = formatList . Set.toList . Set.map show
formatList :: Show a => [a] -> String
formatList = \case
[] -> ""
[a] -> show a
a : xs@(_ : _) -> show a ++ ", " ++ formatList xs
constructors :: System -> Q (Set Name)
constructors sys = do
types <- collectTypes sys
let infos = Map.elems $ regTypes types
names = concatMap (map constructorName . datatypeCons) infos
pure $ Set.fromList names
hasProperConstructors :: System -> Q ()
hasProperConstructors sys = do
sysConstrs <- constructors sys
let weightConstrs = map fst (weights sys)
missingConstrs = sysConstrs `Set.difference` Set.fromList weightConstrs
additionalConstrs = Set.fromList weightConstrs `Set.difference` sysConstrs
unless (Set.null missingConstrs) $ do
fail $
"Missing weight for constructors: "
++ format missingConstrs
unless (Set.null additionalConstrs) $ do
fail $
"Weight definied for non-system constructors: "
++ format additionalConstrs
hasProperFrequencies :: System -> Q ()
hasProperFrequencies sys = do
sysConstrs <- constructors sys
let freqConstrs = map fst (frequencies sys)
additionalConstrs = Set.fromList freqConstrs `Set.difference` sysConstrs
unless (Set.null additionalConstrs) $ do
fail $
"Frequencies definied for non-system constructors: "
++ format additionalConstrs
mkVariables :: Set Name -> Spec (Map Name Let)
mkVariables sys = do
let n = Set.size sys