From 901693b7406588758d7e423b14389aeeffc4859c Mon Sep 17 00:00:00 2001 From: Maciej Bendkowski Date: Tue, 22 Mar 2022 21:09:09 +0100 Subject: [PATCH] Check that systems have proper weights and frequencies --- profile/Lambda/Lambda.hs | 2 +- src/Data/Boltzmann/Sampler/TH.hs | 5 ++++ src/Data/Boltzmann/System.hs | 48 +++++++++++++++++++++++++++++++- 3 files changed, 53 insertions(+), 2 deletions(-) diff --git a/profile/Lambda/Lambda.hs b/profile/Lambda/Lambda.hs index 8ff7afb..f5be3a6 100644 --- a/profile/Lambda/Lambda.hs +++ b/profile/Lambda/Lambda.hs @@ -6,8 +6,8 @@ import Control.Monad (replicateM) import Data.Boltzmann.Sampler (BoltzmannSampler (..), rejectionSampler') import Data.Boltzmann.System (System (..)) -import Data.Boltzmann.System.TH (mkBoltzmannSampler) import Data.Boltzmann.BitOracle (evalIO) +import Data.Boltzmann.System.TH (mkBoltzmannSampler) import System.Random.SplitMix (SMGen) data DeBruijn diff --git a/src/Data/Boltzmann/Sampler/TH.hs b/src/Data/Boltzmann/Sampler/TH.hs index 3e56b69..a4231fe 100644 --- a/src/Data/Boltzmann/Sampler/TH.hs +++ b/src/Data/Boltzmann/Sampler/TH.hs @@ -25,6 +25,8 @@ import Data.Boltzmann.System ( System (targetType, weights), Types (Types, regTypes), collectTypes, + hasProperConstructors, + hasProperFrequencies, paganiniSpecIO, ) import Language.Haskell.TH (Q, runIO) @@ -174,6 +176,9 @@ mkSystemCtx sys = do targetSyn <- targetTypeSynonym sys info let sys' = sys {targetType = coerce targetSyn} + hasProperConstructors sys' + hasProperFrequencies sys' + types <- collectTypes sys' distributions <- runIO $ do spec <- paganiniSpecIO sys' types diff --git a/src/Data/Boltzmann/System.hs b/src/Data/Boltzmann/System.hs index 6498159..551472c 100644 --- a/src/Data/Boltzmann/System.hs +++ b/src/Data/Boltzmann/System.hs @@ -5,11 +5,13 @@ module Data.Boltzmann.System ( System (..), getWeight, paganiniSpecIO, + hasProperConstructors, + hasProperFrequencies, ) where import Language.Haskell.TH.Syntax (Name, Type (AppT, ConT, ListT)) -import Control.Monad (foldM, forM, replicateM) +import Control.Monad (foldM, forM, replicateM, unless) import Data.Boltzmann.Distribution (Distribution (Distribution)) import qualified Data.Map as Map import Data.Map.Strict (Map) @@ -100,6 +102,50 @@ collectFromType types typ = collectFromDataTypeInfo types' info _ -> fail $ "Unsupported type " ++ show typ +format :: Show a => Set a -> String +format = formatList . Set.toList . Set.map show + +formatList :: Show a => [a] -> String +formatList = \case + [] -> "" + [a] -> show a + a : xs@(_ : _) -> show a ++ ", " ++ formatList xs + +constructors :: System -> Q (Set Name) +constructors sys = do + types <- collectTypes sys + let infos = Map.elems $ regTypes types + names = concatMap (map constructorName . datatypeCons) infos + pure $ Set.fromList names + +hasProperConstructors :: System -> Q () +hasProperConstructors sys = do + sysConstrs <- constructors sys + let weightConstrs = map fst (weights sys) + missingConstrs = sysConstrs `Set.difference` Set.fromList weightConstrs + additionalConstrs = Set.fromList weightConstrs `Set.difference` sysConstrs + + unless (Set.null missingConstrs) $ do + fail $ + "Missing weight for constructors: " + ++ format missingConstrs + + unless (Set.null additionalConstrs) $ do + fail $ + "Weight definied for non-system constructors: " + ++ format additionalConstrs + +hasProperFrequencies :: System -> Q () +hasProperFrequencies sys = do + sysConstrs <- constructors sys + let freqConstrs = map fst (frequencies sys) + additionalConstrs = Set.fromList freqConstrs `Set.difference` sysConstrs + + unless (Set.null additionalConstrs) $ do + fail $ + "Frequencies definied for non-system constructors: " + ++ format additionalConstrs + mkVariables :: Set Name -> Spec (Map Name Let) mkVariables sys = do let n = Set.size sys