diff --git a/test/Test/Sampler.hs b/test/Test/Sampler.hs index cdc3bb7..6d61ce1 100644 --- a/test/Test/Sampler.hs +++ b/test/Test/Sampler.hs @@ -51,16 +51,28 @@ tests = (obsSize, obsAbs) <- runLambdaSampler 1_000 close obsSize 10_000 0.2 -- just to be sure close obsAbs 4_000 0.2 -- just to be sure + , testCase "Binary lambda sampler has the correct output distribution" $ do + (obsSize, obsAbs) <- runBinLambdaSampler 1_000 + close obsSize 6_000 0.2 -- just to be sure + close obsAbs 2_340 0.2 -- just to be sure ] lambdaSampler :: BuffonMachine SMGen Lambda lambdaSampler = rejectionSampler (MkLowerBound 8_000) (MkUpperBound 12_000) +binLambdaSampler :: BuffonMachine SMGen Lambda +binLambdaSampler = rejectionSampler (MkLowerBound 5_000) (MkUpperBound 6_400) + runLambdaSampler :: Int -> IO (Double, Double) runLambdaSampler n = evalIO $ do sam <- replicateM n lambdaSampler pure $ statistics $ (\t -> (size t, abstractions t)) <$> sam +runBinLambdaSampler :: Int -> IO (Double, Double) +runBinLambdaSampler n = evalIO $ do + sam <- replicateM n binLambdaSampler + pure $ statistics $ (\t -> (size t, abstractions t)) <$> sam + statistics :: [(Int, Int)] -> (Double, Double) statistics xs = (average $ fst <$> xs, average $ snd <$> xs) diff --git a/test/Test/Samplers/Lambda.hs b/test/Test/Samplers/Lambda.hs index dece87a..b1bc88d 100644 --- a/test/Test/Samplers/Lambda.hs +++ b/test/Test/Samplers/Lambda.hs @@ -85,7 +85,7 @@ mkBoltzmannSampler System { targetType = ''BinLambda , meanSize = 6_000 - , frequencies = def + , frequencies = ('Abs, 2340) <:> def , weights = ('Index, 0) <:> ('App, 2)