Use a blinded forAll for proxy like generators.

This is a bit ironic, as the Typable constraint only exists for
7.10 printing in hedgehog, but the 7.10 proxies don't have typeable.
This commit is contained in:
Huw Campbell 2017-04-10 12:37:45 +10:00
parent adf731218c
commit 6c6e706e66
4 changed files with 15 additions and 18 deletions

View File

@ -61,13 +61,13 @@ genOpaqueOpaqueConvolution = do
Dict -> OpaqueConvolution <$> (genConvolution :: Jack (Convolution ch fl kr kc sr sc))
prop_conv_net_witness = property $
forAll genOpaqueOpaqueConvolution >>= \onet ->
blindForAll genOpaqueOpaqueConvolution >>= \onet ->
case onet of
(OpaqueConvolution ((Convolution _ _) :: Convolution channels filters kernelRows kernelCols strideRows strideCols)) -> success
prop_conv_net = property $
forAll genOpaqueOpaqueConvolution >>= \onet ->
blindForAll genOpaqueOpaqueConvolution >>= \onet ->
case onet of
(OpaqueConvolution (convLayer@(Convolution _ _) :: Convolution channels filters kernelRows kernelCols strideRows strideCols)) ->
let ok stride kernel = [extent | extent <- [(kernel + 1) .. 30 ], (extent - kernel) `mod` stride == 0]
@ -92,7 +92,7 @@ prop_conv_net = property $
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outRows - 1) * strideRows) ~ (inRows - kernelRows)))
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outCols - 1) * strideCols) ~ (inCols - kernelCols)))) of
(Dict, Dict, Dict, Dict) ->
forAll (S3D <$> uniformSample) >>= \(input :: S ('D3 inRows inCols channels)) ->
blindForAll (S3D <$> uniformSample) >>= \(input :: S ('D3 inRows inCols channels)) ->
let (tape, output :: S ('D3 outRows outCols filters)) = runForwards convLayer input
backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S ('D3 inRows inCols channels))
= runBackwards convLayer tape output

View File

@ -17,9 +17,8 @@ import Grenade.Layers.FullyConnected
import Hedgehog
import Test.Jack.Hmatrix
import Test.Jack.Compat
import Test.Jack.Hmatrix
data OpaqueFullyConnected :: * where
OpaqueFullyConnected :: (KnownNat i, KnownNat o) => FullyConnected i o -> OpaqueFullyConnected
@ -43,8 +42,8 @@ genOpaqueFullyConnected = do
prop_fully_connected_forwards :: Property
prop_fully_connected_forwards = property $ do
OpaqueFullyConnected (fclayer :: FullyConnected i o) <- forAll genOpaqueFullyConnected
input :: S ('D1 i) <- forAll (S1D <$> randomVector)
OpaqueFullyConnected (fclayer :: FullyConnected i o) <- blindForAll genOpaqueFullyConnected
input :: S ('D1 i) <- blindForAll (S1D <$> randomVector)
let (tape, output :: S ('D1 o)) = runForwards fclayer input
backed :: (Gradient (FullyConnected i o), S ('D1 i))
= runBackwards fclayer tape output

View File

@ -9,9 +9,6 @@
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
module Test.Grenade.Layers.Nonlinear where
import Control.Monad.Trans.Class (MonadTrans(..))
import Data.Singletons
#if __GLASGOW_HASKELL__ < 800
@ -22,18 +19,13 @@ import Grenade
import GHC.TypeLits
import Hedgehog
import Hedgehog.Internal.Property ( Test (..) )
import Test.Jack.Compat
import Test.Jack.Hmatrix
import Test.Jack.TypeLits
import Numeric.LinearAlgebra.Static ( norm_Inf )
-- | Generates a random input for the test by running the provided generator.
--
blindForAll :: Monad m => Gen m a -> Test m a
blindForAll = Test . lift . lift
prop_sigmoid_grad :: Property
prop_sigmoid_grad = property $
blindForAll genShape >>= \case
@ -66,9 +58,9 @@ prop_tanh_grad = property $
prop_softmax_grad :: Property
prop_softmax_grad = property $
forAll genNat >>= \case
blindForAll genNat >>= \case
(SomeNat (_ :: Proxy s)) ->
forAll genOfShape >>= \(ds :: S ('D1 s)) ->
blindForAll genOfShape >>= \(ds :: S ('D1 s)) ->
let (tape, f :: S ('D1 s)) = runForwards Relu ds
((), ret :: S ('D1 s)) = runBackwards Relu tape (1 :: S ('D1 s))
(_, numer :: S ('D1 s)) = runForwards Relu (ds + 0.0001)

View File

@ -1,9 +1,12 @@
{-# LANGUAGE RankNTypes #-}
module Test.Jack.Compat where
import Control.Monad.Trans.Class (MonadTrans(..))
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Hedgehog.Internal.Property ( Test (..) )
type Jack x = forall m. Monad m => Gen m x
@ -14,3 +17,6 @@ type Jack x = forall m. Monad m => Gen m x
choose :: Integral a => a -> a -> Jack a
choose = Gen.integral ... Range.constant
-- | Generates a random input for the test by running the provided generator.
blindForAll :: Monad m => Gen m a -> Test m a
blindForAll = Test . lift . lift