From dcaed6a16300d6d8b19963e7e151869a3e2d7625 Mon Sep 17 00:00:00 2001 From: Maciej Bendkowski Date: Mon, 25 Jul 2022 20:45:56 +0200 Subject: [PATCH] Use `UpperBound` in `sample` instead of Int --- README.md | 2 +- internal/Data/Boltzmann/Sampler.hs | 14 +++------- internal/Data/Boltzmann/System.hs | 5 +++- internal/Data/Boltzmann/System/TH.hs | 39 +++++++++++++++++++++------- 4 files changed, 39 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 8ed78c6..2f9aa59 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ class BoltzmannSampler a where -- | -- Samples a random object of type @a@. If the object size is larger than -- the given upper bound parameter, @Nothing@ is returned instead. - sample :: RandomGen g => Int -> MaybeT (BuffonMachine g) (a, Int) + sample :: RandomGen g => MeanSize -> MaybeT (BuffonMachine g) (a, Int) ``` The so created `sample` function implements a Boltzmann sampler for `BinTree`. diff --git a/internal/Data/Boltzmann/Sampler.hs b/internal/Data/Boltzmann/Sampler.hs index fcc6d59..dbf398a 100644 --- a/internal/Data/Boltzmann/Sampler.hs +++ b/internal/Data/Boltzmann/Sampler.hs @@ -25,6 +25,8 @@ import Data.Boltzmann.BuffonMachine (BuffonMachine, eval) import Data.Coerce (coerce) import System.Random (RandomGen) +import Data.Boltzmann.System (MeanSize) +import Data.Boltzmann.System.TH (LowerBound (..), UpperBound (..)) import qualified Test.QuickCheck as QuickCheck (Gen) import qualified Test.QuickCheck.Gen as QuickCheck (Gen (MkGen)) import qualified Test.QuickCheck.Random as QuickCheck (QCGen (QCGen)) @@ -34,15 +36,7 @@ class BoltzmannSampler a where -- | -- Samples a random object of type @a@. If the object size is larger than -- the given upper bound parameter, @Nothing@ is returned instead. - sample :: RandomGen g => Int -> MaybeT (BuffonMachine g) (a, Int) - --- | Lower bound for rejection samplers. -newtype LowerBound = MkLowerBound Int - deriving (Show) - --- | Upper bound for rejection samplers. -newtype UpperBound = MkUpperBound Int - deriving (Show) + sample :: RandomGen g => UpperBound -> MaybeT (BuffonMachine g) (a, Int) -- | -- Rejection sampler for type @a@. Given lower and upper bound @lb@ and @ub@ @@ -67,7 +61,7 @@ rejectionSampler lb ub = do -- determine the admissible size window @[(1-eps) n, (1+eps) n]@ centered -- around the given size @n@. toleranceRejectionSampler :: - (RandomGen g, BoltzmannSampler a) => Int -> Double -> BuffonMachine g a + (RandomGen g, BoltzmannSampler a) => MeanSize -> Double -> BuffonMachine g a toleranceRejectionSampler n eps = rejectionSampler lb ub where lb = MkLowerBound $ floor $ (1 - eps) * fromIntegral n diff --git a/internal/Data/Boltzmann/System.hs b/internal/Data/Boltzmann/System.hs index be6999a..b7cefb3 100644 --- a/internal/Data/Boltzmann/System.hs +++ b/internal/Data/Boltzmann/System.hs @@ -21,6 +21,7 @@ module Data.Boltzmann.System ( hasProperFrequencies, hasNonNegativeEntries, Constructable (..), + MeanSize, ) where import Language.Haskell.TH.Syntax ( @@ -119,13 +120,15 @@ instance Monoid ConstructorFrequencies where instance Constructable ConstructorFrequencies where x <:> xs = MkConstructorFrequencies [x] <> xs +type MeanSize = Int + -- | -- System of algebraic data types. data System = System { -- | Target type of the system. targetType :: Name , -- | Target mean size of the target types. - meanSize :: Int + meanSize :: MeanSize , -- | Weights of all constructors in the system. weights :: ConstructorWeights , -- | Frequencies of selected constructors in the system. diff --git a/internal/Data/Boltzmann/System/TH.hs b/internal/Data/Boltzmann/System/TH.hs index e84c964..27c5f38 100644 --- a/internal/Data/Boltzmann/System/TH.hs +++ b/internal/Data/Boltzmann/System/TH.hs @@ -1,9 +1,12 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TemplateHaskell #-} module Data.Boltzmann.System.TH ( mkBoltzmannSampler, mkDefBoltzmannSampler, + LowerBound (..), + UpperBound (..), ) where import qualified Data.Map.Strict as Map @@ -33,6 +36,7 @@ import Data.Boltzmann.Sampler.TH ( targetTypeSynonym, ) import Data.Boltzmann.System ( + MeanSize, System ( System, frequencies, @@ -46,7 +50,13 @@ import Data.Boltzmann.System ( import Data.Coerce (coerce) import Data.Default (def) import Data.Functor ((<&>)) -import Language.Haskell.TH (Exp (LamCaseE), Name, Q, Type (ArrowT, ListT)) +import Language.Haskell.TH ( + Exp (LamCaseE), + Name, + Pat (ConP), + Q, + Type (ArrowT, ListT), + ) import Language.Haskell.TH.Datatype ( ConstructorInfo (constructorFields, constructorName), DatatypeInfo (datatypeCons), @@ -71,6 +81,16 @@ import Language.Haskell.TH.Syntax ( newName, ) +-- | Lower bound for rejection samplers. +newtype LowerBound = MkLowerBound Int + deriving stock (Show) + deriving newtype (Ord, Eq, Num) + +-- | Upper bound for rejection samplers. +newtype UpperBound = MkUpperBound Int + deriving stock (Show) + deriving newtype (Ord, Eq, Num) + type SamplerGen a = ReaderT (SamplerCtx ()) Q a data TypeVariant = Plain TypeName | List TypeName @@ -108,7 +128,7 @@ toTypeVariant typ = fail $ "Unsupported type " ++ show typ mkConstrCoerce :: TypeVariant -> ConstructorInfo -> SamplerGen Exp mkConstrCoerce tv info = do typeVariants <- mapM toTypeVariant (constructorFields info) - let constrType = foldr arr (convert tv) (map convert typeVariants) + let constrType = foldr (arr . convert) (convert tv) typeVariants typSynonym <- findTypeSyn tv synonyms <- mapM findTypeSyn typeVariants @@ -156,9 +176,9 @@ mkCaseConstr = \case \case 0 -> pure ([], 0) 1 -> do - (x, w) <- $(sampleExp typSynonym) ub - (xs, ws) <- $(sampleExp listTypSynonym) (ub - w) - pure ((x : xs), w + ws) + (x, w) <- $(sampleExp typSynonym) (coerce ub) + (xs, ws) <- $(sampleExp listTypSynonym) (coerce $ ub - w) + pure (x : xs, w + ws) |] sampleExp :: Type -> Q Exp @@ -189,7 +209,7 @@ mkArgExpr constr = do (patX, expX) <- mkPatExp "x" (patW, expW) <- mkPatExp "w" - sampleExp <- lift [|sample (ub - $(pure weight))|] + sampleExp <- lift [|sample $ coerce (ub - $(pure weight))|] weightExp <- lift [|$(pure weight) + $(pure expW)|] let stmt = BindS (TupP [patX, patW]) sampleExp @@ -260,12 +280,13 @@ mkSamplerExp typ = do choiceExp <- mkChoice typ caseExp <- mkCaseConstr typ - ub <- lift $ mkPat "ub" + ub' <- lift $ mkPat "ub" + ub <- lift $ pure $ ConP 'MkUpperBound [BangP ub'] exp <- lift [|$(pure choiceExp) >>= ($(pure caseExp))|] pure $ LamE - [BangP ub] + [ub] ( DoE Nothing [ NoBindS guardExp @@ -321,7 +342,7 @@ mkBoltzmannSampler sys = do -- the corresponding system using @mkBoltzmannSampler@. Default constructor -- weights are used (see @mkDefWeights@). No custom constructor frequencies are -- assumed. -mkDefBoltzmannSampler :: Name -> Int -> Q [Dec] +mkDefBoltzmannSampler :: Name -> MeanSize -> Q [Dec] mkDefBoltzmannSampler typ meanSize = do defWeights <- mkDefWeights' typ mkBoltzmannSampler $