mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-08-16 05:10:26 +03:00
Better testing
This commit is contained in:
parent
424d791a30
commit
88021fbde7
@ -20,8 +20,6 @@ import Grenade.Core.Shape
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Layers.Convolution
|
||||
|
||||
import qualified Numeric.LinearAlgebra.Static as HStatic
|
||||
|
||||
import Disorder.Jack
|
||||
|
||||
import Test.Jack.Hmatrix
|
||||
@ -61,7 +59,6 @@ prop_conv_net =
|
||||
(case onet of
|
||||
(OpaqueConvolution (convLayer@(Convolution _ _) :: Convolution channels filters kernelRows kernelCols strideRows strideCols)) ->
|
||||
let ok stride kernel = [extent | extent <- [(kernel + 1) .. 30 ], (extent - kernel) `mod` stride == 0]
|
||||
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
|
||||
kr = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||
kc = fromIntegral $ natVal (Proxy :: Proxy kernelCols)
|
||||
sr = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||
@ -69,26 +66,26 @@ prop_conv_net =
|
||||
|
||||
in gamble (elements (ok sr kr)) $ \er ->
|
||||
gamble (elements (ok sc kc)) $ \ec ->
|
||||
let i = fromIntegral (er * ec * ch)
|
||||
rr = ((er - kr) `div` sr) + 1
|
||||
let rr = ((er - kr) `div` sr) + 1
|
||||
rc = ((ec - kc) `div` sc) + 1
|
||||
er' = someNatVal er
|
||||
ec' = someNatVal ec
|
||||
rr' = someNatVal rr
|
||||
rc' = someNatVal rc
|
||||
in gamble (vectorOf i sizedRealFrac) $ \(input :: [Double]) ->
|
||||
case (er', ec', rr', rc') of
|
||||
(Just (SomeNat (pinr :: Proxy inRows)), Just (SomeNat (_ :: Proxy inCols)), Just (SomeNat (pour :: Proxy outRows)), Just (SomeNat (_ :: Proxy outCols))) ->
|
||||
let p1 = natDict pinr
|
||||
p2 = natDict pour
|
||||
in case ( p1 %* natDict (Proxy :: Proxy channels)
|
||||
, p2 %* natDict (Proxy :: Proxy filters)
|
||||
-- Fake it till you make it.
|
||||
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outRows - 1) * strideRows) ~ (inRows - kernelRows)))
|
||||
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outCols - 1) * strideCols) ~ (inCols - kernelCols)))) of
|
||||
(Dict, Dict, Dict, Dict) -> let x :: S' ('D3 outRows outCols filters) = runForwards convLayer ((S3D' (HStatic.matrix input)) :: S' ('D3 inRows inCols channels))
|
||||
in x `seq` True
|
||||
_ -> False
|
||||
Just er' = someNatVal er
|
||||
Just ec' = someNatVal ec
|
||||
Just rr' = someNatVal rr
|
||||
Just rc' = someNatVal rc
|
||||
in (case (er', ec', rr', rc') of
|
||||
( SomeNat (pinr :: Proxy inRows), SomeNat (_ :: Proxy inCols), SomeNat (pour :: Proxy outRows), SomeNat (_ :: Proxy outCols)) ->
|
||||
case ( natDict pinr %* natDict (Proxy :: Proxy channels)
|
||||
, natDict pour %* natDict (Proxy :: Proxy filters)
|
||||
-- Fake it till you make it.
|
||||
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outRows - 1) * strideRows) ~ (inRows - kernelRows)))
|
||||
, (unsafeCoerce (Dict :: Dict ()) :: Dict (((outCols - 1) * strideCols) ~ (inCols - kernelCols)))) of
|
||||
(Dict, Dict, Dict, Dict) ->
|
||||
gamble (S3D' <$> uniformSample) $ \(input :: S' ('D3 inRows inCols channels)) ->
|
||||
let output :: S' ('D3 outRows outCols filters) = runForwards convLayer input
|
||||
backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S' ('D3 inRows inCols channels))
|
||||
= runBackwards convLayer input output
|
||||
in backed `seq` True
|
||||
) :: Property
|
||||
) :: Property
|
||||
|
||||
return []
|
||||
|
@ -16,8 +16,6 @@ import Grenade.Core.Shape
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Layers.FullyConnected
|
||||
|
||||
import qualified Numeric.LinearAlgebra.Static as HStatic
|
||||
|
||||
import Disorder.Jack
|
||||
|
||||
import Test.Jack.Hmatrix
|
||||
@ -32,7 +30,7 @@ instance Show OpaqueFullyConnected where
|
||||
genOpaqueFullyConnected :: Jack OpaqueFullyConnected
|
||||
genOpaqueFullyConnected = do
|
||||
input :: Integer <- choose (2, 100)
|
||||
output :: Integer <- choose (2, 100)
|
||||
output :: Integer <- choose (1, 100)
|
||||
let Just input' = someNatVal input
|
||||
let Just output' = someNatVal output
|
||||
case (input', output') of
|
||||
@ -46,11 +44,11 @@ genOpaqueFullyConnected = do
|
||||
prop_fully_connected_forwards :: Property
|
||||
prop_fully_connected_forwards =
|
||||
gamble genOpaqueFullyConnected $ \(OpaqueFullyConnected (fclayer :: FullyConnected i o)) ->
|
||||
let i = fromIntegral $ natVal (Proxy :: Proxy i)
|
||||
in gamble (vectorOf i sizedRealFrac) $ \input ->
|
||||
let x :: S' ('D1 o) = runForwards fclayer (S1D' (HStatic.vector input :: HStatic.R i))
|
||||
in x `seq` True
|
||||
|
||||
gamble (S1D' <$> randomVector) $ \(input :: S' ('D1 i)) ->
|
||||
let output :: S' ('D1 o) = runForwards fclayer input
|
||||
backed :: (Gradient (FullyConnected i o), S' ('D1 i))
|
||||
= runBackwards fclayer input output
|
||||
in backed `seq` True
|
||||
|
||||
return []
|
||||
tests :: IO Bool
|
||||
|
@ -1,15 +1,13 @@
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
|
||||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||
module Test.Grenade.Layers.Internal.Convolution where
|
||||
|
||||
-- import Control.Monad.Random
|
||||
|
||||
import Grenade.Layers.Internal.Convolution
|
||||
|
||||
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
|
||||
|
@ -1,9 +1,9 @@
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE ConstraintKinds #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
|
||||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||
module Test.Grenade.Layers.Internal.Pooling where
|
||||
@ -21,38 +21,18 @@ prop_poolForwards_poolBackwards_behaves_as_reference =
|
||||
output extent kernel stride = (extent - kernel) `div` stride + 1
|
||||
in gamble (choose (2, 100)) $ \height ->
|
||||
gamble (choose (2, 100)) $ \width ->
|
||||
gamble (choose (2, height - 1)) $ \kernel_h ->
|
||||
gamble (choose (2, width - 1)) $ \kernel_w ->
|
||||
gamble (choose (1, height - 1)) $ \kernel_h ->
|
||||
gamble (choose (1, width - 1)) $ \kernel_w ->
|
||||
gamble (elements (ok height kernel_h)) $ \stride_h ->
|
||||
gamble (elements (ok width kernel_w)) $ \stride_w ->
|
||||
gamble (listOfN (height * width) (height * width) sizedRealFrac) $ \input ->
|
||||
let input' = (height >< width) input
|
||||
outFast = poolForward 1 height width kernel_h kernel_w stride_h stride_w input'
|
||||
-- retFast = poolBackward 1 height width kernel_h kernel_w stride_h stride_w input' outFast
|
||||
retFast = poolBackward 1 height width kernel_h kernel_w stride_h stride_w input' outFast
|
||||
|
||||
outReference = Reference.poolForward kernel_h kernel_w stride_h stride_w (output height kernel_h stride_h) (output width kernel_w stride_w) input'
|
||||
-- retReference = Reference.poolBackward kernel_h kernel_w stride_h stride_w input' outReference
|
||||
in outFast === outReference -- .&&. retFast === retReference
|
||||
|
||||
|
||||
prop_poolForwards_poolBackwards_symmetry =
|
||||
let factors n = [x | x <- [1..n], n `mod` x == 0]
|
||||
output extent kernel stride = (extent - kernel) `div` stride + 1
|
||||
in gamble (choose (2, 100)) $ \height ->
|
||||
gamble (choose (2, 100)) $ \width ->
|
||||
gamble ((height `div`) <$> elements (factors height)) $ \kernel_h ->
|
||||
gamble ((width `div`) <$> elements (factors width)) $ \kernel_w ->
|
||||
gamble (listOfN (height * width) (height * width) sizedRealFrac) $ \input ->
|
||||
let input' = (height >< width) input
|
||||
stride_h = kernel_h
|
||||
stride_w = kernel_w
|
||||
outFast = poolForward 1 height width kernel_h kernel_w stride_h stride_w input'
|
||||
retFast = poolBackward 1 height width kernel_h kernel_w stride_h stride_w input' outFast
|
||||
|
||||
outReference = Reference.poolForward kernel_h kernel_w stride_h stride_w (output height kernel_h stride_h) (output width kernel_w stride_w) input'
|
||||
retReference = Reference.poolBackward kernel_h kernel_w stride_h stride_w input' outReference
|
||||
in outFast === outReference .&&. retFast === retReference
|
||||
|
||||
retReference = Reference.poolBackward kernel_h kernel_w stride_h stride_w input' outReference
|
||||
in outFast === outReference .&&. retFast === retReference
|
||||
|
||||
return []
|
||||
tests :: IO Bool
|
||||
|
@ -1,14 +1,6 @@
|
||||
module Test.Grenade.Layers.Internal.Reference where
|
||||
|
||||
|
||||
import Control.Monad.ST ( ST )
|
||||
|
||||
import Data.Foldable ( forM_ )
|
||||
import Data.Function ( on )
|
||||
import Data.List ( maximumBy )
|
||||
|
||||
import Numeric.LinearAlgebra
|
||||
import qualified Numeric.LinearAlgebra.Devel as U
|
||||
|
||||
im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
im2col nrows ncols srows scols m =
|
||||
@ -47,7 +39,6 @@ col2imfit starts krows kcols drows dcols m =
|
||||
accums = concatMap (\(conv',(stx',sty')) -> fmap (\((ix,iy), val) -> ((ix + stx', iy + sty'), val)) conv') pairs
|
||||
in accum (konst 0 (drows, dcols)) (+) accums
|
||||
|
||||
|
||||
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
poolForward nrows ncols srows scols outputRows outputCols m =
|
||||
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
|
||||
@ -60,7 +51,7 @@ poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
|
||||
|
||||
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
poolForwardFit starts nrows ncols _ outputCols m =
|
||||
let els = fmap (\start -> unsafeMaxElementSubmatrix start (nrows, ncols) m) starts
|
||||
let els = fmap (\start -> maxElement $ subMatrix start (nrows, ncols) m) starts
|
||||
in matrix outputCols els
|
||||
|
||||
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
@ -76,34 +67,14 @@ poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
|
||||
in uncurry (poolBackwardFit starts krows kcols) <$> inputMatrices
|
||||
|
||||
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
poolBackwardFit starts krows kcols inputMatrix gradientMatrix = U.runSTMatrix $ do
|
||||
poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
|
||||
let inRows = rows inputMatrix
|
||||
inCols = cols inputMatrix
|
||||
gradCol = cols gradientMatrix
|
||||
extent = (krows, kcols)
|
||||
|
||||
retM <- U.newMatrix 0 inRows inCols
|
||||
|
||||
forM_ (zip [0..] starts) $ \(ix, start) -> do
|
||||
let loc = unsafeMaxIndexSubMatrix start extent inputMatrix
|
||||
uncurry (unsafeModifyMatrix retM) loc ((+) $ uncurry (U.atM' gradientMatrix) $ divMod ix gradCol)
|
||||
|
||||
return retM
|
||||
|
||||
unsafeMaxElementSubmatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> Double
|
||||
unsafeMaxElementSubmatrix starts extent m = uncurry (U.atM' m) $ unsafeMaxIndexSubMatrix starts extent m
|
||||
|
||||
unsafeMaxIndexSubMatrix :: (Int,Int) -> (Int,Int) -> Matrix Double -> (Int, Int)
|
||||
unsafeMaxIndexSubMatrix (startRow, startCol) (extentRow, extentCold) m =
|
||||
let mrows = [startRow .. startRow + extentRow - 1]
|
||||
mcols = [startCol .. startCol + extentCold - 1]
|
||||
pairs = concatMap ( \r -> fmap (\c -> (r , c)) mcols ) mrows
|
||||
in maximumBy (compare `on` uncurry (U.atM' m)) pairs
|
||||
|
||||
|
||||
unsafeModifyMatrix :: U.STMatrix s Double -> Int -> Int -> (Double -> Double) -> ST s ()
|
||||
unsafeModifyMatrix x r c f = U.unsafeReadMatrix x r c >>= U.unsafeWriteMatrix x r c . f
|
||||
|
||||
inds = fmap (\start -> maxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
|
||||
grads = toList $ flatten gradientMatrix
|
||||
grads' = zip3 starts grads inds
|
||||
accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
|
||||
in accum (konst 0 (inRows, inCols)) (+) accums
|
||||
|
||||
-- | These functions are not even remotely safe, but it's only called from the statically typed
|
||||
-- commands, so we should be good ?!?!?
|
||||
|
Loading…
Reference in New Issue
Block a user