Use UpperBound in sample instead of Int

This commit is contained in:
Maciej Bendkowski 2022-07-25 20:45:56 +02:00
parent f8672b62e5
commit dcaed6a163
4 changed files with 39 additions and 21 deletions

View File

@ -47,7 +47,7 @@ class BoltzmannSampler a where
-- |
-- Samples a random object of type @a@. If the object size is larger than
-- the given upper bound parameter, @Nothing@ is returned instead.
sample :: RandomGen g => Int -> MaybeT (BuffonMachine g) (a, Int)
sample :: RandomGen g => MeanSize -> MaybeT (BuffonMachine g) (a, Int)
```
The so created `sample` function implements a Boltzmann sampler for `BinTree`.

View File

@ -25,6 +25,8 @@ import Data.Boltzmann.BuffonMachine (BuffonMachine, eval)
import Data.Coerce (coerce)
import System.Random (RandomGen)
import Data.Boltzmann.System (MeanSize)
import Data.Boltzmann.System.TH (LowerBound (..), UpperBound (..))
import qualified Test.QuickCheck as QuickCheck (Gen)
import qualified Test.QuickCheck.Gen as QuickCheck (Gen (MkGen))
import qualified Test.QuickCheck.Random as QuickCheck (QCGen (QCGen))
@ -34,15 +36,7 @@ class BoltzmannSampler a where
-- |
-- Samples a random object of type @a@. If the object size is larger than
-- the given upper bound parameter, @Nothing@ is returned instead.
sample :: RandomGen g => Int -> MaybeT (BuffonMachine g) (a, Int)
-- | Lower bound for rejection samplers.
newtype LowerBound = MkLowerBound Int
deriving (Show)
-- | Upper bound for rejection samplers.
newtype UpperBound = MkUpperBound Int
deriving (Show)
sample :: RandomGen g => UpperBound -> MaybeT (BuffonMachine g) (a, Int)
-- |
-- Rejection sampler for type @a@. Given lower and upper bound @lb@ and @ub@
@ -67,7 +61,7 @@ rejectionSampler lb ub = do
-- determine the admissible size window @[(1-eps) n, (1+eps) n]@ centered
-- around the given size @n@.
toleranceRejectionSampler ::
(RandomGen g, BoltzmannSampler a) => Int -> Double -> BuffonMachine g a
(RandomGen g, BoltzmannSampler a) => MeanSize -> Double -> BuffonMachine g a
toleranceRejectionSampler n eps = rejectionSampler lb ub
where
lb = MkLowerBound $ floor $ (1 - eps) * fromIntegral n

View File

@ -21,6 +21,7 @@ module Data.Boltzmann.System (
hasProperFrequencies,
hasNonNegativeEntries,
Constructable (..),
MeanSize,
) where
import Language.Haskell.TH.Syntax (
@ -119,13 +120,15 @@ instance Monoid ConstructorFrequencies where
instance Constructable ConstructorFrequencies where
x <:> xs = MkConstructorFrequencies [x] <> xs
type MeanSize = Int
-- |
-- System of algebraic data types.
data System = System
{ -- | Target type of the system.
targetType :: Name
, -- | Target mean size of the target types.
meanSize :: Int
meanSize :: MeanSize
, -- | Weights of all constructors in the system.
weights :: ConstructorWeights
, -- | Frequencies of selected constructors in the system.

View File

@ -1,9 +1,12 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Boltzmann.System.TH (
mkBoltzmannSampler,
mkDefBoltzmannSampler,
LowerBound (..),
UpperBound (..),
) where
import qualified Data.Map.Strict as Map
@ -33,6 +36,7 @@ import Data.Boltzmann.Sampler.TH (
targetTypeSynonym,
)
import Data.Boltzmann.System (
MeanSize,
System (
System,
frequencies,
@ -46,7 +50,13 @@ import Data.Boltzmann.System (
import Data.Coerce (coerce)
import Data.Default (def)
import Data.Functor ((<&>))
import Language.Haskell.TH (Exp (LamCaseE), Name, Q, Type (ArrowT, ListT))
import Language.Haskell.TH (
Exp (LamCaseE),
Name,
Pat (ConP),
Q,
Type (ArrowT, ListT),
)
import Language.Haskell.TH.Datatype (
ConstructorInfo (constructorFields, constructorName),
DatatypeInfo (datatypeCons),
@ -71,6 +81,16 @@ import Language.Haskell.TH.Syntax (
newName,
)
-- | Lower bound for rejection samplers.
newtype LowerBound = MkLowerBound Int
deriving stock (Show)
deriving newtype (Ord, Eq, Num)
-- | Upper bound for rejection samplers.
newtype UpperBound = MkUpperBound Int
deriving stock (Show)
deriving newtype (Ord, Eq, Num)
type SamplerGen a = ReaderT (SamplerCtx ()) Q a
data TypeVariant = Plain TypeName | List TypeName
@ -108,7 +128,7 @@ toTypeVariant typ = fail $ "Unsupported type " ++ show typ
mkConstrCoerce :: TypeVariant -> ConstructorInfo -> SamplerGen Exp
mkConstrCoerce tv info = do
typeVariants <- mapM toTypeVariant (constructorFields info)
let constrType = foldr arr (convert tv) (map convert typeVariants)
let constrType = foldr (arr . convert) (convert tv) typeVariants
typSynonym <- findTypeSyn tv
synonyms <- mapM findTypeSyn typeVariants
@ -156,9 +176,9 @@ mkCaseConstr = \case
\case
0 -> pure ([], 0)
1 -> do
(x, w) <- $(sampleExp typSynonym) ub
(xs, ws) <- $(sampleExp listTypSynonym) (ub - w)
pure ((x : xs), w + ws)
(x, w) <- $(sampleExp typSynonym) (coerce ub)
(xs, ws) <- $(sampleExp listTypSynonym) (coerce $ ub - w)
pure (x : xs, w + ws)
|]
sampleExp :: Type -> Q Exp
@ -189,7 +209,7 @@ mkArgExpr constr = do
(patX, expX) <- mkPatExp "x"
(patW, expW) <- mkPatExp "w"
sampleExp <- lift [|sample (ub - $(pure weight))|]
sampleExp <- lift [|sample $ coerce (ub - $(pure weight))|]
weightExp <- lift [|$(pure weight) + $(pure expW)|]
let stmt = BindS (TupP [patX, patW]) sampleExp
@ -260,12 +280,13 @@ mkSamplerExp typ = do
choiceExp <- mkChoice typ
caseExp <- mkCaseConstr typ
ub <- lift $ mkPat "ub"
ub' <- lift $ mkPat "ub"
ub <- lift $ pure $ ConP 'MkUpperBound [BangP ub']
exp <- lift [|$(pure choiceExp) >>= ($(pure caseExp))|]
pure $
LamE
[BangP ub]
[ub]
( DoE
Nothing
[ NoBindS guardExp
@ -321,7 +342,7 @@ mkBoltzmannSampler sys = do
-- the corresponding system using @mkBoltzmannSampler@. Default constructor
-- weights are used (see @mkDefWeights@). No custom constructor frequencies are
-- assumed.
mkDefBoltzmannSampler :: Name -> Int -> Q [Dec]
mkDefBoltzmannSampler :: Name -> MeanSize -> Q [Dec]
mkDefBoltzmannSampler typ meanSize = do
defWeights <- mkDefWeights' typ
mkBoltzmannSampler $