mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Add pad and crop layers, add pad to mnist example
This commit is contained in:
parent
047ee6a08c
commit
8b288cca9d
@ -38,6 +38,7 @@ library
|
||||
Grenade.Core.Runner
|
||||
Grenade.Core.Shape
|
||||
Grenade.Core.Phase
|
||||
Grenade.Layers.Crop
|
||||
Grenade.Layers.Convolution
|
||||
Grenade.Layers.Dropout
|
||||
Grenade.Layers.FullyConnected
|
||||
@ -46,6 +47,7 @@ library
|
||||
Grenade.Layers.Logit
|
||||
Grenade.Layers.Relu
|
||||
Grenade.Layers.Tanh
|
||||
Grenade.Layers.Pad
|
||||
Grenade.Layers.Pooling
|
||||
|
||||
|
||||
|
@ -32,15 +32,16 @@ import Grenade
|
||||
|
||||
-- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations,
|
||||
-- this network should get down to about a 1.3% error rate.
|
||||
randomMnistNet :: (MonadRandom m) => m (Network Identity '[('D2 28 28), ('D3 24 24 10), ('D3 12 12 10), ('D3 12 12 10), ('D3 8 8 16), ('D3 4 4 16), ('D1 256), ('D1 256), ('D1 80), ('D1 80), ('D1 10), ('D1 10)])
|
||||
randomMnistNet :: (MonadRandom m) => m (Network Identity '[('D2 28 28), ('D2 32 32), ('D3 28 28 10), ('D3 14 14 10), ('D3 14 14 10), ('D3 10 10 16), ('D3 5 5 16), ('D1 400), ('D1 400), ('D1 80), ('D1 80), ('D1 10), ('D1 10)])
|
||||
randomMnistNet = do
|
||||
let pad :: Pad 2 2 2 2 = Pad
|
||||
a :: Convolution 1 10 5 5 1 1 <- randomConvolution
|
||||
let b :: Pooling 2 2 2 2 = Pooling
|
||||
c :: Convolution 10 16 5 5 1 1 <- randomConvolution
|
||||
let d :: Pooling 2 2 2 2 = Pooling
|
||||
e :: FullyConnected 256 80 <- randomFullyConnected
|
||||
e :: FullyConnected 400 80 <- randomFullyConnected
|
||||
f :: FullyConnected 80 10 <- randomFullyConnected
|
||||
return $ a :~> b :~> Relu :~> c :~> d :~> FlattenLayer :~> Relu :~> e :~> Logit :~> f :~> O Logit
|
||||
return $ pad :~> a :~> b :~> Relu :~> c :~> d :~> FlattenLayer :~> Relu :~> e :~> Logit :~> f :~> O Logit
|
||||
|
||||
convTest :: Int -> FilePath -> FilePath -> Double -> IO ()
|
||||
convTest iterations trainFile validateFile rate = do
|
||||
|
@ -20,7 +20,9 @@ import Grenade.Core.Network as X
|
||||
import Grenade.Core.Runner as X
|
||||
import Grenade.Core.Shape as X
|
||||
import Grenade.Core.Phase as X
|
||||
import Grenade.Layers.Crop as X
|
||||
import Grenade.Layers.Dropout as X
|
||||
import Grenade.Layers.Pad as X
|
||||
import Grenade.Layers.Pooling as X
|
||||
import Grenade.Layers.Flatten as X
|
||||
import Grenade.Layers.Fuse as X
|
||||
|
70
src/Grenade/Layers/Crop.hs
Normal file
70
src/Grenade/Layers/Crop.hs
Normal file
@ -0,0 +1,70 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
|
||||
module Grenade.Layers.Crop (
|
||||
Crop (..)
|
||||
) where
|
||||
|
||||
import Data.Maybe
|
||||
import Data.Proxy
|
||||
import Data.Singletons.TypeLits
|
||||
import GHC.TypeLits
|
||||
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
|
||||
import Numeric.LinearAlgebra (konst, subMatrix, diagBlock)
|
||||
import Numeric.LinearAlgebra.Static (extract, create)
|
||||
|
||||
-- | A cropping layer for a neural network.
|
||||
data Crop :: Nat
|
||||
-> Nat
|
||||
-> Nat
|
||||
-> Nat -> * where
|
||||
Crop :: ( KnownNat cropLeft
|
||||
, KnownNat cropTop
|
||||
, KnownNat cropRight
|
||||
, KnownNat cropBottom
|
||||
) => Crop cropLeft cropTop cropRight cropBottom
|
||||
|
||||
instance Show (Crop cropLeft cropTop cropRight cropBottom) where
|
||||
show Crop = "Crop"
|
||||
|
||||
-- | A two dimentional image can be cropped.
|
||||
instance ( Monad m
|
||||
, KnownNat cropLeft
|
||||
, KnownNat cropTop
|
||||
, KnownNat cropRight
|
||||
, KnownNat cropBottom
|
||||
, KnownNat inputRows
|
||||
, KnownNat inputColumns
|
||||
, KnownNat outputRows
|
||||
, KnownNat outputColumns
|
||||
, (inputRows - cropTop - cropBottom) ~ outputRows
|
||||
, (inputColumns - cropLeft - cropRight) ~ outputColumns
|
||||
) => Layer m (Crop cropLeft cropTop cropRight cropBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
|
||||
runForwards Crop (S2D' input) =
|
||||
let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
|
||||
cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
|
||||
nrows = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||
ncols = fromIntegral $ natVal (Proxy :: Proxy outputColumns)
|
||||
m = extract input
|
||||
r = subMatrix (cropt, cropl) (nrows, ncols) m
|
||||
in return . S2D' . fromJust . create $ r
|
||||
runBackards _ crop _ (S2D' dEdy) =
|
||||
let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
|
||||
cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
|
||||
cropr = fromIntegral $ natVal (Proxy :: Proxy cropRight)
|
||||
cropb = fromIntegral $ natVal (Proxy :: Proxy cropBottom)
|
||||
eo = extract dEdy
|
||||
vs = diagBlock [konst 0 (cropt,cropl), eo, konst 0 (cropb,cropr)]
|
||||
in return (crop, S2D' . fromJust . create $ vs)
|
70
src/Grenade/Layers/Pad.hs
Normal file
70
src/Grenade/Layers/Pad.hs
Normal file
@ -0,0 +1,70 @@
|
||||
{-# LANGUAGE BangPatterns #-}
|
||||
{-# LANGUAGE DataKinds #-}
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
{-# LANGUAGE StandaloneDeriving #-}
|
||||
{-# LANGUAGE GADTs #-}
|
||||
{-# LANGUAGE TypeOperators #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE MultiParamTypeClasses #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE PolyKinds #-}
|
||||
|
||||
module Grenade.Layers.Pad (
|
||||
Pad (..)
|
||||
) where
|
||||
|
||||
import Data.Maybe
|
||||
import Data.Proxy
|
||||
import Data.Singletons.TypeLits
|
||||
import GHC.TypeLits
|
||||
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
|
||||
import Numeric.LinearAlgebra (konst, subMatrix, diagBlock)
|
||||
import Numeric.LinearAlgebra.Static (extract, create)
|
||||
|
||||
-- | A padding layer for a neural network.
|
||||
data Pad :: Nat
|
||||
-> Nat
|
||||
-> Nat
|
||||
-> Nat -> * where
|
||||
Pad :: ( KnownNat padLeft
|
||||
, KnownNat padTop
|
||||
, KnownNat padRight
|
||||
, KnownNat padBottom
|
||||
) => Pad padLeft padTop padRight padBottom
|
||||
|
||||
instance Show (Pad padLeft padTop padRight padBottom) where
|
||||
show Pad = "Pad"
|
||||
|
||||
-- | A two dimentional image can be padped.
|
||||
instance ( Monad m
|
||||
, KnownNat padLeft
|
||||
, KnownNat padTop
|
||||
, KnownNat padRight
|
||||
, KnownNat padBottom
|
||||
, KnownNat inputRows
|
||||
, KnownNat inputColumns
|
||||
, KnownNat outputRows
|
||||
, KnownNat outputColumns
|
||||
, (inputRows + padTop + padBottom) ~ outputRows
|
||||
, (inputColumns + padLeft + padRight) ~ outputColumns
|
||||
) => Layer m (Pad padLeft padTop padRight padBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
|
||||
runForwards Pad (S2D' input) =
|
||||
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
|
||||
padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
|
||||
padr = fromIntegral $ natVal (Proxy :: Proxy padRight)
|
||||
padb = fromIntegral $ natVal (Proxy :: Proxy padBottom)
|
||||
m = extract input
|
||||
r = diagBlock [konst 0 (padt,padl), m, konst 0 (padb,padr)]
|
||||
in return . S2D' . fromJust . create $ r
|
||||
runBackards _ pad _ (S2D' dEdy) =
|
||||
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
|
||||
padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
|
||||
nrows = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
ncols = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||
m = extract dEdy
|
||||
vs = subMatrix (padt, padl) (nrows, ncols) m
|
||||
in return (pad, S2D' . fromJust . create $ vs)
|
Loading…
Reference in New Issue
Block a user