mirror of
https://github.com/maciej-bendkowski/generic-boltzmann-brain.git
synced 2024-10-26 13:22:32 +03:00
Use UpperBound
in sample
instead of Int
This commit is contained in:
parent
f8672b62e5
commit
dcaed6a163
@ -47,7 +47,7 @@ class BoltzmannSampler a where
|
|||||||
-- |
|
-- |
|
||||||
-- Samples a random object of type @a@. If the object size is larger than
|
-- Samples a random object of type @a@. If the object size is larger than
|
||||||
-- the given upper bound parameter, @Nothing@ is returned instead.
|
-- the given upper bound parameter, @Nothing@ is returned instead.
|
||||||
sample :: RandomGen g => Int -> MaybeT (BuffonMachine g) (a, Int)
|
sample :: RandomGen g => MeanSize -> MaybeT (BuffonMachine g) (a, Int)
|
||||||
```
|
```
|
||||||
|
|
||||||
The so created `sample` function implements a Boltzmann sampler for `BinTree`.
|
The so created `sample` function implements a Boltzmann sampler for `BinTree`.
|
||||||
|
@ -25,6 +25,8 @@ import Data.Boltzmann.BuffonMachine (BuffonMachine, eval)
|
|||||||
import Data.Coerce (coerce)
|
import Data.Coerce (coerce)
|
||||||
import System.Random (RandomGen)
|
import System.Random (RandomGen)
|
||||||
|
|
||||||
|
import Data.Boltzmann.System (MeanSize)
|
||||||
|
import Data.Boltzmann.System.TH (LowerBound (..), UpperBound (..))
|
||||||
import qualified Test.QuickCheck as QuickCheck (Gen)
|
import qualified Test.QuickCheck as QuickCheck (Gen)
|
||||||
import qualified Test.QuickCheck.Gen as QuickCheck (Gen (MkGen))
|
import qualified Test.QuickCheck.Gen as QuickCheck (Gen (MkGen))
|
||||||
import qualified Test.QuickCheck.Random as QuickCheck (QCGen (QCGen))
|
import qualified Test.QuickCheck.Random as QuickCheck (QCGen (QCGen))
|
||||||
@ -34,15 +36,7 @@ class BoltzmannSampler a where
|
|||||||
-- |
|
-- |
|
||||||
-- Samples a random object of type @a@. If the object size is larger than
|
-- Samples a random object of type @a@. If the object size is larger than
|
||||||
-- the given upper bound parameter, @Nothing@ is returned instead.
|
-- the given upper bound parameter, @Nothing@ is returned instead.
|
||||||
sample :: RandomGen g => Int -> MaybeT (BuffonMachine g) (a, Int)
|
sample :: RandomGen g => UpperBound -> MaybeT (BuffonMachine g) (a, Int)
|
||||||
|
|
||||||
-- | Lower bound for rejection samplers.
|
|
||||||
newtype LowerBound = MkLowerBound Int
|
|
||||||
deriving (Show)
|
|
||||||
|
|
||||||
-- | Upper bound for rejection samplers.
|
|
||||||
newtype UpperBound = MkUpperBound Int
|
|
||||||
deriving (Show)
|
|
||||||
|
|
||||||
-- |
|
-- |
|
||||||
-- Rejection sampler for type @a@. Given lower and upper bound @lb@ and @ub@
|
-- Rejection sampler for type @a@. Given lower and upper bound @lb@ and @ub@
|
||||||
@ -67,7 +61,7 @@ rejectionSampler lb ub = do
|
|||||||
-- determine the admissible size window @[(1-eps) n, (1+eps) n]@ centered
|
-- determine the admissible size window @[(1-eps) n, (1+eps) n]@ centered
|
||||||
-- around the given size @n@.
|
-- around the given size @n@.
|
||||||
toleranceRejectionSampler ::
|
toleranceRejectionSampler ::
|
||||||
(RandomGen g, BoltzmannSampler a) => Int -> Double -> BuffonMachine g a
|
(RandomGen g, BoltzmannSampler a) => MeanSize -> Double -> BuffonMachine g a
|
||||||
toleranceRejectionSampler n eps = rejectionSampler lb ub
|
toleranceRejectionSampler n eps = rejectionSampler lb ub
|
||||||
where
|
where
|
||||||
lb = MkLowerBound $ floor $ (1 - eps) * fromIntegral n
|
lb = MkLowerBound $ floor $ (1 - eps) * fromIntegral n
|
||||||
|
@ -21,6 +21,7 @@ module Data.Boltzmann.System (
|
|||||||
hasProperFrequencies,
|
hasProperFrequencies,
|
||||||
hasNonNegativeEntries,
|
hasNonNegativeEntries,
|
||||||
Constructable (..),
|
Constructable (..),
|
||||||
|
MeanSize,
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Language.Haskell.TH.Syntax (
|
import Language.Haskell.TH.Syntax (
|
||||||
@ -119,13 +120,15 @@ instance Monoid ConstructorFrequencies where
|
|||||||
instance Constructable ConstructorFrequencies where
|
instance Constructable ConstructorFrequencies where
|
||||||
x <:> xs = MkConstructorFrequencies [x] <> xs
|
x <:> xs = MkConstructorFrequencies [x] <> xs
|
||||||
|
|
||||||
|
type MeanSize = Int
|
||||||
|
|
||||||
-- |
|
-- |
|
||||||
-- System of algebraic data types.
|
-- System of algebraic data types.
|
||||||
data System = System
|
data System = System
|
||||||
{ -- | Target type of the system.
|
{ -- | Target type of the system.
|
||||||
targetType :: Name
|
targetType :: Name
|
||||||
, -- | Target mean size of the target types.
|
, -- | Target mean size of the target types.
|
||||||
meanSize :: Int
|
meanSize :: MeanSize
|
||||||
, -- | Weights of all constructors in the system.
|
, -- | Weights of all constructors in the system.
|
||||||
weights :: ConstructorWeights
|
weights :: ConstructorWeights
|
||||||
, -- | Frequencies of selected constructors in the system.
|
, -- | Frequencies of selected constructors in the system.
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
|
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||||
{-# LANGUAGE RecordWildCards #-}
|
{-# LANGUAGE RecordWildCards #-}
|
||||||
{-# LANGUAGE TemplateHaskell #-}
|
{-# LANGUAGE TemplateHaskell #-}
|
||||||
|
|
||||||
module Data.Boltzmann.System.TH (
|
module Data.Boltzmann.System.TH (
|
||||||
mkBoltzmannSampler,
|
mkBoltzmannSampler,
|
||||||
mkDefBoltzmannSampler,
|
mkDefBoltzmannSampler,
|
||||||
|
LowerBound (..),
|
||||||
|
UpperBound (..),
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
@ -33,6 +36,7 @@ import Data.Boltzmann.Sampler.TH (
|
|||||||
targetTypeSynonym,
|
targetTypeSynonym,
|
||||||
)
|
)
|
||||||
import Data.Boltzmann.System (
|
import Data.Boltzmann.System (
|
||||||
|
MeanSize,
|
||||||
System (
|
System (
|
||||||
System,
|
System,
|
||||||
frequencies,
|
frequencies,
|
||||||
@ -46,7 +50,13 @@ import Data.Boltzmann.System (
|
|||||||
import Data.Coerce (coerce)
|
import Data.Coerce (coerce)
|
||||||
import Data.Default (def)
|
import Data.Default (def)
|
||||||
import Data.Functor ((<&>))
|
import Data.Functor ((<&>))
|
||||||
import Language.Haskell.TH (Exp (LamCaseE), Name, Q, Type (ArrowT, ListT))
|
import Language.Haskell.TH (
|
||||||
|
Exp (LamCaseE),
|
||||||
|
Name,
|
||||||
|
Pat (ConP),
|
||||||
|
Q,
|
||||||
|
Type (ArrowT, ListT),
|
||||||
|
)
|
||||||
import Language.Haskell.TH.Datatype (
|
import Language.Haskell.TH.Datatype (
|
||||||
ConstructorInfo (constructorFields, constructorName),
|
ConstructorInfo (constructorFields, constructorName),
|
||||||
DatatypeInfo (datatypeCons),
|
DatatypeInfo (datatypeCons),
|
||||||
@ -71,6 +81,16 @@ import Language.Haskell.TH.Syntax (
|
|||||||
newName,
|
newName,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
-- | Lower bound for rejection samplers.
|
||||||
|
newtype LowerBound = MkLowerBound Int
|
||||||
|
deriving stock (Show)
|
||||||
|
deriving newtype (Ord, Eq, Num)
|
||||||
|
|
||||||
|
-- | Upper bound for rejection samplers.
|
||||||
|
newtype UpperBound = MkUpperBound Int
|
||||||
|
deriving stock (Show)
|
||||||
|
deriving newtype (Ord, Eq, Num)
|
||||||
|
|
||||||
type SamplerGen a = ReaderT (SamplerCtx ()) Q a
|
type SamplerGen a = ReaderT (SamplerCtx ()) Q a
|
||||||
|
|
||||||
data TypeVariant = Plain TypeName | List TypeName
|
data TypeVariant = Plain TypeName | List TypeName
|
||||||
@ -108,7 +128,7 @@ toTypeVariant typ = fail $ "Unsupported type " ++ show typ
|
|||||||
mkConstrCoerce :: TypeVariant -> ConstructorInfo -> SamplerGen Exp
|
mkConstrCoerce :: TypeVariant -> ConstructorInfo -> SamplerGen Exp
|
||||||
mkConstrCoerce tv info = do
|
mkConstrCoerce tv info = do
|
||||||
typeVariants <- mapM toTypeVariant (constructorFields info)
|
typeVariants <- mapM toTypeVariant (constructorFields info)
|
||||||
let constrType = foldr arr (convert tv) (map convert typeVariants)
|
let constrType = foldr (arr . convert) (convert tv) typeVariants
|
||||||
|
|
||||||
typSynonym <- findTypeSyn tv
|
typSynonym <- findTypeSyn tv
|
||||||
synonyms <- mapM findTypeSyn typeVariants
|
synonyms <- mapM findTypeSyn typeVariants
|
||||||
@ -156,9 +176,9 @@ mkCaseConstr = \case
|
|||||||
\case
|
\case
|
||||||
0 -> pure ([], 0)
|
0 -> pure ([], 0)
|
||||||
1 -> do
|
1 -> do
|
||||||
(x, w) <- $(sampleExp typSynonym) ub
|
(x, w) <- $(sampleExp typSynonym) (coerce ub)
|
||||||
(xs, ws) <- $(sampleExp listTypSynonym) (ub - w)
|
(xs, ws) <- $(sampleExp listTypSynonym) (coerce $ ub - w)
|
||||||
pure ((x : xs), w + ws)
|
pure (x : xs, w + ws)
|
||||||
|]
|
|]
|
||||||
|
|
||||||
sampleExp :: Type -> Q Exp
|
sampleExp :: Type -> Q Exp
|
||||||
@ -189,7 +209,7 @@ mkArgExpr constr = do
|
|||||||
(patX, expX) <- mkPatExp "x"
|
(patX, expX) <- mkPatExp "x"
|
||||||
(patW, expW) <- mkPatExp "w"
|
(patW, expW) <- mkPatExp "w"
|
||||||
|
|
||||||
sampleExp <- lift [|sample (ub - $(pure weight))|]
|
sampleExp <- lift [|sample $ coerce (ub - $(pure weight))|]
|
||||||
weightExp <- lift [|$(pure weight) + $(pure expW)|]
|
weightExp <- lift [|$(pure weight) + $(pure expW)|]
|
||||||
let stmt = BindS (TupP [patX, patW]) sampleExp
|
let stmt = BindS (TupP [patX, patW]) sampleExp
|
||||||
|
|
||||||
@ -260,12 +280,13 @@ mkSamplerExp typ = do
|
|||||||
choiceExp <- mkChoice typ
|
choiceExp <- mkChoice typ
|
||||||
caseExp <- mkCaseConstr typ
|
caseExp <- mkCaseConstr typ
|
||||||
|
|
||||||
ub <- lift $ mkPat "ub"
|
ub' <- lift $ mkPat "ub"
|
||||||
|
ub <- lift $ pure $ ConP 'MkUpperBound [BangP ub']
|
||||||
exp <- lift [|$(pure choiceExp) >>= ($(pure caseExp))|]
|
exp <- lift [|$(pure choiceExp) >>= ($(pure caseExp))|]
|
||||||
|
|
||||||
pure $
|
pure $
|
||||||
LamE
|
LamE
|
||||||
[BangP ub]
|
[ub]
|
||||||
( DoE
|
( DoE
|
||||||
Nothing
|
Nothing
|
||||||
[ NoBindS guardExp
|
[ NoBindS guardExp
|
||||||
@ -321,7 +342,7 @@ mkBoltzmannSampler sys = do
|
|||||||
-- the corresponding system using @mkBoltzmannSampler@. Default constructor
|
-- the corresponding system using @mkBoltzmannSampler@. Default constructor
|
||||||
-- weights are used (see @mkDefWeights@). No custom constructor frequencies are
|
-- weights are used (see @mkDefWeights@). No custom constructor frequencies are
|
||||||
-- assumed.
|
-- assumed.
|
||||||
mkDefBoltzmannSampler :: Name -> Int -> Q [Dec]
|
mkDefBoltzmannSampler :: Name -> MeanSize -> Q [Dec]
|
||||||
mkDefBoltzmannSampler typ meanSize = do
|
mkDefBoltzmannSampler typ meanSize = do
|
||||||
defWeights <- mkDefWeights' typ
|
defWeights <- mkDefWeights' typ
|
||||||
mkBoltzmannSampler $
|
mkBoltzmannSampler $
|
||||||
|
Loading…
Reference in New Issue
Block a user