Test sampler outcome distribution

This commit is contained in:
Maciej Bendkowski 2022-03-29 18:22:12 +02:00
parent 1ea9de8136
commit 2572c62fca
4 changed files with 59 additions and 12 deletions

View File

@ -10,10 +10,9 @@ import Test.Tasty (
testGroup,
)
import Test.Tasty.HUnit (
Assertion,
assertBool,
testCase,
)
import Test.Utils (almostEqual)
tests :: TestTree
tests =
@ -67,7 +66,3 @@ choiceTest dist n = evalIO $ do
frequency :: (Ord a) => [a] -> [(a, Int)]
frequency xs = Map.toList (Map.fromListWith (+) [(x, 1) | x <- xs])
almostEqual :: (Show a, Ord a, Fractional a) => a -> a -> Assertion
almostEqual a b =
assertBool ("Was " ++ show a ++ " " ++ show b) $ abs (a - b) < 0.01

View File

@ -1,12 +1,27 @@
module Test.Sampler (tests) where
import Test.Samplers.BinTree (BinTree)
import Test.Samplers.Lambda (BinLambda, Lambda)
import Test.Samplers.Tree (Tree)
import Test.Utils (Size (size))
import Data.Boltzmann (
BuffonMachine,
EvalIO (evalIO),
LowerBound (..),
UpperBound (..),
rejectionSampler,
)
import Test.Samplers.BinTree (BinTree)
import Test.Samplers.Lambda (BinLambda, Lambda, abstractions)
import Test.Samplers.Tree (Tree)
import System.Random.SplitMix (SMGen)
import Control.Monad (replicateM)
import Data.List (genericLength)
import Test.Tasty (TestTree, testGroup)
import Test.Tasty.HUnit (
testCase,
)
import Test.Tasty.QuickCheck as QC (testProperty)
import Test.Utils (Size (size), close)
tests :: TestTree
tests =
@ -28,4 +43,22 @@ tests =
\term ->
let s = size @BinLambda term
in 5_000 <= s && s <= 6_400
, testCase "Lambda sampler has the correct output distribution" $ do
(obsSize, obsAbs) <- runLambdaSampler 1_000
close obsSize 10_000 0.2 -- just to be sure
close obsAbs 4_000 0.2 -- just to be sure
]
lambdaSampler :: BuffonMachine SMGen Lambda
lambdaSampler = rejectionSampler (MkLowerBound 8_000) (MkUpperBound 12_000)
runLambdaSampler :: Int -> IO (Double, Double)
runLambdaSampler n = evalIO $ do
sam <- replicateM n lambdaSampler
pure $ statistics $ (\t -> (size t, abstractions t)) <$> sam
statistics :: [(Int, Int)] -> (Double, Double)
statistics xs = (average $ fst <$> xs, average $ snd <$> xs)
average :: (Real a, Fractional b) => [a] -> b
average xs = realToFrac (sum xs) / genericLength xs

View File

@ -5,6 +5,7 @@ module Test.Samplers.Lambda (
DeBruijn (..),
Lambda (..),
BinLambda (..),
abstractions,
) where
import Data.Boltzmann (
@ -20,7 +21,7 @@ import Data.Boltzmann (
import Data.Default (def)
import GHC.Generics (Generic)
import Test.QuickCheck (Arbitrary (arbitrary, shrink))
import Test.Utils (Size(size))
import Test.Utils (Size (size))
data DeBruijn
= Z
@ -37,6 +38,12 @@ data Lambda
| Abs Lambda
deriving (Generic, Show)
abstractions :: Lambda -> Int
abstractions = \case
Index _ -> 0
App lt rt -> abstractions lt + abstractions rt
Abs t -> 1 + abstractions t
mkBoltzmannSampler
System
{ targetType = ''Lambda

View File

@ -1,5 +1,17 @@
module Test.Utils (Size(..)) where
module Test.Utils (Size (..), almostEqual, close) where
import Test.Tasty.HUnit (Assertion, assertBool)
-- | Objects with size.
class Size a where
size :: a -> Int
almostEqual :: (Show a, Ord a, Fractional a) => a -> a -> Assertion
almostEqual a b =
assertBool ("Was " ++ show a ++ " expected almost " ++ show b) $
abs (a - b) < 0.01
close :: (Show a, Ord a, Num a) => a -> a -> a -> Assertion
close a b eps =
assertBool ("Was " ++ show a ++ " expected close to " ++ show b) $
(1 - eps) * b <= a && a <= (1 + eps) * b