Include Distribution in the BitOracle module

This commit is contained in:
Maciej Bendkowski 2022-03-26 19:11:25 +01:00
parent 882b55f40b
commit 64cc3147e0
8 changed files with 39 additions and 47 deletions

View File

@ -55,7 +55,6 @@ library
library generic-boltzmann-brain-internal
exposed-modules:
Data.Boltzmann.BitOracle
Data.Boltzmann.Distribution
Data.Boltzmann.Sampler
Data.Boltzmann.Sampler.TH
Data.Boltzmann.System
@ -172,7 +171,7 @@ test-suite generic-boltzmann-brain-test
type: exitcode-stdio-1.0
main-is: Spec.hs
other-modules:
Test.Unit.Distribution
Test.Unit.BitOracle
Paths_generic_boltzmann_brain
hs-source-dirs:
test

View File

@ -1,3 +1,5 @@
{-# LANGUAGE TemplateHaskell #-}
-- |
-- Module : Data.Boltzmann.BitOracle
-- Description :
@ -12,6 +14,8 @@ module Data.Boltzmann.BitOracle (
EvalIO (..),
eval,
getBit,
Distribution (..),
choice,
) where
import Control.Monad.Trans.State.Strict (
@ -23,10 +27,13 @@ import Control.Monad.Trans.State.Strict (
)
import Data.Bits (Bits (testBit))
import Data.Vector (Vector, null, (!))
import Data.Word (Word32)
import Instances.TH.Lift ()
import Language.Haskell.TH.Lift (deriveLift)
import System.Random (Random (random), RandomGen, StdGen, getStdGen)
import System.Random.SplitMix (SMGen, initSMGen)
import Prelude hiding (null)
-- | Buffered random bit oracle.
data Oracle g = Oracle
@ -94,3 +101,27 @@ instance EvalIO SMGen where
instance EvalIO StdGen where
{-# INLINE evalIO #-}
evalIO m = eval m <$> getStdGen
newtype Distribution = Distribution {unDistribution :: Vector Int}
deriving stock (Show)
deriveLift ''Distribution
-- |
-- Given a compact discrete distribution generating tree (in vector form)
-- computes a discrete random variable following that distribution.
choice :: RandomGen g => Distribution -> Discrete g
choice enc
| null (unDistribution enc) = pure 0
| otherwise = choice' enc 0
{-# SPECIALIZE choice :: Distribution -> Discrete SMGen #-}
{-# SPECIALIZE choice :: Distribution -> Discrete StdGen #-}
choice' :: RandomGen g => Distribution -> Int -> Discrete g
choice' enc c = do
h <- getBit
let b = fromEnum h
let c' = unDistribution enc ! (c + b)
if unDistribution enc ! c' < 0
then pure $ -(1 + unDistribution enc ! c')
else choice' enc c'

View File

@ -1,37 +0,0 @@
{-# LANGUAGE TemplateHaskell #-}
module Data.Boltzmann.Distribution (
Distribution (..),
choice,
) where
import Data.Boltzmann.BitOracle (Discrete, getBit)
import Data.Vector (Vector, null, (!))
import Language.Haskell.TH.Lift (deriveLift)
import System.Random (RandomGen, StdGen)
import System.Random.SplitMix (SMGen)
import Prelude hiding (null)
newtype Distribution = Distribution {unDistribution :: Vector Int}
deriving stock (Show)
deriveLift ''Distribution
-- |
-- Given a compact discrete distribution generating tree (in vector form)
-- computes a discrete random variable following that distribution.
choice :: RandomGen g => Distribution -> Discrete g
choice enc
| null (unDistribution enc) = pure 0
| otherwise = choice' enc 0
{-# SPECIALIZE choice :: Distribution -> Discrete SMGen #-}
{-# SPECIALIZE choice :: Distribution -> Discrete StdGen #-}
choice' :: RandomGen g => Distribution -> Int -> Discrete g
choice' enc c = do
h <- getBit
let b = fromEnum h
let c' = unDistribution enc ! (c + b)
if unDistribution enc ! c' < 0
then pure $ -(1 + unDistribution enc ! c')
else choice' enc c'

View File

@ -21,7 +21,7 @@ import Data.Map (Map)
import qualified Data.Map.Strict as Map
import Control.Monad (forM)
import Data.Boltzmann.Distribution (Distribution)
import Data.Boltzmann.BitOracle (Distribution)
import Data.Boltzmann.System (
ConstructorWeights (MkConstructorWeights, unConstructorWeights),
Distributions (Distributions, listTypeDdgs, regTypeDdgs),

View File

@ -21,7 +21,7 @@ import Language.Haskell.TH.Syntax (
)
import Control.Monad (foldM, forM, replicateM, unless)
import Data.Boltzmann.Distribution (Distribution (Distribution))
import Data.Boltzmann.BitOracle (Distribution (Distribution))
import qualified Data.Map as Map
import Data.Map.Strict (Map)
import Data.Maybe (fromJust, fromMaybe)

View File

@ -12,7 +12,7 @@ import qualified Data.Set as Set
import Control.Monad (forM, guard)
import Control.Monad.Trans (MonadTrans (lift))
import Control.Monad.Trans.Reader (ReaderT (runReaderT), asks)
import Data.Boltzmann.Distribution (Distribution, choice)
import Data.Boltzmann.BitOracle (Distribution, choice)
import Data.Boltzmann.Sampler.TH (
ConstructorName (MkConstructorName),
ListTypeDistributions (unListTypeDistributions),

View File

@ -1,5 +1,5 @@
import Test.Tasty (TestTree, defaultMain, testGroup)
import qualified Test.Unit.Distribution as Distribution
import qualified Test.Unit.BitOracle as BitOracle
main :: IO ()
main = defaultMain tests
@ -9,4 +9,4 @@ tests = testGroup "Unit tests" unitTests
unitTests :: [TestTree]
unitTests =
[Distribution.unitTests]
[BitOracle.unitTests]

View File

@ -1,8 +1,7 @@
module Test.Unit.Distribution (unitTests) where
module Test.Unit.BitOracle (unitTests) where
import Control.Monad (replicateM)
import Data.Boltzmann.Distribution (Distribution (..), choice)
import Data.Boltzmann.BitOracle (evalIO)
import Data.Boltzmann.BitOracle (Distribution (..), choice, evalIO)
import qualified Data.Map as Map
import Data.Vector (fromList)
import System.Random.SplitMix (SMGen)