Better testing

This commit is contained in:
Huw Campbell 2016-12-15 16:36:22 +11:00
parent 424d791a30
commit 88021fbde7
5 changed files with 49 additions and 105 deletions

View File

@ -20,8 +20,6 @@ import Grenade.Core.Shape
import Grenade.Core.Network
import Grenade.Layers.Convolution
import qualified Numeric.LinearAlgebra.Static as HStatic
import Disorder.Jack
import Test.Jack.Hmatrix
@ -61,7 +59,6 @@ prop_conv_net =
(case onet of
(OpaqueConvolution (convLayer@(Convolution _ _) :: Convolution channels filters kernelRows kernelCols strideRows strideCols)) ->
let ok stride kernel = [extent | extent <- [(kernel + 1) .. 30 ], (extent - kernel) `mod` stride == 0]
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
kr = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
kc = fromIntegral $ natVal (Proxy :: Proxy kernelCols)
sr = fromIntegral $ natVal (Proxy :: Proxy strideRows)
@ -69,26 +66,26 @@ prop_conv_net =
in gamble (elements (ok sr kr)) $ \er ->
gamble (elements (ok sc kc)) $ \ec ->
let i = fromIntegral (er * ec * ch)
rr = ((er - kr) `div` sr) + 1
let rr = ((er - kr) `div` sr) + 1
rc = ((ec - kc) `div` sc) + 1
er' = someNatVal er
ec' = someNatVal ec
rr' = someNatVal rr
rc' = someNatVal rc
in gamble (vectorOf i sizedRealFrac) $ \(input :: [Double]) ->
case (er', ec', rr', rc') of
(Just (SomeNat (pinr :: Proxy inRows)), Just (SomeNat (_ :: Proxy inCols)), Just (SomeNat (pour :: Proxy outRows)), Just (SomeNat (_ :: Proxy outCols))) ->
let p1 = natDict pinr
p2 = natDict pour
in case ( p1 %* natDict (Proxy :: Proxy channels)
, p2 %* natDict (Proxy :: Proxy filters)
-- Fake it till you make it.
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outRows - 1) * strideRows) ~ (inRows - kernelRows)))
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outCols - 1) * strideCols) ~ (inCols - kernelCols)))) of
(Dict, Dict, Dict, Dict) -> let x :: S' ('D3 outRows outCols filters) = runForwards convLayer ((S3D' (HStatic.matrix input)) :: S' ('D3 inRows inCols channels))
in x `seq` True
_ -> False
Just er' = someNatVal er
Just ec' = someNatVal ec
Just rr' = someNatVal rr
Just rc' = someNatVal rc
in (case (er', ec', rr', rc') of
( SomeNat (pinr :: Proxy inRows), SomeNat (_ :: Proxy inCols), SomeNat (pour :: Proxy outRows), SomeNat (_ :: Proxy outCols)) ->
case ( natDict pinr %* natDict (Proxy :: Proxy channels)
, natDict pour %* natDict (Proxy :: Proxy filters)
-- Fake it till you make it.
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outRows - 1) * strideRows) ~ (inRows - kernelRows)))
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outCols - 1) * strideCols) ~ (inCols - kernelCols)))) of
(Dict, Dict, Dict, Dict) ->
gamble (S3D' <$> uniformSample) $ \(input :: S' ('D3 inRows inCols channels)) ->
let output :: S' ('D3 outRows outCols filters) = runForwards convLayer input
backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S' ('D3 inRows inCols channels))
= runBackwards convLayer input output
in backed `seq` True
) :: Property
) :: Property
return []

View File

@ -16,8 +16,6 @@ import Grenade.Core.Shape
import Grenade.Core.Network
import Grenade.Layers.FullyConnected
import qualified Numeric.LinearAlgebra.Static as HStatic
import Disorder.Jack
import Test.Jack.Hmatrix
@ -32,7 +30,7 @@ instance Show OpaqueFullyConnected where
genOpaqueFullyConnected :: Jack OpaqueFullyConnected
genOpaqueFullyConnected = do
input :: Integer <- choose (2, 100)
output :: Integer <- choose (2, 100)
output :: Integer <- choose (1, 100)
let Just input' = someNatVal input
let Just output' = someNatVal output
case (input', output') of
@ -46,11 +44,11 @@ genOpaqueFullyConnected = do
prop_fully_connected_forwards :: Property
prop_fully_connected_forwards =
gamble genOpaqueFullyConnected $ \(OpaqueFullyConnected (fclayer :: FullyConnected i o)) ->
let i = fromIntegral $ natVal (Proxy :: Proxy i)
in gamble (vectorOf i sizedRealFrac) $ \input ->
let x :: S' ('D1 o) = runForwards fclayer (S1D' (HStatic.vector input :: HStatic.R i))
in x `seq` True
gamble (S1D' <$> randomVector) $ \(input :: S' ('D1 i)) ->
let output :: S' ('D1 o) = runForwards fclayer input
backed :: (Gradient (FullyConnected i o), S' ('D1 i))
= runBackwards fclayer input output
in backed `seq` True
return []
tests :: IO Bool

View File

@ -1,15 +1,13 @@
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
module Test.Grenade.Layers.Internal.Convolution where
-- import Control.Monad.Random
import Grenade.Layers.Internal.Convolution
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))

View File

@ -1,9 +1,9 @@
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
module Test.Grenade.Layers.Internal.Pooling where
@ -21,38 +21,18 @@ prop_poolForwards_poolBackwards_behaves_as_reference =
output extent kernel stride = (extent - kernel) `div` stride + 1
in gamble (choose (2, 100)) $ \height ->
gamble (choose (2, 100)) $ \width ->
gamble (choose (2, height - 1)) $ \kernel_h ->
gamble (choose (2, width - 1)) $ \kernel_w ->
gamble (choose (1, height - 1)) $ \kernel_h ->
gamble (choose (1, width - 1)) $ \kernel_w ->
gamble (elements (ok height kernel_h)) $ \stride_h ->
gamble (elements (ok width kernel_w)) $ \stride_w ->
gamble (listOfN (height * width) (height * width) sizedRealFrac) $ \input ->
let input' = (height >< width) input
outFast = poolForward 1 height width kernel_h kernel_w stride_h stride_w input'
-- retFast = poolBackward 1 height width kernel_h kernel_w stride_h stride_w input' outFast
retFast = poolBackward 1 height width kernel_h kernel_w stride_h stride_w input' outFast
outReference = Reference.poolForward kernel_h kernel_w stride_h stride_w (output height kernel_h stride_h) (output width kernel_w stride_w) input'
-- retReference = Reference.poolBackward kernel_h kernel_w stride_h stride_w input' outReference
in outFast === outReference -- .&&. retFast === retReference
prop_poolForwards_poolBackwards_symmetry =
let factors n = [x | x <- [1..n], n `mod` x == 0]
output extent kernel stride = (extent - kernel) `div` stride + 1
in gamble (choose (2, 100)) $ \height ->
gamble (choose (2, 100)) $ \width ->
gamble ((height `div`) <$> elements (factors height)) $ \kernel_h ->
gamble ((width `div`) <$> elements (factors width)) $ \kernel_w ->
gamble (listOfN (height * width) (height * width) sizedRealFrac) $ \input ->
let input' = (height >< width) input
stride_h = kernel_h
stride_w = kernel_w
outFast = poolForward 1 height width kernel_h kernel_w stride_h stride_w input'
retFast = poolBackward 1 height width kernel_h kernel_w stride_h stride_w input' outFast
outReference = Reference.poolForward kernel_h kernel_w stride_h stride_w (output height kernel_h stride_h) (output width kernel_w stride_w) input'
retReference = Reference.poolBackward kernel_h kernel_w stride_h stride_w input' outReference
in outFast === outReference .&&. retFast === retReference
retReference = Reference.poolBackward kernel_h kernel_w stride_h stride_w input' outReference
in outFast === outReference .&&. retFast === retReference
return []
tests :: IO Bool

View File

@ -1,14 +1,6 @@
module Test.Grenade.Layers.Internal.Reference where
import Control.Monad.ST ( ST )
import Data.Foldable ( forM_ )
import Data.Function ( on )
import Data.List ( maximumBy )
import Numeric.LinearAlgebra
import qualified Numeric.LinearAlgebra.Devel as U
im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
im2col nrows ncols srows scols m =
@ -47,7 +39,6 @@ col2imfit starts krows kcols drows dcols m =
accums = concatMap (\(conv',(stx',sty')) -> fmap (\((ix,iy), val) -> ((ix + stx', iy + sty'), val)) conv') pairs
in accum (konst 0 (drows, dcols)) (+) accums
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
poolForward nrows ncols srows scols outputRows outputCols m =
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
@ -60,7 +51,7 @@ poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
poolForwardFit starts nrows ncols _ outputCols m =
let els = fmap (\start -> unsafeMaxElementSubmatrix start (nrows, ncols) m) starts
let els = fmap (\start -> maxElement $ subMatrix start (nrows, ncols) m) starts
in matrix outputCols els
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
@ -76,34 +67,14 @@ poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
in uncurry (poolBackwardFit starts krows kcols) <$> inputMatrices
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
poolBackwardFit starts krows kcols inputMatrix gradientMatrix = U.runSTMatrix $ do
poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
let inRows = rows inputMatrix
inCols = cols inputMatrix
gradCol = cols gradientMatrix
extent = (krows, kcols)
retM <- U.newMatrix 0 inRows inCols
forM_ (zip [0..] starts) $ \(ix, start) -> do
let loc = unsafeMaxIndexSubMatrix start extent inputMatrix
uncurry (unsafeModifyMatrix retM) loc ((+) $ uncurry (U.atM' gradientMatrix) $ divMod ix gradCol)
return retM
unsafeMaxElementSubmatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> Double
unsafeMaxElementSubmatrix starts extent m = uncurry (U.atM' m) $ unsafeMaxIndexSubMatrix starts extent m
unsafeMaxIndexSubMatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> (Int, Int)
unsafeMaxIndexSubMatrix (startRow, startCol) (extentRow, extentCold) m =
let mrows = [startRow .. startRow + extentRow - 1]
mcols = [startCol .. startCol + extentCold - 1]
pairs = concatMap ( \r -> fmap (\c -> (r , c)) mcols ) mrows
in maximumBy (compare `on` uncurry (U.atM' m)) pairs
unsafeModifyMatrix :: U.STMatrix s Double -> Int -> Int -> (Double -> Double) -> ST s ()
unsafeModifyMatrix x r c f = U.unsafeReadMatrix x r c >>= U.unsafeWriteMatrix x r c . f
inds = fmap (\start -> maxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
grads = toList $ flatten gradientMatrix
grads' = zip3 starts grads inds
accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
in accum (konst 0 (inRows, inCols)) (+) accums
-- | These functions are not even remotely safe, but it's only called from the statically typed
-- commands, so we should be good ?!?!?