mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Merge pull request #13 from HuwCampbell/topic/performance
Topic/performance
This commit is contained in:
commit
8b3ca1e0b6
@ -50,6 +50,9 @@ library
|
||||
Grenade.Layers.Pad
|
||||
Grenade.Layers.Pooling
|
||||
|
||||
Grenade.Layers.Internal.Convolution
|
||||
Grenade.Layers.Internal.Pooling
|
||||
|
||||
|
||||
executable feedforward
|
||||
ghc-options: -Wall -threaded -O2
|
||||
@ -76,7 +79,7 @@ executable mnist
|
||||
, optparse-applicative == 0.12.*
|
||||
, text == 1.2.*
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, hmatrix
|
||||
, hmatrix >= 0.18 && < 0.19
|
||||
, transformers
|
||||
, singletons
|
||||
, MonadRandom
|
||||
|
@ -1 +0,0 @@
|
||||
Subproject commit 9aade51bd0bb6339cfa8aca014bd96f801d9b19e
|
@ -8,7 +8,6 @@
|
||||
|
||||
import Control.Monad
|
||||
import Control.Monad.Random
|
||||
|
||||
import GHC.TypeLits
|
||||
|
||||
import qualified Numeric.LinearAlgebra.Static as SA
|
||||
@ -51,7 +50,7 @@ netTest rate n = do
|
||||
where
|
||||
inCircle :: KnownNat n => SA.R n -> (SA.R n, Double) -> Bool
|
||||
v `inCircle` (o, r) = SA.norm_2 (v - o) <= r
|
||||
trainEach !nt !(i, o) = train rate nt i o
|
||||
trainEach !network (i,o) = train rate network i o
|
||||
|
||||
render n' | n' <= 0.2 = ' '
|
||||
| n' <= 0.4 = '.'
|
||||
|
@ -9,6 +9,8 @@
|
||||
import Control.Applicative
|
||||
import Control.Monad
|
||||
import Control.Monad.Random
|
||||
import Control.Monad.Trans.Class
|
||||
import Control.Monad.Trans.Except
|
||||
|
||||
import qualified Data.Attoparsec.Text as A
|
||||
import qualified Data.Text as T
|
||||
@ -35,35 +37,23 @@ randomMnist :: MonadRandom m
|
||||
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10])
|
||||
randomMnist = randomNetwork
|
||||
|
||||
|
||||
convTest :: Int -> FilePath -> FilePath -> LearningParameters -> IO ()
|
||||
convTest :: Int -> FilePath -> FilePath -> LearningParameters -> ExceptT String IO ()
|
||||
convTest iterations trainFile validateFile rate = do
|
||||
net0 <- evalRandIO randomMnist
|
||||
fT <- T.readFile trainFile
|
||||
fV <- T.readFile validateFile
|
||||
let trainRows = traverse (A.parseOnly p) (T.lines fT)
|
||||
let validateRows = traverse (A.parseOnly p) (T.lines fV)
|
||||
case (trainRows, validateRows) of
|
||||
(Right tr', Right vr') -> foldM_ (runIteration tr' vr') net0 [1..iterations]
|
||||
err -> print err
|
||||
net0 <- lift randomMnist
|
||||
trainData <- readMNIST trainFile
|
||||
validateData <- readMNIST validateFile
|
||||
lift $ foldM_ (runIteration trainData validateData) net0 [1..iterations]
|
||||
|
||||
where
|
||||
trainEach !rate' !nt !(i, o) = train rate' nt i o
|
||||
where
|
||||
trainEach rate' !network (i, o) = train rate' network i o
|
||||
|
||||
p :: A.Parser (S' ('D2 28 28), S' ('D1 10))
|
||||
p = do
|
||||
lab <- A.decimal
|
||||
pixels <- many (A.char ',' >> A.double)
|
||||
let lab' = replicate lab 0 ++ [1] ++ replicate (9 - lab) 0
|
||||
return (S2D' $ SA.fromList pixels, S1D' $ SA.fromList lab')
|
||||
|
||||
runIteration trainRows validateRows net i = do
|
||||
let trained' = foldl (trainEach rate) net trainRows
|
||||
let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows
|
||||
let res' = fmap (\(S1D' label, S1D' prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
|
||||
print trained'
|
||||
putStrLn $ "Iteration " ++ show i ++ ": " ++ show (length (filter ((==) <$> fst <*> snd) res')) ++ " of " ++ show (length res')
|
||||
return trained'
|
||||
runIteration trainRows validateRows net i = do
|
||||
let trained' = foldl (trainEach ( rate { learningRate = learningRate rate * 0.9 ^ i} )) net trainRows
|
||||
let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows
|
||||
let res' = fmap (\(S1D' label, S1D' prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
|
||||
print trained'
|
||||
putStrLn $ "Iteration " ++ show i ++ ": " ++ show (length (filter ((==) <$> fst <*> snd) res')) ++ " of " ++ show (length res')
|
||||
return trained'
|
||||
|
||||
data MnistOpts = MnistOpts FilePath FilePath Int LearningParameters
|
||||
|
||||
@ -81,4 +71,20 @@ main :: IO ()
|
||||
main = do
|
||||
MnistOpts mnist vali iter rate <- execParser (info (mnist' <**> helper) idm)
|
||||
putStrLn "Training convolutional neural network..."
|
||||
convTest iter mnist vali rate
|
||||
|
||||
res <- runExceptT $ convTest iter mnist vali rate
|
||||
case res of
|
||||
Right () -> pure ()
|
||||
Left err -> putStrLn err
|
||||
|
||||
readMNIST :: FilePath -> ExceptT String IO [(S' ('D2 28 28), S' ('D1 10))]
|
||||
readMNIST mnist = ExceptT $ do
|
||||
mnistdata <- T.readFile mnist
|
||||
return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata)
|
||||
|
||||
parseMNIST :: A.Parser (S' ('D2 28 28), S' ('D1 10))
|
||||
parseMNIST = do
|
||||
lab <- A.decimal
|
||||
pixels <- many (A.char ',' >> A.double)
|
||||
let lab' = replicate lab 0 ++ [1] ++ replicate (9 - lab) 0
|
||||
return (S2D' $ SA.fromList pixels, S1D' $ SA.fromList lab')
|
||||
|
@ -51,7 +51,7 @@ class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where
|
||||
-- layer gave from the input and the back propagated derivatives from
|
||||
-- the layer above.
|
||||
-- Returns the gradient layer and the derivatives to push back further.
|
||||
runBackards :: x -> S' i -> S' o -> (Gradient x, S' i)
|
||||
runBackwards :: x -> S' i -> S' o -> (Gradient x, S' i)
|
||||
|
||||
-- | Type of a network.
|
||||
-- The [*] type specifies the types of the layers. This is needed for parallel
|
||||
|
@ -16,7 +16,7 @@ import Data.Singletons.Prelude
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
|
||||
-- | Drive and network and collect it's back propogated gradients.
|
||||
-- | Drive and network and collect its back propogated gradients.
|
||||
backPropagate :: forall input output shapes layers. (Head shapes ~ input, Last shapes ~ output)
|
||||
=> Network layers shapes -> S' input -> S' output -> Gradients layers
|
||||
backPropagate network input target =
|
||||
@ -29,10 +29,10 @@ backPropagate network input target =
|
||||
-- handle input from the beginning, feeding upwards.
|
||||
go !x (layer :~> n)
|
||||
= let y = runForwards layer x
|
||||
-- recursively run the rest of the network, and get the layer from above.
|
||||
-- recursively run the rest of the network, and get the gradients from above.
|
||||
(n', dWs') = go y n
|
||||
-- calculate the gradient for this layer to pass down,
|
||||
(layer', dWs) = runBackards layer x dWs'
|
||||
(layer', dWs) = runBackwards layer x dWs'
|
||||
|
||||
in (layer' :/> n', dWs)
|
||||
|
||||
@ -40,7 +40,7 @@ backPropagate network input target =
|
||||
go !x (O layer)
|
||||
= let y = runForwards layer x
|
||||
-- the gradient (how much y affects the error)
|
||||
(layer', dWs) = runBackards layer x (y - target)
|
||||
(layer', dWs) = runBackwards layer x (y - target)
|
||||
|
||||
in (OG layer', dWs)
|
||||
|
||||
|
@ -16,26 +16,22 @@ module Grenade.Layers.Convolution (
|
||||
Convolution (..)
|
||||
, Convolution' (..)
|
||||
, randomConvolution
|
||||
, im2col
|
||||
, vid2col
|
||||
, col2im
|
||||
, col2vid
|
||||
, fittingStarts
|
||||
) where
|
||||
|
||||
import Control.Monad.Random hiding (fromList)
|
||||
import Control.Monad.Random hiding ( fromList )
|
||||
import Data.Maybe
|
||||
import Data.Proxy
|
||||
import Data.Singletons.TypeLits
|
||||
import GHC.TypeLits
|
||||
|
||||
import Numeric.LinearAlgebra hiding (uniformSample, konst)
|
||||
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
|
||||
import qualified Numeric.LinearAlgebra as LA
|
||||
import Numeric.LinearAlgebra.Static hiding ((|||), build, toRows)
|
||||
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
import Grenade.Core.Vector
|
||||
import Grenade.Layers.Internal.Convolution
|
||||
|
||||
-- | A convolution layer for a neural network.
|
||||
-- This uses the im2col convolution trick popularised by Caffe, which essentially turns the
|
||||
@ -159,12 +155,13 @@ instance ( KnownNat kernelRows
|
||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||
c = im2col kx ky sx sy ex
|
||||
c = im2colUnsafe kx ky sx sy ex
|
||||
mt = c LA.<> ek
|
||||
r = col2vid 1 1 1 1 ox oy mt
|
||||
r = col2vidUnsafe 1 1 1 1 ox oy mt
|
||||
rs = fmap (fromJust . create) r
|
||||
in S3D' $ mkVector rs
|
||||
runBackards (Convolution kernel _) (S2D' input) (S3D' dEdy) =
|
||||
|
||||
runBackwards (Convolution kernel _) (S2D' input) (S3D' dEdy) =
|
||||
let ex = extract input
|
||||
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||
@ -174,17 +171,18 @@ instance ( KnownNat kernelRows
|
||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||
c = im2col kx ky sx sy ex
|
||||
|
||||
c = im2colUnsafe kx ky sx sy ex
|
||||
|
||||
eo = vecToList $ fmap extract dEdy
|
||||
ek = extract kernel
|
||||
|
||||
vs = vid2col 1 1 1 1 ox oy eo
|
||||
vs = vid2colUnsafe 1 1 1 1 ox oy eo
|
||||
|
||||
kN = fromJust . create $ tr c LA.<> vs
|
||||
dW = vs LA.<> tr ek
|
||||
|
||||
xW = col2im kx ky sx sy ix iy dW
|
||||
xW = col2imUnsafe kx ky sx sy ix iy dW
|
||||
in (Convolution' kN, S2D' . fromJust . create $ xW)
|
||||
|
||||
|
||||
@ -215,12 +213,13 @@ instance ( KnownNat kernelRows
|
||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||
c = vid2col kx ky sx sy ix iy ex
|
||||
|
||||
c = vid2colUnsafe kx ky sx sy ix iy ex
|
||||
mt = c LA.<> ek
|
||||
r = col2vid 1 1 1 1 ox oy mt
|
||||
r = col2vidUnsafe 1 1 1 1 ox oy mt
|
||||
rs = fmap (fromJust . create) r
|
||||
in S3D' $ mkVector rs
|
||||
runBackards (Convolution kernel _) (S3D' input) (S3D' dEdy) =
|
||||
runBackwards (Convolution kernel _) (S3D' input) (S3D' dEdy) =
|
||||
let ex = vecToList $ fmap extract input
|
||||
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
||||
@ -230,77 +229,17 @@ instance ( KnownNat kernelRows
|
||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||
c = vid2col kx ky sx sy ix iy ex
|
||||
|
||||
c = vid2colUnsafe kx ky sx sy ix iy ex
|
||||
|
||||
eo = vecToList $ fmap extract dEdy
|
||||
ek = extract kernel
|
||||
|
||||
vs = vid2col 1 1 1 1 ox oy eo
|
||||
vs = vid2colUnsafe 1 1 1 1 ox oy eo
|
||||
|
||||
kN = fromJust . create $ tr c LA.<> vs
|
||||
|
||||
dW = vs LA.<> tr ek
|
||||
|
||||
xW = col2vid kx ky sx sy ix iy dW
|
||||
xW = col2vidUnsafe kx ky sx sy ix iy dW
|
||||
in (Convolution' kN, S3D' . mkVector . fmap (fromJust . create) $ xW)
|
||||
|
||||
im2col :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
im2col nrows ncols srows scols m =
|
||||
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
|
||||
in im2colFit starts nrows ncols m
|
||||
|
||||
im2colFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
im2colFit starts nrows ncols m =
|
||||
let imRows = fmap (\start -> flatten $ subMatrix start (nrows, ncols) m) starts
|
||||
in fromRows imRows
|
||||
|
||||
vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
|
||||
vid2col nrows ncols srows scols inputrows inputcols ms =
|
||||
let starts = fittingStarts inputrows nrows srows inputcols ncols scols
|
||||
subs = fmap (im2colFit starts nrows ncols) ms
|
||||
in foldl1 (|||) subs
|
||||
|
||||
col2vid :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
|
||||
col2vid nrows ncols srows scols drows dcols m =
|
||||
let starts = fittingStart (cols m) (nrows * ncols) (nrows * ncols)
|
||||
r = rows m
|
||||
mats = fmap (\s -> subMatrix (0,s) (r, nrows * ncols) m) starts
|
||||
colSts = fittingStarts drows nrows srows dcols ncols scols
|
||||
in fmap (col2imfit colSts nrows ncols drows dcols) mats
|
||||
|
||||
col2im :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
col2im krows kcols srows scols drows dcols m =
|
||||
let starts = fittingStarts drows krows srows dcols kcols scols
|
||||
in col2imfit starts krows kcols drows dcols m
|
||||
|
||||
col2imfit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
col2imfit starts krows kcols drows dcols m =
|
||||
let indicies = fmap (\[a,b] -> (a,b)) $ sequence [[0..(krows-1)], [0..(kcols-1)]]
|
||||
convs = fmap (zip indicies . toList) . toRows $ m
|
||||
pairs = zip convs starts
|
||||
accums = concat $ fmap (\(conv',(stx',sty')) -> fmap (\((ix,iy), val) -> ((ix + stx', iy + sty'), val)) conv') pairs
|
||||
in accum (LA.konst 0 (drows, dcols)) (+) accums
|
||||
|
||||
|
||||
-- | These functions are not even remotely safe, but it's only called from the statically typed
|
||||
-- commands, so we should be good ?!?!?
|
||||
-- Returns the starting sub matrix locations which fit inside the larger matrix for the
|
||||
-- convolution. Takes into account the stride and kernel size.
|
||||
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
|
||||
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
|
||||
let rs = fittingStart nrows kernelrows steprows
|
||||
cs = fittingStart ncols kernelcols stepcolsh
|
||||
ls = sequence [rs, cs]
|
||||
in fmap (\[a,b] -> (a,b)) ls
|
||||
|
||||
-- | Returns the starting sub vector which fit inside the larger vector for the
|
||||
-- convolution. Takes into account the stride and kernel size.
|
||||
fittingStart :: Int -> Int -> Int -> [Int]
|
||||
fittingStart width kernel steps =
|
||||
let go left | left + kernel < width
|
||||
= left : go (left + steps)
|
||||
| left + kernel == width
|
||||
= left : []
|
||||
| otherwise
|
||||
= error "Kernel and step do not fit in matrix."
|
||||
in go 0
|
||||
|
@ -58,7 +58,7 @@ instance ( KnownNat cropLeft
|
||||
m = extract input
|
||||
r = subMatrix (cropt, cropl) (nrows, ncols) m
|
||||
in S2D' . fromJust . create $ r
|
||||
runBackards _ _ (S2D' dEdy) =
|
||||
runBackwards _ _ (S2D' dEdy) =
|
||||
let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft)
|
||||
cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop)
|
||||
cropr = fromIntegral $ natVal (Proxy :: Proxy cropRight)
|
||||
|
@ -47,5 +47,5 @@ randomDropout rate = do
|
||||
instance (KnownNat i) => Layer (Dropout i) ('D1 i) ('D1 i) where
|
||||
runForwards (Dropout drops) (S1D' x) = S1D' $ x * drops
|
||||
runForwards (Pass rate) (S1D' x)= S1D' $ dvmap (* (1 - rate)) x
|
||||
runBackards (Dropout drops) _ (S1D' x) = ((), S1D' $ x * drops)
|
||||
runBackards (Pass rate) _ (S1D' x) = ((), S1D' $ dvmap (* (1 - rate)) x)
|
||||
runBackwards (Dropout drops) _ (S1D' x) = ((), S1D' $ x * drops)
|
||||
runBackwards (Pass rate) _ (S1D' x) = ((), S1D' $ dvmap (* (1 - rate)) x)
|
||||
|
@ -33,11 +33,11 @@ instance UpdateLayer FlattenLayer where
|
||||
|
||||
instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * y)) => Layer FlattenLayer ('D2 x y) ('D1 a) where
|
||||
runForwards _ (S2D' y) = S1D' . fromList . toList . flatten . extract $ y
|
||||
runBackards _ _ (S1D' y) = ((), S2D' . fromList . toList . unwrap $ y)
|
||||
runBackwards _ _ (S1D' y) = ((), S2D' . fromList . toList . unwrap $ y)
|
||||
|
||||
instance (KnownNat a, KnownNat x, KnownNat y, KnownNat z, a ~ (x * y * z)) => Layer FlattenLayer ('D3 x y z) ('D1 a) where
|
||||
runForwards _ (S3D' y) = S1D' . raiseShapeError . create . vjoin . vecToList . fmap (flatten . extract) $ y
|
||||
runBackards _ _ (S1D' o) =
|
||||
runBackwards _ _ (S1D' o) =
|
||||
let x' = fromIntegral $ natVal (Proxy :: Proxy x)
|
||||
y' = fromIntegral $ natVal (Proxy :: Proxy y)
|
||||
z' = fromIntegral $ natVal (Proxy :: Proxy z)
|
||||
|
@ -52,7 +52,7 @@ instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o)
|
||||
runForwards (FullyConnected wB _ wN _) (S1D' v) = S1D' (wB + wN #> v)
|
||||
|
||||
-- Run a backpropogation step for a full connected layer.
|
||||
runBackards (FullyConnected _ _ wN _) (S1D' x) (S1D' dEdy) =
|
||||
runBackwards (FullyConnected _ _ wN _) (S1D' x) (S1D' dEdy) =
|
||||
let wB' = dEdy
|
||||
mm' = dEdy `outer` x
|
||||
-- calcluate derivatives for next step
|
||||
|
@ -45,8 +45,8 @@ instance (Layer x i h, Layer y h o) => Layer (Fuse x y i h o) i o where
|
||||
let yInput :: S' h = runForwards x input
|
||||
in runForwards y yInput
|
||||
|
||||
runBackards (x :$$ y) input backGradient =
|
||||
runBackwards (x :$$ y) input backGradient =
|
||||
let yInput :: S' h = runForwards x input
|
||||
(y', yGrad) = runBackards y yInput backGradient
|
||||
(x', xGrad) = runBackards x input yGrad
|
||||
(y', yGrad) = runBackwards y yInput backGradient
|
||||
(x', xGrad) = runBackwards x input yGrad
|
||||
in ((x', y'), xGrad)
|
||||
|
190
src/Grenade/Layers/Internal/Convolution.hs
Normal file
190
src/Grenade/Layers/Internal/Convolution.hs
Normal file
@ -0,0 +1,190 @@
|
||||
module Grenade.Layers.Internal.Convolution (
|
||||
col2vidUnsafe
|
||||
, col2imUnsafe
|
||||
, vid2colUnsafe
|
||||
, im2colUnsafe
|
||||
, fittingStarts
|
||||
) where
|
||||
|
||||
import Control.Monad.ST ( runST )
|
||||
|
||||
import Data.STRef ( newSTRef, modifySTRef, writeSTRef, readSTRef )
|
||||
import Data.Foldable ( forM_ )
|
||||
import Data.Traversable ( forM )
|
||||
|
||||
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
|
||||
import qualified Numeric.LinearAlgebra.Devel as U
|
||||
|
||||
-- This module provides provides im2col function and friends, ala caffe.
|
||||
--
|
||||
-- /* From Caffe */
|
||||
-- @
|
||||
-- void col2im_cpu(const Dtype* data_col, const int channels,
|
||||
-- const int height, const int width, const int kernel_h, const int kernel_w,
|
||||
-- const int pad_h, const int pad_w,
|
||||
-- const int stride_h, const int stride_w,
|
||||
-- const int dilation_h, const int dilation_w,
|
||||
-- Dtype* data_im) {
|
||||
-- caffe_set(height * width * channels, Dtype(0), data_im);
|
||||
-- const int output_h = (height + 2 * pad_h -
|
||||
-- (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||
-- const int output_w = (width + 2 * pad_w -
|
||||
-- (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||
-- const int channel_size = height * width;
|
||||
-- for (int channel = channels; channel--; data_im += channel_size) {
|
||||
-- for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
|
||||
-- for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
|
||||
-- int input_row = -pad_h + kernel_row * dilation_h;
|
||||
-- for (int output_rows = output_h; output_rows; output_rows--) {
|
||||
-- if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
|
||||
-- data_col += output_w;
|
||||
-- } else {
|
||||
-- int input_col = -pad_w + kernel_col * dilation_w;
|
||||
-- for (int output_col = output_w; output_col; output_col--) {
|
||||
-- if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
|
||||
-- data_im[input_row * width + input_col] += *data_col;
|
||||
-- }
|
||||
-- data_col++;
|
||||
-- input_col += stride_w;
|
||||
-- }
|
||||
-- }
|
||||
-- input_row += stride_h;
|
||||
-- }
|
||||
-- }
|
||||
-- }
|
||||
-- }
|
||||
-- }
|
||||
-- @
|
||||
--
|
||||
|
||||
-- | col2im function.
|
||||
--
|
||||
-- Takes a column patch, and reconstitutes it into a normal image.
|
||||
-- Does not do any bounds checking on the matrix, so should only
|
||||
-- be called once the sizes are ensured correct.
|
||||
col2imUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = U.runSTMatrix $ do
|
||||
let columnMatrixRows = rows columnMatrix
|
||||
|
||||
dataIm <- U.newMatrix 0 destinationRows destinationCols
|
||||
offsetR <- newSTRef 0
|
||||
offsetC <- newSTRef 0
|
||||
|
||||
forM_ [0 .. columnMatrixRows - 1] $ \inputRow -> do
|
||||
inputColumnRef <- newSTRef 0
|
||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
offsetR' <- readSTRef offsetR
|
||||
offsetC' <- readSTRef offsetC
|
||||
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ U.atM' columnMatrix inputRow inputColumn)
|
||||
modifySTRef inputColumnRef (+1)
|
||||
|
||||
offsetC' <- readSTRef offsetC
|
||||
if offsetC' + kernelColumns < destinationCols
|
||||
then modifySTRef offsetC (+ strideColumns)
|
||||
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
|
||||
|
||||
return dataIm
|
||||
|
||||
-- | col2vid function.
|
||||
--
|
||||
-- Takes a column patch image, and reconstitutes it into a normal image with multiple channels.
|
||||
-- Does not do any bounds checking on the matrix, so should only
|
||||
-- be called once the sizes are ensured correct.
|
||||
col2vidUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> [Matrix Double]
|
||||
col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = runST $ do
|
||||
let columnMatrixRows = rows columnMatrix
|
||||
let filters = cols columnMatrix `div` (kernelRows * kernelColumns)
|
||||
|
||||
forM [0 .. filters - 1] $ \iter -> do
|
||||
let offsetM = iter * (kernelRows * kernelColumns)
|
||||
dataIm <- U.newMatrix 0 destinationRows destinationCols
|
||||
offsetR <- newSTRef 0
|
||||
offsetC <- newSTRef 0
|
||||
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
|
||||
inputColumn <- newSTRef 0
|
||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||
ic <- readSTRef inputColumn
|
||||
offsetR' <- readSTRef offsetR
|
||||
offsetC' <- readSTRef offsetC
|
||||
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ U.atM' columnMatrix ir (ic + offsetM))
|
||||
modifySTRef inputColumn (+1)
|
||||
|
||||
offsetC' <- readSTRef offsetC
|
||||
if offsetC' + kernelColumns < destinationCols
|
||||
then modifySTRef offsetC (+ strideColumns)
|
||||
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
|
||||
|
||||
U.unsafeFreezeMatrix dataIm
|
||||
|
||||
vid2colUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
|
||||
vid2colUnsafe kernelRows kernelColumns striderows stridecols vidrows vidcols dataVid = U.runSTMatrix $ do
|
||||
let starts = fittingStarts vidrows kernelRows striderows vidcols kernelColumns stridecols
|
||||
kernelSize = kernelRows * kernelColumns
|
||||
numberOfPatches = length starts
|
||||
channels = length dataVid
|
||||
|
||||
dataCol <- U.newMatrix 0 numberOfPatches (channels * kernelSize)
|
||||
|
||||
offsetC <- newSTRef 0
|
||||
|
||||
forM_ dataVid $ \dataIm -> do
|
||||
inputRowRef <- newSTRef 0
|
||||
offsetC' <- readSTRef offsetC
|
||||
forM_ starts $ \(startRow, startCol) -> do
|
||||
inputColumnRef <- newSTRef 0
|
||||
inputRow <- readSTRef inputRowRef
|
||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
U.modifyMatrix dataCol inputRow (inputColumn + offsetC') (+ U.atM' dataIm (kr + startRow) (kc + startCol))
|
||||
modifySTRef inputColumnRef (+1)
|
||||
modifySTRef inputRowRef (+1)
|
||||
|
||||
modifySTRef offsetC (+ kernelSize)
|
||||
|
||||
return dataCol
|
||||
|
||||
im2colUnsafe :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
im2colUnsafe kernelRows kernelColumns striderows stridecols dataIm = U.runSTMatrix $ do
|
||||
let starts = fittingStarts (rows dataIm) kernelRows striderows (cols dataIm) kernelColumns stridecols
|
||||
kernelSize = kernelRows * kernelColumns
|
||||
numberOfPatches = length starts
|
||||
|
||||
dataCol <- U.newMatrix 0 numberOfPatches kernelSize
|
||||
|
||||
inputRowRef <- newSTRef 0
|
||||
forM_ starts $ \(startRow, startCol) -> do
|
||||
inputColumnRef <- newSTRef 0
|
||||
inputRow <- readSTRef inputRowRef
|
||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||
inputColumn <- readSTRef inputColumnRef
|
||||
U.modifyMatrix dataCol inputRow inputColumn (+ U.atM' dataIm (kr + startRow) (kc + startCol))
|
||||
modifySTRef inputColumnRef (+1)
|
||||
modifySTRef inputRowRef (+1)
|
||||
|
||||
return dataCol
|
||||
|
||||
|
||||
-- | Returns the starting sub matrix locations which fit inside the larger matrix for the
|
||||
-- convolution. Takes into account the stride and kernel size.
|
||||
fittingStarts :: Int -> Int -> Int -> Int -> Int -> Int -> [(Int,Int)]
|
||||
fittingStarts nrows kernelrows steprows ncols kernelcols stepcolsh =
|
||||
let rs = fittingStart nrows kernelrows steprows
|
||||
cs = fittingStart ncols kernelcols stepcolsh
|
||||
in concatMap ( \r -> fmap (\c -> (r , c)) cs ) rs
|
||||
|
||||
-- | Returns the starting sub vector which fit inside the larger vector for the
|
||||
-- convolution. Takes into account the stride and kernel size.
|
||||
fittingStart :: Int -> Int -> Int -> [Int]
|
||||
fittingStart width kernel steps =
|
||||
let go left | left + kernel < width
|
||||
= left : go (left + steps)
|
||||
| left + kernel == width
|
||||
= [left]
|
||||
| otherwise
|
||||
= []
|
||||
in go 0
|
63
src/Grenade/Layers/Internal/Pooling.hs
Normal file
63
src/Grenade/Layers/Internal/Pooling.hs
Normal file
@ -0,0 +1,63 @@
|
||||
module Grenade.Layers.Internal.Pooling (
|
||||
poolForward
|
||||
, poolBackward
|
||||
, poolForwardList
|
||||
, poolBackwardList
|
||||
) where
|
||||
|
||||
import Data.Function ( on )
|
||||
import Data.List ( maximumBy )
|
||||
|
||||
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
|
||||
import qualified Numeric.LinearAlgebra as LA
|
||||
import qualified Numeric.LinearAlgebra.Devel as U
|
||||
|
||||
|
||||
import Grenade.Layers.Internal.Convolution
|
||||
|
||||
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
poolForward nrows ncols srows scols outputRows outputCols m =
|
||||
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
|
||||
in poolForwardFit starts nrows ncols outputRows outputCols m
|
||||
|
||||
poolForwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double) -> f (Matrix Double)
|
||||
poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
|
||||
let starts = fittingStarts inRows nrows srows inCols ncols scols
|
||||
in poolForwardFit starts nrows ncols outputRows outputCols <$> ms
|
||||
|
||||
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
poolForwardFit starts nrows ncols _ outputCols m =
|
||||
let els = fmap (\start -> unsafeMaxElement $ subMatrix start (nrows, ncols) m) starts
|
||||
in LA.matrix outputCols els
|
||||
|
||||
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
poolBackward krows kcols srows scols inputMatrix gradientMatrix =
|
||||
let inRows = rows inputMatrix
|
||||
inCols = cols inputMatrix
|
||||
starts = fittingStarts inRows krows srows inCols kcols scols
|
||||
in poolBackwardFit starts krows kcols inputMatrix gradientMatrix
|
||||
|
||||
poolBackwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double, Matrix Double) -> f (Matrix Double)
|
||||
poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
|
||||
let starts = fittingStarts inRows krows srows inCols kcols scols
|
||||
in uncurry (poolBackwardFit starts krows kcols) <$> inputMatrices
|
||||
|
||||
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
|
||||
let inRows = rows inputMatrix
|
||||
inCols = cols inputMatrix
|
||||
inds = fmap (\start -> unsafeMaxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
|
||||
grads = toList $ flatten gradientMatrix
|
||||
grads' = zip3 starts grads inds
|
||||
accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
|
||||
in accum (LA.konst 0 (inRows, inCols)) (+) accums
|
||||
|
||||
unsafeMaxElement :: Matrix Double -> Double
|
||||
unsafeMaxElement m = uncurry (U.atM' m) $ unsafeMaxIndex m
|
||||
|
||||
unsafeMaxIndex :: Matrix Double -> (Int, Int)
|
||||
unsafeMaxIndex m =
|
||||
let mrows = [0 .. rows m - 1]
|
||||
mcols = [0 .. cols m - 1]
|
||||
pairs = concatMap ( \r -> fmap (\c -> (r , c)) mcols ) mrows
|
||||
in maximumBy (compare `on` uncurry (U.atM' m)) pairs
|
@ -29,15 +29,15 @@ instance UpdateLayer Logit where
|
||||
|
||||
instance (KnownNat i) => Layer Logit ('D1 i) ('D1 i) where
|
||||
runForwards _ (S1D' y) = S1D' (logistic y)
|
||||
runBackards _ (S1D' y) (S1D' dEdy) = ((), S1D' (logistic' y * dEdy))
|
||||
runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (logistic' y * dEdy))
|
||||
|
||||
instance (KnownNat i, KnownNat j) => Layer Logit ('D2 i j) ('D2 i j) where
|
||||
runForwards _ (S2D' y) = S2D' (logistic y)
|
||||
runBackards _ (S2D' y) (S2D' dEdy) = ((), S2D' (logistic' y * dEdy))
|
||||
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (logistic' y * dEdy))
|
||||
|
||||
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Logit ('D3 i j k) ('D3 i j k) where
|
||||
runForwards _ (S3D' y) = S3D' (fmap logistic y)
|
||||
runBackards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> logistic' y' * dEdy') y dEdy))
|
||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> logistic' y' * dEdy') y dEdy))
|
||||
|
||||
|
||||
logistic :: Floating a => a -> a
|
||||
|
@ -58,7 +58,7 @@ instance ( KnownNat padLeft
|
||||
m = extract input
|
||||
r = diagBlock [konst 0 (padt,padl), m, konst 0 (padb,padr)]
|
||||
in S2D' . fromJust . create $ r
|
||||
runBackards Pad _ (S2D' dEdy) =
|
||||
runBackwards Pad _ (S2D' dEdy) =
|
||||
let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
|
||||
padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
|
||||
nrows = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
|
@ -12,8 +12,6 @@
|
||||
|
||||
module Grenade.Layers.Pooling (
|
||||
Pooling (..)
|
||||
, poolForward
|
||||
, poolBackward
|
||||
) where
|
||||
|
||||
import Data.Maybe
|
||||
@ -24,10 +22,8 @@ import GHC.TypeLits
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Core.Shape
|
||||
import Grenade.Core.Vector
|
||||
import Grenade.Layers.Convolution
|
||||
import Grenade.Layers.Internal.Pooling
|
||||
|
||||
import Numeric.LinearAlgebra hiding (uniformSample)
|
||||
import qualified Numeric.LinearAlgebra as LA
|
||||
import Numeric.LinearAlgebra.Static as LAS hiding ((|||), build, toRows)
|
||||
|
||||
-- | A pooling layer for a neural network.
|
||||
@ -37,16 +33,12 @@ import Numeric.LinearAlgebra.Static as LAS hiding ((|||), build, toRow
|
||||
-- The kernel size dictates which input and output sizes will "fit". Fitting the equation:
|
||||
-- `out = (in - kernel) / stride + 1` for both dimensions.
|
||||
--
|
||||
data Pooling :: Nat
|
||||
-> Nat
|
||||
-> Nat
|
||||
-> Nat -> * where
|
||||
data Pooling :: Nat -> Nat -> Nat -> Nat -> * where
|
||||
Pooling :: Pooling kernelRows kernelColumns strideRows strideColumns
|
||||
|
||||
instance Show (Pooling k k' s s') where
|
||||
show Pooling = "Pooling"
|
||||
|
||||
|
||||
instance UpdateLayer (Pooling kernelRows kernelColumns strideRows strideColumns) where
|
||||
type Gradient (Pooling kr kc sr sc) = ()
|
||||
runUpdate _ Pooling _ = Pooling
|
||||
@ -75,7 +67,7 @@ instance ( KnownNat kernelRows
|
||||
r = poolForward kx ky sx sy ox oy $ ex
|
||||
rs = fromJust . create $ r
|
||||
in S2D' $ rs
|
||||
runBackards Pooling (S2D' input) (S2D' dEdy) =
|
||||
runBackwards Pooling (S2D' input) (S2D' dEdy) =
|
||||
let kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
|
||||
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||
@ -111,7 +103,7 @@ instance ( KnownNat kernelRows
|
||||
r = poolForwardList kx ky sx sy ix iy ox oy ex
|
||||
rs = fmap (fromJust . create) r
|
||||
in S3D' rs
|
||||
runBackards Pooling (S3D' input) (S3D' dEdy) =
|
||||
runBackwards Pooling (S3D' input) (S3D' dEdy) =
|
||||
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
|
||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||
@ -123,40 +115,3 @@ instance ( KnownNat kernelRows
|
||||
ez = vectorZip (,) ex eo
|
||||
vs = poolBackwardList kx ky sx sy ix iy ez
|
||||
in ((), S3D' . fmap (fromJust . create) $ vs)
|
||||
|
||||
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
poolForward nrows ncols srows scols outputRows outputCols m =
|
||||
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
|
||||
in poolForwardFit starts nrows ncols outputRows outputCols m
|
||||
|
||||
poolForwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double) -> f (Matrix Double)
|
||||
poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
|
||||
let starts = fittingStarts inRows nrows srows inCols ncols scols
|
||||
in poolForwardFit starts nrows ncols outputRows outputCols <$> ms
|
||||
|
||||
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||
poolForwardFit starts nrows ncols _ outputCols m =
|
||||
let els = fmap (\start -> maxElement $ subMatrix start (nrows, ncols) m) starts
|
||||
in LA.matrix outputCols els
|
||||
|
||||
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
poolBackward krows kcols srows scols inputMatrix gradientMatrix =
|
||||
let inRows = (rows inputMatrix)
|
||||
inCols = (cols inputMatrix)
|
||||
starts = fittingStarts inRows krows srows inCols kcols scols
|
||||
in poolBackwardFit starts krows kcols inputMatrix gradientMatrix
|
||||
|
||||
poolBackwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double, Matrix Double) -> f (Matrix Double)
|
||||
poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
|
||||
let starts = fittingStarts inRows krows srows inCols kcols scols
|
||||
in (uncurry $ poolBackwardFit starts krows kcols) <$> inputMatrices
|
||||
|
||||
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||
poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
|
||||
let inRows = (rows inputMatrix)
|
||||
inCols = (cols inputMatrix)
|
||||
inds = fmap (\start -> maxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
|
||||
grads = toList $ flatten gradientMatrix
|
||||
grads' = zip3 starts grads inds
|
||||
accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
|
||||
in accum (LA.konst 0 (inRows, inCols)) (+) accums
|
||||
|
@ -31,7 +31,7 @@ instance ( KnownNat i) => Layer Relu ('D1 i) ('D1 i) where
|
||||
runForwards _ (S1D' y) = S1D' (relu y)
|
||||
where
|
||||
relu = LAS.dvmap (\a -> if a <= 0 then 0 else a)
|
||||
runBackards _ (S1D' y) (S1D' dEdy) = ((), S1D' (relu' y * dEdy))
|
||||
runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (relu' y * dEdy))
|
||||
where
|
||||
relu' = LAS.dvmap (\a -> if a <= 0 then 0 else 1)
|
||||
|
||||
@ -39,7 +39,7 @@ instance (KnownNat i, KnownNat j) => Layer Relu ('D2 i j) ('D2 i j) where
|
||||
runForwards _ (S2D' y) = S2D' (relu y)
|
||||
where
|
||||
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
||||
runBackards _ (S2D' y) (S2D' dEdy) = ((), S2D' (relu' y * dEdy))
|
||||
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (relu' y * dEdy))
|
||||
where
|
||||
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
||||
|
||||
@ -47,6 +47,6 @@ instance (KnownNat i, KnownNat j, KnownNat k) => Layer Relu ('D3 i j k) ('D3 i j
|
||||
runForwards _ (S3D' y) = S3D' (fmap relu y)
|
||||
where
|
||||
relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
|
||||
runBackards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> relu' y' * dEdy') y dEdy))
|
||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> relu' y' * dEdy') y dEdy))
|
||||
where
|
||||
relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
|
||||
|
@ -26,15 +26,15 @@ instance UpdateLayer Tanh where
|
||||
|
||||
instance KnownNat i => Layer Tanh ('D1 i) ('D1 i) where
|
||||
runForwards _ (S1D' y) = S1D' (tanh y)
|
||||
runBackards _ (S1D' y) (S1D' dEdy) = ((), S1D' (tanh' y * dEdy))
|
||||
runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (tanh' y * dEdy))
|
||||
|
||||
instance (KnownNat i, KnownNat j) => Layer Tanh ('D2 i j) ('D2 i j) where
|
||||
runForwards _ (S2D' y) = S2D' (tanh y)
|
||||
runBackards _ (S2D' y) (S2D' dEdy) = ((), S2D' (tanh' y * dEdy))
|
||||
runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (tanh' y * dEdy))
|
||||
|
||||
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Tanh ('D3 i j k) ('D3 i j k) where
|
||||
runForwards _ (S3D' y) = S3D' (fmap tanh y)
|
||||
runBackards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> tanh' y' * dEdy') y dEdy))
|
||||
runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (vectorZip (\y' dEdy' -> tanh' y' * dEdy') y dEdy))
|
||||
|
||||
tanh' :: (Floating a) => a -> a
|
||||
tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
|
||||
|
@ -8,6 +8,7 @@ import Grenade.Core.Shape
|
||||
import Grenade.Core.Vector as Grenade
|
||||
import Grenade.Core.Network
|
||||
import Grenade.Layers.Convolution
|
||||
import Grenade.Layers.Internal.Convolution
|
||||
|
||||
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
|
||||
import qualified Numeric.LinearAlgebra.Static as HStatic
|
||||
@ -26,7 +27,7 @@ prop_im2col_no_stride = once $
|
||||
, 5.0, 6.0, 9.0, 10.0
|
||||
, 6.0, 7.0, 10.0, 11.0
|
||||
, 7.0, 8.0, 11.0, 12.0 ]
|
||||
out = im2col 2 2 1 1 input
|
||||
out = im2colUnsafe 2 2 1 1 input
|
||||
in expected === out
|
||||
|
||||
prop_im2col_stride = once $
|
||||
@ -39,7 +40,7 @@ prop_im2col_stride = once $
|
||||
, 3.0, 4.0, 7.0, 8.0
|
||||
, 5.0, 6.0, 9.0, 10.0
|
||||
, 7.0, 8.0, 11.0, 12.0 ]
|
||||
out = im2col 2 2 1 2 input
|
||||
out = im2colUnsafe 2 2 1 2 input
|
||||
in expected === out
|
||||
|
||||
prop_im2col_other = once $
|
||||
@ -50,7 +51,7 @@ prop_im2col_other = once $
|
||||
expected = (2><6)
|
||||
[ 1.0, 2.0, 5.0, 6.0 , 9.0, 10.0
|
||||
, 3.0, 4.0, 7.0, 8.0 , 11.0 ,12.0 ]
|
||||
out = im2col 3 2 1 2 input
|
||||
out = im2colUnsafe 3 2 1 2 input
|
||||
in expected === out
|
||||
|
||||
-- If there's no overlap (stride is the same size as the kernel)
|
||||
@ -60,9 +61,20 @@ prop_im2col_sym_on_same_stride = once $
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
out = col2im 3 2 3 2 3 4 . im2col 3 2 3 2 $ input
|
||||
out = col2imUnsafe 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 $ input
|
||||
in input === out
|
||||
|
||||
-- If there's no overlap (stride is the same size as the kernel)
|
||||
-- then col2im . im2col should be symmetric.
|
||||
prop_im2colunsafe_sym_on_same_stride = once $
|
||||
let input = (3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
out = col2imUnsafe 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 $ input
|
||||
in input === out
|
||||
|
||||
|
||||
-- If there is an overlap, then the gradient passed back should be
|
||||
-- the sum of the gradients across the filters.
|
||||
prop_im2col_col2im_additive = once $
|
||||
@ -74,29 +86,29 @@ prop_im2col_col2im_additive = once $
|
||||
[ 1.0, 2.0, 2.0, 1.0
|
||||
, 2.0, 4.0, 4.0, 2.0
|
||||
, 1.0, 2.0, 2.0, 1.0 ]
|
||||
out = col2im 2 2 1 1 3 4 . im2col 2 2 1 1 $ input
|
||||
out = col2imUnsafe 2 2 1 1 3 4 . im2colUnsafe 2 2 1 1 $ input
|
||||
in expected === out
|
||||
|
||||
prop_simple_conv_forwards = once $
|
||||
-- Create a convolution kernel with 4 filters.
|
||||
-- [ 1, 0 [ 0, 1 [ 0, 1 [ 0, 0
|
||||
-- , 0,-1 ] ,-1, 0 ] , 1, 0 ] ,-1,-1 ]
|
||||
let myKernel = (HStatic.matrix
|
||||
let myKernel = HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0, 0.0
|
||||
, 0.0, 1.0, 1.0, 0.0
|
||||
, 0.0, -1.0, 1.0, -1.0
|
||||
,-1.0, 0.0, 0.0, -1.0 ] :: HStatic.L 4 4)
|
||||
zeroKernel = (HStatic.matrix
|
||||
,-1.0, 0.0, 0.0, -1.0 ] :: HStatic.L 4 4
|
||||
zeroKernel = HStatic.matrix
|
||||
[ 0.0, 0.0, 0.0, 0.0
|
||||
, 0.0, 0.0, 0.0, 0.0
|
||||
, 0.0, 0.0, 0.0, 0.0
|
||||
, 0.0, 0.0, 0.0, 0.0 ] :: HStatic.L 4 4)
|
||||
, 0.0, 0.0, 0.0, 0.0 ] :: HStatic.L 4 4
|
||||
|
||||
expectedGradient = (HStatic.matrix
|
||||
expectedGradient = HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0, 2.0
|
||||
, 2.0, 0.0, 0.0, 5.0
|
||||
, 3.0, 0.0, 0.0, 4.0
|
||||
, 4.0, 0.0, 0.0, 6.0 ] :: HStatic.L 4 4)
|
||||
, 4.0, 0.0, 0.0, 6.0 ] :: HStatic.L 4 4
|
||||
|
||||
convLayer = Convolution myKernel zeroKernel :: Convolution 1 4 2 2 1 1
|
||||
|
||||
@ -104,39 +116,36 @@ prop_simple_conv_forwards = once $
|
||||
[ 1.0, 2.0, 5.0
|
||||
, 3.0, 4.0, 6.0] :: HStatic.L 2 3)
|
||||
|
||||
expect = ([(HStatic.matrix
|
||||
[ -3.0 , -4.0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ -1.0 , 1.0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 5.0 , 9.0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ -7.0 , -10.0 ] :: HStatic.L 1 2)]) :: [HStatic.L 1 2]
|
||||
expect = [ HStatic.matrix
|
||||
[ -3.0 , -4.0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ -1.0 , 1.0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ 5.0 , 9.0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ -7.0 , -10.0 ] :: HStatic.L 1 2] :: [HStatic.L 1 2]
|
||||
out = runForwards convLayer input :: S' ('D3 1 2 4)
|
||||
|
||||
grad = S3D' ( mkVector
|
||||
[(HStatic.matrix
|
||||
[ 1 , 0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 0 , 1 ] :: HStatic.L 1 2)] ) :: S' ('D3 1 2 4)
|
||||
[ HStatic.matrix
|
||||
[ 1 , 0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2
|
||||
, HStatic.matrix
|
||||
[ 0 , 1 ] :: HStatic.L 1 2] ) :: S' ('D3 1 2 4)
|
||||
|
||||
expectBack = (HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0
|
||||
, 0.0, -2.0,-1.0] :: HStatic.L 2 3)
|
||||
(nc, inX) = runBackards convLayer input grad
|
||||
(nc, inX) = runBackwards convLayer input grad
|
||||
|
||||
in case (out, inX, nc) of
|
||||
(S3D' out' , S2D' inX', Convolution' backGrad)
|
||||
-> ((HStatic.extract <$> expect) === (HStatic.extract <$> vecToList out'))
|
||||
.&&. ((HStatic.extract expectBack) === (HStatic.extract inX'))
|
||||
.&&. ((HStatic.extract expectedGradient) === (HStatic.extract backGrad))
|
||||
-- Temporarily disabled, as l2 adjustment puts in off 5%
|
||||
-- .&&. HStatic.extract expectedKernel === HStatic.extract kernel'
|
||||
|
||||
.&&. (HStatic.extract expectBack === HStatic.extract inX')
|
||||
.&&. (HStatic.extract expectedGradient === HStatic.extract backGrad)
|
||||
|
||||
prop_vid2col_no_stride = once $
|
||||
let input = [(3><4)
|
||||
@ -154,7 +163,7 @@ prop_vid2col_no_stride = once $
|
||||
, 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0
|
||||
, 6.0, 7.0, 10.0, 11.0 , 26.0, 27.0, 30.0, 31.0
|
||||
, 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ]
|
||||
out = vid2col 2 2 1 1 3 4 input
|
||||
out = vid2colUnsafe 2 2 1 1 3 4 input
|
||||
in expected === out
|
||||
|
||||
prop_vid2col_stride = once $
|
||||
@ -171,10 +180,9 @@ prop_vid2col_stride = once $
|
||||
, 3.0, 4.0, 7.0, 8.0 , 23.0, 24.0, 27.0, 28.0
|
||||
, 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0
|
||||
, 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ]
|
||||
out = vid2col 2 2 1 2 3 4 input
|
||||
out = vid2colUnsafe 2 2 1 2 3 4 input
|
||||
in expected === out
|
||||
|
||||
|
||||
prop_vid2col_invert = once $
|
||||
let input = [(3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
@ -184,9 +192,10 @@ prop_vid2col_invert = once $
|
||||
[ 21.0, 22.0, 23.0, 24.0
|
||||
, 25.0, 26.0, 27.0, 28.0
|
||||
, 29.0, 30.0, 31.0, 32.0 ] ]
|
||||
out = col2vid 3 2 3 2 3 4 . vid2col 3 2 3 2 3 4 $ input
|
||||
out = col2vidUnsafe 3 2 3 2 3 4 . vid2colUnsafe 3 2 3 2 3 4 $ input
|
||||
in input === out
|
||||
|
||||
|
||||
-- This test show that 2D convs act the same
|
||||
-- 3D convs with one layer
|
||||
prop_single_conv_forwards = once $
|
||||
@ -216,36 +225,36 @@ prop_single_conv_forwards = once $
|
||||
[ 1.0, 2.0, 5.0
|
||||
, 3.0, 4.0, 6.0] :: HStatic.L 2 3] ) :: S' ('D3 2 3 1)
|
||||
|
||||
expect = ([(HStatic.matrix
|
||||
[ -3.0 , -4.0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ -1.0 , 1.0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 5.0 , 9.0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ -7.0 , -10.0 ] :: HStatic.L 1 2)]) :: [HStatic.L 1 2]
|
||||
expect = [HStatic.matrix
|
||||
[ -3.0 , -4.0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ -1.0 , 1.0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ 5.0 , 9.0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ -7.0 , -10.0 ] :: HStatic.L 1 2] :: [HStatic.L 1 2]
|
||||
out = runForwards convLayer input :: S' ('D3 1 2 4)
|
||||
|
||||
grad = S3D' ( mkVector
|
||||
[(HStatic.matrix
|
||||
[ 1 , 0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2)
|
||||
,(HStatic.matrix
|
||||
[ 0 , 1 ] :: HStatic.L 1 2)] ) :: S' ('D3 1 2 4)
|
||||
[HStatic.matrix
|
||||
[ 1 , 0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ 0 , 0 ] :: HStatic.L 1 2
|
||||
,HStatic.matrix
|
||||
[ 0 , 1 ] :: HStatic.L 1 2] ) :: S' ('D3 1 2 4)
|
||||
|
||||
expectBack = (HStatic.matrix
|
||||
[ 1.0, 0.0, 0.0
|
||||
, 0.0, -2.0,-1.0] :: HStatic.L 2 3)
|
||||
(nc, inX) = runBackards convLayer input grad
|
||||
(nc, inX) = runBackwards convLayer input grad
|
||||
|
||||
in case (out, inX, nc) of
|
||||
(S3D' out' , S3D' inX', Convolution' backGrad)
|
||||
-> ((HStatic.extract <$> expect) === (HStatic.extract <$> vecToList out'))
|
||||
.&&. ([HStatic.extract expectBack] === (HStatic.extract <$> vecToList inX'))
|
||||
.&&. ((HStatic.extract expectedGradient) === (HStatic.extract backGrad))
|
||||
.&&. (HStatic.extract expectedGradient === HStatic.extract backGrad)
|
||||
|
||||
return []
|
||||
tests :: IO Bool
|
||||
|
@ -4,7 +4,7 @@
|
||||
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
|
||||
module Test.Grenade.Layers.Pooling where
|
||||
|
||||
import Grenade.Layers.Pooling
|
||||
import Grenade.Layers.Internal.Pooling
|
||||
|
||||
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
|
||||
|
||||
@ -12,15 +12,26 @@ import Test.QuickCheck hiding ((><))
|
||||
|
||||
prop_pool = once $
|
||||
let input = (3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
[ 1.0, 14.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
expected = (2><3)
|
||||
[ 6.0, 7.0, 8.0
|
||||
[ 14.0, 14.0, 8.0
|
||||
, 10.0, 11.0, 12.0 ]
|
||||
out = poolForward 2 2 1 1 2 3 input
|
||||
in expected === out
|
||||
|
||||
prop_pool_rectangular = once $
|
||||
let input = (3><4)
|
||||
[ 1.0, 14.0, 3.0, 4.0
|
||||
, 5.0, 6.0, 7.0, 8.0
|
||||
, 9.0, 10.0, 11.0, 12.0 ]
|
||||
expected = (2><2)
|
||||
[ 14.0, 14.0
|
||||
, 11.0, 12.0 ]
|
||||
out = poolForward 2 3 1 1 2 2 input
|
||||
in expected === out
|
||||
|
||||
prop_pool_backwards = once $
|
||||
let input = (3><4)
|
||||
[ 1.0, 2.0, 3.0, 4.0
|
||||
|
Loading…
Reference in New Issue
Block a user