mirror of
https://github.com/maciej-bendkowski/generic-boltzmann-brain.git
synced 2024-09-11 12:48:09 +03:00
Check that systems have proper weights and frequencies
This commit is contained in:
parent
7adf382942
commit
901693b740
@ -6,8 +6,8 @@ import Control.Monad (replicateM)
|
||||
import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler')
|
||||
import Data.Boltzmann.System (System (..))
|
||||
|
||||
import Data.Boltzmann.System.TH (mkBoltzmannSampler)
|
||||
import Data.Boltzmann.BitOracle (evalIO)
|
||||
import Data.Boltzmann.System.TH (mkBoltzmannSampler)
|
||||
import System.Random.SplitMix (SMGen)
|
||||
|
||||
data DeBruijn
|
||||
|
@ -25,6 +25,8 @@ import Data.Boltzmann.System (
|
||||
System (targetType, weights),
|
||||
Types (Types, regTypes),
|
||||
collectTypes,
|
||||
hasProperConstructors,
|
||||
hasProperFrequencies,
|
||||
paganiniSpecIO,
|
||||
)
|
||||
import Language.Haskell.TH (Q, runIO)
|
||||
@ -174,6 +176,9 @@ mkSystemCtx sys = do
|
||||
targetSyn <- targetTypeSynonym sys info
|
||||
let sys' = sys {targetType = coerce targetSyn}
|
||||
|
||||
hasProperConstructors sys'
|
||||
hasProperFrequencies sys'
|
||||
|
||||
types <- collectTypes sys'
|
||||
distributions <- runIO $ do
|
||||
spec <- paganiniSpecIO sys' types
|
||||
|
@ -5,11 +5,13 @@ module Data.Boltzmann.System (
|
||||
System (..),
|
||||
getWeight,
|
||||
paganiniSpecIO,
|
||||
hasProperConstructors,
|
||||
hasProperFrequencies,
|
||||
) where
|
||||
|
||||
import Language.Haskell.TH.Syntax (Name, Type (AppT, ConT, ListT))
|
||||
|
||||
import Control.Monad (foldM, forM, replicateM)
|
||||
import Control.Monad (foldM, forM, replicateM, unless)
|
||||
import Data.Boltzmann.Distribution (Distribution (Distribution))
|
||||
import qualified Data.Map as Map
|
||||
import Data.Map.Strict (Map)
|
||||
@ -100,6 +102,50 @@ collectFromType types typ =
|
||||
collectFromDataTypeInfo types' info
|
||||
_ -> fail $ "Unsupported type " ++ show typ
|
||||
|
||||
format :: Show a => Set a -> String
|
||||
format = formatList . Set.toList . Set.map show
|
||||
|
||||
formatList :: Show a => [a] -> String
|
||||
formatList = \case
|
||||
[] -> ""
|
||||
[a] -> show a
|
||||
a : xs@(_ : _) -> show a ++ ", " ++ formatList xs
|
||||
|
||||
constructors :: System -> Q (Set Name)
|
||||
constructors sys = do
|
||||
types <- collectTypes sys
|
||||
let infos = Map.elems $ regTypes types
|
||||
names = concatMap (map constructorName . datatypeCons) infos
|
||||
pure $ Set.fromList names
|
||||
|
||||
hasProperConstructors :: System -> Q ()
|
||||
hasProperConstructors sys = do
|
||||
sysConstrs <- constructors sys
|
||||
let weightConstrs = map fst (weights sys)
|
||||
missingConstrs = sysConstrs `Set.difference` Set.fromList weightConstrs
|
||||
additionalConstrs = Set.fromList weightConstrs `Set.difference` sysConstrs
|
||||
|
||||
unless (Set.null missingConstrs) $ do
|
||||
fail $
|
||||
"Missing weight for constructors: "
|
||||
++ format missingConstrs
|
||||
|
||||
unless (Set.null additionalConstrs) $ do
|
||||
fail $
|
||||
"Weight definied for non-system constructors: "
|
||||
++ format additionalConstrs
|
||||
|
||||
hasProperFrequencies :: System -> Q ()
|
||||
hasProperFrequencies sys = do
|
||||
sysConstrs <- constructors sys
|
||||
let freqConstrs = map fst (frequencies sys)
|
||||
additionalConstrs = Set.fromList freqConstrs `Set.difference` sysConstrs
|
||||
|
||||
unless (Set.null additionalConstrs) $ do
|
||||
fail $
|
||||
"Frequencies definied for non-system constructors: "
|
||||
++ format additionalConstrs
|
||||
|
||||
mkVariables :: Set Name -> Spec (Map Name Let)
|
||||
mkVariables sys = do
|
||||
let n = Set.size sys
|
||||
|
Loading…
Reference in New Issue
Block a user