This commit is contained in:
Huw Campbell 2016-12-08 09:03:29 +11:00
parent b090b5f073
commit 670f2d952f
6 changed files with 111 additions and 118 deletions

View File

@ -40,7 +40,6 @@ library
Grenade.Core.Shape Grenade.Core.Shape
Grenade.Layers.Crop Grenade.Layers.Crop
Grenade.Layers.Convolution Grenade.Layers.Convolution
Grenade.Layers.Convolution.Internal
Grenade.Layers.Dropout Grenade.Layers.Dropout
Grenade.Layers.FullyConnected Grenade.Layers.FullyConnected
Grenade.Layers.Flatten Grenade.Layers.Flatten
@ -51,6 +50,9 @@ library
Grenade.Layers.Pad Grenade.Layers.Pad
Grenade.Layers.Pooling Grenade.Layers.Pooling
Grenade.Layers.Internal.Convolution
Grenade.Layers.Internal.Pooling
executable feedforward executable feedforward
ghc-options: -Wall -threaded -O2 ghc-options: -Wall -threaded -O2

View File

@ -31,7 +31,7 @@ import Numeric.LinearAlgebra.Static hiding ((|||), build, toRows)
import Grenade.Core.Network import Grenade.Core.Network
import Grenade.Core.Shape import Grenade.Core.Shape
import Grenade.Core.Vector import Grenade.Core.Vector
import Grenade.Layers.Convolution.Internal import Grenade.Layers.Internal.Convolution
-- | A convolution layer for a neural network. -- | A convolution layer for a neural network.
-- This uses the im2col convolution trick popularised by Caffe, which essentially turns the -- This uses the im2col convolution trick popularised by Caffe, which essentially turns the
@ -149,15 +149,13 @@ instance ( KnownNat kernelRows
runForwards (Convolution kernel _) (S2D' input) = runForwards (Convolution kernel _) (S2D' input) =
let ex = extract input let ex = extract input
ek = extract kernel ek = extract kernel
ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
ky = fromIntegral $ natVal (Proxy :: Proxy kernelCols) ky = fromIntegral $ natVal (Proxy :: Proxy kernelCols)
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows) sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols) sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows) ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols) oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
c = im2colUnsafe kx ky sx sy ix iy ex c = im2colUnsafe kx ky sx sy ex
mt = c LA.<> ek mt = c LA.<> ek
r = col2vidUnsafe 1 1 1 1 ox oy mt r = col2vidUnsafe 1 1 1 1 ox oy mt
rs = fmap (fromJust . create) r rs = fmap (fromJust . create) r
@ -173,14 +171,13 @@ instance ( KnownNat kernelRows
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols) sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows) ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols) oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
fl = fromIntegral $ natVal (Proxy :: Proxy filters)
c = im2colUnsafe kx ky sx sy ix iy ex c = im2colUnsafe kx ky sx sy ex
eo = vecToList $ fmap extract dEdy eo = vecToList $ fmap extract dEdy
ek = extract kernel ek = extract kernel
vs = vid2colUnsafe fl 1 1 1 1 ox oy eo vs = vid2colUnsafe 1 1 1 1 ox oy eo
kN = fromJust . create $ tr c LA.<> vs kN = fromJust . create $ tr c LA.<> vs
dW = vs LA.<> tr ek dW = vs LA.<> tr ek
@ -216,8 +213,8 @@ instance ( KnownNat kernelRows
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols) sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows) ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols) oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
c = vid2colUnsafe ch kx ky sx sy ix iy ex c = vid2colUnsafe kx ky sx sy ix iy ex
mt = c LA.<> ek mt = c LA.<> ek
r = col2vidUnsafe 1 1 1 1 ox oy mt r = col2vidUnsafe 1 1 1 1 ox oy mt
rs = fmap (fromJust . create) r rs = fmap (fromJust . create) r
@ -232,14 +229,13 @@ instance ( KnownNat kernelRows
sy = fromIntegral $ natVal (Proxy :: Proxy strideCols) sy = fromIntegral $ natVal (Proxy :: Proxy strideCols)
ox = fromIntegral $ natVal (Proxy :: Proxy outputRows) ox = fromIntegral $ natVal (Proxy :: Proxy outputRows)
oy = fromIntegral $ natVal (Proxy :: Proxy outputCols) oy = fromIntegral $ natVal (Proxy :: Proxy outputCols)
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
fl = fromIntegral $ natVal (Proxy :: Proxy filters) c = vid2colUnsafe kx ky sx sy ix iy ex
c = vid2colUnsafe ch kx ky sx sy ix iy ex
eo = vecToList $ fmap extract dEdy eo = vecToList $ fmap extract dEdy
ek = extract kernel ek = extract kernel
vs = vid2colUnsafe fl 1 1 1 1 ox oy eo vs = vid2colUnsafe 1 1 1 1 ox oy eo
kN = fromJust . create $ tr c LA.<> vs kN = fromJust . create $ tr c LA.<> vs

View File

@ -1,4 +1,4 @@
module Grenade.Layers.Convolution.Internal ( module Grenade.Layers.Internal.Convolution (
im2col im2col
-- , im2colUnsafe -- , im2colUnsafe
, vid2col , vid2col
@ -72,7 +72,7 @@ fittingStart width kernel steps =
| left + kernel == width | left + kernel == width
= [left] = [left]
| otherwise | otherwise
= error "Kernel and step do not fit in matrix." = [] -- error "Kernel and step do not fit in matrix."
in go 0 in go 0
col2imFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double col2imFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
@ -119,14 +119,9 @@ col2imFit starts krows kcols drows dcols m =
-- } -- }
-- } -- }
-- let starts = fittingStart (cols m) (krows * kcols) (krows * kcols)
-- r = rows m
-- mats = fmap (\s -> subMatrix (0,s) (r, krows * kcols) m) starts
-- in parMap rseq (col2imUnsafe krows kcols srows scols drows dcols) mats
col2imUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double col2imUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = U.runSTMatrix $ do col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows destinationCols columnMatrix = U.runSTMatrix $ do
let columnMatrixRows = rows columnMatrix let columnMatrixRows = rows columnMatrix
dataIm <- U.newMatrix 0 destinationRows destinationCols dataIm <- U.newMatrix 0 destinationRows destinationCols
@ -134,15 +129,15 @@ col2imUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows d
offsetR <- newSTRef 0 offsetR <- newSTRef 0
offsetC <- newSTRef 0 offsetC <- newSTRef 0
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do forM_ [0 .. columnMatrixRows - 1] $ \inputRow -> do
inputColumn <- newSTRef 0 inputColumnRef <- newSTRef 0
forM_ [0 .. kernelRows -1] $ \kr -> forM_ [0 .. kernelRows -1] $ \kr ->
forM_ [0 .. kernelColumns -1] $ \kc -> do forM_ [0 .. kernelColumns -1] $ \kc -> do
ic <- readSTRef inputColumn inputColumn <- readSTRef inputColumnRef
offsetR' <- readSTRef offsetR offsetR' <- readSTRef offsetR
offsetC' <- readSTRef offsetC offsetC' <- readSTRef offsetC
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ atIndex columnMatrix (ir,ic)) U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ U.atM' columnMatrix inputRow inputColumn)
modifySTRef inputColumn (+1) modifySTRef inputColumnRef (+1)
offsetC' <- readSTRef offsetC offsetC' <- readSTRef offsetC
if offsetC' + kernelColumns < destinationCols if offsetC' + kernelColumns < destinationCols
@ -158,11 +153,11 @@ col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows
dataIms <- traverse (\_ -> U.newMatrix 0 destinationRows destinationCols) [0 .. filters-1] dataIms <- traverse (\_ -> U.newMatrix 0 destinationRows destinationCols) [0 .. filters-1]
offsetR <- newSTRef 0
offsetC <- newSTRef 0
offsetM <- newSTRef 0 offsetM <- newSTRef 0
forM_ dataIms $ \dataIm -> do forM_ dataIms $ \dataIm -> do
offsetR <- newSTRef 0
offsetC <- newSTRef 0
offsetM' <- readSTRef offsetM offsetM' <- readSTRef offsetM
forM_ [0 .. columnMatrixRows - 1] $ \ir -> do forM_ [0 .. columnMatrixRows - 1] $ \ir -> do
inputColumn <- newSTRef 0 inputColumn <- newSTRef 0
@ -171,7 +166,7 @@ col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows
ic <- readSTRef inputColumn ic <- readSTRef inputColumn
offsetR' <- readSTRef offsetR offsetR' <- readSTRef offsetR
offsetC' <- readSTRef offsetC offsetC' <- readSTRef offsetC
U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ atIndex columnMatrix (ir, ic + offsetM')) U.modifyMatrix dataIm (kr + offsetR') (kc + offsetC') (+ U.atM' columnMatrix ir (ic + offsetM'))
modifySTRef inputColumn (+1) modifySTRef inputColumn (+1)
offsetC' <- readSTRef offsetC offsetC' <- readSTRef offsetC
@ -179,60 +174,55 @@ col2vidUnsafe kernelRows kernelColumns strideRows strideColumns destinationRows
then modifySTRef offsetC (+ strideColumns) then modifySTRef offsetC (+ strideColumns)
else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows) else writeSTRef offsetC 0 >> modifySTRef offsetR (+ strideRows)
writeSTRef offsetR 0
writeSTRef offsetC 0
modifySTRef offsetM (+ (kernelRows * kernelColumns)) modifySTRef offsetM (+ (kernelRows * kernelColumns))
traverse U.freezeMatrix dataIms traverse U.unsafeFreezeMatrix dataIms
vid2colUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double vid2colUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> [Matrix Double] -> Matrix Double
vid2colUnsafe channels kernelRows kernelColumns striderows stridecols vidrows vidcols dataVid = U.runSTMatrix $ do vid2colUnsafe kernelRows kernelColumns striderows stridecols vidrows vidcols dataVid = U.runSTMatrix $ do
let starts = fittingStarts vidrows kernelRows striderows vidcols kernelColumns stridecols let starts = fittingStarts vidrows kernelRows striderows vidcols kernelColumns stridecols
matWidth = kernelRows * kernelColumns kernelSize = kernelRows * kernelColumns
destinationRows = 1 + (vidrows - kernelRows) `div` striderows numberOfPatches = length starts
destinationCols = 1 + (vidcols - kernelColumns) `div` stridecols channels = length dataVid
destinationSize = destinationRows * destinationCols
dataCol <- U.newMatrix 0 destinationSize (channels * matWidth) dataCol <- U.newMatrix 0 numberOfPatches (channels * kernelSize)
offsetC <- newSTRef 0 offsetC <- newSTRef 0
forM_ dataVid $ \dataIm -> do forM_ dataVid $ \dataIm -> do
inputRow <- newSTRef 0 inputRowRef <- newSTRef 0
offsetC' <- readSTRef offsetC offsetC' <- readSTRef offsetC
forM_ starts $ \(startRow, startCol) -> do forM_ starts $ \(startRow, startCol) -> do
inputColumn <- newSTRef 0 inputColumnRef <- newSTRef 0
inputRow' <- readSTRef inputRow inputRow <- readSTRef inputRowRef
forM_ [0 .. kernelRows -1] $ \kr -> forM_ [0 .. kernelRows -1] $ \kr ->
forM_ [0 .. kernelColumns -1] $ \kc -> do forM_ [0 .. kernelColumns -1] $ \kc -> do
inputColumn' <- readSTRef inputColumn inputColumn <- readSTRef inputColumnRef
U.modifyMatrix dataCol inputRow' (inputColumn' + offsetC') (+ atIndex dataIm (kr + startRow, kc + startCol)) U.modifyMatrix dataCol inputRow (inputColumn + offsetC') (+ U.atM' dataIm (kr + startRow) (kc + startCol))
modifySTRef inputColumn (+1) modifySTRef inputColumnRef (+1)
modifySTRef inputRow (+1) modifySTRef inputRowRef (+1)
modifySTRef offsetC (+ matWidth) modifySTRef offsetC (+ kernelSize)
return dataCol return dataCol
im2colUnsafe :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double im2colUnsafe :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
im2colUnsafe kernelRows kernelColumns striderows stridecols vidrows vidcols dataIm = U.runSTMatrix $ do im2colUnsafe kernelRows kernelColumns striderows stridecols dataIm = U.runSTMatrix $ do
let starts = fittingStarts vidrows kernelRows striderows vidcols kernelColumns stridecols let starts = fittingStarts (rows dataIm) kernelRows striderows (cols dataIm) kernelColumns stridecols
matWidth = kernelRows * kernelColumns kernelSize = kernelRows * kernelColumns
destinationRows = 1 + (vidrows - kernelRows) `div` striderows numberOfPatches = length starts
destinationCols = 1 + (vidcols - kernelColumns) `div` stridecols
destinationSize = destinationRows * destinationCols
dataCol <- U.newMatrix 0 destinationSize matWidth dataCol <- U.newMatrix 0 numberOfPatches kernelSize
inputRow <- newSTRef 0 inputRowRef <- newSTRef 0
forM_ starts $ \(startRow, startCol) -> do forM_ starts $ \(startRow, startCol) -> do
inputColumn <- newSTRef 0 inputColumnRef <- newSTRef 0
inputRow' <- readSTRef inputRow inputRow <- readSTRef inputRowRef
forM_ [0 .. kernelRows -1] $ \kr -> forM_ [0 .. kernelRows -1] $ \kr ->
forM_ [0 .. kernelColumns -1] $ \kc -> do forM_ [0 .. kernelColumns -1] $ \kc -> do
inputColumn' <- readSTRef inputColumn inputColumn <- readSTRef inputColumnRef
U.modifyMatrix dataCol inputRow' inputColumn' (+ atIndex dataIm (kr + startRow, kc + startCol)) U.modifyMatrix dataCol inputRow inputColumn (+ U.atM' dataIm (kr + startRow) (kc + startCol))
modifySTRef inputColumn (+1) modifySTRef inputColumnRef (+1)
modifySTRef inputRow (+1) modifySTRef inputRowRef (+1)
return dataCol return dataCol

View File

@ -0,0 +1,48 @@
module Grenade.Layers.Internal.Pooling (
poolForward
, poolBackward
, poolForwardList
, poolBackwardList
) where
import Numeric.LinearAlgebra hiding ( uniformSample, konst )
import qualified Numeric.LinearAlgebra as LA
import Grenade.Layers.Internal.Convolution
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
poolForward nrows ncols srows scols outputRows outputCols m =
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
in poolForwardFit starts nrows ncols outputRows outputCols m
poolForwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double) -> f (Matrix Double)
poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
let starts = fittingStarts inRows nrows srows inCols ncols scols
in poolForwardFit starts nrows ncols outputRows outputCols <$> ms
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
poolForwardFit starts nrows ncols _ outputCols m =
let els = fmap (\start -> maxElement $ subMatrix start (nrows, ncols) m) starts
in LA.matrix outputCols els
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
poolBackward krows kcols srows scols inputMatrix gradientMatrix =
let inRows = rows inputMatrix
inCols = cols inputMatrix
starts = fittingStarts inRows krows srows inCols kcols scols
in poolBackwardFit starts krows kcols inputMatrix gradientMatrix
poolBackwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double, Matrix Double) -> f (Matrix Double)
poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
let starts = fittingStarts inRows krows srows inCols kcols scols
in uncurry (poolBackwardFit starts krows kcols) <$> inputMatrices
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
let inRows = rows inputMatrix
inCols = cols inputMatrix
inds = fmap (\start -> maxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
grads = toList $ flatten gradientMatrix
grads' = zip3 starts grads inds
accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
in accum (LA.konst 0 (inRows, inCols)) (+) accums

View File

@ -24,10 +24,8 @@ import GHC.TypeLits
import Grenade.Core.Network import Grenade.Core.Network
import Grenade.Core.Shape import Grenade.Core.Shape
import Grenade.Core.Vector import Grenade.Core.Vector
import Grenade.Layers.Convolution.Internal import Grenade.Layers.Internal.Pooling
import Numeric.LinearAlgebra hiding (uniformSample)
import qualified Numeric.LinearAlgebra as LA
import Numeric.LinearAlgebra.Static as LAS hiding ((|||), build, toRows) import Numeric.LinearAlgebra.Static as LAS hiding ((|||), build, toRows)
-- | A pooling layer for a neural network. -- | A pooling layer for a neural network.
@ -37,16 +35,12 @@ import Numeric.LinearAlgebra.Static as LAS hiding ((|||), build, toRow
-- The kernel size dictates which input and output sizes will "fit". Fitting the equation: -- The kernel size dictates which input and output sizes will "fit". Fitting the equation:
-- `out = (in - kernel) / stride + 1` for both dimensions. -- `out = (in - kernel) / stride + 1` for both dimensions.
-- --
data Pooling :: Nat data Pooling :: Nat -> Nat -> Nat -> Nat -> * where
-> Nat
-> Nat
-> Nat -> * where
Pooling :: Pooling kernelRows kernelColumns strideRows strideColumns Pooling :: Pooling kernelRows kernelColumns strideRows strideColumns
instance Show (Pooling k k' s s') where instance Show (Pooling k k' s s') where
show Pooling = "Pooling" show Pooling = "Pooling"
instance UpdateLayer (Pooling kernelRows kernelColumns strideRows strideColumns) where instance UpdateLayer (Pooling kernelRows kernelColumns strideRows strideColumns) where
type Gradient (Pooling kr kc sr sc) = () type Gradient (Pooling kr kc sr sc) = ()
runUpdate _ Pooling _ = Pooling runUpdate _ Pooling _ = Pooling
@ -123,40 +117,3 @@ instance ( KnownNat kernelRows
ez = vectorZip (,) ex eo ez = vectorZip (,) ex eo
vs = poolBackwardList kx ky sx sy ix iy ez vs = poolBackwardList kx ky sx sy ix iy ez
in ((), S3D' . fmap (fromJust . create) $ vs) in ((), S3D' . fmap (fromJust . create) $ vs)
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
poolForward nrows ncols srows scols outputRows outputCols m =
let starts = fittingStarts (rows m) nrows srows (cols m) ncols scols
in poolForwardFit starts nrows ncols outputRows outputCols m
poolForwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double) -> f (Matrix Double)
poolForwardList nrows ncols srows scols inRows inCols outputRows outputCols ms =
let starts = fittingStarts inRows nrows srows inCols ncols scols
in poolForwardFit starts nrows ncols outputRows outputCols <$> ms
poolForwardFit :: [(Int,Int)] -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
poolForwardFit starts nrows ncols _ outputCols m =
let els = fmap (\start -> maxElement $ subMatrix start (nrows, ncols) m) starts
in LA.matrix outputCols els
poolBackward :: Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
poolBackward krows kcols srows scols inputMatrix gradientMatrix =
let inRows = (rows inputMatrix)
inCols = (cols inputMatrix)
starts = fittingStarts inRows krows srows inCols kcols scols
in poolBackwardFit starts krows kcols inputMatrix gradientMatrix
poolBackwardList :: Functor f => Int -> Int -> Int -> Int -> Int -> Int -> f (Matrix Double, Matrix Double) -> f (Matrix Double)
poolBackwardList krows kcols srows scols inRows inCols inputMatrices =
let starts = fittingStarts inRows krows srows inCols kcols scols
in (uncurry $ poolBackwardFit starts krows kcols) <$> inputMatrices
poolBackwardFit :: [(Int,Int)] -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
poolBackwardFit starts krows kcols inputMatrix gradientMatrix =
let inRows = (rows inputMatrix)
inCols = (cols inputMatrix)
inds = fmap (\start -> maxIndex $ subMatrix start (krows, kcols) inputMatrix) starts
grads = toList $ flatten gradientMatrix
grads' = zip3 starts grads inds
accums = fmap (\((stx',sty'),grad,(inx, iny)) -> ((stx' + inx, sty' + iny), grad)) grads'
in accum (LA.konst 0 (inRows, inCols)) (+) accums

View File

@ -8,7 +8,7 @@ import Grenade.Core.Shape
import Grenade.Core.Vector as Grenade import Grenade.Core.Vector as Grenade
import Grenade.Core.Network import Grenade.Core.Network
import Grenade.Layers.Convolution import Grenade.Layers.Convolution
import Grenade.Layers.Convolution.Internal import Grenade.Layers.Internal.Convolution
import Numeric.LinearAlgebra hiding (uniformSample, konst, (===)) import Numeric.LinearAlgebra hiding (uniformSample, konst, (===))
import qualified Numeric.LinearAlgebra.Static as HStatic import qualified Numeric.LinearAlgebra.Static as HStatic
@ -51,7 +51,7 @@ prop_im2col_other = once $
expected = (2><6) expected = (2><6)
[ 1.0, 2.0, 5.0, 6.0 , 9.0, 10.0 [ 1.0, 2.0, 5.0, 6.0 , 9.0, 10.0
, 3.0, 4.0, 7.0, 8.0 , 11.0 ,12.0 ] , 3.0, 4.0, 7.0, 8.0 , 11.0 ,12.0 ]
out = im2col 3 2 1 2 input out = im2colUnsafe 3 2 1 2 input
in expected === out in expected === out
-- If there's no overlap (stride is the same size as the kernel) -- If there's no overlap (stride is the same size as the kernel)
@ -71,7 +71,7 @@ prop_im2colunsafe_sym_on_same_stride = once $
[ 1.0, 2.0, 3.0, 4.0 [ 1.0, 2.0, 3.0, 4.0
, 5.0, 6.0, 7.0, 8.0 , 5.0, 6.0, 7.0, 8.0
, 9.0, 10.0, 11.0, 12.0 ] , 9.0, 10.0, 11.0, 12.0 ]
out = col2imUnsafe 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 3 4 $ input out = col2imUnsafe 3 2 3 2 3 4 . im2colUnsafe 3 2 3 2 $ input
in input === out in input === out
@ -86,7 +86,7 @@ prop_im2col_col2im_additive = once $
[ 1.0, 2.0, 2.0, 1.0 [ 1.0, 2.0, 2.0, 1.0
, 2.0, 4.0, 4.0, 2.0 , 2.0, 4.0, 4.0, 2.0
, 1.0, 2.0, 2.0, 1.0 ] , 1.0, 2.0, 2.0, 1.0 ]
out = col2im 2 2 1 1 3 4 . im2col 2 2 1 1 $ input out = col2imUnsafe 2 2 1 1 3 4 . im2colUnsafe 2 2 1 1 $ input
in expected === out in expected === out
prop_simple_conv_forwards = once $ prop_simple_conv_forwards = once $
@ -166,7 +166,7 @@ prop_vid2col_no_stride = once $
, 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0 , 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0
, 6.0, 7.0, 10.0, 11.0 , 26.0, 27.0, 30.0, 31.0 , 6.0, 7.0, 10.0, 11.0 , 26.0, 27.0, 30.0, 31.0
, 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ] , 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ]
out = vid2col 2 2 1 1 3 4 input out = vid2colUnsafe 2 2 1 1 3 4 input
in expected === out in expected === out
prop_vid2col_stride = once $ prop_vid2col_stride = once $
@ -183,7 +183,7 @@ prop_vid2col_stride = once $
, 3.0, 4.0, 7.0, 8.0 , 23.0, 24.0, 27.0, 28.0 , 3.0, 4.0, 7.0, 8.0 , 23.0, 24.0, 27.0, 28.0
, 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0 , 5.0, 6.0, 9.0, 10.0 , 25.0, 26.0, 29.0, 30.0
, 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ] , 7.0, 8.0, 11.0, 12.0 , 27.0, 28.0, 31.0, 32.0 ]
out = vid2col 2 2 1 2 3 4 input out = vid2colUnsafe 2 2 1 2 3 4 input
in expected === out in expected === out
@ -208,7 +208,7 @@ prop_vid2col_invert_unsafe = once $
[ 21.0, 22.0, 23.0, 24.0 [ 21.0, 22.0, 23.0, 24.0
, 25.0, 26.0, 27.0, 28.0 , 25.0, 26.0, 27.0, 28.0
, 29.0, 30.0, 31.0, 32.0 ] ] , 29.0, 30.0, 31.0, 32.0 ] ]
out = col2vidUnsafe 3 2 3 2 3 4 . vid2colUnsafe 2 3 2 3 2 3 4 $ input out = col2vidUnsafe 3 2 3 2 3 4 . vid2colUnsafe 3 2 3 2 3 4 $ input
in input === out in input === out