mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Use a blinded forAll for proxy like generators.
This is a bit ironic, as the Typable constraint only exists for 7.10 printing in hedgehog, but the 7.10 proxies don't have typeable.
This commit is contained in:
parent
adf731218c
commit
6c6e706e66
@ -61,13 +61,13 @@ genOpaqueOpaqueConvolution = do
|
||||
Dict -> OpaqueConvolution <$> (genConvolution :: Jack (Convolution ch fl kr kc sr sc))
|
||||
|
||||
prop_conv_net_witness = property $
|
||||
forAll genOpaqueOpaqueConvolution >>= \onet ->
|
||||
blindForAll genOpaqueOpaqueConvolution >>= \onet ->
|
||||
case onet of
|
||||
(OpaqueConvolution ((Convolution _ _) :: Convolution channels filters kernelRows kernelCols strideRows strideCols)) -> success
|
||||
|
||||
|
||||
prop_conv_net = property $
|
||||
forAll genOpaqueOpaqueConvolution >>= \onet ->
|
||||
blindForAll genOpaqueOpaqueConvolution >>= \onet ->
|
||||
case onet of
|
||||
(OpaqueConvolution (convLayer@(Convolution _ _) :: Convolution channels filters kernelRows kernelCols strideRows strideCols)) ->
|
||||
let ok stride kernel = [extent | extent <- [(kernel + 1) .. 30 ], (extent - kernel) `mod` stride == 0]
|
||||
@ -92,7 +92,7 @@ prop_conv_net = property $
|
||||
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outRows - 1) * strideRows) ~ (inRows - kernelRows)))
|
||||
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outCols - 1) * strideCols) ~ (inCols - kernelCols)))) of
|
||||
(Dict, Dict, Dict, Dict) ->
|
||||
forAll (S3D <$> uniformSample) >>= \(input :: S ('D3 inRows inCols channels)) ->
|
||||
blindForAll (S3D <$> uniformSample) >>= \(input :: S ('D3 inRows inCols channels)) ->
|
||||
let (tape, output :: S ('D3 outRows outCols filters)) = runForwards convLayer input
|
||||
backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S ('D3 inRows inCols channels))
|
||||
= runBackwards convLayer tape output
|
||||
|
@ -17,9 +17,8 @@ import Grenade.Layers.FullyConnected
|
||||
|
||||
import Hedgehog
|
||||
|
||||
import Test.Jack.Hmatrix
|
||||
import Test.Jack.Compat
|
||||
|
||||
import Test.Jack.Hmatrix
|
||||
|
||||
data OpaqueFullyConnected :: * where
|
||||
OpaqueFullyConnected :: (KnownNat i, KnownNat o) => FullyConnected i o -> OpaqueFullyConnected
|
||||
@ -43,8 +42,8 @@ genOpaqueFullyConnected = do
|
||||
|
||||
prop_fully_connected_forwards :: Property
|
||||
prop_fully_connected_forwards = property $ do
|
||||
OpaqueFullyConnected (fclayer :: FullyConnected i o) <- forAll genOpaqueFullyConnected
|
||||
input :: S ('D1 i) <- forAll (S1D <$> randomVector)
|
||||
OpaqueFullyConnected (fclayer :: FullyConnected i o) <- blindForAll genOpaqueFullyConnected
|
||||
input :: S ('D1 i) <- blindForAll (S1D <$> randomVector)
|
||||
let (tape, output :: S ('D1 o)) = runForwards fclayer input
|
||||
backed :: (Gradient (FullyConnected i o), S ('D1 i))
|
||||
= runBackwards fclayer tape output
|
||||
|
@ -9,9 +9,6 @@
|
||||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||
module Test.Grenade.Layers.Nonlinear where
|
||||
|
||||
|
||||
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||
|
||||
import Data.Singletons
|
||||
|
||||
#if __GLASGOW_HASKELL__ < 800
|
||||
@ -22,18 +19,13 @@ import Grenade
|
||||
import GHC.TypeLits
|
||||
|
||||
import Hedgehog
|
||||
import Hedgehog.Internal.Property ( Test (..) )
|
||||
|
||||
import Test.Jack.Compat
|
||||
import Test.Jack.Hmatrix
|
||||
import Test.Jack.TypeLits
|
||||
|
||||
import Numeric.LinearAlgebra.Static ( norm_Inf )
|
||||
|
||||
-- | Generates a random input for the test by running the provided generator.
|
||||
--
|
||||
blindForAll :: Monad m => Gen m a -> Test m a
|
||||
blindForAll = Test . lift . lift
|
||||
|
||||
prop_sigmoid_grad :: Property
|
||||
prop_sigmoid_grad = property $
|
||||
blindForAll genShape >>= \case
|
||||
@ -66,9 +58,9 @@ prop_tanh_grad = property $
|
||||
|
||||
prop_softmax_grad :: Property
|
||||
prop_softmax_grad = property $
|
||||
forAll genNat >>= \case
|
||||
blindForAll genNat >>= \case
|
||||
(SomeNat (_ :: Proxy s)) ->
|
||||
forAll genOfShape >>= \(ds :: S ('D1 s)) ->
|
||||
blindForAll genOfShape >>= \(ds :: S ('D1 s)) ->
|
||||
let (tape, f :: S ('D1 s)) = runForwards Relu ds
|
||||
((), ret :: S ('D1 s)) = runBackwards Relu tape (1 :: S ('D1 s))
|
||||
(_, numer :: S ('D1 s)) = runForwards Relu (ds + 0.0001)
|
||||
|
@ -1,9 +1,12 @@
|
||||
{-# LANGUAGE RankNTypes #-}
|
||||
module Test.Jack.Compat where
|
||||
|
||||
import Control.Monad.Trans.Class (MonadTrans(..))
|
||||
|
||||
import Hedgehog
|
||||
import qualified Hedgehog.Gen as Gen
|
||||
import qualified Hedgehog.Range as Range
|
||||
import Hedgehog.Internal.Property ( Test (..) )
|
||||
|
||||
type Jack x = forall m. Monad m => Gen m x
|
||||
|
||||
@ -14,3 +17,6 @@ type Jack x = forall m. Monad m => Gen m x
|
||||
choose :: Integral a => a -> a -> Jack a
|
||||
choose = Gen.integral ... Range.constant
|
||||
|
||||
-- | Generates a random input for the test by running the provided generator.
|
||||
blindForAll :: Monad m => Gen m a -> Test m a
|
||||
blindForAll = Test . lift . lift
|
||||
|
Loading…
Reference in New Issue
Block a user