Move tests to hedgehog.

This commit is contained in:
Huw Campbell 2017-04-10 08:49:18 +10:00
parent 8ebbea6a1c
commit da810e6f4e
13 changed files with 201 additions and 173 deletions

3
.gitmodules vendored
View File

@ -1,3 +0,0 @@
[submodule "lib/haskell-hedgehog"]
path = lib/haskell-hedgehog
url = git@github.com:hedgehogqa/haskell-hedgehog.git

View File

@ -125,16 +125,14 @@ test-suite test
build-depends:
base >= 4.8 && < 5
, grenade
, ambiata-disorder-core
, ambiata-disorder-jack
, hedgehog >= 0.1 && < 0.2
, hmatrix
, mtl
, singletons
, text == 1.2.*
, typelits-witnesses
, transformers
, constraints
, QuickCheck >= 2.7 && < 2.9
, quickcheck-instances == 0.3.*
, MonadRandom
, random
, ad

@ -1 +0,0 @@
Subproject commit d2a02fe40b621db5fb7b9ee8a4daef1949e95ef2

View File

@ -18,10 +18,12 @@ import GHC.TypeLits.Witnesses
import Grenade.Core
import Grenade.Layers.Convolution
import Disorder.Jack
import Hedgehog
import qualified Hedgehog.Gen as Gen
import Test.Jack.Hmatrix
import Test.Jack.TypeLits
import Test.Jack.Compat
data OpaqueConvolution :: * where
OpaqueConvolution :: Convolution channels filters kernelRows kernelColumns strideRows strideColumns -> OpaqueConvolution
@ -58,15 +60,15 @@ genOpaqueOpaqueConvolution = do
in case p1 %* p2 %* p3 of
Dict -> OpaqueConvolution <$> (genConvolution :: Jack (Convolution ch fl kr kc sr sc))
prop_conv_net_witness =
gamble genOpaqueOpaqueConvolution $ \onet ->
(case onet of
(OpaqueConvolution ((Convolution _ _) :: Convolution channels filters kernelRows kernelCols strideRows strideCols)) -> True
) :: Bool
prop_conv_net_witness = property $
forAll genOpaqueOpaqueConvolution >>= \onet ->
case onet of
(OpaqueConvolution ((Convolution _ _) :: Convolution channels filters kernelRows kernelCols strideRows strideCols)) -> success
prop_conv_net =
gamble genOpaqueOpaqueConvolution $ \onet ->
(case onet of
prop_conv_net = property $
forAll genOpaqueOpaqueConvolution >>= \onet ->
case onet of
(OpaqueConvolution (convLayer@(Convolution _ _) :: Convolution channels filters kernelRows kernelCols strideRows strideCols)) ->
let ok stride kernel = [extent | extent <- [(kernel + 1) .. 30 ], (extent - kernel) `mod` stride == 0]
kr = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
@ -74,15 +76,15 @@ prop_conv_net =
sr = fromIntegral $ natVal (Proxy :: Proxy strideRows)
sc = fromIntegral $ natVal (Proxy :: Proxy strideCols)
in gamble (elements (ok sr kr)) $ \er ->
gamble (elements (ok sc kc)) $ \ec ->
in forAll (Gen.element (ok sr kr)) >>= \er ->
forAll (Gen.element (ok sc kc)) >>= \ec ->
let rr = ((er - kr) `div` sr) + 1
rc = ((ec - kc) `div` sc) + 1
Just er' = someNatVal er
Just ec' = someNatVal ec
Just rr' = someNatVal rr
Just rc' = someNatVal rc
in (case (er', ec', rr', rc') of
in case (er', ec', rr', rc') of
( SomeNat (pinr :: Proxy inRows), SomeNat (_ :: Proxy inCols), SomeNat (pour :: Proxy outRows), SomeNat (_ :: Proxy outCols)) ->
case ( natDict pinr %* natDict (Proxy :: Proxy channels)
, natDict pour %* natDict (Proxy :: Proxy filters)
@ -90,14 +92,12 @@ prop_conv_net =
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outRows - 1) * strideRows) ~ (inRows - kernelRows)))
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outCols - 1) * strideCols) ~ (inCols - kernelCols)))) of
(Dict, Dict, Dict, Dict) ->
gamble (S3D <$> uniformSample) $ \(input :: S ('D3 inRows inCols channels)) ->
forAll (S3D <$> uniformSample) >>= \(input :: S ('D3 inRows inCols channels)) ->
let (tape, output :: S ('D3 outRows outCols filters)) = runForwards convLayer input
backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S ('D3 inRows inCols channels))
= runBackwards convLayer tape output
in backed `seq` True
) :: Property
) :: Property
in backed `seq` success
return []
tests :: IO Bool
tests = $quickCheckAll
tests = $$(checkConcurrent)

View File

@ -15,9 +15,10 @@ import GHC.TypeLits
import Grenade.Core
import Grenade.Layers.FullyConnected
import Disorder.Jack
import Hedgehog
import Test.Jack.Hmatrix
import Test.Jack.Compat
data OpaqueFullyConnected :: * where
@ -28,8 +29,8 @@ instance Show OpaqueFullyConnected where
genOpaqueFullyConnected :: Jack OpaqueFullyConnected
genOpaqueFullyConnected = do
input :: Integer <- choose (2, 100)
output :: Integer <- choose (1, 100)
input :: Integer <- choose 2 100
output :: Integer <- choose 1 100
let Just input' = someNatVal input
let Just output' = someNatVal output
case (input', output') of
@ -41,14 +42,13 @@ genOpaqueFullyConnected = do
return . OpaqueFullyConnected $ (FullyConnected (FullyConnected' wB wN) (FullyConnected' bM kM) :: FullyConnected i' o')
prop_fully_connected_forwards :: Property
prop_fully_connected_forwards =
gamble genOpaqueFullyConnected $ \(OpaqueFullyConnected (fclayer :: FullyConnected i o)) ->
gamble (S1D <$> randomVector) $ \(input :: S ('D1 i)) ->
let (tape, output :: S ('D1 o)) = runForwards fclayer input
backed :: (Gradient (FullyConnected i o), S ('D1 i))
= runBackwards fclayer tape output
in backed `seq` True
prop_fully_connected_forwards = property $ do
OpaqueFullyConnected (fclayer :: FullyConnected i o) <- forAll genOpaqueFullyConnected
input :: S ('D1 i) <- forAll (S1D <$> randomVector)
let (tape, output :: S ('D1 o)) = runForwards fclayer input
backed :: (Gradient (FullyConnected i o), S ('D1 i))
= runBackwards fclayer tape output
backed `seq` success
return []
tests :: IO Bool
tests = $quickCheckAll
tests = $$(checkConcurrent)

View File

@ -12,41 +12,48 @@ import Grenade.Layers.Internal.Convolution
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
import Disorder.Jack
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import qualified Test.Grenade.Layers.Internal.Reference as Reference
import Test.Jack.Compat
prop_im2col_col2im_symmetrical_with_kernel_stride =
let factors n = [x | x <- [1..n], n `mod` x == 0]
in gamble (choose (2, 100)) $ \height ->
gamble (choose (2, 100)) $ \width ->
gamble ((height `div`) <$> elements (factors height)) $ \kernel_h ->
gamble ((width `div`) <$> elements (factors width)) $ \kernel_w ->
gamble (listOfN (height * width) (height * width) sizedRealFrac) $ \input ->
let input' = (height >< width) input
stride_h = kernel_h
stride_w = kernel_w
out = col2im kernel_h kernel_w stride_h stride_w height width . im2col kernel_h kernel_w stride_h stride_w $ input'
in input' === out
in property $ do
height <- forAll $ choose 2 100
width <- forAll $ choose 2 100
kernel_h <- forAll $ (height `div`) <$> Gen.element (factors height)
kernel_w <- forAll $ (width `div`) <$> Gen.element (factors width)
input <- forAll $ Gen.list (Range.singleton $ height * width) (Gen.realFloat $ Range.linearFracFrom 0 (-100) 100)
let input' = (height >< width) input
let stride_h = kernel_h
let stride_w = kernel_w
let out = col2im kernel_h kernel_w stride_h stride_w height width . im2col kernel_h kernel_w stride_h stride_w $ input'
input' === out
prop_im2col_col2im_behaves_as_reference =
let ok extent kernel = [stride | stride <- [1..extent], (extent - kernel) `mod` stride == 0]
in gamble (choose (2, 100)) $ \height ->
gamble (choose (2, 100)) $ \width ->
gamble (choose (2, height - 1)) $ \kernel_h ->
gamble (choose (2, width - 1)) $ \kernel_w ->
gamble (elements (ok height kernel_h)) $ \stride_h ->
gamble (elements (ok width kernel_w)) $ \stride_w ->
gamble (listOfN (height * width) (height * width) sizedRealFrac) $ \input ->
let input' = (height >< width) input
outFast = im2col kernel_h kernel_w stride_h stride_w input'
retFast = col2im kernel_h kernel_w stride_h stride_w height width outFast
in property $ do
height <- forAll (choose 2 100)
width <- forAll (choose 2 100)
kernel_h <- forAll (choose 2 (height - 1))
kernel_w <- forAll (choose 2 (width - 1))
stride_h <- forAll (Gen.element (ok height kernel_h))
stride_w <- forAll (Gen.element (ok width kernel_w))
input <- forAll ( Gen.list (Range.singleton $ height * width) (Gen.realFloat $ Range.linearFracFrom 0 (-100) 100))
let input' = (height >< width) input
let outFast = im2col kernel_h kernel_w stride_h stride_w input'
let retFast = col2im kernel_h kernel_w stride_h stride_w height width outFast
outReference = Reference.im2col kernel_h kernel_w stride_h stride_w input'
retReference = Reference.col2im kernel_h kernel_w stride_h stride_w height width outReference
in outFast === outReference .&&. retFast === retReference
let outReference = Reference.im2col kernel_h kernel_w stride_h stride_w input'
let retReference = Reference.col2im kernel_h kernel_w stride_h stride_w height width outReference
outFast === outReference
retFast === retReference
return []
tests :: IO Bool
tests = $quickCheckAll
tests = $$(checkConcurrent)

View File

@ -12,28 +12,35 @@ import Grenade.Layers.Internal.Pooling
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
import Disorder.Jack
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import qualified Test.Grenade.Layers.Internal.Reference as Reference
import Test.Jack.Compat
prop_poolForwards_poolBackwards_behaves_as_reference =
let ok extent kernel = [stride | stride <- [1..extent], (extent - kernel) `mod` stride == 0]
output extent kernel stride = (extent - kernel) `div` stride + 1
in gamble (choose (2, 100)) $ \height ->
gamble (choose (2, 100)) $ \width ->
gamble (choose (1, height - 1)) $ \kernel_h ->
gamble (choose (1, width - 1)) $ \kernel_w ->
gamble (elements (ok height kernel_h)) $ \stride_h ->
gamble (elements (ok width kernel_w)) $ \stride_w ->
gamble (listOfN (height * width) (height * width) sizedRealFrac) $ \input ->
let input' = (height >< width) input
outFast = poolForward 1 height width kernel_h kernel_w stride_h stride_w input'
retFast = poolBackward 1 height width kernel_h kernel_w stride_h stride_w input' outFast
in property $ do
height <- forAll $ choose 2 100
width <- forAll $ choose 2 100
kernel_h <- forAll $ choose 1 (height - 1)
kernel_w <- forAll $ choose 1 (width - 1)
stride_h <- forAll $ Gen.element (ok height kernel_h)
stride_w <- forAll $ Gen.element (ok width kernel_w)
input <- forAll $ Gen.list (Range.singleton $ height * width) (Gen.realFloat $ Range.linearFracFrom 0 (-100) 100)
let input' = (height >< width) input
let outFast = poolForward 1 height width kernel_h kernel_w stride_h stride_w input'
let retFast = poolBackward 1 height width kernel_h kernel_w stride_h stride_w input' outFast
let outReference = Reference.poolForward kernel_h kernel_w stride_h stride_w (output height kernel_h stride_h) (output width kernel_w stride_w) input'
let retReference = Reference.poolBackward kernel_h kernel_w stride_h stride_w input' outReference
outFast === outReference
retFast === retReference
outReference = Reference.poolForward kernel_h kernel_w stride_h stride_w (output height kernel_h stride_h) (output width kernel_w stride_w) input'
retReference = Reference.poolBackward kernel_h kernel_w stride_h stride_w input' outReference
in outFast === outReference .&&. retFast === retReference
return []
tests :: IO Bool
tests = $quickCheckAll
tests = $$(checkConcurrent)

View File

@ -9,6 +9,9 @@
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
module Test.Grenade.Layers.Nonlinear where
import Control.Monad.Trans.Class (MonadTrans(..))
import Data.Singletons
#if __GLASGOW_HASKELL__ < 800
@ -18,60 +21,63 @@ import Data.Proxy
import Grenade
import GHC.TypeLits
import Disorder.Jack
import Hedgehog
import Hedgehog.Internal.Property ( Test (..) )
import Test.Jack.Hmatrix
import Test.Jack.TypeLits
import Numeric.LinearAlgebra.Static ( norm_Inf )
-- | Generates a random input for the test by running the provided generator.
--
blindForAll :: Monad m => Gen m a -> Test m a
blindForAll = Test . lift . lift
prop_sigmoid_grad :: Property
prop_sigmoid_grad =
gambleDisplay (const "Shape") genShape $ \case
prop_sigmoid_grad = property $
blindForAll genShape >>= \case
(SomeSing (r :: Sing s)) ->
withSingI r $
gamble genOfShape $ \(ds :: S s) ->
blindForAll genOfShape >>= \(ds :: S s) ->
let (tape, f :: S s) = runForwards Logit ds
((), ret :: S s) = runBackwards Logit tape (1 :: S s)
(_, numer :: S s) = runForwards Logit (ds + 0.0001)
numericalGradient = (numer - f) * 10000
in counterexample (show numericalGradient ++ show ret)
((case numericalGradient - ret of
in assert ((case numericalGradient - ret of
(S1D x) -> norm_Inf x < 0.0001
(S2D x) -> norm_Inf x < 0.0001
(S3D x) -> norm_Inf x < 0.0001) :: Bool)
prop_tanh_grad :: Property
prop_tanh_grad =
gambleDisplay (const "Shape") genShape $ \case
prop_tanh_grad = property $
blindForAll genShape >>= \case
(SomeSing (r :: Sing s)) ->
withSingI r $
gamble genOfShape $ \(ds :: S s) ->
blindForAll genOfShape >>= \(ds :: S s) ->
let (tape, f :: S s) = runForwards Tanh ds
((), ret :: S s) = runBackwards Tanh tape (1 :: S s)
(_, numer :: S s) = runForwards Tanh (ds + 0.0001)
numericalGradient = (numer - f) * 10000
in counterexample (show numericalGradient ++ show ret)
((case numericalGradient - ret of
in assert ((case numericalGradient - ret of
(S1D x) -> norm_Inf x < 0.001
(S2D x) -> norm_Inf x < 0.001
(S3D x) -> norm_Inf x < 0.001) :: Bool)
prop_softmax_grad :: Property
prop_softmax_grad =
gamble genNat $ \case
prop_softmax_grad = property $
forAll genNat >>= \case
(SomeNat (_ :: Proxy s)) ->
gamble genOfShape $ \(ds :: S ('D1 s)) ->
forAll genOfShape >>= \(ds :: S ('D1 s)) ->
let (tape, f :: S ('D1 s)) = runForwards Relu ds
((), ret :: S ('D1 s)) = runBackwards Relu tape (1 :: S ('D1 s))
(_, numer :: S ('D1 s)) = runForwards Relu (ds + 0.0001)
numericalGradient = (numer - f) * 10000
in counterexample (show numericalGradient ++ show ret)
((case numericalGradient - ret of
in assert ((case numericalGradient - ret of
(S1D x) -> norm_Inf x < 0.0001) :: Bool)
return []
tests :: IO Bool
tests = $quickCheckAll
tests :: IO Bool
tests = $$(checkConcurrent)

View File

@ -15,7 +15,7 @@ module Test.Grenade.Layers.PadCrop where
import Grenade
import Disorder.Jack
import Hedgehog
import Numeric.LinearAlgebra.Static ( norm_Inf )
@ -25,25 +25,29 @@ prop_pad_crop :: Property
prop_pad_crop =
let net :: Network '[Pad 2 3 4 6, Crop 2 3 4 6] '[ 'D3 7 9 5, 'D3 16 15 5, 'D3 7 9 5 ]
net = Pad :~> Crop :~> NNil
in gamble genOfShape $ \(d :: S ('D3 7 9 5)) ->
in property $
forAll genOfShape >>= \(d :: S ('D3 7 9 5)) ->
let (tapes, res) = runForwards net d
(_ , grad) = runBackwards net tapes d
in d ~~~ res .&&. grad ~~~ d
in do assert $ d ~~~ res
assert $ grad ~~~ d
prop_pad_crop_2d :: Property
prop_pad_crop_2d =
let net :: Network '[Pad 2 3 4 6, Crop 2 3 4 6] '[ 'D2 7 9, 'D2 16 15, 'D2 7 9 ]
net = Pad :~> Crop :~> NNil
in gamble genOfShape $ \(d :: S ('D2 7 9)) ->
in property $
forAll genOfShape >>= \(d :: S ('D2 7 9)) ->
let (tapes, res) = runForwards net d
(_ , grad) = runBackwards net tapes d
in d ~~~ res .&&. grad ~~~ d
in do assert $ d ~~~ res
assert $ grad ~~~ d
(~~~) :: S x -> S x -> Bool
(S1D x) ~~~ (S1D y) = norm_Inf (x - y) < 0.00001
(S2D x) ~~~ (S2D y) = norm_Inf (x - y) < 0.00001
(S3D x) ~~~ (S3D y) = norm_Inf (x - y) < 0.00001
return []
tests :: IO Bool
tests = $quickCheckAll
tests = $$(checkConcurrent)

View File

@ -12,8 +12,8 @@ import Data.Singletons ()
import GHC.TypeLits
import Grenade.Layers.Pooling
import Disorder.Jack
import Hedgehog
import Test.Jack.Compat
data OpaquePooling :: * where
OpaquePooling :: (KnownNat kh, KnownNat kw, KnownNat sh, KnownNat sw) => Pooling kh kw sh sw -> OpaquePooling
@ -23,22 +23,21 @@ instance Show OpaquePooling where
genOpaquePooling :: Jack OpaquePooling
genOpaquePooling = do
Just kernelHeight <- someNatVal <$> choose (2, 15)
Just kernelWidth <- someNatVal <$> choose (2, 15)
Just strideHeight <- someNatVal <$> choose (2, 15)
Just strideWidth <- someNatVal <$> choose (2, 15)
Just kernelHeight <- someNatVal <$> choose 2 15
Just kernelWidth <- someNatVal <$> choose 2 15
Just strideHeight <- someNatVal <$> choose 2 15
Just strideWidth <- someNatVal <$> choose 2 15
case (kernelHeight, kernelWidth, strideHeight, strideWidth) of
(SomeNat (_ :: Proxy kh), SomeNat (_ :: Proxy kw), SomeNat (_ :: Proxy sh), SomeNat (_ :: Proxy sw)) ->
return $ OpaquePooling (Pooling :: Pooling kh kw sh sw)
prop_pool_layer_witness =
gamble genOpaquePooling $ \onet ->
(case onet of
(OpaquePooling (Pooling :: Pooling kernelRows kernelCols strideRows strideCols)) -> True
) :: Bool
property $ do
onet <- forAll genOpaquePooling
case onet of
(OpaquePooling (Pooling :: Pooling kernelRows kernelCols strideRows strideCols)) ->
assert True
return []
tests :: IO Bool
tests = $quickCheckAll
tests = $$(checkConcurrent)

View File

@ -10,7 +10,7 @@
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
module Test.Grenade.Recurrent.Layers.LSTM where
import Disorder.Jack
import Hedgehog
import Data.Foldable ( toList )
import Data.Singletons.TypeLits
@ -24,6 +24,7 @@ import qualified Numeric.LinearAlgebra.Static as S
import qualified Test.Grenade.Recurrent.Layers.LSTM.Reference as Reference
import Test.Jack.Hmatrix
import Test.Jack.Compat
genLSTM :: forall i o. (KnownNat i, KnownNat o) => Jack (LSTM i o)
genLSTM = do
@ -39,63 +40,69 @@ genLSTM = do
<*> pure (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
prop_lstm_reference_forwards =
gamble randomVector $ \(input :: S.R 3) ->
gamble randomVector $ \(cell :: S.R 2) ->
gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
let actual = runRecurrentForwards net (S1D cell) (S1D input)
in case actual of
(_, (S1D cellOut) :: S ('D1 2), (S1D output) :: S ('D1 2)) ->
let cellOut' = Reference.Vector . H.toList . S.extract $ cellOut
output' = Reference.Vector . H.toList . S.extract $ output
refNet = Reference.lstmToReference lstmWeights
refCell = Reference.Vector . H.toList . S.extract $ cell
refInput = Reference.Vector . H.toList . S.extract $ input
(refCO, refO) = Reference.runLSTM refNet refCell refInput
in toList refCO ~~~ toList cellOut' .&&. toList refO ~~~ toList output'
property $ do
input :: S.R 3 <- forAll randomVector
cell :: S.R 2 <- forAll randomVector
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
let actual = runRecurrentForwards net (S1D cell) (S1D input)
case actual of
(_, (S1D cellOut) :: S ('D1 2), (S1D output) :: S ('D1 2)) ->
let cellOut' = Reference.Vector . H.toList . S.extract $ cellOut
output' = Reference.Vector . H.toList . S.extract $ output
refNet = Reference.lstmToReference lstmWeights
refCell = Reference.Vector . H.toList . S.extract $ cell
refInput = Reference.Vector . H.toList . S.extract $ input
(refCO, refO) = Reference.runLSTM refNet refCell refInput
in do assert (toList refCO ~~~ toList cellOut')
assert (toList refO ~~~ toList output')
prop_lstm_reference_backwards =
gamble randomVector $ \(input :: S.R 3) ->
gamble randomVector $ \(cell :: S.R 2) ->
gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
in case actualBacks of
(actualGradients, _, _ :: S ('D1 3)) ->
let refNet = Reference.lstmToReference lstmWeights
refCell = Reference.Vector . H.toList . S.extract $ cell
refInput = Reference.Vector . H.toList . S.extract $ input
refGradients = Reference.runLSTMback refCell refInput refNet
in toList refGradients ~~~ toList (Reference.lstmToReference actualGradients)
property $ do
input :: S.R 3 <- forAll randomVector
cell :: S.R 2 <- forAll randomVector
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
case actualBacks of
(actualGradients, _, _ :: S ('D1 3)) ->
let refNet = Reference.lstmToReference lstmWeights
refCell = Reference.Vector . H.toList . S.extract $ cell
refInput = Reference.Vector . H.toList . S.extract $ input
refGradients = Reference.runLSTMback refCell refInput refNet
in assert $ toList refGradients ~~~ toList (Reference.lstmToReference actualGradients)
prop_lstm_reference_backwards_input =
gamble randomVector $ \(input :: S.R 3) ->
gamble randomVector $ \(cell :: S.R 2) ->
gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
in case actualBacks of
(_, _, S1D actualGradients :: S ('D1 3)) ->
let refNet = Reference.lstmToReference lstmWeights
refCell = Reference.Vector . H.toList . S.extract $ cell
refInput = Reference.Vector . H.toList . S.extract $ input
refGradients = Reference.runLSTMbackOnInput refCell refNet refInput
in toList refGradients ~~~ H.toList (S.extract actualGradients)
property $ do
input :: S.R 3 <- forAll randomVector
cell :: S.R 2 <- forAll randomVector
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
case actualBacks of
(_, _, S1D actualGradients :: S ('D1 3)) ->
let refNet = Reference.lstmToReference lstmWeights
refCell = Reference.Vector . H.toList . S.extract $ cell
refInput = Reference.Vector . H.toList . S.extract $ input
refGradients = Reference.runLSTMbackOnInput refCell refNet refInput
in assert $ toList refGradients ~~~ H.toList (S.extract actualGradients)
prop_lstm_reference_backwards_cell =
gamble randomVector $ \(input :: S.R 3) ->
gamble randomVector $ \(cell :: S.R 2) ->
gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
in case actualBacks of
(_, S1D actualGradients, _ :: S ('D1 3)) ->
let refNet = Reference.lstmToReference lstmWeights
refCell = Reference.Vector . H.toList . S.extract $ cell
refInput = Reference.Vector . H.toList . S.extract $ input
refGradients = Reference.runLSTMbackOnCell refInput refNet refCell
in toList refGradients ~~~ H.toList (S.extract actualGradients)
property $ do
input :: S.R 3 <- forAll randomVector
cell :: S.R 2 <- forAll randomVector
net@(LSTM lstmWeights _) :: LSTM 3 2 <- forAll genLSTM
let actualBacks = runRecurrentBackwards net (S1D cell, S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
case actualBacks of
(_, S1D actualGradients, _ :: S ('D1 3)) ->
let refNet = Reference.lstmToReference lstmWeights
refCell = Reference.Vector . H.toList . S.extract $ cell
refInput = Reference.Vector . H.toList . S.extract $ input
refGradients = Reference.runLSTMbackOnCell refInput refNet refCell
in assert $ toList refGradients ~~~ H.toList (S.extract actualGradients)
(~~~) as bs = all (< 1e-8) (zipWith (-) as bs)
infix 4 ~~~
return []
tests :: IO Bool
tests = $quickCheckAll
tests = $$(checkConcurrent)

View File

@ -8,17 +8,19 @@ module Test.Jack.Hmatrix where
import Grenade
import Data.Singletons
import Disorder.Jack
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import GHC.TypeLits
import qualified Numeric.LinearAlgebra.Static as HStatic
import Test.Jack.Compat
randomVector :: forall n. KnownNat n => Jack (HStatic.R n)
randomVector = (\s -> HStatic.randomVector s HStatic.Uniform * 2 - 1) <$> sizedNat
randomVector = (\s -> HStatic.randomVector s HStatic.Uniform * 2 - 1) <$> Gen.int Range.linearBounded
uniformSample :: forall m n. (KnownNat m, KnownNat n) => Jack (HStatic.L m n)
uniformSample = (\s -> HStatic.uniformSample s (-1) 1 ) <$> sizedNat
uniformSample = (\s -> HStatic.uniformSample s (-1) 1 ) <$> Gen.int Range.linearBounded
-- | Generate random data of the desired shape
genOfShape :: forall x. ( SingI x ) => Jack (S x)

View File

@ -10,16 +10,18 @@ import Data.Constraint
import Data.Proxy
#endif
import Data.Singletons
import Disorder.Jack
import qualified Hedgehog.Gen as Gen
import Grenade
import GHC.TypeLits
import GHC.TypeLits.Witnesses
import Test.Jack.Compat
genNat :: Jack SomeNat
genNat = do
Just n <- someNatVal <$> choose (1, 10)
Just n <- someNatVal <$> choose 1 10
return n
#if __GLASGOW_HASKELL__ < 800
@ -30,7 +32,7 @@ type Shape' = Shape
genShape :: Jack (SomeSing Shape')
genShape
= oneOf [
= Gen.choice [
genD1
, genD2
, genD3