mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Tests
This commit is contained in:
parent
b090b5f073
commit
670f2d952f
@ -40,7 +40,6 @@ library
|
|||||||
Grenade.Core.Shape
|
Grenade.Core.Shape
|
||||||
Grenade.Layers.Crop
|
Grenade.Layers.Crop
|
||||||
Grenade.Layers.Convolution
|
Grenade.Layers.Convolution
|
||||||
Grenade.Layers.Convolution.Internal
|
|
||||||
Grenade.Layers.Dropout
|
Grenade.Layers.Dropout
|
||||||
Grenade.Layers.FullyConnected
|
Grenade.Layers.FullyConnected
|
||||||
Grenade.Layers.Flatten
|
Grenade.Layers.Flatten
|
||||||
@ -51,6 +50,9 @@ library
|
|||||||
Grenade.Layers.Pad
|
Grenade.Layers.Pad
|
||||||
Grenade.Layers.Pooling
|
Grenade.Layers.Pooling
|
||||||
|
|
||||||
|
Grenade.Layers.Internal.Convolution
|
||||||
|
Grenade.Layers.Internal.Pooling
|
||||||
|
|
||||||
|
|
||||||
executable feedforward
|
executable feedforward
|
||||||
ghc-options: -Wall -threaded -O2
|
ghc-options: -Wall -threaded -O2
|
||||||
|
@ -31,7 +31,7 @@ import Numeric.LinearAlgebra.Static hiding ((|||), build, toRows)
|
|||||||
import Grenade.Core.Network
|
import Grenade.Core.Network
|
||||||
import Grenade.Core.Shape
|
import Grenade.Core.Shape
|
||||||
import Grenade.Core.Vector
|
import Grenade.Core.Vector
|
||||||
import Grenade.Layers.Convolution.Internal
|
import Grenade.Layers.Internal.Convolution
|
||||||
|
|
||||||
-- | A convolution layer for a neural network.
|
-- | A convolution layer for a neural network.
|
||||||
-- This uses the im2col convolution trick popularised by Caffe, which essentially turns the
|
-- This uses the im2col convolution trick popularised by Caffe, which essentially turns the
|
||||||
@ -149,15 +149,13 @@ instance ( KnownNat kernelRows
|
|||||||
runForwards (Convolution kernel _) (S2D' input) =
|
runForwards (Convolution kernel _) (S2D' input) =
|
||||||
let ex = extract input
|
let ex = extract input
|
||||||
ek = extract kernel
|
ek = extract kernel
|
||||||
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
|
|
||||||
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
|
|
||||||
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
|
||||||
ky = fromIntegral $ natVal (Proxy :: Proxy kernelCols)
|
ky = fromIntegral $ natVal (Proxy :: Proxy kernelCols)
|
||||||
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
|
||||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||||
c = im2colUnsafe kx ky sx sy ix iy ex
|
c = im2colUnsafe kx ky sx sy ex
|
||||||
mt = c LA.<> ek
|
mt = c LA.<> ek
|
||||||
r = col2vidUnsafe 1 1 1 1 ox oy mt
|
r = col2vidUnsafe 1 1 1 1 ox oy mt
|
||||||
rs = fmap (fromJust . create) r
|
rs = fmap (fromJust . create) r
|
||||||
@ -173,14 +171,13 @@ instance ( KnownNat kernelRows
|
|||||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||||
fl = fromIntegral $ natVal (Proxy :: Proxy filters)
|
|
||||||
|
|
||||||
c = im2colUnsafe kx ky sx sy ix iy ex
|
c = im2colUnsafe kx ky sx sy ex
|
||||||
|
|
||||||
eo = vecToList $ fmap extract dEdy
|
eo = vecToList $ fmap extract dEdy
|
||||||
ek = extract kernel
|
ek = extract kernel
|
||||||
|
|
||||||
vs = vid2colUnsafe fl 1 1 1 1 ox oy eo
|
vs = vid2colUnsafe 1 1 1 1 ox oy eo
|
||||||
|
|
||||||
kN = fromJust . create $ tr c LA.<> vs
|
kN = fromJust . create $ tr c LA.<> vs
|
||||||
dW = vs LA.<> tr ek
|
dW = vs LA.<> tr ek
|
||||||
@ -216,8 +213,8 @@ instance ( KnownNat kernelRows
|
|||||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||||
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
|
|
||||||
c = vid2colUnsafe ch kx ky sx sy ix iy ex
|
c = vid2colUnsafe kx ky sx sy ix iy ex
|
||||||
mt = c LA.<> ek
|
mt = c LA.<> ek
|
||||||
r = col2vidUnsafe 1 1 1 1 ox oy mt
|
r = col2vidUnsafe 1 1 1 1 ox oy mt
|
||||||
rs = fmap (fromJust . create) r
|
rs = fmap (fromJust . create) r
|
||||||
@ -232,14 +229,13 @@ instance ( KnownNat kernelRows
|
|||||||
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
|
||||||
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
|
||||||
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
|
||||||
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
|
|
||||||
fl = fromIntegral $ natVal (Proxy :: Proxy filters)
|
c = vid2colUnsafe kx ky sx sy ix iy ex
|
||||||
c = vid2colUnsafe ch kx ky sx sy ix iy ex
|
|
||||||
|
|
||||||
eo = vecToList $ fmap extract dEdy
|
eo = vecToList $ fmap extract dEdy
|
||||||
ek = extract kernel
|
ek = extract kernel
|
||||||
|
|
||||||
vs = vid2colUnsafe fl 1 1 1 1 ox oy eo
|
vs = vid2colUnsafe 1 1 1 1 ox oy eo
|
||||||
|
|
||||||
kN = fromJust . create $ tr c LA.<> vs
|
kN = fromJust . create $ tr c LA.<> vs
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
module Grenade.Layers.Convolution.Internal (
|
module Grenade.Layers.Internal.Convolution (
|
||||||
im2col
|
im2col
|
||||||
-- , im2colUnsafe
|
-- , im2colUnsafe
|
||||||
, vid2col
|
, vid2col
|
||||||
@ -72,7 +72,7 @@ fittingStart width kernel steps =
|
|||||||
| left + kernel == width
|
| left + kernel == width
|
||||||
= [left]
|
= [left]
|
||||||
| otherwise
|
| otherwise
|
||||||
= error "Kernel and step do not fit in matrix."
|
= [] -- error "Kernel and step do not fit in matrix."
|
||||||
in go 0
|
in go 0
|
||||||
|
|
||||||
col2imFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
col2imFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
@ -119,14 +119,9 @@ col2imFit starts krows kcols drows dcols m =
|
|||||||
-- }
|
-- }
|
||||||
-- }
|
-- }
|
||||||
|
|
||||||
|
|
||||||
-- let starts = fittingStart (cols m) (krows * kcols) (krows * kcols)
|
|
||||||
-- r = rows m
|
|
||||||
-- mats = fmap (\s -> subMatrix (0,s) (r, krows * kcols) m) starts
|
|
||||||
-- in parMap rseq (col2imUnsafe krows kcols srows scols drows dcols) mats
|
|
||||||
|
|
||||||
col2imUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
col2imUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = U.runSTMatrix $ do
|
col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = U.runSTMatrix $ do
|
||||||
|
|
||||||
let columnMatrixRows = rows columnMatrix
|
let columnMatrixRows = rows columnMatrix
|
||||||
|
|
||||||
dataIm <- U.newMatrix 0 destinationRows destinationCols
|
dataIm <- U.newMatrix 0 destinationRows destinationCols
|
||||||
@ -134,15 +129,15 @@ col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows d
|
|||||||
offsetR <- newSTRef 0
|
offsetR <- newSTRef 0
|
||||||
offsetC <- newSTRef 0
|
offsetC <- newSTRef 0
|
||||||
|
|
||||||
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
|
forM_ [0 .. columnMatrixRows - 1] $ \inputRow -> do
|
||||||
inputColumn <- newSTRef 0
|
inputColumnRef <- newSTRef 0
|
||||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||||
ic <- readSTRef inputColumn
|
inputColumn <- readSTRef inputColumnRef
|
||||||
offsetR' <- readSTRef offsetR
|
offsetR' <- readSTRef offsetR
|
||||||
offsetC' <- readSTRef offsetC
|
offsetC' <- readSTRef offsetC
|
||||||
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ atIndex columnMatrix (ir,ic))
|
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ U.atM' columnMatrix inputRow inputColumn)
|
||||||
modifySTRef inputColumn (+1)
|
modifySTRef inputColumnRef (+1)
|
||||||
|
|
||||||
offsetC' <- readSTRef offsetC
|
offsetC' <- readSTRef offsetC
|
||||||
if offsetC' + kernelColumns < destinationCols
|
if offsetC' + kernelColumns < destinationCols
|
||||||
@ -158,12 +153,12 @@ col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows
|
|||||||
|
|
||||||
dataIms <- traverse (\_ -> U.newMatrix 0 destinationRows destinationCols) [0 .. filters-1]
|
dataIms <- traverse (\_ -> U.newMatrix 0 destinationRows destinationCols) [0 .. filters-1]
|
||||||
|
|
||||||
offsetR <- newSTRef 0
|
|
||||||
offsetC <- newSTRef 0
|
|
||||||
offsetM <- newSTRef 0
|
offsetM <- newSTRef 0
|
||||||
|
|
||||||
forM_ dataIms $ \dataIm -> do
|
forM_ dataIms $ \dataIm -> do
|
||||||
offsetM' <- readSTRef offsetM
|
offsetR <- newSTRef 0
|
||||||
|
offsetC <- newSTRef 0
|
||||||
|
offsetM' <- readSTRef offsetM
|
||||||
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
|
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
|
||||||
inputColumn <- newSTRef 0
|
inputColumn <- newSTRef 0
|
||||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||||
@ -171,7 +166,7 @@ col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows
|
|||||||
ic <- readSTRef inputColumn
|
ic <- readSTRef inputColumn
|
||||||
offsetR' <- readSTRef offsetR
|
offsetR' <- readSTRef offsetR
|
||||||
offsetC' <- readSTRef offsetC
|
offsetC' <- readSTRef offsetC
|
||||||
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ atIndex columnMatrix (ir, ic + offsetM'))
|
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ U.atM' columnMatrix ir (ic + offsetM'))
|
||||||
modifySTRef inputColumn (+1)
|
modifySTRef inputColumn (+1)
|
||||||
|
|
||||||
offsetC' <- readSTRef offsetC
|
offsetC' <- readSTRef offsetC
|
||||||
@ -179,60 +174,55 @@ col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows
|
|||||||
then modifySTRef offsetC (+ strideColumns)
|
then modifySTRef offsetC (+ strideColumns)
|
||||||
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
|
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
|
||||||
|
|
||||||
writeSTRef offsetR 0
|
|
||||||
writeSTRef offsetC 0
|
|
||||||
modifySTRef offsetM (+ (kernelRows * kernelColumns))
|
modifySTRef offsetM (+ (kernelRows * kernelColumns))
|
||||||
|
|
||||||
traverse U.freezeMatrix dataIms
|
traverse U.unsafeFreezeMatrix dataIms
|
||||||
|
|
||||||
vid2colUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
|
vid2colUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
|
||||||
vid2colUnsafe channels kernelRows kernelColumns striderows stridecols vidrows vidcols dataVid = U.runSTMatrix $ do
|
vid2colUnsafe kernelRows kernelColumns striderows stridecols vidrows vidcols dataVid = U.runSTMatrix $ do
|
||||||
let starts = fittingStarts vidrows kernelRows striderows vidcols kernelColumns stridecols
|
let starts = fittingStarts vidrows kernelRows striderows vidcols kernelColumns stridecols
|
||||||
matWidth = kernelRows * kernelColumns
|
kernelSize = kernelRows * kernelColumns
|
||||||
destinationRows = 1 + (vidrows - kernelRows) `div` striderows
|
numberOfPatches = length starts
|
||||||
destinationCols = 1 + (vidcols - kernelColumns) `div` stridecols
|
channels = length dataVid
|
||||||
destinationSize = destinationRows * destinationCols
|
|
||||||
|
|
||||||
dataCol <- U.newMatrix 0 destinationSize (channels * matWidth)
|
dataCol <- U.newMatrix 0 numberOfPatches (channels * kernelSize)
|
||||||
|
|
||||||
offsetC <- newSTRef 0
|
offsetC <- newSTRef 0
|
||||||
|
|
||||||
forM_ dataVid $ \dataIm -> do
|
forM_ dataVid $ \dataIm -> do
|
||||||
inputRow <- newSTRef 0
|
inputRowRef <- newSTRef 0
|
||||||
offsetC' <- readSTRef offsetC
|
offsetC' <- readSTRef offsetC
|
||||||
forM_ starts $ \(startRow, startCol) -> do
|
forM_ starts $ \(startRow, startCol) -> do
|
||||||
inputColumn <- newSTRef 0
|
inputColumnRef <- newSTRef 0
|
||||||
inputRow' <- readSTRef inputRow
|
inputRow <- readSTRef inputRowRef
|
||||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||||
inputColumn' <- readSTRef inputColumn
|
inputColumn <- readSTRef inputColumnRef
|
||||||
U.modifyMatrix dataCol inputRow' (inputColumn' + offsetC') (+ atIndex dataIm (kr + startRow, kc + startCol))
|
U.modifyMatrix dataCol inputRow (inputColumn + offsetC') (+ U.atM' dataIm (kr + startRow) (kc + startCol))
|
||||||
modifySTRef inputColumn (+1)
|
modifySTRef inputColumnRef (+1)
|
||||||
modifySTRef inputRow (+1)
|
modifySTRef inputRowRef (+1)
|
||||||
|
|
||||||
modifySTRef offsetC (+ matWidth)
|
modifySTRef offsetC (+ kernelSize)
|
||||||
|
|
||||||
return dataCol
|
return dataCol
|
||||||
|
|
||||||
im2colUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
im2colUnsafe :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
im2colUnsafe kernelRows kernelColumns striderows stridecols vidrows vidcols dataIm = U.runSTMatrix $ do
|
im2colUnsafe kernelRows kernelColumns striderows stridecols dataIm = U.runSTMatrix $ do
|
||||||
let starts = fittingStarts vidrows kernelRows striderows vidcols kernelColumns stridecols
|
let starts = fittingStarts (rows dataIm) kernelRows striderows (cols dataIm) kernelColumns stridecols
|
||||||
matWidth = kernelRows * kernelColumns
|
kernelSize = kernelRows * kernelColumns
|
||||||
destinationRows = 1 + (vidrows - kernelRows) `div` striderows
|
numberOfPatches = length starts
|
||||||
destinationCols = 1 + (vidcols - kernelColumns) `div` stridecols
|
|
||||||
destinationSize = destinationRows * destinationCols
|
|
||||||
|
|
||||||
dataCol <- U.newMatrix 0 destinationSize matWidth
|
dataCol <- U.newMatrix 0 numberOfPatches kernelSize
|
||||||
|
|
||||||
inputRow <- newSTRef 0
|
inputRowRef <- newSTRef 0
|
||||||
forM_ starts $ \(startRow, startCol) -> do
|
forM_ starts $ \(startRow, startCol) -> do
|
||||||
inputColumn <- newSTRef 0
|
inputColumnRef <- newSTRef 0
|
||||||
inputRow' <- readSTRef inputRow
|
inputRow <- readSTRef inputRowRef
|
||||||
forM_ [0 .. kernelRows -1] $ \kr ->
|
forM_ [0 .. kernelRows -1] $ \kr ->
|
||||||
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
forM_ [0 .. kernelColumns -1] $ \kc -> do
|
||||||
inputColumn' <- readSTRef inputColumn
|
inputColumn <- readSTRef inputColumnRef
|
||||||
U.modifyMatrix dataCol inputRow' inputColumn' (+ atIndex dataIm (kr + startRow, kc + startCol))
|
U.modifyMatrix dataCol inputRow inputColumn (+ U.atM' dataIm (kr + startRow) (kc + startCol))
|
||||||
modifySTRef inputColumn (+1)
|
modifySTRef inputColumnRef (+1)
|
||||||
modifySTRef inputRow (+1)
|
modifySTRef inputRowRef (+1)
|
||||||
|
|
||||||
return dataCol
|
return dataCol
|
48
src/Grenade/Layers/Internal/Pooling.hs
Normal file
48
src/Grenade/Layers/Internal/Pooling.hs
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
module Grenade.Layers.Internal.Pooling (
|
||||||
|
poolForward
|
||||||
|
, poolBackward
|
||||||
|
, poolForwardList
|
||||||
|
, poolBackwardList
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
|
||||||
|
import qualified Numeric.LinearAlgebra as LA
|
||||||
|
|
||||||
|
import Grenade.Layers.Internal.Convolution
|
||||||
|
|
||||||
|
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
poolForward nrows ncols srows scols outputRows outputCols m =
|
||||||
|
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
|
||||||
|
in poolForwardFit starts nrows ncols outputRows outputCols m
|
||||||
|
|
||||||
|
poolForwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double) -> f (Matrix Double)
|
||||||
|
poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
|
||||||
|
let starts = fittingStarts inRows nrows srows inCols ncols scols
|
||||||
|
in poolForwardFit starts nrows ncols outputRows outputCols <$> ms
|
||||||
|
|
||||||
|
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
||||||
|
poolForwardFit starts nrows ncols _ outputCols m =
|
||||||
|
let els = fmap (\start -> maxElement $ subMatrix start (nrows, ncols) m) starts
|
||||||
|
in LA.matrix outputCols els
|
||||||
|
|
||||||
|
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||||
|
poolBackward krows kcols srows scols inputMatrix gradientMatrix =
|
||||||
|
let inRows = rows inputMatrix
|
||||||
|
inCols = cols inputMatrix
|
||||||
|
starts = fittingStarts inRows krows srows inCols kcols scols
|
||||||
|
in poolBackwardFit starts krows kcols inputMatrix gradientMatrix
|
||||||
|
|
||||||
|
poolBackwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double, Matrix Double) -> f (Matrix Double)
|
||||||
|
poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
|
||||||
|
let starts = fittingStarts inRows krows srows inCols kcols scols
|
||||||
|
in uncurry (poolBackwardFit starts krows kcols) <$> inputMatrices
|
||||||
|
|
||||||
|
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
||||||
|
poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
|
||||||
|
let inRows = rows inputMatrix
|
||||||
|
inCols = cols inputMatrix
|
||||||
|
inds = fmap (\start -> maxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
|
||||||
|
grads = toList $ flatten gradientMatrix
|
||||||
|
grads' = zip3 starts grads inds
|
||||||
|
accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
|
||||||
|
in accum (LA.konst 0 (inRows, inCols)) (+) accums
|
@ -24,10 +24,8 @@ import GHC.TypeLits
|
|||||||
import Grenade.Core.Network
|
import Grenade.Core.Network
|
||||||
import Grenade.Core.Shape
|
import Grenade.Core.Shape
|
||||||
import Grenade.Core.Vector
|
import Grenade.Core.Vector
|
||||||
import Grenade.Layers.Convolution.Internal
|
import Grenade.Layers.Internal.Pooling
|
||||||
|
|
||||||
import Numeric.LinearAlgebra hiding (uniformSample)
|
|
||||||
import qualified Numeric.LinearAlgebra as LA
|
|
||||||
import Numeric.LinearAlgebra.Static as LAS hiding ((|||), build, toRows)
|
import Numeric.LinearAlgebra.Static as LAS hiding ((|||), build, toRows)
|
||||||
|
|
||||||
-- | A pooling layer for a neural network.
|
-- | A pooling layer for a neural network.
|
||||||
@ -37,16 +35,12 @@ import Numeric.LinearAlgebra.Static as LAS hiding ((|||), build, toRow
|
|||||||
-- The kernel size dictates which input and output sizes will "fit". Fitting the equation:
|
-- The kernel size dictates which input and output sizes will "fit". Fitting the equation:
|
||||||
-- `out = (in - kernel) / stride + 1` for both dimensions.
|
-- `out = (in - kernel) / stride + 1` for both dimensions.
|
||||||
--
|
--
|
||||||
data Pooling :: Nat
|
data Pooling :: Nat -> Nat -> Nat -> Nat -> * where
|
||||||
-> Nat
|
|
||||||
-> Nat
|
|
||||||
-> Nat -> * where
|
|
||||||
Pooling :: Pooling kernelRows kernelColumns strideRows strideColumns
|
Pooling :: Pooling kernelRows kernelColumns strideRows strideColumns
|
||||||
|
|
||||||
instance Show (Pooling k k' s s') where
|
instance Show (Pooling k k' s s') where
|
||||||
show Pooling = "Pooling"
|
show Pooling = "Pooling"
|
||||||
|
|
||||||
|
|
||||||
instance UpdateLayer (Pooling kernelRows kernelColumns strideRows strideColumns) where
|
instance UpdateLayer (Pooling kernelRows kernelColumns strideRows strideColumns) where
|
||||||
type Gradient (Pooling kr kc sr sc) = ()
|
type Gradient (Pooling kr kc sr sc) = ()
|
||||||
runUpdate _ Pooling _ = Pooling
|
runUpdate _ Pooling _ = Pooling
|
||||||
@ -123,40 +117,3 @@ instance ( KnownNat kernelRows
|
|||||||
ez = vectorZip (,) ex eo
|
ez = vectorZip (,) ex eo
|
||||||
vs = poolBackwardList kx ky sx sy ix iy ez
|
vs = poolBackwardList kx ky sx sy ix iy ez
|
||||||
in ((), S3D' . fmap (fromJust . create) $ vs)
|
in ((), S3D' . fmap (fromJust . create) $ vs)
|
||||||
|
|
||||||
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
|
||||||
poolForward nrows ncols srows scols outputRows outputCols m =
|
|
||||||
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
|
|
||||||
in poolForwardFit starts nrows ncols outputRows outputCols m
|
|
||||||
|
|
||||||
poolForwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double) -> f (Matrix Double)
|
|
||||||
poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
|
|
||||||
let starts = fittingStarts inRows nrows srows inCols ncols scols
|
|
||||||
in poolForwardFit starts nrows ncols outputRows outputCols <$> ms
|
|
||||||
|
|
||||||
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
|
|
||||||
poolForwardFit starts nrows ncols _ outputCols m =
|
|
||||||
let els = fmap (\start -> maxElement $ subMatrix start (nrows, ncols) m) starts
|
|
||||||
in LA.matrix outputCols els
|
|
||||||
|
|
||||||
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
|
||||||
poolBackward krows kcols srows scols inputMatrix gradientMatrix =
|
|
||||||
let inRows = (rows inputMatrix)
|
|
||||||
inCols = (cols inputMatrix)
|
|
||||||
starts = fittingStarts inRows krows srows inCols kcols scols
|
|
||||||
in poolBackwardFit starts krows kcols inputMatrix gradientMatrix
|
|
||||||
|
|
||||||
poolBackwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double, Matrix Double) -> f (Matrix Double)
|
|
||||||
poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
|
|
||||||
let starts = fittingStarts inRows krows srows inCols kcols scols
|
|
||||||
in (uncurry $ poolBackwardFit starts krows kcols) <$> inputMatrices
|
|
||||||
|
|
||||||
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
|
|
||||||
poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
|
|
||||||
let inRows = (rows inputMatrix)
|
|
||||||
inCols = (cols inputMatrix)
|
|
||||||
inds = fmap (\start -> maxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
|
|
||||||
grads = toList $ flatten gradientMatrix
|
|
||||||
grads' = zip3 starts grads inds
|
|
||||||
accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
|
|
||||||
in accum (LA.konst 0 (inRows, inCols)) (+) accums
|
|
||||||
|
@ -8,7 +8,7 @@ import Grenade.Core.Shape
|
|||||||
import Grenade.Core.Vector as Grenade
|
import Grenade.Core.Vector as Grenade
|
||||||
import Grenade.Core.Network
|
import Grenade.Core.Network
|
||||||
import Grenade.Layers.Convolution
|
import Grenade.Layers.Convolution
|
||||||
import Grenade.Layers.Convolution.Internal
|
import Grenade.Layers.Internal.Convolution
|
||||||
|
|
||||||
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
|
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
|
||||||
import qualified Numeric.LinearAlgebra.Static as HStatic
|
import qualified Numeric.LinearAlgebra.Static as HStatic
|
||||||
@ -51,7 +51,7 @@ prop_im2col_other = once $
|
|||||||
expected = (2><6)
|
expected = (2><6)
|
||||||
[ 1.0, 2.0, 5.0, 6.0 , 9.0, 10.0
|
[ 1.0, 2.0, 5.0, 6.0 , 9.0, 10.0
|
||||||
, 3.0, 4.0, 7.0, 8.0 , 11.0 ,12.0 ]
|
, 3.0, 4.0, 7.0, 8.0 , 11.0 ,12.0 ]
|
||||||
out = im2col 3 2 1 2 input
|
out = im2colUnsafe 3 2 1 2 input
|
||||||
in expected === out
|
in expected === out
|
||||||
|
|
||||||
-- If there's no overlap (stride is the same size as the kernel)
|
-- If there's no overlap (stride is the same size as the kernel)
|
||||||
@ -71,7 +71,7 @@ prop_im2colunsafe_sym_on_same_stride = once $
|
|||||||
[ 1.0, 2.0, 3.0, 4.0
|
[ 1.0, 2.0, 3.0, 4.0
|
||||||
, 5.0, 6.0, 7.0, 8.0
|
, 5.0, 6.0, 7.0, 8.0
|
||||||
, 9.0, 10.0, 11.0, 12.0 ]
|
, 9.0, 10.0, 11.0, 12.0 ]
|
||||||
out = col2imUnsafe 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 3 4 $ input
|
out = col2imUnsafe 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 $ input
|
||||||
in input === out
|
in input === out
|
||||||
|
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ prop_im2col_col2im_additive = once $
|
|||||||
[ 1.0, 2.0, 2.0, 1.0
|
[ 1.0, 2.0, 2.0, 1.0
|
||||||
, 2.0, 4.0, 4.0, 2.0
|
, 2.0, 4.0, 4.0, 2.0
|
||||||
, 1.0, 2.0, 2.0, 1.0 ]
|
, 1.0, 2.0, 2.0, 1.0 ]
|
||||||
out = col2im 2 2 1 1 3 4 . im2col 2 2 1 1 $ input
|
out = col2imUnsafe 2 2 1 1 3 4 . im2colUnsafe 2 2 1 1 $ input
|
||||||
in expected === out
|
in expected === out
|
||||||
|
|
||||||
prop_simple_conv_forwards = once $
|
prop_simple_conv_forwards = once $
|
||||||
@ -166,7 +166,7 @@ prop_vid2col_no_stride = once $
|
|||||||
, 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0
|
, 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0
|
||||||
, 6.0, 7.0, 10.0, 11.0 , 26.0, 27.0, 30.0, 31.0
|
, 6.0, 7.0, 10.0, 11.0 , 26.0, 27.0, 30.0, 31.0
|
||||||
, 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ]
|
, 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ]
|
||||||
out = vid2col 2 2 1 1 3 4 input
|
out = vid2colUnsafe 2 2 1 1 3 4 input
|
||||||
in expected === out
|
in expected === out
|
||||||
|
|
||||||
prop_vid2col_stride = once $
|
prop_vid2col_stride = once $
|
||||||
@ -183,7 +183,7 @@ prop_vid2col_stride = once $
|
|||||||
, 3.0, 4.0, 7.0, 8.0 , 23.0, 24.0, 27.0, 28.0
|
, 3.0, 4.0, 7.0, 8.0 , 23.0, 24.0, 27.0, 28.0
|
||||||
, 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0
|
, 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0
|
||||||
, 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ]
|
, 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ]
|
||||||
out = vid2col 2 2 1 2 3 4 input
|
out = vid2colUnsafe 2 2 1 2 3 4 input
|
||||||
in expected === out
|
in expected === out
|
||||||
|
|
||||||
|
|
||||||
@ -208,7 +208,7 @@ prop_vid2col_invert_unsafe = once $
|
|||||||
[ 21.0, 22.0, 23.0, 24.0
|
[ 21.0, 22.0, 23.0, 24.0
|
||||||
, 25.0, 26.0, 27.0, 28.0
|
, 25.0, 26.0, 27.0, 28.0
|
||||||
, 29.0, 30.0, 31.0, 32.0 ] ]
|
, 29.0, 30.0, 31.0, 32.0 ] ]
|
||||||
out = col2vidUnsafe 3 2 3 2 3 4 . vid2colUnsafe 2 3 2 3 2 3 4 $ input
|
out = col2vidUnsafe 3 2 3 2 3 4 . vid2colUnsafe 3 2 3 2 3 4 $ input
|
||||||
in input === out
|
in input === out
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user