mirror of
https://github.com/maciej-bendkowski/generic-boltzmann-brain.git
synced 2024-08-16 16:10:27 +03:00
Use UpperBound
in sample
instead of Int
This commit is contained in:
parent
f8672b62e5
commit
dcaed6a163
@ -47,7 +47,7 @@ class BoltzmannSampler a where
|
||||
-- |
|
||||
-- Samples a random object of type @a@. If the object size is larger than
|
||||
-- the given upper bound parameter, @Nothing@ is returned instead.
|
||||
sample :: RandomGen g => Int -> MaybeT (BuffonMachine g) (a, Int)
|
||||
sample :: RandomGen g => MeanSize -> MaybeT (BuffonMachine g) (a, Int)
|
||||
```
|
||||
|
||||
The so created `sample` function implements a Boltzmann sampler for `BinTree`.
|
||||
|
@ -25,6 +25,8 @@ import Data.Boltzmann.BuffonMachine (BuffonMachine, eval)
|
||||
import Data.Coerce (coerce)
|
||||
import System.Random (RandomGen)
|
||||
|
||||
import Data.Boltzmann.System (MeanSize)
|
||||
import Data.Boltzmann.System.TH (LowerBound (..), UpperBound (..))
|
||||
import qualified Test.QuickCheck as QuickCheck (Gen)
|
||||
import qualified Test.QuickCheck.Gen as QuickCheck (Gen (MkGen))
|
||||
import qualified Test.QuickCheck.Random as QuickCheck (QCGen (QCGen))
|
||||
@ -34,15 +36,7 @@ class BoltzmannSampler a where
|
||||
-- |
|
||||
-- Samples a random object of type @a@. If the object size is larger than
|
||||
-- the given upper bound parameter, @Nothing@ is returned instead.
|
||||
sample :: RandomGen g => Int -> MaybeT (BuffonMachine g) (a, Int)
|
||||
|
||||
-- | Lower bound for rejection samplers.
|
||||
newtype LowerBound = MkLowerBound Int
|
||||
deriving (Show)
|
||||
|
||||
-- | Upper bound for rejection samplers.
|
||||
newtype UpperBound = MkUpperBound Int
|
||||
deriving (Show)
|
||||
sample :: RandomGen g => UpperBound -> MaybeT (BuffonMachine g) (a, Int)
|
||||
|
||||
-- |
|
||||
-- Rejection sampler for type @a@. Given lower and upper bound @lb@ and @ub@
|
||||
@ -67,7 +61,7 @@ rejectionSampler lb ub = do
|
||||
-- determine the admissible size window @[(1-eps) n, (1+eps) n]@ centered
|
||||
-- around the given size @n@.
|
||||
toleranceRejectionSampler ::
|
||||
(RandomGen g, BoltzmannSampler a) => Int -> Double -> BuffonMachine g a
|
||||
(RandomGen g, BoltzmannSampler a) => MeanSize -> Double -> BuffonMachine g a
|
||||
toleranceRejectionSampler n eps = rejectionSampler lb ub
|
||||
where
|
||||
lb = MkLowerBound $ floor $ (1 - eps) * fromIntegral n
|
||||
|
@ -21,6 +21,7 @@ module Data.Boltzmann.System (
|
||||
hasProperFrequencies,
|
||||
hasNonNegativeEntries,
|
||||
Constructable (..),
|
||||
MeanSize,
|
||||
) where
|
||||
|
||||
import Language.Haskell.TH.Syntax (
|
||||
@ -119,13 +120,15 @@ instance Monoid ConstructorFrequencies where
|
||||
instance Constructable ConstructorFrequencies where
|
||||
x <:> xs = MkConstructorFrequencies [x] <> xs
|
||||
|
||||
type MeanSize = Int
|
||||
|
||||
-- |
|
||||
-- System of algebraic data types.
|
||||
data System = System
|
||||
{ -- | Target type of the system.
|
||||
targetType :: Name
|
||||
, -- | Target mean size of the target types.
|
||||
meanSize :: Int
|
||||
meanSize :: MeanSize
|
||||
, -- | Weights of all constructors in the system.
|
||||
weights :: ConstructorWeights
|
||||
, -- | Frequencies of selected constructors in the system.
|
||||
|
@ -1,9 +1,12 @@
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
|
||||
module Data.Boltzmann.System.TH (
|
||||
mkBoltzmannSampler,
|
||||
mkDefBoltzmannSampler,
|
||||
LowerBound (..),
|
||||
UpperBound (..),
|
||||
) where
|
||||
|
||||
import qualified Data.Map.Strict as Map
|
||||
@ -33,6 +36,7 @@ import Data.Boltzmann.Sampler.TH (
|
||||
targetTypeSynonym,
|
||||
)
|
||||
import Data.Boltzmann.System (
|
||||
MeanSize,
|
||||
System (
|
||||
System,
|
||||
frequencies,
|
||||
@ -46,7 +50,13 @@ import Data.Boltzmann.System (
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Default (def)
|
||||
import Data.Functor ((<&>))
|
||||
import Language.Haskell.TH (Exp (LamCaseE), Name, Q, Type (ArrowT, ListT))
|
||||
import Language.Haskell.TH (
|
||||
Exp (LamCaseE),
|
||||
Name,
|
||||
Pat (ConP),
|
||||
Q,
|
||||
Type (ArrowT, ListT),
|
||||
)
|
||||
import Language.Haskell.TH.Datatype (
|
||||
ConstructorInfo (constructorFields, constructorName),
|
||||
DatatypeInfo (datatypeCons),
|
||||
@ -71,6 +81,16 @@ import Language.Haskell.TH.Syntax (
|
||||
newName,
|
||||
)
|
||||
|
||||
-- | Lower bound for rejection samplers.
|
||||
newtype LowerBound = MkLowerBound Int
|
||||
deriving stock (Show)
|
||||
deriving newtype (Ord, Eq, Num)
|
||||
|
||||
-- | Upper bound for rejection samplers.
|
||||
newtype UpperBound = MkUpperBound Int
|
||||
deriving stock (Show)
|
||||
deriving newtype (Ord, Eq, Num)
|
||||
|
||||
type SamplerGen a = ReaderT (SamplerCtx ()) Q a
|
||||
|
||||
data TypeVariant = Plain TypeName | List TypeName
|
||||
@ -108,7 +128,7 @@ toTypeVariant typ = fail $ "Unsupported type " ++ show typ
|
||||
mkConstrCoerce :: TypeVariant -> ConstructorInfo -> SamplerGen Exp
|
||||
mkConstrCoerce tv info = do
|
||||
typeVariants <- mapM toTypeVariant (constructorFields info)
|
||||
let constrType = foldr arr (convert tv) (map convert typeVariants)
|
||||
let constrType = foldr (arr . convert) (convert tv) typeVariants
|
||||
|
||||
typSynonym <- findTypeSyn tv
|
||||
synonyms <- mapM findTypeSyn typeVariants
|
||||
@ -156,9 +176,9 @@ mkCaseConstr = \case
|
||||
\case
|
||||
0 -> pure ([], 0)
|
||||
1 -> do
|
||||
(x, w) <- $(sampleExp typSynonym) ub
|
||||
(xs, ws) <- $(sampleExp listTypSynonym) (ub - w)
|
||||
pure ((x : xs), w + ws)
|
||||
(x, w) <- $(sampleExp typSynonym) (coerce ub)
|
||||
(xs, ws) <- $(sampleExp listTypSynonym) (coerce $ ub - w)
|
||||
pure (x : xs, w + ws)
|
||||
|]
|
||||
|
||||
sampleExp :: Type -> Q Exp
|
||||
@ -189,7 +209,7 @@ mkArgExpr constr = do
|
||||
(patX, expX) <- mkPatExp "x"
|
||||
(patW, expW) <- mkPatExp "w"
|
||||
|
||||
sampleExp <- lift [|sample (ub - $(pure weight))|]
|
||||
sampleExp <- lift [|sample $ coerce (ub - $(pure weight))|]
|
||||
weightExp <- lift [|$(pure weight) + $(pure expW)|]
|
||||
let stmt = BindS (TupP [patX, patW]) sampleExp
|
||||
|
||||
@ -260,12 +280,13 @@ mkSamplerExp typ = do
|
||||
choiceExp <- mkChoice typ
|
||||
caseExp <- mkCaseConstr typ
|
||||
|
||||
ub <- lift $ mkPat "ub"
|
||||
ub' <- lift $ mkPat "ub"
|
||||
ub <- lift $ pure $ ConP 'MkUpperBound [BangP ub']
|
||||
exp <- lift [|$(pure choiceExp) >>= ($(pure caseExp))|]
|
||||
|
||||
pure $
|
||||
LamE
|
||||
[BangP ub]
|
||||
[ub]
|
||||
( DoE
|
||||
Nothing
|
||||
[ NoBindS guardExp
|
||||
@ -321,7 +342,7 @@ mkBoltzmannSampler sys = do
|
||||
-- the corresponding system using @mkBoltzmannSampler@. Default constructor
|
||||
-- weights are used (see @mkDefWeights@). No custom constructor frequencies are
|
||||
-- assumed.
|
||||
mkDefBoltzmannSampler :: Name -> Int -> Q [Dec]
|
||||
mkDefBoltzmannSampler :: Name -> MeanSize -> Q [Dec]
|
||||
mkDefBoltzmannSampler typ meanSize = do
|
||||
defWeights <- mkDefWeights' typ
|
||||
mkBoltzmannSampler $
|
||||
|
Loading…
Reference in New Issue
Block a user