Use UpperBound in sample instead of Int

This commit is contained in:
Maciej Bendkowski 2022-07-25 20:45:56 +02:00
parent f8672b62e5
commit dcaed6a163
4 changed files with 39 additions and 21 deletions

View File

@ -47,7 +47,7 @@ class BoltzmannSampler a where
-- | -- |
-- Samples a random object of type @a@. If the object size is larger than -- Samples a random object of type @a@. If the object size is larger than
-- the given upper bound parameter, @Nothing@ is returned instead. -- the given upper bound parameter, @Nothing@ is returned instead.
sample :: RandomGen g => Int -> MaybeT (BuffonMachine g) (a, Int) sample :: RandomGen g => MeanSize -> MaybeT (BuffonMachine g) (a, Int)
``` ```
The so created `sample` function implements a Boltzmann sampler for `BinTree`. The so created `sample` function implements a Boltzmann sampler for `BinTree`.

View File

@ -25,6 +25,8 @@ import Data.Boltzmann.BuffonMachine (BuffonMachine, eval)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import System.Random (RandomGen) import System.Random (RandomGen)
import Data.Boltzmann.System (MeanSize)
import Data.Boltzmann.System.TH (LowerBound (..), UpperBound (..))
import qualified Test.QuickCheck as QuickCheck (Gen) import qualified Test.QuickCheck as QuickCheck (Gen)
import qualified Test.QuickCheck.Gen as QuickCheck (Gen (MkGen)) import qualified Test.QuickCheck.Gen as QuickCheck (Gen (MkGen))
import qualified Test.QuickCheck.Random as QuickCheck (QCGen (QCGen)) import qualified Test.QuickCheck.Random as QuickCheck (QCGen (QCGen))
@ -34,15 +36,7 @@ class BoltzmannSampler a where
-- | -- |
-- Samples a random object of type @a@. If the object size is larger than -- Samples a random object of type @a@. If the object size is larger than
-- the given upper bound parameter, @Nothing@ is returned instead. -- the given upper bound parameter, @Nothing@ is returned instead.
sample :: RandomGen g => Int -> MaybeT (BuffonMachine g) (a, Int) sample :: RandomGen g => UpperBound -> MaybeT (BuffonMachine g) (a, Int)
-- | Lower bound for rejection samplers.
newtype LowerBound = MkLowerBound Int
deriving (Show)
-- | Upper bound for rejection samplers.
newtype UpperBound = MkUpperBound Int
deriving (Show)
-- | -- |
-- Rejection sampler for type @a@. Given lower and upper bound @lb@ and @ub@ -- Rejection sampler for type @a@. Given lower and upper bound @lb@ and @ub@
@ -67,7 +61,7 @@ rejectionSampler lb ub = do
-- determine the admissible size window @[(1-eps) n, (1+eps) n]@ centered -- determine the admissible size window @[(1-eps) n, (1+eps) n]@ centered
-- around the given size @n@. -- around the given size @n@.
toleranceRejectionSampler :: toleranceRejectionSampler ::
(RandomGen g, BoltzmannSampler a) => Int -> Double -> BuffonMachine g a (RandomGen g, BoltzmannSampler a) => MeanSize -> Double -> BuffonMachine g a
toleranceRejectionSampler n eps = rejectionSampler lb ub toleranceRejectionSampler n eps = rejectionSampler lb ub
where where
lb = MkLowerBound $ floor $ (1 - eps) * fromIntegral n lb = MkLowerBound $ floor $ (1 - eps) * fromIntegral n

View File

@ -21,6 +21,7 @@ module Data.Boltzmann.System (
hasProperFrequencies, hasProperFrequencies,
hasNonNegativeEntries, hasNonNegativeEntries,
Constructable (..), Constructable (..),
MeanSize,
) where ) where
import Language.Haskell.TH.Syntax ( import Language.Haskell.TH.Syntax (
@ -119,13 +120,15 @@ instance Monoid ConstructorFrequencies where
instance Constructable ConstructorFrequencies where instance Constructable ConstructorFrequencies where
x <:> xs = MkConstructorFrequencies [x] <> xs x <:> xs = MkConstructorFrequencies [x] <> xs
type MeanSize = Int
-- | -- |
-- System of algebraic data types. -- System of algebraic data types.
data System = System data System = System
{ -- | Target type of the system. { -- | Target type of the system.
targetType :: Name targetType :: Name
, -- | Target mean size of the target types. , -- | Target mean size of the target types.
meanSize :: Int meanSize :: MeanSize
, -- | Weights of all constructors in the system. , -- | Weights of all constructors in the system.
weights :: ConstructorWeights weights :: ConstructorWeights
, -- | Frequencies of selected constructors in the system. , -- | Frequencies of selected constructors in the system.

View File

@ -1,9 +1,12 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TemplateHaskell #-}
module Data.Boltzmann.System.TH ( module Data.Boltzmann.System.TH (
mkBoltzmannSampler, mkBoltzmannSampler,
mkDefBoltzmannSampler, mkDefBoltzmannSampler,
LowerBound (..),
UpperBound (..),
) where ) where
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
@ -33,6 +36,7 @@ import Data.Boltzmann.Sampler.TH (
targetTypeSynonym, targetTypeSynonym,
) )
import Data.Boltzmann.System ( import Data.Boltzmann.System (
MeanSize,
System ( System (
System, System,
frequencies, frequencies,
@ -46,7 +50,13 @@ import Data.Boltzmann.System (
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Default (def) import Data.Default (def)
import Data.Functor ((<&>)) import Data.Functor ((<&>))
import Language.Haskell.TH (Exp (LamCaseE), Name, Q, Type (ArrowT, ListT)) import Language.Haskell.TH (
Exp (LamCaseE),
Name,
Pat (ConP),
Q,
Type (ArrowT, ListT),
)
import Language.Haskell.TH.Datatype ( import Language.Haskell.TH.Datatype (
ConstructorInfo (constructorFields, constructorName), ConstructorInfo (constructorFields, constructorName),
DatatypeInfo (datatypeCons), DatatypeInfo (datatypeCons),
@ -71,6 +81,16 @@ import Language.Haskell.TH.Syntax (
newName, newName,
) )
-- | Lower bound for rejection samplers.
newtype LowerBound = MkLowerBound Int
deriving stock (Show)
deriving newtype (Ord, Eq, Num)
-- | Upper bound for rejection samplers.
newtype UpperBound = MkUpperBound Int
deriving stock (Show)
deriving newtype (Ord, Eq, Num)
type SamplerGen a = ReaderT (SamplerCtx ()) Q a type SamplerGen a = ReaderT (SamplerCtx ()) Q a
data TypeVariant = Plain TypeName | List TypeName data TypeVariant = Plain TypeName | List TypeName
@ -108,7 +128,7 @@ toTypeVariant typ = fail $ "Unsupported type " ++ show typ
mkConstrCoerce :: TypeVariant -> ConstructorInfo -> SamplerGen Exp mkConstrCoerce :: TypeVariant -> ConstructorInfo -> SamplerGen Exp
mkConstrCoerce tv info = do mkConstrCoerce tv info = do
typeVariants <- mapM toTypeVariant (constructorFields info) typeVariants <- mapM toTypeVariant (constructorFields info)
let constrType = foldr arr (convert tv) (map convert typeVariants) let constrType = foldr (arr . convert) (convert tv) typeVariants
typSynonym <- findTypeSyn tv typSynonym <- findTypeSyn tv
synonyms <- mapM findTypeSyn typeVariants synonyms <- mapM findTypeSyn typeVariants
@ -156,9 +176,9 @@ mkCaseConstr = \case
\case \case
0 -> pure ([], 0) 0 -> pure ([], 0)
1 -> do 1 -> do
(x, w) <- $(sampleExp typSynonym) ub (x, w) <- $(sampleExp typSynonym) (coerce ub)
(xs, ws) <- $(sampleExp listTypSynonym) (ub - w) (xs, ws) <- $(sampleExp listTypSynonym) (coerce $ ub - w)
pure ((x : xs), w + ws) pure (x : xs, w + ws)
|] |]
sampleExp :: Type -> Q Exp sampleExp :: Type -> Q Exp
@ -189,7 +209,7 @@ mkArgExpr constr = do
(patX, expX) <- mkPatExp "x" (patX, expX) <- mkPatExp "x"
(patW, expW) <- mkPatExp "w" (patW, expW) <- mkPatExp "w"
sampleExp <- lift [|sample (ub - $(pure weight))|] sampleExp <- lift [|sample $ coerce (ub - $(pure weight))|]
weightExp <- lift [|$(pure weight) + $(pure expW)|] weightExp <- lift [|$(pure weight) + $(pure expW)|]
let stmt = BindS (TupP [patX, patW]) sampleExp let stmt = BindS (TupP [patX, patW]) sampleExp
@ -260,12 +280,13 @@ mkSamplerExp typ = do
choiceExp <- mkChoice typ choiceExp <- mkChoice typ
caseExp <- mkCaseConstr typ caseExp <- mkCaseConstr typ
ub <- lift $ mkPat "ub" ub' <- lift $ mkPat "ub"
ub <- lift $ pure $ ConP 'MkUpperBound [BangP ub']
exp <- lift [|$(pure choiceExp) >>= ($(pure caseExp))|] exp <- lift [|$(pure choiceExp) >>= ($(pure caseExp))|]
pure $ pure $
LamE LamE
[BangP ub] [ub]
( DoE ( DoE
Nothing Nothing
[ NoBindS guardExp [ NoBindS guardExp
@ -321,7 +342,7 @@ mkBoltzmannSampler sys = do
-- the corresponding system using @mkBoltzmannSampler@. Default constructor -- the corresponding system using @mkBoltzmannSampler@. Default constructor
-- weights are used (see @mkDefWeights@). No custom constructor frequencies are -- weights are used (see @mkDefWeights@). No custom constructor frequencies are
-- assumed. -- assumed.
mkDefBoltzmannSampler :: Name -> Int -> Q [Dec] mkDefBoltzmannSampler :: Name -> MeanSize -> Q [Dec]
mkDefBoltzmannSampler typ meanSize = do mkDefBoltzmannSampler typ meanSize = do
defWeights <- mkDefWeights' typ defWeights <- mkDefWeights' typ
mkBoltzmannSampler $ mkBoltzmannSampler $