Test that list synonym samplers respect size constraints

This commit is contained in:
Maciej Bendkowski 2022-03-30 18:36:16 +02:00
parent 2572c62fca
commit 50cd6352d6
3 changed files with 55 additions and 29 deletions

View File

@ -100,20 +100,6 @@ getWeight constr = do
weightResolver <- asks constructorWeight weightResolver <- asks constructorWeight
lift $ weightResolver `unWeightResolver` constr lift $ weightResolver `unWeightResolver` constr
mkCoerce :: TypeVariant -> SamplerGen Exp
mkCoerce tv = do
typSynonym <- findTypeSyn tv
let fromType = convert tv
toType = typSynonym
coerce' <- lift [|coerce|]
pure $ AppTypeE (AppTypeE coerce' fromType) toType
where
convert :: TypeVariant -> Type
convert = \case
Plain tn -> ConT $ coerce tn
List tn -> AppT ListT (ConT $ coerce tn)
toTypeVariant :: Type -> SamplerGen TypeVariant toTypeVariant :: Type -> SamplerGen TypeVariant
toTypeVariant (ConT tn) = pure . Plain $ coerce tn toTypeVariant (ConT tn) = pure . Plain $ coerce tn
toTypeVariant (AppT ListT (ConT tn)) = pure . List $ coerce tn toTypeVariant (AppT ListT (ConT tn)) = pure . List $ coerce tn
@ -160,17 +146,25 @@ mkCaseConstr = \case
caseMatches <- mapM (mkCaseMatch tv) constrGroup caseMatches <- mapM (mkCaseMatch tv) constrGroup
pure $ LamCaseE caseMatches pure $ LamCaseE caseMatches
tv -> do tv@(List tn) ->
coerceExp <- mkCoerce tv do
lift typSynonym <- findTypeSyn (Plain tn)
[| listTypSynonym <- findTypeSyn tv
\case
0 -> pure ([], 0) lift
1 -> do [|
(x, w) <- sample ub \case
(xs, ws) <- sample (ub - w) 0 -> pure ([], 0)
pure ($(pure coerceExp) (x : xs), w + ws) 1 -> do
|] (x, w) <- $(sampleExp typSynonym) ub
(xs, ws) <- $(sampleExp listTypSynonym) (ub - w)
pure ((x : xs), w + ws)
|]
sampleExp :: Type -> Q Exp
sampleExp t = do
sample' <- [|sample|]
pure $ AppTypeE sample' t
mkCaseMatch :: TypeVariant -> (ConstructorInfo, Integer) -> SamplerGen Match mkCaseMatch :: TypeVariant -> (ConstructorInfo, Integer) -> SamplerGen Match
mkCaseMatch tv (constr, idx) = do mkCaseMatch tv (constr, idx) = do

View File

@ -10,7 +10,7 @@ import Data.Boltzmann (
import Test.Samplers.BinTree (BinTree) import Test.Samplers.BinTree (BinTree)
import Test.Samplers.Lambda (BinLambda, Lambda, abstractions) import Test.Samplers.Lambda (BinLambda, Lambda, abstractions)
import Test.Samplers.Tree (Tree) import Test.Samplers.Tree (Tree, Tree')
import System.Random.SplitMix (SMGen) import System.Random.SplitMix (SMGen)
@ -35,6 +35,10 @@ tests =
\tree -> \tree ->
let s = size @Tree tree let s = size @Tree tree
in 1600 <= s && s <= 2400 in 1600 <= s && s <= 2400
, QC.testProperty "Tree' sampler respects size constraints" $
\tree ->
let s = size @Tree' tree
in 8500 <= s && s <= 11_150
, QC.testProperty "Lambda sampler respects size constraints" $ , QC.testProperty "Lambda sampler respects size constraints" $
\term -> \term ->
let s = size @Lambda term let s = size @Lambda term

View File

@ -1,18 +1,23 @@
{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TemplateHaskell #-}
module Test.Samplers.Tree (Tree (..)) where module Test.Samplers.Tree (Tree (..), Tree' (..)) where
import Data.Boltzmann ( import Data.Boltzmann (
BoltzmannSampler (..), BoltzmannSampler (..),
Constructable (..),
LowerBound (MkLowerBound), LowerBound (MkLowerBound),
System (..),
UpperBound (MkUpperBound), UpperBound (MkUpperBound),
hoistRejectionSampler, hoistRejectionSampler,
mkBoltzmannSampler,
mkDefBoltzmannSampler, mkDefBoltzmannSampler,
mkDefWeights,
) )
import Data.Default (def)
import GHC.Generics (Generic) import GHC.Generics (Generic)
import Test.QuickCheck (Arbitrary (arbitrary, shrink)) import Test.QuickCheck (Arbitrary (arbitrary, shrink))
import Test.Utils (Size(size)) import Test.Utils (Size (size))
data Tree = T [Tree] data Tree = T [Tree]
deriving (Generic, Show) deriving (Generic, Show)
@ -21,10 +26,33 @@ mkDefBoltzmannSampler ''Tree 2000
instance Size Tree where instance Size Tree where
size = \case size = \case
T ts -> 1 + sum (map size ts) T ts -> 1 + sum (size <$> ts)
instance Arbitrary Tree where instance Arbitrary Tree where
arbitrary = arbitrary =
hoistRejectionSampler $ hoistRejectionSampler $
const (MkLowerBound 1600, MkUpperBound 2400) const (MkLowerBound 1600, MkUpperBound 2400)
shrink = const [] shrink = const []
newtype Tree' = MkTree' Tree
deriving (Generic, Show)
instance Size Tree' where
size = \case
(MkTree' (T ts)) -> 2 + sum (size . MkTree' <$> ts)
mkBoltzmannSampler
System
{ targetType = ''Tree'
, meanSize = 10_000
, frequencies = def
, weights =
('T, 2)
<:> $(mkDefWeights ''Tree)
}
instance Arbitrary Tree' where
arbitrary =
hoistRejectionSampler $
const (MkLowerBound 8500, MkUpperBound 11_150)
shrink = const []