Get rid of Samplable

This commit is contained in:
Maciej Bendkowski 2022-03-20 19:12:53 +01:00
parent 20607c1167
commit b74fb6e5fe
7 changed files with 29 additions and 34 deletions

View File

@ -26,7 +26,7 @@ source-repository head
library library
exposed-modules: exposed-modules:
Data.Boltzmann.Samplable Data.Boltzmann.Distribution
Data.Boltzmann.Sampler Data.Boltzmann.Sampler
Data.Boltzmann.Sampler.TH Data.Boltzmann.Sampler.TH
Data.Boltzmann.System Data.Boltzmann.System
@ -140,7 +140,7 @@ test-suite generic-boltzmann-brain-test
type: exitcode-stdio-1.0 type: exitcode-stdio-1.0
main-is: Spec.hs main-is: Spec.hs
other-modules: other-modules:
Test.Unit.Samplable Test.Unit.Distribution
Paths_generic_boltzmann_brain Paths_generic_boltzmann_brain
hs-source-dirs: hs-source-dirs:
test test

View File

@ -1,7 +1,6 @@
{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TemplateHaskell #-}
module Data.Boltzmann.Samplable ( module Data.Boltzmann.Distribution (
Samplable (..),
Distribution (..), Distribution (..),
choice, choice,
) where ) where
@ -13,11 +12,7 @@ import System.Random (RandomGen, StdGen)
import System.Random.SplitMix (SMGen) import System.Random.SplitMix (SMGen)
import Prelude hiding (null) import Prelude hiding (null)
class Samplable a where newtype Distribution = Distribution {unDistribution :: Vector Int}
distribution :: Distribution a
weight :: a -> Int
newtype Distribution a = Distribution {unDistribution :: Vector Int}
deriving stock (Show) deriving stock (Show)
deriveLift ''Distribution deriveLift ''Distribution
@ -25,18 +20,18 @@ deriveLift ''Distribution
-- | -- |
-- Given a compact discrete distribution generating tree (in vector form) -- Given a compact discrete distribution generating tree (in vector form)
-- computes a discrete random variable following that distribution. -- computes a discrete random variable following that distribution.
choice :: RandomGen g => Distribution a -> Discrete g choice :: RandomGen g => Distribution -> Discrete g
choice enc choice enc
| null (unDistribution enc) = return 0 | null (unDistribution enc) = pure 0
| otherwise = choice' enc 0 | otherwise = choice' enc 0
{-# SPECIALIZE choice :: Distribution a -> Discrete SMGen #-} {-# SPECIALIZE choice :: Distribution -> Discrete SMGen #-}
{-# SPECIALIZE choice :: Distribution a -> Discrete StdGen #-} {-# SPECIALIZE choice :: Distribution -> Discrete StdGen #-}
choice' :: RandomGen g => Distribution a -> Int -> Discrete g choice' :: RandomGen g => Distribution -> Int -> Discrete g
choice' enc c = do choice' enc c = do
h <- getBit h <- getBit
let b = fromEnum h let b = fromEnum h
let c' = unDistribution enc ! (c + b) let c' = unDistribution enc ! (c + b)
if unDistribution enc ! c' < 0 if unDistribution enc ! c' < 0
then return $ -(1 + unDistribution enc ! c') then pure $ -(1 + unDistribution enc ! c')
else choice' enc c' else choice' enc c'

View File

@ -19,7 +19,7 @@ import Data.Map (Map)
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
import Control.Monad (forM) import Control.Monad (forM)
import Data.Boltzmann.Samplable (Distribution) import Data.Boltzmann.Distribution (Distribution)
import Data.Boltzmann.System ( import Data.Boltzmann.System (
Distributions (Distributions, listTypeDdgs, regTypeDdgs), Distributions (Distributions, listTypeDdgs, regTypeDdgs),
System (targetType, weights), System (targetType, weights),
@ -72,7 +72,7 @@ idResolver :: SynonymResolver
idResolver = MkSynonymResolver $ pure . coerce idResolver = MkSynonymResolver $ pure . coerce
newtype TypeDistributions a = MkTypeDistributions newtype TypeDistributions a = MkTypeDistributions
{ unTypeDistributions :: TypeName -> Q (Distribution a) { unTypeDistributions :: TypeName -> Q Distribution
} }
mkTypeDistributions :: Distributions a -> TypeDistributions a mkTypeDistributions :: Distributions a -> TypeDistributions a
@ -85,7 +85,7 @@ mkTypeDistributions Distributions {regTypeDdgs} =
"Missing type constructor distribution for " ++ show n "Missing type constructor distribution for " ++ show n
newtype ListTypeDistributions a = MkListTypeDistributions newtype ListTypeDistributions a = MkListTypeDistributions
{ unListTypeDistributions :: TypeName -> Q (Distribution a) { unListTypeDistributions :: TypeName -> Q Distribution
} }
mkListTypeDistributions :: Distributions a -> ListTypeDistributions a mkListTypeDistributions :: Distributions a -> ListTypeDistributions a

View File

@ -10,7 +10,7 @@ module Data.Boltzmann.System (
import Language.Haskell.TH.Syntax (Name, Type (AppT, ConT, ListT)) import Language.Haskell.TH.Syntax (Name, Type (AppT, ConT, ListT))
import Control.Monad (foldM, forM, replicateM) import Control.Monad (foldM, forM, replicateM)
import Data.Boltzmann.Samplable (Distribution (Distribution)) import Data.Boltzmann.Distribution (Distribution (Distribution))
import qualified Data.Map as Map import qualified Data.Map as Map
import Data.Map.Strict (Map) import Data.Map.Strict (Map)
import Data.Maybe (fromJust, fromMaybe) import Data.Maybe (fromJust, fromMaybe)
@ -122,7 +122,7 @@ mkMarkingVariables sys = do
mapM mapM
( \(cons, freq) -> do ( \(cons, freq) -> do
x <- variable' $ fromIntegral freq x <- variable' $ fromIntegral freq
return (cons, x) pure (cons, x)
) )
(frequencies sys) (frequencies sys)
@ -172,8 +172,8 @@ defaults Nothing = 1
defaults (Just (Let x)) = x defaults (Just (Let x)) = x
data Distributions a = Distributions data Distributions a = Distributions
{ regTypeDdgs :: Map Name (Distribution a) { regTypeDdgs :: Map Name Distribution
, listTypeDdgs :: Map Name (Distribution a) , listTypeDdgs :: Map Name Distribution
} }
deriving stock (Show) deriving stock (Show)

View File

@ -9,7 +9,7 @@ import qualified Data.Set as Set
import Control.Monad (forM, guard) import Control.Monad (forM, guard)
import Control.Monad.Trans (MonadTrans (lift)) import Control.Monad.Trans (MonadTrans (lift))
import Control.Monad.Trans.Reader (ReaderT (runReaderT), asks) import Control.Monad.Trans.Reader (ReaderT (runReaderT), asks)
import Data.Boltzmann.Samplable (Distribution, choice) import Data.Boltzmann.Distribution (Distribution, choice)
import Data.Boltzmann.Sampler.TH ( import Data.Boltzmann.Sampler.TH (
ConstructorName (MkConstructorName), ConstructorName (MkConstructorName),
ListTypeDistributions (unListTypeDistributions), ListTypeDistributions (unListTypeDistributions),
@ -66,7 +66,7 @@ findTypeSyn = \case
Plain tn -> getSynonym tn <&> ConT . coerce Plain tn -> getSynonym tn <&> ConT . coerce
List tn -> getSynonym tn <&> AppT ListT . ConT . coerce List tn -> getSynonym tn <&> AppT ListT . ConT . coerce
getDistribution :: TypeVariant -> SamplerGen (Distribution ()) getDistribution :: TypeVariant -> SamplerGen Distribution
getDistribution = \case getDistribution = \case
Plain tn -> do Plain tn -> do
distributions <- asks typeDistributions distributions <- asks typeDistributions

View File

@ -1,5 +1,5 @@
import Test.Tasty (TestTree, defaultMain, testGroup) import Test.Tasty (TestTree, defaultMain, testGroup)
import qualified Test.Unit.Samplable as Samplable import qualified Test.Unit.Distribution as Distribution
main :: IO () main :: IO ()
main = defaultMain tests main = defaultMain tests
@ -9,4 +9,4 @@ tests = testGroup "Unit tests" unitTests
unitTests :: [TestTree] unitTests :: [TestTree]
unitTests = unitTests =
[Samplable.unitTests] [Distribution.unitTests]

View File

@ -1,7 +1,7 @@
module Test.Unit.Samplable (unitTests) where module Test.Unit.Distribution (unitTests) where
import Control.Monad (replicateM) import Control.Monad (replicateM)
import Data.Boltzmann.Samplable (Distribution (..), choice) import Data.Boltzmann.Distribution (Distribution (..), choice)
import Data.BuffonMachine (evalIO) import Data.BuffonMachine (evalIO)
import qualified Data.Map as Map import qualified Data.Map as Map
import Data.Vector (fromList) import Data.Vector (fromList)
@ -46,25 +46,25 @@ choiceTests =
] ]
-- [1/2, 1/2] -- [1/2, 1/2]
distributionA :: Distribution a distributionA :: Distribution
distributionA = Distribution $ fromList [2, 3, -2, -1] distributionA = Distribution $ fromList [2, 3, -2, -1]
-- [1/3, 1/3, 1/3] -- [1/3, 1/3, 1/3]
distributionB :: Distribution a distributionB :: Distribution
distributionB = Distribution $ fromList [2, 138, 4, 137, 6, 133, 8, 132, 10, 128, 12, 127, 14, 123, 16, 122, 18, 118, 20, 117, 22, 113, 24, 112, 26, 108, 28, 107, 30, 103, 32, 102, 34, 98, 36, 97, 38, 93, 40, 92, 42, 88, 44, 87, 46, 83, 48, 82, 50, 78, 52, 77, 54, 73, 56, 72, 58, 68, 60, 67, 62, 66, 64, 65, -2, -1, -3, -3, 70, 71, -2, -1, -3, 75, 76, -2, -1, -3, 80, 81, -2, -1, -3, 85, 86, -2, -1, -3, 90, 91, -2, -1, -3, 95, 96, -2, -1, -3, 100, 101, -2, -1, -3, 105, 106, -2, -1, -3, 110, 111, -2, -1, -3, 115, 116, -2, -1, -3, 120, 121, -2, -1, -3, 125, 126, -2, -1, -3, 130, 131, -2, -1, -3, 135, 136, -2, -1, -3, 140, 141, -2, -1] distributionB = Distribution $ fromList [2, 138, 4, 137, 6, 133, 8, 132, 10, 128, 12, 127, 14, 123, 16, 122, 18, 118, 20, 117, 22, 113, 24, 112, 26, 108, 28, 107, 30, 103, 32, 102, 34, 98, 36, 97, 38, 93, 40, 92, 42, 88, 44, 87, 46, 83, 48, 82, 50, 78, 52, 77, 54, 73, 56, 72, 58, 68, 60, 67, 62, 66, 64, 65, -2, -1, -3, -3, 70, 71, -2, -1, -3, 75, 76, -2, -1, -3, 80, 81, -2, -1, -3, 85, 86, -2, -1, -3, 90, 91, -2, -1, -3, 95, 96, -2, -1, -3, 100, 101, -2, -1, -3, 105, 106, -2, -1, -3, 110, 111, -2, -1, -3, 115, 116, -2, -1, -3, 120, 121, -2, -1, -3, 125, 126, -2, -1, -3, 130, 131, -2, -1, -3, 135, 136, -2, -1, -3, 140, 141, -2, -1]
-- [1/7, 4/7, 2/7] -- [1/7, 4/7, 2/7]
distributionC :: Distribution a distributionC :: Distribution
distributionC = Distribution $ fromList [2, 96, 4, 95, 6, 94, 8, 93, 10, 92, 12, 91, 14, 90, 16, 89, 18, 88, 20, 87, 22, 86, 24, 85, 26, 84, 28, 83, 30, 82, 32, 81, 34, 80, 36, 79, 38, 78, 40, 77, 42, 76, 44, 75, 46, 74, 48, 73, 50, 72, 52, 71, 54, 70, 56, 69, 58, 68, 60, 67, 62, 66, 64, 65, -3, -1, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2] distributionC = Distribution $ fromList [2, 96, 4, 95, 6, 94, 8, 93, 10, 92, 12, 91, 14, 90, 16, 89, 18, 88, 20, 87, 22, 86, 24, 85, 26, 84, 28, 83, 30, 82, 32, 81, 34, 80, 36, 79, 38, 78, 40, 77, 42, 76, 44, 75, 46, 74, 48, 73, 50, 72, 52, 71, 54, 70, 56, 69, 58, 68, 60, 67, 62, 66, 64, 65, -3, -1, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2, -1, -3, -2]
distributionD :: Distribution a distributionD :: Distribution
distributionD = Distribution $ fromList [] distributionD = Distribution $ fromList []
choiceTest :: Distribution a -> Int -> IO [(Int, Double)] choiceTest :: Distribution -> Int -> IO [(Int, Double)]
choiceTest dist n = evalIO $ do choiceTest dist n = evalIO $ do
sam <- replicateM n (choice @SMGen dist) sam <- replicateM n (choice @SMGen dist)
let groups = frequency sam let groups = frequency sam
return $ map (\(k, s) -> (k, fromIntegral s / fromIntegral n)) groups pure $ map (\(k, s) -> (k, fromIntegral s / fromIntegral n)) groups
frequency :: (Ord a) => [a] -> [(a, Int)] frequency :: (Ord a) => [a] -> [(a, Int)]
frequency xs = Map.toList (Map.fromListWith (+) [(x, 1) | x <- xs]) frequency xs = Map.toList (Map.fromListWith (+) [(x, 1) | x <- xs])