From 64cc3147e03abc4eec72439f938879a97c556433 Mon Sep 17 00:00:00 2001 From: Maciej Bendkowski Date: Sat, 26 Mar 2022 19:11:25 +0100 Subject: [PATCH] Include `Distribution` in the `BitOracle` module --- generic-boltzmann-brain.cabal | 3 +- internal/Data/Boltzmann/BitOracle.hs | 31 ++++++++++++++++ internal/Data/Boltzmann/Distribution.hs | 37 ------------------- internal/Data/Boltzmann/Sampler/TH.hs | 2 +- internal/Data/Boltzmann/System.hs | 2 +- internal/Data/Boltzmann/System/TH.hs | 2 +- test/Spec.hs | 4 +- .../Unit/{Distribution.hs => BitOracle.hs} | 5 +-- 8 files changed, 39 insertions(+), 47 deletions(-) delete mode 100644 internal/Data/Boltzmann/Distribution.hs rename test/Test/Unit/{Distribution.hs => BitOracle.hs} (95%) diff --git a/generic-boltzmann-brain.cabal b/generic-boltzmann-brain.cabal index 5388772..641e0fe 100644 --- a/generic-boltzmann-brain.cabal +++ b/generic-boltzmann-brain.cabal @@ -55,7 +55,6 @@ library library generic-boltzmann-brain-internal exposed-modules: Data.Boltzmann.BitOracle - Data.Boltzmann.Distribution Data.Boltzmann.Sampler Data.Boltzmann.Sampler.TH Data.Boltzmann.System @@ -172,7 +171,7 @@ test-suite generic-boltzmann-brain-test type: exitcode-stdio-1.0 main-is: Spec.hs other-modules: - Test.Unit.Distribution + Test.Unit.BitOracle Paths_generic_boltzmann_brain hs-source-dirs: test diff --git a/internal/Data/Boltzmann/BitOracle.hs b/internal/Data/Boltzmann/BitOracle.hs index fc2b080..c26a247 100644 --- a/internal/Data/Boltzmann/BitOracle.hs +++ b/internal/Data/Boltzmann/BitOracle.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE TemplateHaskell #-} + -- | -- Module : Data.Boltzmann.BitOracle -- Description : @@ -12,6 +14,8 @@ module Data.Boltzmann.BitOracle ( EvalIO (..), eval, getBit, + Distribution (..), + choice, ) where import Control.Monad.Trans.State.Strict ( @@ -23,10 +27,13 @@ import Control.Monad.Trans.State.Strict ( ) import Data.Bits (Bits (testBit)) +import Data.Vector (Vector, null, (!)) import Data.Word (Word32) import Instances.TH.Lift () +import Language.Haskell.TH.Lift (deriveLift) import System.Random (Random (random), RandomGen, StdGen, getStdGen) import System.Random.SplitMix (SMGen, initSMGen) +import Prelude hiding (null) -- | Buffered random bit oracle. data Oracle g = Oracle @@ -94,3 +101,27 @@ instance EvalIO SMGen where instance EvalIO StdGen where {-# INLINE evalIO #-} evalIO m = eval m <$> getStdGen + +newtype Distribution = Distribution {unDistribution :: Vector Int} + deriving stock (Show) + +deriveLift ''Distribution + +-- | +-- Given a compact discrete distribution generating tree (in vector form) +-- computes a discrete random variable following that distribution. +choice :: RandomGen g => Distribution -> Discrete g +choice enc + | null (unDistribution enc) = pure 0 + | otherwise = choice' enc 0 +{-# SPECIALIZE choice :: Distribution -> Discrete SMGen #-} +{-# SPECIALIZE choice :: Distribution -> Discrete StdGen #-} + +choice' :: RandomGen g => Distribution -> Int -> Discrete g +choice' enc c = do + h <- getBit + let b = fromEnum h + let c' = unDistribution enc ! (c + b) + if unDistribution enc ! c' < 0 + then pure $ -(1 + unDistribution enc ! c') + else choice' enc c' diff --git a/internal/Data/Boltzmann/Distribution.hs b/internal/Data/Boltzmann/Distribution.hs deleted file mode 100644 index 87a7fc6..0000000 --- a/internal/Data/Boltzmann/Distribution.hs +++ /dev/null @@ -1,37 +0,0 @@ -{-# LANGUAGE TemplateHaskell #-} - -module Data.Boltzmann.Distribution ( - Distribution (..), - choice, -) where - -import Data.Boltzmann.BitOracle (Discrete, getBit) -import Data.Vector (Vector, null, (!)) -import Language.Haskell.TH.Lift (deriveLift) -import System.Random (RandomGen, StdGen) -import System.Random.SplitMix (SMGen) -import Prelude hiding (null) - -newtype Distribution = Distribution {unDistribution :: Vector Int} - deriving stock (Show) - -deriveLift ''Distribution - --- | --- Given a compact discrete distribution generating tree (in vector form) --- computes a discrete random variable following that distribution. -choice :: RandomGen g => Distribution -> Discrete g -choice enc - | null (unDistribution enc) = pure 0 - | otherwise = choice' enc 0 -{-# SPECIALIZE choice :: Distribution -> Discrete SMGen #-} -{-# SPECIALIZE choice :: Distribution -> Discrete StdGen #-} - -choice' :: RandomGen g => Distribution -> Int -> Discrete g -choice' enc c = do - h <- getBit - let b = fromEnum h - let c' = unDistribution enc ! (c + b) - if unDistribution enc ! c' < 0 - then pure $ -(1 + unDistribution enc ! c') - else choice' enc c' diff --git a/internal/Data/Boltzmann/Sampler/TH.hs b/internal/Data/Boltzmann/Sampler/TH.hs index f1645de..1ccd89f 100644 --- a/internal/Data/Boltzmann/Sampler/TH.hs +++ b/internal/Data/Boltzmann/Sampler/TH.hs @@ -21,7 +21,7 @@ import Data.Map (Map) import qualified Data.Map.Strict as Map import Control.Monad (forM) -import Data.Boltzmann.Distribution (Distribution) +import Data.Boltzmann.BitOracle (Distribution) import Data.Boltzmann.System ( ConstructorWeights (MkConstructorWeights, unConstructorWeights), Distributions (Distributions, listTypeDdgs, regTypeDdgs), diff --git a/internal/Data/Boltzmann/System.hs b/internal/Data/Boltzmann/System.hs index d918d7e..e5d9941 100644 --- a/internal/Data/Boltzmann/System.hs +++ b/internal/Data/Boltzmann/System.hs @@ -21,7 +21,7 @@ import Language.Haskell.TH.Syntax ( ) import Control.Monad (foldM, forM, replicateM, unless) -import Data.Boltzmann.Distribution (Distribution (Distribution)) +import Data.Boltzmann.BitOracle (Distribution (Distribution)) import qualified Data.Map as Map import Data.Map.Strict (Map) import Data.Maybe (fromJust, fromMaybe) diff --git a/internal/Data/Boltzmann/System/TH.hs b/internal/Data/Boltzmann/System/TH.hs index 3384525..902f079 100644 --- a/internal/Data/Boltzmann/System/TH.hs +++ b/internal/Data/Boltzmann/System/TH.hs @@ -12,7 +12,7 @@ import qualified Data.Set as Set import Control.Monad (forM, guard) import Control.Monad.Trans (MonadTrans (lift)) import Control.Monad.Trans.Reader (ReaderT (runReaderT), asks) -import Data.Boltzmann.Distribution (Distribution, choice) +import Data.Boltzmann.BitOracle (Distribution, choice) import Data.Boltzmann.Sampler.TH ( ConstructorName (MkConstructorName), ListTypeDistributions (unListTypeDistributions), diff --git a/test/Spec.hs b/test/Spec.hs index 098e32c..625f4e4 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,5 +1,5 @@ import Test.Tasty (TestTree, defaultMain, testGroup) -import qualified Test.Unit.Distribution as Distribution +import qualified Test.Unit.BitOracle as BitOracle main :: IO () main = defaultMain tests @@ -9,4 +9,4 @@ tests = testGroup "Unit tests" unitTests unitTests :: [TestTree] unitTests = - [Distribution.unitTests] + [BitOracle.unitTests] diff --git a/test/Test/Unit/Distribution.hs b/test/Test/Unit/BitOracle.hs similarity index 95% rename from test/Test/Unit/Distribution.hs rename to test/Test/Unit/BitOracle.hs index 0c684fc..873ea22 100644 --- a/test/Test/Unit/Distribution.hs +++ b/test/Test/Unit/BitOracle.hs @@ -1,8 +1,7 @@ -module Test.Unit.Distribution (unitTests) where +module Test.Unit.BitOracle (unitTests) where import Control.Monad (replicateM) -import Data.Boltzmann.Distribution (Distribution (..), choice) -import Data.Boltzmann.BitOracle (evalIO) +import Data.Boltzmann.BitOracle (Distribution (..), choice, evalIO) import qualified Data.Map as Map import Data.Vector (fromList) import System.Random.SplitMix (SMGen)