Simplify, start working toward randomNet again

This commit is contained in:
Huw Campbell 2016-12-05 13:46:24 +11:00
parent ae4de42556
commit a87fde1a90
16 changed files with 115 additions and 93 deletions

View File

@ -17,7 +17,6 @@ import Options.Applicative
import Grenade
-- The defininition for our simple feed forward network.
-- The type level list represents the shapes passed through the layers. One can see that for this demonstration
-- we are using relu, tanh and logit non-linear units, which can be easily subsituted for each other in and out.
@ -26,12 +25,10 @@ import Grenade
-- between the shapes, so inference can't do it all for us.
-- With around 100000 examples, this should show two clear circles which have been learned by the network.
randomNet :: MonadRandom m => m (Network '[ 'D1 2, 'D1 40, 'D1 40, 'D1 10, 'D1 10, 'D1 1, 'D1 1])
randomNet = do
a :: FullyConnected 2 40 <- randomFullyConnected
b :: FullyConnected 40 10 <- randomFullyConnected
c :: FullyConnected 10 1 <- randomFullyConnected
return $ a :~> Tanh :~> b :~> Relu :~> c :~> O Logit
randomNet :: MonadRandom m
=> m (Network '[ FullyConnected 2 40, Tanh, FullyConnected 40 10, Relu, FullyConnected 10 1, Logit ]
'[ 'D1 2, 'D1 40, 'D1 40, 'D1 10, 'D1 10, 'D1 1, 'D1 1])
randomNet = randomNetwork
netTest :: MonadRandom m => LearningParameters -> Int -> m String
netTest rate n = do
@ -46,7 +43,7 @@ netTest rate n = do
let trained = foldl trainEach net0 (zip inps outs)
let testIns = [ [ (x,y) | x <- [0..50] ]
| y <- [0..20] ]
| y <- [0..20] ]
let outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet trained (S1D' $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
return $ unlines outMat
@ -54,7 +51,7 @@ netTest rate n = do
where
inCircle :: KnownNat n => SA.R n -> (SA.R n, Double) -> Bool
v `inCircle` (o, r) = SA.norm_2 (v - o) <= r
trainEach !nt !(i, o) = train rate i o nt
trainEach !nt !(i, o) = train rate nt i o
render n' | n' <= 0.2 = ' '
| n' <= 0.4 = '.'
@ -65,7 +62,6 @@ netTest rate n = do
normx :: S' ('D1 1) -> Double
normx (S1D' r) = SA.mean r
data FeedForwardOpts = FeedForwardOpts Int LearningParameters
feedForward' :: Parser FeedForwardOpts

View File

@ -30,19 +30,15 @@ import Grenade
-- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations,
-- this network should get down to about a 1.3% error rate.
randomMnistNet :: MonadRandom m => m (Network '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10])
randomMnistNet = do
a :: Convolution 1 10 5 5 1 1 <- randomConvolution
let b :: Pooling 2 2 2 2 = Pooling
c :: Convolution 10 16 5 5 1 1 <- randomConvolution
let d :: Pooling 2 2 2 2 = Pooling
e :: FullyConnected 256 80 <- randomFullyConnected
f :: FullyConnected 80 10 <- randomFullyConnected
return $ a :~> b :~> Relu :~> c :~> d :~> FlattenLayer :~> Relu :~> e :~> Logit :~> f :~> O Logit
randomMnist :: MonadRandom m
=> m (Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10])
randomMnist = randomNetwork
convTest :: Int -> FilePath -> FilePath -> LearningParameters -> IO ()
convTest iterations trainFile validateFile rate = do
net0 <- evalRandIO randomMnistNet
net0 <- evalRandIO randomMnist
fT <- T.readFile trainFile
fV <- T.readFile validateFile
let trainRows = traverse (A.parseOnly p) (T.lines fT)
@ -52,7 +48,7 @@ convTest iterations trainFile validateFile rate = do
err -> print err
where
trainEach !rate' !nt !(i, o) = train rate' i o nt
trainEach !rate' !nt !(i, o) = train rate' nt i o
p :: A.Parser (S' ('D2 28 28), S' ('D1 10))
p = do

View File

@ -8,14 +8,20 @@
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
module Grenade.Core.Network (
Layer (..)
, Network (..)
, UpdateLayer (..)
, LearningParameters (..)
, Gradients (..)
, CreatableNetwork (..)
) where
import Control.Monad.Random (MonadRandom)
import Grenade.Core.Shape
data LearningParameters = LearningParameters {
@ -26,12 +32,14 @@ data LearningParameters = LearningParameters {
-- | Class for updating a layer. All layers implement this, and it is
-- shape independent.
class UpdateLayer x where
class Show x => UpdateLayer x where
-- | The type for the gradient for this layer.
-- Unit if there isn't a gradient to pass back.
type Gradient x :: *
-- | Update a layer with its gradient and learning parameters
runUpdate :: LearningParameters -> x -> Gradient x -> x
-- | Create a random layer, many layers will use pure
createRandom :: MonadRandom m => m x
-- | Class for a layer. All layers implement this, however, they don't
-- need to implement it for all shapes, only ones which are appropriate.
@ -46,19 +54,35 @@ class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where
runBackards :: x -> S' i -> S' o -> (Gradient x, S' i)
-- | Type of a network.
-- The [*] type specifies the types of the layers. This is needed for parallel
-- running and being all the gradients beck together.
-- The [Shape] type specifies the shapes of data passed between the layers.
-- Could be considered to be a heterogeneous list of layers which are able to
-- transform the data shapes of the network.
data Network :: [Shape] -> * where
O :: (Show x, Layer x i o, KnownShape o, KnownShape i)
=> !x
-> Network '[i, o]
(:~>) :: (Show x, Layer x i h, KnownShape h, KnownShape i)
=> !x
-> !(Network (h ': hs))
-> Network (i ': h ': hs)
data Network :: [*] -> [Shape] -> * where
O :: Layer x i o => !x -> Network '[x] '[i, o]
(:~>) :: Layer x i h => !x -> !(Network xs (h ': hs)) -> Network (x ': xs) (i ': h ': hs)
infixr 5 :~>
instance Show (Network h) where
instance Show (Network l h) where
show (O a) = "O " ++ show a
show (i :~> o) = show i ++ "\n:~>\n" ++ show o
-- | Gradients of a network.
-- Parameterised on the layers of a Network.
data Gradients :: [*] -> * where
OG :: UpdateLayer x => Gradient x -> Gradients '[x]
(:/>) :: UpdateLayer x => Gradient x -> Gradients xs -> Gradients (x ': xs)
-- | A network can easily be created by hand with (:~>), but an easy way to initialise a random
-- network is with the randomNetwork.
class CreatableNetwork (xs :: [*]) (ss :: [Shape]) where
-- | Create a network of the types requested
randomNetwork :: MonadRandom m => m (Network xs ss)
instance Layer x i o => CreatableNetwork (x ': '[]) (i ': o ': '[]) where
randomNetwork = O <$> createRandom
instance (Layer x i o, CreatableNetwork xs (o ': r ': rs)) => CreatableNetwork (x ': xs) (i ': o ': r ': rs) where
randomNetwork = (:~>) <$> createRandom <*> randomNetwork

View File

@ -7,52 +7,64 @@
module Grenade.Core.Runner (
train
, backPropagate
, runNet
, applyUpdate
) where
import Data.Singletons.Prelude
import Grenade.Core.Network
import Grenade.Core.Shape
-- | Update a network with new weights after training with an instance.
train :: forall i o hs. (Head hs ~ i, Last hs ~ o, KnownShape i, KnownShape o)
=> LearningParameters -- ^ learning rate
-> S' i -- ^ input vector
-> S' o -- ^ target vector
-> Network hs -- ^ network to train
-> Network hs
train rate x0 target = fst . go x0
-- | Drive and network and collect it's back propogated gradients.
backPropagate :: forall input output shapes layers. (Head shapes ~ input, Last shapes ~ output)
=> Network layers shapes -> S' input -> S' output -> Gradients layers
backPropagate network input target =
fst $ go input network
where
go :: forall j js. (Head js ~ j, Last js ~ o, KnownShape j)
=> S' j -- ^ input vector
-> Network js -- ^ network to train
-> (Network js, S' j)
go :: forall j js sublayers. (Head js ~ j, Last js ~ output)
=> S' j -- ^ input vector
-> Network sublayers js -- ^ network to train
-> (Gradients sublayers, S' j)
-- handle input from the beginning, feeding upwards.
go !x (layer :~> n)
= let y = runForwards layer x
-- run the rest of the network, and get the layer from above.
-- recursively run the rest of the network, and get the layer from above.
(n', dWs') = go y n
-- calculate the gradient for this layer to pass down,
(layer', dWs) = runBackards layer x dWs'
-- Update this layer using the gradient
newLayer = runUpdate rate layer layer'
in (newLayer :~> n', dWs)
in (layer' :/> n', dWs)
-- handle the output layer, bouncing the derivatives back down.
go !x (O layer)
= let y = runForwards layer x
-- the gradient (how much y affects the error)
-- the gradient (how much y affects the error)
(layer', dWs) = runBackards layer x (y - target)
newLayer = runUpdate rate layer layer'
in (O newLayer, dWs)
in (OG layer', dWs)
-- | Update a network with new weights after training with an instance.
train :: forall input output shapes layers. (Head shapes ~ input, Last shapes ~ output)
=> LearningParameters -- ^ learning rate
-> Network layers shapes -- ^ network to train
-> S' input -> S' output -- ^ target vector
-> Network layers shapes
train rate network input output =
let grads = backPropagate network input output
in applyUpdate rate network grads
applyUpdate :: LearningParameters -> Network ls ss -> Gradients ls -> Network ls ss
applyUpdate rate (O layer) (OG gradient)
= O (runUpdate rate layer gradient)
applyUpdate rate (layer :~> rest) (gradient :/> grest)
= runUpdate rate layer gradient :~> applyUpdate rate rest grest
applyUpdate _ _ _
= error "Impossible for the gradients of a network to have a different length to the network"
-- | Just forwards propagation with no training.
runNet :: Network hs
runNet :: Network layers hs
-> S' (Head hs) -- ^ input vector
-> S' (Last hs) -- ^ target vector
runNet (layer :~> n) !x = let y = runForwards layer x
in runNet n y
runNet (layer :~> n) !x = let y = runForwards layer x in runNet n y
runNet (O layer) !x = runForwards layer x

View File

@ -1,9 +1,7 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE PolyKinds #-}
@ -18,11 +16,9 @@
module Grenade.Core.Shape (
Shape (..)
, S' (..)
, KnownShape (..)
) where
import Data.Singletons.TypeLits
import Data.Proxy
import Numeric.LinearAlgebra.Static
@ -36,7 +32,7 @@ data Shape =
| D2 Nat Nat
| D3 Nat Nat Nat
instance KnownShape x => Num (S' x) where
instance Num (S' x) where
(+) (S1D' x) (S1D' y) = S1D' (x + y)
(+) (S2D' x) (S2D' y) = S2D' (x + y)
(+) (S3D' x) (S3D' y) = S3D' (vectorZip (+) x y)
@ -72,16 +68,3 @@ instance Show (S' n) where
show (S1D' a) = "S1D' " ++ show a
show (S2D' a) = "S2D' " ++ show a
show (S3D' a) = "S3D' " ++ show a
-- | Singleton for Shape
class KnownShape (n :: Shape) where
shapeSing :: Proxy n
instance KnownShape ('D1 n) where
shapeSing = Proxy
instance KnownShape ('D2 n m) where
shapeSing = Proxy
instance KnownShape ('D3 l n m) where
shapeSing = Proxy

View File

@ -119,7 +119,14 @@ randomConvolution = do
mm = konst 0
return $ Convolution wN mm
instance UpdateLayer (Convolution channels filters kernelRows kernelCols strideRows strideCols) where
instance ( KnownNat channels
, KnownNat filters
, KnownNat kernelRows
, KnownNat kernelColumns
, KnownNat strideRows
, KnownNat strideColumns
, KnownNat (kernelRows * kernelColumns * channels)
) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where
type Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols) = (Convolution' channels filters kernelRows kernelCols strideRows strideCols)
runUpdate LearningParameters {..} (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) =
let newMomentum = konst learningMomentum * oldMomentum - konst learningRate * kernelGradient
@ -127,6 +134,8 @@ instance UpdateLayer (Convolution channels filters kernelRows kernelCols strideR
newKernel = oldKernel + newMomentum - regulariser
in Convolution newKernel newMomentum
createRandom = randomConvolution
-- | A two dimentional image may have a convolution filter applied to it
instance ( KnownNat kernelRows
, KnownNat kernelCols
@ -139,6 +148,7 @@ instance ( KnownNat kernelRows
, KnownNat outputCols
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
, ((outputCols - 1) * strideCols) ~ (inputCols - kernelCols)
, KnownNat (kernelRows * kernelCols * 1)
) => Layer (Convolution 1 filters kernelRows kernelCols strideRows strideCols) ('D2 inputRows inputCols) ('D3 outputRows outputCols filters) where
runForwards (Convolution kernel _) (S2D' input) =
let ex = extract input
@ -192,6 +202,7 @@ instance ( KnownNat kernelRows
, KnownNat channels
, ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
, ((outputCols - 1) * strideCols) ~ (inputCols - kernelCols)
, KnownNat (kernelRows * kernelCols * channels)
) => Layer (Convolution channels filters kernelRows kernelCols strideRows strideCols) ('D3 inputRows inputCols channels) ('D3 outputRows outputCols filters) where
runForwards (Convolution kernel _) (S3D' input) =
let ex = vecToList $ fmap extract input

View File

@ -28,11 +28,7 @@ data Crop :: Nat
-> Nat
-> Nat
-> Nat -> * where
Crop :: ( KnownNat cropLeft
, KnownNat cropTop
, KnownNat cropRight
, KnownNat cropBottom
) => Crop cropLeft cropTop cropRight cropBottom
Crop :: Crop cropLeft cropTop cropRight cropBottom
instance Show (Crop cropLeft cropTop cropRight cropBottom) where
show Crop = "Crop"
@ -40,6 +36,7 @@ instance Show (Crop cropLeft cropTop cropRight cropBottom) where
instance UpdateLayer (Crop l t r b) where
type Gradient (Crop l t r b) = ()
runUpdate _ x _ = x
createRandom = return Crop
-- | A two dimentional image can be cropped.
instance ( KnownNat cropLeft

View File

@ -34,6 +34,7 @@ data Dropout o =
instance (KnownNat i) => UpdateLayer (Dropout i) where
type Gradient (Dropout i) = ()
runUpdate _ x _ = x
createRandom = randomDropout 0.95
randomDropout :: (MonadRandom m, KnownNat i)
=> Double -> m (Dropout i)

View File

@ -28,6 +28,8 @@ data FlattenLayer = FlattenLayer
instance UpdateLayer FlattenLayer where
type Gradient FlattenLayer = ()
runUpdate _ _ _ = FlattenLayer
createRandom = return FlattenLayer
instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer FlattenLayer ('D2 x y) ('D1 a) where
runForwards _ (S2D' y) = S1D' . fromList . toList . flatten . extract $ y

View File

@ -45,6 +45,8 @@ instance (KnownNat i, KnownNat o) => UpdateLayer (FullyConnected i o) where
newActivations = oldActivations + newMomentum - regulariser
in FullyConnected newBias newBiasMomentum newActivations newMomentum
createRandom = randomFullyConnected
instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where
-- Do a matrix vector multiplication and return the result.
runForwards (FullyConnected wB _ wN _) (S1D' v) = S1D' (wB + wN #> v)

View File

@ -23,23 +23,24 @@ import Grenade.Core.Shape
-- This does however have a trade off, internal incremental states in the Wengert tape are
-- not retained during reverse accumulation. So less RAM is used, but more compute is required.
data Fuse :: * -> * -> Shape -> Shape -> Shape -> * where
(:$$) :: (Show x, Show y, Layer x i h, Layer y h o, KnownShape h, KnownShape i, KnownShape o)
(:$$) :: (Layer x i h, Layer y h o)
=> !x
-> !y
-> Fuse x y i h o
infixr 5 :$$
instance Show (Fuse x y i h o) where
instance (Show x, Show y) => Show (Fuse x y i h o) where
show (x :$$ y) = "(" ++ show x ++ " :$$ " ++ show y ++ ")"
instance (KnownShape i, KnownShape h, KnownShape o) => UpdateLayer (Fuse x y i h o) where
instance (Layer x i h, Layer y h o) => UpdateLayer (Fuse x y i h o) where
type Gradient (Fuse x y i h o) = (Gradient x, Gradient y)
runUpdate lr (x :$$ y) (x', y') =
let newX = runUpdate lr x x'
newY = runUpdate lr y y'
in (newX :$$ newY)
createRandom = (:$$) <$> createRandom <*> createRandom
instance (KnownShape i, KnownShape h, KnownShape o) => Layer (Fuse x y i h o) i o where
instance (Layer x i h, Layer y h o) => Layer (Fuse x y i h o) i o where
runForwards (x :$$ y) input =
let yInput :: S' h = runForwards x input
in runForwards y yInput

View File

@ -25,6 +25,7 @@ data Logit = Logit
instance UpdateLayer Logit where
type Gradient Logit = ()
runUpdate _ _ _ = Logit
createRandom = return Logit
instance (KnownNat i) => Layer Logit ('D1 i) ('D1 i) where
runForwards _ (S1D' y) = S1D' (logistic y)

View File

@ -28,11 +28,7 @@ data Pad :: Nat
-> Nat
-> Nat
-> Nat -> * where
Pad :: ( KnownNat padLeft
, KnownNat padTop
, KnownNat padRight
, KnownNat padBottom
) => Pad padLeft padTop padRight padBottom
Pad :: Pad padLeft padTop padRight padBottom
instance Show (Pad padLeft padTop padRight padBottom) where
show Pad = "Pad"
@ -40,6 +36,7 @@ instance Show (Pad padLeft padTop padRight padBottom) where
instance UpdateLayer (Pad l t r b) where
type Gradient (Pad l t r b) = ()
runUpdate _ x _ = x
createRandom = return Pad
-- | A two dimentional image can be padped.
instance ( KnownNat padLeft

View File

@ -41,11 +41,7 @@ data Pooling :: Nat
-> Nat
-> Nat
-> Nat -> * where
Pooling :: ( KnownNat kernelRows
, KnownNat kernelColumns
, KnownNat strideRows
, KnownNat strideColumns
) => Pooling kernelRows kernelColumns strideRows strideColumns
Pooling :: Pooling kernelRows kernelColumns strideRows strideColumns
instance Show (Pooling k k' s s') where
show Pooling = "Pooling"
@ -54,6 +50,7 @@ instance Show (Pooling k k' s s') where
instance UpdateLayer (Pooling kernelRows kernelColumns strideRows strideColumns) where
type Gradient (Pooling kr kc sr sc) = ()
runUpdate _ Pooling _ = Pooling
createRandom = return Pooling
-- | A two dimentional image can be pooled.
instance ( KnownNat kernelRows

View File

@ -25,6 +25,7 @@ data Relu = Relu
instance UpdateLayer Relu where
type Gradient Relu = ()
runUpdate _ _ _ = Relu
createRandom = return Relu
instance ( KnownNat i) => Layer Relu ('D1 i) ('D1 i) where
runForwards _ (S1D' y) = S1D' (relu y)

View File

@ -15,13 +15,14 @@ import Grenade.Core.Network
import Grenade.Core.Shape
-- | A Tanh layer.
-- A layer which can act between any shape of the same dimension, perfoming an tanh function.s
-- A layer which can act between any shape of the same dimension, perfoming a tanh function.
data Tanh = Tanh
deriving Show
instance UpdateLayer Tanh where
type Gradient Tanh = ()
runUpdate _ _ _ = Tanh
createRandom = return Tanh
instance KnownNat i => Layer Tanh ('D1 i) ('D1 i) where
runForwards _ (S1D' y) = S1D' (tanh y)