mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Remove Typeable, fix tests
This commit is contained in:
parent
114dab4103
commit
cd4598bee8
@ -16,7 +16,6 @@ module Grenade.Core.Network (
|
||||
, LearningParameters (..)
|
||||
) where
|
||||
|
||||
import Data.Typeable
|
||||
import Grenade.Core.Shape
|
||||
|
||||
data LearningParameters = LearningParameters {
|
||||
@ -51,10 +50,10 @@ class UpdateLayer m x => Layer (m :: * -> *) x (i :: Shape) (o :: Shape) where
|
||||
-- Could be considered to be a heterogeneous list of layers which are able to
|
||||
-- transform the data shapes of the network.
|
||||
data Network :: (* -> *) -> [Shape] -> * where
|
||||
O :: (Typeable x, Show x, Layer m x i o, KnownShape o, KnownShape i)
|
||||
O :: (Show x, Layer m x i o, KnownShape o, KnownShape i)
|
||||
=> !x
|
||||
-> Network m '[i, o]
|
||||
(:~>) :: (Typeable x, Show x, Layer m x i h, KnownShape h, KnownShape i)
|
||||
(:~>) :: (Show x, Layer m x i h, KnownShape h, KnownShape i)
|
||||
=> !x
|
||||
-> !(Network m (h ': hs))
|
||||
-> Network m (i ': h ': hs)
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
module Grenade.Layers.Convolution (
|
||||
Convolution (..)
|
||||
, Convolution' (..)
|
||||
, randomConvolution
|
||||
, im2col
|
||||
, vid2col
|
||||
|
@ -93,11 +93,12 @@ prop_simple_conv_forwards = once $
|
||||
, 0.0, 0.0, 0.0, 0.0
|
||||
, 0.0, 0.0, 0.0, 0.0
|
||||
, 0.0, 0.0, 0.0, 0.0 ] :: HStatic.L 4 4)
|
||||
--expectedKernel = (HStatic.matrix
|
||||
-- [ 0.0, 0.0, 0.0, -2.0
|
||||
-- ,-2.0, 1.0, 1.0, -5.0
|
||||
-- ,-3.0, -1.0, 1.0, -5.0
|
||||
-- ,-5.0, 0.0, 0.0, -7.0 ] :: HStatic.L 4 4)
|
||||
|
||||
expectedGradient = (HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0, 2.0
|
||||
, 2.0, 0.0, 0.0, 5.0
|
||||
, 3.0, 0.0, 0.0, 4.0
|
||||
, 4.0, 0.0, 0.0, 6.0 ] :: HStatic.L 4 4)
|
||||
|
||||
convLayer = Convolution myKernel zeroKernel :: Convolution 1 4 2 2 1 1
|
||||
|
||||
@ -128,12 +129,13 @@ prop_simple_conv_forwards = once $
|
||||
expectBack = (HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0
|
||||
, 0.0, -2.0,-1.0] :: HStatic.L 2 3)
|
||||
(nc, inX) = runIdentity $ runBackards 1 convLayer input grad :: ( Convolution 1 4 2 2 1 1 , S' ('D2 2 3))
|
||||
(nc, inX) = runIdentity $ runBackards convLayer input grad
|
||||
|
||||
in case (out, inX, nc) of
|
||||
(S3D' out' , S2D' inX', Convolution _ _)
|
||||
(S3D' out' , S2D' inX', Convolution' backGrad)
|
||||
-> ((HStatic.extract <$> expect) === (HStatic.extract <$> vecToList out'))
|
||||
.&&. ((HStatic.extract expectBack) === (HStatic.extract inX'))
|
||||
.&&. ((HStatic.extract expectedGradient) === (HStatic.extract backGrad))
|
||||
-- Temporarily disabled, as l2 adjustment puts in off 5%
|
||||
-- .&&. HStatic.extract expectedKernel === HStatic.extract kernel'
|
||||
|
||||
@ -203,11 +205,12 @@ prop_single_conv_forwards = once $
|
||||
, 0.0, 0.0, 0.0, 0.0
|
||||
, 0.0, 0.0, 0.0, 0.0
|
||||
, 0.0, 0.0, 0.0, 0.0 ] :: HStatic.L 4 4)
|
||||
--expectedKernel = (HStatic.matrix
|
||||
-- [ 0.0, 0.0, 0.0, -2.0
|
||||
-- ,-2.0, 1.0, 1.0, -5.0
|
||||
-- ,-3.0, -1.0, 1.0, -5.0
|
||||
-- ,-5.0, 0.0, 0.0, -7.0 ] :: HStatic.L 4 4)
|
||||
|
||||
expectedGradient = (HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0, 2.0
|
||||
, 2.0, 0.0, 0.0, 5.0
|
||||
, 3.0, 0.0, 0.0, 4.0
|
||||
, 4.0, 0.0, 0.0, 6.0 ] :: HStatic.L 4 4)
|
||||
|
||||
convLayer = Convolution myKernel zeroKernel :: Convolution 1 4 2 2 1 1
|
||||
|
||||
@ -238,13 +241,13 @@ prop_single_conv_forwards = once $
|
||||
expectBack = (HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0
|
||||
, 0.0, -2.0,-1.0] :: HStatic.L 2 3)
|
||||
(nc, inX) = runIdentity $ runBackards 1 convLayer input grad :: ( Convolution 1 4 2 2 1 1 , S' ('D3 2 3 1))
|
||||
(nc, inX) = runIdentity $ runBackards convLayer input grad
|
||||
|
||||
in case (out, inX, nc) of
|
||||
(S3D' out' , S3D' inX', Convolution _ _)
|
||||
(S3D' out' , S3D' inX', Convolution' backGrad)
|
||||
-> ((HStatic.extract <$> expect) === (HStatic.extract <$> vecToList out'))
|
||||
.&&. ([HStatic.extract expectBack] === (HStatic.extract <$> vecToList inX'))
|
||||
-- .&&. HStatic.extract expectedKernel === HStatic.extract kernel'
|
||||
.&&. ((HStatic.extract expectedGradient) === (HStatic.extract backGrad))
|
||||
|
||||
return []
|
||||
tests :: IO Bool
|
||||
|
Loading…
Reference in New Issue
Block a user