Remove Show constraint on UpdateLayer

This commit is contained in:
Huw Campbell 2017-01-24 13:47:47 +11:00
parent 148c7778ab
commit d82a121797
8 changed files with 25 additions and 16 deletions

2
mafia
View File

@ -120,4 +120,4 @@ case "$MODE" in
upgrade) shift; run_upgrade "$@" ;;
*) exec_mafia "$@"
esac
# Version: 3044e63eb472fb9e16926d4ab2ca9dd9e255829c
# Version: 360716306a06db842ec022a7b9f161d2208483f0

View File

@ -33,8 +33,8 @@ import Grenade.Utils.OneHot
-- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations,
-- this network should get down to about a 1.3% error rate.
type MNIST = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10]
type MNIST = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Relu, FlattenLayer, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D3 4 4 16, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10]
randomMnist :: MonadRandom m => m MNIST
randomMnist = randomNetwork

View File

@ -13,6 +13,7 @@ import Grenade.Layers.Flatten as X
import Grenade.Layers.Fuse as X
import Grenade.Layers.FullyConnected as X
import Grenade.Layers.Logit as X
import Grenade.Layers.Convolution as X
import Grenade.Layers.Relu as X
import Grenade.Layers.Tanh as X
import Grenade.Layers.Convolution as X
import Grenade.Layers.Relu as X
import Grenade.Layers.Tanh as X
import Grenade.Layers.Softmax as X

View File

@ -42,7 +42,7 @@ data LearningParameters = LearningParameters {
-- | Class for updating a layer. All layers implement this, and it is
-- shape independent.
class Show x => UpdateLayer x where
class UpdateLayer x where
{-# MINIMAL runUpdate, createRandom #-}
-- | The type for the gradient for this layer.
-- Unit if there isn't a gradient to pass back.
@ -86,9 +86,11 @@ data Network :: [*] -> [Shape] -> * where
(:~>) :: (SingI i, SingI h, Layer x i h) => !x -> !(Network xs (h ': hs)) -> Network (x ': xs) (i ': h ': hs)
infixr 5 :~>
instance Show (Network l h) where
instance Show (Network '[] '[i]) where
show NNil = "NNil"
show (i :~> o) = show i ++ "\n:~>\n" ++ show o
instance (Show x, Show (Network xs rs)) => Show (Network (x ': xs) (i ': rs)) where
show (x :~> xs) = show x ++ "\n~>\n" ++ show xs
-- | Gradients of a network.
-- Parameterised on the layers of a Network.

View File

@ -14,6 +14,9 @@ Stability : experimental
This module defines simple back propagation and training functions
for a network.
-}
-- GHC 7.10 doesn't think that go is complete
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
module Grenade.Core.Runner (
train
, backPropagate
@ -48,7 +51,6 @@ backPropagate network input target =
(n', dWs') = go y n
-- calculate the gradient for this layer to pass down,
(layer', dWs) = runBackwards layer tape dWs'
in (layer' :/> n', dWs)
-- Bouncing the derivatives back down.

View File

@ -31,7 +31,7 @@ instance UpdateLayer FlattenLayer where
runUpdate _ _ _ = FlattenLayer
createRandom = return FlattenLayer
instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * z)) => Layer FlattenLayer ('D2 x y) ('D1 a) where
instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer FlattenLayer ('D2 x y) ('D1 a) where
type Tape FlattenLayer ('D2 x y) ('D1 a) = ()
runForwards _ (S2D y) = ((), fromJust' . fromStorable . flatten . extract $ y)
runBackwards _ _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)

View File

@ -59,11 +59,12 @@ data RecurrentNetwork :: [*] -> [Shape] -> * where
infixr 5 :~~>
infixr 5 :~@>
instance Show (RecurrentNetwork l h) where
show RNil = "RNil"
show (i :~~> o) = show i ++ "\n:~~>\n" ++ show o
show (i :~@> o) = show i ++ "\n:~@>\n" ++ show o
instance Show (RecurrentNetwork '[] '[i]) where
show RNil = "NNil"
instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (FeedForward x ': xs) (i ': rs)) where
show (x :~~> xs) = show x ++ "\n~~>\n" ++ show xs
instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (Recurrent x ': xs) (i ': rs)) where
show (x :~@> xs) = show x ++ "\n~~>\n" ++ show xs
-- | Recurrent inputs (sideways shapes on an imaginary unrolled graph)
-- Parameterised on the layers of a Network.

View File

@ -7,6 +7,9 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
-- GHC 7.10 doesn't think that go is complete
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
module Grenade.Recurrent.Core.Runner (
trainRecurrent
, runRecurrent